├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── assets ├── header_model_release.png ├── intel_performance.jpg ├── m2_performance.jpg ├── tl1.png └── tl2.png ├── docs └── codegen.md ├── include └── ggml-bitnet.h ├── media ├── benchmark.png └── demo.mp4 ├── preset_kernels ├── Llama3-8B-1.58-100B-tokens │ ├── bitnet-lut-kernels-tl1.h │ ├── bitnet-lut-kernels-tl2.h │ ├── kernel_config_tl1.ini │ └── kernel_config_tl2.ini ├── bitnet_b1_58-3B │ ├── bitnet-lut-kernels-tl1.h │ ├── bitnet-lut-kernels-tl2.h │ ├── kernel_config_tl1.ini │ └── kernel_config_tl2.ini └── bitnet_b1_58-large │ ├── bitnet-lut-kernels-tl1.h │ ├── bitnet-lut-kernels-tl2.h │ ├── kernel_config_tl1.ini │ └── kernel_config_tl2.ini ├── requirements.txt ├── run_inference.py ├── setup_env.py ├── src ├── CMakeLists.txt ├── ggml-bitnet-lut.cpp └── ggml-bitnet-mad.cpp └── utils ├── codegen_tl1.py ├── codegen_tl2.py ├── convert-hf-to-gguf-bitnet.py ├── convert-ms-to-gguf-bitnet.py ├── convert.py ├── e2e_benchmark.py ├── generate-dummy-bitnet-model.py └── kernel_tuning.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Extensions 2 | 3 | *.a 4 | *.bat 5 | *.bin 6 | *.dll 7 | *.dot 8 | *.etag 9 | *.exe 10 | *.gcda 11 | *.gcno 12 | *.gcov 13 | *.gguf 14 | *.gguf.json 15 | *.lastModified 16 | *.log 17 | *.metallib 18 | *.o 19 | *.so 20 | *.tmp 21 | 22 | # IDE / OS 23 | 24 | .cache/ 25 | .ccls-cache/ 26 | .direnv/ 27 | .DS_Store 28 | .envrc 29 | .idea/ 30 | .swiftpm 31 | .vs/ 32 | .vscode/ 33 | nppBackup 34 | 35 | # Models 36 | models/* 37 | 38 | # Python 39 | 40 | /.venv 41 | __pycache__/ 42 | */poetry.lock 43 | poetry.toml 44 | 45 | build/ 46 | logs/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/llama.cpp"] 2 | path = 3rdparty/llama.cpp 3 | url = https://github.com/Eddie-Wang1120/llama.cpp.git 4 | branch = merge-dev 5 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. 2 | project("bitnet.cpp" C CXX) 3 | include(CheckIncludeFileCXX) 4 | 5 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 6 | 7 | if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) 8 | set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) 9 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") 10 | endif() 11 | 12 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 13 | 14 | # option list 15 | option(BITNET_ARM_TL1 "bitnet.cpp: use tl1 on arm platform" OFF) 16 | option(BITNET_X86_TL2 "bitnet.cpp: use tl2 on x86 platform" OFF) 17 | 18 | 19 | set(CMAKE_CXX_STANDARD_REQUIRED true) 20 | set(CMAKE_C_STANDARD 11) 21 | set(CMAKE_C_STANDARD_REQUIRED true) 22 | set(THREADS_PREFER_PTHREAD_FLAG ON) 23 | 24 | # override ggml options 25 | set(GGML_BITNET_ARM_TL1 ${BITNET_ARM_TL1}) 26 | set(GGML_BITNET_X86_TL2 ${BITNET_X86_TL2}) 27 | 28 | if (GGML_BITNET_ARM_TL1) 29 | add_compile_definitions(GGML_BITNET_ARM_TL1) 30 | endif() 31 | if (GGML_BITNET_X86_TL2) 32 | add_compile_definitions(GGML_BITNET_X86_TL2) 33 | endif() 34 | 35 | if (CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") 36 | add_compile_options(-fpermissive) 37 | endif() 38 | 39 | find_package(Threads REQUIRED) 40 | 41 | add_subdirectory(src) 42 | add_subdirectory(3rdparty/llama.cpp) 43 | 44 | # install 45 | 46 | include(GNUInstallDirs) 47 | include(CMakePackageConfigHelpers) 48 | 49 | set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} 50 | CACHE PATH "Location of header files") 51 | set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} 52 | CACHE PATH "Location of library files") 53 | set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} 54 | CACHE PATH "Location of binary files") 55 | set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER}) 56 | set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT}) 57 | set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER}) 58 | 59 | get_target_property(GGML_DIRECTORY ggml SOURCE_DIR) 60 | get_directory_property(GGML_DIR_DEFINES DIRECTORY ${GGML_DIRECTORY} COMPILE_DEFINITIONS) 61 | get_target_property(GGML_TARGET_DEFINES ggml COMPILE_DEFINITIONS) 62 | set(GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES} ${GGML_DIR_DEFINES}) 63 | get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES) 64 | 65 | get_directory_property(LLAMA_TRANSIENT_DEFINES COMPILE_DEFINITIONS) 66 | 67 | write_basic_package_version_file( 68 | ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake 69 | VERSION ${LLAMA_INSTALL_VERSION} 70 | COMPATIBILITY SameMajorVersion) 71 | 72 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake 73 | ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake 74 | DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama) 75 | 76 | set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/llama.h) 77 | install(TARGETS llama LIBRARY PUBLIC_HEADER) -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bitnet.cpp 2 | [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) 3 | ![version](https://img.shields.io/badge/version-1.0-blue) 4 | 5 | [BitNet Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) 6 | 7 | Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or [build and run](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) it on your own CPU. 8 | 9 | bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU (with NPU and GPU support coming next). 10 | 11 | The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details. 12 | 13 | m2_performance 14 | m2_performance 15 | 16 | >The tested models are dummy setups used in a research context to demonstrate the inference performance of bitnet.cpp. 17 | 18 | ## Demo 19 | 20 | A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2: 21 | 22 | https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1 23 | 24 | ## What's New: 25 | - 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) ![NEW](https://img.shields.io/badge/NEW-red) 26 | - 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880) 27 | - 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965) 28 | - 10/21/2024 [1-bit AI Infra: Part 1.1, Fast and Lossless BitNet b1.58 Inference on CPUs](https://arxiv.org/abs/2410.16144) 29 | - 10/17/2024 bitnet.cpp 1.0 released. 30 | - 03/21/2024 [The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf) 31 | - 02/27/2024 [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764) 32 | - 10/17/2023 [BitNet: Scaling 1-bit Transformers for Large Language Models](https://arxiv.org/abs/2310.11453) 33 | 34 | ## Acknowledgements 35 | 36 | This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) framework. We would like to thank all the authors for their contributions to the open-source community. Also, bitnet.cpp's kernels are built on top of the Lookup Table methodologies pioneered in [T-MAC](https://github.com/microsoft/T-MAC/). For inference of general low-bit LLMs beyond ternary models, we recommend using T-MAC. 37 | ## Official Models 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 |
ModelParametersCPUKernel
I2_STL1TL2
BitNet-b1.58-2B-4T2.4Bx86
ARM
66 | 67 | ## Supported Models 68 | ❗️**We use existing 1-bit LLMs available on [Hugging Face](https://huggingface.co/) to demonstrate the inference capabilities of bitnet.cpp. We hope the release of bitnet.cpp will inspire the development of 1-bit LLMs in large-scale settings in terms of model size and training tokens.** 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 |
ModelParametersCPUKernel
I2_STL1TL2
bitnet_b1_58-large0.7Bx86
ARM
bitnet_b1_58-3B3.3Bx86
ARM
Llama3-8B-1.58-100B-tokens8.0Bx86
ARM
Falcon3 Family1B-10Bx86
ARM
140 | 141 | 142 | 143 | ## Installation 144 | 145 | ### Requirements 146 | - python>=3.9 147 | - cmake>=3.22 148 | - clang>=18 149 | - For Windows users, install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/). In the installer, toggle on at least the following options(this also automatically installs the required additional tools like CMake): 150 | - Desktop-development with C++ 151 | - C++-CMake Tools for Windows 152 | - Git for Windows 153 | - C++-Clang Compiler for Windows 154 | - MS-Build Support for LLVM-Toolset (clang) 155 | - For Debian/Ubuntu users, you can download with [Automatic installation script](https://apt.llvm.org/) 156 | 157 | `bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)"` 158 | - conda (highly recommend) 159 | 160 | ### Build from source 161 | 162 | > [!IMPORTANT] 163 | > If you are using Windows, please remember to always use a Developer Command Prompt / PowerShell for VS2022 for the following commands. Please refer to the FAQs below if you see any issues. 164 | 165 | 1. Clone the repo 166 | ```bash 167 | git clone --recursive https://github.com/microsoft/BitNet.git 168 | cd BitNet 169 | ``` 170 | 2. Install the dependencies 171 | ```bash 172 | # (Recommended) Create a new conda environment 173 | conda create -n bitnet-cpp python=3.9 174 | conda activate bitnet-cpp 175 | 176 | pip install -r requirements.txt 177 | ``` 178 | 3. Build the project 179 | ```bash 180 | # Manually download the model and run with local path 181 | huggingface-cli download microsoft/BitNet-b1.58-2B-4T-gguf --local-dir models/BitNet-b1.58-2B-4T 182 | python setup_env.py -md models/BitNet-b1.58-2B-4T -q i2_s 183 | 184 | ``` 185 |
186 | usage: setup_env.py [-h] [--hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}] [--model-dir MODEL_DIR] [--log-dir LOG_DIR] [--quant-type {i2_s,tl1}] [--quant-embd]
187 |                     [--use-pretuned]
188 | 
189 | Setup the environment for running inference
190 | 
191 | optional arguments:
192 |   -h, --help            show this help message and exit
193 |   --hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}, -hr {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}
194 |                         Model used for inference
195 |   --model-dir MODEL_DIR, -md MODEL_DIR
196 |                         Directory to save/load the model
197 |   --log-dir LOG_DIR, -ld LOG_DIR
198 |                         Directory to save the logging info
199 |   --quant-type {i2_s,tl1}, -q {i2_s,tl1}
200 |                         Quantization type
201 |   --quant-embd          Quantize the embeddings to f16
202 |   --use-pretuned, -p    Use the pretuned kernel parameters
203 | 
204 | ## Usage 205 | ### Basic usage 206 | ```bash 207 | # Run inference with the quantized model 208 | python run_inference.py -m models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf -p "You are a helpful assistant" -cnv 209 | ``` 210 |
211 | usage: run_inference.py [-h] [-m MODEL] [-n N_PREDICT] -p PROMPT [-t THREADS] [-c CTX_SIZE] [-temp TEMPERATURE] [-cnv]
212 | 
213 | Run inference
214 | 
215 | optional arguments:
216 |   -h, --help            show this help message and exit
217 |   -m MODEL, --model MODEL
218 |                         Path to model file
219 |   -n N_PREDICT, --n-predict N_PREDICT
220 |                         Number of tokens to predict when generating text
221 |   -p PROMPT, --prompt PROMPT
222 |                         Prompt to generate text from
223 |   -t THREADS, --threads THREADS
224 |                         Number of threads to use
225 |   -c CTX_SIZE, --ctx-size CTX_SIZE
226 |                         Size of the prompt context
227 |   -temp TEMPERATURE, --temperature TEMPERATURE
228 |                         Temperature, a hyperparameter that controls the randomness of the generated text
229 |   -cnv, --conversation  Whether to enable chat mode or not (for instruct models.)
230 |                         (When this option is turned on, the prompt specified by -p will be used as the system prompt.)
231 | 
232 | 233 | ### Benchmark 234 | We provide scripts to run the inference benchmark providing a model. 235 | 236 | ``` 237 | usage: e2e_benchmark.py -m MODEL [-n N_TOKEN] [-p N_PROMPT] [-t THREADS] 238 | 239 | Setup the environment for running the inference 240 | 241 | required arguments: 242 | -m MODEL, --model MODEL 243 | Path to the model file. 244 | 245 | optional arguments: 246 | -h, --help 247 | Show this help message and exit. 248 | -n N_TOKEN, --n-token N_TOKEN 249 | Number of generated tokens. 250 | -p N_PROMPT, --n-prompt N_PROMPT 251 | Prompt to generate text from. 252 | -t THREADS, --threads THREADS 253 | Number of threads to use. 254 | ``` 255 | 256 | Here's a brief explanation of each argument: 257 | 258 | - `-m`, `--model`: The path to the model file. This is a required argument that must be provided when running the script. 259 | - `-n`, `--n-token`: The number of tokens to generate during the inference. It is an optional argument with a default value of 128. 260 | - `-p`, `--n-prompt`: The number of prompt tokens to use for generating text. This is an optional argument with a default value of 512. 261 | - `-t`, `--threads`: The number of threads to use for running the inference. It is an optional argument with a default value of 2. 262 | - `-h`, `--help`: Show the help message and exit. Use this argument to display usage information. 263 | 264 | For example: 265 | 266 | ```sh 267 | python utils/e2e_benchmark.py -m /path/to/model -n 200 -p 256 -t 4 268 | ``` 269 | 270 | This command would run the inference benchmark using the model located at `/path/to/model`, generating 200 tokens from a 256 token prompt, utilizing 4 threads. 271 | 272 | For the model layout that do not supported by any public model, we provide scripts to generate a dummy model with the given model layout, and run the benchmark on your machine: 273 | 274 | ```bash 275 | python utils/generate-dummy-bitnet-model.py models/bitnet_b1_58-large --outfile models/dummy-bitnet-125m.tl1.gguf --outtype tl1 --model-size 125M 276 | 277 | # Run benchmark with the generated model, use -m to specify the model path, -p to specify the prompt processed, -n to specify the number of token to generate 278 | python utils/e2e_benchmark.py -m models/dummy-bitnet-125m.tl1.gguf -p 512 -n 128 279 | ``` 280 | ### FAQ (Frequently Asked Questions)📌 281 | 282 | #### Q1: The build dies with errors building llama.cpp due to issues with std::chrono in log.cpp? 283 | 284 | **A:** 285 | This is an issue introduced in recent version of llama.cpp. Please refer to this [commit](https://github.com/tinglou/llama.cpp/commit/4e3db1e3d78cc1bcd22bcb3af54bd2a4628dd323) in the [discussion](https://github.com/abetlen/llama-cpp-python/issues/1942) to fix this issue. 286 | 287 | #### Q2: How to build with clang in conda environment on windows? 288 | 289 | **A:** 290 | Before building the project, verify your clang installation and access to Visual Studio tools by running: 291 | ``` 292 | clang -v 293 | ``` 294 | 295 | This command checks that you are using the correct version of clang and that the Visual Studio tools are available. If you see an error message such as: 296 | ``` 297 | 'clang' is not recognized as an internal or external command, operable program or batch file. 298 | ``` 299 | 300 | It indicates that your command line window is not properly initialized for Visual Studio tools. 301 | 302 | • If you are using Command Prompt, run: 303 | ``` 304 | "C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64 305 | ``` 306 | 307 | • If you are using Windows PowerShell, run the following commands: 308 | ``` 309 | Import-Module "C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\Microsoft.VisualStudio.DevShell.dll" Enter-VsDevShell 3f0e31ad -SkipAutomaticLocation -DevCmdArguments "-arch=x64 -host_arch=x64" 310 | ``` 311 | 312 | These steps will initialize your environment and allow you to use the correct Visual Studio tools. 313 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /assets/header_model_release.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/assets/header_model_release.png -------------------------------------------------------------------------------- /assets/intel_performance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/assets/intel_performance.jpg -------------------------------------------------------------------------------- /assets/m2_performance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/assets/m2_performance.jpg -------------------------------------------------------------------------------- /assets/tl1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/assets/tl1.png -------------------------------------------------------------------------------- /assets/tl2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/assets/tl2.png -------------------------------------------------------------------------------- /docs/codegen.md: -------------------------------------------------------------------------------- 1 | Codegen for TL1 and TL2 2 | ------------------------ 3 | 4 | codegen_tl1.py and codegen_tl2.py are using params to generate kernel codes in different devices to achieve fastest performance for TL1 and TL2. 5 | 6 | We cutting weight into multiple compute blocks to best utilize hardware capabilities. 7 | 8 | ### Example 9 | bitnet_b1_58-large: 10 | 11 | - Make sure Matmul kernels shapes \ 12 | For example, bitnet_b1_58-large Matmul kernel shapes are:\ 13 | [1536, 4096]\ 14 | [1536, 1536]\ 15 | [4096, 1536] 16 | 17 | - Make sure each BM, BK, bm for each kernel to meet the requirements below 18 | - Generate codes\ 19 | For example, for bitnet_b1_58-large, we can gencode like: 20 | 21 | ```bash 22 | # For TL1 23 | python utils/codegen_tl1.py --model bitnet_b1_58-large --BM 256,128,256 --BK 128,64,128 --bm 32,64,32 24 | 25 | # For TL2 26 | python utils/codegen_tl2.py --model bitnet_b1_58-large --BM 256,128,256 --BK 96,192,96 --bm 32,32,32 27 | ``` 28 | 29 | ### TL1: 30 | ![TL1](../assets/tl1.png) 31 | 32 | For TL1, we cut weight into M / BM weights, each weight shape is (BM, K). Then we cut weight into K / BK weights, each weight shape is (BM, BK). As for (BM, BK) weight, we cut it the same way into (bm, compute_num / bm) compute blocks, and finish computing in it. 33 | 34 | Thus, we need to make sure 35 | - M % BM == 0 36 | - K % BK == 0 37 | - BM % bm == 0 38 | - bm choose in [32, 64] 39 | 40 | ### TL2: 41 | ![TL2](../assets/tl2.png) 42 | 43 | For TL2, things got a little more complicated. Due to TL2 needs BK % 6 == 0, we need to split K into threeK and twoK, in which compute in TL2 for (M, threeK), compute in TL1 for (M, two_K). 44 | 45 | Thus, we needs to make sure 46 | - M % BM == 0 47 | - K % BK % 32 == 0 48 | - BM % bm == 0 49 | - bm choose in \[32\] -------------------------------------------------------------------------------- /include/ggml-bitnet.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | #include "ggml-backend.h" 5 | 6 | #ifdef __ARM_NEON 7 | #include 8 | typedef float32_t bitnet_float_type; 9 | #else 10 | typedef float bitnet_float_type; 11 | #endif 12 | 13 | #ifdef __cplusplus 14 | extern "C" { 15 | #endif 16 | 17 | struct bitnet_tensor_extra { 18 | int lut_scales_size; 19 | int BK; 20 | int n_tile_num; 21 | uint8_t * qweights; 22 | bitnet_float_type * scales; 23 | }; 24 | 25 | GGML_API void ggml_bitnet_init(void); 26 | GGML_API void ggml_bitnet_free(void); 27 | // src0->type == Q4_0/IQ2_XXS/IQ3_XXS 28 | // bitnet.cpp currently only supports BitNet quantization or GPTQ-like quantization (only scales, without zeros) 29 | // If use i-quantization gguf models, the results will be wrong 30 | // TODO: add customized block types Q2_0/Q3_0 31 | GGML_API bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst); 32 | GGML_API size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst); 33 | GGML_API void ggml_bitnet_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits); 34 | GGML_API void ggml_bitnet_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits); 35 | GGML_API void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor); 36 | GGML_API int ggml_bitnet_get_type_bits(enum ggml_type type); 37 | GGML_API void ggml_bitnet_set_n_threads(int n_threads); 38 | #if defined(GGML_BITNET_ARM_TL1) 39 | GGML_API void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C); 40 | GGML_API void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT); 41 | #endif 42 | #if defined(GGML_BITNET_X86_TL2) 43 | GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C); 44 | GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT); 45 | #endif 46 | 47 | #ifdef __cplusplus 48 | } 49 | #endif 50 | -------------------------------------------------------------------------------- /media/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/media/benchmark.png -------------------------------------------------------------------------------- /media/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/media/demo.mp4 -------------------------------------------------------------------------------- /preset_kernels/Llama3-8B-1.58-100B-tokens/bitnet-lut-kernels-tl1.h: -------------------------------------------------------------------------------- 1 | #if defined(GGML_BITNET_ARM_TL1) 2 | #include "ggml-bitnet.h" 3 | #define GGML_BITNET_MAX_NODES 8192 4 | static bool initialized = false; 5 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; 6 | static size_t bitnet_tensor_extras_index = 0; 7 | static void * aligned_malloc(size_t size) {{ 8 | #if defined(_WIN32) 9 | return _aligned_malloc(size, 64); 10 | #else 11 | void * ptr = nullptr; 12 | posix_memalign(&ptr, 64, size); 13 | return ptr; 14 | #endif 15 | }} 16 | static void aligned_free(void * ptr) {{ 17 | #if defined(_WIN32) 18 | _aligned_free(ptr); 19 | #else 20 | free(ptr); 21 | #endif 22 | }} 23 | 24 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{ 25 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; 26 | bitnet_float_type* b = (bitnet_float_type*)b_; 27 | #ifdef __ARM_NEON 28 | float32x4_t temp_max = vdupq_n_f32(0); 29 | for (int i=0; i < k / 4; i++) {{ 30 | float32x4_t vec_bs = vld1q_f32(b + 4 * i); 31 | float32x4_t abssum = vabsq_f32(vec_bs); 32 | temp_max = vmaxq_f32(abssum, temp_max); 33 | }} 34 | float32_t scales = 127 / vmaxvq_f32(temp_max); 35 | *lut_scales = scales; 36 | #elif defined __AVX2__ 37 | __m256 max_vec = _mm256_set1_ps(0.f); 38 | const __m256 vec_sign = _mm256_set1_ps(-0.0f); 39 | // #pragma unroll 40 | for (int i = 0; i < k / 8; i++) {{ 41 | __m256 vec_b = _mm256_loadu_ps(b + i * 8); 42 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); 43 | max_vec = _mm256_max_ps(vec_babs, max_vec); 44 | }} 45 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); 46 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); 47 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); 48 | float scales = 127 / _mm_cvtss_f32(max1); 49 | *lut_scales = scales; 50 | #endif 51 | }} 52 | 53 | void partial_max_reset(void* lut_scales_) {{ 54 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; 55 | *lut_scales = 0.0; 56 | }} 57 | 58 | #ifdef __ARM_NEON 59 | inline void Transpose_8_8( 60 | int16x8_t *v0, 61 | int16x8_t *v1, 62 | int16x8_t *v2, 63 | int16x8_t *v3, 64 | int16x8_t *v4, 65 | int16x8_t *v5, 66 | int16x8_t *v6, 67 | int16x8_t *v7) 68 | {{ 69 | int16x8x2_t q04 = vzipq_s16(*v0, *v4); 70 | int16x8x2_t q15 = vzipq_s16(*v1, *v5); 71 | int16x8x2_t q26 = vzipq_s16(*v2, *v6); 72 | int16x8x2_t q37 = vzipq_s16(*v3, *v7); 73 | 74 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]); 75 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]); 76 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]); 77 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]); 78 | 79 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]); 80 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]); 81 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]); 82 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]); 83 | 84 | *v0 = q_fin_0.val[0]; 85 | *v1 = q_fin_0.val[1]; 86 | *v2 = q_fin_1.val[0]; 87 | *v3 = q_fin_1.val[1]; 88 | *v4 = q_fin_2.val[0]; 89 | *v5 = q_fin_2.val[1]; 90 | *v6 = q_fin_3.val[0]; 91 | *v7 = q_fin_3.val[1]; 92 | }} 93 | #endif 94 | 95 | template 96 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{ 97 | #ifdef __ARM_NEON 98 | int16x8_t vec_lut[16]; 99 | float32_t scales = *lut_scales; 100 | uint8_t tbl_mask[16]; 101 | tbl_mask[0] = 0; 102 | tbl_mask[1] = 2; 103 | tbl_mask[2] = 4; 104 | tbl_mask[3] = 6; 105 | tbl_mask[4] = 8; 106 | tbl_mask[5] = 10; 107 | tbl_mask[6] = 12; 108 | tbl_mask[7] = 14; 109 | tbl_mask[8] = 1; 110 | tbl_mask[9] = 3; 111 | tbl_mask[10] = 5; 112 | tbl_mask[11] = 7; 113 | tbl_mask[12] = 9; 114 | tbl_mask[13] = 11; 115 | tbl_mask[14] = 13; 116 | tbl_mask[15] = 15; 117 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask); 118 | #pragma unroll 119 | for (int k = 0; k < act_k / 16; ++k) {{ 120 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16); 121 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8); 122 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales); 123 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales); 124 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales); 125 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales); 126 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0); 127 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1); 128 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2); 129 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3); 130 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0); 131 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1); 132 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2); 133 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3); 134 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2); 135 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3); 136 | vec_lut[0] = vdupq_n_s16(0); 137 | vec_lut[0] = vec_lut[0] - vec_bs_0; 138 | vec_lut[0] = vec_lut[0] - vec_bs_1; 139 | vec_lut[1] = vdupq_n_s16(0); 140 | vec_lut[1] = vec_lut[1] - vec_bs_0; 141 | vec_lut[2] = vdupq_n_s16(0); 142 | vec_lut[2] = vec_lut[2] - vec_bs_0; 143 | vec_lut[2] = vec_lut[2] + vec_bs_1; 144 | vec_lut[3] = vdupq_n_s16(0); 145 | vec_lut[3] = vec_lut[3] - vec_bs_1; 146 | vec_lut[4] = vdupq_n_s16(0); 147 | vec_lut[5] = vec_bs_1; 148 | vec_lut[6] = vec_bs_0; 149 | vec_lut[6] = vec_lut[6] - vec_bs_1; 150 | vec_lut[7] = vec_bs_0; 151 | vec_lut[8] = vec_bs_0; 152 | vec_lut[8] = vec_lut[8] + vec_bs_1; 153 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]), 154 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7])); 155 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]), 156 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15])); 157 | #pragma unroll 158 | for (int idx = 0; idx < 8; idx++) {{ 159 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q); 160 | int8x8_t q0_low = vget_low_s8(q0_s); 161 | int8x8_t q0_high = vget_high_s8(q0_s); 162 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q); 163 | int8x8_t q1_low = vget_low_s8(q1_s); 164 | int8x8_t q1_high = vget_high_s8(q1_s); 165 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high); 166 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high); 167 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low); 168 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low); 169 | }} 170 | }} 171 | #endif 172 | }} 173 | 174 | static bool is_type_supported(enum ggml_type type) {{ 175 | if (type == GGML_TYPE_Q4_0 || 176 | type == GGML_TYPE_TL1) {{ 177 | return true; 178 | }} else {{ 179 | return false; 180 | }} 181 | }} 182 | #include 183 | 184 | #define BM14336_4096 256 185 | #define BBK14336_4096 128 186 | inline void tbl_impl_14336_4096(int32_t* c, int8_t* lut, uint8_t* a) { 187 | #ifdef __ARM_NEON 188 | const int KK = BBK14336_4096 / 2; 189 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 190 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 191 | int8x16_t vec_lut[2 * KK]; 192 | int16x8_t vec_c[8]; 193 | #pragma unroll 194 | for (int k = 0; k < 2 * KK; k++) { 195 | vec_lut[k] = vld1q_s8(lut + k * 16); 196 | } 197 | 198 | #pragma unroll 199 | for (int i = 0; i < BM14336_4096; i += 64) { 200 | #pragma unroll 201 | for (int i=0; i<8; i++) { 202 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 203 | } 204 | 205 | #pragma unroll 206 | for (int k = 0; k < KK / 2; k++) { 207 | 208 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 209 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 210 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 211 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); 212 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); 213 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); 214 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); 215 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 216 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 217 | vec_c[0] += vec_v_left_0.val[0]; 218 | vec_c[0] += vec_v_right_0.val[0]; 219 | vec_c[1] += vec_v_left_0.val[1]; 220 | vec_c[1] += vec_v_right_0.val[1]; 221 | 222 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 223 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 224 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 225 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); 226 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); 227 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); 228 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); 229 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 230 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 231 | vec_c[2] += vec_v_left_1.val[0]; 232 | vec_c[2] += vec_v_right_1.val[0]; 233 | vec_c[3] += vec_v_left_1.val[1]; 234 | vec_c[3] += vec_v_right_1.val[1]; 235 | 236 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 237 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 238 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 239 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); 240 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); 241 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); 242 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); 243 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 244 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 245 | vec_c[4] += vec_v_left_2.val[0]; 246 | vec_c[4] += vec_v_right_2.val[0]; 247 | vec_c[5] += vec_v_left_2.val[1]; 248 | vec_c[5] += vec_v_right_2.val[1]; 249 | 250 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 251 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 252 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 253 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); 254 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); 255 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); 256 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); 257 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 258 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 259 | vec_c[6] += vec_v_left_3.val[0]; 260 | vec_c[6] += vec_v_right_3.val[0]; 261 | vec_c[7] += vec_v_left_3.val[1]; 262 | vec_c[7] += vec_v_right_3.val[1]; 263 | 264 | } 265 | 266 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 267 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 268 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 269 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 270 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 271 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 272 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 273 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 274 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 275 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 276 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 277 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 278 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 279 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 280 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 281 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 282 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); 283 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); 284 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); 285 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); 286 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); 287 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); 288 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); 289 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); 290 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); 291 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); 292 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); 293 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); 294 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); 295 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); 296 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); 297 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); 298 | 299 | } 300 | #endif 301 | } 302 | 303 | int32_t qgemm_lut_14336_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 304 | alignas(32) uint32_t CBits[BM14336_4096]; 305 | memset(&(CBits[0]), 0, BM14336_4096 * sizeof(int32_t)); 306 | #pragma unroll 307 | for (int32_t k_outer = 0; k_outer < 4096 / BBK14336_4096; ++k_outer) { 308 | tbl_impl_14336_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK14336_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK14336_4096 / 2 / 2 * BM14336_4096)]))); 309 | } 310 | #pragma unroll 311 | for (int i = 0; i < BM14336_4096; i++) { 312 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 313 | } 314 | return 0; 315 | }; 316 | #include 317 | 318 | #define BM4096_14336 256 319 | #define BBK4096_14336 128 320 | inline void tbl_impl_4096_14336(int32_t* c, int8_t* lut, uint8_t* a) { 321 | #ifdef __ARM_NEON 322 | const int KK = BBK4096_14336 / 2; 323 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 324 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 325 | int8x16_t vec_lut[2 * KK]; 326 | int16x8_t vec_c[4]; 327 | #pragma unroll 328 | for (int k = 0; k < 2 * KK; k++) { 329 | vec_lut[k] = vld1q_s8(lut + k * 16); 330 | } 331 | 332 | #pragma unroll 333 | for (int i = 0; i < BM4096_14336; i += 32) { 334 | #pragma unroll 335 | for (int i=0; i<4; i++) { 336 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 337 | } 338 | 339 | #pragma unroll 340 | for (int k = 0; k < KK / 4; k++) { 341 | 342 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 343 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 344 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 345 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); 346 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); 347 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); 348 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); 349 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 350 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 351 | vec_c[0] += vec_v_left_0.val[0]; 352 | vec_c[0] += vec_v_right_0.val[0]; 353 | vec_c[1] += vec_v_left_0.val[1]; 354 | vec_c[1] += vec_v_right_0.val[1]; 355 | 356 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 357 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 358 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 359 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); 360 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); 361 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); 362 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); 363 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 364 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 365 | vec_c[0] += vec_v_left_1.val[0]; 366 | vec_c[0] += vec_v_right_1.val[0]; 367 | vec_c[1] += vec_v_left_1.val[1]; 368 | vec_c[1] += vec_v_right_1.val[1]; 369 | 370 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 371 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 372 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 373 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); 374 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); 375 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); 376 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); 377 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 378 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 379 | vec_c[2] += vec_v_left_2.val[0]; 380 | vec_c[2] += vec_v_right_2.val[0]; 381 | vec_c[3] += vec_v_left_2.val[1]; 382 | vec_c[3] += vec_v_right_2.val[1]; 383 | 384 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 385 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 386 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 387 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); 388 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); 389 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); 390 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); 391 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 392 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 393 | vec_c[2] += vec_v_left_3.val[0]; 394 | vec_c[2] += vec_v_right_3.val[0]; 395 | vec_c[3] += vec_v_left_3.val[1]; 396 | vec_c[3] += vec_v_right_3.val[1]; 397 | 398 | } 399 | 400 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 401 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 402 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 403 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 404 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 405 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 406 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 407 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 408 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 409 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 410 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 411 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 412 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 413 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 414 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 415 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 416 | 417 | } 418 | #endif 419 | } 420 | 421 | int32_t qgemm_lut_4096_14336(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 422 | alignas(32) uint32_t CBits[BM4096_14336]; 423 | memset(&(CBits[0]), 0, BM4096_14336 * sizeof(int32_t)); 424 | #pragma unroll 425 | for (int32_t k_outer = 0; k_outer < 14336 / BBK4096_14336; ++k_outer) { 426 | tbl_impl_4096_14336((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_14336 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_14336 / 2 / 2 * BM4096_14336)]))); 427 | } 428 | #pragma unroll 429 | for (int i = 0; i < BM4096_14336; i++) { 430 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 431 | } 432 | return 0; 433 | }; 434 | #include 435 | 436 | #define BM1024_4096 128 437 | #define BBK1024_4096 64 438 | inline void tbl_impl_1024_4096(int32_t* c, int8_t* lut, uint8_t* a) { 439 | #ifdef __ARM_NEON 440 | const int KK = BBK1024_4096 / 2; 441 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 442 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 443 | int8x16_t vec_lut[2 * KK]; 444 | int16x8_t vec_c[8]; 445 | #pragma unroll 446 | for (int k = 0; k < 2 * KK; k++) { 447 | vec_lut[k] = vld1q_s8(lut + k * 16); 448 | } 449 | 450 | #pragma unroll 451 | for (int i = 0; i < BM1024_4096; i += 64) { 452 | #pragma unroll 453 | for (int i=0; i<8; i++) { 454 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 455 | } 456 | 457 | #pragma unroll 458 | for (int k = 0; k < KK / 2; k++) { 459 | 460 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 461 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 462 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 463 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); 464 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); 465 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); 466 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); 467 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 468 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 469 | vec_c[0] += vec_v_left_0.val[0]; 470 | vec_c[0] += vec_v_right_0.val[0]; 471 | vec_c[1] += vec_v_left_0.val[1]; 472 | vec_c[1] += vec_v_right_0.val[1]; 473 | 474 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 475 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 476 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 477 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); 478 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); 479 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); 480 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); 481 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 482 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 483 | vec_c[2] += vec_v_left_1.val[0]; 484 | vec_c[2] += vec_v_right_1.val[0]; 485 | vec_c[3] += vec_v_left_1.val[1]; 486 | vec_c[3] += vec_v_right_1.val[1]; 487 | 488 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 489 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 490 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 491 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); 492 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); 493 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); 494 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); 495 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 496 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 497 | vec_c[4] += vec_v_left_2.val[0]; 498 | vec_c[4] += vec_v_right_2.val[0]; 499 | vec_c[5] += vec_v_left_2.val[1]; 500 | vec_c[5] += vec_v_right_2.val[1]; 501 | 502 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 503 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 504 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 505 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); 506 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); 507 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); 508 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); 509 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 510 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 511 | vec_c[6] += vec_v_left_3.val[0]; 512 | vec_c[6] += vec_v_right_3.val[0]; 513 | vec_c[7] += vec_v_left_3.val[1]; 514 | vec_c[7] += vec_v_right_3.val[1]; 515 | 516 | } 517 | 518 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 519 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 520 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 521 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 522 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 523 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 524 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 525 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 526 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 527 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 528 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 529 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 530 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 531 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 532 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 533 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 534 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); 535 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); 536 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); 537 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); 538 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); 539 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); 540 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); 541 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); 542 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); 543 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); 544 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); 545 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); 546 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); 547 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); 548 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); 549 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); 550 | 551 | } 552 | #endif 553 | } 554 | 555 | int32_t qgemm_lut_1024_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 556 | alignas(32) uint32_t CBits[BM1024_4096]; 557 | memset(&(CBits[0]), 0, BM1024_4096 * sizeof(int32_t)); 558 | #pragma unroll 559 | for (int32_t k_outer = 0; k_outer < 4096 / BBK1024_4096; ++k_outer) { 560 | tbl_impl_1024_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1024_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1024_4096 / 2 / 2 * BM1024_4096)]))); 561 | } 562 | #pragma unroll 563 | for (int i = 0; i < BM1024_4096; i++) { 564 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 565 | } 566 | return 0; 567 | }; 568 | #include 569 | 570 | #define BM4096_4096 128 571 | #define BBK4096_4096 64 572 | inline void tbl_impl_4096_4096(int32_t* c, int8_t* lut, uint8_t* a) { 573 | #ifdef __ARM_NEON 574 | const int KK = BBK4096_4096 / 2; 575 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 576 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 577 | int8x16_t vec_lut[2 * KK]; 578 | int16x8_t vec_c[4]; 579 | #pragma unroll 580 | for (int k = 0; k < 2 * KK; k++) { 581 | vec_lut[k] = vld1q_s8(lut + k * 16); 582 | } 583 | 584 | #pragma unroll 585 | for (int i = 0; i < BM4096_4096; i += 32) { 586 | #pragma unroll 587 | for (int i=0; i<4; i++) { 588 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 589 | } 590 | 591 | #pragma unroll 592 | for (int k = 0; k < KK / 4; k++) { 593 | 594 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 595 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 596 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 597 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); 598 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); 599 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); 600 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); 601 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 602 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 603 | vec_c[0] += vec_v_left_0.val[0]; 604 | vec_c[0] += vec_v_right_0.val[0]; 605 | vec_c[1] += vec_v_left_0.val[1]; 606 | vec_c[1] += vec_v_right_0.val[1]; 607 | 608 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 609 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 610 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 611 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); 612 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); 613 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); 614 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); 615 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 616 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 617 | vec_c[0] += vec_v_left_1.val[0]; 618 | vec_c[0] += vec_v_right_1.val[0]; 619 | vec_c[1] += vec_v_left_1.val[1]; 620 | vec_c[1] += vec_v_right_1.val[1]; 621 | 622 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 623 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 624 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 625 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); 626 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); 627 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); 628 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); 629 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 630 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 631 | vec_c[2] += vec_v_left_2.val[0]; 632 | vec_c[2] += vec_v_right_2.val[0]; 633 | vec_c[3] += vec_v_left_2.val[1]; 634 | vec_c[3] += vec_v_right_2.val[1]; 635 | 636 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 637 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 638 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 639 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); 640 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); 641 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); 642 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); 643 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 644 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 645 | vec_c[2] += vec_v_left_3.val[0]; 646 | vec_c[2] += vec_v_right_3.val[0]; 647 | vec_c[3] += vec_v_left_3.val[1]; 648 | vec_c[3] += vec_v_right_3.val[1]; 649 | 650 | } 651 | 652 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 653 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 654 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 655 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 656 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 657 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 658 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 659 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 660 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 661 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 662 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 663 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 664 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 665 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 666 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 667 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 668 | 669 | } 670 | #endif 671 | } 672 | 673 | int32_t qgemm_lut_4096_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 674 | alignas(32) uint32_t CBits[BM4096_4096]; 675 | memset(&(CBits[0]), 0, BM4096_4096 * sizeof(int32_t)); 676 | #pragma unroll 677 | for (int32_t k_outer = 0; k_outer < 4096 / BBK4096_4096; ++k_outer) { 678 | tbl_impl_4096_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_4096 / 2 / 2 * BM4096_4096)]))); 679 | } 680 | #pragma unroll 681 | for (int i = 0; i < BM4096_4096; i++) { 682 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 683 | } 684 | return 0; 685 | }; 686 | 687 | template 688 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{ 689 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0]))); 690 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0]))); 691 | 692 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0]))); 693 | }} 694 | void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) { 695 | if (m == 14336 && k == 4096) { 696 | preprocessor_k<4096>(B, LUT_Scales, QLUT); 697 | } 698 | else if (m == 4096 && k == 14336) { 699 | preprocessor_k<14336>(B, LUT_Scales, QLUT); 700 | } 701 | else if (m == 1024 && k == 4096) { 702 | preprocessor_k<4096>(B, LUT_Scales, QLUT); 703 | } 704 | else if (m == 4096 && k == 4096) { 705 | preprocessor_k<4096>(B, LUT_Scales, QLUT); 706 | } 707 | } 708 | void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 709 | if (m == 14336 && k == 4096) { 710 | qgemm_lut_14336_4096(A, LUT, Scales, LUT_Scales, C); 711 | } 712 | else if (m == 4096 && k == 14336) { 713 | qgemm_lut_4096_14336(A, LUT, Scales, LUT_Scales, C); 714 | } 715 | else if (m == 1024 && k == 4096) { 716 | qgemm_lut_1024_4096(A, LUT, Scales, LUT_Scales, C); 717 | } 718 | else if (m == 4096 && k == 4096) { 719 | qgemm_lut_4096_4096(A, LUT, Scales, LUT_Scales, C); 720 | } 721 | } 722 | 723 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { 724 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { 725 | return; 726 | } 727 | 728 | int k = tensor->ne[0]; 729 | int m = tensor->ne[1]; 730 | const int lut_scales_size = 1; 731 | const int scales_size = 1; 732 | int bk = 0; 733 | int bm = 0; 734 | 735 | if (m == 14336 && k == 4096) { 736 | bm = BM14336_4096; 737 | bk = BBK14336_4096; 738 | } 739 | else if (m == 4096 && k == 14336) { 740 | bm = BM4096_14336; 741 | bk = BBK4096_14336; 742 | } 743 | else if (m == 1024 && k == 4096) { 744 | bm = BM1024_4096; 745 | bk = BBK1024_4096; 746 | } 747 | else if (m == 4096 && k == 4096) { 748 | bm = BM4096_4096; 749 | bk = BBK4096_4096; 750 | } 751 | 752 | const int n_tile_num = m / bm; 753 | const int BK = bk; 754 | uint8_t * qweights; 755 | bitnet_float_type * scales; 756 | 757 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); 758 | qweights = (uint8_t *) tensor->data; 759 | float * i2_scales = (float * )(qweights + k * m / 4); 760 | scales[0] = (bitnet_float_type) i2_scales[0]; 761 | 762 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; 763 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = { 764 | /* .lut_scales_size = */ lut_scales_size, 765 | /* .scales_size = */ scales_size, 766 | /* .n_tile_num = */ n_tile_num, 767 | /* .qweights = */ qweights, 768 | /* .scales = */ scales 769 | }; 770 | } 771 | #endif -------------------------------------------------------------------------------- /preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl1.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 14336 3 | k = 4096 4 | bm = 256 5 | bk = 128 6 | bmm = 64 7 | 8 | [Kernels_1] 9 | m = 4096 10 | k = 14336 11 | bm = 256 12 | bk = 128 13 | bmm = 32 14 | 15 | [Kernels_2] 16 | m = 1024 17 | k = 4096 18 | bm = 128 19 | bk = 64 20 | bmm = 64 21 | 22 | [Kernels_3] 23 | m = 4096 24 | k = 4096 25 | bm = 128 26 | bk = 64 27 | bmm = 32 28 | 29 | -------------------------------------------------------------------------------- /preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl2.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 14336 3 | k = 4096 4 | bm = 256 5 | bk = 96 6 | bmm = 32 7 | 8 | [Kernels_1] 9 | m = 4096 10 | k = 14336 11 | bm = 128 12 | bk = 96 13 | bmm = 32 14 | 15 | [Kernels_2] 16 | m = 1024 17 | k = 4096 18 | bm = 256 19 | bk = 96 20 | bmm = 32 21 | 22 | [Kernels_3] 23 | m = 4096 24 | k = 4096 25 | bm = 128 26 | bk = 96 27 | bmm = 32 28 | 29 | -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-3B/bitnet-lut-kernels-tl1.h: -------------------------------------------------------------------------------- 1 | #if defined(GGML_BITNET_ARM_TL1) 2 | #include "ggml-bitnet.h" 3 | #define GGML_BITNET_MAX_NODES 8192 4 | static bool initialized = false; 5 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; 6 | static size_t bitnet_tensor_extras_index = 0; 7 | static void * aligned_malloc(size_t size) {{ 8 | #if defined(_WIN32) 9 | return _aligned_malloc(size, 64); 10 | #else 11 | void * ptr = nullptr; 12 | posix_memalign(&ptr, 64, size); 13 | return ptr; 14 | #endif 15 | }} 16 | static void aligned_free(void * ptr) {{ 17 | #if defined(_WIN32) 18 | _aligned_free(ptr); 19 | #else 20 | free(ptr); 21 | #endif 22 | }} 23 | 24 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{ 25 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; 26 | bitnet_float_type* b = (bitnet_float_type*)b_; 27 | #ifdef __ARM_NEON 28 | float32x4_t temp_max = vdupq_n_f32(0); 29 | for (int i=0; i < k / 4; i++) {{ 30 | float32x4_t vec_bs = vld1q_f32(b + 4 * i); 31 | float32x4_t abssum = vabsq_f32(vec_bs); 32 | temp_max = vmaxq_f32(abssum, temp_max); 33 | }} 34 | float32_t scales = 127 / vmaxvq_f32(temp_max); 35 | *lut_scales = scales; 36 | #elif defined __AVX2__ 37 | __m256 max_vec = _mm256_set1_ps(0.f); 38 | const __m256 vec_sign = _mm256_set1_ps(-0.0f); 39 | // #pragma unroll 40 | for (int i = 0; i < k / 8; i++) {{ 41 | __m256 vec_b = _mm256_loadu_ps(b + i * 8); 42 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); 43 | max_vec = _mm256_max_ps(vec_babs, max_vec); 44 | }} 45 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); 46 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); 47 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); 48 | float scales = 127 / _mm_cvtss_f32(max1); 49 | *lut_scales = scales; 50 | #endif 51 | }} 52 | 53 | void partial_max_reset(void* lut_scales_) {{ 54 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; 55 | *lut_scales = 0.0; 56 | }} 57 | 58 | #ifdef __ARM_NEON 59 | inline void Transpose_8_8( 60 | int16x8_t *v0, 61 | int16x8_t *v1, 62 | int16x8_t *v2, 63 | int16x8_t *v3, 64 | int16x8_t *v4, 65 | int16x8_t *v5, 66 | int16x8_t *v6, 67 | int16x8_t *v7) 68 | {{ 69 | int16x8x2_t q04 = vzipq_s16(*v0, *v4); 70 | int16x8x2_t q15 = vzipq_s16(*v1, *v5); 71 | int16x8x2_t q26 = vzipq_s16(*v2, *v6); 72 | int16x8x2_t q37 = vzipq_s16(*v3, *v7); 73 | 74 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]); 75 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]); 76 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]); 77 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]); 78 | 79 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]); 80 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]); 81 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]); 82 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]); 83 | 84 | *v0 = q_fin_0.val[0]; 85 | *v1 = q_fin_0.val[1]; 86 | *v2 = q_fin_1.val[0]; 87 | *v3 = q_fin_1.val[1]; 88 | *v4 = q_fin_2.val[0]; 89 | *v5 = q_fin_2.val[1]; 90 | *v6 = q_fin_3.val[0]; 91 | *v7 = q_fin_3.val[1]; 92 | }} 93 | #endif 94 | 95 | template 96 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{ 97 | #ifdef __ARM_NEON 98 | int16x8_t vec_lut[16]; 99 | float32_t scales = *lut_scales; 100 | uint8_t tbl_mask[16]; 101 | tbl_mask[0] = 0; 102 | tbl_mask[1] = 2; 103 | tbl_mask[2] = 4; 104 | tbl_mask[3] = 6; 105 | tbl_mask[4] = 8; 106 | tbl_mask[5] = 10; 107 | tbl_mask[6] = 12; 108 | tbl_mask[7] = 14; 109 | tbl_mask[8] = 1; 110 | tbl_mask[9] = 3; 111 | tbl_mask[10] = 5; 112 | tbl_mask[11] = 7; 113 | tbl_mask[12] = 9; 114 | tbl_mask[13] = 11; 115 | tbl_mask[14] = 13; 116 | tbl_mask[15] = 15; 117 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask); 118 | #pragma unroll 119 | for (int k = 0; k < act_k / 16; ++k) {{ 120 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16); 121 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8); 122 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales); 123 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales); 124 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales); 125 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales); 126 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0); 127 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1); 128 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2); 129 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3); 130 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0); 131 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1); 132 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2); 133 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3); 134 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2); 135 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3); 136 | vec_lut[0] = vdupq_n_s16(0); 137 | vec_lut[0] = vec_lut[0] - vec_bs_0; 138 | vec_lut[0] = vec_lut[0] - vec_bs_1; 139 | vec_lut[1] = vdupq_n_s16(0); 140 | vec_lut[1] = vec_lut[1] - vec_bs_0; 141 | vec_lut[2] = vdupq_n_s16(0); 142 | vec_lut[2] = vec_lut[2] - vec_bs_0; 143 | vec_lut[2] = vec_lut[2] + vec_bs_1; 144 | vec_lut[3] = vdupq_n_s16(0); 145 | vec_lut[3] = vec_lut[3] - vec_bs_1; 146 | vec_lut[4] = vdupq_n_s16(0); 147 | vec_lut[5] = vec_bs_1; 148 | vec_lut[6] = vec_bs_0; 149 | vec_lut[6] = vec_lut[6] - vec_bs_1; 150 | vec_lut[7] = vec_bs_0; 151 | vec_lut[8] = vec_bs_0; 152 | vec_lut[8] = vec_lut[8] + vec_bs_1; 153 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]), 154 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7])); 155 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]), 156 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15])); 157 | #pragma unroll 158 | for (int idx = 0; idx < 8; idx++) {{ 159 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q); 160 | int8x8_t q0_low = vget_low_s8(q0_s); 161 | int8x8_t q0_high = vget_high_s8(q0_s); 162 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q); 163 | int8x8_t q1_low = vget_low_s8(q1_s); 164 | int8x8_t q1_high = vget_high_s8(q1_s); 165 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high); 166 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high); 167 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low); 168 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low); 169 | }} 170 | }} 171 | #endif 172 | }} 173 | 174 | static bool is_type_supported(enum ggml_type type) {{ 175 | if (type == GGML_TYPE_Q4_0 || 176 | type == GGML_TYPE_TL1) {{ 177 | return true; 178 | }} else {{ 179 | return false; 180 | }} 181 | }} 182 | #include 183 | 184 | #define BM3200_8640 160 185 | #define BBK3200_8640 64 186 | inline void tbl_impl_3200_8640(int32_t* c, int8_t* lut, uint8_t* a) { 187 | #ifdef __ARM_NEON 188 | const int KK = BBK3200_8640 / 2; 189 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 190 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 191 | int8x16_t vec_lut[2 * KK]; 192 | int16x8_t vec_c[4]; 193 | #pragma unroll 194 | for (int k = 0; k < 2 * KK; k++) { 195 | vec_lut[k] = vld1q_s8(lut + k * 16); 196 | } 197 | 198 | #pragma unroll 199 | for (int i = 0; i < BM3200_8640; i += 32) { 200 | #pragma unroll 201 | for (int i=0; i<4; i++) { 202 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 203 | } 204 | 205 | #pragma unroll 206 | for (int k = 0; k < KK / 4; k++) { 207 | 208 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 209 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 210 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 211 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); 212 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); 213 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); 214 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); 215 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 216 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 217 | vec_c[0] += vec_v_left_0.val[0]; 218 | vec_c[0] += vec_v_right_0.val[0]; 219 | vec_c[1] += vec_v_left_0.val[1]; 220 | vec_c[1] += vec_v_right_0.val[1]; 221 | 222 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 223 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 224 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 225 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); 226 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); 227 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); 228 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); 229 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 230 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 231 | vec_c[0] += vec_v_left_1.val[0]; 232 | vec_c[0] += vec_v_right_1.val[0]; 233 | vec_c[1] += vec_v_left_1.val[1]; 234 | vec_c[1] += vec_v_right_1.val[1]; 235 | 236 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 237 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 238 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 239 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); 240 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); 241 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); 242 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); 243 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 244 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 245 | vec_c[2] += vec_v_left_2.val[0]; 246 | vec_c[2] += vec_v_right_2.val[0]; 247 | vec_c[3] += vec_v_left_2.val[1]; 248 | vec_c[3] += vec_v_right_2.val[1]; 249 | 250 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 251 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 252 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 253 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); 254 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); 255 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); 256 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); 257 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 258 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 259 | vec_c[2] += vec_v_left_3.val[0]; 260 | vec_c[2] += vec_v_right_3.val[0]; 261 | vec_c[3] += vec_v_left_3.val[1]; 262 | vec_c[3] += vec_v_right_3.val[1]; 263 | 264 | } 265 | 266 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 267 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 268 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 269 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 270 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 271 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 272 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 273 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 274 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 275 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 276 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 277 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 278 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 279 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 280 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 281 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 282 | 283 | } 284 | #endif 285 | } 286 | 287 | int32_t qgemm_lut_3200_8640(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 288 | alignas(32) uint32_t CBits[BM3200_8640]; 289 | memset(&(CBits[0]), 0, BM3200_8640 * sizeof(int32_t)); 290 | #pragma unroll 291 | for (int32_t k_outer = 0; k_outer < 8640 / BBK3200_8640; ++k_outer) { 292 | tbl_impl_3200_8640((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_8640 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_8640 / 2 / 2 * BM3200_8640)]))); 293 | } 294 | #pragma unroll 295 | for (int i = 0; i < BM3200_8640; i++) { 296 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 297 | } 298 | return 0; 299 | }; 300 | #include 301 | 302 | #define BM3200_3200 320 303 | #define BBK3200_3200 128 304 | inline void tbl_impl_3200_3200(int32_t* c, int8_t* lut, uint8_t* a) { 305 | #ifdef __ARM_NEON 306 | const int KK = BBK3200_3200 / 2; 307 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 308 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 309 | int8x16_t vec_lut[2 * KK]; 310 | int16x8_t vec_c[8]; 311 | #pragma unroll 312 | for (int k = 0; k < 2 * KK; k++) { 313 | vec_lut[k] = vld1q_s8(lut + k * 16); 314 | } 315 | 316 | #pragma unroll 317 | for (int i = 0; i < BM3200_3200; i += 64) { 318 | #pragma unroll 319 | for (int i=0; i<8; i++) { 320 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 321 | } 322 | 323 | #pragma unroll 324 | for (int k = 0; k < KK / 2; k++) { 325 | 326 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 327 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 328 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 329 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); 330 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); 331 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); 332 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); 333 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 334 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 335 | vec_c[0] += vec_v_left_0.val[0]; 336 | vec_c[0] += vec_v_right_0.val[0]; 337 | vec_c[1] += vec_v_left_0.val[1]; 338 | vec_c[1] += vec_v_right_0.val[1]; 339 | 340 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 341 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 342 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 343 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); 344 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); 345 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); 346 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); 347 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 348 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 349 | vec_c[2] += vec_v_left_1.val[0]; 350 | vec_c[2] += vec_v_right_1.val[0]; 351 | vec_c[3] += vec_v_left_1.val[1]; 352 | vec_c[3] += vec_v_right_1.val[1]; 353 | 354 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 355 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 356 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 357 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); 358 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); 359 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); 360 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); 361 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 362 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 363 | vec_c[4] += vec_v_left_2.val[0]; 364 | vec_c[4] += vec_v_right_2.val[0]; 365 | vec_c[5] += vec_v_left_2.val[1]; 366 | vec_c[5] += vec_v_right_2.val[1]; 367 | 368 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 369 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 370 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 371 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); 372 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); 373 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); 374 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); 375 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 376 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 377 | vec_c[6] += vec_v_left_3.val[0]; 378 | vec_c[6] += vec_v_right_3.val[0]; 379 | vec_c[7] += vec_v_left_3.val[1]; 380 | vec_c[7] += vec_v_right_3.val[1]; 381 | 382 | } 383 | 384 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 385 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 386 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 387 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 388 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 389 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 390 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 391 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 392 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 393 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 394 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 395 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 396 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 397 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 398 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 399 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 400 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); 401 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); 402 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); 403 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); 404 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); 405 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); 406 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); 407 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); 408 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); 409 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); 410 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); 411 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); 412 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); 413 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); 414 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); 415 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); 416 | 417 | } 418 | #endif 419 | } 420 | 421 | int32_t qgemm_lut_3200_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 422 | alignas(32) uint32_t CBits[BM3200_3200]; 423 | memset(&(CBits[0]), 0, BM3200_3200 * sizeof(int32_t)); 424 | #pragma unroll 425 | for (int32_t k_outer = 0; k_outer < 3200 / BBK3200_3200; ++k_outer) { 426 | tbl_impl_3200_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_3200 / 2 / 2 * BM3200_3200)]))); 427 | } 428 | #pragma unroll 429 | for (int i = 0; i < BM3200_3200; i++) { 430 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 431 | } 432 | return 0; 433 | }; 434 | #include 435 | 436 | #define BM8640_3200 320 437 | #define BBK8640_3200 64 438 | inline void tbl_impl_8640_3200(int32_t* c, int8_t* lut, uint8_t* a) { 439 | #ifdef __ARM_NEON 440 | const int KK = BBK8640_3200 / 2; 441 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 442 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 443 | int8x16_t vec_lut[2 * KK]; 444 | int16x8_t vec_c[4]; 445 | #pragma unroll 446 | for (int k = 0; k < 2 * KK; k++) { 447 | vec_lut[k] = vld1q_s8(lut + k * 16); 448 | } 449 | 450 | #pragma unroll 451 | for (int i = 0; i < BM8640_3200; i += 32) { 452 | #pragma unroll 453 | for (int i=0; i<4; i++) { 454 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 455 | } 456 | 457 | #pragma unroll 458 | for (int k = 0; k < KK / 4; k++) { 459 | 460 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 461 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 462 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 463 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); 464 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); 465 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); 466 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); 467 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 468 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 469 | vec_c[0] += vec_v_left_0.val[0]; 470 | vec_c[0] += vec_v_right_0.val[0]; 471 | vec_c[1] += vec_v_left_0.val[1]; 472 | vec_c[1] += vec_v_right_0.val[1]; 473 | 474 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 475 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 476 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 477 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); 478 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); 479 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); 480 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); 481 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 482 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 483 | vec_c[0] += vec_v_left_1.val[0]; 484 | vec_c[0] += vec_v_right_1.val[0]; 485 | vec_c[1] += vec_v_left_1.val[1]; 486 | vec_c[1] += vec_v_right_1.val[1]; 487 | 488 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 489 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 490 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 491 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); 492 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); 493 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); 494 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); 495 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 496 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 497 | vec_c[2] += vec_v_left_2.val[0]; 498 | vec_c[2] += vec_v_right_2.val[0]; 499 | vec_c[3] += vec_v_left_2.val[1]; 500 | vec_c[3] += vec_v_right_2.val[1]; 501 | 502 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 503 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 504 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 505 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); 506 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); 507 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); 508 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); 509 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 510 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 511 | vec_c[2] += vec_v_left_3.val[0]; 512 | vec_c[2] += vec_v_right_3.val[0]; 513 | vec_c[3] += vec_v_left_3.val[1]; 514 | vec_c[3] += vec_v_right_3.val[1]; 515 | 516 | } 517 | 518 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 519 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 520 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 521 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 522 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 523 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 524 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 525 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 526 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 527 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 528 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 529 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 530 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 531 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 532 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 533 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 534 | 535 | } 536 | #endif 537 | } 538 | 539 | int32_t qgemm_lut_8640_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 540 | alignas(32) uint32_t CBits[BM8640_3200]; 541 | memset(&(CBits[0]), 0, BM8640_3200 * sizeof(int32_t)); 542 | #pragma unroll 543 | for (int32_t k_outer = 0; k_outer < 3200 / BBK8640_3200; ++k_outer) { 544 | tbl_impl_8640_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK8640_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK8640_3200 / 2 / 2 * BM8640_3200)]))); 545 | } 546 | #pragma unroll 547 | for (int i = 0; i < BM8640_3200; i++) { 548 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 549 | } 550 | return 0; 551 | }; 552 | 553 | template 554 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{ 555 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0]))); 556 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0]))); 557 | 558 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0]))); 559 | }} 560 | void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) { 561 | if (m == 3200 && k == 8640) { 562 | preprocessor_k<8640>(B, LUT_Scales, QLUT); 563 | } 564 | else if (m == 3200 && k == 3200) { 565 | preprocessor_k<3200>(B, LUT_Scales, QLUT); 566 | } 567 | else if (m == 8640 && k == 3200) { 568 | preprocessor_k<3200>(B, LUT_Scales, QLUT); 569 | } 570 | } 571 | void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 572 | if (m == 3200 && k == 8640) { 573 | qgemm_lut_3200_8640(A, LUT, Scales, LUT_Scales, C); 574 | } 575 | else if (m == 3200 && k == 3200) { 576 | qgemm_lut_3200_3200(A, LUT, Scales, LUT_Scales, C); 577 | } 578 | else if (m == 8640 && k == 3200) { 579 | qgemm_lut_8640_3200(A, LUT, Scales, LUT_Scales, C); 580 | } 581 | } 582 | 583 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { 584 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { 585 | return; 586 | } 587 | 588 | int k = tensor->ne[0]; 589 | int m = tensor->ne[1]; 590 | const int lut_scales_size = 1; 591 | const int scales_size = 1; 592 | int bk = 0; 593 | int bm = 0; 594 | 595 | if (m == 3200 && k == 8640) { 596 | bm = BM3200_8640; 597 | bk = BBK3200_8640; 598 | } 599 | else if (m == 3200 && k == 3200) { 600 | bm = BM3200_3200; 601 | bk = BBK3200_3200; 602 | } 603 | else if (m == 8640 && k == 3200) { 604 | bm = BM8640_3200; 605 | bk = BBK8640_3200; 606 | } 607 | 608 | const int n_tile_num = m / bm; 609 | const int BK = bk; 610 | uint8_t * qweights; 611 | bitnet_float_type * scales; 612 | 613 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); 614 | qweights = (uint8_t *) tensor->data; 615 | float * i2_scales = (float * )(qweights + k * m / 4); 616 | scales[0] = (bitnet_float_type) i2_scales[0]; 617 | 618 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; 619 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = { 620 | /* .lut_scales_size = */ lut_scales_size, 621 | /* .scales_size = */ scales_size, 622 | /* .n_tile_num = */ n_tile_num, 623 | /* .qweights = */ qweights, 624 | /* .scales = */ scales 625 | }; 626 | } 627 | #endif -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-3B/kernel_config_tl1.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 3200 3 | k = 8640 4 | bm = 160 5 | bk = 64 6 | bmm = 32 7 | 8 | [Kernels_1] 9 | m = 3200 10 | k = 3200 11 | bm = 320 12 | bk = 128 13 | bmm = 64 14 | 15 | [Kernels_2] 16 | m = 8640 17 | k = 3200 18 | bm = 320 19 | bk = 64 20 | bmm = 32 21 | 22 | -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-3B/kernel_config_tl2.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 3200 3 | k = 8640 4 | bm = 160 5 | bk = 96 6 | bmm = 32 7 | 8 | [Kernels_1] 9 | m = 3200 10 | k = 3200 11 | bm = 320 12 | bk = 96 13 | bmm = 32 14 | 15 | [Kernels_2] 16 | m = 8640 17 | k = 3200 18 | bm = 320 19 | bk = 96 20 | bmm = 32 21 | 22 | -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-large/bitnet-lut-kernels-tl1.h: -------------------------------------------------------------------------------- 1 | #if defined(GGML_BITNET_ARM_TL1) 2 | #include "ggml-bitnet.h" 3 | #define GGML_BITNET_MAX_NODES 8192 4 | static bool initialized = false; 5 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; 6 | static size_t bitnet_tensor_extras_index = 0; 7 | static void * aligned_malloc(size_t size) {{ 8 | #if defined(_WIN32) 9 | return _aligned_malloc(size, 64); 10 | #else 11 | void * ptr = nullptr; 12 | posix_memalign(&ptr, 64, size); 13 | return ptr; 14 | #endif 15 | }} 16 | static void aligned_free(void * ptr) {{ 17 | #if defined(_WIN32) 18 | _aligned_free(ptr); 19 | #else 20 | free(ptr); 21 | #endif 22 | }} 23 | 24 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{ 25 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; 26 | bitnet_float_type* b = (bitnet_float_type*)b_; 27 | #ifdef __ARM_NEON 28 | float32x4_t temp_max = vdupq_n_f32(0); 29 | for (int i=0; i < k / 4; i++) {{ 30 | float32x4_t vec_bs = vld1q_f32(b + 4 * i); 31 | float32x4_t abssum = vabsq_f32(vec_bs); 32 | temp_max = vmaxq_f32(abssum, temp_max); 33 | }} 34 | float32_t scales = 127 / vmaxvq_f32(temp_max); 35 | *lut_scales = scales; 36 | #elif defined __AVX2__ 37 | __m256 max_vec = _mm256_set1_ps(0.f); 38 | const __m256 vec_sign = _mm256_set1_ps(-0.0f); 39 | // #pragma unroll 40 | for (int i = 0; i < k / 8; i++) {{ 41 | __m256 vec_b = _mm256_loadu_ps(b + i * 8); 42 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); 43 | max_vec = _mm256_max_ps(vec_babs, max_vec); 44 | }} 45 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); 46 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); 47 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); 48 | float scales = 127 / _mm_cvtss_f32(max1); 49 | *lut_scales = scales; 50 | #endif 51 | }} 52 | 53 | void partial_max_reset(void* lut_scales_) {{ 54 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; 55 | *lut_scales = 0.0; 56 | }} 57 | 58 | #ifdef __ARM_NEON 59 | inline void Transpose_8_8( 60 | int16x8_t *v0, 61 | int16x8_t *v1, 62 | int16x8_t *v2, 63 | int16x8_t *v3, 64 | int16x8_t *v4, 65 | int16x8_t *v5, 66 | int16x8_t *v6, 67 | int16x8_t *v7) 68 | {{ 69 | int16x8x2_t q04 = vzipq_s16(*v0, *v4); 70 | int16x8x2_t q15 = vzipq_s16(*v1, *v5); 71 | int16x8x2_t q26 = vzipq_s16(*v2, *v6); 72 | int16x8x2_t q37 = vzipq_s16(*v3, *v7); 73 | 74 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]); 75 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]); 76 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]); 77 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]); 78 | 79 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]); 80 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]); 81 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]); 82 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]); 83 | 84 | *v0 = q_fin_0.val[0]; 85 | *v1 = q_fin_0.val[1]; 86 | *v2 = q_fin_1.val[0]; 87 | *v3 = q_fin_1.val[1]; 88 | *v4 = q_fin_2.val[0]; 89 | *v5 = q_fin_2.val[1]; 90 | *v6 = q_fin_3.val[0]; 91 | *v7 = q_fin_3.val[1]; 92 | }} 93 | #endif 94 | 95 | template 96 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{ 97 | #ifdef __ARM_NEON 98 | int16x8_t vec_lut[16]; 99 | float32_t scales = *lut_scales; 100 | uint8_t tbl_mask[16]; 101 | tbl_mask[0] = 0; 102 | tbl_mask[1] = 2; 103 | tbl_mask[2] = 4; 104 | tbl_mask[3] = 6; 105 | tbl_mask[4] = 8; 106 | tbl_mask[5] = 10; 107 | tbl_mask[6] = 12; 108 | tbl_mask[7] = 14; 109 | tbl_mask[8] = 1; 110 | tbl_mask[9] = 3; 111 | tbl_mask[10] = 5; 112 | tbl_mask[11] = 7; 113 | tbl_mask[12] = 9; 114 | tbl_mask[13] = 11; 115 | tbl_mask[14] = 13; 116 | tbl_mask[15] = 15; 117 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask); 118 | #pragma unroll 119 | for (int k = 0; k < act_k / 16; ++k) {{ 120 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16); 121 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8); 122 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales); 123 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales); 124 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales); 125 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales); 126 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0); 127 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1); 128 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2); 129 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3); 130 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0); 131 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1); 132 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2); 133 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3); 134 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2); 135 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3); 136 | vec_lut[0] = vdupq_n_s16(0); 137 | vec_lut[0] = vec_lut[0] - vec_bs_0; 138 | vec_lut[0] = vec_lut[0] - vec_bs_1; 139 | vec_lut[1] = vdupq_n_s16(0); 140 | vec_lut[1] = vec_lut[1] - vec_bs_0; 141 | vec_lut[2] = vdupq_n_s16(0); 142 | vec_lut[2] = vec_lut[2] - vec_bs_0; 143 | vec_lut[2] = vec_lut[2] + vec_bs_1; 144 | vec_lut[3] = vdupq_n_s16(0); 145 | vec_lut[3] = vec_lut[3] - vec_bs_1; 146 | vec_lut[4] = vdupq_n_s16(0); 147 | vec_lut[5] = vec_bs_1; 148 | vec_lut[6] = vec_bs_0; 149 | vec_lut[6] = vec_lut[6] - vec_bs_1; 150 | vec_lut[7] = vec_bs_0; 151 | vec_lut[8] = vec_bs_0; 152 | vec_lut[8] = vec_lut[8] + vec_bs_1; 153 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]), 154 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7])); 155 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]), 156 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15])); 157 | #pragma unroll 158 | for (int idx = 0; idx < 8; idx++) {{ 159 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q); 160 | int8x8_t q0_low = vget_low_s8(q0_s); 161 | int8x8_t q0_high = vget_high_s8(q0_s); 162 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q); 163 | int8x8_t q1_low = vget_low_s8(q1_s); 164 | int8x8_t q1_high = vget_high_s8(q1_s); 165 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high); 166 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high); 167 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low); 168 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low); 169 | }} 170 | }} 171 | #endif 172 | }} 173 | 174 | static bool is_type_supported(enum ggml_type type) {{ 175 | if (type == GGML_TYPE_Q4_0 || 176 | type == GGML_TYPE_TL1) {{ 177 | return true; 178 | }} else {{ 179 | return false; 180 | }} 181 | }} 182 | #include 183 | 184 | #define BM1536_4096 256 185 | #define BBK1536_4096 128 186 | inline void tbl_impl_1536_4096(int32_t* c, int8_t* lut, uint8_t* a) { 187 | #ifdef __ARM_NEON 188 | const int KK = BBK1536_4096 / 2; 189 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 190 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 191 | int8x16_t vec_lut[2 * KK]; 192 | int16x8_t vec_c[4]; 193 | #pragma unroll 194 | for (int k = 0; k < 2 * KK; k++) { 195 | vec_lut[k] = vld1q_s8(lut + k * 16); 196 | } 197 | 198 | #pragma unroll 199 | for (int i = 0; i < BM1536_4096; i += 32) { 200 | #pragma unroll 201 | for (int i=0; i<4; i++) { 202 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 203 | } 204 | 205 | #pragma unroll 206 | for (int k = 0; k < KK / 4; k++) { 207 | 208 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 209 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 210 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 211 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); 212 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); 213 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); 214 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); 215 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 216 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 217 | vec_c[0] += vec_v_left_0.val[0]; 218 | vec_c[0] += vec_v_right_0.val[0]; 219 | vec_c[1] += vec_v_left_0.val[1]; 220 | vec_c[1] += vec_v_right_0.val[1]; 221 | 222 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 223 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 224 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 225 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); 226 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); 227 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); 228 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); 229 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 230 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 231 | vec_c[0] += vec_v_left_1.val[0]; 232 | vec_c[0] += vec_v_right_1.val[0]; 233 | vec_c[1] += vec_v_left_1.val[1]; 234 | vec_c[1] += vec_v_right_1.val[1]; 235 | 236 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 237 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 238 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 239 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); 240 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); 241 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); 242 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); 243 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 244 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 245 | vec_c[2] += vec_v_left_2.val[0]; 246 | vec_c[2] += vec_v_right_2.val[0]; 247 | vec_c[3] += vec_v_left_2.val[1]; 248 | vec_c[3] += vec_v_right_2.val[1]; 249 | 250 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 251 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 252 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 253 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); 254 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); 255 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); 256 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); 257 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 258 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 259 | vec_c[2] += vec_v_left_3.val[0]; 260 | vec_c[2] += vec_v_right_3.val[0]; 261 | vec_c[3] += vec_v_left_3.val[1]; 262 | vec_c[3] += vec_v_right_3.val[1]; 263 | 264 | } 265 | 266 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 267 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 268 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 269 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 270 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 271 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 272 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 273 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 274 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 275 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 276 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 277 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 278 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 279 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 280 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 281 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 282 | 283 | } 284 | #endif 285 | } 286 | 287 | int32_t qgemm_lut_1536_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 288 | alignas(32) uint32_t CBits[BM1536_4096]; 289 | memset(&(CBits[0]), 0, BM1536_4096 * sizeof(int32_t)); 290 | #pragma unroll 291 | for (int32_t k_outer = 0; k_outer < 4096 / BBK1536_4096; ++k_outer) { 292 | tbl_impl_1536_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_4096 / 2 / 2 * BM1536_4096)]))); 293 | } 294 | #pragma unroll 295 | for (int i = 0; i < BM1536_4096; i++) { 296 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 297 | } 298 | return 0; 299 | }; 300 | #include 301 | 302 | #define BM1536_1536 128 303 | #define BBK1536_1536 64 304 | inline void tbl_impl_1536_1536(int32_t* c, int8_t* lut, uint8_t* a) { 305 | #ifdef __ARM_NEON 306 | const int KK = BBK1536_1536 / 2; 307 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 308 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 309 | int8x16_t vec_lut[2 * KK]; 310 | int16x8_t vec_c[8]; 311 | #pragma unroll 312 | for (int k = 0; k < 2 * KK; k++) { 313 | vec_lut[k] = vld1q_s8(lut + k * 16); 314 | } 315 | 316 | #pragma unroll 317 | for (int i = 0; i < BM1536_1536; i += 64) { 318 | #pragma unroll 319 | for (int i=0; i<8; i++) { 320 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 321 | } 322 | 323 | #pragma unroll 324 | for (int k = 0; k < KK / 2; k++) { 325 | 326 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 327 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 328 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 329 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); 330 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); 331 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); 332 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); 333 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 334 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 335 | vec_c[0] += vec_v_left_0.val[0]; 336 | vec_c[0] += vec_v_right_0.val[0]; 337 | vec_c[1] += vec_v_left_0.val[1]; 338 | vec_c[1] += vec_v_right_0.val[1]; 339 | 340 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 341 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 342 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 343 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); 344 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); 345 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); 346 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); 347 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 348 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 349 | vec_c[2] += vec_v_left_1.val[0]; 350 | vec_c[2] += vec_v_right_1.val[0]; 351 | vec_c[3] += vec_v_left_1.val[1]; 352 | vec_c[3] += vec_v_right_1.val[1]; 353 | 354 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 355 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 356 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 357 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); 358 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); 359 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); 360 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); 361 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 362 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 363 | vec_c[4] += vec_v_left_2.val[0]; 364 | vec_c[4] += vec_v_right_2.val[0]; 365 | vec_c[5] += vec_v_left_2.val[1]; 366 | vec_c[5] += vec_v_right_2.val[1]; 367 | 368 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 369 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 370 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 371 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); 372 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); 373 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); 374 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); 375 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 376 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 377 | vec_c[6] += vec_v_left_3.val[0]; 378 | vec_c[6] += vec_v_right_3.val[0]; 379 | vec_c[7] += vec_v_left_3.val[1]; 380 | vec_c[7] += vec_v_right_3.val[1]; 381 | 382 | } 383 | 384 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 385 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 386 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 387 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 388 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 389 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 390 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 391 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 392 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 393 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 394 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 395 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 396 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 397 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 398 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 399 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 400 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); 401 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); 402 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); 403 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); 404 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); 405 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); 406 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); 407 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); 408 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); 409 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); 410 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); 411 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); 412 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); 413 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); 414 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); 415 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); 416 | 417 | } 418 | #endif 419 | } 420 | 421 | int32_t qgemm_lut_1536_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 422 | alignas(32) uint32_t CBits[BM1536_1536]; 423 | memset(&(CBits[0]), 0, BM1536_1536 * sizeof(int32_t)); 424 | #pragma unroll 425 | for (int32_t k_outer = 0; k_outer < 1536 / BBK1536_1536; ++k_outer) { 426 | tbl_impl_1536_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_1536 / 2 / 2 * BM1536_1536)]))); 427 | } 428 | #pragma unroll 429 | for (int i = 0; i < BM1536_1536; i++) { 430 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 431 | } 432 | return 0; 433 | }; 434 | #include 435 | 436 | #define BM4096_1536 256 437 | #define BBK4096_1536 128 438 | inline void tbl_impl_4096_1536(int32_t* c, int8_t* lut, uint8_t* a) { 439 | #ifdef __ARM_NEON 440 | const int KK = BBK4096_1536 / 2; 441 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f); 442 | const int8x16_t vec_zero = vdupq_n_s16(0x0000); 443 | int8x16_t vec_lut[2 * KK]; 444 | int16x8_t vec_c[4]; 445 | #pragma unroll 446 | for (int k = 0; k < 2 * KK; k++) { 447 | vec_lut[k] = vld1q_s8(lut + k * 16); 448 | } 449 | 450 | #pragma unroll 451 | for (int i = 0; i < BM4096_1536; i += 32) { 452 | #pragma unroll 453 | for (int i=0; i<4; i++) { 454 | vec_c[i] = vandq_s16(vec_c[i], vec_zero); 455 | } 456 | 457 | #pragma unroll 458 | for (int k = 0; k < KK / 4; k++) { 459 | 460 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); 461 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); 462 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); 463 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); 464 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); 465 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); 466 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); 467 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); 468 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); 469 | vec_c[0] += vec_v_left_0.val[0]; 470 | vec_c[0] += vec_v_right_0.val[0]; 471 | vec_c[1] += vec_v_left_0.val[1]; 472 | vec_c[1] += vec_v_right_0.val[1]; 473 | 474 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); 475 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); 476 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); 477 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); 478 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); 479 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); 480 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); 481 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); 482 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); 483 | vec_c[0] += vec_v_left_1.val[0]; 484 | vec_c[0] += vec_v_right_1.val[0]; 485 | vec_c[1] += vec_v_left_1.val[1]; 486 | vec_c[1] += vec_v_right_1.val[1]; 487 | 488 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); 489 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); 490 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); 491 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); 492 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); 493 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); 494 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); 495 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); 496 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); 497 | vec_c[2] += vec_v_left_2.val[0]; 498 | vec_c[2] += vec_v_right_2.val[0]; 499 | vec_c[3] += vec_v_left_2.val[1]; 500 | vec_c[3] += vec_v_right_2.val[1]; 501 | 502 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); 503 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); 504 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); 505 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); 506 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); 507 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); 508 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); 509 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); 510 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); 511 | vec_c[2] += vec_v_left_3.val[0]; 512 | vec_c[2] += vec_v_right_3.val[0]; 513 | vec_c[3] += vec_v_left_3.val[1]; 514 | vec_c[3] += vec_v_right_3.val[1]; 515 | 516 | } 517 | 518 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); 519 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); 520 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); 521 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); 522 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); 523 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); 524 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); 525 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); 526 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); 527 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); 528 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); 529 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); 530 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); 531 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); 532 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); 533 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); 534 | 535 | } 536 | #endif 537 | } 538 | 539 | int32_t qgemm_lut_4096_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 540 | alignas(32) uint32_t CBits[BM4096_1536]; 541 | memset(&(CBits[0]), 0, BM4096_1536 * sizeof(int32_t)); 542 | #pragma unroll 543 | for (int32_t k_outer = 0; k_outer < 1536 / BBK4096_1536; ++k_outer) { 544 | tbl_impl_4096_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_1536 / 2 / 2 * BM4096_1536)]))); 545 | } 546 | #pragma unroll 547 | for (int i = 0; i < BM4096_1536; i++) { 548 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; 549 | } 550 | return 0; 551 | }; 552 | 553 | template 554 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{ 555 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0]))); 556 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0]))); 557 | 558 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0]))); 559 | }} 560 | void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) { 561 | if (m == 1536 && k == 4096) { 562 | preprocessor_k<4096>(B, LUT_Scales, QLUT); 563 | } 564 | else if (m == 1536 && k == 1536) { 565 | preprocessor_k<1536>(B, LUT_Scales, QLUT); 566 | } 567 | else if (m == 4096 && k == 1536) { 568 | preprocessor_k<1536>(B, LUT_Scales, QLUT); 569 | } 570 | } 571 | void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { 572 | if (m == 1536 && k == 4096) { 573 | qgemm_lut_1536_4096(A, LUT, Scales, LUT_Scales, C); 574 | } 575 | else if (m == 1536 && k == 1536) { 576 | qgemm_lut_1536_1536(A, LUT, Scales, LUT_Scales, C); 577 | } 578 | else if (m == 4096 && k == 1536) { 579 | qgemm_lut_4096_1536(A, LUT, Scales, LUT_Scales, C); 580 | } 581 | } 582 | 583 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { 584 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { 585 | return; 586 | } 587 | 588 | int k = tensor->ne[0]; 589 | int m = tensor->ne[1]; 590 | const int lut_scales_size = 1; 591 | const int scales_size = 1; 592 | int bk = 0; 593 | int bm = 0; 594 | 595 | if (m == 1536 && k == 4096) { 596 | bm = BM1536_4096; 597 | bk = BBK1536_4096; 598 | } 599 | else if (m == 1536 && k == 1536) { 600 | bm = BM1536_1536; 601 | bk = BBK1536_1536; 602 | } 603 | else if (m == 4096 && k == 1536) { 604 | bm = BM4096_1536; 605 | bk = BBK4096_1536; 606 | } 607 | 608 | const int n_tile_num = m / bm; 609 | const int BK = bk; 610 | uint8_t * qweights; 611 | bitnet_float_type * scales; 612 | 613 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); 614 | qweights = (uint8_t *) tensor->data; 615 | float * i2_scales = (float * )(qweights + k * m / 4); 616 | scales[0] = (bitnet_float_type) i2_scales[0]; 617 | 618 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; 619 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = { 620 | /* .lut_scales_size = */ lut_scales_size, 621 | /* .scales_size = */ scales_size, 622 | /* .n_tile_num = */ n_tile_num, 623 | /* .qweights = */ qweights, 624 | /* .scales = */ scales 625 | }; 626 | } 627 | #endif -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-large/kernel_config_tl1.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 1536 3 | k = 4096 4 | bm = 256 5 | bk = 128 6 | bmm = 32 7 | 8 | [Kernels_1] 9 | m = 1536 10 | k = 1536 11 | bm = 128 12 | bk = 64 13 | bmm = 64 14 | 15 | [Kernels_2] 16 | m = 4096 17 | k = 1536 18 | bm = 256 19 | bk = 128 20 | bmm = 32 21 | 22 | -------------------------------------------------------------------------------- /preset_kernels/bitnet_b1_58-large/kernel_config_tl2.ini: -------------------------------------------------------------------------------- 1 | [Kernels_0] 2 | m = 1536 3 | k = 4096 4 | bm = 256 5 | bk = 96 6 | bmm = 32 7 | 8 | [Kernels_1] 9 | m = 1536 10 | k = 1536 11 | bm = 128 12 | bk = 192 13 | bmm = 32 14 | 15 | [Kernels_2] 16 | m = 4096 17 | k = 1536 18 | bm = 256 19 | bk = 96 20 | bmm = 64 21 | 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # These requirements include all dependencies for all top-level python scripts 2 | # for llama.cpp. Avoid adding packages here directly. 3 | # 4 | # Package versions must stay compatible across all top-level python scripts. 5 | # 6 | 7 | -r 3rdparty/llama.cpp/requirements/requirements-convert_legacy_llama.txt 8 | -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt 9 | -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt 10 | -r 3rdparty/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt 11 | -r 3rdparty/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import signal 4 | import platform 5 | import argparse 6 | import subprocess 7 | 8 | def run_command(command, shell=False): 9 | """Run a system command and ensure it succeeds.""" 10 | try: 11 | subprocess.run(command, shell=shell, check=True) 12 | except subprocess.CalledProcessError as e: 13 | print(f"Error occurred while running command: {e}") 14 | sys.exit(1) 15 | 16 | def run_inference(): 17 | build_dir = "build" 18 | if platform.system() == "Windows": 19 | main_path = os.path.join(build_dir, "bin", "Release", "llama-cli.exe") 20 | if not os.path.exists(main_path): 21 | main_path = os.path.join(build_dir, "bin", "llama-cli") 22 | else: 23 | main_path = os.path.join(build_dir, "bin", "llama-cli") 24 | command = [ 25 | f'{main_path}', 26 | '-m', args.model, 27 | '-n', str(args.n_predict), 28 | '-t', str(args.threads), 29 | '-p', args.prompt, 30 | '-ngl', '0', 31 | '-c', str(args.ctx_size), 32 | '--temp', str(args.temperature), 33 | "-b", "1", 34 | ] 35 | if args.conversation: 36 | command.append("-cnv") 37 | run_command(command) 38 | 39 | def signal_handler(sig, frame): 40 | print("Ctrl+C pressed, exiting...") 41 | sys.exit(0) 42 | 43 | if __name__ == "__main__": 44 | signal.signal(signal.SIGINT, signal_handler) 45 | # Usage: python run_inference.py -p "Microsoft Corporation is an American multinational corporation and technology company headquartered in Redmond, Washington." 46 | parser = argparse.ArgumentParser(description='Run inference') 47 | parser.add_argument("-m", "--model", type=str, help="Path to model file", required=False, default="models/bitnet_b1_58-3B/ggml-model-i2_s.gguf") 48 | parser.add_argument("-n", "--n-predict", type=int, help="Number of tokens to predict when generating text", required=False, default=128) 49 | parser.add_argument("-p", "--prompt", type=str, help="Prompt to generate text from", required=True) 50 | parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2) 51 | parser.add_argument("-c", "--ctx-size", type=int, help="Size of the prompt context", required=False, default=2048) 52 | parser.add_argument("-temp", "--temperature", type=float, help="Temperature, a hyperparameter that controls the randomness of the generated text", required=False, default=0.8) 53 | parser.add_argument("-cnv", "--conversation", action='store_true', help="Whether to enable chat mode or not (for instruct models.)") 54 | 55 | args = parser.parse_args() 56 | run_inference() -------------------------------------------------------------------------------- /setup_env.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import signal 3 | import sys 4 | import os 5 | import platform 6 | import argparse 7 | import logging 8 | import shutil 9 | from pathlib import Path 10 | 11 | logger = logging.getLogger("setup_env") 12 | 13 | SUPPORTED_HF_MODELS = { 14 | "1bitLLM/bitnet_b1_58-large": { 15 | "model_name": "bitnet_b1_58-large", 16 | }, 17 | "1bitLLM/bitnet_b1_58-3B": { 18 | "model_name": "bitnet_b1_58-3B", 19 | }, 20 | "HF1BitLLM/Llama3-8B-1.58-100B-tokens": { 21 | "model_name": "Llama3-8B-1.58-100B-tokens", 22 | }, 23 | "tiiuae/Falcon3-7B-Instruct-1.58bit": { 24 | "model_name": "Falcon3-7B-Instruct-1.58bit", 25 | }, 26 | "tiiuae/Falcon3-7B-1.58bit": { 27 | "model_name": "Falcon3-7B-1.58bit", 28 | }, 29 | "tiiuae/Falcon3-10B-Instruct-1.58bit": { 30 | "model_name": "Falcon3-10B-Instruct-1.58bit", 31 | }, 32 | "tiiuae/Falcon3-10B-1.58bit": { 33 | "model_name": "Falcon3-10B-1.58bit", 34 | }, 35 | "tiiuae/Falcon3-3B-Instruct-1.58bit": { 36 | "model_name": "Falcon3-3B-Instruct-1.58bit", 37 | }, 38 | "tiiuae/Falcon3-3B-1.58bit": { 39 | "model_name": "Falcon3-3B-1.58bit", 40 | }, 41 | "tiiuae/Falcon3-1B-Instruct-1.58bit": { 42 | "model_name": "Falcon3-1B-Instruct-1.58bit", 43 | }, 44 | "microsoft/BitNet-b1.58-2B-4T": { 45 | "model_name": "BitNet-b1.58-2B-4T", 46 | }, 47 | } 48 | 49 | SUPPORTED_QUANT_TYPES = { 50 | "arm64": ["i2_s", "tl1"], 51 | "x86_64": ["i2_s", "tl2"] 52 | } 53 | 54 | COMPILER_EXTRA_ARGS = { 55 | "arm64": ["-DBITNET_ARM_TL1=ON"], 56 | "x86_64": ["-DBITNET_X86_TL2=ON"] 57 | } 58 | 59 | OS_EXTRA_ARGS = { 60 | "Windows":["-T", "ClangCL"], 61 | } 62 | 63 | ARCH_ALIAS = { 64 | "AMD64": "x86_64", 65 | "x86": "x86_64", 66 | "x86_64": "x86_64", 67 | "aarch64": "arm64", 68 | "arm64": "arm64", 69 | "ARM64": "arm64", 70 | } 71 | 72 | def system_info(): 73 | return platform.system(), ARCH_ALIAS[platform.machine()] 74 | 75 | def get_model_name(): 76 | if args.hf_repo: 77 | return SUPPORTED_HF_MODELS[args.hf_repo]["model_name"] 78 | return os.path.basename(os.path.normpath(args.model_dir)) 79 | 80 | def run_command(command, shell=False, log_step=None): 81 | """Run a system command and ensure it succeeds.""" 82 | if log_step: 83 | log_file = os.path.join(args.log_dir, log_step + ".log") 84 | with open(log_file, "w") as f: 85 | try: 86 | subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f) 87 | except subprocess.CalledProcessError as e: 88 | logging.error(f"Error occurred while running command: {e}, check details in {log_file}") 89 | sys.exit(1) 90 | else: 91 | try: 92 | subprocess.run(command, shell=shell, check=True) 93 | except subprocess.CalledProcessError as e: 94 | logging.error(f"Error occurred while running command: {e}") 95 | sys.exit(1) 96 | 97 | def prepare_model(): 98 | _, arch = system_info() 99 | hf_url = args.hf_repo 100 | model_dir = args.model_dir 101 | quant_type = args.quant_type 102 | quant_embd = args.quant_embd 103 | if hf_url is not None: 104 | # download the model 105 | model_dir = os.path.join(model_dir, SUPPORTED_HF_MODELS[hf_url]["model_name"]) 106 | Path(model_dir).mkdir(parents=True, exist_ok=True) 107 | logging.info(f"Downloading model {hf_url} from HuggingFace to {model_dir}...") 108 | run_command(["huggingface-cli", "download", hf_url, "--local-dir", model_dir], log_step="download_model") 109 | elif not os.path.exists(model_dir): 110 | logging.error(f"Model directory {model_dir} does not exist.") 111 | sys.exit(1) 112 | else: 113 | logging.info(f"Loading model from directory {model_dir}.") 114 | gguf_path = os.path.join(model_dir, "ggml-model-" + quant_type + ".gguf") 115 | if not os.path.exists(gguf_path) or os.path.getsize(gguf_path) == 0: 116 | logging.info(f"Converting HF model to GGUF format...") 117 | if quant_type.startswith("tl"): 118 | run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", quant_type, "--quant-embd"], log_step="convert_to_tl") 119 | else: # i2s 120 | # convert to f32 121 | run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "f32"], log_step="convert_to_f32_gguf") 122 | f32_model = os.path.join(model_dir, "ggml-model-f32.gguf") 123 | i2s_model = os.path.join(model_dir, "ggml-model-i2_s.gguf") 124 | # quantize to i2s 125 | if platform.system() != "Windows": 126 | if quant_embd: 127 | run_command(["./build/bin/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s") 128 | else: 129 | run_command(["./build/bin/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s") 130 | else: 131 | if quant_embd: 132 | run_command(["./build/bin/Release/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s") 133 | else: 134 | run_command(["./build/bin/Release/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s") 135 | 136 | logging.info(f"GGUF model saved at {gguf_path}") 137 | else: 138 | logging.info(f"GGUF model already exists at {gguf_path}") 139 | 140 | def setup_gguf(): 141 | # Install the pip package 142 | run_command([sys.executable, "-m", "pip", "install", "3rdparty/llama.cpp/gguf-py"], log_step="install_gguf") 143 | 144 | def gen_code(): 145 | _, arch = system_info() 146 | 147 | llama3_f3_models = set([model['model_name'] for model in SUPPORTED_HF_MODELS.values() if model['model_name'].startswith("Falcon3") or model['model_name'].startswith("Llama")]) 148 | 149 | if arch == "arm64": 150 | if args.use_pretuned: 151 | pretuned_kernels = os.path.join("preset_kernels", get_model_name()) 152 | if not os.path.exists(pretuned_kernels): 153 | logging.error(f"Pretuned kernels not found for model {args.hf_repo}") 154 | sys.exit(1) 155 | if args.quant_type == "tl1": 156 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl1.h"), "include/bitnet-lut-kernels.h") 157 | shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl1.ini"), "include/kernel_config.ini") 158 | elif args.quant_type == "tl2": 159 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h") 160 | shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl2.ini"), "include/kernel_config.ini") 161 | if get_model_name() == "bitnet_b1_58-large": 162 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen") 163 | elif get_model_name() in llama3_f3_models: 164 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen") 165 | elif get_model_name() == "bitnet_b1_58-3B": 166 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen") 167 | elif get_model_name() == "BitNet-b1.58-2B-4T": 168 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen") 169 | else: 170 | raise NotImplementedError() 171 | else: 172 | if args.use_pretuned: 173 | # cp preset_kernels/model_name/bitnet-lut-kernels_tl1.h to include/bitnet-lut-kernels.h 174 | pretuned_kernels = os.path.join("preset_kernels", get_model_name()) 175 | if not os.path.exists(pretuned_kernels): 176 | logging.error(f"Pretuned kernels not found for model {args.hf_repo}") 177 | sys.exit(1) 178 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h") 179 | if get_model_name() == "bitnet_b1_58-large": 180 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen") 181 | elif get_model_name() in llama3_f3_models: 182 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen") 183 | elif get_model_name() == "bitnet_b1_58-3B": 184 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen") 185 | elif get_model_name() == "BitNet-b1.58-2B-4T": 186 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen") 187 | else: 188 | raise NotImplementedError() 189 | 190 | 191 | def compile(): 192 | # Check if cmake is installed 193 | cmake_exists = subprocess.run(["cmake", "--version"], capture_output=True) 194 | if cmake_exists.returncode != 0: 195 | logging.error("Cmake is not available. Please install CMake and try again.") 196 | sys.exit(1) 197 | _, arch = system_info() 198 | if arch not in COMPILER_EXTRA_ARGS.keys(): 199 | logging.error(f"Arch {arch} is not supported yet") 200 | exit(0) 201 | logging.info("Compiling the code using CMake.") 202 | run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files") 203 | # run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"]) 204 | run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile") 205 | 206 | def main(): 207 | setup_gguf() 208 | gen_code() 209 | compile() 210 | prepare_model() 211 | 212 | def parse_args(): 213 | _, arch = system_info() 214 | parser = argparse.ArgumentParser(description='Setup the environment for running the inference') 215 | parser.add_argument("--hf-repo", "-hr", type=str, help="Model used for inference", choices=SUPPORTED_HF_MODELS.keys()) 216 | parser.add_argument("--model-dir", "-md", type=str, help="Directory to save/load the model", default="models") 217 | parser.add_argument("--log-dir", "-ld", type=str, help="Directory to save the logging info", default="logs") 218 | parser.add_argument("--quant-type", "-q", type=str, help="Quantization type", choices=SUPPORTED_QUANT_TYPES[arch], default="i2_s") 219 | parser.add_argument("--quant-embd", action="store_true", help="Quantize the embeddings to f16") 220 | parser.add_argument("--use-pretuned", "-p", action="store_true", help="Use the pretuned kernel parameters") 221 | return parser.parse_args() 222 | 223 | def signal_handler(sig, frame): 224 | logging.info("Ctrl+C pressed, exiting...") 225 | sys.exit(0) 226 | 227 | if __name__ == "__main__": 228 | signal.signal(signal.SIGINT, signal_handler) 229 | args = parse_args() 230 | Path(args.log_dir).mkdir(parents=True, exist_ok=True) 231 | logging.basicConfig(level=logging.INFO) 232 | main() -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(GGML_HEADERS_BITNET ../include/ggml-bitnet.h) 2 | set(GGML_SOURCES_BITNET ggml-bitnet-mad.cpp) 3 | set(GGML_SOURCES_BITNET ggml-bitnet-lut.cpp) 4 | 5 | include_directories(3rdparty/llama.cpp/ggml/include) 6 | 7 | if (NOT (CMAKE_C_COMPILER_ID MATCHES "Clang" OR CMAKE_C_COMPILER_ID STREQUAL "GNU") OR 8 | NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")) 9 | message(FATAL_ERROR "Clang or GCC is required for Bitnet.cpp compilation") 10 | endif() 11 | -------------------------------------------------------------------------------- /src/ggml-bitnet-lut.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "ggml-bitnet.h" 9 | #include "ggml-quants.h" 10 | #include "bitnet-lut-kernels.h" 11 | 12 | #if defined(GGML_BITNET_ARM_TL1) 13 | 14 | void ggml_bitnet_init(void) { 15 | // LOG(INFO) << "ggml_bitnet_init"; 16 | 17 | if (initialized) { 18 | return; 19 | } 20 | initialized = true; 21 | 22 | // if (wrapper == nullptr) { 23 | // wrapper = new BITNET::BITNETGeMMWrapper(); 24 | // } 25 | if (bitnet_tensor_extras == nullptr) { 26 | bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES]; 27 | } 28 | bitnet_tensor_extras_index = 0; 29 | } 30 | 31 | void ggml_bitnet_free(void) { 32 | // LOG(INFO) << "ggml_bitnet_free"; 33 | 34 | if (!initialized) { 35 | return; 36 | } 37 | initialized = false; 38 | 39 | // delete wrapper; 40 | // wrapper = nullptr; 41 | for (size_t i = 0; i < bitnet_tensor_extras_index; i++) { 42 | // aligned_free(bitnet_tensor_extras[i].qweights); 43 | // aligned_free(bitnet_tensor_extras[i].scales); 44 | } 45 | delete[] bitnet_tensor_extras; 46 | bitnet_tensor_extras = nullptr; 47 | } 48 | 49 | static bool do_permutate(enum ggml_type type) { 50 | if (type == GGML_TYPE_TL1) { 51 | // Add additional args to decide if permuted I2 or naive I2 52 | return false; 53 | } else { 54 | return true; 55 | } 56 | } 57 | 58 | bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { 59 | if ((is_type_supported(src0->type)) && 60 | src1->type == GGML_TYPE_F32 && 61 | dst->type == GGML_TYPE_F32 && 62 | src0->backend == GGML_BACKEND_TYPE_CPU) { 63 | if (src1->ne[1] <= 1) { 64 | return true; 65 | } 66 | } 67 | return false; 68 | } 69 | 70 | size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { 71 | const size_t ne01 = src0->ne[1]; 72 | const size_t ne10 = src1->ne[0]; 73 | const size_t ne11 = src1->ne[1]; 74 | const int bits = ggml_bitnet_get_type_bits(src0->type); 75 | 76 | size_t wsize = ne10 * ne11 * 15 * sizeof(int8_t) + 1 * ne11 * 2 * sizeof(bitnet_float_type); 77 | if (sizeof(bitnet_float_type) == 2) { 78 | // Need fp32 to fp16 conversion 79 | wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type); 80 | } 81 | wsize = ((wsize - 1) / 64 + 1) * 64; 82 | return wsize; 83 | } 84 | 85 | int ggml_bitnet_get_type_bits(enum ggml_type type) { 86 | switch (type) { 87 | case GGML_TYPE_TL1: 88 | return 2; 89 | case GGML_TYPE_Q4_0: 90 | return 4; 91 | default: 92 | return 0; 93 | } 94 | } 95 | 96 | #endif 97 | #if defined(GGML_BITNET_X86_TL2) 98 | void ggml_bitnet_init(void) { 99 | // LOG(INFO) << "ggml_bitnet_init"; 100 | 101 | if (initialized) { 102 | return; 103 | } 104 | initialized = true; 105 | 106 | // if (wrapper == nullptr) { 107 | // wrapper = new BITNET::BITNETGeMMWrapper(); 108 | // } 109 | if (bitnet_tensor_extras == nullptr) { 110 | bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES]; 111 | } 112 | bitnet_tensor_extras_index = 0; 113 | } 114 | 115 | void ggml_bitnet_free(void) { 116 | // LOG(INFO) << "ggml_bitnet_free"; 117 | 118 | if (!initialized) { 119 | return; 120 | } 121 | initialized = false; 122 | 123 | // delete wrapper; 124 | // wrapper = nullptr; 125 | for (size_t i = 0; i < bitnet_tensor_extras_index; i++) { 126 | // aligned_free(bitnet_tensor_extras[i].qweights); 127 | // aligned_free(bitnet_tensor_extras[i].scales); 128 | } 129 | delete[] bitnet_tensor_extras; 130 | bitnet_tensor_extras = nullptr; 131 | } 132 | 133 | bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { 134 | if ((is_type_supported(src0->type)) && 135 | src1->type == GGML_TYPE_F32 && 136 | dst->type == GGML_TYPE_F32 && 137 | src0->backend == GGML_BACKEND_TYPE_CPU) { 138 | return true; 139 | } 140 | return false; 141 | } 142 | 143 | size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) { 144 | const size_t ne01 = src0->ne[1]; 145 | const size_t ne10 = src1->ne[0]; 146 | const size_t ne11 = src1->ne[1]; 147 | 148 | size_t wsize = ne10 * ne11 * 11 * sizeof(int8_t) + 2 * ne11 * 2 * sizeof(bitnet_float_type); 149 | if (sizeof(bitnet_float_type) == 2) { 150 | // Need fp32 to fp16 conversion 151 | wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type); 152 | } 153 | wsize = ((wsize - 1) / 64 + 1) * 64; 154 | return wsize; 155 | } 156 | 157 | int ggml_bitnet_get_type_bits(enum ggml_type type) { 158 | switch (type) { 159 | case GGML_TYPE_TL2: 160 | return 2; 161 | case GGML_TYPE_Q4_0: 162 | return 4; 163 | default: 164 | return 0; 165 | } 166 | } 167 | #endif -------------------------------------------------------------------------------- /src/ggml-bitnet-mad.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ggml-bitnet.h" 5 | #include "ggml-quants.h" 6 | #include 7 | #include 8 | 9 | #define QK_I2_S 128 10 | #define QK_I2 128 11 | 12 | #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) 13 | #include 14 | // horizontally add 8 int32_t 15 | static inline int hsum_i32_8(const __m256i a) { 16 | const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); 17 | const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); 18 | const __m128i sum64 = _mm_add_epi32(hi64, sum128); 19 | const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); 20 | return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); 21 | } 22 | #elif defined(__loongarch_asx) 23 | // horizontally add 8 int32_t 24 | static inline int hsum_i32_8(const __m256i a) { 25 | 26 | __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); 27 | __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); 28 | 29 | __m128i tmp1_128 = lasx_extracti128_lo(tmp1); 30 | __m128i tmp2_128 = lasx_extracti128_lo(tmp2); 31 | 32 | __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); 33 | 34 | __m128i ev = __lsx_vpickev_w(sum128, sum128); 35 | __m128i od = __lsx_vpickod_w(sum128, sum128); 36 | __m128i sum64 = __lsx_vadd_w(ev, od); 37 | 38 | int sum64_1, sum64_2; 39 | sum64_1 = __lsx_vpickve2gr_w(sum64, 0); 40 | sum64_2 = __lsx_vpickve2gr_w(sum64, 1); 41 | 42 | return sum64_1 + sum64_2; 43 | } 44 | #endif 45 | 46 | size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { 47 | // 2 bits per weight 48 | 49 | size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); 50 | 51 | int n = nrow * n_per_row; 52 | 53 | // f32 -> q8 54 | double max = 0; 55 | for (int i = 0; i < n; ++i) { 56 | max = fmax(max, (double)fabs((double)src[i])); 57 | } 58 | double i2_scale = max; 59 | 60 | uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t)); 61 | for (int i=0; i 0 ? 2 : 0; 67 | } 68 | 69 | memset(dst, 0, n * sizeof(uint8_t) / 4); 70 | 71 | // q8 -> 0, 1, 2 72 | // | | | 73 | // -1, 0, 1 74 | 75 | uint8_t* i2_weight = (uint8_t*)dst; 76 | for (int i = 0; i < n / QK_I2; i++) { 77 | for (int j = 0; j < QK_I2; j++) { 78 | int group_idx = j / 32; 79 | int group_pos = j % 32; 80 | uint8_t temp = (q8[i * QK_I2 + j] << (6 - 2 * group_idx)); 81 | i2_weight[i * 32 + group_pos] |= temp; 82 | } 83 | } 84 | 85 | float* scale_ptr = (float*)((char*)i2_weight + n / 4); 86 | scale_ptr[0] = i2_scale; 87 | 88 | free(q8); 89 | 90 | // 32B for alignment 91 | return nrow * row_size / 4 + 32; 92 | } 93 | 94 | void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { 95 | const uint8_t * x = (uint8_t *)vx; 96 | const int8_t * y = (int8_t *)vy; 97 | 98 | const int nb = n / QK_I2_S; 99 | const int group32_num = nb / 32; 100 | const int la_num = nb % 32; 101 | const int groupla_num = nb % 32 != 0 ? 1 : 0; 102 | 103 | #if defined(__AVX2__) 104 | 105 | __m256i mask = _mm256_set1_epi8(0x03); 106 | __m256i accu = _mm256_setzero_si256(); 107 | 108 | for (int i=0; i < group32_num; i++){ 109 | __m256i accu32 = _mm256_setzero_si256(); 110 | for (int j=0; j < 32; j++) { 111 | // 128 index 112 | __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + i * 32 * 32 + j * 32)); 113 | __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); 114 | __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); 115 | __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); 116 | 117 | // each 32 index 118 | xq8_3 = _mm256_and_si256(xq8_3, mask); 119 | xq8_2 = _mm256_and_si256(xq8_2, mask); 120 | xq8_1 = _mm256_and_si256(xq8_1, mask); 121 | xq8_0 = _mm256_and_si256(xq8_0, mask); 122 | 123 | // each 32 index 124 | __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 0)); 125 | __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 32)); 126 | __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 64)); 127 | __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 96)); 128 | 129 | // 128 index accumulation add 130 | // split into 32 accumulation block 131 | // each block each 128 index accumulated 4index 132 | // each index maximum 256 133 | // each block maximum 4 * 256 134 | // each block accumulation maximum 127 * 256 135 | // each 32 group index (128 index in one group) needs cast to int32 136 | xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); 137 | xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); 138 | xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); 139 | xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); 140 | 141 | accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1)); 142 | accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3)); 143 | } 144 | accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, _mm256_set1_epi16(1)), accu); 145 | } 146 | 147 | for (int i = 0; i < groupla_num; i++){ 148 | __m256i accula = _mm256_setzero_si256(); 149 | for (int j = 0; j < la_num; j++) { 150 | // 128 index 151 | __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + group32_num * 32 * 32 + j * 32)); 152 | __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); 153 | __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); 154 | __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); 155 | 156 | // each 32 index 157 | xq8_3 = _mm256_and_si256(xq8_3, mask); 158 | xq8_2 = _mm256_and_si256(xq8_2, mask); 159 | xq8_1 = _mm256_and_si256(xq8_1, mask); 160 | xq8_0 = _mm256_and_si256(xq8_0, mask); 161 | 162 | // each 32 index 163 | __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 0)); 164 | __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 32)); 165 | __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 64)); 166 | __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 96)); 167 | 168 | // 128 index accumulation add 169 | // split into 32 accumulation block 170 | // each block each 128 index accumulated 4index 171 | // each index maximum 256 172 | // each block maximum 4 * 256 173 | // each block accumulation maximum 127 * 256 174 | // each 32 group index (128 index in one group) needs cast to int32 175 | xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); 176 | xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); 177 | xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); 178 | xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); 179 | 180 | accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1)); 181 | accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3)); 182 | } 183 | accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, _mm256_set1_epi16(1))); 184 | } 185 | int sumi = hsum_i32_8(accu); 186 | *s = (float)sumi; 187 | 188 | #elif defined(__ARM_NEON) 189 | 190 | int32x4_t accu_0 = vdupq_n_s32(0); 191 | int32x4_t accu_1 = vdupq_n_s32(0); 192 | int32x4_t accu_2 = vdupq_n_s32(0); 193 | int32x4_t accu_3 = vdupq_n_s32(0); 194 | const uint8x16_t mask = vdupq_n_u8(3); 195 | 196 | for (int i=0; i < group32_num; i++) { 197 | 198 | #if defined(__ARM_FEATURE_DOTPROD) 199 | 200 | #else 201 | int16x8_t accu32_0 = vdupq_n_s16(0); 202 | int16x8_t accu32_1 = vdupq_n_s16(0); 203 | int16x8_t accu32_2 = vdupq_n_s16(0); 204 | int16x8_t accu32_3 = vdupq_n_s16(0); 205 | #endif 206 | 207 | for (int j=0; j < 32; j++) { 208 | uint8x16_t xq8_6 = vld1q_u8(x + i * 32 * 32 + j * 32); 209 | uint8x16_t xq8_7 = vld1q_u8(x + i * 32 * 32 + j * 32 + 16); 210 | uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2); 211 | uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2); 212 | uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4); 213 | uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4); 214 | uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6); 215 | uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6); 216 | 217 | int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); 218 | int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); 219 | int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); 220 | int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); 221 | int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask)); 222 | int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask)); 223 | int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask)); 224 | int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask)); 225 | 226 | const int8x16_t yq8_0 = vld1q_s8(y + i * 128 * 32 + j * 128 + 0); 227 | const int8x16_t yq8_1 = vld1q_s8(y + i * 128 * 32 + j * 128 + 16); 228 | const int8x16_t yq8_2 = vld1q_s8(y + i * 128 * 32 + j * 128 + 32); 229 | const int8x16_t yq8_3 = vld1q_s8(y + i * 128 * 32 + j * 128 + 48); 230 | const int8x16_t yq8_4 = vld1q_s8(y + i * 128 * 32 + j * 128 + 64); 231 | const int8x16_t yq8_5 = vld1q_s8(y + i * 128 * 32 + j * 128 + 80); 232 | const int8x16_t yq8_6 = vld1q_s8(y + i * 128 * 32 + j * 128 + 96); 233 | const int8x16_t yq8_7 = vld1q_s8(y + i * 128 * 32 + j * 128 + 112); 234 | 235 | #if defined(__ARM_FEATURE_DOTPROD) 236 | accu_0 = vdotq_s32(accu_0, q8_0, yq8_0); 237 | accu_1 = vdotq_s32(accu_1, q8_1, yq8_1); 238 | accu_2 = vdotq_s32(accu_2, q8_2, yq8_2); 239 | accu_3 = vdotq_s32(accu_3, q8_3, yq8_3); 240 | accu_0 = vdotq_s32(accu_0, q8_4, yq8_4); 241 | accu_1 = vdotq_s32(accu_1, q8_5, yq8_5); 242 | accu_2 = vdotq_s32(accu_2, q8_6, yq8_6); 243 | accu_3 = vdotq_s32(accu_3, q8_7, yq8_7); 244 | #else 245 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_0), vget_low_s8(yq8_0)); 246 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_0), vget_high_s8(yq8_0)); 247 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_1), vget_low_s8(yq8_1)); 248 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_1), vget_high_s8(yq8_1)); 249 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_2), vget_low_s8(yq8_2)); 250 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_2), vget_high_s8(yq8_2)); 251 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_3), vget_low_s8(yq8_3)); 252 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_3), vget_high_s8(yq8_3)); 253 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_4), vget_low_s8(yq8_4)); 254 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_4), vget_high_s8(yq8_4)); 255 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_5), vget_low_s8(yq8_5)); 256 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_5), vget_high_s8(yq8_5)); 257 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_6), vget_low_s8(yq8_6)); 258 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_6), vget_high_s8(yq8_6)); 259 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_7), vget_low_s8(yq8_7)); 260 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_7), vget_high_s8(yq8_7)); 261 | #endif 262 | } 263 | 264 | #if defined(__ARM_FEATURE_DOTPROD) 265 | 266 | #else 267 | accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accu32_0))); 268 | accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accu32_0)); 269 | accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accu32_1))); 270 | accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accu32_1)); 271 | accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accu32_2))); 272 | accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accu32_2)); 273 | accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accu32_3))); 274 | accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accu32_3)); 275 | #endif 276 | } 277 | 278 | for (int i = 0; i < groupla_num; i++){ 279 | #if defined(__ARM_FEATURE_DOTPROD) 280 | 281 | #else 282 | int16x8_t accula_0 = vdupq_n_s16(0); 283 | int16x8_t accula_1 = vdupq_n_s16(0); 284 | int16x8_t accula_2 = vdupq_n_s16(0); 285 | int16x8_t accula_3 = vdupq_n_s16(0); 286 | #endif 287 | for (int j = 0; j < la_num; j++) { 288 | uint8x16_t xq8_6 = vld1q_u8(x + group32_num * 32 * 32 + j * 32); 289 | uint8x16_t xq8_7 = vld1q_u8(x + group32_num * 32 * 32 + j * 32 + 16); 290 | uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2); 291 | uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2); 292 | uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4); 293 | uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4); 294 | uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6); 295 | uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6); 296 | 297 | int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); 298 | int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); 299 | int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); 300 | int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); 301 | int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask)); 302 | int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask)); 303 | int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask)); 304 | int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask)); 305 | 306 | const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 0); 307 | const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 16); 308 | const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 32); 309 | const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 48); 310 | const int8x16_t yq8_4 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 64); 311 | const int8x16_t yq8_5 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 80); 312 | const int8x16_t yq8_6 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 96); 313 | const int8x16_t yq8_7 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 112); 314 | 315 | #if defined(__ARM_FEATURE_DOTPROD) 316 | accu_0 = vdotq_s32(accu_0, q8_0, yq8_0); 317 | accu_1 = vdotq_s32(accu_1, q8_1, yq8_1); 318 | accu_2 = vdotq_s32(accu_2, q8_2, yq8_2); 319 | accu_3 = vdotq_s32(accu_3, q8_3, yq8_3); 320 | accu_0 = vdotq_s32(accu_0, q8_4, yq8_4); 321 | accu_1 = vdotq_s32(accu_1, q8_5, yq8_5); 322 | accu_2 = vdotq_s32(accu_2, q8_6, yq8_6); 323 | accu_3 = vdotq_s32(accu_3, q8_7, yq8_7); 324 | #else 325 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_0), vget_low_s8(yq8_0)); 326 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_0), vget_high_s8(yq8_0)); 327 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_1), vget_low_s8(yq8_1)); 328 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_1), vget_high_s8(yq8_1)); 329 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_2), vget_low_s8(yq8_2)); 330 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_2), vget_high_s8(yq8_2)); 331 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_3), vget_low_s8(yq8_3)); 332 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_3), vget_high_s8(yq8_3)); 333 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_4), vget_low_s8(yq8_4)); 334 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_4), vget_high_s8(yq8_4)); 335 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_5), vget_low_s8(yq8_5)); 336 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_5), vget_high_s8(yq8_5)); 337 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_6), vget_low_s8(yq8_6)); 338 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_6), vget_high_s8(yq8_6)); 339 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_7), vget_low_s8(yq8_7)); 340 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_7), vget_high_s8(yq8_7)); 341 | #endif 342 | } 343 | #if defined(__ARM_FEATURE_DOTPROD) 344 | 345 | #else 346 | accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accula_0))); 347 | accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accula_0)); 348 | accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accula_1))); 349 | accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accula_1)); 350 | accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accula_2))); 351 | accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accula_2)); 352 | accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accula_3))); 353 | accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accula_3)); 354 | #endif 355 | } 356 | accu_0 = vaddq_s32(accu_0, accu_1); 357 | accu_2 = vaddq_s32(accu_2, accu_3); 358 | accu_0 = vaddq_s32(accu_0, accu_2); 359 | int sumi = vaddlvq_s32(accu_0); 360 | *s = (float)sumi; 361 | 362 | #endif 363 | } -------------------------------------------------------------------------------- /utils/codegen_tl1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from configparser import ConfigParser 4 | 5 | def gen_ctor_code(): 6 | kernel_code = "\n\ 7 | #include \"ggml-bitnet.h\"\n\ 8 | #define GGML_BITNET_MAX_NODES 8192\n\ 9 | static bool initialized = false;\n\ 10 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\ 11 | static size_t bitnet_tensor_extras_index = 0;\n\ 12 | static void * aligned_malloc(size_t size) {{\n\ 13 | #if defined(_WIN32)\n\ 14 | return _aligned_malloc(size, 64);\n\ 15 | #else\n\ 16 | void * ptr = nullptr;\n\ 17 | posix_memalign(&ptr, 64, size);\n\ 18 | return ptr;\n\ 19 | #endif\n\ 20 | }}\n\ 21 | static void aligned_free(void * ptr) {{\n\ 22 | #if defined(_WIN32)\n\ 23 | _aligned_free(ptr);\n\ 24 | #else\n\ 25 | free(ptr);\n\ 26 | #endif\n\ 27 | }}\n\ 28 | \n\ 29 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{\n\ 30 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ 31 | bitnet_float_type* b = (bitnet_float_type*)b_;\n\ 32 | #ifdef __ARM_NEON\n\ 33 | float32x4_t temp_max = vdupq_n_f32(0);\n\ 34 | for (int i=0; i < k / 4; i++) {{\n\ 35 | float32x4_t vec_bs = vld1q_f32(b + 4 * i);\n\ 36 | float32x4_t abssum = vabsq_f32(vec_bs);\n\ 37 | temp_max = vmaxq_f32(abssum, temp_max);\n\ 38 | }}\n\ 39 | float32_t scales = 127 / vmaxvq_f32(temp_max);\n\ 40 | *lut_scales = scales;\n\ 41 | #elif defined __AVX2__\n\ 42 | __m256 max_vec = _mm256_set1_ps(0.f);\n\ 43 | const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\ 44 | // #pragma unroll\n\ 45 | for (int i = 0; i < k / 8; i++) {{\n\ 46 | __m256 vec_b = _mm256_loadu_ps(b + i * 8);\n\ 47 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);\n\ 48 | max_vec = _mm256_max_ps(vec_babs, max_vec);\n\ 49 | }}\n\ 50 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));\n\ 51 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));\n\ 52 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));\n\ 53 | float scales = 127 / _mm_cvtss_f32(max1);\n\ 54 | *lut_scales = scales;\n\ 55 | #endif\n\ 56 | }}\n\ 57 | \n\ 58 | void partial_max_reset(void* lut_scales_) {{\n\ 59 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\ 60 | *lut_scales = 0.0;\n\ 61 | }}\n\ 62 | \n\ 63 | #ifdef __ARM_NEON\n\ 64 | inline void Transpose_8_8(\n\ 65 | int16x8_t *v0,\n\ 66 | int16x8_t *v1,\n\ 67 | int16x8_t *v2,\n\ 68 | int16x8_t *v3,\n\ 69 | int16x8_t *v4,\n\ 70 | int16x8_t *v5,\n\ 71 | int16x8_t *v6,\n\ 72 | int16x8_t *v7)\n\ 73 | {{\n\ 74 | int16x8x2_t q04 = vzipq_s16(*v0, *v4);\n\ 75 | int16x8x2_t q15 = vzipq_s16(*v1, *v5);\n\ 76 | int16x8x2_t q26 = vzipq_s16(*v2, *v6);\n\ 77 | int16x8x2_t q37 = vzipq_s16(*v3, *v7);\n\ 78 | \n\ 79 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);\n\ 80 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);\n\ 81 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);\n\ 82 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);\n\ 83 | \n\ 84 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);\n\ 85 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);\n\ 86 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);\n\ 87 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);\n\ 88 | \n\ 89 | *v0 = q_fin_0.val[0];\n\ 90 | *v1 = q_fin_0.val[1];\n\ 91 | *v2 = q_fin_1.val[0];\n\ 92 | *v3 = q_fin_1.val[1];\n\ 93 | *v4 = q_fin_2.val[0];\n\ 94 | *v5 = q_fin_2.val[1];\n\ 95 | *v6 = q_fin_3.val[0];\n\ 96 | *v7 = q_fin_3.val[1];\n\ 97 | }}\n\ 98 | #endif\n\ 99 | \n\ 100 | template\n\ 101 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{\n\ 102 | #ifdef __ARM_NEON\n\ 103 | int16x8_t vec_lut[16];\n\ 104 | float32_t scales = *lut_scales;\n\ 105 | uint8_t tbl_mask[16];\n\ 106 | tbl_mask[0] = 0;\n\ 107 | tbl_mask[1] = 2;\n\ 108 | tbl_mask[2] = 4;\n\ 109 | tbl_mask[3] = 6;\n\ 110 | tbl_mask[4] = 8;\n\ 111 | tbl_mask[5] = 10;\n\ 112 | tbl_mask[6] = 12;\n\ 113 | tbl_mask[7] = 14;\n\ 114 | tbl_mask[8] = 1;\n\ 115 | tbl_mask[9] = 3;\n\ 116 | tbl_mask[10] = 5;\n\ 117 | tbl_mask[11] = 7;\n\ 118 | tbl_mask[12] = 9;\n\ 119 | tbl_mask[13] = 11;\n\ 120 | tbl_mask[14] = 13;\n\ 121 | tbl_mask[15] = 15;\n\ 122 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);\n\ 123 | #pragma unroll\n\ 124 | for (int k = 0; k < act_k / 16; ++k) {{\n\ 125 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);\n\ 126 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);\n\ 127 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);\n\ 128 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);\n\ 129 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);\n\ 130 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);\n\ 131 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);\n\ 132 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);\n\ 133 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);\n\ 134 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);\n\ 135 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);\n\ 136 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);\n\ 137 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);\n\ 138 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);\n\ 139 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);\n\ 140 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);\n\ 141 | vec_lut[0] = vdupq_n_s16(0);\n\ 142 | vec_lut[0] = vec_lut[0] - vec_bs_0;\n\ 143 | vec_lut[0] = vec_lut[0] - vec_bs_1;\n\ 144 | vec_lut[1] = vdupq_n_s16(0);\n\ 145 | vec_lut[1] = vec_lut[1] - vec_bs_0;\n\ 146 | vec_lut[2] = vdupq_n_s16(0);\n\ 147 | vec_lut[2] = vec_lut[2] - vec_bs_0;\n\ 148 | vec_lut[2] = vec_lut[2] + vec_bs_1;\n\ 149 | vec_lut[3] = vdupq_n_s16(0);\n\ 150 | vec_lut[3] = vec_lut[3] - vec_bs_1;\n\ 151 | vec_lut[4] = vdupq_n_s16(0);\n\ 152 | vec_lut[5] = vec_bs_1;\n\ 153 | vec_lut[6] = vec_bs_0;\n\ 154 | vec_lut[6] = vec_lut[6] - vec_bs_1;\n\ 155 | vec_lut[7] = vec_bs_0;\n\ 156 | vec_lut[8] = vec_bs_0;\n\ 157 | vec_lut[8] = vec_lut[8] + vec_bs_1;\n\ 158 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),\n\ 159 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));\n\ 160 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),\n\ 161 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));\n\ 162 | #pragma unroll\n\ 163 | for (int idx = 0; idx < 8; idx++) {{\n\ 164 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);\n\ 165 | int8x8_t q0_low = vget_low_s8(q0_s);\n\ 166 | int8x8_t q0_high = vget_high_s8(q0_s);\n\ 167 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);\n\ 168 | int8x8_t q1_low = vget_low_s8(q1_s);\n\ 169 | int8x8_t q1_high = vget_high_s8(q1_s);\n\ 170 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);\n\ 171 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);\n\ 172 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);\n\ 173 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);\n\ 174 | }}\n\ 175 | }}\n\ 176 | #endif\n\ 177 | }}\n\ 178 | \n\ 179 | static bool is_type_supported(enum ggml_type type) {{\n\ 180 | if (type == GGML_TYPE_Q4_0 ||\n\ 181 | type == GGML_TYPE_TL1) {{\n\ 182 | return true;\n\ 183 | }} else {{\n\ 184 | return false;\n\ 185 | }}\n\ 186 | }}\n\ 187 | " 188 | return kernel_code 189 | 190 | def gen_body_core_code(bm, by): 191 | length = 4 192 | all_code = "" 193 | for i in range(length): 194 | core_code = "\n\ 195 | uint8x16_t vec_a_{0} = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + {0} * 16);\n\ 196 | uint8x16_t vec_a{0}_top = vshrq_n_u8(vec_a_{0}, 4);\n\ 197 | uint8x16_t vec_a{0}_bot = vandq_u8(vec_a_{0}, vec_mask);\n\ 198 | int8x16_t vec_v_{0}_left_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {2}], vec_a{0}_top);\n\ 199 | int8x16_t vec_v_{0}_left_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {3}], vec_a{0}_top);\n\ 200 | int8x16_t vec_v_{0}_right_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {4}], vec_a{0}_bot);\n\ 201 | int8x16_t vec_v_{0}_right_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {5}], vec_a{0}_bot);\n\ 202 | int8x16x2_t vec_v_left_{0} = vzipq_s8(vec_v_{0}_left_tmp1, vec_v_{0}_left_tmp0);\n\ 203 | int8x16x2_t vec_v_right_{0} = vzipq_s8(vec_v_{0}_right_tmp1, vec_v_{0}_right_tmp0);\n\ 204 | vec_c[{6}] += vec_v_left_{0}.val[0];\n\ 205 | vec_c[{6}] += vec_v_right_{0}.val[0];\n\ 206 | vec_c[{7}] += vec_v_left_{0}.val[1];\n\ 207 | vec_c[{7}] += vec_v_right_{0}.val[1];\n\ 208 | ".format(i, 2 * by // 2, (4 * i) % (2 * by // 2), (4 * i + 1) % (2 * by // 2), (4 * i + 2) % (2 * by // 2), (4 * i + 3) % (2 * by // 2), (i * 2) // (by // 2) * 2 + 0, (i * 2) // (by // 2) * 2 + 1) 209 | 210 | all_code = "".join([all_code, core_code]) 211 | 212 | all_code = "".join([all_code, "\n }\n\n"]) 213 | 214 | for i in range(bm // 8): 215 | core_code = "\ 216 | int32x4_t vec_v_bot_low_low_{0} = vmovl_s16(vget_low_s16(vec_c[{0}]));\n\ 217 | int32x4_t vec_v_bot_low_high_{0} = vmovl_high_s16(vec_c[{0}]);\n\ 218 | vst1q_s32(c + i + {1}, vld1q_s32(c + i + {1}) + vec_v_bot_low_low_{0});\n\ 219 | vst1q_s32(c + i + {2}, vld1q_s32(c + i + {2}) + vec_v_bot_low_high_{0});\n".format(i, i * 8, i * 8 + 4) 220 | all_code = "".join([all_code, core_code]) 221 | 222 | return all_code 223 | 224 | def gen_tbl_impl(pre, BM, BK, bm, k): 225 | 226 | kernel_code = "\ 227 | #include \n\ 228 | \n\ 229 | #define BM{0} {1}\n\ 230 | #define BBK{0} {2}\n\ 231 | inline void tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\ 232 | #ifdef __ARM_NEON\n\ 233 | const int KK = BBK{0} / 2;\n\ 234 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\ 235 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\ 236 | int8x16_t vec_lut[2 * KK];\n\ 237 | ".format(pre, BM, BK) 238 | 239 | kernel_code = "".join([kernel_code, " int16x8_t vec_c[{}];".format(bm // 8)]) 240 | 241 | kernel_code = "".join([kernel_code, "\n\ 242 | #pragma unroll\n\ 243 | for (int k = 0; k < 2 * KK; k++) {\n\ 244 | vec_lut[k] = vld1q_s8(lut + k * 16);\n\ 245 | }\n"]) 246 | 247 | pre_core_code = "\n\ 248 | #pragma unroll\n\ 249 | for (int i = 0; i < BM{}; i += {}) {{\n\ 250 | #pragma unroll\n\ 251 | for (int i=0; i<{}; i++) {{\n\ 252 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);\n\ 253 | }}\n".format(pre, bm, bm // 8) 254 | 255 | body_core_pre_code = "\n\ 256 | #pragma unroll\n\ 257 | for (int k = 0; k < KK / {}; k++) {{\n\ 258 | ".format(256 // bm // 2) 259 | 260 | body_core_post_code = "\n\ 261 | }\n\ 262 | \ 263 | #endif\n\ 264 | }\n" 265 | 266 | kernel_code = "".join([kernel_code, pre_core_code, body_core_pre_code, gen_body_core_code(bm, 256 // bm), body_core_post_code]) 267 | 268 | kernel_code = "".join([kernel_code, "\n\ 269 | int32_t qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ 270 | alignas({1}) uint32_t CBits[BM{0}];\n\ 271 | memset(&(CBits[0]), 0, BM{0} * sizeof(int32_t));\n\ 272 | #pragma unroll\n\ 273 | for (int32_t k_outer = 0; k_outer < {2} / BBK{0}; ++k_outer) {{\n\ 274 | tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 2 / 2 * BM{0})])));\n\ 275 | }}\n\ 276 | #pragma unroll\n\ 277 | for (int i = 0; i < BM{0}; i++) {{\n\ 278 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];\n\ 279 | }}\n\ 280 | return 0;\n\ 281 | }};\n".format(pre, min(32, BK), k)]) 282 | 283 | return kernel_code 284 | 285 | def gen_top_api(kernel_shapes): 286 | 287 | kernel_code = "void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {{\n\ 288 | if (m == {0} && k == {1}) {{\n\ 289 | preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\ 290 | }}\n\ 291 | ".format(kernel_shapes[0][0], kernel_shapes[0][1]) 292 | for i in range(1, len(kernel_shapes)): 293 | kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\ 294 | preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\ 295 | }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])]) 296 | kernel_code = "".join([kernel_code, "}\n"]) 297 | kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\ 298 | if (m == {0} && k == {1}) {{\n\ 299 | qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\ 300 | }}\n\ 301 | ".format(kernel_shapes[0][0], kernel_shapes[0][1])]) 302 | for i in range(1, len(kernel_shapes)): 303 | kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\ 304 | qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\ 305 | }}\n\ 306 | ".format(kernel_shapes[i][0], kernel_shapes[i][1])]) 307 | kernel_code = "".join([kernel_code, "}\n"]) 308 | return kernel_code 309 | 310 | def gen_preprocess_code(): 311 | kernel_code = "\n\ 312 | template\n\ 313 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{\n\ 314 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));\n\ 315 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));\n\ 316 | \n\ 317 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));\n\ 318 | }}\n" 319 | return kernel_code 320 | 321 | def gen_transform_code(kernel_shape): 322 | kernel_code = "\n\ 323 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\ 324 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\ 325 | return;\n\ 326 | }\n\ 327 | \n\ 328 | int k = tensor->ne[0];\n\ 329 | int m = tensor->ne[1];\n\ 330 | const int lut_scales_size = 1;\n\ 331 | const int scales_size = 1;\n\ 332 | int bk = 0;\n\ 333 | int bm = 0;\n" 334 | 335 | kernel_code = "".join([kernel_code, "\n\ 336 | if (m == {0} && k == {1}) {{\n\ 337 | bm = BM{0}_{1};\n\ 338 | bk = BBK{0}_{1};\n\ 339 | }}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])]) 340 | 341 | for i in range(1, len(kernel_shapes)): 342 | kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\ 343 | bm = BM{0}_{1};\n\ 344 | bk = BBK{0}_{1};\n\ 345 | }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])]) 346 | 347 | kernel_code = "".join([kernel_code, "\n\ 348 | const int n_tile_num = m / bm;\n\ 349 | const int BK = bk;\n\ 350 | uint8_t * qweights;\n\ 351 | bitnet_float_type * scales;\n\ 352 | \n\ 353 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\ 354 | qweights = (uint8_t *) tensor->data;\n\ 355 | float * i2_scales = (float * )(qweights + k * m / 4);\n\ 356 | scales[0] = (bitnet_float_type) i2_scales[0];\n\ 357 | \n\ 358 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\ 359 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\ 360 | /* .lut_scales_size = */ lut_scales_size,\n\ 361 | /* .BK = */ BK,\n\ 362 | /* .n_tile_num = */ n_tile_num,\n\ 363 | /* .qweights = */ qweights,\n\ 364 | /* .scales = */ scales\n\ 365 | };\n\ 366 | }\n"]) 367 | 368 | return kernel_code 369 | 370 | if __name__ == "__main__": 371 | ModelShapeDict = { 372 | "bitnet_b1_58-large" : [[1536, 4096], 373 | [1536, 1536], 374 | [4096, 1536]], 375 | "bitnet_b1_58-3B" : [[3200, 8640], 376 | [3200, 3200], 377 | [8640, 3200]], 378 | "Llama3-8B-1.58-100B-tokens" : [[14336, 4096], 379 | [4096, 14336], 380 | [1024, 4096], 381 | [4096, 4096]] 382 | } 383 | 384 | parser = argparse.ArgumentParser(description='gen impl') 385 | parser.add_argument('--model',default="input", type=str, dest="model", 386 | help="choose from bitnet_b1_58-large/bitnet_b1_58-3B/Llama3-8B-1.58-100B-tokens.") 387 | parser.add_argument('--BM',default="input", type=str, 388 | help="block length when cutting one weight (M, K) into M / BM weights (BM, K).") 389 | parser.add_argument('--BK',default="input", type=str, 390 | help="block length when cutting one weight (M, K) into K / BK weights (M, BK).") 391 | parser.add_argument('--bm',default="input", type=str, 392 | help="using simd instructions to compute (bm, 256 / bm) in one block") 393 | args = parser.parse_args() 394 | 395 | kernel_shapes = ModelShapeDict[args.model] 396 | 397 | BM_list = [int(item) for item in args.BM.split(',')] 398 | BK_list = [int(item) for item in args.BK.split(',')] 399 | bm_list = [int(item) for item in args.bm.split(',')] 400 | 401 | assert(len(BM_list) == len(BK_list) == len(bm_list) == len(kernel_shapes)), "number of BM / BK / bm shoud be {}".format(len(kernel_shapes)) 402 | 403 | for i in range(len(kernel_shapes)): 404 | assert kernel_shapes[i][0] % BM_list[i] == 0, "M %% BM should be 0" 405 | assert kernel_shapes[i][1] % BK_list[i] == 0, "K %% BK should be 0" 406 | assert bm_list[i] in [32, 64], "choose bm from [32, 64]" 407 | 408 | tbl_impl_code = [] 409 | 410 | for i in range(len(kernel_shapes)): 411 | tbl_impl_code.append( 412 | gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], kernel_shapes[i][1]) 413 | ) 414 | api_code = gen_top_api(kernel_shapes) 415 | pre_code = gen_preprocess_code() 416 | ctor_code = gen_ctor_code() 417 | trans_code = gen_transform_code(kernel_shapes) 418 | 419 | output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include") 420 | 421 | with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f: 422 | f.write(''.join("#if defined(GGML_BITNET_ARM_TL1)")) 423 | f.write(''.join(ctor_code)) 424 | for code in tbl_impl_code: 425 | f.write(''.join(code)) 426 | f.write(''.join(pre_code)) 427 | f.write(''.join(api_code)) 428 | f.write(''.join(trans_code)) 429 | f.write(''.join("#endif")) 430 | 431 | config = ConfigParser() 432 | 433 | for i in range(len(kernel_shapes)): 434 | config.add_section('Kernels_{}'.format(i)) 435 | config.set('Kernels_{}'.format(i), 'M'.format(i), str(kernel_shapes[i][0])) 436 | config.set('Kernels_{}'.format(i), 'K'.format(i), str(kernel_shapes[i][1])) 437 | config.set('Kernels_{}'.format(i), 'BM'.format(i), str(BM_list[i])) 438 | config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i])) 439 | config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i])) 440 | 441 | with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile: 442 | config.write(configfile) -------------------------------------------------------------------------------- /utils/e2e_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import argparse 5 | import platform 6 | import subprocess 7 | 8 | def run_command(command, shell=False, log_step=None): 9 | """Run a system command and ensure it succeeds.""" 10 | if log_step: 11 | log_file = os.path.join(args.log_dir, log_step + ".log") 12 | with open(log_file, "w") as f: 13 | try: 14 | subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f) 15 | except subprocess.CalledProcessError as e: 16 | logging.error(f"Error occurred while running command: {e}, check details in {log_file}") 17 | sys.exit(1) 18 | else: 19 | try: 20 | subprocess.run(command, shell=shell, check=True) 21 | except subprocess.CalledProcessError as e: 22 | logging.error(f"Error occurred while running command: {e}") 23 | sys.exit(1) 24 | 25 | def run_benchmark(): 26 | build_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "build") 27 | if platform.system() == "Windows": 28 | bench_path = os.path.join(build_dir, "bin", "Release", "llama-bench.exe") 29 | if not os.path.exists(bench_path): 30 | bench_path = os.path.join(build_dir, "bin", "llama-bench") 31 | else: 32 | bench_path = os.path.join(build_dir, "bin", "llama-bench") 33 | if not os.path.exists(bench_path): 34 | logging.error(f"Benchmark binary not found, please build first.") 35 | sys.exit(1) 36 | command = [ 37 | f'{bench_path}', 38 | '-m', args.model, 39 | '-n', str(args.n_token), 40 | '-ngl', '0', 41 | '-b', '1', 42 | '-t', str(args.threads), 43 | '-p', str(args.n_prompt), 44 | '-r', '5' 45 | ] 46 | run_command(command) 47 | 48 | def parse_args(): 49 | parser = argparse.ArgumentParser(description='Setup the environment for running the inference') 50 | parser.add_argument("-m", "--model", type=str, help="Path to model file", required=True) 51 | parser.add_argument("-n", "--n-token", type=int, help="Number of generated tokens", required=False, default=128) 52 | parser.add_argument("-p", "--n-prompt", type=int, help="Prompt to generate text from", required=False, default=512) 53 | parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2) 54 | return parser.parse_args() 55 | 56 | if __name__ == "__main__": 57 | logging.basicConfig(level=logging.INFO) 58 | args = parse_args() 59 | run_benchmark() -------------------------------------------------------------------------------- /utils/kernel_tuning.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/utils/kernel_tuning.py --------------------------------------------------------------------------------