├── .gitignore
├── .gitmodules
├── CMakeLists.txt
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── assets
├── header_model_release.png
├── intel_performance.jpg
├── m2_performance.jpg
├── tl1.png
└── tl2.png
├── docs
└── codegen.md
├── include
└── ggml-bitnet.h
├── media
├── benchmark.png
└── demo.mp4
├── preset_kernels
├── Llama3-8B-1.58-100B-tokens
│ ├── bitnet-lut-kernels-tl1.h
│ ├── bitnet-lut-kernels-tl2.h
│ ├── kernel_config_tl1.ini
│ └── kernel_config_tl2.ini
├── bitnet_b1_58-3B
│ ├── bitnet-lut-kernels-tl1.h
│ ├── bitnet-lut-kernels-tl2.h
│ ├── kernel_config_tl1.ini
│ └── kernel_config_tl2.ini
└── bitnet_b1_58-large
│ ├── bitnet-lut-kernels-tl1.h
│ ├── bitnet-lut-kernels-tl2.h
│ ├── kernel_config_tl1.ini
│ └── kernel_config_tl2.ini
├── requirements.txt
├── run_inference.py
├── setup_env.py
├── src
├── CMakeLists.txt
├── ggml-bitnet-lut.cpp
└── ggml-bitnet-mad.cpp
└── utils
├── codegen_tl1.py
├── codegen_tl2.py
├── convert-hf-to-gguf-bitnet.py
├── convert-ms-to-gguf-bitnet.py
├── convert.py
├── e2e_benchmark.py
├── generate-dummy-bitnet-model.py
└── kernel_tuning.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Extensions
2 |
3 | *.a
4 | *.bat
5 | *.bin
6 | *.dll
7 | *.dot
8 | *.etag
9 | *.exe
10 | *.gcda
11 | *.gcno
12 | *.gcov
13 | *.gguf
14 | *.gguf.json
15 | *.lastModified
16 | *.log
17 | *.metallib
18 | *.o
19 | *.so
20 | *.tmp
21 |
22 | # IDE / OS
23 |
24 | .cache/
25 | .ccls-cache/
26 | .direnv/
27 | .DS_Store
28 | .envrc
29 | .idea/
30 | .swiftpm
31 | .vs/
32 | .vscode/
33 | nppBackup
34 |
35 | # Models
36 | models/*
37 |
38 | # Python
39 |
40 | /.venv
41 | __pycache__/
42 | */poetry.lock
43 | poetry.toml
44 |
45 | build/
46 | logs/
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "3rdparty/llama.cpp"]
2 | path = 3rdparty/llama.cpp
3 | url = https://github.com/Eddie-Wang1120/llama.cpp.git
4 | branch = merge-dev
5 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
2 | project("bitnet.cpp" C CXX)
3 | include(CheckIncludeFileCXX)
4 |
5 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
6 |
7 | if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
8 | set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
9 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
10 | endif()
11 |
12 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
13 |
14 | # option list
15 | option(BITNET_ARM_TL1 "bitnet.cpp: use tl1 on arm platform" OFF)
16 | option(BITNET_X86_TL2 "bitnet.cpp: use tl2 on x86 platform" OFF)
17 |
18 |
19 | set(CMAKE_CXX_STANDARD_REQUIRED true)
20 | set(CMAKE_C_STANDARD 11)
21 | set(CMAKE_C_STANDARD_REQUIRED true)
22 | set(THREADS_PREFER_PTHREAD_FLAG ON)
23 |
24 | # override ggml options
25 | set(GGML_BITNET_ARM_TL1 ${BITNET_ARM_TL1})
26 | set(GGML_BITNET_X86_TL2 ${BITNET_X86_TL2})
27 |
28 | if (GGML_BITNET_ARM_TL1)
29 | add_compile_definitions(GGML_BITNET_ARM_TL1)
30 | endif()
31 | if (GGML_BITNET_X86_TL2)
32 | add_compile_definitions(GGML_BITNET_X86_TL2)
33 | endif()
34 |
35 | if (CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
36 | add_compile_options(-fpermissive)
37 | endif()
38 |
39 | find_package(Threads REQUIRED)
40 |
41 | add_subdirectory(src)
42 | add_subdirectory(3rdparty/llama.cpp)
43 |
44 | # install
45 |
46 | include(GNUInstallDirs)
47 | include(CMakePackageConfigHelpers)
48 |
49 | set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR}
50 | CACHE PATH "Location of header files")
51 | set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR}
52 | CACHE PATH "Location of library files")
53 | set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR}
54 | CACHE PATH "Location of binary files")
55 | set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER})
56 | set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT})
57 | set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER})
58 |
59 | get_target_property(GGML_DIRECTORY ggml SOURCE_DIR)
60 | get_directory_property(GGML_DIR_DEFINES DIRECTORY ${GGML_DIRECTORY} COMPILE_DEFINITIONS)
61 | get_target_property(GGML_TARGET_DEFINES ggml COMPILE_DEFINITIONS)
62 | set(GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES} ${GGML_DIR_DEFINES})
63 | get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES)
64 |
65 | get_directory_property(LLAMA_TRANSIENT_DEFINES COMPILE_DEFINITIONS)
66 |
67 | write_basic_package_version_file(
68 | ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake
69 | VERSION ${LLAMA_INSTALL_VERSION}
70 | COMPATIBILITY SameMajorVersion)
71 |
72 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake
73 | ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake
74 | DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama)
75 |
76 | set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/llama.h)
77 | install(TARGETS llama LIBRARY PUBLIC_HEADER)
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # bitnet.cpp
2 | [](https://opensource.org/licenses/MIT)
3 | 
4 |
5 | [
](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
6 |
7 | Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or [build and run](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) it on your own CPU.
8 |
9 | bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU (with NPU and GPU support coming next).
10 |
11 | The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details.
12 |
13 |
14 |
15 |
16 | >The tested models are dummy setups used in a research context to demonstrate the inference performance of bitnet.cpp.
17 |
18 | ## Demo
19 |
20 | A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2:
21 |
22 | https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1
23 |
24 | ## What's New:
25 | - 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) 
26 | - 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880)
27 | - 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965)
28 | - 10/21/2024 [1-bit AI Infra: Part 1.1, Fast and Lossless BitNet b1.58 Inference on CPUs](https://arxiv.org/abs/2410.16144)
29 | - 10/17/2024 bitnet.cpp 1.0 released.
30 | - 03/21/2024 [The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf)
31 | - 02/27/2024 [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764)
32 | - 10/17/2023 [BitNet: Scaling 1-bit Transformers for Large Language Models](https://arxiv.org/abs/2310.11453)
33 |
34 | ## Acknowledgements
35 |
36 | This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) framework. We would like to thank all the authors for their contributions to the open-source community. Also, bitnet.cpp's kernels are built on top of the Lookup Table methodologies pioneered in [T-MAC](https://github.com/microsoft/T-MAC/). For inference of general low-bit LLMs beyond ternary models, we recommend using T-MAC.
37 | ## Official Models
38 |
39 |
40 |
41 | Model |
42 | Parameters |
43 | CPU |
44 | Kernel |
45 |
46 |
47 | I2_S |
48 | TL1 |
49 | TL2 |
50 |
51 |
52 | BitNet-b1.58-2B-4T |
53 | 2.4B |
54 | x86 |
55 | ✅ |
56 | ❌ |
57 | ✅ |
58 |
59 |
60 | ARM |
61 | ✅ |
62 | ✅ |
63 | ❌ |
64 |
65 |
66 |
67 | ## Supported Models
68 | ❗️**We use existing 1-bit LLMs available on [Hugging Face](https://huggingface.co/) to demonstrate the inference capabilities of bitnet.cpp. We hope the release of bitnet.cpp will inspire the development of 1-bit LLMs in large-scale settings in terms of model size and training tokens.**
69 |
70 |
71 |
72 |
73 | Model |
74 | Parameters |
75 | CPU |
76 | Kernel |
77 |
78 |
79 | I2_S |
80 | TL1 |
81 | TL2 |
82 |
83 |
84 | bitnet_b1_58-large |
85 | 0.7B |
86 | x86 |
87 | ✅ |
88 | ❌ |
89 | ✅ |
90 |
91 |
92 | ARM |
93 | ✅ |
94 | ✅ |
95 | ❌ |
96 |
97 |
98 | bitnet_b1_58-3B |
99 | 3.3B |
100 | x86 |
101 | ❌ |
102 | ❌ |
103 | ✅ |
104 |
105 |
106 | ARM |
107 | ❌ |
108 | ✅ |
109 | ❌ |
110 |
111 |
112 | Llama3-8B-1.58-100B-tokens |
113 | 8.0B |
114 | x86 |
115 | ✅ |
116 | ❌ |
117 | ✅ |
118 |
119 |
120 | ARM |
121 | ✅ |
122 | ✅ |
123 | ❌ |
124 |
125 |
126 | Falcon3 Family |
127 | 1B-10B |
128 | x86 |
129 | ✅ |
130 | ❌ |
131 | ✅ |
132 |
133 |
134 | ARM |
135 | ✅ |
136 | ✅ |
137 | ❌ |
138 |
139 |
140 |
141 |
142 |
143 | ## Installation
144 |
145 | ### Requirements
146 | - python>=3.9
147 | - cmake>=3.22
148 | - clang>=18
149 | - For Windows users, install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/). In the installer, toggle on at least the following options(this also automatically installs the required additional tools like CMake):
150 | - Desktop-development with C++
151 | - C++-CMake Tools for Windows
152 | - Git for Windows
153 | - C++-Clang Compiler for Windows
154 | - MS-Build Support for LLVM-Toolset (clang)
155 | - For Debian/Ubuntu users, you can download with [Automatic installation script](https://apt.llvm.org/)
156 |
157 | `bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)"`
158 | - conda (highly recommend)
159 |
160 | ### Build from source
161 |
162 | > [!IMPORTANT]
163 | > If you are using Windows, please remember to always use a Developer Command Prompt / PowerShell for VS2022 for the following commands. Please refer to the FAQs below if you see any issues.
164 |
165 | 1. Clone the repo
166 | ```bash
167 | git clone --recursive https://github.com/microsoft/BitNet.git
168 | cd BitNet
169 | ```
170 | 2. Install the dependencies
171 | ```bash
172 | # (Recommended) Create a new conda environment
173 | conda create -n bitnet-cpp python=3.9
174 | conda activate bitnet-cpp
175 |
176 | pip install -r requirements.txt
177 | ```
178 | 3. Build the project
179 | ```bash
180 | # Manually download the model and run with local path
181 | huggingface-cli download microsoft/BitNet-b1.58-2B-4T-gguf --local-dir models/BitNet-b1.58-2B-4T
182 | python setup_env.py -md models/BitNet-b1.58-2B-4T -q i2_s
183 |
184 | ```
185 |
186 | usage: setup_env.py [-h] [--hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}] [--model-dir MODEL_DIR] [--log-dir LOG_DIR] [--quant-type {i2_s,tl1}] [--quant-embd]
187 | [--use-pretuned]
188 |
189 | Setup the environment for running inference
190 |
191 | optional arguments:
192 | -h, --help show this help message and exit
193 | --hf-repo {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}, -hr {1bitLLM/bitnet_b1_58-large,1bitLLM/bitnet_b1_58-3B,HF1BitLLM/Llama3-8B-1.58-100B-tokens,tiiuae/Falcon3-1B-Instruct-1.58bit,tiiuae/Falcon3-3B-Instruct-1.58bit,tiiuae/Falcon3-7B-Instruct-1.58bit,tiiuae/Falcon3-10B-Instruct-1.58bit}
194 | Model used for inference
195 | --model-dir MODEL_DIR, -md MODEL_DIR
196 | Directory to save/load the model
197 | --log-dir LOG_DIR, -ld LOG_DIR
198 | Directory to save the logging info
199 | --quant-type {i2_s,tl1}, -q {i2_s,tl1}
200 | Quantization type
201 | --quant-embd Quantize the embeddings to f16
202 | --use-pretuned, -p Use the pretuned kernel parameters
203 |
204 | ## Usage
205 | ### Basic usage
206 | ```bash
207 | # Run inference with the quantized model
208 | python run_inference.py -m models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf -p "You are a helpful assistant" -cnv
209 | ```
210 |
211 | usage: run_inference.py [-h] [-m MODEL] [-n N_PREDICT] -p PROMPT [-t THREADS] [-c CTX_SIZE] [-temp TEMPERATURE] [-cnv]
212 |
213 | Run inference
214 |
215 | optional arguments:
216 | -h, --help show this help message and exit
217 | -m MODEL, --model MODEL
218 | Path to model file
219 | -n N_PREDICT, --n-predict N_PREDICT
220 | Number of tokens to predict when generating text
221 | -p PROMPT, --prompt PROMPT
222 | Prompt to generate text from
223 | -t THREADS, --threads THREADS
224 | Number of threads to use
225 | -c CTX_SIZE, --ctx-size CTX_SIZE
226 | Size of the prompt context
227 | -temp TEMPERATURE, --temperature TEMPERATURE
228 | Temperature, a hyperparameter that controls the randomness of the generated text
229 | -cnv, --conversation Whether to enable chat mode or not (for instruct models.)
230 | (When this option is turned on, the prompt specified by -p will be used as the system prompt.)
231 |
232 |
233 | ### Benchmark
234 | We provide scripts to run the inference benchmark providing a model.
235 |
236 | ```
237 | usage: e2e_benchmark.py -m MODEL [-n N_TOKEN] [-p N_PROMPT] [-t THREADS]
238 |
239 | Setup the environment for running the inference
240 |
241 | required arguments:
242 | -m MODEL, --model MODEL
243 | Path to the model file.
244 |
245 | optional arguments:
246 | -h, --help
247 | Show this help message and exit.
248 | -n N_TOKEN, --n-token N_TOKEN
249 | Number of generated tokens.
250 | -p N_PROMPT, --n-prompt N_PROMPT
251 | Prompt to generate text from.
252 | -t THREADS, --threads THREADS
253 | Number of threads to use.
254 | ```
255 |
256 | Here's a brief explanation of each argument:
257 |
258 | - `-m`, `--model`: The path to the model file. This is a required argument that must be provided when running the script.
259 | - `-n`, `--n-token`: The number of tokens to generate during the inference. It is an optional argument with a default value of 128.
260 | - `-p`, `--n-prompt`: The number of prompt tokens to use for generating text. This is an optional argument with a default value of 512.
261 | - `-t`, `--threads`: The number of threads to use for running the inference. It is an optional argument with a default value of 2.
262 | - `-h`, `--help`: Show the help message and exit. Use this argument to display usage information.
263 |
264 | For example:
265 |
266 | ```sh
267 | python utils/e2e_benchmark.py -m /path/to/model -n 200 -p 256 -t 4
268 | ```
269 |
270 | This command would run the inference benchmark using the model located at `/path/to/model`, generating 200 tokens from a 256 token prompt, utilizing 4 threads.
271 |
272 | For the model layout that do not supported by any public model, we provide scripts to generate a dummy model with the given model layout, and run the benchmark on your machine:
273 |
274 | ```bash
275 | python utils/generate-dummy-bitnet-model.py models/bitnet_b1_58-large --outfile models/dummy-bitnet-125m.tl1.gguf --outtype tl1 --model-size 125M
276 |
277 | # Run benchmark with the generated model, use -m to specify the model path, -p to specify the prompt processed, -n to specify the number of token to generate
278 | python utils/e2e_benchmark.py -m models/dummy-bitnet-125m.tl1.gguf -p 512 -n 128
279 | ```
280 | ### FAQ (Frequently Asked Questions)📌
281 |
282 | #### Q1: The build dies with errors building llama.cpp due to issues with std::chrono in log.cpp?
283 |
284 | **A:**
285 | This is an issue introduced in recent version of llama.cpp. Please refer to this [commit](https://github.com/tinglou/llama.cpp/commit/4e3db1e3d78cc1bcd22bcb3af54bd2a4628dd323) in the [discussion](https://github.com/abetlen/llama-cpp-python/issues/1942) to fix this issue.
286 |
287 | #### Q2: How to build with clang in conda environment on windows?
288 |
289 | **A:**
290 | Before building the project, verify your clang installation and access to Visual Studio tools by running:
291 | ```
292 | clang -v
293 | ```
294 |
295 | This command checks that you are using the correct version of clang and that the Visual Studio tools are available. If you see an error message such as:
296 | ```
297 | 'clang' is not recognized as an internal or external command, operable program or batch file.
298 | ```
299 |
300 | It indicates that your command line window is not properly initialized for Visual Studio tools.
301 |
302 | • If you are using Command Prompt, run:
303 | ```
304 | "C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64
305 | ```
306 |
307 | • If you are using Windows PowerShell, run the following commands:
308 | ```
309 | Import-Module "C:\Program Files\Microsoft Visual Studio\2022\Professional\Common7\Tools\Microsoft.VisualStudio.DevShell.dll" Enter-VsDevShell 3f0e31ad -SkipAutomaticLocation -DevCmdArguments "-arch=x64 -host_arch=x64"
310 | ```
311 |
312 | These steps will initialize your environment and allow you to use the correct Visual Studio tools.
313 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/assets/header_model_release.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/assets/header_model_release.png
--------------------------------------------------------------------------------
/assets/intel_performance.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/assets/intel_performance.jpg
--------------------------------------------------------------------------------
/assets/m2_performance.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/assets/m2_performance.jpg
--------------------------------------------------------------------------------
/assets/tl1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/assets/tl1.png
--------------------------------------------------------------------------------
/assets/tl2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/assets/tl2.png
--------------------------------------------------------------------------------
/docs/codegen.md:
--------------------------------------------------------------------------------
1 | Codegen for TL1 and TL2
2 | ------------------------
3 |
4 | codegen_tl1.py and codegen_tl2.py are using params to generate kernel codes in different devices to achieve fastest performance for TL1 and TL2.
5 |
6 | We cutting weight into multiple compute blocks to best utilize hardware capabilities.
7 |
8 | ### Example
9 | bitnet_b1_58-large:
10 |
11 | - Make sure Matmul kernels shapes \
12 | For example, bitnet_b1_58-large Matmul kernel shapes are:\
13 | [1536, 4096]\
14 | [1536, 1536]\
15 | [4096, 1536]
16 |
17 | - Make sure each BM, BK, bm for each kernel to meet the requirements below
18 | - Generate codes\
19 | For example, for bitnet_b1_58-large, we can gencode like:
20 |
21 | ```bash
22 | # For TL1
23 | python utils/codegen_tl1.py --model bitnet_b1_58-large --BM 256,128,256 --BK 128,64,128 --bm 32,64,32
24 |
25 | # For TL2
26 | python utils/codegen_tl2.py --model bitnet_b1_58-large --BM 256,128,256 --BK 96,192,96 --bm 32,32,32
27 | ```
28 |
29 | ### TL1:
30 | 
31 |
32 | For TL1, we cut weight into M / BM weights, each weight shape is (BM, K). Then we cut weight into K / BK weights, each weight shape is (BM, BK). As for (BM, BK) weight, we cut it the same way into (bm, compute_num / bm) compute blocks, and finish computing in it.
33 |
34 | Thus, we need to make sure
35 | - M % BM == 0
36 | - K % BK == 0
37 | - BM % bm == 0
38 | - bm choose in [32, 64]
39 |
40 | ### TL2:
41 | 
42 |
43 | For TL2, things got a little more complicated. Due to TL2 needs BK % 6 == 0, we need to split K into threeK and twoK, in which compute in TL2 for (M, threeK), compute in TL1 for (M, two_K).
44 |
45 | Thus, we needs to make sure
46 | - M % BM == 0
47 | - K % BK % 32 == 0
48 | - BM % bm == 0
49 | - bm choose in \[32\]
--------------------------------------------------------------------------------
/include/ggml-bitnet.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h"
4 | #include "ggml-backend.h"
5 |
6 | #ifdef __ARM_NEON
7 | #include
8 | typedef float32_t bitnet_float_type;
9 | #else
10 | typedef float bitnet_float_type;
11 | #endif
12 |
13 | #ifdef __cplusplus
14 | extern "C" {
15 | #endif
16 |
17 | struct bitnet_tensor_extra {
18 | int lut_scales_size;
19 | int BK;
20 | int n_tile_num;
21 | uint8_t * qweights;
22 | bitnet_float_type * scales;
23 | };
24 |
25 | GGML_API void ggml_bitnet_init(void);
26 | GGML_API void ggml_bitnet_free(void);
27 | // src0->type == Q4_0/IQ2_XXS/IQ3_XXS
28 | // bitnet.cpp currently only supports BitNet quantization or GPTQ-like quantization (only scales, without zeros)
29 | // If use i-quantization gguf models, the results will be wrong
30 | // TODO: add customized block types Q2_0/Q3_0
31 | GGML_API bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
32 | GGML_API size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
33 | GGML_API void ggml_bitnet_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits);
34 | GGML_API void ggml_bitnet_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits);
35 | GGML_API void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor);
36 | GGML_API int ggml_bitnet_get_type_bits(enum ggml_type type);
37 | GGML_API void ggml_bitnet_set_n_threads(int n_threads);
38 | #if defined(GGML_BITNET_ARM_TL1)
39 | GGML_API void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C);
40 | GGML_API void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT);
41 | #endif
42 | #if defined(GGML_BITNET_X86_TL2)
43 | GGML_API void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C);
44 | GGML_API void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT);
45 | #endif
46 |
47 | #ifdef __cplusplus
48 | }
49 | #endif
50 |
--------------------------------------------------------------------------------
/media/benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/media/benchmark.png
--------------------------------------------------------------------------------
/media/demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/media/demo.mp4
--------------------------------------------------------------------------------
/preset_kernels/Llama3-8B-1.58-100B-tokens/bitnet-lut-kernels-tl1.h:
--------------------------------------------------------------------------------
1 | #if defined(GGML_BITNET_ARM_TL1)
2 | #include "ggml-bitnet.h"
3 | #define GGML_BITNET_MAX_NODES 8192
4 | static bool initialized = false;
5 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;
6 | static size_t bitnet_tensor_extras_index = 0;
7 | static void * aligned_malloc(size_t size) {{
8 | #if defined(_WIN32)
9 | return _aligned_malloc(size, 64);
10 | #else
11 | void * ptr = nullptr;
12 | posix_memalign(&ptr, 64, size);
13 | return ptr;
14 | #endif
15 | }}
16 | static void aligned_free(void * ptr) {{
17 | #if defined(_WIN32)
18 | _aligned_free(ptr);
19 | #else
20 | free(ptr);
21 | #endif
22 | }}
23 |
24 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{
25 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
26 | bitnet_float_type* b = (bitnet_float_type*)b_;
27 | #ifdef __ARM_NEON
28 | float32x4_t temp_max = vdupq_n_f32(0);
29 | for (int i=0; i < k / 4; i++) {{
30 | float32x4_t vec_bs = vld1q_f32(b + 4 * i);
31 | float32x4_t abssum = vabsq_f32(vec_bs);
32 | temp_max = vmaxq_f32(abssum, temp_max);
33 | }}
34 | float32_t scales = 127 / vmaxvq_f32(temp_max);
35 | *lut_scales = scales;
36 | #elif defined __AVX2__
37 | __m256 max_vec = _mm256_set1_ps(0.f);
38 | const __m256 vec_sign = _mm256_set1_ps(-0.0f);
39 | // #pragma unroll
40 | for (int i = 0; i < k / 8; i++) {{
41 | __m256 vec_b = _mm256_loadu_ps(b + i * 8);
42 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);
43 | max_vec = _mm256_max_ps(vec_babs, max_vec);
44 | }}
45 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));
46 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));
47 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));
48 | float scales = 127 / _mm_cvtss_f32(max1);
49 | *lut_scales = scales;
50 | #endif
51 | }}
52 |
53 | void partial_max_reset(void* lut_scales_) {{
54 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
55 | *lut_scales = 0.0;
56 | }}
57 |
58 | #ifdef __ARM_NEON
59 | inline void Transpose_8_8(
60 | int16x8_t *v0,
61 | int16x8_t *v1,
62 | int16x8_t *v2,
63 | int16x8_t *v3,
64 | int16x8_t *v4,
65 | int16x8_t *v5,
66 | int16x8_t *v6,
67 | int16x8_t *v7)
68 | {{
69 | int16x8x2_t q04 = vzipq_s16(*v0, *v4);
70 | int16x8x2_t q15 = vzipq_s16(*v1, *v5);
71 | int16x8x2_t q26 = vzipq_s16(*v2, *v6);
72 | int16x8x2_t q37 = vzipq_s16(*v3, *v7);
73 |
74 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);
75 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);
76 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);
77 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);
78 |
79 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);
80 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);
81 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);
82 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);
83 |
84 | *v0 = q_fin_0.val[0];
85 | *v1 = q_fin_0.val[1];
86 | *v2 = q_fin_1.val[0];
87 | *v3 = q_fin_1.val[1];
88 | *v4 = q_fin_2.val[0];
89 | *v5 = q_fin_2.val[1];
90 | *v6 = q_fin_3.val[0];
91 | *v7 = q_fin_3.val[1];
92 | }}
93 | #endif
94 |
95 | template
96 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{
97 | #ifdef __ARM_NEON
98 | int16x8_t vec_lut[16];
99 | float32_t scales = *lut_scales;
100 | uint8_t tbl_mask[16];
101 | tbl_mask[0] = 0;
102 | tbl_mask[1] = 2;
103 | tbl_mask[2] = 4;
104 | tbl_mask[3] = 6;
105 | tbl_mask[4] = 8;
106 | tbl_mask[5] = 10;
107 | tbl_mask[6] = 12;
108 | tbl_mask[7] = 14;
109 | tbl_mask[8] = 1;
110 | tbl_mask[9] = 3;
111 | tbl_mask[10] = 5;
112 | tbl_mask[11] = 7;
113 | tbl_mask[12] = 9;
114 | tbl_mask[13] = 11;
115 | tbl_mask[14] = 13;
116 | tbl_mask[15] = 15;
117 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);
118 | #pragma unroll
119 | for (int k = 0; k < act_k / 16; ++k) {{
120 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);
121 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);
122 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);
123 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);
124 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);
125 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);
126 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);
127 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);
128 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);
129 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);
130 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);
131 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);
132 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);
133 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);
134 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);
135 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);
136 | vec_lut[0] = vdupq_n_s16(0);
137 | vec_lut[0] = vec_lut[0] - vec_bs_0;
138 | vec_lut[0] = vec_lut[0] - vec_bs_1;
139 | vec_lut[1] = vdupq_n_s16(0);
140 | vec_lut[1] = vec_lut[1] - vec_bs_0;
141 | vec_lut[2] = vdupq_n_s16(0);
142 | vec_lut[2] = vec_lut[2] - vec_bs_0;
143 | vec_lut[2] = vec_lut[2] + vec_bs_1;
144 | vec_lut[3] = vdupq_n_s16(0);
145 | vec_lut[3] = vec_lut[3] - vec_bs_1;
146 | vec_lut[4] = vdupq_n_s16(0);
147 | vec_lut[5] = vec_bs_1;
148 | vec_lut[6] = vec_bs_0;
149 | vec_lut[6] = vec_lut[6] - vec_bs_1;
150 | vec_lut[7] = vec_bs_0;
151 | vec_lut[8] = vec_bs_0;
152 | vec_lut[8] = vec_lut[8] + vec_bs_1;
153 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),
154 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));
155 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),
156 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));
157 | #pragma unroll
158 | for (int idx = 0; idx < 8; idx++) {{
159 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);
160 | int8x8_t q0_low = vget_low_s8(q0_s);
161 | int8x8_t q0_high = vget_high_s8(q0_s);
162 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);
163 | int8x8_t q1_low = vget_low_s8(q1_s);
164 | int8x8_t q1_high = vget_high_s8(q1_s);
165 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);
166 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);
167 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);
168 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);
169 | }}
170 | }}
171 | #endif
172 | }}
173 |
174 | static bool is_type_supported(enum ggml_type type) {{
175 | if (type == GGML_TYPE_Q4_0 ||
176 | type == GGML_TYPE_TL1) {{
177 | return true;
178 | }} else {{
179 | return false;
180 | }}
181 | }}
182 | #include
183 |
184 | #define BM14336_4096 256
185 | #define BBK14336_4096 128
186 | inline void tbl_impl_14336_4096(int32_t* c, int8_t* lut, uint8_t* a) {
187 | #ifdef __ARM_NEON
188 | const int KK = BBK14336_4096 / 2;
189 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
190 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
191 | int8x16_t vec_lut[2 * KK];
192 | int16x8_t vec_c[8];
193 | #pragma unroll
194 | for (int k = 0; k < 2 * KK; k++) {
195 | vec_lut[k] = vld1q_s8(lut + k * 16);
196 | }
197 |
198 | #pragma unroll
199 | for (int i = 0; i < BM14336_4096; i += 64) {
200 | #pragma unroll
201 | for (int i=0; i<8; i++) {
202 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
203 | }
204 |
205 | #pragma unroll
206 | for (int k = 0; k < KK / 2; k++) {
207 |
208 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
209 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
210 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
211 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top);
212 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top);
213 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot);
214 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot);
215 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
216 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
217 | vec_c[0] += vec_v_left_0.val[0];
218 | vec_c[0] += vec_v_right_0.val[0];
219 | vec_c[1] += vec_v_left_0.val[1];
220 | vec_c[1] += vec_v_right_0.val[1];
221 |
222 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
223 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
224 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
225 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top);
226 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top);
227 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot);
228 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot);
229 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
230 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
231 | vec_c[2] += vec_v_left_1.val[0];
232 | vec_c[2] += vec_v_right_1.val[0];
233 | vec_c[3] += vec_v_left_1.val[1];
234 | vec_c[3] += vec_v_right_1.val[1];
235 |
236 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
237 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
238 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
239 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top);
240 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top);
241 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot);
242 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot);
243 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
244 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
245 | vec_c[4] += vec_v_left_2.val[0];
246 | vec_c[4] += vec_v_right_2.val[0];
247 | vec_c[5] += vec_v_left_2.val[1];
248 | vec_c[5] += vec_v_right_2.val[1];
249 |
250 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
251 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
252 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
253 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top);
254 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top);
255 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot);
256 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot);
257 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
258 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
259 | vec_c[6] += vec_v_left_3.val[0];
260 | vec_c[6] += vec_v_right_3.val[0];
261 | vec_c[7] += vec_v_left_3.val[1];
262 | vec_c[7] += vec_v_right_3.val[1];
263 |
264 | }
265 |
266 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
267 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
268 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
269 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
270 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
271 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
272 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
273 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
274 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
275 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
276 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
277 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
278 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
279 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
280 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
281 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
282 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4]));
283 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]);
284 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4);
285 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4);
286 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5]));
287 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]);
288 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5);
289 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5);
290 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6]));
291 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]);
292 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6);
293 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6);
294 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7]));
295 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]);
296 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7);
297 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7);
298 |
299 | }
300 | #endif
301 | }
302 |
303 | int32_t qgemm_lut_14336_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
304 | alignas(32) uint32_t CBits[BM14336_4096];
305 | memset(&(CBits[0]), 0, BM14336_4096 * sizeof(int32_t));
306 | #pragma unroll
307 | for (int32_t k_outer = 0; k_outer < 4096 / BBK14336_4096; ++k_outer) {
308 | tbl_impl_14336_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK14336_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK14336_4096 / 2 / 2 * BM14336_4096)])));
309 | }
310 | #pragma unroll
311 | for (int i = 0; i < BM14336_4096; i++) {
312 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
313 | }
314 | return 0;
315 | };
316 | #include
317 |
318 | #define BM4096_14336 256
319 | #define BBK4096_14336 128
320 | inline void tbl_impl_4096_14336(int32_t* c, int8_t* lut, uint8_t* a) {
321 | #ifdef __ARM_NEON
322 | const int KK = BBK4096_14336 / 2;
323 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
324 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
325 | int8x16_t vec_lut[2 * KK];
326 | int16x8_t vec_c[4];
327 | #pragma unroll
328 | for (int k = 0; k < 2 * KK; k++) {
329 | vec_lut[k] = vld1q_s8(lut + k * 16);
330 | }
331 |
332 | #pragma unroll
333 | for (int i = 0; i < BM4096_14336; i += 32) {
334 | #pragma unroll
335 | for (int i=0; i<4; i++) {
336 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
337 | }
338 |
339 | #pragma unroll
340 | for (int k = 0; k < KK / 4; k++) {
341 |
342 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
343 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
344 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
345 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
346 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
347 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
348 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
349 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
350 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
351 | vec_c[0] += vec_v_left_0.val[0];
352 | vec_c[0] += vec_v_right_0.val[0];
353 | vec_c[1] += vec_v_left_0.val[1];
354 | vec_c[1] += vec_v_right_0.val[1];
355 |
356 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
357 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
358 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
359 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
360 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
361 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
362 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
363 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
364 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
365 | vec_c[0] += vec_v_left_1.val[0];
366 | vec_c[0] += vec_v_right_1.val[0];
367 | vec_c[1] += vec_v_left_1.val[1];
368 | vec_c[1] += vec_v_right_1.val[1];
369 |
370 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
371 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
372 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
373 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
374 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
375 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
376 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
377 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
378 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
379 | vec_c[2] += vec_v_left_2.val[0];
380 | vec_c[2] += vec_v_right_2.val[0];
381 | vec_c[3] += vec_v_left_2.val[1];
382 | vec_c[3] += vec_v_right_2.val[1];
383 |
384 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
385 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
386 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
387 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
388 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
389 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
390 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
391 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
392 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
393 | vec_c[2] += vec_v_left_3.val[0];
394 | vec_c[2] += vec_v_right_3.val[0];
395 | vec_c[3] += vec_v_left_3.val[1];
396 | vec_c[3] += vec_v_right_3.val[1];
397 |
398 | }
399 |
400 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
401 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
402 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
403 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
404 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
405 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
406 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
407 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
408 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
409 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
410 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
411 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
412 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
413 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
414 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
415 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
416 |
417 | }
418 | #endif
419 | }
420 |
421 | int32_t qgemm_lut_4096_14336(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
422 | alignas(32) uint32_t CBits[BM4096_14336];
423 | memset(&(CBits[0]), 0, BM4096_14336 * sizeof(int32_t));
424 | #pragma unroll
425 | for (int32_t k_outer = 0; k_outer < 14336 / BBK4096_14336; ++k_outer) {
426 | tbl_impl_4096_14336((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_14336 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_14336 / 2 / 2 * BM4096_14336)])));
427 | }
428 | #pragma unroll
429 | for (int i = 0; i < BM4096_14336; i++) {
430 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
431 | }
432 | return 0;
433 | };
434 | #include
435 |
436 | #define BM1024_4096 128
437 | #define BBK1024_4096 64
438 | inline void tbl_impl_1024_4096(int32_t* c, int8_t* lut, uint8_t* a) {
439 | #ifdef __ARM_NEON
440 | const int KK = BBK1024_4096 / 2;
441 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
442 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
443 | int8x16_t vec_lut[2 * KK];
444 | int16x8_t vec_c[8];
445 | #pragma unroll
446 | for (int k = 0; k < 2 * KK; k++) {
447 | vec_lut[k] = vld1q_s8(lut + k * 16);
448 | }
449 |
450 | #pragma unroll
451 | for (int i = 0; i < BM1024_4096; i += 64) {
452 | #pragma unroll
453 | for (int i=0; i<8; i++) {
454 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
455 | }
456 |
457 | #pragma unroll
458 | for (int k = 0; k < KK / 2; k++) {
459 |
460 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
461 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
462 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
463 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top);
464 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top);
465 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot);
466 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot);
467 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
468 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
469 | vec_c[0] += vec_v_left_0.val[0];
470 | vec_c[0] += vec_v_right_0.val[0];
471 | vec_c[1] += vec_v_left_0.val[1];
472 | vec_c[1] += vec_v_right_0.val[1];
473 |
474 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
475 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
476 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
477 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top);
478 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top);
479 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot);
480 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot);
481 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
482 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
483 | vec_c[2] += vec_v_left_1.val[0];
484 | vec_c[2] += vec_v_right_1.val[0];
485 | vec_c[3] += vec_v_left_1.val[1];
486 | vec_c[3] += vec_v_right_1.val[1];
487 |
488 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
489 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
490 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
491 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top);
492 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top);
493 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot);
494 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot);
495 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
496 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
497 | vec_c[4] += vec_v_left_2.val[0];
498 | vec_c[4] += vec_v_right_2.val[0];
499 | vec_c[5] += vec_v_left_2.val[1];
500 | vec_c[5] += vec_v_right_2.val[1];
501 |
502 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
503 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
504 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
505 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top);
506 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top);
507 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot);
508 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot);
509 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
510 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
511 | vec_c[6] += vec_v_left_3.val[0];
512 | vec_c[6] += vec_v_right_3.val[0];
513 | vec_c[7] += vec_v_left_3.val[1];
514 | vec_c[7] += vec_v_right_3.val[1];
515 |
516 | }
517 |
518 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
519 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
520 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
521 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
522 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
523 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
524 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
525 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
526 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
527 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
528 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
529 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
530 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
531 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
532 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
533 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
534 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4]));
535 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]);
536 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4);
537 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4);
538 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5]));
539 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]);
540 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5);
541 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5);
542 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6]));
543 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]);
544 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6);
545 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6);
546 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7]));
547 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]);
548 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7);
549 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7);
550 |
551 | }
552 | #endif
553 | }
554 |
555 | int32_t qgemm_lut_1024_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
556 | alignas(32) uint32_t CBits[BM1024_4096];
557 | memset(&(CBits[0]), 0, BM1024_4096 * sizeof(int32_t));
558 | #pragma unroll
559 | for (int32_t k_outer = 0; k_outer < 4096 / BBK1024_4096; ++k_outer) {
560 | tbl_impl_1024_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1024_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1024_4096 / 2 / 2 * BM1024_4096)])));
561 | }
562 | #pragma unroll
563 | for (int i = 0; i < BM1024_4096; i++) {
564 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
565 | }
566 | return 0;
567 | };
568 | #include
569 |
570 | #define BM4096_4096 128
571 | #define BBK4096_4096 64
572 | inline void tbl_impl_4096_4096(int32_t* c, int8_t* lut, uint8_t* a) {
573 | #ifdef __ARM_NEON
574 | const int KK = BBK4096_4096 / 2;
575 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
576 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
577 | int8x16_t vec_lut[2 * KK];
578 | int16x8_t vec_c[4];
579 | #pragma unroll
580 | for (int k = 0; k < 2 * KK; k++) {
581 | vec_lut[k] = vld1q_s8(lut + k * 16);
582 | }
583 |
584 | #pragma unroll
585 | for (int i = 0; i < BM4096_4096; i += 32) {
586 | #pragma unroll
587 | for (int i=0; i<4; i++) {
588 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
589 | }
590 |
591 | #pragma unroll
592 | for (int k = 0; k < KK / 4; k++) {
593 |
594 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
595 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
596 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
597 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
598 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
599 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
600 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
601 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
602 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
603 | vec_c[0] += vec_v_left_0.val[0];
604 | vec_c[0] += vec_v_right_0.val[0];
605 | vec_c[1] += vec_v_left_0.val[1];
606 | vec_c[1] += vec_v_right_0.val[1];
607 |
608 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
609 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
610 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
611 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
612 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
613 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
614 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
615 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
616 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
617 | vec_c[0] += vec_v_left_1.val[0];
618 | vec_c[0] += vec_v_right_1.val[0];
619 | vec_c[1] += vec_v_left_1.val[1];
620 | vec_c[1] += vec_v_right_1.val[1];
621 |
622 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
623 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
624 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
625 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
626 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
627 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
628 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
629 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
630 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
631 | vec_c[2] += vec_v_left_2.val[0];
632 | vec_c[2] += vec_v_right_2.val[0];
633 | vec_c[3] += vec_v_left_2.val[1];
634 | vec_c[3] += vec_v_right_2.val[1];
635 |
636 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
637 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
638 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
639 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
640 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
641 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
642 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
643 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
644 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
645 | vec_c[2] += vec_v_left_3.val[0];
646 | vec_c[2] += vec_v_right_3.val[0];
647 | vec_c[3] += vec_v_left_3.val[1];
648 | vec_c[3] += vec_v_right_3.val[1];
649 |
650 | }
651 |
652 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
653 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
654 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
655 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
656 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
657 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
658 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
659 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
660 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
661 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
662 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
663 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
664 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
665 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
666 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
667 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
668 |
669 | }
670 | #endif
671 | }
672 |
673 | int32_t qgemm_lut_4096_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
674 | alignas(32) uint32_t CBits[BM4096_4096];
675 | memset(&(CBits[0]), 0, BM4096_4096 * sizeof(int32_t));
676 | #pragma unroll
677 | for (int32_t k_outer = 0; k_outer < 4096 / BBK4096_4096; ++k_outer) {
678 | tbl_impl_4096_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_4096 / 2 / 2 * BM4096_4096)])));
679 | }
680 | #pragma unroll
681 | for (int i = 0; i < BM4096_4096; i++) {
682 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
683 | }
684 | return 0;
685 | };
686 |
687 | template
688 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{
689 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));
690 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));
691 |
692 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));
693 | }}
694 | void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {
695 | if (m == 14336 && k == 4096) {
696 | preprocessor_k<4096>(B, LUT_Scales, QLUT);
697 | }
698 | else if (m == 4096 && k == 14336) {
699 | preprocessor_k<14336>(B, LUT_Scales, QLUT);
700 | }
701 | else if (m == 1024 && k == 4096) {
702 | preprocessor_k<4096>(B, LUT_Scales, QLUT);
703 | }
704 | else if (m == 4096 && k == 4096) {
705 | preprocessor_k<4096>(B, LUT_Scales, QLUT);
706 | }
707 | }
708 | void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
709 | if (m == 14336 && k == 4096) {
710 | qgemm_lut_14336_4096(A, LUT, Scales, LUT_Scales, C);
711 | }
712 | else if (m == 4096 && k == 14336) {
713 | qgemm_lut_4096_14336(A, LUT, Scales, LUT_Scales, C);
714 | }
715 | else if (m == 1024 && k == 4096) {
716 | qgemm_lut_1024_4096(A, LUT, Scales, LUT_Scales, C);
717 | }
718 | else if (m == 4096 && k == 4096) {
719 | qgemm_lut_4096_4096(A, LUT, Scales, LUT_Scales, C);
720 | }
721 | }
722 |
723 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {
724 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
725 | return;
726 | }
727 |
728 | int k = tensor->ne[0];
729 | int m = tensor->ne[1];
730 | const int lut_scales_size = 1;
731 | const int scales_size = 1;
732 | int bk = 0;
733 | int bm = 0;
734 |
735 | if (m == 14336 && k == 4096) {
736 | bm = BM14336_4096;
737 | bk = BBK14336_4096;
738 | }
739 | else if (m == 4096 && k == 14336) {
740 | bm = BM4096_14336;
741 | bk = BBK4096_14336;
742 | }
743 | else if (m == 1024 && k == 4096) {
744 | bm = BM1024_4096;
745 | bk = BBK1024_4096;
746 | }
747 | else if (m == 4096 && k == 4096) {
748 | bm = BM4096_4096;
749 | bk = BBK4096_4096;
750 | }
751 |
752 | const int n_tile_num = m / bm;
753 | const int BK = bk;
754 | uint8_t * qweights;
755 | bitnet_float_type * scales;
756 |
757 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));
758 | qweights = (uint8_t *) tensor->data;
759 | float * i2_scales = (float * )(qweights + k * m / 4);
760 | scales[0] = (bitnet_float_type) i2_scales[0];
761 |
762 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;
763 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = {
764 | /* .lut_scales_size = */ lut_scales_size,
765 | /* .scales_size = */ scales_size,
766 | /* .n_tile_num = */ n_tile_num,
767 | /* .qweights = */ qweights,
768 | /* .scales = */ scales
769 | };
770 | }
771 | #endif
--------------------------------------------------------------------------------
/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl1.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 14336
3 | k = 4096
4 | bm = 256
5 | bk = 128
6 | bmm = 64
7 |
8 | [Kernels_1]
9 | m = 4096
10 | k = 14336
11 | bm = 256
12 | bk = 128
13 | bmm = 32
14 |
15 | [Kernels_2]
16 | m = 1024
17 | k = 4096
18 | bm = 128
19 | bk = 64
20 | bmm = 64
21 |
22 | [Kernels_3]
23 | m = 4096
24 | k = 4096
25 | bm = 128
26 | bk = 64
27 | bmm = 32
28 |
29 |
--------------------------------------------------------------------------------
/preset_kernels/Llama3-8B-1.58-100B-tokens/kernel_config_tl2.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 14336
3 | k = 4096
4 | bm = 256
5 | bk = 96
6 | bmm = 32
7 |
8 | [Kernels_1]
9 | m = 4096
10 | k = 14336
11 | bm = 128
12 | bk = 96
13 | bmm = 32
14 |
15 | [Kernels_2]
16 | m = 1024
17 | k = 4096
18 | bm = 256
19 | bk = 96
20 | bmm = 32
21 |
22 | [Kernels_3]
23 | m = 4096
24 | k = 4096
25 | bm = 128
26 | bk = 96
27 | bmm = 32
28 |
29 |
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-3B/bitnet-lut-kernels-tl1.h:
--------------------------------------------------------------------------------
1 | #if defined(GGML_BITNET_ARM_TL1)
2 | #include "ggml-bitnet.h"
3 | #define GGML_BITNET_MAX_NODES 8192
4 | static bool initialized = false;
5 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;
6 | static size_t bitnet_tensor_extras_index = 0;
7 | static void * aligned_malloc(size_t size) {{
8 | #if defined(_WIN32)
9 | return _aligned_malloc(size, 64);
10 | #else
11 | void * ptr = nullptr;
12 | posix_memalign(&ptr, 64, size);
13 | return ptr;
14 | #endif
15 | }}
16 | static void aligned_free(void * ptr) {{
17 | #if defined(_WIN32)
18 | _aligned_free(ptr);
19 | #else
20 | free(ptr);
21 | #endif
22 | }}
23 |
24 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{
25 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
26 | bitnet_float_type* b = (bitnet_float_type*)b_;
27 | #ifdef __ARM_NEON
28 | float32x4_t temp_max = vdupq_n_f32(0);
29 | for (int i=0; i < k / 4; i++) {{
30 | float32x4_t vec_bs = vld1q_f32(b + 4 * i);
31 | float32x4_t abssum = vabsq_f32(vec_bs);
32 | temp_max = vmaxq_f32(abssum, temp_max);
33 | }}
34 | float32_t scales = 127 / vmaxvq_f32(temp_max);
35 | *lut_scales = scales;
36 | #elif defined __AVX2__
37 | __m256 max_vec = _mm256_set1_ps(0.f);
38 | const __m256 vec_sign = _mm256_set1_ps(-0.0f);
39 | // #pragma unroll
40 | for (int i = 0; i < k / 8; i++) {{
41 | __m256 vec_b = _mm256_loadu_ps(b + i * 8);
42 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);
43 | max_vec = _mm256_max_ps(vec_babs, max_vec);
44 | }}
45 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));
46 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));
47 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));
48 | float scales = 127 / _mm_cvtss_f32(max1);
49 | *lut_scales = scales;
50 | #endif
51 | }}
52 |
53 | void partial_max_reset(void* lut_scales_) {{
54 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
55 | *lut_scales = 0.0;
56 | }}
57 |
58 | #ifdef __ARM_NEON
59 | inline void Transpose_8_8(
60 | int16x8_t *v0,
61 | int16x8_t *v1,
62 | int16x8_t *v2,
63 | int16x8_t *v3,
64 | int16x8_t *v4,
65 | int16x8_t *v5,
66 | int16x8_t *v6,
67 | int16x8_t *v7)
68 | {{
69 | int16x8x2_t q04 = vzipq_s16(*v0, *v4);
70 | int16x8x2_t q15 = vzipq_s16(*v1, *v5);
71 | int16x8x2_t q26 = vzipq_s16(*v2, *v6);
72 | int16x8x2_t q37 = vzipq_s16(*v3, *v7);
73 |
74 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);
75 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);
76 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);
77 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);
78 |
79 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);
80 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);
81 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);
82 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);
83 |
84 | *v0 = q_fin_0.val[0];
85 | *v1 = q_fin_0.val[1];
86 | *v2 = q_fin_1.val[0];
87 | *v3 = q_fin_1.val[1];
88 | *v4 = q_fin_2.val[0];
89 | *v5 = q_fin_2.val[1];
90 | *v6 = q_fin_3.val[0];
91 | *v7 = q_fin_3.val[1];
92 | }}
93 | #endif
94 |
95 | template
96 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{
97 | #ifdef __ARM_NEON
98 | int16x8_t vec_lut[16];
99 | float32_t scales = *lut_scales;
100 | uint8_t tbl_mask[16];
101 | tbl_mask[0] = 0;
102 | tbl_mask[1] = 2;
103 | tbl_mask[2] = 4;
104 | tbl_mask[3] = 6;
105 | tbl_mask[4] = 8;
106 | tbl_mask[5] = 10;
107 | tbl_mask[6] = 12;
108 | tbl_mask[7] = 14;
109 | tbl_mask[8] = 1;
110 | tbl_mask[9] = 3;
111 | tbl_mask[10] = 5;
112 | tbl_mask[11] = 7;
113 | tbl_mask[12] = 9;
114 | tbl_mask[13] = 11;
115 | tbl_mask[14] = 13;
116 | tbl_mask[15] = 15;
117 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);
118 | #pragma unroll
119 | for (int k = 0; k < act_k / 16; ++k) {{
120 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);
121 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);
122 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);
123 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);
124 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);
125 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);
126 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);
127 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);
128 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);
129 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);
130 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);
131 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);
132 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);
133 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);
134 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);
135 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);
136 | vec_lut[0] = vdupq_n_s16(0);
137 | vec_lut[0] = vec_lut[0] - vec_bs_0;
138 | vec_lut[0] = vec_lut[0] - vec_bs_1;
139 | vec_lut[1] = vdupq_n_s16(0);
140 | vec_lut[1] = vec_lut[1] - vec_bs_0;
141 | vec_lut[2] = vdupq_n_s16(0);
142 | vec_lut[2] = vec_lut[2] - vec_bs_0;
143 | vec_lut[2] = vec_lut[2] + vec_bs_1;
144 | vec_lut[3] = vdupq_n_s16(0);
145 | vec_lut[3] = vec_lut[3] - vec_bs_1;
146 | vec_lut[4] = vdupq_n_s16(0);
147 | vec_lut[5] = vec_bs_1;
148 | vec_lut[6] = vec_bs_0;
149 | vec_lut[6] = vec_lut[6] - vec_bs_1;
150 | vec_lut[7] = vec_bs_0;
151 | vec_lut[8] = vec_bs_0;
152 | vec_lut[8] = vec_lut[8] + vec_bs_1;
153 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),
154 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));
155 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),
156 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));
157 | #pragma unroll
158 | for (int idx = 0; idx < 8; idx++) {{
159 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);
160 | int8x8_t q0_low = vget_low_s8(q0_s);
161 | int8x8_t q0_high = vget_high_s8(q0_s);
162 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);
163 | int8x8_t q1_low = vget_low_s8(q1_s);
164 | int8x8_t q1_high = vget_high_s8(q1_s);
165 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);
166 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);
167 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);
168 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);
169 | }}
170 | }}
171 | #endif
172 | }}
173 |
174 | static bool is_type_supported(enum ggml_type type) {{
175 | if (type == GGML_TYPE_Q4_0 ||
176 | type == GGML_TYPE_TL1) {{
177 | return true;
178 | }} else {{
179 | return false;
180 | }}
181 | }}
182 | #include
183 |
184 | #define BM3200_8640 160
185 | #define BBK3200_8640 64
186 | inline void tbl_impl_3200_8640(int32_t* c, int8_t* lut, uint8_t* a) {
187 | #ifdef __ARM_NEON
188 | const int KK = BBK3200_8640 / 2;
189 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
190 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
191 | int8x16_t vec_lut[2 * KK];
192 | int16x8_t vec_c[4];
193 | #pragma unroll
194 | for (int k = 0; k < 2 * KK; k++) {
195 | vec_lut[k] = vld1q_s8(lut + k * 16);
196 | }
197 |
198 | #pragma unroll
199 | for (int i = 0; i < BM3200_8640; i += 32) {
200 | #pragma unroll
201 | for (int i=0; i<4; i++) {
202 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
203 | }
204 |
205 | #pragma unroll
206 | for (int k = 0; k < KK / 4; k++) {
207 |
208 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
209 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
210 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
211 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
212 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
213 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
214 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
215 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
216 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
217 | vec_c[0] += vec_v_left_0.val[0];
218 | vec_c[0] += vec_v_right_0.val[0];
219 | vec_c[1] += vec_v_left_0.val[1];
220 | vec_c[1] += vec_v_right_0.val[1];
221 |
222 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
223 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
224 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
225 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
226 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
227 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
228 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
229 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
230 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
231 | vec_c[0] += vec_v_left_1.val[0];
232 | vec_c[0] += vec_v_right_1.val[0];
233 | vec_c[1] += vec_v_left_1.val[1];
234 | vec_c[1] += vec_v_right_1.val[1];
235 |
236 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
237 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
238 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
239 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
240 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
241 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
242 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
243 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
244 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
245 | vec_c[2] += vec_v_left_2.val[0];
246 | vec_c[2] += vec_v_right_2.val[0];
247 | vec_c[3] += vec_v_left_2.val[1];
248 | vec_c[3] += vec_v_right_2.val[1];
249 |
250 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
251 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
252 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
253 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
254 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
255 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
256 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
257 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
258 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
259 | vec_c[2] += vec_v_left_3.val[0];
260 | vec_c[2] += vec_v_right_3.val[0];
261 | vec_c[3] += vec_v_left_3.val[1];
262 | vec_c[3] += vec_v_right_3.val[1];
263 |
264 | }
265 |
266 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
267 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
268 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
269 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
270 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
271 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
272 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
273 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
274 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
275 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
276 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
277 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
278 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
279 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
280 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
281 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
282 |
283 | }
284 | #endif
285 | }
286 |
287 | int32_t qgemm_lut_3200_8640(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
288 | alignas(32) uint32_t CBits[BM3200_8640];
289 | memset(&(CBits[0]), 0, BM3200_8640 * sizeof(int32_t));
290 | #pragma unroll
291 | for (int32_t k_outer = 0; k_outer < 8640 / BBK3200_8640; ++k_outer) {
292 | tbl_impl_3200_8640((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_8640 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_8640 / 2 / 2 * BM3200_8640)])));
293 | }
294 | #pragma unroll
295 | for (int i = 0; i < BM3200_8640; i++) {
296 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
297 | }
298 | return 0;
299 | };
300 | #include
301 |
302 | #define BM3200_3200 320
303 | #define BBK3200_3200 128
304 | inline void tbl_impl_3200_3200(int32_t* c, int8_t* lut, uint8_t* a) {
305 | #ifdef __ARM_NEON
306 | const int KK = BBK3200_3200 / 2;
307 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
308 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
309 | int8x16_t vec_lut[2 * KK];
310 | int16x8_t vec_c[8];
311 | #pragma unroll
312 | for (int k = 0; k < 2 * KK; k++) {
313 | vec_lut[k] = vld1q_s8(lut + k * 16);
314 | }
315 |
316 | #pragma unroll
317 | for (int i = 0; i < BM3200_3200; i += 64) {
318 | #pragma unroll
319 | for (int i=0; i<8; i++) {
320 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
321 | }
322 |
323 | #pragma unroll
324 | for (int k = 0; k < KK / 2; k++) {
325 |
326 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
327 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
328 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
329 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top);
330 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top);
331 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot);
332 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot);
333 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
334 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
335 | vec_c[0] += vec_v_left_0.val[0];
336 | vec_c[0] += vec_v_right_0.val[0];
337 | vec_c[1] += vec_v_left_0.val[1];
338 | vec_c[1] += vec_v_right_0.val[1];
339 |
340 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
341 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
342 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
343 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top);
344 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top);
345 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot);
346 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot);
347 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
348 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
349 | vec_c[2] += vec_v_left_1.val[0];
350 | vec_c[2] += vec_v_right_1.val[0];
351 | vec_c[3] += vec_v_left_1.val[1];
352 | vec_c[3] += vec_v_right_1.val[1];
353 |
354 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
355 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
356 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
357 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top);
358 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top);
359 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot);
360 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot);
361 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
362 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
363 | vec_c[4] += vec_v_left_2.val[0];
364 | vec_c[4] += vec_v_right_2.val[0];
365 | vec_c[5] += vec_v_left_2.val[1];
366 | vec_c[5] += vec_v_right_2.val[1];
367 |
368 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
369 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
370 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
371 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top);
372 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top);
373 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot);
374 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot);
375 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
376 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
377 | vec_c[6] += vec_v_left_3.val[0];
378 | vec_c[6] += vec_v_right_3.val[0];
379 | vec_c[7] += vec_v_left_3.val[1];
380 | vec_c[7] += vec_v_right_3.val[1];
381 |
382 | }
383 |
384 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
385 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
386 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
387 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
388 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
389 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
390 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
391 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
392 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
393 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
394 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
395 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
396 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
397 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
398 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
399 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
400 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4]));
401 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]);
402 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4);
403 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4);
404 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5]));
405 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]);
406 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5);
407 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5);
408 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6]));
409 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]);
410 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6);
411 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6);
412 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7]));
413 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]);
414 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7);
415 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7);
416 |
417 | }
418 | #endif
419 | }
420 |
421 | int32_t qgemm_lut_3200_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
422 | alignas(32) uint32_t CBits[BM3200_3200];
423 | memset(&(CBits[0]), 0, BM3200_3200 * sizeof(int32_t));
424 | #pragma unroll
425 | for (int32_t k_outer = 0; k_outer < 3200 / BBK3200_3200; ++k_outer) {
426 | tbl_impl_3200_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_3200 / 2 / 2 * BM3200_3200)])));
427 | }
428 | #pragma unroll
429 | for (int i = 0; i < BM3200_3200; i++) {
430 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
431 | }
432 | return 0;
433 | };
434 | #include
435 |
436 | #define BM8640_3200 320
437 | #define BBK8640_3200 64
438 | inline void tbl_impl_8640_3200(int32_t* c, int8_t* lut, uint8_t* a) {
439 | #ifdef __ARM_NEON
440 | const int KK = BBK8640_3200 / 2;
441 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
442 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
443 | int8x16_t vec_lut[2 * KK];
444 | int16x8_t vec_c[4];
445 | #pragma unroll
446 | for (int k = 0; k < 2 * KK; k++) {
447 | vec_lut[k] = vld1q_s8(lut + k * 16);
448 | }
449 |
450 | #pragma unroll
451 | for (int i = 0; i < BM8640_3200; i += 32) {
452 | #pragma unroll
453 | for (int i=0; i<4; i++) {
454 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
455 | }
456 |
457 | #pragma unroll
458 | for (int k = 0; k < KK / 4; k++) {
459 |
460 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
461 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
462 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
463 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
464 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
465 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
466 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
467 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
468 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
469 | vec_c[0] += vec_v_left_0.val[0];
470 | vec_c[0] += vec_v_right_0.val[0];
471 | vec_c[1] += vec_v_left_0.val[1];
472 | vec_c[1] += vec_v_right_0.val[1];
473 |
474 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
475 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
476 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
477 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
478 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
479 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
480 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
481 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
482 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
483 | vec_c[0] += vec_v_left_1.val[0];
484 | vec_c[0] += vec_v_right_1.val[0];
485 | vec_c[1] += vec_v_left_1.val[1];
486 | vec_c[1] += vec_v_right_1.val[1];
487 |
488 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
489 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
490 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
491 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
492 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
493 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
494 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
495 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
496 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
497 | vec_c[2] += vec_v_left_2.val[0];
498 | vec_c[2] += vec_v_right_2.val[0];
499 | vec_c[3] += vec_v_left_2.val[1];
500 | vec_c[3] += vec_v_right_2.val[1];
501 |
502 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
503 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
504 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
505 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
506 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
507 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
508 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
509 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
510 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
511 | vec_c[2] += vec_v_left_3.val[0];
512 | vec_c[2] += vec_v_right_3.val[0];
513 | vec_c[3] += vec_v_left_3.val[1];
514 | vec_c[3] += vec_v_right_3.val[1];
515 |
516 | }
517 |
518 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
519 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
520 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
521 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
522 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
523 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
524 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
525 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
526 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
527 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
528 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
529 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
530 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
531 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
532 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
533 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
534 |
535 | }
536 | #endif
537 | }
538 |
539 | int32_t qgemm_lut_8640_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
540 | alignas(32) uint32_t CBits[BM8640_3200];
541 | memset(&(CBits[0]), 0, BM8640_3200 * sizeof(int32_t));
542 | #pragma unroll
543 | for (int32_t k_outer = 0; k_outer < 3200 / BBK8640_3200; ++k_outer) {
544 | tbl_impl_8640_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK8640_3200 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK8640_3200 / 2 / 2 * BM8640_3200)])));
545 | }
546 | #pragma unroll
547 | for (int i = 0; i < BM8640_3200; i++) {
548 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
549 | }
550 | return 0;
551 | };
552 |
553 | template
554 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{
555 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));
556 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));
557 |
558 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));
559 | }}
560 | void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {
561 | if (m == 3200 && k == 8640) {
562 | preprocessor_k<8640>(B, LUT_Scales, QLUT);
563 | }
564 | else if (m == 3200 && k == 3200) {
565 | preprocessor_k<3200>(B, LUT_Scales, QLUT);
566 | }
567 | else if (m == 8640 && k == 3200) {
568 | preprocessor_k<3200>(B, LUT_Scales, QLUT);
569 | }
570 | }
571 | void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
572 | if (m == 3200 && k == 8640) {
573 | qgemm_lut_3200_8640(A, LUT, Scales, LUT_Scales, C);
574 | }
575 | else if (m == 3200 && k == 3200) {
576 | qgemm_lut_3200_3200(A, LUT, Scales, LUT_Scales, C);
577 | }
578 | else if (m == 8640 && k == 3200) {
579 | qgemm_lut_8640_3200(A, LUT, Scales, LUT_Scales, C);
580 | }
581 | }
582 |
583 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {
584 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
585 | return;
586 | }
587 |
588 | int k = tensor->ne[0];
589 | int m = tensor->ne[1];
590 | const int lut_scales_size = 1;
591 | const int scales_size = 1;
592 | int bk = 0;
593 | int bm = 0;
594 |
595 | if (m == 3200 && k == 8640) {
596 | bm = BM3200_8640;
597 | bk = BBK3200_8640;
598 | }
599 | else if (m == 3200 && k == 3200) {
600 | bm = BM3200_3200;
601 | bk = BBK3200_3200;
602 | }
603 | else if (m == 8640 && k == 3200) {
604 | bm = BM8640_3200;
605 | bk = BBK8640_3200;
606 | }
607 |
608 | const int n_tile_num = m / bm;
609 | const int BK = bk;
610 | uint8_t * qweights;
611 | bitnet_float_type * scales;
612 |
613 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));
614 | qweights = (uint8_t *) tensor->data;
615 | float * i2_scales = (float * )(qweights + k * m / 4);
616 | scales[0] = (bitnet_float_type) i2_scales[0];
617 |
618 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;
619 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = {
620 | /* .lut_scales_size = */ lut_scales_size,
621 | /* .scales_size = */ scales_size,
622 | /* .n_tile_num = */ n_tile_num,
623 | /* .qweights = */ qweights,
624 | /* .scales = */ scales
625 | };
626 | }
627 | #endif
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-3B/kernel_config_tl1.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 3200
3 | k = 8640
4 | bm = 160
5 | bk = 64
6 | bmm = 32
7 |
8 | [Kernels_1]
9 | m = 3200
10 | k = 3200
11 | bm = 320
12 | bk = 128
13 | bmm = 64
14 |
15 | [Kernels_2]
16 | m = 8640
17 | k = 3200
18 | bm = 320
19 | bk = 64
20 | bmm = 32
21 |
22 |
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-3B/kernel_config_tl2.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 3200
3 | k = 8640
4 | bm = 160
5 | bk = 96
6 | bmm = 32
7 |
8 | [Kernels_1]
9 | m = 3200
10 | k = 3200
11 | bm = 320
12 | bk = 96
13 | bmm = 32
14 |
15 | [Kernels_2]
16 | m = 8640
17 | k = 3200
18 | bm = 320
19 | bk = 96
20 | bmm = 32
21 |
22 |
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-large/bitnet-lut-kernels-tl1.h:
--------------------------------------------------------------------------------
1 | #if defined(GGML_BITNET_ARM_TL1)
2 | #include "ggml-bitnet.h"
3 | #define GGML_BITNET_MAX_NODES 8192
4 | static bool initialized = false;
5 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;
6 | static size_t bitnet_tensor_extras_index = 0;
7 | static void * aligned_malloc(size_t size) {{
8 | #if defined(_WIN32)
9 | return _aligned_malloc(size, 64);
10 | #else
11 | void * ptr = nullptr;
12 | posix_memalign(&ptr, 64, size);
13 | return ptr;
14 | #endif
15 | }}
16 | static void aligned_free(void * ptr) {{
17 | #if defined(_WIN32)
18 | _aligned_free(ptr);
19 | #else
20 | free(ptr);
21 | #endif
22 | }}
23 |
24 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{
25 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
26 | bitnet_float_type* b = (bitnet_float_type*)b_;
27 | #ifdef __ARM_NEON
28 | float32x4_t temp_max = vdupq_n_f32(0);
29 | for (int i=0; i < k / 4; i++) {{
30 | float32x4_t vec_bs = vld1q_f32(b + 4 * i);
31 | float32x4_t abssum = vabsq_f32(vec_bs);
32 | temp_max = vmaxq_f32(abssum, temp_max);
33 | }}
34 | float32_t scales = 127 / vmaxvq_f32(temp_max);
35 | *lut_scales = scales;
36 | #elif defined __AVX2__
37 | __m256 max_vec = _mm256_set1_ps(0.f);
38 | const __m256 vec_sign = _mm256_set1_ps(-0.0f);
39 | // #pragma unroll
40 | for (int i = 0; i < k / 8; i++) {{
41 | __m256 vec_b = _mm256_loadu_ps(b + i * 8);
42 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);
43 | max_vec = _mm256_max_ps(vec_babs, max_vec);
44 | }}
45 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));
46 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));
47 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));
48 | float scales = 127 / _mm_cvtss_f32(max1);
49 | *lut_scales = scales;
50 | #endif
51 | }}
52 |
53 | void partial_max_reset(void* lut_scales_) {{
54 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
55 | *lut_scales = 0.0;
56 | }}
57 |
58 | #ifdef __ARM_NEON
59 | inline void Transpose_8_8(
60 | int16x8_t *v0,
61 | int16x8_t *v1,
62 | int16x8_t *v2,
63 | int16x8_t *v3,
64 | int16x8_t *v4,
65 | int16x8_t *v5,
66 | int16x8_t *v6,
67 | int16x8_t *v7)
68 | {{
69 | int16x8x2_t q04 = vzipq_s16(*v0, *v4);
70 | int16x8x2_t q15 = vzipq_s16(*v1, *v5);
71 | int16x8x2_t q26 = vzipq_s16(*v2, *v6);
72 | int16x8x2_t q37 = vzipq_s16(*v3, *v7);
73 |
74 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);
75 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);
76 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);
77 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);
78 |
79 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);
80 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);
81 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);
82 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);
83 |
84 | *v0 = q_fin_0.val[0];
85 | *v1 = q_fin_0.val[1];
86 | *v2 = q_fin_1.val[0];
87 | *v3 = q_fin_1.val[1];
88 | *v4 = q_fin_2.val[0];
89 | *v5 = q_fin_2.val[1];
90 | *v6 = q_fin_3.val[0];
91 | *v7 = q_fin_3.val[1];
92 | }}
93 | #endif
94 |
95 | template
96 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{
97 | #ifdef __ARM_NEON
98 | int16x8_t vec_lut[16];
99 | float32_t scales = *lut_scales;
100 | uint8_t tbl_mask[16];
101 | tbl_mask[0] = 0;
102 | tbl_mask[1] = 2;
103 | tbl_mask[2] = 4;
104 | tbl_mask[3] = 6;
105 | tbl_mask[4] = 8;
106 | tbl_mask[5] = 10;
107 | tbl_mask[6] = 12;
108 | tbl_mask[7] = 14;
109 | tbl_mask[8] = 1;
110 | tbl_mask[9] = 3;
111 | tbl_mask[10] = 5;
112 | tbl_mask[11] = 7;
113 | tbl_mask[12] = 9;
114 | tbl_mask[13] = 11;
115 | tbl_mask[14] = 13;
116 | tbl_mask[15] = 15;
117 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);
118 | #pragma unroll
119 | for (int k = 0; k < act_k / 16; ++k) {{
120 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);
121 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);
122 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);
123 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);
124 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);
125 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);
126 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);
127 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);
128 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);
129 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);
130 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);
131 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);
132 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);
133 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);
134 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);
135 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);
136 | vec_lut[0] = vdupq_n_s16(0);
137 | vec_lut[0] = vec_lut[0] - vec_bs_0;
138 | vec_lut[0] = vec_lut[0] - vec_bs_1;
139 | vec_lut[1] = vdupq_n_s16(0);
140 | vec_lut[1] = vec_lut[1] - vec_bs_0;
141 | vec_lut[2] = vdupq_n_s16(0);
142 | vec_lut[2] = vec_lut[2] - vec_bs_0;
143 | vec_lut[2] = vec_lut[2] + vec_bs_1;
144 | vec_lut[3] = vdupq_n_s16(0);
145 | vec_lut[3] = vec_lut[3] - vec_bs_1;
146 | vec_lut[4] = vdupq_n_s16(0);
147 | vec_lut[5] = vec_bs_1;
148 | vec_lut[6] = vec_bs_0;
149 | vec_lut[6] = vec_lut[6] - vec_bs_1;
150 | vec_lut[7] = vec_bs_0;
151 | vec_lut[8] = vec_bs_0;
152 | vec_lut[8] = vec_lut[8] + vec_bs_1;
153 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),
154 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));
155 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),
156 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));
157 | #pragma unroll
158 | for (int idx = 0; idx < 8; idx++) {{
159 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);
160 | int8x8_t q0_low = vget_low_s8(q0_s);
161 | int8x8_t q0_high = vget_high_s8(q0_s);
162 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);
163 | int8x8_t q1_low = vget_low_s8(q1_s);
164 | int8x8_t q1_high = vget_high_s8(q1_s);
165 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);
166 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);
167 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);
168 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);
169 | }}
170 | }}
171 | #endif
172 | }}
173 |
174 | static bool is_type_supported(enum ggml_type type) {{
175 | if (type == GGML_TYPE_Q4_0 ||
176 | type == GGML_TYPE_TL1) {{
177 | return true;
178 | }} else {{
179 | return false;
180 | }}
181 | }}
182 | #include
183 |
184 | #define BM1536_4096 256
185 | #define BBK1536_4096 128
186 | inline void tbl_impl_1536_4096(int32_t* c, int8_t* lut, uint8_t* a) {
187 | #ifdef __ARM_NEON
188 | const int KK = BBK1536_4096 / 2;
189 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
190 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
191 | int8x16_t vec_lut[2 * KK];
192 | int16x8_t vec_c[4];
193 | #pragma unroll
194 | for (int k = 0; k < 2 * KK; k++) {
195 | vec_lut[k] = vld1q_s8(lut + k * 16);
196 | }
197 |
198 | #pragma unroll
199 | for (int i = 0; i < BM1536_4096; i += 32) {
200 | #pragma unroll
201 | for (int i=0; i<4; i++) {
202 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
203 | }
204 |
205 | #pragma unroll
206 | for (int k = 0; k < KK / 4; k++) {
207 |
208 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
209 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
210 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
211 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
212 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
213 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
214 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
215 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
216 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
217 | vec_c[0] += vec_v_left_0.val[0];
218 | vec_c[0] += vec_v_right_0.val[0];
219 | vec_c[1] += vec_v_left_0.val[1];
220 | vec_c[1] += vec_v_right_0.val[1];
221 |
222 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
223 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
224 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
225 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
226 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
227 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
228 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
229 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
230 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
231 | vec_c[0] += vec_v_left_1.val[0];
232 | vec_c[0] += vec_v_right_1.val[0];
233 | vec_c[1] += vec_v_left_1.val[1];
234 | vec_c[1] += vec_v_right_1.val[1];
235 |
236 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
237 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
238 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
239 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
240 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
241 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
242 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
243 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
244 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
245 | vec_c[2] += vec_v_left_2.val[0];
246 | vec_c[2] += vec_v_right_2.val[0];
247 | vec_c[3] += vec_v_left_2.val[1];
248 | vec_c[3] += vec_v_right_2.val[1];
249 |
250 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
251 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
252 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
253 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
254 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
255 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
256 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
257 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
258 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
259 | vec_c[2] += vec_v_left_3.val[0];
260 | vec_c[2] += vec_v_right_3.val[0];
261 | vec_c[3] += vec_v_left_3.val[1];
262 | vec_c[3] += vec_v_right_3.val[1];
263 |
264 | }
265 |
266 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
267 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
268 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
269 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
270 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
271 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
272 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
273 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
274 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
275 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
276 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
277 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
278 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
279 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
280 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
281 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
282 |
283 | }
284 | #endif
285 | }
286 |
287 | int32_t qgemm_lut_1536_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
288 | alignas(32) uint32_t CBits[BM1536_4096];
289 | memset(&(CBits[0]), 0, BM1536_4096 * sizeof(int32_t));
290 | #pragma unroll
291 | for (int32_t k_outer = 0; k_outer < 4096 / BBK1536_4096; ++k_outer) {
292 | tbl_impl_1536_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_4096 / 2 / 2 * BM1536_4096)])));
293 | }
294 | #pragma unroll
295 | for (int i = 0; i < BM1536_4096; i++) {
296 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
297 | }
298 | return 0;
299 | };
300 | #include
301 |
302 | #define BM1536_1536 128
303 | #define BBK1536_1536 64
304 | inline void tbl_impl_1536_1536(int32_t* c, int8_t* lut, uint8_t* a) {
305 | #ifdef __ARM_NEON
306 | const int KK = BBK1536_1536 / 2;
307 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
308 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
309 | int8x16_t vec_lut[2 * KK];
310 | int16x8_t vec_c[8];
311 | #pragma unroll
312 | for (int k = 0; k < 2 * KK; k++) {
313 | vec_lut[k] = vld1q_s8(lut + k * 16);
314 | }
315 |
316 | #pragma unroll
317 | for (int i = 0; i < BM1536_1536; i += 64) {
318 | #pragma unroll
319 | for (int i=0; i<8; i++) {
320 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
321 | }
322 |
323 | #pragma unroll
324 | for (int k = 0; k < KK / 2; k++) {
325 |
326 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
327 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
328 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
329 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top);
330 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top);
331 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot);
332 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot);
333 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
334 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
335 | vec_c[0] += vec_v_left_0.val[0];
336 | vec_c[0] += vec_v_right_0.val[0];
337 | vec_c[1] += vec_v_left_0.val[1];
338 | vec_c[1] += vec_v_right_0.val[1];
339 |
340 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
341 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
342 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
343 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top);
344 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top);
345 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot);
346 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot);
347 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
348 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
349 | vec_c[2] += vec_v_left_1.val[0];
350 | vec_c[2] += vec_v_right_1.val[0];
351 | vec_c[3] += vec_v_left_1.val[1];
352 | vec_c[3] += vec_v_right_1.val[1];
353 |
354 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
355 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
356 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
357 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top);
358 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top);
359 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot);
360 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot);
361 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
362 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
363 | vec_c[4] += vec_v_left_2.val[0];
364 | vec_c[4] += vec_v_right_2.val[0];
365 | vec_c[5] += vec_v_left_2.val[1];
366 | vec_c[5] += vec_v_right_2.val[1];
367 |
368 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
369 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
370 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
371 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top);
372 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top);
373 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot);
374 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot);
375 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
376 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
377 | vec_c[6] += vec_v_left_3.val[0];
378 | vec_c[6] += vec_v_right_3.val[0];
379 | vec_c[7] += vec_v_left_3.val[1];
380 | vec_c[7] += vec_v_right_3.val[1];
381 |
382 | }
383 |
384 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
385 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
386 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
387 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
388 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
389 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
390 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
391 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
392 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
393 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
394 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
395 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
396 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
397 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
398 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
399 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
400 | int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4]));
401 | int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]);
402 | vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4);
403 | vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4);
404 | int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5]));
405 | int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]);
406 | vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5);
407 | vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5);
408 | int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6]));
409 | int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]);
410 | vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6);
411 | vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6);
412 | int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7]));
413 | int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]);
414 | vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7);
415 | vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7);
416 |
417 | }
418 | #endif
419 | }
420 |
421 | int32_t qgemm_lut_1536_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
422 | alignas(32) uint32_t CBits[BM1536_1536];
423 | memset(&(CBits[0]), 0, BM1536_1536 * sizeof(int32_t));
424 | #pragma unroll
425 | for (int32_t k_outer = 0; k_outer < 1536 / BBK1536_1536; ++k_outer) {
426 | tbl_impl_1536_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1536_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1536_1536 / 2 / 2 * BM1536_1536)])));
427 | }
428 | #pragma unroll
429 | for (int i = 0; i < BM1536_1536; i++) {
430 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
431 | }
432 | return 0;
433 | };
434 | #include
435 |
436 | #define BM4096_1536 256
437 | #define BBK4096_1536 128
438 | inline void tbl_impl_4096_1536(int32_t* c, int8_t* lut, uint8_t* a) {
439 | #ifdef __ARM_NEON
440 | const int KK = BBK4096_1536 / 2;
441 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);
442 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);
443 | int8x16_t vec_lut[2 * KK];
444 | int16x8_t vec_c[4];
445 | #pragma unroll
446 | for (int k = 0; k < 2 * KK; k++) {
447 | vec_lut[k] = vld1q_s8(lut + k * 16);
448 | }
449 |
450 | #pragma unroll
451 | for (int i = 0; i < BM4096_1536; i += 32) {
452 | #pragma unroll
453 | for (int i=0; i<4; i++) {
454 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);
455 | }
456 |
457 | #pragma unroll
458 | for (int k = 0; k < KK / 4; k++) {
459 |
460 | uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16);
461 | uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4);
462 | uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask);
463 | int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top);
464 | int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top);
465 | int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot);
466 | int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot);
467 | int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0);
468 | int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0);
469 | vec_c[0] += vec_v_left_0.val[0];
470 | vec_c[0] += vec_v_right_0.val[0];
471 | vec_c[1] += vec_v_left_0.val[1];
472 | vec_c[1] += vec_v_right_0.val[1];
473 |
474 | uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16);
475 | uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4);
476 | uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask);
477 | int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top);
478 | int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top);
479 | int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot);
480 | int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot);
481 | int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0);
482 | int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0);
483 | vec_c[0] += vec_v_left_1.val[0];
484 | vec_c[0] += vec_v_right_1.val[0];
485 | vec_c[1] += vec_v_left_1.val[1];
486 | vec_c[1] += vec_v_right_1.val[1];
487 |
488 | uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16);
489 | uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4);
490 | uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask);
491 | int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top);
492 | int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top);
493 | int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot);
494 | int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot);
495 | int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0);
496 | int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0);
497 | vec_c[2] += vec_v_left_2.val[0];
498 | vec_c[2] += vec_v_right_2.val[0];
499 | vec_c[3] += vec_v_left_2.val[1];
500 | vec_c[3] += vec_v_right_2.val[1];
501 |
502 | uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16);
503 | uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4);
504 | uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask);
505 | int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top);
506 | int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top);
507 | int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot);
508 | int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot);
509 | int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0);
510 | int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0);
511 | vec_c[2] += vec_v_left_3.val[0];
512 | vec_c[2] += vec_v_right_3.val[0];
513 | vec_c[3] += vec_v_left_3.val[1];
514 | vec_c[3] += vec_v_right_3.val[1];
515 |
516 | }
517 |
518 | int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0]));
519 | int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]);
520 | vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0);
521 | vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0);
522 | int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1]));
523 | int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]);
524 | vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1);
525 | vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1);
526 | int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2]));
527 | int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]);
528 | vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2);
529 | vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2);
530 | int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3]));
531 | int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]);
532 | vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3);
533 | vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3);
534 |
535 | }
536 | #endif
537 | }
538 |
539 | int32_t qgemm_lut_4096_1536(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
540 | alignas(32) uint32_t CBits[BM4096_1536];
541 | memset(&(CBits[0]), 0, BM4096_1536 * sizeof(int32_t));
542 | #pragma unroll
543 | for (int32_t k_outer = 0; k_outer < 1536 / BBK4096_1536; ++k_outer) {
544 | tbl_impl_4096_1536((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_1536 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_1536 / 2 / 2 * BM4096_1536)])));
545 | }
546 | #pragma unroll
547 | for (int i = 0; i < BM4096_1536; i++) {
548 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];
549 | }
550 | return 0;
551 | };
552 |
553 | template
554 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{
555 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));
556 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));
557 |
558 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));
559 | }}
560 | void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {
561 | if (m == 1536 && k == 4096) {
562 | preprocessor_k<4096>(B, LUT_Scales, QLUT);
563 | }
564 | else if (m == 1536 && k == 1536) {
565 | preprocessor_k<1536>(B, LUT_Scales, QLUT);
566 | }
567 | else if (m == 4096 && k == 1536) {
568 | preprocessor_k<1536>(B, LUT_Scales, QLUT);
569 | }
570 | }
571 | void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
572 | if (m == 1536 && k == 4096) {
573 | qgemm_lut_1536_4096(A, LUT, Scales, LUT_Scales, C);
574 | }
575 | else if (m == 1536 && k == 1536) {
576 | qgemm_lut_1536_1536(A, LUT, Scales, LUT_Scales, C);
577 | }
578 | else if (m == 4096 && k == 1536) {
579 | qgemm_lut_4096_1536(A, LUT, Scales, LUT_Scales, C);
580 | }
581 | }
582 |
583 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {
584 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
585 | return;
586 | }
587 |
588 | int k = tensor->ne[0];
589 | int m = tensor->ne[1];
590 | const int lut_scales_size = 1;
591 | const int scales_size = 1;
592 | int bk = 0;
593 | int bm = 0;
594 |
595 | if (m == 1536 && k == 4096) {
596 | bm = BM1536_4096;
597 | bk = BBK1536_4096;
598 | }
599 | else if (m == 1536 && k == 1536) {
600 | bm = BM1536_1536;
601 | bk = BBK1536_1536;
602 | }
603 | else if (m == 4096 && k == 1536) {
604 | bm = BM4096_1536;
605 | bk = BBK4096_1536;
606 | }
607 |
608 | const int n_tile_num = m / bm;
609 | const int BK = bk;
610 | uint8_t * qweights;
611 | bitnet_float_type * scales;
612 |
613 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));
614 | qweights = (uint8_t *) tensor->data;
615 | float * i2_scales = (float * )(qweights + k * m / 4);
616 | scales[0] = (bitnet_float_type) i2_scales[0];
617 |
618 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;
619 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = {
620 | /* .lut_scales_size = */ lut_scales_size,
621 | /* .scales_size = */ scales_size,
622 | /* .n_tile_num = */ n_tile_num,
623 | /* .qweights = */ qweights,
624 | /* .scales = */ scales
625 | };
626 | }
627 | #endif
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-large/kernel_config_tl1.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 1536
3 | k = 4096
4 | bm = 256
5 | bk = 128
6 | bmm = 32
7 |
8 | [Kernels_1]
9 | m = 1536
10 | k = 1536
11 | bm = 128
12 | bk = 64
13 | bmm = 64
14 |
15 | [Kernels_2]
16 | m = 4096
17 | k = 1536
18 | bm = 256
19 | bk = 128
20 | bmm = 32
21 |
22 |
--------------------------------------------------------------------------------
/preset_kernels/bitnet_b1_58-large/kernel_config_tl2.ini:
--------------------------------------------------------------------------------
1 | [Kernels_0]
2 | m = 1536
3 | k = 4096
4 | bm = 256
5 | bk = 96
6 | bmm = 32
7 |
8 | [Kernels_1]
9 | m = 1536
10 | k = 1536
11 | bm = 128
12 | bk = 192
13 | bmm = 32
14 |
15 | [Kernels_2]
16 | m = 4096
17 | k = 1536
18 | bm = 256
19 | bk = 96
20 | bmm = 64
21 |
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # These requirements include all dependencies for all top-level python scripts
2 | # for llama.cpp. Avoid adding packages here directly.
3 | #
4 | # Package versions must stay compatible across all top-level python scripts.
5 | #
6 |
7 | -r 3rdparty/llama.cpp/requirements/requirements-convert_legacy_llama.txt
8 | -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt
9 | -r 3rdparty/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt
10 | -r 3rdparty/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt
11 | -r 3rdparty/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt
--------------------------------------------------------------------------------
/run_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import signal
4 | import platform
5 | import argparse
6 | import subprocess
7 |
8 | def run_command(command, shell=False):
9 | """Run a system command and ensure it succeeds."""
10 | try:
11 | subprocess.run(command, shell=shell, check=True)
12 | except subprocess.CalledProcessError as e:
13 | print(f"Error occurred while running command: {e}")
14 | sys.exit(1)
15 |
16 | def run_inference():
17 | build_dir = "build"
18 | if platform.system() == "Windows":
19 | main_path = os.path.join(build_dir, "bin", "Release", "llama-cli.exe")
20 | if not os.path.exists(main_path):
21 | main_path = os.path.join(build_dir, "bin", "llama-cli")
22 | else:
23 | main_path = os.path.join(build_dir, "bin", "llama-cli")
24 | command = [
25 | f'{main_path}',
26 | '-m', args.model,
27 | '-n', str(args.n_predict),
28 | '-t', str(args.threads),
29 | '-p', args.prompt,
30 | '-ngl', '0',
31 | '-c', str(args.ctx_size),
32 | '--temp', str(args.temperature),
33 | "-b", "1",
34 | ]
35 | if args.conversation:
36 | command.append("-cnv")
37 | run_command(command)
38 |
39 | def signal_handler(sig, frame):
40 | print("Ctrl+C pressed, exiting...")
41 | sys.exit(0)
42 |
43 | if __name__ == "__main__":
44 | signal.signal(signal.SIGINT, signal_handler)
45 | # Usage: python run_inference.py -p "Microsoft Corporation is an American multinational corporation and technology company headquartered in Redmond, Washington."
46 | parser = argparse.ArgumentParser(description='Run inference')
47 | parser.add_argument("-m", "--model", type=str, help="Path to model file", required=False, default="models/bitnet_b1_58-3B/ggml-model-i2_s.gguf")
48 | parser.add_argument("-n", "--n-predict", type=int, help="Number of tokens to predict when generating text", required=False, default=128)
49 | parser.add_argument("-p", "--prompt", type=str, help="Prompt to generate text from", required=True)
50 | parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2)
51 | parser.add_argument("-c", "--ctx-size", type=int, help="Size of the prompt context", required=False, default=2048)
52 | parser.add_argument("-temp", "--temperature", type=float, help="Temperature, a hyperparameter that controls the randomness of the generated text", required=False, default=0.8)
53 | parser.add_argument("-cnv", "--conversation", action='store_true', help="Whether to enable chat mode or not (for instruct models.)")
54 |
55 | args = parser.parse_args()
56 | run_inference()
--------------------------------------------------------------------------------
/setup_env.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import signal
3 | import sys
4 | import os
5 | import platform
6 | import argparse
7 | import logging
8 | import shutil
9 | from pathlib import Path
10 |
11 | logger = logging.getLogger("setup_env")
12 |
13 | SUPPORTED_HF_MODELS = {
14 | "1bitLLM/bitnet_b1_58-large": {
15 | "model_name": "bitnet_b1_58-large",
16 | },
17 | "1bitLLM/bitnet_b1_58-3B": {
18 | "model_name": "bitnet_b1_58-3B",
19 | },
20 | "HF1BitLLM/Llama3-8B-1.58-100B-tokens": {
21 | "model_name": "Llama3-8B-1.58-100B-tokens",
22 | },
23 | "tiiuae/Falcon3-7B-Instruct-1.58bit": {
24 | "model_name": "Falcon3-7B-Instruct-1.58bit",
25 | },
26 | "tiiuae/Falcon3-7B-1.58bit": {
27 | "model_name": "Falcon3-7B-1.58bit",
28 | },
29 | "tiiuae/Falcon3-10B-Instruct-1.58bit": {
30 | "model_name": "Falcon3-10B-Instruct-1.58bit",
31 | },
32 | "tiiuae/Falcon3-10B-1.58bit": {
33 | "model_name": "Falcon3-10B-1.58bit",
34 | },
35 | "tiiuae/Falcon3-3B-Instruct-1.58bit": {
36 | "model_name": "Falcon3-3B-Instruct-1.58bit",
37 | },
38 | "tiiuae/Falcon3-3B-1.58bit": {
39 | "model_name": "Falcon3-3B-1.58bit",
40 | },
41 | "tiiuae/Falcon3-1B-Instruct-1.58bit": {
42 | "model_name": "Falcon3-1B-Instruct-1.58bit",
43 | },
44 | "microsoft/BitNet-b1.58-2B-4T": {
45 | "model_name": "BitNet-b1.58-2B-4T",
46 | },
47 | }
48 |
49 | SUPPORTED_QUANT_TYPES = {
50 | "arm64": ["i2_s", "tl1"],
51 | "x86_64": ["i2_s", "tl2"]
52 | }
53 |
54 | COMPILER_EXTRA_ARGS = {
55 | "arm64": ["-DBITNET_ARM_TL1=ON"],
56 | "x86_64": ["-DBITNET_X86_TL2=ON"]
57 | }
58 |
59 | OS_EXTRA_ARGS = {
60 | "Windows":["-T", "ClangCL"],
61 | }
62 |
63 | ARCH_ALIAS = {
64 | "AMD64": "x86_64",
65 | "x86": "x86_64",
66 | "x86_64": "x86_64",
67 | "aarch64": "arm64",
68 | "arm64": "arm64",
69 | "ARM64": "arm64",
70 | }
71 |
72 | def system_info():
73 | return platform.system(), ARCH_ALIAS[platform.machine()]
74 |
75 | def get_model_name():
76 | if args.hf_repo:
77 | return SUPPORTED_HF_MODELS[args.hf_repo]["model_name"]
78 | return os.path.basename(os.path.normpath(args.model_dir))
79 |
80 | def run_command(command, shell=False, log_step=None):
81 | """Run a system command and ensure it succeeds."""
82 | if log_step:
83 | log_file = os.path.join(args.log_dir, log_step + ".log")
84 | with open(log_file, "w") as f:
85 | try:
86 | subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f)
87 | except subprocess.CalledProcessError as e:
88 | logging.error(f"Error occurred while running command: {e}, check details in {log_file}")
89 | sys.exit(1)
90 | else:
91 | try:
92 | subprocess.run(command, shell=shell, check=True)
93 | except subprocess.CalledProcessError as e:
94 | logging.error(f"Error occurred while running command: {e}")
95 | sys.exit(1)
96 |
97 | def prepare_model():
98 | _, arch = system_info()
99 | hf_url = args.hf_repo
100 | model_dir = args.model_dir
101 | quant_type = args.quant_type
102 | quant_embd = args.quant_embd
103 | if hf_url is not None:
104 | # download the model
105 | model_dir = os.path.join(model_dir, SUPPORTED_HF_MODELS[hf_url]["model_name"])
106 | Path(model_dir).mkdir(parents=True, exist_ok=True)
107 | logging.info(f"Downloading model {hf_url} from HuggingFace to {model_dir}...")
108 | run_command(["huggingface-cli", "download", hf_url, "--local-dir", model_dir], log_step="download_model")
109 | elif not os.path.exists(model_dir):
110 | logging.error(f"Model directory {model_dir} does not exist.")
111 | sys.exit(1)
112 | else:
113 | logging.info(f"Loading model from directory {model_dir}.")
114 | gguf_path = os.path.join(model_dir, "ggml-model-" + quant_type + ".gguf")
115 | if not os.path.exists(gguf_path) or os.path.getsize(gguf_path) == 0:
116 | logging.info(f"Converting HF model to GGUF format...")
117 | if quant_type.startswith("tl"):
118 | run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", quant_type, "--quant-embd"], log_step="convert_to_tl")
119 | else: # i2s
120 | # convert to f32
121 | run_command([sys.executable, "utils/convert-hf-to-gguf-bitnet.py", model_dir, "--outtype", "f32"], log_step="convert_to_f32_gguf")
122 | f32_model = os.path.join(model_dir, "ggml-model-f32.gguf")
123 | i2s_model = os.path.join(model_dir, "ggml-model-i2_s.gguf")
124 | # quantize to i2s
125 | if platform.system() != "Windows":
126 | if quant_embd:
127 | run_command(["./build/bin/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s")
128 | else:
129 | run_command(["./build/bin/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s")
130 | else:
131 | if quant_embd:
132 | run_command(["./build/bin/Release/llama-quantize", "--token-embedding-type", "f16", f32_model, i2s_model, "I2_S", "1", "1"], log_step="quantize_to_i2s")
133 | else:
134 | run_command(["./build/bin/Release/llama-quantize", f32_model, i2s_model, "I2_S", "1"], log_step="quantize_to_i2s")
135 |
136 | logging.info(f"GGUF model saved at {gguf_path}")
137 | else:
138 | logging.info(f"GGUF model already exists at {gguf_path}")
139 |
140 | def setup_gguf():
141 | # Install the pip package
142 | run_command([sys.executable, "-m", "pip", "install", "3rdparty/llama.cpp/gguf-py"], log_step="install_gguf")
143 |
144 | def gen_code():
145 | _, arch = system_info()
146 |
147 | llama3_f3_models = set([model['model_name'] for model in SUPPORTED_HF_MODELS.values() if model['model_name'].startswith("Falcon3") or model['model_name'].startswith("Llama")])
148 |
149 | if arch == "arm64":
150 | if args.use_pretuned:
151 | pretuned_kernels = os.path.join("preset_kernels", get_model_name())
152 | if not os.path.exists(pretuned_kernels):
153 | logging.error(f"Pretuned kernels not found for model {args.hf_repo}")
154 | sys.exit(1)
155 | if args.quant_type == "tl1":
156 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl1.h"), "include/bitnet-lut-kernels.h")
157 | shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl1.ini"), "include/kernel_config.ini")
158 | elif args.quant_type == "tl2":
159 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
160 | shutil.copyfile(os.path.join(pretuned_kernels, "kernel_config_tl2.ini"), "include/kernel_config.ini")
161 | if get_model_name() == "bitnet_b1_58-large":
162 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "128,64,128", "--bm", "32,64,32"], log_step="codegen")
163 | elif get_model_name() in llama3_f3_models:
164 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "128,64,128,64", "--bm", "32,64,32,64"], log_step="codegen")
165 | elif get_model_name() == "bitnet_b1_58-3B":
166 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
167 | elif get_model_name() == "BitNet-b1.58-2B-4T":
168 | run_command([sys.executable, "utils/codegen_tl1.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "64,128,64", "--bm", "32,64,32"], log_step="codegen")
169 | else:
170 | raise NotImplementedError()
171 | else:
172 | if args.use_pretuned:
173 | # cp preset_kernels/model_name/bitnet-lut-kernels_tl1.h to include/bitnet-lut-kernels.h
174 | pretuned_kernels = os.path.join("preset_kernels", get_model_name())
175 | if not os.path.exists(pretuned_kernels):
176 | logging.error(f"Pretuned kernels not found for model {args.hf_repo}")
177 | sys.exit(1)
178 | shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h")
179 | if get_model_name() == "bitnet_b1_58-large":
180 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen")
181 | elif get_model_name() in llama3_f3_models:
182 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen")
183 | elif get_model_name() == "bitnet_b1_58-3B":
184 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
185 | elif get_model_name() == "BitNet-b1.58-2B-4T":
186 | run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-3B", "--BM", "160,320,320", "--BK", "96,96,96", "--bm", "32,32,32"], log_step="codegen")
187 | else:
188 | raise NotImplementedError()
189 |
190 |
191 | def compile():
192 | # Check if cmake is installed
193 | cmake_exists = subprocess.run(["cmake", "--version"], capture_output=True)
194 | if cmake_exists.returncode != 0:
195 | logging.error("Cmake is not available. Please install CMake and try again.")
196 | sys.exit(1)
197 | _, arch = system_info()
198 | if arch not in COMPILER_EXTRA_ARGS.keys():
199 | logging.error(f"Arch {arch} is not supported yet")
200 | exit(0)
201 | logging.info("Compiling the code using CMake.")
202 | run_command(["cmake", "-B", "build", *COMPILER_EXTRA_ARGS[arch], *OS_EXTRA_ARGS.get(platform.system(), [])], log_step="generate_build_files")
203 | # run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"])
204 | run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile")
205 |
206 | def main():
207 | setup_gguf()
208 | gen_code()
209 | compile()
210 | prepare_model()
211 |
212 | def parse_args():
213 | _, arch = system_info()
214 | parser = argparse.ArgumentParser(description='Setup the environment for running the inference')
215 | parser.add_argument("--hf-repo", "-hr", type=str, help="Model used for inference", choices=SUPPORTED_HF_MODELS.keys())
216 | parser.add_argument("--model-dir", "-md", type=str, help="Directory to save/load the model", default="models")
217 | parser.add_argument("--log-dir", "-ld", type=str, help="Directory to save the logging info", default="logs")
218 | parser.add_argument("--quant-type", "-q", type=str, help="Quantization type", choices=SUPPORTED_QUANT_TYPES[arch], default="i2_s")
219 | parser.add_argument("--quant-embd", action="store_true", help="Quantize the embeddings to f16")
220 | parser.add_argument("--use-pretuned", "-p", action="store_true", help="Use the pretuned kernel parameters")
221 | return parser.parse_args()
222 |
223 | def signal_handler(sig, frame):
224 | logging.info("Ctrl+C pressed, exiting...")
225 | sys.exit(0)
226 |
227 | if __name__ == "__main__":
228 | signal.signal(signal.SIGINT, signal_handler)
229 | args = parse_args()
230 | Path(args.log_dir).mkdir(parents=True, exist_ok=True)
231 | logging.basicConfig(level=logging.INFO)
232 | main()
--------------------------------------------------------------------------------
/src/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | set(GGML_HEADERS_BITNET ../include/ggml-bitnet.h)
2 | set(GGML_SOURCES_BITNET ggml-bitnet-mad.cpp)
3 | set(GGML_SOURCES_BITNET ggml-bitnet-lut.cpp)
4 |
5 | include_directories(3rdparty/llama.cpp/ggml/include)
6 |
7 | if (NOT (CMAKE_C_COMPILER_ID MATCHES "Clang" OR CMAKE_C_COMPILER_ID STREQUAL "GNU") OR
8 | NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU"))
9 | message(FATAL_ERROR "Clang or GCC is required for Bitnet.cpp compilation")
10 | endif()
11 |
--------------------------------------------------------------------------------
/src/ggml-bitnet-lut.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include "ggml-bitnet.h"
9 | #include "ggml-quants.h"
10 | #include "bitnet-lut-kernels.h"
11 |
12 | #if defined(GGML_BITNET_ARM_TL1)
13 |
14 | void ggml_bitnet_init(void) {
15 | // LOG(INFO) << "ggml_bitnet_init";
16 |
17 | if (initialized) {
18 | return;
19 | }
20 | initialized = true;
21 |
22 | // if (wrapper == nullptr) {
23 | // wrapper = new BITNET::BITNETGeMMWrapper();
24 | // }
25 | if (bitnet_tensor_extras == nullptr) {
26 | bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES];
27 | }
28 | bitnet_tensor_extras_index = 0;
29 | }
30 |
31 | void ggml_bitnet_free(void) {
32 | // LOG(INFO) << "ggml_bitnet_free";
33 |
34 | if (!initialized) {
35 | return;
36 | }
37 | initialized = false;
38 |
39 | // delete wrapper;
40 | // wrapper = nullptr;
41 | for (size_t i = 0; i < bitnet_tensor_extras_index; i++) {
42 | // aligned_free(bitnet_tensor_extras[i].qweights);
43 | // aligned_free(bitnet_tensor_extras[i].scales);
44 | }
45 | delete[] bitnet_tensor_extras;
46 | bitnet_tensor_extras = nullptr;
47 | }
48 |
49 | static bool do_permutate(enum ggml_type type) {
50 | if (type == GGML_TYPE_TL1) {
51 | // Add additional args to decide if permuted I2 or naive I2
52 | return false;
53 | } else {
54 | return true;
55 | }
56 | }
57 |
58 | bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
59 | if ((is_type_supported(src0->type)) &&
60 | src1->type == GGML_TYPE_F32 &&
61 | dst->type == GGML_TYPE_F32 &&
62 | src0->backend == GGML_BACKEND_TYPE_CPU) {
63 | if (src1->ne[1] <= 1) {
64 | return true;
65 | }
66 | }
67 | return false;
68 | }
69 |
70 | size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
71 | const size_t ne01 = src0->ne[1];
72 | const size_t ne10 = src1->ne[0];
73 | const size_t ne11 = src1->ne[1];
74 | const int bits = ggml_bitnet_get_type_bits(src0->type);
75 |
76 | size_t wsize = ne10 * ne11 * 15 * sizeof(int8_t) + 1 * ne11 * 2 * sizeof(bitnet_float_type);
77 | if (sizeof(bitnet_float_type) == 2) {
78 | // Need fp32 to fp16 conversion
79 | wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type);
80 | }
81 | wsize = ((wsize - 1) / 64 + 1) * 64;
82 | return wsize;
83 | }
84 |
85 | int ggml_bitnet_get_type_bits(enum ggml_type type) {
86 | switch (type) {
87 | case GGML_TYPE_TL1:
88 | return 2;
89 | case GGML_TYPE_Q4_0:
90 | return 4;
91 | default:
92 | return 0;
93 | }
94 | }
95 |
96 | #endif
97 | #if defined(GGML_BITNET_X86_TL2)
98 | void ggml_bitnet_init(void) {
99 | // LOG(INFO) << "ggml_bitnet_init";
100 |
101 | if (initialized) {
102 | return;
103 | }
104 | initialized = true;
105 |
106 | // if (wrapper == nullptr) {
107 | // wrapper = new BITNET::BITNETGeMMWrapper();
108 | // }
109 | if (bitnet_tensor_extras == nullptr) {
110 | bitnet_tensor_extras = new bitnet_tensor_extra[GGML_BITNET_MAX_NODES];
111 | }
112 | bitnet_tensor_extras_index = 0;
113 | }
114 |
115 | void ggml_bitnet_free(void) {
116 | // LOG(INFO) << "ggml_bitnet_free";
117 |
118 | if (!initialized) {
119 | return;
120 | }
121 | initialized = false;
122 |
123 | // delete wrapper;
124 | // wrapper = nullptr;
125 | for (size_t i = 0; i < bitnet_tensor_extras_index; i++) {
126 | // aligned_free(bitnet_tensor_extras[i].qweights);
127 | // aligned_free(bitnet_tensor_extras[i].scales);
128 | }
129 | delete[] bitnet_tensor_extras;
130 | bitnet_tensor_extras = nullptr;
131 | }
132 |
133 | bool ggml_bitnet_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
134 | if ((is_type_supported(src0->type)) &&
135 | src1->type == GGML_TYPE_F32 &&
136 | dst->type == GGML_TYPE_F32 &&
137 | src0->backend == GGML_BACKEND_TYPE_CPU) {
138 | return true;
139 | }
140 | return false;
141 | }
142 |
143 | size_t ggml_bitnet_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
144 | const size_t ne01 = src0->ne[1];
145 | const size_t ne10 = src1->ne[0];
146 | const size_t ne11 = src1->ne[1];
147 |
148 | size_t wsize = ne10 * ne11 * 11 * sizeof(int8_t) + 2 * ne11 * 2 * sizeof(bitnet_float_type);
149 | if (sizeof(bitnet_float_type) == 2) {
150 | // Need fp32 to fp16 conversion
151 | wsize += std::max(ne10, ne01) * ne11 * sizeof(bitnet_float_type);
152 | }
153 | wsize = ((wsize - 1) / 64 + 1) * 64;
154 | return wsize;
155 | }
156 |
157 | int ggml_bitnet_get_type_bits(enum ggml_type type) {
158 | switch (type) {
159 | case GGML_TYPE_TL2:
160 | return 2;
161 | case GGML_TYPE_Q4_0:
162 | return 4;
163 | default:
164 | return 0;
165 | }
166 | }
167 | #endif
--------------------------------------------------------------------------------
/src/ggml-bitnet-mad.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "ggml-bitnet.h"
5 | #include "ggml-quants.h"
6 | #include
7 | #include
8 |
9 | #define QK_I2_S 128
10 | #define QK_I2 128
11 |
12 | #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
13 | #include
14 | // horizontally add 8 int32_t
15 | static inline int hsum_i32_8(const __m256i a) {
16 | const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
17 | const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
18 | const __m128i sum64 = _mm_add_epi32(hi64, sum128);
19 | const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
20 | return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
21 | }
22 | #elif defined(__loongarch_asx)
23 | // horizontally add 8 int32_t
24 | static inline int hsum_i32_8(const __m256i a) {
25 |
26 | __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
27 | __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
28 |
29 | __m128i tmp1_128 = lasx_extracti128_lo(tmp1);
30 | __m128i tmp2_128 = lasx_extracti128_lo(tmp2);
31 |
32 | __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
33 |
34 | __m128i ev = __lsx_vpickev_w(sum128, sum128);
35 | __m128i od = __lsx_vpickod_w(sum128, sum128);
36 | __m128i sum64 = __lsx_vadd_w(ev, od);
37 |
38 | int sum64_1, sum64_2;
39 | sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
40 | sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
41 |
42 | return sum64_1 + sum64_2;
43 | }
44 | #endif
45 |
46 | size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
47 | // 2 bits per weight
48 |
49 | size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row);
50 |
51 | int n = nrow * n_per_row;
52 |
53 | // f32 -> q8
54 | double max = 0;
55 | for (int i = 0; i < n; ++i) {
56 | max = fmax(max, (double)fabs((double)src[i]));
57 | }
58 | double i2_scale = max;
59 |
60 | uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t));
61 | for (int i=0; i 0 ? 2 : 0;
67 | }
68 |
69 | memset(dst, 0, n * sizeof(uint8_t) / 4);
70 |
71 | // q8 -> 0, 1, 2
72 | // | | |
73 | // -1, 0, 1
74 |
75 | uint8_t* i2_weight = (uint8_t*)dst;
76 | for (int i = 0; i < n / QK_I2; i++) {
77 | for (int j = 0; j < QK_I2; j++) {
78 | int group_idx = j / 32;
79 | int group_pos = j % 32;
80 | uint8_t temp = (q8[i * QK_I2 + j] << (6 - 2 * group_idx));
81 | i2_weight[i * 32 + group_pos] |= temp;
82 | }
83 | }
84 |
85 | float* scale_ptr = (float*)((char*)i2_weight + n / 4);
86 | scale_ptr[0] = i2_scale;
87 |
88 | free(q8);
89 |
90 | // 32B for alignment
91 | return nrow * row_size / 4 + 32;
92 | }
93 |
94 | void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
95 | const uint8_t * x = (uint8_t *)vx;
96 | const int8_t * y = (int8_t *)vy;
97 |
98 | const int nb = n / QK_I2_S;
99 | const int group32_num = nb / 32;
100 | const int la_num = nb % 32;
101 | const int groupla_num = nb % 32 != 0 ? 1 : 0;
102 |
103 | #if defined(__AVX2__)
104 |
105 | __m256i mask = _mm256_set1_epi8(0x03);
106 | __m256i accu = _mm256_setzero_si256();
107 |
108 | for (int i=0; i < group32_num; i++){
109 | __m256i accu32 = _mm256_setzero_si256();
110 | for (int j=0; j < 32; j++) {
111 | // 128 index
112 | __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + i * 32 * 32 + j * 32));
113 | __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
114 | __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
115 | __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
116 |
117 | // each 32 index
118 | xq8_3 = _mm256_and_si256(xq8_3, mask);
119 | xq8_2 = _mm256_and_si256(xq8_2, mask);
120 | xq8_1 = _mm256_and_si256(xq8_1, mask);
121 | xq8_0 = _mm256_and_si256(xq8_0, mask);
122 |
123 | // each 32 index
124 | __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 0));
125 | __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 32));
126 | __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 64));
127 | __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 96));
128 |
129 | // 128 index accumulation add
130 | // split into 32 accumulation block
131 | // each block each 128 index accumulated 4index
132 | // each index maximum 256
133 | // each block maximum 4 * 256
134 | // each block accumulation maximum 127 * 256
135 | // each 32 group index (128 index in one group) needs cast to int32
136 | xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
137 | xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
138 | xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
139 | xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
140 |
141 | accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1));
142 | accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3));
143 | }
144 | accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, _mm256_set1_epi16(1)), accu);
145 | }
146 |
147 | for (int i = 0; i < groupla_num; i++){
148 | __m256i accula = _mm256_setzero_si256();
149 | for (int j = 0; j < la_num; j++) {
150 | // 128 index
151 | __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + group32_num * 32 * 32 + j * 32));
152 | __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2);
153 | __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4);
154 | __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6);
155 |
156 | // each 32 index
157 | xq8_3 = _mm256_and_si256(xq8_3, mask);
158 | xq8_2 = _mm256_and_si256(xq8_2, mask);
159 | xq8_1 = _mm256_and_si256(xq8_1, mask);
160 | xq8_0 = _mm256_and_si256(xq8_0, mask);
161 |
162 | // each 32 index
163 | __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 0));
164 | __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 32));
165 | __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 64));
166 | __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 96));
167 |
168 | // 128 index accumulation add
169 | // split into 32 accumulation block
170 | // each block each 128 index accumulated 4index
171 | // each index maximum 256
172 | // each block maximum 4 * 256
173 | // each block accumulation maximum 127 * 256
174 | // each 32 group index (128 index in one group) needs cast to int32
175 | xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0);
176 | xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1);
177 | xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2);
178 | xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3);
179 |
180 | accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1));
181 | accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3));
182 | }
183 | accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, _mm256_set1_epi16(1)));
184 | }
185 | int sumi = hsum_i32_8(accu);
186 | *s = (float)sumi;
187 |
188 | #elif defined(__ARM_NEON)
189 |
190 | int32x4_t accu_0 = vdupq_n_s32(0);
191 | int32x4_t accu_1 = vdupq_n_s32(0);
192 | int32x4_t accu_2 = vdupq_n_s32(0);
193 | int32x4_t accu_3 = vdupq_n_s32(0);
194 | const uint8x16_t mask = vdupq_n_u8(3);
195 |
196 | for (int i=0; i < group32_num; i++) {
197 |
198 | #if defined(__ARM_FEATURE_DOTPROD)
199 |
200 | #else
201 | int16x8_t accu32_0 = vdupq_n_s16(0);
202 | int16x8_t accu32_1 = vdupq_n_s16(0);
203 | int16x8_t accu32_2 = vdupq_n_s16(0);
204 | int16x8_t accu32_3 = vdupq_n_s16(0);
205 | #endif
206 |
207 | for (int j=0; j < 32; j++) {
208 | uint8x16_t xq8_6 = vld1q_u8(x + i * 32 * 32 + j * 32);
209 | uint8x16_t xq8_7 = vld1q_u8(x + i * 32 * 32 + j * 32 + 16);
210 | uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2);
211 | uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2);
212 | uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4);
213 | uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4);
214 | uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6);
215 | uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6);
216 |
217 | int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
218 | int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
219 | int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
220 | int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
221 | int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask));
222 | int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask));
223 | int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask));
224 | int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask));
225 |
226 | const int8x16_t yq8_0 = vld1q_s8(y + i * 128 * 32 + j * 128 + 0);
227 | const int8x16_t yq8_1 = vld1q_s8(y + i * 128 * 32 + j * 128 + 16);
228 | const int8x16_t yq8_2 = vld1q_s8(y + i * 128 * 32 + j * 128 + 32);
229 | const int8x16_t yq8_3 = vld1q_s8(y + i * 128 * 32 + j * 128 + 48);
230 | const int8x16_t yq8_4 = vld1q_s8(y + i * 128 * 32 + j * 128 + 64);
231 | const int8x16_t yq8_5 = vld1q_s8(y + i * 128 * 32 + j * 128 + 80);
232 | const int8x16_t yq8_6 = vld1q_s8(y + i * 128 * 32 + j * 128 + 96);
233 | const int8x16_t yq8_7 = vld1q_s8(y + i * 128 * 32 + j * 128 + 112);
234 |
235 | #if defined(__ARM_FEATURE_DOTPROD)
236 | accu_0 = vdotq_s32(accu_0, q8_0, yq8_0);
237 | accu_1 = vdotq_s32(accu_1, q8_1, yq8_1);
238 | accu_2 = vdotq_s32(accu_2, q8_2, yq8_2);
239 | accu_3 = vdotq_s32(accu_3, q8_3, yq8_3);
240 | accu_0 = vdotq_s32(accu_0, q8_4, yq8_4);
241 | accu_1 = vdotq_s32(accu_1, q8_5, yq8_5);
242 | accu_2 = vdotq_s32(accu_2, q8_6, yq8_6);
243 | accu_3 = vdotq_s32(accu_3, q8_7, yq8_7);
244 | #else
245 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_0), vget_low_s8(yq8_0));
246 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_0), vget_high_s8(yq8_0));
247 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_1), vget_low_s8(yq8_1));
248 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_1), vget_high_s8(yq8_1));
249 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_2), vget_low_s8(yq8_2));
250 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_2), vget_high_s8(yq8_2));
251 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_3), vget_low_s8(yq8_3));
252 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_3), vget_high_s8(yq8_3));
253 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_4), vget_low_s8(yq8_4));
254 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_4), vget_high_s8(yq8_4));
255 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_5), vget_low_s8(yq8_5));
256 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_5), vget_high_s8(yq8_5));
257 | accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_6), vget_low_s8(yq8_6));
258 | accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_6), vget_high_s8(yq8_6));
259 | accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_7), vget_low_s8(yq8_7));
260 | accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_7), vget_high_s8(yq8_7));
261 | #endif
262 | }
263 |
264 | #if defined(__ARM_FEATURE_DOTPROD)
265 |
266 | #else
267 | accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accu32_0)));
268 | accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accu32_0));
269 | accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accu32_1)));
270 | accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accu32_1));
271 | accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accu32_2)));
272 | accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accu32_2));
273 | accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accu32_3)));
274 | accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accu32_3));
275 | #endif
276 | }
277 |
278 | for (int i = 0; i < groupla_num; i++){
279 | #if defined(__ARM_FEATURE_DOTPROD)
280 |
281 | #else
282 | int16x8_t accula_0 = vdupq_n_s16(0);
283 | int16x8_t accula_1 = vdupq_n_s16(0);
284 | int16x8_t accula_2 = vdupq_n_s16(0);
285 | int16x8_t accula_3 = vdupq_n_s16(0);
286 | #endif
287 | for (int j = 0; j < la_num; j++) {
288 | uint8x16_t xq8_6 = vld1q_u8(x + group32_num * 32 * 32 + j * 32);
289 | uint8x16_t xq8_7 = vld1q_u8(x + group32_num * 32 * 32 + j * 32 + 16);
290 | uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2);
291 | uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2);
292 | uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4);
293 | uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4);
294 | uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6);
295 | uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6);
296 |
297 | int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));
298 | int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
299 | int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
300 | int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask));
301 | int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask));
302 | int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask));
303 | int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask));
304 | int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask));
305 |
306 | const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 0);
307 | const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 16);
308 | const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 32);
309 | const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 48);
310 | const int8x16_t yq8_4 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 64);
311 | const int8x16_t yq8_5 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 80);
312 | const int8x16_t yq8_6 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 96);
313 | const int8x16_t yq8_7 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 112);
314 |
315 | #if defined(__ARM_FEATURE_DOTPROD)
316 | accu_0 = vdotq_s32(accu_0, q8_0, yq8_0);
317 | accu_1 = vdotq_s32(accu_1, q8_1, yq8_1);
318 | accu_2 = vdotq_s32(accu_2, q8_2, yq8_2);
319 | accu_3 = vdotq_s32(accu_3, q8_3, yq8_3);
320 | accu_0 = vdotq_s32(accu_0, q8_4, yq8_4);
321 | accu_1 = vdotq_s32(accu_1, q8_5, yq8_5);
322 | accu_2 = vdotq_s32(accu_2, q8_6, yq8_6);
323 | accu_3 = vdotq_s32(accu_3, q8_7, yq8_7);
324 | #else
325 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_0), vget_low_s8(yq8_0));
326 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_0), vget_high_s8(yq8_0));
327 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_1), vget_low_s8(yq8_1));
328 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_1), vget_high_s8(yq8_1));
329 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_2), vget_low_s8(yq8_2));
330 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_2), vget_high_s8(yq8_2));
331 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_3), vget_low_s8(yq8_3));
332 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_3), vget_high_s8(yq8_3));
333 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_4), vget_low_s8(yq8_4));
334 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_4), vget_high_s8(yq8_4));
335 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_5), vget_low_s8(yq8_5));
336 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_5), vget_high_s8(yq8_5));
337 | accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_6), vget_low_s8(yq8_6));
338 | accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_6), vget_high_s8(yq8_6));
339 | accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_7), vget_low_s8(yq8_7));
340 | accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_7), vget_high_s8(yq8_7));
341 | #endif
342 | }
343 | #if defined(__ARM_FEATURE_DOTPROD)
344 |
345 | #else
346 | accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accula_0)));
347 | accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accula_0));
348 | accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accula_1)));
349 | accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accula_1));
350 | accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accula_2)));
351 | accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accula_2));
352 | accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accula_3)));
353 | accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accula_3));
354 | #endif
355 | }
356 | accu_0 = vaddq_s32(accu_0, accu_1);
357 | accu_2 = vaddq_s32(accu_2, accu_3);
358 | accu_0 = vaddq_s32(accu_0, accu_2);
359 | int sumi = vaddlvq_s32(accu_0);
360 | *s = (float)sumi;
361 |
362 | #endif
363 | }
--------------------------------------------------------------------------------
/utils/codegen_tl1.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from configparser import ConfigParser
4 |
5 | def gen_ctor_code():
6 | kernel_code = "\n\
7 | #include \"ggml-bitnet.h\"\n\
8 | #define GGML_BITNET_MAX_NODES 8192\n\
9 | static bool initialized = false;\n\
10 | static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\
11 | static size_t bitnet_tensor_extras_index = 0;\n\
12 | static void * aligned_malloc(size_t size) {{\n\
13 | #if defined(_WIN32)\n\
14 | return _aligned_malloc(size, 64);\n\
15 | #else\n\
16 | void * ptr = nullptr;\n\
17 | posix_memalign(&ptr, 64, size);\n\
18 | return ptr;\n\
19 | #endif\n\
20 | }}\n\
21 | static void aligned_free(void * ptr) {{\n\
22 | #if defined(_WIN32)\n\
23 | _aligned_free(ptr);\n\
24 | #else\n\
25 | free(ptr);\n\
26 | #endif\n\
27 | }}\n\
28 | \n\
29 | void per_tensor_quant(int k, void* lut_scales_, void* b_) {{\n\
30 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
31 | bitnet_float_type* b = (bitnet_float_type*)b_;\n\
32 | #ifdef __ARM_NEON\n\
33 | float32x4_t temp_max = vdupq_n_f32(0);\n\
34 | for (int i=0; i < k / 4; i++) {{\n\
35 | float32x4_t vec_bs = vld1q_f32(b + 4 * i);\n\
36 | float32x4_t abssum = vabsq_f32(vec_bs);\n\
37 | temp_max = vmaxq_f32(abssum, temp_max);\n\
38 | }}\n\
39 | float32_t scales = 127 / vmaxvq_f32(temp_max);\n\
40 | *lut_scales = scales;\n\
41 | #elif defined __AVX2__\n\
42 | __m256 max_vec = _mm256_set1_ps(0.f);\n\
43 | const __m256 vec_sign = _mm256_set1_ps(-0.0f);\n\
44 | // #pragma unroll\n\
45 | for (int i = 0; i < k / 8; i++) {{\n\
46 | __m256 vec_b = _mm256_loadu_ps(b + i * 8);\n\
47 | __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);\n\
48 | max_vec = _mm256_max_ps(vec_babs, max_vec);\n\
49 | }}\n\
50 | __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));\n\
51 | max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));\n\
52 | max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));\n\
53 | float scales = 127 / _mm_cvtss_f32(max1);\n\
54 | *lut_scales = scales;\n\
55 | #endif\n\
56 | }}\n\
57 | \n\
58 | void partial_max_reset(void* lut_scales_) {{\n\
59 | bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;\n\
60 | *lut_scales = 0.0;\n\
61 | }}\n\
62 | \n\
63 | #ifdef __ARM_NEON\n\
64 | inline void Transpose_8_8(\n\
65 | int16x8_t *v0,\n\
66 | int16x8_t *v1,\n\
67 | int16x8_t *v2,\n\
68 | int16x8_t *v3,\n\
69 | int16x8_t *v4,\n\
70 | int16x8_t *v5,\n\
71 | int16x8_t *v6,\n\
72 | int16x8_t *v7)\n\
73 | {{\n\
74 | int16x8x2_t q04 = vzipq_s16(*v0, *v4);\n\
75 | int16x8x2_t q15 = vzipq_s16(*v1, *v5);\n\
76 | int16x8x2_t q26 = vzipq_s16(*v2, *v6);\n\
77 | int16x8x2_t q37 = vzipq_s16(*v3, *v7);\n\
78 | \n\
79 | int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]);\n\
80 | int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]);\n\
81 | int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]);\n\
82 | int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]);\n\
83 | \n\
84 | int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]);\n\
85 | int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]);\n\
86 | int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]);\n\
87 | int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]);\n\
88 | \n\
89 | *v0 = q_fin_0.val[0];\n\
90 | *v1 = q_fin_0.val[1];\n\
91 | *v2 = q_fin_1.val[0];\n\
92 | *v3 = q_fin_1.val[1];\n\
93 | *v4 = q_fin_2.val[0];\n\
94 | *v5 = q_fin_2.val[1];\n\
95 | *v6 = q_fin_3.val[0];\n\
96 | *v7 = q_fin_3.val[1];\n\
97 | }}\n\
98 | #endif\n\
99 | \n\
100 | template\n\
101 | inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{\n\
102 | #ifdef __ARM_NEON\n\
103 | int16x8_t vec_lut[16];\n\
104 | float32_t scales = *lut_scales;\n\
105 | uint8_t tbl_mask[16];\n\
106 | tbl_mask[0] = 0;\n\
107 | tbl_mask[1] = 2;\n\
108 | tbl_mask[2] = 4;\n\
109 | tbl_mask[3] = 6;\n\
110 | tbl_mask[4] = 8;\n\
111 | tbl_mask[5] = 10;\n\
112 | tbl_mask[6] = 12;\n\
113 | tbl_mask[7] = 14;\n\
114 | tbl_mask[8] = 1;\n\
115 | tbl_mask[9] = 3;\n\
116 | tbl_mask[10] = 5;\n\
117 | tbl_mask[11] = 7;\n\
118 | tbl_mask[12] = 9;\n\
119 | tbl_mask[13] = 11;\n\
120 | tbl_mask[14] = 13;\n\
121 | tbl_mask[15] = 15;\n\
122 | uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask);\n\
123 | #pragma unroll\n\
124 | for (int k = 0; k < act_k / 16; ++k) {{\n\
125 | float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16);\n\
126 | float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8);\n\
127 | float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales);\n\
128 | float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales);\n\
129 | float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales);\n\
130 | float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales);\n\
131 | int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0);\n\
132 | int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1);\n\
133 | int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2);\n\
134 | int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3);\n\
135 | int16x4_t vec_b16_0 = vmovn_s32(vec_b_0);\n\
136 | int16x4_t vec_b16_1 = vmovn_s32(vec_b_1);\n\
137 | int16x4_t vec_b16_2 = vmovn_s32(vec_b_2);\n\
138 | int16x4_t vec_b16_3 = vmovn_s32(vec_b_3);\n\
139 | int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2);\n\
140 | int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3);\n\
141 | vec_lut[0] = vdupq_n_s16(0);\n\
142 | vec_lut[0] = vec_lut[0] - vec_bs_0;\n\
143 | vec_lut[0] = vec_lut[0] - vec_bs_1;\n\
144 | vec_lut[1] = vdupq_n_s16(0);\n\
145 | vec_lut[1] = vec_lut[1] - vec_bs_0;\n\
146 | vec_lut[2] = vdupq_n_s16(0);\n\
147 | vec_lut[2] = vec_lut[2] - vec_bs_0;\n\
148 | vec_lut[2] = vec_lut[2] + vec_bs_1;\n\
149 | vec_lut[3] = vdupq_n_s16(0);\n\
150 | vec_lut[3] = vec_lut[3] - vec_bs_1;\n\
151 | vec_lut[4] = vdupq_n_s16(0);\n\
152 | vec_lut[5] = vec_bs_1;\n\
153 | vec_lut[6] = vec_bs_0;\n\
154 | vec_lut[6] = vec_lut[6] - vec_bs_1;\n\
155 | vec_lut[7] = vec_bs_0;\n\
156 | vec_lut[8] = vec_bs_0;\n\
157 | vec_lut[8] = vec_lut[8] + vec_bs_1;\n\
158 | Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]),\n\
159 | &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7]));\n\
160 | Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]),\n\
161 | &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15]));\n\
162 | #pragma unroll\n\
163 | for (int idx = 0; idx < 8; idx++) {{\n\
164 | int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q);\n\
165 | int8x8_t q0_low = vget_low_s8(q0_s);\n\
166 | int8x8_t q0_high = vget_high_s8(q0_s);\n\
167 | int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q);\n\
168 | int8x8_t q1_low = vget_low_s8(q1_s);\n\
169 | int8x8_t q1_high = vget_high_s8(q1_s);\n\
170 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high);\n\
171 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high);\n\
172 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low);\n\
173 | vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low);\n\
174 | }}\n\
175 | }}\n\
176 | #endif\n\
177 | }}\n\
178 | \n\
179 | static bool is_type_supported(enum ggml_type type) {{\n\
180 | if (type == GGML_TYPE_Q4_0 ||\n\
181 | type == GGML_TYPE_TL1) {{\n\
182 | return true;\n\
183 | }} else {{\n\
184 | return false;\n\
185 | }}\n\
186 | }}\n\
187 | "
188 | return kernel_code
189 |
190 | def gen_body_core_code(bm, by):
191 | length = 4
192 | all_code = ""
193 | for i in range(length):
194 | core_code = "\n\
195 | uint8x16_t vec_a_{0} = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + {0} * 16);\n\
196 | uint8x16_t vec_a{0}_top = vshrq_n_u8(vec_a_{0}, 4);\n\
197 | uint8x16_t vec_a{0}_bot = vandq_u8(vec_a_{0}, vec_mask);\n\
198 | int8x16_t vec_v_{0}_left_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {2}], vec_a{0}_top);\n\
199 | int8x16_t vec_v_{0}_left_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {3}], vec_a{0}_top);\n\
200 | int8x16_t vec_v_{0}_right_tmp0 = vqtbl1q_s8(vec_lut[{1} * k + {4}], vec_a{0}_bot);\n\
201 | int8x16_t vec_v_{0}_right_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {5}], vec_a{0}_bot);\n\
202 | int8x16x2_t vec_v_left_{0} = vzipq_s8(vec_v_{0}_left_tmp1, vec_v_{0}_left_tmp0);\n\
203 | int8x16x2_t vec_v_right_{0} = vzipq_s8(vec_v_{0}_right_tmp1, vec_v_{0}_right_tmp0);\n\
204 | vec_c[{6}] += vec_v_left_{0}.val[0];\n\
205 | vec_c[{6}] += vec_v_right_{0}.val[0];\n\
206 | vec_c[{7}] += vec_v_left_{0}.val[1];\n\
207 | vec_c[{7}] += vec_v_right_{0}.val[1];\n\
208 | ".format(i, 2 * by // 2, (4 * i) % (2 * by // 2), (4 * i + 1) % (2 * by // 2), (4 * i + 2) % (2 * by // 2), (4 * i + 3) % (2 * by // 2), (i * 2) // (by // 2) * 2 + 0, (i * 2) // (by // 2) * 2 + 1)
209 |
210 | all_code = "".join([all_code, core_code])
211 |
212 | all_code = "".join([all_code, "\n }\n\n"])
213 |
214 | for i in range(bm // 8):
215 | core_code = "\
216 | int32x4_t vec_v_bot_low_low_{0} = vmovl_s16(vget_low_s16(vec_c[{0}]));\n\
217 | int32x4_t vec_v_bot_low_high_{0} = vmovl_high_s16(vec_c[{0}]);\n\
218 | vst1q_s32(c + i + {1}, vld1q_s32(c + i + {1}) + vec_v_bot_low_low_{0});\n\
219 | vst1q_s32(c + i + {2}, vld1q_s32(c + i + {2}) + vec_v_bot_low_high_{0});\n".format(i, i * 8, i * 8 + 4)
220 | all_code = "".join([all_code, core_code])
221 |
222 | return all_code
223 |
224 | def gen_tbl_impl(pre, BM, BK, bm, k):
225 |
226 | kernel_code = "\
227 | #include \n\
228 | \n\
229 | #define BM{0} {1}\n\
230 | #define BBK{0} {2}\n\
231 | inline void tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\
232 | #ifdef __ARM_NEON\n\
233 | const int KK = BBK{0} / 2;\n\
234 | const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\
235 | const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\
236 | int8x16_t vec_lut[2 * KK];\n\
237 | ".format(pre, BM, BK)
238 |
239 | kernel_code = "".join([kernel_code, " int16x8_t vec_c[{}];".format(bm // 8)])
240 |
241 | kernel_code = "".join([kernel_code, "\n\
242 | #pragma unroll\n\
243 | for (int k = 0; k < 2 * KK; k++) {\n\
244 | vec_lut[k] = vld1q_s8(lut + k * 16);\n\
245 | }\n"])
246 |
247 | pre_core_code = "\n\
248 | #pragma unroll\n\
249 | for (int i = 0; i < BM{}; i += {}) {{\n\
250 | #pragma unroll\n\
251 | for (int i=0; i<{}; i++) {{\n\
252 | vec_c[i] = vandq_s16(vec_c[i], vec_zero);\n\
253 | }}\n".format(pre, bm, bm // 8)
254 |
255 | body_core_pre_code = "\n\
256 | #pragma unroll\n\
257 | for (int k = 0; k < KK / {}; k++) {{\n\
258 | ".format(256 // bm // 2)
259 |
260 | body_core_post_code = "\n\
261 | }\n\
262 | \
263 | #endif\n\
264 | }\n"
265 |
266 | kernel_code = "".join([kernel_code, pre_core_code, body_core_pre_code, gen_body_core_code(bm, 256 // bm), body_core_post_code])
267 |
268 | kernel_code = "".join([kernel_code, "\n\
269 | int32_t qgemm_lut_{0}(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
270 | alignas({1}) uint32_t CBits[BM{0}];\n\
271 | memset(&(CBits[0]), 0, BM{0} * sizeof(int32_t));\n\
272 | #pragma unroll\n\
273 | for (int32_t k_outer = 0; k_outer < {2} / BBK{0}; ++k_outer) {{\n\
274 | tbl_impl_{0}((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK{0} / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK{0} / 2 / 2 * BM{0})])));\n\
275 | }}\n\
276 | #pragma unroll\n\
277 | for (int i = 0; i < BM{0}; i++) {{\n\
278 | ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0];\n\
279 | }}\n\
280 | return 0;\n\
281 | }};\n".format(pre, min(32, BK), k)])
282 |
283 | return kernel_code
284 |
285 | def gen_top_api(kernel_shapes):
286 |
287 | kernel_code = "void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) {{\n\
288 | if (m == {0} && k == {1}) {{\n\
289 | preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\
290 | }}\n\
291 | ".format(kernel_shapes[0][0], kernel_shapes[0][1])
292 | for i in range(1, len(kernel_shapes)):
293 | kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\
294 | preprocessor_k<{1}>(B, LUT_Scales, QLUT);\n\
295 | }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])])
296 | kernel_code = "".join([kernel_code, "}\n"])
297 | kernel_code = "".join([kernel_code, "void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {{\n\
298 | if (m == {0} && k == {1}) {{\n\
299 | qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\
300 | }}\n\
301 | ".format(kernel_shapes[0][0], kernel_shapes[0][1])])
302 | for i in range(1, len(kernel_shapes)):
303 | kernel_code = "".join([kernel_code, " else if (m == {0} && k == {1}) {{\n\
304 | qgemm_lut_{0}_{1}(A, LUT, Scales, LUT_Scales, C);\n\
305 | }}\n\
306 | ".format(kernel_shapes[i][0], kernel_shapes[i][1])])
307 | kernel_code = "".join([kernel_code, "}\n"])
308 | return kernel_code
309 |
310 | def gen_preprocess_code():
311 | kernel_code = "\n\
312 | template\n\
313 | void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{\n\
314 | partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0])));\n\
315 | per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0])));\n\
316 | \n\
317 | lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0])));\n\
318 | }}\n"
319 | return kernel_code
320 |
321 | def gen_transform_code(kernel_shape):
322 | kernel_code = "\n\
323 | void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {\n\
324 | if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {\n\
325 | return;\n\
326 | }\n\
327 | \n\
328 | int k = tensor->ne[0];\n\
329 | int m = tensor->ne[1];\n\
330 | const int lut_scales_size = 1;\n\
331 | const int scales_size = 1;\n\
332 | int bk = 0;\n\
333 | int bm = 0;\n"
334 |
335 | kernel_code = "".join([kernel_code, "\n\
336 | if (m == {0} && k == {1}) {{\n\
337 | bm = BM{0}_{1};\n\
338 | bk = BBK{0}_{1};\n\
339 | }}\n".format(kernel_shapes[0][0], kernel_shapes[0][1])])
340 |
341 | for i in range(1, len(kernel_shapes)):
342 | kernel_code = "".join([kernel_code, "else if (m == {0} && k == {1}) {{\n\
343 | bm = BM{0}_{1};\n\
344 | bk = BBK{0}_{1};\n\
345 | }}\n".format(kernel_shapes[i][0], kernel_shapes[i][1])])
346 |
347 | kernel_code = "".join([kernel_code, "\n\
348 | const int n_tile_num = m / bm;\n\
349 | const int BK = bk;\n\
350 | uint8_t * qweights;\n\
351 | bitnet_float_type * scales;\n\
352 | \n\
353 | scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));\n\
354 | qweights = (uint8_t *) tensor->data;\n\
355 | float * i2_scales = (float * )(qweights + k * m / 4);\n\
356 | scales[0] = (bitnet_float_type) i2_scales[0];\n\
357 | \n\
358 | tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\
359 | bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\
360 | /* .lut_scales_size = */ lut_scales_size,\n\
361 | /* .BK = */ BK,\n\
362 | /* .n_tile_num = */ n_tile_num,\n\
363 | /* .qweights = */ qweights,\n\
364 | /* .scales = */ scales\n\
365 | };\n\
366 | }\n"])
367 |
368 | return kernel_code
369 |
370 | if __name__ == "__main__":
371 | ModelShapeDict = {
372 | "bitnet_b1_58-large" : [[1536, 4096],
373 | [1536, 1536],
374 | [4096, 1536]],
375 | "bitnet_b1_58-3B" : [[3200, 8640],
376 | [3200, 3200],
377 | [8640, 3200]],
378 | "Llama3-8B-1.58-100B-tokens" : [[14336, 4096],
379 | [4096, 14336],
380 | [1024, 4096],
381 | [4096, 4096]]
382 | }
383 |
384 | parser = argparse.ArgumentParser(description='gen impl')
385 | parser.add_argument('--model',default="input", type=str, dest="model",
386 | help="choose from bitnet_b1_58-large/bitnet_b1_58-3B/Llama3-8B-1.58-100B-tokens.")
387 | parser.add_argument('--BM',default="input", type=str,
388 | help="block length when cutting one weight (M, K) into M / BM weights (BM, K).")
389 | parser.add_argument('--BK',default="input", type=str,
390 | help="block length when cutting one weight (M, K) into K / BK weights (M, BK).")
391 | parser.add_argument('--bm',default="input", type=str,
392 | help="using simd instructions to compute (bm, 256 / bm) in one block")
393 | args = parser.parse_args()
394 |
395 | kernel_shapes = ModelShapeDict[args.model]
396 |
397 | BM_list = [int(item) for item in args.BM.split(',')]
398 | BK_list = [int(item) for item in args.BK.split(',')]
399 | bm_list = [int(item) for item in args.bm.split(',')]
400 |
401 | assert(len(BM_list) == len(BK_list) == len(bm_list) == len(kernel_shapes)), "number of BM / BK / bm shoud be {}".format(len(kernel_shapes))
402 |
403 | for i in range(len(kernel_shapes)):
404 | assert kernel_shapes[i][0] % BM_list[i] == 0, "M %% BM should be 0"
405 | assert kernel_shapes[i][1] % BK_list[i] == 0, "K %% BK should be 0"
406 | assert bm_list[i] in [32, 64], "choose bm from [32, 64]"
407 |
408 | tbl_impl_code = []
409 |
410 | for i in range(len(kernel_shapes)):
411 | tbl_impl_code.append(
412 | gen_tbl_impl("{}_{}".format(kernel_shapes[i][0], kernel_shapes[i][1]), BM_list[i], BK_list[i], bm_list[i], kernel_shapes[i][1])
413 | )
414 | api_code = gen_top_api(kernel_shapes)
415 | pre_code = gen_preprocess_code()
416 | ctor_code = gen_ctor_code()
417 | trans_code = gen_transform_code(kernel_shapes)
418 |
419 | output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "include")
420 |
421 | with open(''.join([output_dir, "/bitnet-lut-kernels.h"]), 'w') as f:
422 | f.write(''.join("#if defined(GGML_BITNET_ARM_TL1)"))
423 | f.write(''.join(ctor_code))
424 | for code in tbl_impl_code:
425 | f.write(''.join(code))
426 | f.write(''.join(pre_code))
427 | f.write(''.join(api_code))
428 | f.write(''.join(trans_code))
429 | f.write(''.join("#endif"))
430 |
431 | config = ConfigParser()
432 |
433 | for i in range(len(kernel_shapes)):
434 | config.add_section('Kernels_{}'.format(i))
435 | config.set('Kernels_{}'.format(i), 'M'.format(i), str(kernel_shapes[i][0]))
436 | config.set('Kernels_{}'.format(i), 'K'.format(i), str(kernel_shapes[i][1]))
437 | config.set('Kernels_{}'.format(i), 'BM'.format(i), str(BM_list[i]))
438 | config.set('Kernels_{}'.format(i), 'BK'.format(i), str(BK_list[i]))
439 | config.set('Kernels_{}'.format(i), 'bmm'.format(i), str(bm_list[i]))
440 |
441 | with open(''.join([output_dir, "/kernel_config.ini"]), 'w') as configfile:
442 | config.write(configfile)
--------------------------------------------------------------------------------
/utils/e2e_benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import logging
4 | import argparse
5 | import platform
6 | import subprocess
7 |
8 | def run_command(command, shell=False, log_step=None):
9 | """Run a system command and ensure it succeeds."""
10 | if log_step:
11 | log_file = os.path.join(args.log_dir, log_step + ".log")
12 | with open(log_file, "w") as f:
13 | try:
14 | subprocess.run(command, shell=shell, check=True, stdout=f, stderr=f)
15 | except subprocess.CalledProcessError as e:
16 | logging.error(f"Error occurred while running command: {e}, check details in {log_file}")
17 | sys.exit(1)
18 | else:
19 | try:
20 | subprocess.run(command, shell=shell, check=True)
21 | except subprocess.CalledProcessError as e:
22 | logging.error(f"Error occurred while running command: {e}")
23 | sys.exit(1)
24 |
25 | def run_benchmark():
26 | build_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "build")
27 | if platform.system() == "Windows":
28 | bench_path = os.path.join(build_dir, "bin", "Release", "llama-bench.exe")
29 | if not os.path.exists(bench_path):
30 | bench_path = os.path.join(build_dir, "bin", "llama-bench")
31 | else:
32 | bench_path = os.path.join(build_dir, "bin", "llama-bench")
33 | if not os.path.exists(bench_path):
34 | logging.error(f"Benchmark binary not found, please build first.")
35 | sys.exit(1)
36 | command = [
37 | f'{bench_path}',
38 | '-m', args.model,
39 | '-n', str(args.n_token),
40 | '-ngl', '0',
41 | '-b', '1',
42 | '-t', str(args.threads),
43 | '-p', str(args.n_prompt),
44 | '-r', '5'
45 | ]
46 | run_command(command)
47 |
48 | def parse_args():
49 | parser = argparse.ArgumentParser(description='Setup the environment for running the inference')
50 | parser.add_argument("-m", "--model", type=str, help="Path to model file", required=True)
51 | parser.add_argument("-n", "--n-token", type=int, help="Number of generated tokens", required=False, default=128)
52 | parser.add_argument("-p", "--n-prompt", type=int, help="Prompt to generate text from", required=False, default=512)
53 | parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2)
54 | return parser.parse_args()
55 |
56 | if __name__ == "__main__":
57 | logging.basicConfig(level=logging.INFO)
58 | args = parse_args()
59 | run_benchmark()
--------------------------------------------------------------------------------
/utils/kernel_tuning.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/BitNet/c17d1c5d77c48af7d6fb29c9f28a3da0277fc394/utils/kernel_tuning.py
--------------------------------------------------------------------------------