├── .github ├── 8raspi.jpg ├── 8raspi2.jpg ├── cover.png └── workflows │ └── main.yml ├── .gitignore ├── .vscode └── launch.json ├── LICENSE ├── Makefile ├── README.md ├── converter ├── .gitignore ├── convert-hf.py ├── convert-llama.py ├── convert-tokenizer-hf.py ├── convert-tokenizer-llama2.py ├── convert-tokenizer-llama3.py ├── requirements.txt ├── tokenizer-writer.py ├── writer-test.py └── writer.py ├── docs ├── HUGGINGFACE.md └── LLAMA.md ├── examples ├── chat-api-client.js ├── macbeth.sh └── n-workers.sh ├── launch.py ├── report └── report.pdf └── src ├── api-types.hpp ├── app.cpp ├── app.hpp ├── dllama-api.cpp ├── dllama.cpp ├── json.hpp ├── llm.cpp ├── llm.hpp ├── mmap.hpp ├── nn ├── llamafile │ ├── sgemm.cpp │ └── sgemm.hpp ├── nn-config-builder.hpp ├── nn-core.cpp ├── nn-core.hpp ├── nn-cpu-ops-test.cpp ├── nn-cpu-ops.cpp ├── nn-cpu-ops.hpp ├── nn-cpu-test.cpp ├── nn-cpu.cpp ├── nn-cpu.hpp ├── nn-executor.cpp ├── nn-executor.hpp ├── nn-network.cpp ├── nn-network.hpp ├── nn-quants.cpp ├── nn-quants.hpp ├── nn-vulkan-test.cpp ├── nn-vulkan.cpp ├── nn-vulkan.hpp ├── pthread.h └── vulkan │ ├── cast-forward-f32-f32.comp │ ├── cast-forward-f32-q80.comp │ ├── embedding-forward-f32-f32.comp │ ├── inv-rms-forward-f32-f32.comp │ ├── matmul-forward-f32-f32-f32.comp │ ├── matmul-forward-q80-q40-f32.comp │ ├── merge-add-forward-f32-f32.comp │ ├── merge-add-forward-q80-f32.comp │ ├── mul-forward-f32-f32.comp │ ├── multi-head-att-forward-f32-f32.comp │ ├── rms-norm-forward-f32-f32-f32.comp │ ├── rope-forward-f32-f32.comp │ ├── shift-forward-f32-f32.comp │ └── silu-forward-f32-f32.comp ├── tokenizer-test.cpp ├── tokenizer.cpp └── tokenizer.hpp /.github/8raspi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/b4rtaz/distributed-llama/a16d2f03e66437088dce2ba4b82304a8101c074f/.github/8raspi.jpg -------------------------------------------------------------------------------- /.github/8raspi2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/b4rtaz/distributed-llama/a16d2f03e66437088dce2ba4b82304a8101c074f/.github/8raspi2.jpg -------------------------------------------------------------------------------- /.github/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/b4rtaz/distributed-llama/a16d2f03e66437088dce2ba4b82304a8101c074f/.github/cover.png -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: main 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | - feat/nn 7 | push: 8 | branches: 9 | - main 10 | - feat/nn 11 | jobs: 12 | build-linux: 13 | name: Linux 14 | runs-on: ${{matrix.os}} 15 | strategy: 16 | matrix: 17 | os: 18 | - ubuntu-latest 19 | platforms: 20 | - linux/arm64 21 | - linux/amd64 22 | steps: 23 | - name: Checkout Repo 24 | uses: actions/checkout@v3 25 | - name: Dependencies 26 | id: dependencies 27 | run: sudo apt-get update && sudo apt-get install build-essential 28 | - name: Build 29 | id: build 30 | run: | 31 | make dllama 32 | make nn-cpu-test 33 | make nn-cpu-ops-test 34 | make tokenizer-test 35 | - name: nn-cpu-test 36 | run: ./nn-cpu-test 37 | - name: nn-cpu-ops-test 38 | run: ./nn-cpu-ops-test 39 | - name: tokenizer-test 40 | run: ./tokenizer-test 41 | 42 | build-windows: 43 | name: Windows 44 | runs-on: windows-latest 45 | steps: 46 | - name: Checkout Repo 47 | uses: actions/checkout@v3 48 | - name: Dependencies 49 | id: dependencies 50 | run: choco install make 51 | - name: Build 52 | id: build 53 | run: | 54 | make dllama 55 | make nn-cpu-test 56 | make nn-cpu-ops-test 57 | make tokenizer-test 58 | - name: nn-cpu-test 59 | run: ./nn-cpu-test 60 | - name: nn-cpu-ops-test 61 | run: ./nn-cpu-ops-test 62 | - name: tokenizer-test 63 | run: ./tokenizer-test 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/settings.json 2 | 3 | *.o 4 | *.0 5 | *.dSYM 6 | *.data 7 | *.temp 8 | *.tmp 9 | __pycache__ 10 | 11 | *-test 12 | /models 13 | main 14 | run*.sh 15 | server 16 | /dllama 17 | /dllama-* 18 | *.exe 19 | *.spv -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "main", 6 | "type": "cppdbg", 7 | "request": "launch", 8 | "program": "${workspaceFolder}/main", 9 | "args": [], 10 | "stopAtEntry": false, 11 | "cwd": "${workspaceFolder}", 12 | "environment": [], 13 | "externalConsole": false, 14 | "MIMode": "lldb" 15 | } 16 | ] 17 | } 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2024 Bartłomiej Tadych (b4rtaz) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CXX = g++ 2 | CXXFLAGS = -std=c++11 -Werror -Wformat -Werror=format-security 3 | 4 | ifndef TERMUX_VERSION 5 | CXXFLAGS += -march=native -mtune=native 6 | endif 7 | 8 | ifdef DEBUG 9 | CXXFLAGS += -g -fsanitize=address 10 | else 11 | CXXFLAGS += -O3 12 | endif 13 | 14 | ifdef WVLA 15 | CXXFLAGS += -Wvla-extension 16 | endif 17 | 18 | ifdef DLLAMA_VULKAN 19 | CGLSLC = glslc 20 | 21 | ifeq ($(OS),Windows_NT) 22 | LIBS += -L$(VK_SDK_PATH)\lib -lvulkan-1 23 | CXXFLAGS += -DDLLAMA_VULKAN -I$(VK_SDK_PATH)\include 24 | else 25 | LIBS += -lvulkan 26 | CXXFLAGS += -DDLLAMA_VULKAN 27 | endif 28 | 29 | DEPS += nn-vulkan.o 30 | endif 31 | 32 | ifeq ($(OS),Windows_NT) 33 | LIBS += -lws2_32 34 | DELETE_CMD = del /f 35 | else 36 | LIBS += -lpthread 37 | DELETE_CMD = rm -fv 38 | endif 39 | 40 | .PHONY: clean dllama 41 | 42 | clean: 43 | $(DELETE_CMD) *.o dllama dllama-* socket-benchmark mmap-buffer-* *-test *.exe 44 | 45 | # nn 46 | nn-quants.o: src/nn/nn-quants.cpp 47 | $(CXX) $(CXXFLAGS) -c $^ -o $@ 48 | nn-core.o: src/nn/nn-core.cpp 49 | $(CXX) $(CXXFLAGS) -c $^ -o $@ 50 | nn-executor.o: src/nn/nn-executor.cpp 51 | $(CXX) $(CXXFLAGS) -c $^ -o $@ 52 | nn-network.o: src/nn/nn-network.cpp 53 | $(CXX) $(CXXFLAGS) -c $^ -o $@ 54 | llamafile-sgemm.o: src/nn/llamafile/sgemm.cpp 55 | $(CXX) $(CXXFLAGS) -c $^ -o $@ 56 | nn-cpu-ops.o: src/nn/nn-cpu-ops.cpp 57 | $(CXX) $(CXXFLAGS) -c $^ -o $@ 58 | nn-cpu.o: src/nn/nn-cpu.cpp 59 | $(CXX) $(CXXFLAGS) -c $^ -o $@ 60 | nn-cpu-test: src/nn/nn-cpu-test.cpp nn-quants.o nn-core.o nn-executor.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o 61 | $(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS) 62 | nn-cpu-ops-test: src/nn/nn-cpu-ops-test.cpp nn-quants.o nn-core.o nn-executor.o llamafile-sgemm.o nn-cpu.o 63 | $(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS) 64 | nn-vulkan.o: src/nn/nn-vulkan.cpp 65 | $(CXX) $(CXXFLAGS) -c $^ -o $@ 66 | 67 | ifdef DLLAMA_VULKAN 68 | VULKAN_SHADER_SRCS := $(wildcard src/nn/vulkan/*.comp) 69 | VULKAN_SHADER_BINS := $(VULKAN_SHADER_SRCS:.comp=.spv) 70 | DEPS += $(VULKAN_SHADER_BINS) 71 | 72 | %.spv: %.comp 73 | $(CGLSLC) -c $< -o $@ 74 | nn-vulkan-test: src/nn/nn-vulkan-test.cpp nn-quants.o nn-core.o nn-executor.o nn-vulkan.o ${DEPS} 75 | $(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS) 76 | endif 77 | 78 | # llm 79 | tokenizer.o: src/tokenizer.cpp 80 | $(CXX) $(CXXFLAGS) -c $^ -o $@ 81 | llm.o: src/llm.cpp 82 | $(CXX) $(CXXFLAGS) -c $^ -o $@ 83 | app.o: src/app.cpp 84 | $(CXX) $(CXXFLAGS) -c $^ -o $@ 85 | tokenizer-test: src/tokenizer-test.cpp nn-quants.o nn-core.o llamafile-sgemm.o nn-cpu-ops.o tokenizer.o 86 | $(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS) 87 | dllama: src/dllama.cpp nn-quants.o nn-core.o nn-executor.o nn-network.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o tokenizer.o llm.o app.o ${DEPS} 88 | $(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS) 89 | dllama-api: src/dllama-api.cpp nn-quants.o nn-core.o nn-executor.o nn-network.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o tokenizer.o llm.o app.o ${DEPS} 90 | $(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS) 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Distributed Llama](.github/cover.png) 2 | 3 | # Distributed Llama 4 | 5 | [![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/b4rtaz/distributed-llama/.github%2Fworkflows%2Fmain.yml?style=flat-square)](https://github.com/b4rtaz/distributed-llama/actions) [![License: MIT](https://img.shields.io/github/license/mashape/apistatus.svg?style=flat-square)](/LICENSE) [![Support this project](https://img.shields.io/github/sponsors/b4rtaz?style=flat-square&label=support%20this%20project&color=green)](https://github.com/sponsors/b4rtaz) [![Discord](https://discordapp.com/api/guilds/1245814812353495070/widget.png?style=shield)](https://discord.com/widget?id=1245814812353495070&theme=dark) 6 | 7 | Connect home devices into a powerful cluster to accelerate LLM inference. More devices mean faster performance, leveraging tensor parallelism and high-speed synchronization over Ethernet. 8 | 9 | Supports Linux, macOS, and Windows. Optimized for ARM and x86_64 AVX2 CPUs. 10 | 11 | **News** 12 | - 23 Mar 2025 - [🌋 Experimental Vulkan support](https://github.com/b4rtaz/distributed-llama/releases/tag/v0.13.0) 13 | - 12 Feb 2025 - 🚧 Merged the [fundamental codebase refactor](https://github.com/b4rtaz/distributed-llama/releases/tag/v0.12.0) 14 | - 9 Jan 2025 - [🍎 Llama 3.3 70B on 4 x Mac Mini M4 Pro 24GB RAM](https://github.com/b4rtaz/distributed-llama/discussions/147) 15 | - 28 Jul 2024 - [🌳 How to Run Llama 3.1 405B on Home Devices? Build AI Cluster!](https://medium.com/@b4rtaz/how-to-run-llama-3-405b-on-home-devices-build-ai-cluster-ad0d5ad3473b) 16 | 17 | 18 | ### 🔥 Setup Root Node by Single Command 19 | 20 | Python 3 and C++ compiler required. The command will download the model and the tokenizer. 21 | 22 | | Model | Size | Command | 23 | | --------------------------------- | -------- | ---------------------------------------------------- | 24 | | Llama 3.1 8B Instruct Q40 | 6.32 GB | `python launch.py llama3_1_8b_instruct_q40` | 25 | | Llama 3.1 405B Instruct Q40. | 238 GB | `python launch.py llama3_1_405b_instruct_q40`. | 26 | | Llama 3.2 1B Instruct Q40 | 1.7 GB | `python launch.py llama3_2_1b_instruct_q40` | 27 | | Llama 3.2 3B Instruct Q40 | 3.4 GB | `python launch.py llama3_2_3b_instruct_q40` | 28 | | Llama 3.3 70B Instruct Q40 | 40 GB | `python launch.py llama3_3_70b_instruct_q40` | 29 | | DeepSeek R1 Distill Llama 8B Q40 | 6.32 GB | `python launch.py deepseek_r1_distill_llama_8b_q40` | 30 | 31 | ### 🛠️ Convert Model Manually 32 | 33 | Supported architectures: Llama. 34 | 35 | * [How to Convert Llama 3.1](./docs/LLAMA.md) 36 | * [How to Convert Hugging Face Model](./docs/HUGGINGFACE.md) 37 | 38 | ### 🚧 Known Limitations 39 | 40 | * You can run Distributed Llama only on 1, 2, 4... 2^n nodes. 41 | * The maximum number of nodes is equal to the number of KV heads in the model [#70](https://github.com/b4rtaz/distributed-llama/issues/70). 42 | * Only the following quantizations are supported [#183](https://github.com/b4rtaz/distributed-llama/issues/183): 43 | * `q40` model with `q80` `buffer-float-type` 44 | * `f32` model with `f32` `buffer-float-type` 45 | 46 | ### 👷 Architecture 47 | 48 | The project is split up into two parts: 49 | * **Root node** - it's responsible for loading the model and weights and forward them to workers. Also, it synchronizes the state of the neural network. The root node is also a worker, it processes own slice of the neural network. 50 | * **Worker node** - it processes own slice of the neural network. It doesn't require any configuration related to the model. 51 | 52 | You always need the root node and you can add 2^n - 1 worker nodes to speed up the inference. The RAM usage of the neural network is split up across all nodes. The root node requires a bit more RAM than worker nodes. 53 | 54 | ### 🎹 Commands 55 | 56 | * `dllama inference` - run the inference with a simple benchmark, 57 | * `dllama chat` - run the CLI chat, 58 | * `dllama worker` - run the worker node, 59 | * `dllama-api` - run the API server. 60 | 61 |
62 | 63 | 🎹 Supported Arguments 64 | 65 |
Inference, Chat, API 66 | 67 | | Argument | Description | Example | 68 | | ---------------------------- | ---------------------------------------------------------------- | -------------------------------------- | 69 | | `--model ` | Path to model. | `dllama_model_meta-llama-3-8b_q40.m` | 70 | | `--tokenizer ` | Tokenizer to model. | `dllama_tokenizer_llama3.t` | 71 | | `--buffer-float-type ` | Float precision of synchronization. | `q80` | 72 | | `--workers ` | Addresses of workers (ip:port), separated by space. | `10.0.0.1:9999 10.0.0.2:9999` | 73 | | `--max-seq-len ` | The maximum sequence length, it helps to reduce the RAM usage. | `4096` | 74 | 75 | Inference, Chat, Worker, API 76 | 77 | | Argument | Description | Example | 78 | | ---------------------------- | --------------------------------------------------------------------- | ----------------------------------- | 79 | | `--nthreads ` | Amount of threads. Don't set a higher value than number of CPU cores. | `4` | 80 | 81 | Worker, API 82 | 83 | | Argument | Description | Example | 84 | | ---------------------------- | --------------------------------- | ----------------- | 85 | | `--port ` | Binding port. | `9999` | 86 | 87 | Inference 88 | 89 | | Argument | Description | Example | 90 | | ---------------------------- | ------------------------------ | ------------------ | 91 | | `--prompt ` | Initial prompt. | `"Hello World"` | 92 | | `--steps ` | Number of tokens to generate. | `256` | 93 | 94 |
95 | 96 | ## 📊 Measurements 97 | 98 | Please check the [discussions](https://github.com/b4rtaz/distributed-llama/discussions) section, where many measurements were published on different configurations. 99 | 100 | ## 🚀 Setup 101 | 102 | Select and expand one of the sections below: 103 | 104 |
105 | 106 | 💻 MacOS, Linux, or Windows 107 | 108 |
You need x86_64 AVX2 CPUs or ARM CPUs. Different devices may have different CPUs. 109 | 110 | #### MacOS or Linux 111 | 112 | The below instructions are for Debian-based distributions but you can easily adapt them to your distribution, macOS. 113 | 114 | 1. Install Git and GCC: 115 | ```sh 116 | sudo apt install git build-essential 117 | ``` 118 | 2. Clone this repository and compile Distributed Llama on all computers: 119 | ```sh 120 | git clone https://github.com/b4rtaz/distributed-llama.git 121 | cd distributed-llama 122 | make dllama 123 | make dllama-api 124 | ``` 125 | 126 | Continue to point 3. 127 | 128 | #### Windows 129 | 130 | 1. Install Git and Mingw (via [Chocolatey](https://chocolatey.org/install)): 131 | ```powershell 132 | choco install mingw 133 | ``` 134 | 2. Clone this repository and compile Distributed Llama on all computers: 135 | ```sh 136 | git clone https://github.com/b4rtaz/distributed-llama.git 137 | cd distributed-llama 138 | make dllama 139 | make dllama-api 140 | ``` 141 | 142 | Continue to point 3. 143 | 144 | #### Run Cluster 145 | 146 | 3. Transfer weights and the tokenizer file to the root computer. 147 | 4. Run worker nodes on worker computers: 148 | ```sh 149 | ./dllama worker --port 9999 --nthreads 4 150 | ``` 151 | 5. Run root node on the root computer: 152 | ```sh 153 | ./dllama inference --model dllama_model_meta-llama-3-8b_q40.m --tokenizer dllama_tokenizer_llama3.t --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4 --workers 192.168.0.1:9999 154 | ``` 155 | 156 | To add more worker nodes, just add more addresses to the `--workers` argument. 157 | 158 | ``` 159 | ./dllama inference ... --workers 192.168.0.1:9999 192.168.0.2:9999 192.168.0.3:9999 160 | ``` 161 | 162 |
163 | 164 |
165 | 166 | 📟 Raspberry Pi 167 | 168 |
169 | 170 | 1. Install `Raspberry Pi OS Lite (64 bit)` on your Raspberry Pi devices. This OS doesn't have desktop environment. 171 | 2. Connect all devices to your switch or router. 172 | 3. Connect to all devices via SSH. 173 | ``` 174 | ssh user@raspberrypi1.local 175 | ssh user@raspberrypi2.local 176 | ``` 177 | 4. Install Git: 178 | ```sh 179 | sudo apt install git 180 | ``` 181 | 5. Clone this repository and compile Distributed Llama on all devices: 182 | ```sh 183 | git clone https://github.com/b4rtaz/distributed-llama.git 184 | cd distributed-llama 185 | make dllama 186 | make dllama-api 187 | ``` 188 | 6. Transfer weights and the tokenizer file to the root device. 189 | 7. Optional: assign static IP addresses. 190 | ```sh 191 | sudo ip addr add 10.0.0.1/24 dev eth0 # 1th device 192 | sudo ip addr add 10.0.0.2/24 dev eth0 # 2th device 193 | ``` 194 | 8. Run worker nodes on worker devices: 195 | ```sh 196 | sudo nice -n -20 ./dllama worker --port 9999 --nthreads 4 197 | ``` 198 | 9. Run root node on the root device: 199 | ```sh 200 | sudo nice -n -20 ./dllama inference --model dllama_model_meta-llama-3-8b_q40.m --tokenizer dllama_tokenizer_llama3.t --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4 --workers 10.0.0.2:9999 201 | ``` 202 | 203 | To add more worker nodes, just add more addresses to the `--workers` argument. 204 | 205 | ``` 206 | ./dllama inference ... --workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999 207 | ``` 208 | 209 |
210 | 211 | ## ✋ Contribution 212 | 213 | Feel free to contribute to this project. For small changes, simply create a new merge request. For larger changes, please create an issue to discuss your plans. Please follow these guidelines when contributing: 214 | 215 | * Make only minimal changes and avoid modifying files that are not necessary. 216 | * Ensure the code is compatible across all supported systems and CPUs. 217 | * This repository is maintained in English. 218 | 219 | ## 💡 License 220 | 221 | This project is released under the MIT license. 222 | 223 | ## 📖 Citation 224 | 225 | ``` 226 | @misc{dllama, 227 | author = {Bartłomiej Tadych}, 228 | title = {Distributed Llama}, 229 | year = {2024}, 230 | publisher = {GitHub}, 231 | journal = {GitHub repository}, 232 | howpublished = {\url{https://github.com/b4rtaz/distributed-llama}}, 233 | commit = {7eb77ca93ec0d502e28d36b6fb20039b449cbea4} 234 | } 235 | ``` 236 | -------------------------------------------------------------------------------- /converter/.gitignore: -------------------------------------------------------------------------------- 1 | *.t 2 | *.m 3 | *.bin 4 | */ 5 | -------------------------------------------------------------------------------- /converter/convert-hf.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import sys 4 | import os 5 | from writer import parseFloatType, writeTensor, writeHeader, FloatType 6 | from safetensors import safe_open 7 | 8 | class ArchType: 9 | LLAMA = 0xABCD00 10 | 11 | def permute(tensor, nHeads: int, nKvHeads: int): 12 | if nHeads != nKvHeads: 13 | nHeads = nKvHeads 14 | return (tensor.reshape(nHeads, 2, tensor.shape[0] // nHeads // 2, *tensor.shape[1:]).swapaxes(1, 2).reshape(tensor.shape)) 15 | 16 | class Processor: 17 | def __init__(self, config): 18 | self.config = config 19 | self.currentModelIndex = None 20 | self.currentModel = None 21 | self.currentModelKeys = None 22 | self.layerMap = {} 23 | self.plan = [] 24 | 25 | def __unloadModel(self): 26 | if self.currentModel: 27 | del self.currentModel 28 | self.currentModel = None 29 | gc.collect() 30 | 31 | def __loadModel(self, index: int): 32 | if (self.currentModelIndex == index): 33 | return 34 | self.__unloadModel() 35 | filePath = self.config['files'][index] 36 | fileName = os.path.basename(filePath) 37 | print(f'💿 Loading file {fileName}...') 38 | self.currentModel = safe_open(filePath, framework='pt', device='cpu') 39 | self.currentModelKeys = list(self.currentModel.keys()) 40 | for key in self.currentModelKeys: 41 | self.layerMap[key] = index 42 | print(f'Found {len(self.currentModelKeys)} layers') 43 | self.currentModelIndex = index 44 | 45 | def __permuteQ(self, tensor): 46 | return permute(tensor, self.config['n_heads'], self.config['n_heads']) 47 | 48 | def __permuteK(self, tensor): 49 | return permute(tensor, self.config['n_heads'], self.config['n_kv_heads']) 50 | 51 | def __preparePlan(self): 52 | wt = self.config['weights_float_type'] 53 | p = self.plan 54 | p.append([FloatType.F32, 55 | 'model.embed_tokens.weight']) 56 | for l in range(0, self.config['n_layers']): 57 | p.append([wt, self.__permuteQ, 58 | f'model.layers.{l}.self_attn.q_proj.weight']) 59 | p.append([wt, self.__permuteK, 60 | f'model.layers.{l}.self_attn.k_proj.weight']) 61 | p.append([wt, 62 | f'model.layers.{l}.self_attn.v_proj.weight']) 63 | p.append([wt, 64 | f'model.layers.{l}.self_attn.o_proj.weight']) 65 | 66 | if (self.config['n_experts'] > 0): 67 | for e in range(self.config['n_experts']): 68 | p.append([wt, 69 | f'model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight']) # up 70 | p.append([wt, 71 | f'model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight']) # gate 72 | p.append([wt, 73 | f'model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight']) # down 74 | else: 75 | p.append([wt, 76 | f'model.layers.{l}.mlp.gate_proj.weight']) # gate 77 | p.append([wt, 78 | f'model.layers.{l}.mlp.down_proj.weight']) # down 79 | p.append([wt, 80 | f'model.layers.{l}.mlp.up_proj.weight']) # up 81 | 82 | p.append([FloatType.F32, 83 | f'model.layers.{l}.input_layernorm.weight']) 84 | p.append([FloatType.F32, 85 | f'model.layers.{l}.post_attention_layernorm.weight']) 86 | p.append([FloatType.F32, 87 | 'model.norm.weight']) 88 | p.append([wt, 89 | 'lm_head.weight', 'model.embed_tokens.weight']) 90 | 91 | def write(self, outputFile: str): 92 | self.__preparePlan() 93 | for planItem in self.plan: 94 | lookup = planItem[1:] 95 | transform = None 96 | if (callable(lookup[0])): 97 | transform = lookup[0] 98 | lookup = lookup[1:] 99 | 100 | if (self.currentModelIndex == None): 101 | modelIndex = 0 102 | else: 103 | modelIndex = None 104 | for layerName in lookup: 105 | if (layerName in self.layerMap): 106 | modelIndex = self.layerMap[layerName] 107 | break 108 | if (modelIndex is None): 109 | modelIndex = self.currentModelIndex + 1 110 | self.__loadModel(modelIndex) 111 | 112 | tensor = None 113 | for layerName in lookup: 114 | if (layerName in self.currentModelKeys): 115 | tensor = self.currentModel.get_tensor(layerName) 116 | break 117 | if tensor is None: 118 | raise Exception(f'Layer {lookup[0]} not found') 119 | print(f'🔶 Writing tensor {layerName} {tensor.shape}...') 120 | 121 | floatType = planItem[0] 122 | if (transform): 123 | tensor = transform(tensor) 124 | writeTensor(outputFile, tensor, floatType) 125 | 126 | def parseArchType(type: str): 127 | archType = { 128 | 'llama': ArchType.LLAMA, 129 | 'mistral': ArchType.LLAMA, 130 | }.get(type) 131 | if (archType is None): 132 | raise Exception(f'Unsupported arch type: {type}') 133 | return archType 134 | 135 | def parseHiddenAct(act: str): 136 | hiddenAct = { 137 | 'gelu': 0, 138 | 'silu': 1 139 | }.get(act) 140 | if (hiddenAct is None): 141 | raise Exception(f'Unsupported hidden act: {act}') 142 | return hiddenAct 143 | 144 | def parseRopeType(rt: str): 145 | ropeType = { 146 | 'llama3': 2, # LLAMA3_1 147 | }.get(rt) 148 | if (ropeType is None): 149 | raise Exception(f'Unsupported rope type: {ropeType}') 150 | return ropeType 151 | 152 | def loadConfig(folderPath: str, weightsFloatType: int): 153 | allFiles = os.listdir(folderPath) 154 | allFiles.sort() 155 | with open(os.path.join(folderPath, 'config.json')) as fc: 156 | config = json.load(fc) 157 | files = [] 158 | for fileName in allFiles: 159 | if fileName.endswith('.safetensors') and not fileName.startswith('.'): 160 | files.append(os.path.join(folderPath, fileName)) 161 | if (len(files) == 0): 162 | raise Exception('Not found any model file') 163 | 164 | result = { 165 | 'version': 0, 166 | 'arch_type': parseArchType(config['model_type']), 167 | 'hidden_act': parseHiddenAct(config['hidden_act']), 168 | 'dim': config['hidden_size'], 169 | 'hidden_dim': config['intermediate_size'], 170 | 'n_layers': config['num_hidden_layers'], 171 | 'n_heads': config['num_attention_heads'], 172 | 'n_kv_heads': config['num_key_value_heads'], 173 | 'weights_float_type': weightsFloatType, 174 | 'max_seq_len': config['max_position_embeddings'], 175 | 'vocab_size': config['vocab_size'], 176 | 'files': files, 177 | } 178 | 179 | nExperts = config.get('num_local_experts') 180 | nActiveExperts = config.get('num_active_local_experts') or config.get('num_experts_per_tok') 181 | result['n_experts'] = int(nExperts) if nExperts is not None else 0 182 | result['n_active_experts'] = int(nActiveExperts) if nActiveExperts is not None else 0 183 | 184 | ropeTheta = config.get('rope_theta') 185 | if (ropeTheta is not None): 186 | result['rope_theta'] = int(ropeTheta) 187 | 188 | ropeScaling = config.get('rope_scaling') 189 | if (ropeScaling is not None): 190 | result['rope_scaling_factor'] = int(ropeScaling['factor']) 191 | result['rope_scaling_low_freq_factor'] = int(ropeScaling['low_freq_factor']) 192 | result['rope_scaling_high_freq_factory'] = int(ropeScaling['high_freq_factor']) 193 | result['rope_scaling_orig_max_seq_len'] = int(ropeScaling['original_max_position_embeddings']) 194 | result['rope_type'] = parseRopeType(ropeScaling['rope_type']) 195 | return result 196 | 197 | def printUsage(): 198 | print('Usage: python convert-hf.py ') 199 | print() 200 | print('Options:') 201 | print(' The path to the folder containing the model files') 202 | print(' The float type of the weights (e.g. "q40")') 203 | print(' The name of the model (e.g. "llama3")') 204 | 205 | if __name__ == '__main__': 206 | if (len(sys.argv) < 4): 207 | printUsage() 208 | exit(1) 209 | 210 | sourceFolderPath = sys.argv[1] 211 | weightsFloatType = parseFloatType(sys.argv[2]) 212 | name = sys.argv[3] 213 | outputFileName = f'dllama_model_{name}_{sys.argv[2]}.m' 214 | 215 | print(f'Output file: {outputFileName}') 216 | 217 | config = loadConfig(sourceFolderPath, weightsFloatType) 218 | 219 | with open(outputFileName, 'wb') as outputFile: 220 | writeHeader(outputFile, config) 221 | processor = Processor(config) 222 | processor.write(outputFile) 223 | 224 | print(f'✅ {outputFileName} created successfully') -------------------------------------------------------------------------------- /converter/convert-llama.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import math 6 | import numpy as np 7 | from writer import writeTensor, writeHeader, parseFloatType, strFloatType, FloatType 8 | from pathlib import Path 9 | 10 | LAYER_CHUNK_SIZE = 48 11 | 12 | def convert(modelPath, outputPath, targetFloatType): 13 | paramsPath = os.path.join(modelPath, 'params.json') 14 | with open(paramsPath) as f: 15 | params = json.load(f) 16 | if (params['vocab_size'] < 1): 17 | raise Exception('vocab_size is invalid, please update params.json file') 18 | if (params.get('max_seq_len') is None): 19 | raise Exception('max_seq_len is required, please update params.json file') 20 | params['n_kv_heads'] = params.get('n_kv_heads') or params['n_heads'] 21 | params['head_size'] = params['dim'] / params['n_heads'] 22 | params['arch_type'] = 0xABCD00 23 | params['n_experts'] = 0 24 | params['n_active_experts'] = 0 25 | params['weights_float_type'] = targetFloatType 26 | if ('rope_theta' in params): 27 | params['rope_theta'] = int(params['rope_theta']) 28 | 29 | modelPaths = sorted(list(Path(modelPath).glob('consolidated.*.pth'))) 30 | nSlices = len(modelPaths) 31 | 32 | layers = [] 33 | layers.append('tok_embeddings.weight') 34 | for layerIndex in range(0, params['n_layers']): 35 | layers.append(f'layers.{layerIndex}.attention.wq.weight') 36 | layers.append(f'layers.{layerIndex}.attention.wk.weight') 37 | layers.append(f'layers.{layerIndex}.attention.wv.weight') 38 | layers.append(f'layers.{layerIndex}.attention.wo.weight') 39 | layers.append(f'layers.{layerIndex}.feed_forward.w1.weight') 40 | layers.append(f'layers.{layerIndex}.feed_forward.w2.weight') 41 | layers.append(f'layers.{layerIndex}.feed_forward.w3.weight') 42 | layers.append(f'layers.{layerIndex}.attention_norm.weight') 43 | layers.append(f'layers.{layerIndex}.ffn_norm.weight') 44 | layers.append('norm.weight') 45 | layers.append('output.weight') 46 | 47 | isHeaderWrote = False 48 | outFile = open(outputPath, 'wb') 49 | 50 | nChunks = math.ceil(len(layers) / LAYER_CHUNK_SIZE) 51 | for chunkIndex in range(0, nChunks): 52 | chunkLayerNames = layers[LAYER_CHUNK_SIZE * chunkIndex:LAYER_CHUNK_SIZE * (chunkIndex + 1)] 53 | models = {} 54 | for layerName in chunkLayerNames: 55 | models[layerName] = [] 56 | 57 | print(f'💿 Chunking model {chunkIndex + 1}/{nChunks}...') 58 | 59 | for modelPath in modelPaths: 60 | model = torch.load(modelPath, map_location='cpu') 61 | for modelKey in model: 62 | if (modelKey in chunkLayerNames): 63 | models[modelKey].append(model[modelKey]) 64 | if not isHeaderWrote: 65 | params['hidden_dim'] = model['layers.0.feed_forward.w1.weight'].shape[0] * nSlices 66 | writeHeader(outFile, params) 67 | isHeaderWrote = True 68 | del model 69 | 70 | for layerName in chunkLayerNames: 71 | if layerName == 'rope.freqs': 72 | continue 73 | 74 | isAxis1 = ( 75 | layerName == 'tok_embeddings.weight' or 76 | layerName.endswith('.attention.wo.weight') or 77 | layerName.endswith('.feed_forward.w2.weight') 78 | ) 79 | isAlwaysF32 = ( 80 | layerName == 'tok_embeddings.weight' or 81 | layerName.endswith('.attention_norm.weight') or 82 | layerName.endswith('.ffn_norm.weight') or 83 | layerName == 'norm.weight' 84 | ) 85 | floatType = FloatType.F32 if isAlwaysF32 else targetFloatType 86 | 87 | tensors = models[layerName] 88 | if len(tensors) == 1 or len(tensors[0].shape) == 1: 89 | tensor = tensors[0] 90 | else: 91 | tensor = torch.cat(tensors, dim=(1 if isAxis1 else 0)) 92 | 93 | print(f'🔶 Exporting {layerName} {tensor.shape}...') 94 | writeTensor(outFile, tensor, floatType) 95 | 96 | del models 97 | 98 | outFile.close() 99 | 100 | def usage(): 101 | print('Usage: python convert-llama.py ') 102 | exit(1) 103 | 104 | if __name__ == '__main__': 105 | if (len(sys.argv) < 3): 106 | usage() 107 | 108 | modelPath = sys.argv[1] 109 | targetFloatType = parseFloatType(sys.argv[2]) 110 | targetFloatTypeStr = strFloatType(targetFloatType) 111 | 112 | modelName = os.path.basename(modelPath) 113 | outputFileName = f'dllama_model_{modelName.lower()}_{targetFloatTypeStr}.m' 114 | 115 | print(f'Model name: {modelName}') 116 | print(f'Target float type: {targetFloatTypeStr}') 117 | print(f'Target file: {outputFileName}') 118 | 119 | convert(modelPath, outputFileName, targetFloatType) 120 | 121 | print('Done!') 122 | -------------------------------------------------------------------------------- /converter/convert-tokenizer-hf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import os 4 | from sentencepiece import SentencePieceProcessor 5 | from transformers import PreTrainedTokenizerFast 6 | writer = __import__('tokenizer-writer') 7 | 8 | def openJson(path): 9 | with open(path, 'r', encoding='utf-8') as file: 10 | return json.load(file) 11 | 12 | def unicodeToBytes(): 13 | # https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9 14 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 15 | cs = bs[:] 16 | n = 0 17 | for b in range(2 ** 8): 18 | if b not in bs: 19 | bs.append(b) 20 | cs.append(2 ** 8 + n) 21 | n += 1 22 | cs = [chr(n) for n in cs] 23 | return dict(zip(cs, bs)) 24 | 25 | class TokensResolver: 26 | def __init__(self, dirPath, tokenizerConfig): 27 | self.dirPath = dirPath 28 | self.tokenizerConfig = tokenizerConfig 29 | self.bosId = None 30 | self.eosIds = None 31 | self.tokens = [] 32 | self.scores = [] 33 | 34 | def resolvePreTrainedTokenizerFast(self): 35 | utb = unicodeToBytes() 36 | tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(self.dirPath, 'tokenizer.json')) 37 | vocabLen = len(tokenizer.get_vocab()) 38 | for i in range(vocabLen): 39 | tokenChars = list(tokenizer.convert_ids_to_tokens([i])[0]) 40 | tokenBytes = [] 41 | for chr in tokenChars: 42 | if (chr in utb): 43 | tokenBytes.append(utb[chr]) 44 | else: 45 | tokenBytes += list(chr.encode('utf-8')) 46 | self.tokens.append(bytes(tokenBytes)) 47 | self.scores.append(-float(i)) 48 | 49 | self.bosId = tokenizer.bos_token_id 50 | if (tokenizer.eos_token_id): 51 | self.eosIds = [tokenizer.eos_token_id] 52 | if (self.bosId is None or self.eosId is None): 53 | config = openJson(os.path.join(self.dirPath, 'config.json')) 54 | if (self.bosId is None): 55 | self.bosId = config['bos_token_id'] 56 | if (self.eosIds is None): 57 | self.eosIds = config['eos_token_id'] 58 | if isinstance(self.eosIds, list): 59 | self.eosIds = self.eosIds[:2] # TODO: add support more than 2 eos ids 60 | else: 61 | self.eosIds = [self.eosIds] 62 | 63 | def resolveLlamaTokenizer(self): 64 | modelPath = os.path.join(self.dirPath, 'tokenizer.model') 65 | processor = SentencePieceProcessor(model_file=modelPath) 66 | 67 | assert processor.vocab_size() == processor.get_piece_size() 68 | self.bosId = processor.bos_id() 69 | self.eosIds = [processor.eos_id()] 70 | vocabSize = processor.vocab_size() 71 | for i in range(vocabSize): 72 | t = processor.id_to_piece(i) 73 | s = processor.get_score(i) 74 | t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace 75 | # Check for byte characters 76 | if len(t) == 6 and t.startswith('<0x') and t.endswith('>'): 77 | # For example, "<0x0A>"" is a newline character 78 | b = bytearray.fromhex(t[3:-1]) 79 | else: 80 | b = t.encode('utf-8') 81 | self.tokens.append(b) 82 | self.scores.append(s) 83 | 84 | def resolve(self): 85 | cls = self.tokenizerConfig['tokenizer_class'] 86 | if (cls == 'PreTrainedTokenizerFast' or cls == 'LlamaTokenizerFast'): 87 | return self.resolvePreTrainedTokenizerFast() 88 | if (cls == 'LlamaTokenizer'): 89 | return self.resolveLlamaTokenizer() 90 | raise Exception(f'Tokenizer {cls} is not supported') 91 | 92 | 93 | def printUsage(): 94 | print('Usage: python convert-tokenizer-hf.py ') 95 | print() 96 | print('Options:') 97 | print(' The path to the folder with tokenizer_config.json') 98 | print(' The name of the tokenizer (e.g. "llama3")') 99 | 100 | if __name__ == '__main__': 101 | if (len(sys.argv) < 2): 102 | printUsage() 103 | exit(1) 104 | 105 | dirPath = sys.argv[1] 106 | name = sys.argv[2] 107 | tokenizerConfig = openJson(os.path.join(dirPath, 'tokenizer_config.json')) 108 | 109 | resolver = TokensResolver(dirPath, tokenizerConfig) 110 | resolver.resolve() 111 | 112 | if (resolver.bosId is None or resolver.eosIds is None): 113 | raise Exception('Cannot resolve bosId or eosIds') 114 | print(f'bosId: {resolver.bosId} ({resolver.tokens[resolver.bosId]})') 115 | for eosId in resolver.eosIds: 116 | print(f'eosId: {eosId} ({resolver.tokens[eosId]})') 117 | 118 | chatTemplate = None 119 | chatExtraStop = None 120 | if ('chat_template' in tokenizerConfig): 121 | chatTemplate = tokenizerConfig['chat_template'].encode('utf-8') 122 | input = input('⏩ Enter value for chat extra stop (enter to skip): ') 123 | if (input != ''): 124 | chatExtraStop = input.encode('utf-8') 125 | 126 | outputFileName = f'dllama_tokenizer_{name}.t' 127 | with open(outputFileName, 'wb') as outputFile: 128 | writer.writeTokenizer(outputFile, { 129 | 'bos_id': resolver.bosId, 130 | 'eos_id': resolver.eosIds[0], 131 | 'chat_eos_id': resolver.eosIds[1 if len(resolver.eosIds) > 1 else 0], 132 | }, resolver.tokens, resolver.scores, chatTemplate, chatExtraStop) 133 | print(f'✅ Created {outputFileName}') 134 | -------------------------------------------------------------------------------- /converter/convert-tokenizer-llama2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from sentencepiece import SentencePieceProcessor 4 | writer = __import__('tokenizer-writer') 5 | 6 | chatTemplate = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" 7 | 8 | def printUsage(): 9 | print('Usage: python convert-tokenizer-llama2.py ') 10 | print() 11 | print('Options:') 12 | print(' The path to the folder with llama2 folder path') 13 | 14 | if __name__ == '__main__': 15 | if (len(sys.argv) < 2): 16 | printUsage() 17 | exit(1) 18 | 19 | dirPath = sys.argv[1] 20 | modelPath = os.path.join(dirPath, 'tokenizer.model') 21 | processor = SentencePieceProcessor(model_file=modelPath) 22 | 23 | vocabSize = processor.vocab_size() 24 | tokens = [] 25 | scores = [] 26 | for i in range(vocabSize): 27 | t = processor.id_to_piece(i) 28 | s = processor.get_score(i) 29 | t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace 30 | b = t.encode('utf-8') 31 | tokens.append(b) 32 | scores.append(s) 33 | 34 | outputFileName = 'dllama_tokenizer_llama2.t' 35 | with open(outputFileName, 'wb') as outputFile: 36 | writer.writeTokenizer(outputFile, { 37 | 'bos_id': processor.bos_id(), 38 | 'eos_id': processor.eos_id(), 39 | 'chat_eos_id': processor.eos_id(), 40 | }, tokens, scores, chatTemplate.encode('utf-8'), None) 41 | 42 | print(f'✅ Created {outputFileName}') 43 | -------------------------------------------------------------------------------- /converter/convert-tokenizer-llama3.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import base64 3 | writer = __import__('tokenizer-writer') 4 | 5 | # Format of input file: 6 | # ``` 7 | # IQ== 0 8 | # Ig== 1 9 | # Iw== 2 10 | # ... 11 | # ``` 12 | 13 | nSpecialTokens = 256 14 | specialTokens = [ 15 | '<|begin_of_text|>', 16 | '<|end_of_text|>', 17 | '<|reserved_special_token_0|>', 18 | '<|reserved_special_token_1|>', 19 | '<|reserved_special_token_2|>', 20 | '<|reserved_special_token_3|>', 21 | '<|start_header_id|>', 22 | '<|end_header_id|>', 23 | '<|reserved_special_token_4|>', 24 | '<|eot_id|>', 25 | ] + [ 26 | f'<|reserved_special_token_{i}|>' 27 | for i in range(5, nSpecialTokens - 5) 28 | ] 29 | bosId = 128000 30 | eosId = 128001 31 | chatEosId = 128009 32 | chatTemplate = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 33 | 34 | def printUsage(): 35 | print('Usage: python convert-tokenizer-llama3.py ') 36 | print() 37 | print('Options:') 38 | print(' The path to the Llama 3 tokenizer model (tokenizer.model)') 39 | 40 | if __name__ == '__main__': 41 | if (len(sys.argv) < 2): 42 | printUsage() 43 | exit(1) 44 | 45 | modelPath = sys.argv[1] 46 | outputFileName = 'dllama_tokenizer_llama3.t' 47 | 48 | with open(modelPath, 'r') as inputFile: 49 | with open(outputFileName, 'wb') as outputFile: 50 | inputLines = inputFile.readlines() 51 | nLines = len(inputLines) 52 | 53 | tokens = [] 54 | scores = [] 55 | for line in inputLines: 56 | s = line.split(' ') 57 | bytes = base64.b64decode(s[0]) 58 | score = -float(s[1]) 59 | tokens.append(bytes) 60 | scores.append(score) 61 | 62 | specialTokenIndex = nLines 63 | for token in specialTokens: 64 | bytes = token.encode('utf-8') 65 | score = -float(specialTokenIndex) 66 | tokens.append(bytes) 67 | scores.append(score) 68 | specialTokenIndex += 1 69 | 70 | writer.writeTokenizer(outputFile, { 71 | 'bos_id': bosId, 72 | 'eos_id': eosId, 73 | 'chat_eos_id': chatEosId, 74 | }, tokens, scores, chatTemplate.encode('utf-8'), None) 75 | 76 | print(f'✅ Created {outputFileName}') 77 | -------------------------------------------------------------------------------- /converter/requirements.txt: -------------------------------------------------------------------------------- 1 | python>=3.9 2 | numpy==1.23.5 3 | pytorch==2.0.1 4 | safetensors==0.4.2 5 | sentencepiece==0.1.99 -------------------------------------------------------------------------------- /converter/tokenizer-writer.py: -------------------------------------------------------------------------------- 1 | import struct 2 | 3 | def writeTokenizer(file, params, tokens, scores, chatTemplate, chatExtraStop): 4 | assert(params['eos_id'] is not None) 5 | assert(params['bos_id'] is not None) 6 | 7 | headerKeys = { 8 | 'version': 0, 9 | 'vocab_size': 1, 10 | 'max_token_length': 2, 11 | 'bos_id': 3, 12 | 'eos_id': 4, 13 | 'pad_id': 5, 14 | 'chat_eos_id': 6, 15 | 'chat_template': 7, 16 | 'chat_stop': 8 17 | } 18 | header = struct.pack('i', 0x567124) 19 | 20 | nTokens = len(tokens) 21 | maxTokenLength = max(len(t) for t in tokens) 22 | 23 | params['version'] = 1 24 | params['vocab_size'] = nTokens 25 | params['max_token_length'] = maxTokenLength 26 | if (chatTemplate): 27 | params['chat_template'] = len(chatTemplate) 28 | if (chatExtraStop): 29 | params['chat_stop'] = len(chatExtraStop) 30 | 31 | data = b'' 32 | for key in params: 33 | value = params[key] 34 | if value is None: 35 | continue 36 | if key in headerKeys: 37 | data += struct.pack('ii', headerKeys[key], params[key]) 38 | else: 39 | print(f'Unknown header key: {key}') 40 | 41 | print('⭐ Params:') 42 | print(params) 43 | if (chatTemplate): 44 | print('⭐ Chat template:') 45 | print(chatTemplate) 46 | 47 | header += struct.pack('i', len(header) * 2 + len(data)) 48 | file.write(header) 49 | file.write(data) 50 | if chatTemplate: 51 | file.write(chatTemplate) 52 | if chatExtraStop: 53 | file.write(chatExtraStop) 54 | 55 | for i in range(0, nTokens): 56 | size = len(tokens[i]) 57 | assert(size > 0) 58 | file.write(struct.pack('fI', scores[i], size)) 59 | file.write(tokens[i]) 60 | -------------------------------------------------------------------------------- /converter/writer-test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import torch 4 | from writer import writeQuantizedQ40Tensor 5 | 6 | TEMP_FILE_NAME = 'writer-test.temp' 7 | 8 | def readBase64FromFile(path): 9 | with open(path, 'rb') as file: 10 | return file.read().hex() 11 | 12 | def testWriteQuantizedQ40Tensor(): 13 | EXPECTED_OUTPUT = '7e346345a692b89665b2c5790537876e598aaa366d988876a898b8d788a98868ce660c66f6b3a88cba5ce9a871987ba9cc5bcaaa760c1eb556a4455b747b6b9504968828ef2a8d7c1db5c6be3764799e66db6d8e76463126a30e4333cad7a4f645947c6cf97f9de086d468c8d535a6ba7dc799d3d0c657bab6799468cad8bb349eb7d7635c7c798998696bb38e4085a9eb34444ba96a7f8ba7b2b42d746a96cf9660aeb4499d8708ad5c7b9a7558947645f3bbb6b0346a656887ad9a86059baac5c596ab781c703569bb8a4356a4bd58cb78736ba09759bb0e34a6274e827b957d7a67dfa86846955660d234b6d9d78a378094a8a8708a7a774ae92f8a36b8c999a9b77a7d958a69747c807963941235379886d69a7a8767b3a6a4ac71999760' 14 | 15 | torch.manual_seed(seed=1) 16 | tensor = torch.randn(32, 16) 17 | 18 | with open(TEMP_FILE_NAME, 'wb') as file: 19 | writeQuantizedQ40Tensor(file, tensor) 20 | 21 | contentBase64 = readBase64FromFile(TEMP_FILE_NAME) 22 | assert contentBase64 == EXPECTED_OUTPUT, f'Received: {contentBase64}' 23 | print('✅ writeQuantizedQ40Tensor') 24 | 25 | def runWriteQuantizedQ40TensorBenchmark(): 26 | tensor = torch.randn(8192, 4096) 27 | t0 = time.time() 28 | with open(TEMP_FILE_NAME, 'wb') as file: 29 | writeQuantizedQ40Tensor(file, tensor) 30 | t1 = time.time() 31 | print(f'🕐 writeQuantizedQ40Tensor: {t1 - t0:.4f}s') 32 | 33 | if __name__ == '__main__': 34 | testWriteQuantizedQ40Tensor() 35 | runWriteQuantizedQ40TensorBenchmark() 36 | -------------------------------------------------------------------------------- /converter/writer.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import torch 3 | import time 4 | import numpy as np 5 | 6 | class FloatType: 7 | F32 = 0 8 | F16 = 1 9 | Q40 = 2 10 | Q80 = 3 11 | 12 | floatTypeMap = { 13 | 'f32': FloatType.F32, 14 | 'f16': FloatType.F16, 15 | 'q40': FloatType.Q40, 16 | 'q80': FloatType.Q80, 17 | } 18 | floatTypeNames = list(floatTypeMap.keys()) 19 | 20 | def parseFloatType(type): 21 | floatType = floatTypeMap.get(type) 22 | if floatType is not None: 23 | return floatType 24 | raise Exception(f'{type} is not supported') 25 | 26 | def strFloatType(type): 27 | return floatTypeNames[type] 28 | 29 | def writeQuantizedQ40Tensor(file, x): 30 | x = x.to(torch.float32).numpy().astype(np.float32) 31 | blockSize = 32 32 | blockHalfSize = blockSize // 2 33 | assert(x.shape[0] % blockSize == 0) 34 | groups = x.reshape(-1, blockSize) 35 | gmax = np.max(groups, axis=1) 36 | gmin = np.min(groups, axis=1) 37 | deltas = np.divide(np.where(-gmin > gmax, gmin, gmax), -8) 38 | deltas16 = deltas.astype(np.float16) 39 | ids = np.where(deltas != 0, 1.0 / deltas, 0) 40 | groups = np.add(groups * ids[:, np.newaxis], 8.5) 41 | groups = np.clip(groups, 0, 15).astype(int) 42 | 43 | gLow = groups[:, :blockHalfSize] & 0xF 44 | gHigh = (groups[:, blockHalfSize:] & 0xF) << 4 45 | gCombined = gLow | gHigh 46 | 47 | nBytes = 0 48 | for groupIndex in range(0, len(groups)): 49 | delta16 = deltas16[groupIndex] 50 | buffer = struct.pack(f'e{blockHalfSize}B', delta16, *gCombined[groupIndex]) 51 | file.write(buffer) 52 | nBytes += len(buffer) 53 | return nBytes 54 | 55 | def writeQuantizedQ80Tensor(file, x): 56 | x = x.to(torch.float32).numpy().astype(np.float32) 57 | blockSize = 32 58 | assert(x.shape[0] % blockSize == 0) 59 | groups = x.reshape(-1, blockSize) 60 | gmax = np.max(groups, axis=1) 61 | gmin = np.min(groups, axis=1) 62 | gabsMax = np.where(-gmin > gmax, -gmin, gmax) 63 | deltas = gabsMax / ((1 << 7) - 1) 64 | deltas16 = deltas.astype(np.float16) 65 | ids = np.where(deltas != 0, 1.0 / deltas, 0) 66 | groups = groups * ids[:, np.newaxis] 67 | groups8 = np.round(groups).astype(np.int8) 68 | 69 | nBytes = 0 70 | for groupIndex in range(0, len(groups)): 71 | buffer = struct.pack(f'e{blockSize}b', deltas16[groupIndex], *groups8[groupIndex]) 72 | file.write(buffer) 73 | nBytes += len(buffer) 74 | return nBytes 75 | 76 | def writeF32Tensor(file, d): 77 | chunkSize = 10000 78 | nBytes = 0 79 | for i in range(0, len(d), chunkSize): 80 | chunk = d[i:i+chunkSize].to(torch.float32).numpy().astype(np.float32) 81 | b = struct.pack(f'{len(chunk)}f', *chunk) 82 | nBytes += len(b) 83 | file.write(b) 84 | return nBytes 85 | 86 | def writeF16Tensor(file, d): 87 | d = d.to(torch.float16).numpy().astype(np.float16) 88 | b = struct.pack(f'{len(d)}e', *d) 89 | file.write(b) 90 | return len(b) 91 | 92 | def writeTensor(file, tensor, floatType): 93 | d = tensor.detach().cpu().view(-1) 94 | t0 = time.time() 95 | nBytes = 0 96 | if (floatType == FloatType.F16): 97 | nBytes = writeF16Tensor(file, d) 98 | elif (floatType == FloatType.F32): 99 | nBytes = writeF32Tensor(file, d) 100 | elif (floatType == FloatType.Q40): 101 | nBytes = writeQuantizedQ40Tensor(file, d) 102 | elif (floatType == FloatType.Q80): 103 | nBytes = writeQuantizedQ80Tensor(file, d) 104 | else: 105 | raise Exception(f'Unknown float type') 106 | t1 = time.time() 107 | print(f'Saved {strFloatType(floatType)} tensor in {t1 - t0:.2f}s, {nBytes} bytes') 108 | 109 | def writeHeader(file, params): 110 | headerKeys = { 111 | 'version': 0, 112 | 'arch_type': 1, 113 | 'dim': 2, 114 | 'hidden_dim': 3, 115 | 'n_layers': 4, 116 | 'n_heads': 5, 117 | 'n_kv_heads': 6, 118 | 'n_experts': 7, 119 | 'n_active_experts': 8, 120 | 'vocab_size': 9, 121 | 'max_seq_len': 10, 122 | 'hidden_act': 11, 123 | 'rope_theta': 12, 124 | 'weights_float_type': 13, 125 | 'rope_scaling_factor': 14, 126 | 'rope_scaling_low_freq_factor': 15, 127 | 'rope_scaling_high_freq_factory': 16, 128 | 'rope_scaling_orig_max_seq_len': 17, 129 | 'rope_type': 18, 130 | } 131 | header = struct.pack('i', 0xA00ABCD) 132 | 133 | data = b'' 134 | for key in params: 135 | if key in headerKeys: 136 | data += struct.pack('ii', headerKeys[key], params[key]) 137 | else: 138 | print(f'Warning: Unknown header key: {key}') 139 | 140 | header += struct.pack('i', len(header) * 2 + len(data)) 141 | file.write(header) 142 | file.write(data) 143 | for key in params: 144 | print(f'🎓 {key}: {params[key]}') 145 | print() 146 | -------------------------------------------------------------------------------- /docs/HUGGINGFACE.md: -------------------------------------------------------------------------------- 1 | # How to Run Hugging Face 🤗 Model 2 | 3 | Currently, Distributed Llama supports three types of Hugging Face models: `llama`, `mistral`, and `mixtral`. You can try to convert any compatible Hugging Face model and run it with Distributed Llama. 4 | 5 | > [!IMPORTANT] 6 | > All converters are in the early stages of development. After conversion, the model may not work correctly. 7 | 8 | 1. Download a model, for example: [Mistral-7B-v0.3](https://huggingface.co/mistralai/Mistral-7B-v0.3/tree/main). 9 | 2. The downloaded model should contain `config.json`, `tokenizer.json`, `tokenizer_config.json` and `tokenizer.model` and safetensor files. 10 | 3. Run the converter of the model: 11 | ```sh 12 | cd converter 13 | python convert-hf.py path/to/hf/model q40 mistral-7b-0.3 14 | ``` 15 | 4. Run the converter of the tokenizer: 16 | ```sh 17 | python convert-tokenizer-hf.py path/to/hf/model mistral-7b-0.3 18 | ``` 19 | 5. That's it! Now you can run the Distributed Llama. 20 | ``` 21 | ./dllama inference --model dllama_model_mistral-7b-0.3_q40.m --tokenizer dllama_tokenizer_mistral-7b-0.3.t --buffer-float-type q80 --prompt "Hello world" 22 | ``` 23 | -------------------------------------------------------------------------------- /docs/LLAMA.md: -------------------------------------------------------------------------------- 1 | # How to Run Llama 2 | 3 | ## How to Run Llama 2 4 | 5 | 1. Download [Llama 2](https://github.com/facebookresearch/llama) weights from Meta. This project supports 7B, 7B-chat, 13B, 13B-chat, 70B and 70B-chat models. 6 | 2. Open the `llama-2-7b/params.json` file: 7 | * replace `"vocab_size": -1` to `"vocab_size": 32000`, 8 | * add a new property: `"max_seq_len": 2048`. 9 | 3. Install dependencies of the converter: 10 | ```sh 11 | cd converter && pip install -r requirements.txt 12 | ``` 13 | 4. Convert weights to Distributed Llama format. This will take a bit of time. The script requires Python 3. 14 | ```sh 15 | python convert-llama.py /path/to/meta/llama-2-7b q40 16 | ``` 17 | 5. Download the tokenizer for Llama 2: 18 | ``` 19 | wget https://huggingface.co/b4rtaz/Llama-2-Tokenizer-Distributed-Llama/resolve/main/dllama_tokenizer_llama2.t 20 | ``` 21 | 6. Build the project: 22 | ```bash 23 | make dllama 24 | make dllama-api 25 | ``` 26 | 7. Run: 27 | ```bash 28 | ./dllama inference --model dllama_llama-2-7b_q40.bin --tokenizer dllama-llama2-tokenizer.t --weights-float-type q40 --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4 29 | ``` 30 | 31 | In the table below, you can find the expected size of the converted weights with different floating-point types. 32 | 33 | | Model | Original size | Float32 | Float16 | Q40 | 34 | |-------------|---------------|----------|----------|----------| 35 | | Llama 2 7B | 13.48 GB | 25.10GB | | 3.95 GB | 36 | | Llama 2 13B | 26.03 GB | | | 7.35 GB | 37 | | Llama 2 70B | 137.97 GB | | | 36.98 GB | 38 | 39 | ## How to Run Llama 3 40 | 41 | 1. Get an access to the model on [Llama 3 website](https://llama.meta.com/llama-downloads). 42 | 2. Clone the `https://github.com/meta-llama/llama3` repository. 43 | 3. Run the `download.sh` script to download the model. 44 | 4. For Llama 3 8B model you should have the following files: 45 | - `Meta-Llama-3-8B/consolidated.00.pth` 46 | - `Meta-Llama-3-8B/params.json` 47 | - `Meta-Llama-3-8B/tokenizer.model` 48 | 5. Open `params.json` and add a new property: `"max_seq_len": 8192`. 49 | 6. Clone the `https://github.com/b4rtaz/distributed-llama.git` repository. 50 | 7. Install dependencies of the converter: 51 | ```sh 52 | cd converter && pip install -r requirements.txt 53 | ``` 54 | 8. Convert the model to the Distributed Llama format: 55 | ```bash 56 | python converter/convert-llama.py path/to/Meta-Llama-3-8B q40 57 | ``` 58 | 9. Convert the tokenizer to the Distributed Llama format: 59 | ```bash 60 | python converter/convert-tokenizer-llama3.py path/to/tokenizer.model 61 | ``` 62 | 10. Build the project: 63 | ```bash 64 | make dllama 65 | make dllama-api 66 | ``` 67 | 11. Run the Distributed Llama: 68 | ```bash 69 | ./dllama inference --weights-float-type q40 --buffer-float-type q80 --prompt "My name is" --steps 128 --nthreads 8 --model dllama_meta-llama-3-8b_q40.bin --tokenizer llama3-tokenizer.t 70 | ``` 71 | -------------------------------------------------------------------------------- /examples/chat-api-client.js: -------------------------------------------------------------------------------- 1 | // This is a simple client for dllama-api. 2 | // 3 | // Usage: 4 | // 5 | // 1. Start the server, how to do it is described in the `src/apps/dllama-api/README.md` file. 6 | // 2. Run this script: `node examples/chat-api-client.js` 7 | 8 | const HOST = process.env.HOST ? process.env.HOST : '127.0.0.1'; 9 | const PORT = process.env.PORT ? Number(process.env.PORT) : 9999; 10 | 11 | async function chat(messages, maxTokens) { 12 | const response = await fetch(`http://${HOST}:${PORT}/v1/chat/completions`, { 13 | method: 'POST', 14 | headers: { 15 | 'Content-Type': 'application/json', 16 | }, 17 | body: JSON.stringify({ 18 | messages, 19 | temperature: 0.7, 20 | stop: ['<|eot_id|>'], 21 | max_tokens: maxTokens 22 | }), 23 | }); 24 | return await response.json(); 25 | } 26 | 27 | async function ask(system, user, maxTokens) { 28 | console.log(`> system: ${system}`); 29 | console.log(`> user: ${user}`); 30 | const response = await chat([ 31 | { 32 | role: 'system', 33 | content: system 34 | }, 35 | { 36 | role: 'user', 37 | content: user 38 | } 39 | ], maxTokens); 40 | console.log(response.usage); 41 | console.log(response.choices[0].message.content); 42 | } 43 | 44 | async function main() { 45 | await ask('You are an excellent math teacher.', 'What is 1 + 2?', 128); 46 | await ask('You are a romantic.', 'Where is Europe?', 128); 47 | } 48 | 49 | main(); 50 | -------------------------------------------------------------------------------- /examples/macbeth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This is a simple test of generating a sequence that fulfills the KV cache. 4 | # 5 | # Used model & tokenizer: https://huggingface.co/b4rtaz/llama-3-8b-distributed-llama 6 | # Probably, this test will be working correctly only on MacBook Pro M1, due to differences in float multiplication on different CPUs. 7 | 8 | cd "$(dirname "$0")" 9 | cd .. 10 | 11 | # Source: https://www.opensourceshakespeare.org/views/plays/play_view.php?WorkID=macbeth&Scope=entire 12 | PROMPT="Duncan. What bloody man is that? He can report, 13 | As seemeth by his plight, of the revolt 14 | The newest state. 20 15 | 16 | Malcolm. This is the sergeant 17 | Who like a good and hardy soldier fought 18 | 'Gainst my captivity. Hail, brave friend! 19 | Say to the king the knowledge of the broil 20 | As thou didst leave it. 25 21 | 22 | Sergeant. Doubtful it stood; 23 | As two spent swimmers, that do cling together 24 | And choke their art. The merciless Macdonwald— 25 | Worthy to be a rebel, for to that 26 | The multiplying villanies of nature 30 27 | Do swarm upon him—from the western isles 28 | Of kerns and gallowglasses is supplied; 29 | And fortune, on his damned quarrel smiling, 30 | Show'd like a rebel's whore: but all's too weak: 31 | For brave Macbeth—well he deserves that name— 35 32 | Disdaining fortune, with his brandish'd steel, 33 | Which smoked with bloody execution, 34 | Like valour's minion carved out his passage 35 | Till he faced the slave; 36 | Which ne'er shook hands, nor bade farewell to him, 40 37 | Till he unseam'd him from the nave to the chaps, 38 | And fix'd his head upon our battlements. 39 | 40 | Duncan. O valiant cousin! worthy gentleman! 41 | 42 | Sergeant. As whence the sun 'gins his reflection 43 | Shipwrecking storms and direful thunders break, 45 44 | So from that spring whence comfort seem'd to come 45 | Discomfort swells. Mark, king of Scotland, mark: 46 | No sooner justice had with valour arm'd 47 | Compell'd these skipping kerns to trust their heels, 48 | But the Norweyan lord surveying vantage, 50 49 | With furbish'd arms and new supplies of men 50 | Began a fresh assault. 51 | 52 | Duncan. Dismay'd not this 53 | Our captains, Macbeth and Banquo? 54 | 55 | Sergeant. Yes; 55 56 | As sparrows eagles, or the hare the lion. 57 | If I say sooth, I must report they were 58 | As cannons overcharged with double cracks, so they 59 | Doubly redoubled strokes upon the foe: 60 | Except they meant to bathe in reeking wounds, 60 61 | Or memorise another Golgotha, 62 | I cannot tell. 63 | But I am faint, my gashes cry for help. 64 | 65 | Duncan. So well thy words become thee as thy wounds; 66 | They smack of honour both. Go get him surgeons. 65 67 | [Exit Sergeant, attended] 68 | Who comes here?" 69 | 70 | GENERATED="Malcolm. The worthy Thane of Ross. 71 | Duncan. What a haste looks through a duel's wounds! 70 72 | Some must be pac'd. 73 | [Exit Ross] 74 | See this encounter is like to the poring 75 | On of a beggar's story, told by one 76 | That means to pluck upon the heart the strings 77 | And draw the tears thriftily. 75 78 | [Enter Lennox] 79 | How goes the night, boy? 80 | 81 | Lennox. The night is long that none should wake. 82 | 83 | Duncan. You do not need to stare. The Moor 84 | To know the man. 'Tis the Moors devices. 80 85 | [Exit Lennox] 86 | By the happy right of mine own hands, 87 | Strike all that live in this poor thing of mine. 88 | 'Tis calld the Eyrie, and I am sick at heart. 89 | As hellish-devils do the damned souls 90 | O'their bad lives, thus ill-breveted, linger 91 | O'er lamps and forks and other instruments 92 | That prove the stages of the night. 90 93 | Good sir, take note; I bid you farewell: 94 | Come sleep, and cut short this nitty romance. 95 | [He sleeps.] 96 | If cravens, I bear them like the Minion of the moon, 97 | With tiptoe foot he sneaks and starts to be a man. 95 98 | And when he is found asleep, awake him with this armed' s address: 99 | That sleep which th'assassin hallowed, 100 | Scotland, awake; your king is murder'd, sleep no more. 100 101 | *Furbish'd. Weapons polished for battle. 102 | *Thriftily. Fastidiously, thoughtfully. 103 | *Eyrie. Fortress; the lair of birds of prey. 104 | *Minion. A braggart, a coward. 105 | 106 | 1.5 107 | 108 | Macbeth. So foul and fair a day I have not seen. 5 109 | Ross. Good morning, noble Macbeth. I come from Inverness, 110 | And find our throne void, the arm'd rest you; 10 111 | My Lord of Cassil has resigned his life. 112 | Macbeth. Whate'er you owe, in time repay, fair friends. 113 | Note you the words; I pray you do. 114 | Ross. I am your faithful servant, and will keep 115 | My sworn reward upon your life; my lord. 116 | Macbeth. You shall be well rewarded; stay the press, 20 117 | And I'll not fail. How now, good fellow? 118 | Servant. Sir, his schoolmaster. 25 119 | Macbeth. Well, good, though, old. 120 | Tell me, good fellow, how goes the night? 30 121 | Servant. There's marrygold and fire in your veins, my lord. 122 | Macbeth. He does commend you; the weight of this old night's embargoes 35 123 | Did one hour's waste of time lay upon him. 124 | I know when we are too safe, 'tis dangerous to be secure; 125 | Therefore our fearful parts do brave the danger 40 126 | Which knows it not. I see you are a gentleman. 127 | And a laudable one too; I am most off obliged. 128 | Servant. I should be sorry, my good lord, to have had the labour 45 129 | To outlive this damned hour. 50 130 | Macbeth. What's done cannot be undone. To bed, to bed, to bed. 131 | Servant. Will it please you to lie still? 55 132 | Macbeth. Lord, lord, my heart is in my mouth. All's true that ends well. 133 | Servant. I thank you, fair, and leave you to the content. 60 134 | Macbeth. You see, my lord, it smokes, and shows no cause 135 | Why the drone dies. 65 136 | Servant. Grief fills the room up of one vast stair, 137 | And downs our vaults to the inconstant man above. 70 138 | Macbeth. Go bid thy masters and thy mistress say, 75 139 | I have power in earth to do so much. 140 | There's comfort yet. They are assailable. Then say I, 141 | Thus ye may answer. 142 | Servant. He cannot be wronged; or being wronged, 80 143 | I cannot help him. 85 144 | Macbeth. You know but by this; as this, 90 145 | The Jew foole is hang'd. 95 146 | Servant. No more today, my lord. 100 147 | Macbeth. He does shame to tell him he loves him, but not remove him 105 148 | From his true place; no. 149 | Servant. That's true, and now I remember the story 110 150 | Of that sign in Leo four diurnal courses 151 | Returning in a constant motion were within 115 152 | A boare that had on Taurus' back tetracted; 120 153 | Or neuer, or but once in modulated accidence. 125 154 | Macbeth. Thou climd'st alone, ty'd to the stag's horn. 155 | Servant. I was a bull, for this the goodly year. 130 156 | Come, put me in my place. 157 | Macbeth. Now go to sleep. 135 158 | Servant. The west neuer sett before the equinox 140 159 | Till now; and sunnes look'd not theyr frequencie 145 160 | Upon our lappe till now, my lord. 150 161 | Macbeth. This game of chance you term a gong. 162 | Servant. A gong is a scotch word for an egg. 155 163 | Macbeth. Peace, be still. 160 164 | Servant. I coniecture I smell the blood of an Englishman. 165 165 | Macbeth. The faith is murthered. 166 | Servant. That murder'd in his sleep. 170 167 | Macbeth. And sleeping murdered. 175 168 | Servant. In the fair queen heere in his royal court. 180 169 | Macbeth. So great a mercy that it may last eternally. 170 | Servant. The earth hath bubbles as the water hath, 185 171 | And these are of them. Whate'er we will do 190 172 | To mend the trespasses of the comming time 195 173 | Shall be the seedes of new mischefe, and shall beget 200 174 | The formes of the extinctnese, which we are now. 205 175 | Macbeth. We have scorch'd the snake, not kill'd it. 210 176 | Servant. They hunt it in the morn. Good gally, good lord! 215 177 | It weares a gilded snout. 220 178 | Macbeth. It is the very painting of your fear. 225 179 | Servant. This is the worst. 230 180 | Macbeth. A fair quater of a mile is yet to go. 235 181 | Servant. A mile and half. 240 182 | Macbeth. I have run fifteen miles to-day. 183 | Servant. A calender's date. 184 | Macbeth. A bigger patch, a bigger patch. 245 185 | Servant. Thirteen of more. 250 186 | Macbeth. Wast thou with him? 255 187 | Servant. No, nor he to night. 260 188 | Macbeth. Thou seest the moon" 189 | 190 | echo "Generating, it can take a while..." 191 | 192 | OUTPUT=$(( ./dllama generate --seed 12345 --temperature 0.9 --topp 0.9 --prompt "$PROMPT" --weights-float-type q40 --buffer-float-type f32 --nthreads 2 --steps 2048 --model models/llama3_8b_q40/dllama_model_llama3_8b_q40.m --tokenizer models/llama3_8b_q40/dllama_tokenizer_llama3_8b_q40.t --workers 127.0.0.1:9999 127.0.0.1:9998 127.0.0.1:9997 ) 2>&1) 193 | 194 | echo "$OUTPUT" 195 | 196 | if [[ $OUTPUT == *"$GENERATED"* ]]; then 197 | echo "✅ Output is same" 198 | else 199 | echo "❌ Output is different" 200 | fi 201 | -------------------------------------------------------------------------------- /examples/n-workers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script starts N workers from a single command. Mainly useful for testing and debugging. 4 | # Usage: 5 | # 6 | # W=7 T=2 bash n-workers.sh start 7 | # W=7 bash n-workers.sh stop 8 | # 9 | # Env vars: 10 | # W - n workers 11 | # T - n threads per worker 12 | 13 | cd "$(dirname "$0")" 14 | 15 | if [ -z "$W" ]; then 16 | W=3 17 | fi 18 | if [ -z "$T" ]; then 19 | T=1 20 | fi 21 | 22 | if [ "$1" == "start" ]; then 23 | for (( w = 0; w < $W ; w += 1 )); 24 | do 25 | PORT=$(expr 9999 - $w) 26 | PROC_ID=$(lsof -ti:$PORT) 27 | if [ -n "$PROC_ID" ]; then 28 | kill -9 $PROC_ID 29 | echo "Killed process $PROC_ID" 30 | fi 31 | 32 | mkdir -p dllama_worker_$w # macOs does not support -Logfile argument, so we place logs inside different directories 33 | cd dllama_worker_$w 34 | screen -d -L -S dllama_worker_$w -m ../../dllama worker --port $PORT --nthreads $T 35 | cd .. 36 | echo "Started worker $w on port $PORT" 37 | done 38 | 39 | sleep 2 40 | elif [ "$1" == "stop" ]; then 41 | for (( w = 0; w < $W ; w += 1 )); 42 | do 43 | screen -S dllama_worker_$w -X quit 44 | done 45 | 46 | echo "Stopped $W workers" 47 | else 48 | echo "Usage: $0 [start|stop]" 49 | fi 50 | 51 | echo "> screen -ls" 52 | screen -ls 53 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import multiprocessing 5 | from urllib.request import urlopen 6 | 7 | def parts(length): 8 | result = [] 9 | for i in range(length): 10 | a = chr(97 + (i // 26)) 11 | b = chr(97 + (i % 26)) 12 | result.append(a + b) 13 | return result 14 | 15 | # [['model-url-0', 'model-url-1', ...], 'tokenizer-url', 'weights-float-type', 'buffer-float-type', 'model-type'] 16 | MODELS = { 17 | 'llama3_1_8b_instruct_q40': [ 18 | ['https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.1_instruct_q40.m?download=true'], 19 | 'https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama_3_1.t?download=true', 20 | 'q40', 'q80', 'chat', '--max-seq-len 4096' 21 | ], 22 | 'llama3_1_405b_instruct_q40': [ 23 | list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Llama-3_1-405B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama31_405b_q40_{suffix}?download=true', parts(56))), 24 | 'https://huggingface.co/b4rtaz/Llama-3_1-405B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama_3_1.t?download=true', 25 | 'q40', 'q80', 'chat', '--max-seq-len 4096' 26 | ], 27 | 'llama3_2_1b_instruct_q40': [ 28 | ['https://huggingface.co/b4rtaz/Llama-3_2-1B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.2-1b-instruct_q40.m?download=true'], 29 | 'https://huggingface.co/b4rtaz/Llama-3_2-1B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3_2.t?download=true', 30 | 'q40', 'q80', 'chat', '--max-seq-len 4096' 31 | ], 32 | 'llama3_2_3b_instruct_q40': [ 33 | ['https://huggingface.co/b4rtaz/Llama-3_2-3B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.2-3b-instruct_q40.m?download=true'], 34 | 'https://huggingface.co/b4rtaz/Llama-3_2-3B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3_2.t?download=true', 35 | 'q40', 'q80', 'chat', '--max-seq-len 4096' 36 | ], 37 | 'llama3_3_70b_instruct_q40': [ 38 | list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Llama-3_3-70B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama-3.3-70b_q40{suffix}?download=true', parts(11))), 39 | 'https://huggingface.co/b4rtaz/Llama-3_3-70B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama-3.3-70b.t?download=true', 40 | 'q40', 'q80', 'chat', '--max-seq-len 4096' 41 | ], 42 | 'deepseek_r1_distill_llama_8b_q40': [ 43 | ['https://huggingface.co/b4rtaz/DeepSeek-R1-Distill-Llama-8B-Distributed-Llama/resolve/main/dllama_model_deepseek-r1-distill-llama-8b_q40.m?download=true'], 44 | 'https://huggingface.co/b4rtaz/DeepSeek-R1-Distill-Llama-8B-Distributed-Llama/resolve/main/dllama_tokenizer_deepseek-r1-distill-llama-8b.t?download=true', 45 | 'q40', 'q80', 'chat', '--max-seq-len 4096' 46 | ], 47 | } 48 | 49 | def confirm(message: str): 50 | result = input(f'❓ {message} ("Y" if yes): ').upper() 51 | return result == 'Y' or result == 'YES' 52 | 53 | def downloadFile(urls, path: str): 54 | if os.path.isfile(path): 55 | fileName = os.path.basename(path) 56 | if not confirm(f'{fileName} already exists, do you want to download again?'): 57 | return 58 | 59 | lastSizeMb = 0 60 | with open(path, 'wb') as file: 61 | for url in urls: 62 | startPosition = file.tell() 63 | success = False 64 | for attempt in range(8): 65 | print(f'📄 {url} (attempt: {attempt})') 66 | try: 67 | with urlopen(url) as response: 68 | while True: 69 | chunk = response.read(4096) 70 | if not chunk: 71 | break 72 | file.write(chunk) 73 | sizeMb = file.tell() // (1024 * 1024) 74 | if sizeMb != lastSizeMb: 75 | sys.stdout.write("\rDownloaded %i MB" % sizeMb) 76 | lastSizeMb = sizeMb 77 | sys.stdout.write('\n') 78 | success = True 79 | break 80 | except Exception as e: 81 | print(f'\n❌ Error downloading {url}: {e}') 82 | file.seek(startPosition) 83 | file.truncate() 84 | time.sleep(1 * attempt) 85 | if not success: 86 | raise Exception(f'Failed to download {url}') 87 | sys.stdout.write(' ✅\n') 88 | 89 | def download(modelName: str, model: list): 90 | dirPath = os.path.join('models', modelName) 91 | print(f'📀 Downloading {modelName} to {dirPath}...') 92 | os.makedirs(dirPath, exist_ok=True) 93 | modelUrls = model[0] 94 | tokenizerUrl = model[1] 95 | modelPath = os.path.join(dirPath, f'dllama_model_{modelName}.m') 96 | tokenizerPath = os.path.join(dirPath, f'dllama_tokenizer_{modelName}.t') 97 | downloadFile(modelUrls, modelPath) 98 | downloadFile([tokenizerUrl], tokenizerPath) 99 | print('📀 All files are downloaded') 100 | return (modelPath, tokenizerPath) 101 | 102 | def writeRunFile(modelName: str, command: str): 103 | filePath = f'run_{modelName}.sh' 104 | with open(filePath, 'w') as file: 105 | file.write('#!/bin/sh\n') 106 | file.write('\n') 107 | file.write(f'{command}\n') 108 | return filePath 109 | 110 | def printUsage(): 111 | print('Usage: python download-model.py ') 112 | print() 113 | print('Options:') 114 | print(' The name of the model to download') 115 | print(' --run Run the model after download') 116 | print() 117 | print('Available models:') 118 | for model in MODELS: 119 | print(f' {model}') 120 | 121 | if __name__ == '__main__': 122 | if (len(sys.argv) < 2): 123 | printUsage() 124 | exit(1) 125 | 126 | os.chdir(os.path.dirname(__file__)) 127 | 128 | modelName = sys.argv[1].replace('-', '_') 129 | if modelName not in MODELS: 130 | print(f'Model is not supported: {modelName}') 131 | exit(1) 132 | runAfterDownload = sys.argv.count('--run') > 0 133 | 134 | model = MODELS[modelName] 135 | (modelPath, tokenizerPath) = download(modelName, model) 136 | 137 | nThreads = multiprocessing.cpu_count() 138 | if (model[4] == 'chat'): 139 | command = './dllama chat' 140 | else: 141 | command = './dllama inference --steps 64 --prompt "Hello world"' 142 | command += f' --model {modelPath} --tokenizer {tokenizerPath} --buffer-float-type {model[3]} --nthreads {nThreads}' 143 | if (len(model) > 5): 144 | command += f' {model[5]}' 145 | 146 | print('To run Distributed Llama you need to execute:') 147 | print('--- copy start ---') 148 | print() 149 | print('\033[96m' + command + '\033[0m') 150 | print() 151 | print('--- copy end -----') 152 | 153 | runFilePath = writeRunFile(modelName, command) 154 | print(f'🌻 Created {runFilePath} script to easy run') 155 | 156 | if (not runAfterDownload): 157 | runAfterDownload = confirm('Do you want to run Distributed Llama?') 158 | if (runAfterDownload): 159 | if (not os.path.isfile('dllama')): 160 | os.system('make dllama') 161 | os.system(command) 162 | -------------------------------------------------------------------------------- /report/report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/b4rtaz/distributed-llama/a16d2f03e66437088dce2ba4b82304a8101c074f/report/report.pdf -------------------------------------------------------------------------------- /src/api-types.hpp: -------------------------------------------------------------------------------- 1 | #ifndef API_TYPES_HPP 2 | #define API_TYPES_HPP 3 | 4 | #include 5 | 6 | #include "json.hpp" 7 | 8 | using json = nlohmann::json; 9 | 10 | struct ChatMessageDelta { 11 | std::string role; 12 | std::string content; 13 | 14 | ChatMessageDelta() : role(""), content("") {} 15 | ChatMessageDelta(const std::string& role_, const std::string& content_) : role(role_), content(content_) {} 16 | }; 17 | 18 | struct ChatMessage { 19 | std::string role; 20 | std::string content; 21 | 22 | ChatMessage() : role(""), content("") {} 23 | ChatMessage(const std::string& role_, const std::string& content_) : role(role_), content(content_) {} 24 | }; 25 | 26 | struct ChunkChoice { 27 | int index; 28 | ChatMessageDelta delta; 29 | std::string finish_reason; 30 | 31 | ChunkChoice() : index(0) {} 32 | }; 33 | 34 | 35 | struct Choice { 36 | int index; 37 | ChatMessage message; 38 | std::string finish_reason; 39 | 40 | Choice() : finish_reason("") {} 41 | Choice(ChatMessage &message_) : message(message_), finish_reason("") {} 42 | Choice(const std::string &reason_) : finish_reason(reason_) {} 43 | }; 44 | 45 | struct ChatCompletionChunk { 46 | std::string id; 47 | std::string object; 48 | long long created; 49 | std::string model; 50 | std::vector choices; 51 | 52 | ChatCompletionChunk(ChunkChoice &choice_) 53 | : id("cmpl-c0"), object("chat.completion"), model("Distributed Model") { 54 | created = std::time(nullptr); // Set created to current Unix timestamp 55 | choices.push_back(choice_); 56 | } 57 | }; 58 | 59 | // Struct to represent the usage object 60 | struct ChatUsage { 61 | int prompt_tokens; 62 | int completion_tokens; 63 | int total_tokens; 64 | 65 | ChatUsage() : prompt_tokens(0), completion_tokens(0), total_tokens(0) {} 66 | ChatUsage(int pt, int ct, int tt) : prompt_tokens(pt), completion_tokens(ct), total_tokens(tt) {} 67 | }; 68 | 69 | struct ChatCompletion { 70 | std::string id; 71 | std::string object; 72 | long long created; // Unix timestamp 73 | std::string model; 74 | std::vector choices; 75 | ChatUsage usage; 76 | 77 | ChatCompletion() : id(), object(), model() {} 78 | ChatCompletion(const Choice &choice_, const ChatUsage& usage_) 79 | : id("cmpl-j0"), object("chat.completion"), model("Distributed Model"), usage(usage_) { 80 | created = std::time(nullptr); // Set created to current Unix timestamp 81 | choices.push_back(choice_); 82 | } 83 | }; 84 | 85 | struct Model { 86 | std::string id; 87 | std::string object; 88 | long long created; 89 | std::string owned_by; 90 | 91 | Model() : id(), object(), created(0), owned_by() {} 92 | Model(const std::string &id_) : id(id_), object("model"), created(0), owned_by("user") {} 93 | }; 94 | 95 | struct ModelList { 96 | std::string object; 97 | std::vector data; 98 | ModelList(): object("list") {} 99 | ModelList(const Model &model_) : object("list") { 100 | data.push_back(model_); 101 | } 102 | }; 103 | 104 | struct InferenceParams { 105 | std::vector messages; 106 | int max_tokens; 107 | float temperature; 108 | float top_p; 109 | std::vector stop; 110 | bool stream; 111 | unsigned long long seed; 112 | }; 113 | 114 | // Define to_json for Delta struct 115 | void to_json(json& j, const ChatMessageDelta& msg) { 116 | j = json{{"role", msg.role}, {"content", msg.content}}; 117 | } 118 | 119 | void to_json(json& j, const ChatMessage& msg) { 120 | j = json{{"role", msg.role}, {"content", msg.content}}; 121 | } 122 | 123 | void to_json(json& j, const ChunkChoice& choice) { 124 | j = json{{"index", choice.index}, {"delta", choice.delta}, {"finish_reason", choice.finish_reason}}; 125 | } 126 | 127 | void to_json(json& j, const Choice& choice) { 128 | j = json{{"index", choice.index}, {"message", choice.message}, {"finish_reason", choice.finish_reason}}; 129 | } 130 | 131 | void to_json(json& j, const ChatCompletionChunk& completion) { 132 | j = json{{"id", completion.id}, 133 | {"object", completion.object}, 134 | {"created", completion.created}, 135 | {"model", completion.model}, 136 | {"choices", completion.choices}}; 137 | } 138 | 139 | void to_json(json& j, const ChatUsage& usage) { 140 | j = json{{"completion_tokens", usage.completion_tokens}, 141 | {"prompt_tokens", usage.prompt_tokens}, 142 | {"total_tokens", usage.total_tokens}}; 143 | } 144 | 145 | void to_json(json& j, const ChatCompletion& completion) { 146 | j = json{{"id", completion.id}, 147 | {"object", completion.object}, 148 | {"created", completion.created}, 149 | {"model", completion.model}, 150 | {"usage", completion.usage}, 151 | {"choices", completion.choices}}; 152 | } 153 | 154 | void to_json(json& j, const Model& model) { 155 | j = json{{"id", model.id}, 156 | {"object", model.object}, 157 | {"created", model.created}, 158 | {"owned_by", model.owned_by}}; 159 | } 160 | 161 | void to_json(json& j, const ModelList& models) { 162 | j = json{{"object", models.object}, 163 | {"data", models.data}}; 164 | } 165 | 166 | std::vector parseChatMessages(json &json){ 167 | std::vector messages; 168 | messages.reserve(json.size()); 169 | 170 | for (const auto& item : json) { 171 | messages.emplace_back( 172 | item["role"].template get(), 173 | item["content"].template get() 174 | ); 175 | } 176 | return messages; 177 | } 178 | 179 | #endif 180 | -------------------------------------------------------------------------------- /src/app.hpp: -------------------------------------------------------------------------------- 1 | #ifndef APP_HPP 2 | #define APP_HPP 3 | 4 | #include 5 | #include "nn/nn-core.hpp" 6 | #include "nn/nn-cpu.hpp" 7 | #include "tokenizer.hpp" 8 | #include "llm.hpp" 9 | 10 | class AppCliArgs { 11 | public: 12 | char *mode; 13 | NnUint nThreads; 14 | NnUint nBatches; 15 | bool help; 16 | 17 | // inference 18 | char *modelPath; 19 | char *tokenizerPath; 20 | char *prompt; 21 | NnFloatType syncType; 22 | NnUint nWorkers; 23 | char **workerHosts; 24 | NnUint *workerPorts; 25 | float temperature; 26 | float topp; 27 | NnUint steps; 28 | bool benchmark; 29 | unsigned long long seed; 30 | ChatTemplateType chatTemplateType; 31 | NnUint maxSeqLen; 32 | bool netTurbo; 33 | int gpuIndex; 34 | int gpuSegmentFrom; 35 | int gpuSegmentTo; 36 | 37 | // worker 38 | NnUint port; 39 | 40 | static AppCliArgs parse(int argc, char **argv, bool hasMode); 41 | ~AppCliArgs(); 42 | }; 43 | 44 | typedef struct { 45 | NnUint position; 46 | NnUint batchSize; // 0 = stop signal 47 | } LlmControlPacket; 48 | 49 | class RootLlmInference { 50 | public: 51 | float *logitsPipe; 52 | private: 53 | float *tokenPipe; 54 | float *positionPipe; 55 | LlmHeader *header; 56 | NnNetExecution *execution; 57 | NnExecutor *executor; 58 | NnNetwork *network; 59 | LlmControlPacket controlPacket; 60 | public: 61 | RootLlmInference(LlmNet *net, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network); 62 | void setBatchSize(NnUint batchSize); 63 | void setPosition(NnUint position); 64 | void setToken(NnUint batchIndex, NnUint token); 65 | void forward(); 66 | void finish(); 67 | }; 68 | 69 | class WorkerLlmInference { 70 | public: 71 | bool isFinished; 72 | private: 73 | float *positionPipe; 74 | NnNetExecution *execution; 75 | NnNetwork *network; 76 | LlmControlPacket controlPacket; 77 | public: 78 | WorkerLlmInference(NnNetExecution *execution, NnNetwork *network); 79 | bool tryReadControlPacket(); 80 | }; 81 | 82 | typedef struct { 83 | AppCliArgs *args; 84 | LlmHeader *header; 85 | RootLlmInference *inference; 86 | Tokenizer *tokenizer; 87 | Sampler *sampler; 88 | NnNetwork *network; 89 | NnExecutor *executor; 90 | } AppInferenceContext; 91 | 92 | void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *context)); 93 | void runWorkerApp(AppCliArgs *args); 94 | 95 | #endif 96 | -------------------------------------------------------------------------------- /src/dllama.cpp: -------------------------------------------------------------------------------- 1 | #include "nn/nn-core.hpp" 2 | #include "nn/nn-config-builder.hpp" 3 | #include "nn/nn-cpu.hpp" 4 | #include "nn/nn-network.hpp" 5 | #include "nn/nn-executor.hpp" 6 | #include "llm.hpp" 7 | #include "tokenizer.hpp" 8 | #include "app.hpp" 9 | #include 10 | 11 | static void inference(AppInferenceContext *context) { 12 | if (context->args->prompt == nullptr) 13 | throw std::runtime_error("Prompt is required"); 14 | if (context->args->steps == 0) 15 | throw std::runtime_error("Number of steps is required"); 16 | 17 | std::vector inputTokensVec(std::strlen(context->args->prompt) + 3); 18 | int *inputTokens = inputTokensVec.data(); 19 | 20 | NnUint pos = 0; 21 | int token; 22 | int nInputTokens; 23 | context->tokenizer->encode(context->args->prompt, inputTokens, &nInputTokens, true, false); 24 | 25 | if (nInputTokens > context->header->seqLen) 26 | throw std::runtime_error("The number of prompt tokens is greater than the sequence length"); 27 | if (nInputTokens > context->args->steps) 28 | throw std::runtime_error("The number of prompt tokens is greater than the number of steps"); 29 | 30 | NnSize sentBytes = 0; 31 | NnSize recvBytes = 0; 32 | NnUint evalTotalTime = 0; 33 | NnUint predTotalTime = 0; 34 | 35 | printf("%s\n", context->args->prompt); 36 | for (;;) { 37 | long remainingTokens = nInputTokens - 1 - (long)pos; 38 | if (remainingTokens <= 0) 39 | break; 40 | NnUint batchSize = remainingTokens < context->args->nBatches 41 | ? remainingTokens 42 | : context->args->nBatches; 43 | 44 | context->inference->setBatchSize(batchSize); 45 | context->inference->setPosition(pos); 46 | for (NnUint i = 0; i < batchSize; i++) 47 | context->inference->setToken(i, inputTokens[pos + i]); 48 | 49 | context->inference->forward(); 50 | 51 | pos += batchSize; 52 | token = inputTokens[pos + 1]; 53 | 54 | if (context->network != nullptr) 55 | context->network->getStats(&sentBytes, &recvBytes); 56 | 57 | NnUint evalTime = context->executor->getTotalTime(STEP_EXECUTE_OP); 58 | NnUint syncTime = context->executor->getTotalTime(STEP_SYNC_NODES); 59 | printf("🔷️ Eval%5u ms Sync%5u ms | Sent%6zu kB Recv%6zu kB | (%d tokens)\n", 60 | evalTime / 1000, 61 | syncTime / 1000, 62 | sentBytes / 1024, 63 | recvBytes / 1024, 64 | batchSize); 65 | evalTotalTime += evalTime + syncTime; 66 | } 67 | 68 | fflush(stdout); 69 | 70 | context->inference->setBatchSize(1); 71 | context->tokenizer->resetDecoder(); 72 | 73 | const NnUint maxPos = std::min(context->header->seqLen, context->args->steps); 74 | for (; pos < maxPos; pos++) { 75 | context->inference->setPosition(pos); 76 | context->inference->setToken(0, token); 77 | context->inference->forward(); 78 | 79 | token = context->sampler->sample(context->inference->logitsPipe); 80 | 81 | char *piece = context->tokenizer->decode(token); 82 | 83 | if (context->network != nullptr) 84 | context->network->getStats(&sentBytes, &recvBytes); 85 | 86 | NnUint predTime = context->executor->getTotalTime(STEP_EXECUTE_OP); 87 | NnUint syncTime = context->executor->getTotalTime(STEP_SYNC_NODES); 88 | printf("🔶 Pred%5u ms Sync%5u ms | Sent%6zu kB Recv%6zu kB | %s\n", 89 | predTime / 1000, 90 | syncTime / 1000, 91 | sentBytes / 1024, 92 | recvBytes / 1024, 93 | piece == nullptr ? "~" : piece); 94 | fflush(stdout); 95 | predTotalTime += predTime + syncTime; 96 | } 97 | 98 | NnUint nEvalTokens = nInputTokens - 1; 99 | NnUint nPredTokens = pos - nEvalTokens; 100 | float evalTotalTimeMs = evalTotalTime / 1000.0; 101 | float predTotalTimeMs = predTotalTime / 1000.0; 102 | printf("\n"); 103 | printf("Evaluation\n"); 104 | printf(" nBatches: %d\n", context->args->nBatches); 105 | printf(" nTokens: %d\n", nEvalTokens); 106 | printf(" tokens/s: %3.2f (%3.2f ms/tok)\n", 107 | (nEvalTokens * 1000) / evalTotalTimeMs, 108 | evalTotalTimeMs / ((float) nEvalTokens)); 109 | printf("Prediction\n"); 110 | printf(" nTokens: %d\n", nPredTokens); 111 | printf(" tokens/s: %3.2f (%3.2f ms/tok)\n", 112 | (nPredTokens * 1000) / predTotalTimeMs, 113 | predTotalTimeMs / ((float) nPredTokens)); 114 | } 115 | 116 | static NnUint readStdin(const char *guide, char *buffer, NnUint size) { 117 | std::fflush(stdin); 118 | std::printf("%s", guide); 119 | if (std::fgets(buffer, size, stdin) != NULL) { 120 | NnUint length = std::strlen(buffer); 121 | if (length > 0 && buffer[length - 1] == '\n') { 122 | buffer[length - 1] = '\0'; 123 | length--; 124 | } 125 | return length; 126 | } 127 | return 0; 128 | } 129 | 130 | static void chat(AppInferenceContext *context) { 131 | const NnUint seqLen = context->header->seqLen; 132 | char prompt[2048]; 133 | 134 | TokenizerChatStops stops(context->tokenizer); 135 | ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]); 136 | EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength); 137 | 138 | const NnUint sysPromptLength = readStdin("💻 System prompt (optional): ", prompt, sizeof(prompt)); 139 | std::vector deltaItems; 140 | if (sysPromptLength > 0) 141 | deltaItems.push_back(ChatItem{"system", prompt}); 142 | 143 | NnUint pos = 0; 144 | NnUint userPromptLength; 145 | int token; 146 | int nInputTokens; 147 | do { 148 | do { 149 | userPromptLength = readStdin("\n👱 User\n> ", prompt, sizeof(prompt)); 150 | } while (userPromptLength == 0); 151 | 152 | deltaItems.push_back(ChatItem{"user", prompt}); 153 | 154 | GeneratedChat inputPrompt = templateGenerator.generate(deltaItems.size(), deltaItems.data(), true); 155 | std::unique_ptr inputTokensPtr(new int[inputPrompt.length + 2]); 156 | int *inputTokens = inputTokensPtr.get(); 157 | 158 | bool addBos = pos == 0; 159 | context->tokenizer->encode((char*)inputPrompt.content, inputTokens, &nInputTokens, addBos, true); 160 | 161 | NnUint userPromptEndPos = (NnUint)std::min(seqLen, pos + nInputTokens - 1); 162 | for (NnUint i = 0; ;) { 163 | int remainingTokens = userPromptEndPos - pos; 164 | if (remainingTokens <= 0) 165 | break; 166 | NnUint batchSize = remainingTokens < context->args->nBatches 167 | ? remainingTokens 168 | : context->args->nBatches; 169 | 170 | context->inference->setBatchSize(batchSize); 171 | context->inference->setPosition(pos); 172 | for (NnUint j = 0; j < batchSize; j++) 173 | context->inference->setToken(j, inputTokens[i + j]); 174 | 175 | context->inference->forward(); 176 | 177 | i += batchSize; 178 | pos += batchSize; 179 | token = inputTokens[i + 1]; 180 | } 181 | 182 | context->inference->setBatchSize(1); 183 | context->tokenizer->resetDecoder(); 184 | 185 | printf("\n🤖 Assistant\n"); 186 | if (inputPrompt.publicPrompt != nullptr) 187 | printf("%s", inputPrompt.publicPrompt); 188 | 189 | while (pos < seqLen) { 190 | context->inference->setPosition(pos); 191 | context->inference->setToken(0, token); 192 | context->inference->forward(); 193 | 194 | token = context->sampler->sample(context->inference->logitsPipe); 195 | 196 | char *piece = context->tokenizer->decode(token); 197 | EosDetectorType eosType = eosDetector.append(token, piece); 198 | if (eosType == NOT_EOS || eosType == EOS) { 199 | char *delta = eosDetector.getDelta(); 200 | if (delta != nullptr) { 201 | printf("%s", delta); 202 | fflush(stdout); 203 | } 204 | eosDetector.reset(); 205 | } 206 | pos++; 207 | if (eosType == EOS) break; 208 | } 209 | 210 | deltaItems.clear(); 211 | } while (pos < seqLen); 212 | 213 | printf("(end of context)\n"); 214 | } 215 | 216 | int main(int argc, char **argv) { 217 | initQuants(); 218 | initSockets(); 219 | 220 | int returnCode = EXIT_SUCCESS; 221 | try { 222 | AppCliArgs args = AppCliArgs::parse(argc, argv, true); 223 | if (std::strcmp(args.mode, "inference") == 0) { 224 | args.benchmark = true; 225 | runInferenceApp(&args, &inference); 226 | } else if (std::strcmp(args.mode, "chat") == 0) 227 | runInferenceApp(&args, &chat); 228 | else if (std::strcmp(args.mode, "worker") == 0) 229 | runWorkerApp(&args); 230 | else 231 | throw std::runtime_error("Unsupported mode"); 232 | } catch (std::exception &e) { 233 | printf("🚨 Critical error: %s\n", e.what()); 234 | returnCode = EXIT_FAILURE; 235 | } 236 | 237 | cleanupSockets(); 238 | return returnCode; 239 | } 240 | -------------------------------------------------------------------------------- /src/llm.hpp: -------------------------------------------------------------------------------- 1 | #ifndef LLM_HPP 2 | #define LLM_HPP 3 | 4 | #include "nn/nn-core.hpp" 5 | #include "nn/nn-executor.hpp" 6 | #include "nn/nn-network.hpp" 7 | 8 | enum LlmHeaderKey { 9 | VERSION = 0, 10 | ARCH_TYPE = 1, 11 | DIM = 2, 12 | HIDDEN_DIM = 3, 13 | N_LAYERS = 4, 14 | N_HEADS = 5, 15 | N_KV_HEADS = 6, 16 | N_EXPERTS = 7, 17 | N_ACTIVE_EXPERTS = 8, 18 | VOCAB_SIZE = 9, 19 | SEQ_LEN = 10, 20 | HIDDEN_ACT = 11, 21 | ROPE_THETA = 12, 22 | WEIGHT_FLOAT_TYPE = 13, 23 | ROPE_SCALING_FACTOR = 14, 24 | ROPE_SCALING_LOW_FREQ_FACTOR = 15, 25 | ROPE_SCALING_HIGH_FREQ_FACTORY = 16, 26 | ROPE_SCALING_ORIG_MAX_SEQ_LEN = 17, 27 | ROPE_TYPE = 18, 28 | }; 29 | 30 | enum LlmHiddenAct { 31 | HIDDEN_ACT_GELU, 32 | HIDDEN_ACT_SILU, 33 | }; 34 | 35 | enum LlmArchType { 36 | LLAMA = 0xABCD00, 37 | }; 38 | 39 | typedef struct { 40 | NnSize headerSize; 41 | NnSize fileSize; 42 | int version; 43 | LlmArchType archType; 44 | NnUint dim; 45 | NnUint nLayers; 46 | NnUint nHeads; 47 | NnUint headSize; 48 | NnUint nKvHeads; 49 | NnUint nExperts; 50 | NnUint nActiveExperts; 51 | NnUint origSeqLen; // Original model context length 52 | NnUint seqLen; // Limited context length by the `--max-seq-len` argument 53 | NnUint hiddenDim; 54 | LlmHiddenAct hiddenAct; 55 | NnUint kvDim; 56 | NnUint vocabSize; 57 | float ropeTheta; 58 | NnRopeType ropeType; 59 | float ropeScalingFactor; 60 | float ropeScalingLowFreqFactor; 61 | float ropeScalingHighFreqFactory; 62 | NnUint ropeScalingOrigMaxSeqLen; 63 | float normEpsilon; 64 | 65 | NnFloatType weightType; 66 | NnFloatType syncType; 67 | } LlmHeader; 68 | 69 | typedef struct { 70 | LlmHeader *header; 71 | NnNetConfig netConfig; 72 | NnNodeConfig *nodeConfigs; 73 | NnRowMatmulSlice qSlice; 74 | NnRowMatmulSlice kSlice; 75 | NnRowMatmulSlice vSlice; 76 | NnColMatmulSlice woSlice; 77 | NnRowMatmulSlice w1Slice; 78 | NnColMatmulSlice w2Slice; 79 | NnRowMatmulSlice w3Slice; 80 | NnRowMatmulSlice wclsSlice; 81 | NnUint positionPipeIndex; 82 | NnUint tokenPipeIndex; 83 | NnUint xPipeIndex; 84 | NnUint logitsPipeIndex; 85 | NnSize2D tokenEmbeddingSize; 86 | NnSize2D rmsNormSize; 87 | } LlmNet; 88 | 89 | LlmHeader loadLlmHeader(const char* path, const unsigned int maxSeqLen, NnFloatType syncType); 90 | void printLlmHeader(LlmHeader *header); 91 | LlmNet buildLlmNet(LlmHeader *h, NnUint nNodes, NnUint nBatches); 92 | void releaseLlmNet(LlmNet *net); 93 | void loadLlmNetWeight(const char* path, LlmNet *net, NnRootWeightLoader *loader); 94 | 95 | #endif -------------------------------------------------------------------------------- /src/mmap.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MMAP_HPP 2 | #define MMAP_HPP 3 | 4 | #include 5 | #include 6 | #ifdef _WIN32 7 | #include 8 | #else 9 | #include 10 | #include 11 | #include 12 | #endif 13 | 14 | struct MmapFile { 15 | void* data; 16 | size_t size; 17 | #ifdef _WIN32 18 | HANDLE hFile; 19 | HANDLE hMapping; 20 | #else 21 | int fd; 22 | #endif 23 | }; 24 | 25 | long seekToEnd(FILE* file) { 26 | #ifdef _WIN32 27 | _fseeki64(file, 0, SEEK_END); 28 | return _ftelli64(file); 29 | #else 30 | fseek(file, 0, SEEK_END); 31 | return ftell(file); 32 | #endif 33 | } 34 | 35 | void openMmapFile(MmapFile *file, const char *path, size_t size) { 36 | file->size = size; 37 | #ifdef _WIN32 38 | file->hFile = CreateFileA(path, GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); 39 | if (file->hFile == INVALID_HANDLE_VALUE) { 40 | printf("Cannot open file %s\n", path); 41 | exit(EXIT_FAILURE); 42 | } 43 | 44 | file->hMapping = CreateFileMappingA(file->hFile, NULL, PAGE_READONLY, 0, 0, NULL); 45 | if (file->hMapping == NULL) { 46 | printf("CreateFileMappingA failed, error: %lu\n", GetLastError()); 47 | CloseHandle(file->hFile); 48 | exit(EXIT_FAILURE); 49 | } 50 | 51 | file->data = (void *)MapViewOfFile(file->hMapping, FILE_MAP_READ, 0, 0, 0); 52 | if (file->data == NULL) { 53 | printf("MapViewOfFile failed!\n"); 54 | CloseHandle(file->hMapping); 55 | CloseHandle(file->hFile); 56 | exit(EXIT_FAILURE); 57 | } 58 | #else 59 | file->fd = open(path, O_RDONLY); 60 | if (file->fd == -1) { 61 | throw std::runtime_error("Cannot open file"); 62 | } 63 | 64 | file->data = mmap(NULL, size, PROT_READ, MAP_PRIVATE, file->fd, 0); 65 | if (file->data == MAP_FAILED) { 66 | close(file->fd); 67 | throw std::runtime_error("Mmap failed"); 68 | } 69 | #endif 70 | } 71 | 72 | void closeMmapFile(MmapFile *file) { 73 | #ifdef _WIN32 74 | UnmapViewOfFile(file->data); 75 | CloseHandle(file->hMapping); 76 | CloseHandle(file->hFile); 77 | #else 78 | munmap(file->data, file->size); 79 | close(file->fd); 80 | #endif 81 | } 82 | 83 | #endif -------------------------------------------------------------------------------- /src/nn/llamafile/sgemm.hpp: -------------------------------------------------------------------------------- 1 | #ifndef LLAMAFILE_SGEMM_H 2 | #define LLAMAFILE_SGEMM_H 3 | 4 | #include 5 | 6 | bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, 7 | int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype); 8 | 9 | #endif 10 | -------------------------------------------------------------------------------- /src/nn/nn-config-builder.hpp: -------------------------------------------------------------------------------- 1 | #ifndef NN_CONFIG_BUILDER_H 2 | #define NN_CONFIG_BUILDER_H 3 | 4 | #include "nn-core.hpp" 5 | #include 6 | #include 7 | 8 | static char *cloneString(const char *str) { 9 | NnUint len = std::strlen(str); 10 | char *copy = new char[len + 1]; 11 | std::memcpy(copy, str, len + 1); 12 | return copy; 13 | } 14 | 15 | class NnNetConfigBuilder { 16 | public: 17 | NnUint nNodes; 18 | NnUint nBatches; 19 | std::list pipes; 20 | std::list preSyncs; 21 | 22 | NnNetConfigBuilder(NnUint nNodes, NnUint nBatches) { 23 | this->nNodes = nNodes; 24 | this->nBatches = nBatches; 25 | } 26 | 27 | NnUint addPipe(const char *name, NnSize2D size) { 28 | NnUint pipeIndex = pipes.size(); 29 | pipes.push_back({ cloneString(name), size }); 30 | return pipeIndex; 31 | } 32 | 33 | void addPreSync(NnUint pipeIndex) { 34 | preSyncs.push_back({ pipeIndex }); 35 | } 36 | 37 | NnNetConfig build() { 38 | NnNetConfig config; 39 | config.nNodes = nNodes; 40 | config.nBatches = nBatches; 41 | config.nPipes = pipes.size(); 42 | config.pipes = new NnPipeConfig[config.nPipes]; 43 | std::copy(pipes.begin(), pipes.end(), config.pipes); 44 | config.nPreSyncs = preSyncs.size(); 45 | if (config.nPreSyncs > 0) { 46 | config.preSyncs = new NnPreSyncConfig[config.nPreSyncs]; 47 | std::copy(preSyncs.begin(), preSyncs.end(), config.preSyncs); 48 | } else { 49 | config.preSyncs = nullptr; 50 | } 51 | return config; 52 | } 53 | }; 54 | 55 | class NnNodeConfigBuilder { 56 | public: 57 | NnUint nodeIndex; 58 | std::list buffers; 59 | std::list segments; 60 | 61 | NnNodeConfigBuilder(NnUint nodeIndex) { 62 | this->nodeIndex = nodeIndex; 63 | } 64 | 65 | NnUint addBuffer(const char *name, NnSize2D size) { 66 | NnUint bufferIndex = buffers.size(); 67 | buffers.push_back({ cloneString(name), size }); 68 | return bufferIndex; 69 | } 70 | 71 | void addSegment(NnSegmentConfig segment) { 72 | segments.push_back(segment); 73 | } 74 | 75 | NnNodeConfig build() { 76 | NnNodeConfig config; 77 | config.nodeIndex = nodeIndex; 78 | config.nBuffers = buffers.size(); 79 | if (config.nBuffers > 0) { 80 | config.buffers = new NnBufferConfig[config.nBuffers]; 81 | std::copy(buffers.begin(), buffers.end(), config.buffers); 82 | } else { 83 | config.buffers = nullptr; 84 | } 85 | 86 | config.nSegments = segments.size(); 87 | assert(config.nSegments > 0); 88 | config.segments = new NnSegmentConfig[config.nSegments]; 89 | std::copy(segments.begin(), segments.end(), config.segments); 90 | return config; 91 | } 92 | }; 93 | 94 | class NnSegmentConfigBuilder { 95 | private: 96 | std::list ops; 97 | std::list syncs; 98 | 99 | public: 100 | template 101 | void addOp(NnOpCode code, const char *name, NnUint index, NnPointerConfig input, NnPointerConfig output, NnSize2D weightSize, T config) { 102 | NnUint configSize = sizeof(T); 103 | NnByte *configCopy = new NnByte[configSize]; 104 | std::memcpy(configCopy, &config, configSize); 105 | ops.push_back({ 106 | code, 107 | cloneString(name), 108 | index, 109 | input, 110 | output, 111 | weightSize, 112 | configCopy, 113 | configSize 114 | }); 115 | }; 116 | 117 | void addSync(NnUint pipeIndex, NnSyncType syncType) { 118 | syncs.push_back({ pipeIndex, syncType }); 119 | } 120 | 121 | NnSegmentConfig build() { 122 | NnSegmentConfig segment; 123 | segment.nOps = ops.size(); 124 | if (segment.nOps > 0) { 125 | segment.ops = new NnOpConfig[segment.nOps]; 126 | std::copy(ops.begin(), ops.end(), segment.ops); 127 | } 128 | segment.nSyncs = syncs.size(); 129 | if (segment.nSyncs > 0) { 130 | segment.syncs = new NnSyncConfig[segment.nSyncs]; 131 | std::copy(syncs.begin(), syncs.end(), segment.syncs); 132 | } 133 | return segment; 134 | } 135 | }; 136 | 137 | #endif -------------------------------------------------------------------------------- /src/nn/nn-core.hpp: -------------------------------------------------------------------------------- 1 | #ifndef NN_CORE_H 2 | #define NN_CORE_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "nn-quants.hpp" 9 | 10 | // primitives 11 | 12 | typedef struct { 13 | NnFloatType floatType; 14 | NnUint y; 15 | NnUint x; 16 | NnSize length; 17 | NnSize nBytes; 18 | } NnSize2D; 19 | 20 | // slices 21 | 22 | typedef struct { 23 | NnUint kvDim0; 24 | NnSize2D keySize; 25 | NnSize2D valueSize; 26 | } NnKvCacheSlice; 27 | 28 | typedef struct { 29 | NnFloatType type; 30 | NnUint nNodes; 31 | NnUint d0; 32 | NnUint n; 33 | NnSize2D size; 34 | NnSize2D sliceSize; 35 | } NnRowMatmulSlice; 36 | 37 | typedef struct { 38 | NnFloatType type; 39 | NnUint nNodes; 40 | NnUint n; 41 | NnUint n0; 42 | NnUint d; 43 | NnSize2D size; 44 | NnSize2D sliceSize; 45 | } NnColMatmulSlice; 46 | 47 | typedef struct { 48 | NnUint qDim0; 49 | NnUint qDimStart; 50 | NnUint qDimEnd; 51 | NnUint qShift; 52 | NnUint kvDim; 53 | NnUint kvDim0; 54 | NnUint kvDimStart; 55 | NnUint sliceDim; 56 | NnUint seqLen; 57 | NnUint headSize; 58 | NnUint nKvHeads; 59 | float ropeTheta; 60 | NnSize2D cacheSize; 61 | } NnRopeSlice; 62 | 63 | typedef struct { 64 | NnUint nHeads; 65 | NnUint nHeads0; 66 | NnSize2D attSize; 67 | } NnMultiHeadAttSlice; 68 | 69 | // base enums 70 | 71 | enum NnOpCode { 72 | OP_MERGE_ADD, 73 | OP_EMBEDDING, 74 | OP_INV_RMS, 75 | OP_RMS_NORM, 76 | OP_MATMUL, 77 | OP_ROPE_LLAMA, 78 | OP_MULTIHEAD_ATT, 79 | OP_GELU, 80 | OP_SILU, 81 | OP_MUL, 82 | OP_CAST, 83 | OP_SHIFT, 84 | }; 85 | 86 | enum NnOpQuantType { 87 | // __ 88 | F32_F32_F32, 89 | F32_Q40_F32, 90 | F32_Q40_Q80, 91 | F32_F32_Q80, 92 | Q80_Q80_Q80, 93 | Q80_Q80_F32, 94 | Q80_Q40_F32, 95 | Q80_F32_F32, 96 | }; 97 | 98 | #define N_OP_CODES (OP_SHIFT + 1) 99 | #define N_OP_QUANTS (Q80_F32_F32 + 1) 100 | 101 | enum NnPointerSource { 102 | SRC_PIPE, 103 | SRC_BUFFER, 104 | }; 105 | 106 | enum NnPointerType { 107 | PNTR_RAW, 108 | PNTR_BATCH, 109 | PNTR_BATCHED_SLICE 110 | }; 111 | 112 | enum NnSyncType { 113 | SYNC_WITH_ROOT, // whole pipe to all nodes 114 | SYNC_NODE_SLICES, // my slice of pipe to all nodes 115 | SYNC_NODE_SLICES_EXCEPT_ROOT, // only workers send slices to root, root does not send 116 | }; 117 | 118 | enum NnRopeType { 119 | ROPE_LLAMA = 0, 120 | ROPE_FALCON = 1, 121 | ROPE_LLAMA3_1 = 2, 122 | }; 123 | 124 | // base configs 125 | 126 | typedef struct { 127 | char *name; 128 | NnSize2D size; 129 | } NnPipeConfig; 130 | 131 | typedef struct { 132 | char *name; 133 | NnSize2D size; 134 | } NnBufferConfig; 135 | 136 | typedef struct { 137 | NnPointerSource source; 138 | NnUint pointerIndex; 139 | NnPointerType type; 140 | } NnPointerConfig; 141 | 142 | typedef struct { 143 | NnOpCode code; 144 | char *name; 145 | NnUint index; 146 | NnPointerConfig input; 147 | NnPointerConfig output; 148 | NnSize2D weightSize; 149 | NnByte *config; 150 | NnUint configSize; 151 | } NnOpConfig; 152 | 153 | typedef struct { 154 | NnUint pipeIndex; 155 | } NnPreSyncConfig; 156 | 157 | typedef struct { 158 | NnUint pipeIndex; 159 | NnSyncType syncType; 160 | } NnSyncConfig; 161 | 162 | typedef struct { 163 | NnUint nOps; 164 | NnOpConfig *ops; 165 | NnUint nSyncs; 166 | NnSyncConfig *syncs; 167 | } NnSegmentConfig; 168 | 169 | typedef struct { 170 | NnUint nBatches; 171 | NnUint nNodes; 172 | NnUint nPipes; 173 | NnPipeConfig *pipes; 174 | NnUint nPreSyncs; 175 | NnPreSyncConfig *preSyncs; 176 | } NnNetConfig; 177 | 178 | typedef struct { 179 | NnUint nodeIndex; 180 | NnUint nBuffers; 181 | NnBufferConfig *buffers; 182 | NnUint nSegments; 183 | NnSegmentConfig *segments; 184 | } NnNodeConfig; 185 | 186 | // op configs 187 | 188 | typedef struct { 189 | // empty 190 | } NnEmbeddingOpConfig; 191 | 192 | typedef struct { 193 | float epsilon; 194 | } NnInvRmsOpConfig; 195 | 196 | typedef struct { 197 | NnUint invRmsBufferIndex; 198 | } NnRmsNormOpConfig; 199 | 200 | typedef struct { 201 | // empty 202 | } NnMatmulOpConfig; 203 | 204 | typedef struct { 205 | bool isQ; 206 | NnUint positionPipeIndex; 207 | NnUint ropeCacheBufferIndex; 208 | float ropeScalingFactor; 209 | float ropeScalingLowFreqFactor; 210 | float ropeScalingHighFreqFactor; 211 | NnUint ropeScalingOrigMaxSeqLen; 212 | NnRopeSlice slice; 213 | } NnRopeLlamaOpConfig; 214 | 215 | typedef struct { 216 | NnUint nHeads; 217 | NnUint nHeads0; 218 | NnUint nKvHeads; 219 | NnUint headSize; 220 | NnUint seqLen; 221 | NnUint qSliceD0; 222 | NnUint kvDim0; 223 | NnUint positionPipeIndex; 224 | NnUint queryBufferIndex; 225 | NnUint keyCacheBufferIndex; 226 | NnUint valueCacheBufferIndex; 227 | NnUint attBufferIndex; 228 | } NnMultiHeadAttOpConfig; 229 | 230 | typedef struct { 231 | // empty 232 | } NnMergeAddOpCodeConfig; 233 | 234 | typedef struct { 235 | // empty 236 | } NnSiluOpCodeConfig; 237 | 238 | typedef struct { 239 | NnUint multiplierBufferIndex; 240 | } NnMulOpCodeConfig; 241 | 242 | typedef struct { 243 | // empty 244 | } NnCastOpCodeConfig; 245 | 246 | typedef struct { 247 | NnUint indexPipeIndex; 248 | } NnShiftOpCodeConfig; 249 | 250 | // utility functions 251 | 252 | const char *opCodeToString(NnOpCode code); 253 | const char *opQuantTypeToString(NnOpQuantType type); 254 | 255 | NnSize getBytes(NnFloatType floatType, NnSize n); 256 | NnSize getBlockSize(NnFloatType floatType); 257 | NnOpQuantType getOpQuantType(NnFloatType input, NnFloatType weight, NnFloatType output); 258 | NnSize2D size0(); 259 | NnSize2D size1D(NnFloatType floatType, NnUint x); 260 | NnSize2D size2D(NnFloatType floatType, NnUint y, NnUint x); 261 | NnPointerConfig pointerBatchConfig(NnPointerSource source, NnUint index); 262 | NnPointerConfig pointerBatchedSliceConfig(NnPointerSource source, NnUint index); 263 | NnPointerConfig pointerRawConfig(NnPointerSource source, NnUint index); 264 | bool hasPointerContinuousMemory(NnPointerConfig *config); 265 | 266 | void releaseNetConfig(NnNetConfig *netConfig); 267 | void releaseNodeConfig(NnNodeConfig *nodeConfig); 268 | 269 | void printNodeRequiredMemory(NnNetConfig *netConfig, NnNodeConfig *nodeConfig); 270 | 271 | class Timer { 272 | private: 273 | std::chrono::time_point startTime; 274 | public: 275 | Timer(); 276 | void reset(); 277 | NnUint elapsedMiliseconds(); 278 | NnUint elapsedMicroseconds(); 279 | }; 280 | 281 | // slicers 282 | 283 | NnKvCacheSlice sliceKvCache(NnUint kvDim, NnUint seqLen, NnUint nNodes); 284 | NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d); 285 | NnColMatmulSlice sliceColMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d); 286 | NnRopeSlice sliceRope(NnUint dim, NnUint kvDim, NnUint nKvHeads, NnUint nNodes, NnUint seqLen, NnUint headSize, float ropeTheta, NnUint nodeIndex); 287 | NnMultiHeadAttSlice sliceMultiHeadAtt(NnUint nHeads, NnUint seqLen, NnUint nNodes, NnUint nBatches); 288 | 289 | // splitters 290 | 291 | NnUint splitRowMatmulWeight(NnRowMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0); 292 | NnUint splitColMatmulWeight(NnColMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0); 293 | 294 | // rope 295 | 296 | void fullfillRopeLlama3Cache(const NnRopeLlamaOpConfig *config, float *cache); 297 | 298 | #endif 299 | -------------------------------------------------------------------------------- /src/nn/nn-cpu-ops-test.cpp: -------------------------------------------------------------------------------- 1 | #include "nn-cpu-ops.cpp" 2 | #include 3 | 4 | // framework 5 | 6 | void rand(float *o, const NnUint n, const NnUint seed) { 7 | srand(seed + 123456); 8 | for (NnUint i = 0; i < n; i++) { 9 | float v = (float)(rand() / RAND_MAX); 10 | o[i] = v * 2.0f - 1.0f; 11 | } 12 | } 13 | 14 | void compare_F32(const char *name, const float *a, const float *b, const NnUint n, const float epsilon) { 15 | for (NnUint i = 0; i < n; i++) { 16 | float error = fabs(a[i] - b[i]); 17 | if (error > epsilon) { 18 | printf("❌ %s failed\n", name); 19 | for (NnUint j = i; j < i + 16 && j < n; j++) 20 | printf(" [%3d] %f != %f\n", j, a[j], b[j]); 21 | exit(1); 22 | } 23 | } 24 | printf("✅ %24s passed\n", name); 25 | } 26 | 27 | // tests 28 | 29 | void testSplitThreads() { 30 | // <0; 32> across 3 threads 31 | { 32 | SPLIT_THREADS(a0Start, a0End, 32, 3, 0); // thread 0 33 | assert(a0Start == 0); 34 | assert(a0End == 11); 35 | } 36 | { 37 | SPLIT_THREADS(a1Start, a1End, 32, 3, 1); // thread 1 38 | assert(a1Start == 11); 39 | assert(a1End == 22); 40 | } 41 | { 42 | SPLIT_THREADS(a2Start, a2End, 32, 3, 2); // thread 2 43 | assert(a2Start == 22); 44 | assert(a2End == 32); 45 | } 46 | 47 | // <0; 4> across 8 threads 48 | { 49 | SPLIT_THREADS(b0Start, b0End, 4, 8, 0); // thread 0 50 | assert(b0Start == 0); 51 | assert(b0End == 1); 52 | } 53 | { 54 | SPLIT_THREADS(b0Start, b0End, 4, 8, 3); // thread 3 55 | assert(b0Start == 3); 56 | assert(b0End == 4); 57 | } 58 | { 59 | SPLIT_THREADS(b0Start, b0End, 4, 8, 4); // thread 4 60 | assert(b0Start == 4); 61 | assert(b0End == 4); 62 | } 63 | { 64 | SPLIT_THREADS(b0Start, b0End, 4, 8, 7); // thread 7 65 | assert(b0Start == 4); 66 | assert(b0End == 4); 67 | } 68 | 69 | printf("✅ %24s passed\n", "splitThreads"); 70 | } 71 | 72 | void testConvertF32toF16() { 73 | float x[] = {0.0f, 0.25f, 0.3456f, 1.0f}; 74 | for (NnUint i = 0; i < sizeof(x) / sizeof(float); i++) { 75 | NnFp16 f16 = CONVERT_F32_TO_F16(x[i]); 76 | float f32 = CONVERT_F16_TO_F32(f16); 77 | compare_F32("convertF32toF16", &x[i], &f32, 1, 0.0005); 78 | } 79 | } 80 | 81 | // quantization 82 | void testQuantization(const NnUint m) { 83 | std::vector a(m * Q40_BLOCK_SIZE); 84 | std::vector aTemp(m * Q40_BLOCK_SIZE); 85 | std::vector aQ40(m); 86 | std::vector aQ80(m); 87 | 88 | rand(a.data(), m * Q40_BLOCK_SIZE, m); 89 | 90 | quantizeF32toQ40(a.data(), aQ40.data(), m * Q40_BLOCK_SIZE, 1, 0); 91 | dequantizeQ40toF32(aQ40.data(), aTemp.data(), m * Q40_BLOCK_SIZE, 1, 0); 92 | 93 | compare_F32("testQuantization_Q40", a.data(), aTemp.data(), m * Q40_BLOCK_SIZE, 0.13); 94 | 95 | quantizeF32toQ80(a.data(), aQ80.data(), m * Q80_BLOCK_SIZE, 1, 0); 96 | dequantizeQ80toF32(aQ80.data(), aTemp.data(), m * Q80_BLOCK_SIZE, 1, 0); 97 | 98 | compare_F32("testQuantization_Q80", a.data(), aTemp.data(), m * Q80_BLOCK_SIZE, 0.01); 99 | } 100 | 101 | // invRms 102 | void testInvRms() { 103 | const float epsilon = 0.00001; 104 | 105 | std::vector x(8); 106 | x[0] = 0.1f; 107 | x[1] = 0.3f; 108 | x[2] = 0.2f; 109 | x[3] = 0.4f; 110 | x[4] = 0.6f; 111 | x[5] = 0.5f; 112 | x[6] = 0.0f; 113 | x[7] = 0.8f; 114 | 115 | const float y0 = invRms_F32(x.data(), 8, epsilon); 116 | float ev0 = 1.0f / 0.4402f; 117 | compare_F32("rms_F32", &y0, &ev0, 1, 0.001f); 118 | } 119 | 120 | // rmsNorm 121 | void testRmsNorm(const NnUint m) { 122 | std::vector x(m); 123 | std::vector xQ80(m / Q80_BLOCK_SIZE); 124 | std::vector w(m); 125 | std::vector y(m); 126 | std::vector yTemp(m); 127 | 128 | rand(x.data(), m, m); 129 | rand(w.data(), m, m * m); 130 | quantizeF32toQ80(x.data(), xQ80.data(), m, 1, 0); 131 | const float rms = invRms_F32(x.data(), m, 1e-5f); 132 | 133 | rmsNorm_F32(y.data(), x.data(), rms, w.data(), m, 1, 0); 134 | rmsNorm_Q80_F32_F32(yTemp.data(), xQ80.data(), rms, w.data(), m, 1, 0); 135 | 136 | compare_F32("rmsNorm_Q80_F32_F32", y.data(), yTemp.data(), m, 0.01); 137 | } 138 | 139 | // a *= b 140 | void testMul(const NnUint m) { 141 | const NnUint n = Q80_BLOCK_SIZE * m; 142 | 143 | std::vector a0(n); 144 | std::vector b0(n); 145 | 146 | std::vector aQ(n); 147 | std::vector b1(n / Q80_BLOCK_SIZE); 148 | 149 | rand(a0.data(), n, m); 150 | rand(aQ.data(), n, m); 151 | rand(b0.data(), n, m); 152 | quantizeF32toQ80(b0.data(), b1.data(), n, 1, 0); 153 | 154 | mul_F32(a0.data(), a0.data(), b0.data(), n, 1, 0); 155 | mul_Q80_F32(aQ.data(), aQ.data(), b1.data(), n, 1, 0); 156 | 157 | compare_F32("mul_Q80_F32", a0.data(), aQ.data(), n, 0.005); 158 | } 159 | 160 | // y += x 161 | void testAdd(const NnUint m) { 162 | const NnUint n = Q80_BLOCK_SIZE * m; 163 | 164 | std::vector y(n); 165 | std::vector yTemp(n); 166 | std::vector x(n); 167 | std::vector xQ80(n / Q80_BLOCK_SIZE); 168 | 169 | rand(y.data(), n, m); 170 | rand(yTemp.data(), n, m); 171 | rand(x.data(), n, m); 172 | quantizeF32toQ80(x.data(), xQ80.data(), n, 1, 0); 173 | 174 | add_F32(y.data(), x.data(), n, 1, 0); 175 | add_Q80_F32(yTemp.data(), xQ80.data(), n, 1, 0); 176 | 177 | compare_F32("add_Q80_F32", y.data(), yTemp.data(), n, 0.01); 178 | } 179 | 180 | void testSoftmax() { 181 | std::vector y(8); 182 | for (NnUint i = 0; i < 8; i++) 183 | y[i] = i / 8.0f; 184 | 185 | softmax_F32(y.data(), 8); 186 | 187 | float expectedOutput[8] = { 188 | 0.077399f, 189 | 0.087780f, 190 | 0.099500f, 191 | 0.112761f, 192 | 0.127778f, 193 | 0.144793f, 194 | 0.164072f, 195 | 0.185917f 196 | }; 197 | compare_F32("softmax_F32", y.data(), expectedOutput, 8, 0.001); 198 | } 199 | 200 | void testSilu() { 201 | std::vector y(8); 202 | for (NnUint i = 0; i < 8; i++) 203 | y[i] = i / 8.0f; 204 | 205 | silu_F32(y.data(), 8, 1, 0); 206 | 207 | float expectedOutput[8] = { 208 | 0.000000f, 209 | 0.066401f, 210 | 0.140544f, 211 | 0.222250f, 212 | 0.311233f, 213 | 0.407116f, 214 | 0.509461f, 215 | 0.617802f 216 | }; 217 | compare_F32("silu_F32", y.data(), expectedOutput, 8, 0.001); 218 | } 219 | 220 | // matmul 221 | void testMatmul_F32_Q40_F32(const NnUint m = 2) { 222 | const NnUint n = Q80_BLOCK_SIZE * m; 223 | const NnUint d = Q80_BLOCK_SIZE * m; 224 | 225 | std::vector x(n); 226 | std::vector w(n * d); 227 | std::vector o(d); 228 | std::vector oTemp(d); 229 | std::vector xQ80(n / Q80_BLOCK_SIZE); 230 | std::vector wQ40((n * d) / Q40_BLOCK_SIZE); 231 | 232 | rand(x.data(), n, m); 233 | rand(w.data(), n * d, m); 234 | quantizeF32toQ40(w.data(), wQ40.data(), n * d, 1, 0); 235 | quantizeF32toQ80(x.data(), xQ80.data(), n, 1, 0); 236 | 237 | matmul_F32_F32_F32(o.data(), x.data(), w.data(), n, d, 1, 0); 238 | 239 | matmul_Q80_Q40_F32(oTemp.data(), xQ80.data(), wQ40.data(), n, d, 1, 0); 240 | compare_F32("matmul_Q80_Q40_F32", o.data(), oTemp.data(), d, 4.0f); 241 | } 242 | 243 | void testLlamafileSgemm() { 244 | const NnUint batchSize = 8; 245 | const NnUint n = 256; 246 | const NnUint d = 128; 247 | 248 | std::vector x(n * batchSize); 249 | std::vector xQ((n * batchSize) / Q80_BLOCK_SIZE); 250 | std::vector w(n * d); 251 | std::vector wQ((n * d) / Q40_BLOCK_SIZE); 252 | std::vector o(d * batchSize); 253 | std::vector oTemp(d * batchSize); 254 | 255 | rand(x.data(), n * batchSize, 12345); 256 | rand(w.data(), n * d, 23456); 257 | 258 | quantizeF32toQ80(x.data(), xQ.data(), n * batchSize, 1, 0); 259 | quantizeF32toQ40(w.data(), wQ.data(), n * d, 1, 0); 260 | 261 | // f32 262 | 263 | for (NnUint i = 0; i < batchSize; i++) { 264 | matmul_F32_F32_F32(o.data() + i * d, x.data() + i * n, w.data(), n, d, 1, 0); 265 | } 266 | 267 | assert(llamafile_sgemm( 268 | d, batchSize, n, 269 | w.data(), n, 270 | x.data(), n, 271 | oTemp.data(), d, 272 | 0, 1, 0, 273 | F_32, F_32, F_32 274 | )); 275 | 276 | compare_F32("llamafileSgemm_F32", o.data(), oTemp.data(), d * batchSize, 0.01f); 277 | 278 | // q40ᵀ * q80 279 | 280 | assert(llamafile_sgemm( 281 | d, batchSize, n / Q80_BLOCK_SIZE, 282 | wQ.data(), n / Q80_BLOCK_SIZE, 283 | xQ.data(), n / Q80_BLOCK_SIZE, 284 | oTemp.data(), d, 285 | 0, 1, 0, 286 | F_Q40, F_Q80, F_32 287 | )); 288 | 289 | compare_F32("llamafileSgemm_Q80_Q40", o.data(), oTemp.data(), d * batchSize, 1.5f); 290 | } 291 | 292 | int main() { 293 | initQuants(); 294 | 295 | printCpuInstructionSet(); 296 | testSplitThreads(); 297 | testConvertF32toF16(); 298 | testQuantization(32); 299 | testQuantization(2); 300 | testQuantization(1); 301 | testInvRms(); 302 | testRmsNorm(128); 303 | testMul(32); 304 | testMul(2); 305 | testMul(1); 306 | testAdd(32); 307 | testAdd(2); 308 | testAdd(1); 309 | testSoftmax(); 310 | testSilu(); 311 | testMatmul_F32_Q40_F32(32); 312 | testMatmul_F32_Q40_F32(2); 313 | testMatmul_F32_Q40_F32(1); 314 | testLlamafileSgemm(); 315 | return 0; 316 | } 317 | -------------------------------------------------------------------------------- /src/nn/nn-cpu-ops.hpp: -------------------------------------------------------------------------------- 1 | #ifndef NN_CPU_OPS_H 2 | #define NN_CPU_OPS_H 3 | 4 | #include "nn-core.hpp" 5 | 6 | #define ASSERT_EQ(a, b) \ 7 | if (a != b) { \ 8 | printf("Assertion failed: %d != %d (%s:%d)\n", a, b, __FILE__, __LINE__); \ 9 | exit(-1); \ 10 | } 11 | 12 | typedef struct { 13 | const char *name; 14 | NnByte nBatches; 15 | NnByte *bufferFlags; 16 | NnByte **buffers; 17 | NnBufferConfig *bufferConfigs; 18 | NnByte **pipes; 19 | NnPipeConfig *pipeConfigs; 20 | void *opConfig; 21 | 22 | NnByte **input; 23 | NnSize2D inputSize; 24 | bool hasInputContinuousMemory; 25 | 26 | NnByte **output; 27 | NnSize2D outputSize; 28 | bool hasOutputContinuousMemory; 29 | 30 | NnByte *weight; 31 | NnSize2D weightSize; 32 | } NnCpuOpContext; 33 | 34 | typedef void (*NnCpuOpForwardInit)(NnCpuOpContext *context); 35 | typedef void (*NnCpuOpForward)(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context); 36 | 37 | void printCpuInstructionSet(); 38 | NnCpuOpForwardInit getCpuOpForwardInit(NnOpCode code, NnOpQuantType quantType); 39 | NnCpuOpForward getCpuOpForward(NnOpCode code, NnOpQuantType quantType); 40 | 41 | void softmax_F32(float *x, const NnUint size); 42 | 43 | #endif -------------------------------------------------------------------------------- /src/nn/nn-cpu-test.cpp: -------------------------------------------------------------------------------- 1 | #include "nn-core.hpp" 2 | #include "nn-config-builder.hpp" 3 | #include "nn-cpu.hpp" 4 | #include 5 | 6 | #define DIM 32 7 | #define N_BATCHES 2 8 | 9 | void buildConfig(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) { 10 | NnUint nNodes = 1; 11 | NnNetConfigBuilder netBuilder(nNodes, N_BATCHES); 12 | NnUint xPipeIndex = netBuilder.addPipe("X", size2D(F_32, N_BATCHES, DIM)); 13 | 14 | NnNodeConfigBuilder nodeBuilder(0); 15 | NnUint invRmsBufferIndex = nodeBuilder.addBuffer("inv_rms", size2D(F_32, N_BATCHES, 1)); 16 | NnSegmentConfigBuilder segmentBuilder; 17 | segmentBuilder.addSync(xPipeIndex, SYNC_NODE_SLICES_EXCEPT_ROOT); 18 | 19 | segmentBuilder.addOp(OP_INV_RMS, "inv_rms", 0, 20 | pointerBatchConfig(SRC_PIPE, xPipeIndex), 21 | pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex), 22 | size0(), 23 | NnInvRmsOpConfig{1e-5f}); 24 | 25 | segmentBuilder.addOp(OP_RMS_NORM, "rms_norm", 0, 26 | pointerBatchConfig(SRC_PIPE, xPipeIndex), 27 | pointerBatchConfig(SRC_PIPE, xPipeIndex), 28 | size1D(F_32, DIM), 29 | NnRmsNormOpConfig{invRmsBufferIndex}); 30 | 31 | nodeBuilder.addSegment(segmentBuilder.build()); 32 | 33 | *netConfig = netBuilder.build(); 34 | *nodeConfig = nodeBuilder.build(); 35 | } 36 | 37 | void print2D(const char *name, NnUint x, NnUint y, float *w) { 38 | for (NnUint i = 0; i < y; i++) { 39 | printf("%s[%d] = ", name, i); 40 | for (NnUint j = 0; j < x; j++) 41 | printf("%f ", w[i * x + j]); 42 | printf("\n"); 43 | } 44 | } 45 | 46 | int main() { 47 | initQuants(); 48 | 49 | NnUint nThreads = 2; 50 | NnNetConfig netConfig; 51 | NnNodeConfig nodeConfig; 52 | buildConfig(&netConfig, &nodeConfig); 53 | 54 | NnNetExecution execution(nThreads, &netConfig); 55 | float *x = (float *)execution.pipes[0]; 56 | for (NnUint b = 0; b < N_BATCHES; b++) { 57 | for (NnUint i = 0; i < DIM; i++) 58 | x[b * DIM + i] = i / (float)DIM + (float)b; 59 | } 60 | 61 | print2D("x", DIM, N_BATCHES, x); 62 | 63 | float rmsNormWeight[DIM]; 64 | for (NnUint i = 0; i < DIM; i++) 65 | rmsNormWeight[i] = 0.5 + i / (float)DIM; 66 | 67 | NnCpuDevice *device = new NnCpuDevice(&netConfig, &nodeConfig, &execution); 68 | std::vector devices; 69 | devices.push_back(NnExecutorDevice(device, -1, -1)); 70 | 71 | NnFakeNodeSynchronizer synchronizer; 72 | float *rms = (float *)device->buffers[0]; 73 | NnExecutor executor(&netConfig, &nodeConfig, &devices, &execution, &synchronizer, false); 74 | executor.loadWeight("rms_norm", 0, sizeof(rmsNormWeight), (NnByte *)rmsNormWeight); 75 | 76 | execution.setBatchSize(2); 77 | executor.forward(); 78 | 79 | print2D("rms", N_BATCHES, 1, rms); 80 | print2D("x", DIM, N_BATCHES, x); 81 | 82 | releaseNetConfig(&netConfig); 83 | releaseNodeConfig(&nodeConfig); 84 | return 0; 85 | } -------------------------------------------------------------------------------- /src/nn/nn-cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "nn-cpu.hpp" 2 | #include "nn-cpu-ops.hpp" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #ifdef _WIN32 8 | #include 9 | #else 10 | #include 11 | #include 12 | #include 13 | #endif 14 | 15 | #define DEBUG_CPU_OP_QUANTS false 16 | 17 | #define BUFFER_ALIGNMENT 64 18 | 19 | static NnByte *allocAlignedBuffer(NnSize size) { 20 | NnByte *buffer; 21 | #ifdef _WIN32 22 | buffer = (NnByte *)_aligned_malloc(size, BUFFER_ALIGNMENT); 23 | if (buffer == NULL) 24 | throw std::runtime_error("_aligned_malloc failed"); 25 | #else 26 | if (posix_memalign((void **)&buffer, BUFFER_ALIGNMENT, size) != 0) 27 | throw std::runtime_error("posix_memalign failed"); 28 | mlock(buffer, size); 29 | #endif 30 | return buffer; 31 | } 32 | 33 | static void releaseAlignedBuffer(NnByte *buffer) { 34 | #ifdef _WIN32 35 | _aligned_free(buffer); 36 | #else 37 | free(buffer); 38 | #endif 39 | } 40 | 41 | NnCpuDevice::NnCpuDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) { 42 | this->netConfig = netConfig; 43 | this->nodeConfig = nodeConfig; 44 | this->netExecution = netExecution; 45 | 46 | printCpuInstructionSet(); 47 | 48 | nBuffers = nodeConfig->nBuffers; 49 | buffers = new NnByte *[nBuffers]; 50 | for (NnUint bufferIndex = 0; bufferIndex < nBuffers; bufferIndex++) { 51 | NnBufferConfig *config = &nodeConfig->buffers[bufferIndex]; 52 | NnByte *buffer = allocAlignedBuffer(config->size.nBytes); 53 | buffers[bufferIndex] = buffer; 54 | } 55 | 56 | bufferFlags = new NnByte[nBuffers]; 57 | std::memset(bufferFlags, 0, nBuffers * sizeof(NnByte)); 58 | } 59 | 60 | NnCpuDevice::~NnCpuDevice() { 61 | for (NnUint bufferIndex = 0; bufferIndex < nBuffers; bufferIndex++) 62 | releaseAlignedBuffer(buffers[bufferIndex]); 63 | delete[] buffers; 64 | delete[] bufferFlags; 65 | } 66 | 67 | NnUint NnCpuDevice::maxNThreads() { 68 | return std::thread::hardware_concurrency(); 69 | } 70 | 71 | NnDeviceSegment *NnCpuDevice::createSegment(NnUint segmentIndex) { 72 | NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex]; 73 | assert(segmentConfig->nOps > 0); 74 | 75 | std::vector opQuants(segmentConfig->nOps); 76 | std::vector opForwardLocal(segmentConfig->nOps); 77 | std::vector inputSizes(segmentConfig->nOps); 78 | std::vector outputSizes(segmentConfig->nOps); 79 | 80 | std::unique_ptr inputsPtr(new NnByte *[segmentConfig->nOps * netConfig->nBatches]); 81 | std::unique_ptr outputsPtr(new NnByte *[segmentConfig->nOps * netConfig->nBatches]); 82 | NnByte **inputs = inputsPtr.get(); 83 | NnByte **outputs = outputsPtr.get(); 84 | 85 | for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { 86 | NnOpConfig *opConfig = &segmentConfig->ops[opIndex]; 87 | NnSize2D inputSize; 88 | NnSize2D outputSize; 89 | resolvePointer(&inputs[opIndex * netConfig->nBatches], &inputSize, &opConfig->input); 90 | resolvePointer(&outputs[opIndex * netConfig->nBatches], &outputSize, &opConfig->output); 91 | NnOpQuantType opQuant = getOpQuantType( 92 | inputSize.floatType, 93 | opConfig->weightSize.floatType, 94 | outputSize.floatType); 95 | #if DEBUG_CPU_OP_QUANTS 96 | printf("%20s %2d: %s\n", opConfig->name, opConfig->index, opQuantTypeToString(opQuant)); 97 | #endif 98 | NnCpuOpForward forward = getCpuOpForward(opConfig->code, opQuant); 99 | if (forward == nullptr) { 100 | throw std::invalid_argument( 101 | std::string("Unsupported CPU op code: ") + opCodeToString(opConfig->code) + 102 | ", quant: " + opQuantTypeToString(opQuant) + 103 | ", op name: " + opConfig->name); 104 | } 105 | inputSizes[opIndex] = inputSize; 106 | outputSizes[opIndex] = outputSize; 107 | opQuants[opIndex] = opQuant; 108 | opForwardLocal[opIndex] = forward; 109 | } 110 | 111 | inputsPtr.release(); 112 | outputsPtr.release(); 113 | 114 | NnCpuOpForward *opForward = new NnCpuOpForward[segmentConfig->nOps]; 115 | NnCpuOpContext *opContexts = new NnCpuOpContext[segmentConfig->nOps]; 116 | 117 | for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { 118 | NnOpConfig *opConfig = &segmentConfig->ops[opIndex]; 119 | NnCpuOpContext *opContext = &opContexts[opIndex]; 120 | NnCpuOpForwardInit opInit = getCpuOpForwardInit(opConfig->code, opQuants[opIndex]); 121 | opContext->name = opConfig->name; 122 | opContext->opConfig = opConfig->config; 123 | opContext->weightSize = opConfig->weightSize; 124 | opContext->nBatches = netConfig->nBatches; 125 | opContext->pipes = netExecution->pipes; 126 | opContext->pipeConfigs = netConfig->pipes; 127 | opContext->buffers = buffers; 128 | opContext->bufferConfigs = nodeConfig->buffers; 129 | opContext->bufferFlags = bufferFlags; 130 | 131 | opContext->input = &inputs[opIndex * netConfig->nBatches]; 132 | opContext->inputSize = inputSizes[opIndex]; 133 | opContext->hasInputContinuousMemory = hasPointerContinuousMemory(&opConfig->input); 134 | 135 | opContext->output = &outputs[opIndex * netConfig->nBatches]; 136 | opContext->outputSize = outputSizes[opIndex]; 137 | opContext->hasOutputContinuousMemory = hasPointerContinuousMemory(&opConfig->output); 138 | 139 | #if not(DEBUG_USE_MMAP_FOR_WEIGHTS) 140 | if (opContext->weightSize.nBytes > 0) 141 | opContext->weight = allocAlignedBuffer(opContext->weightSize.nBytes); 142 | else 143 | opContext->weight = nullptr; 144 | #endif 145 | 146 | if (opInit != nullptr) 147 | opInit(opContext); 148 | opForward[opIndex] = opForwardLocal[opIndex]; 149 | } 150 | return new NnCpuDeviceSegment(opForward, opContexts, segmentConfig->nOps); 151 | } 152 | 153 | NnCpuDeviceSegment::~NnCpuDeviceSegment() { 154 | for (NnUint opIndex = 0; opIndex < nOps; opIndex++) { 155 | NnCpuOpContext *context = &opContexts[opIndex]; 156 | if (opIndex == 0) { 157 | delete[] context->input; 158 | delete[] context->output; 159 | } 160 | #if not(DEBUG_USE_MMAP_FOR_WEIGHTS) 161 | if (context->weightSize.nBytes > 0) 162 | releaseAlignedBuffer(context->weight); 163 | #endif 164 | } 165 | delete[] opForward; 166 | delete[] opContexts; 167 | } 168 | 169 | void NnCpuDevice::resolvePointer(NnByte **pntr, NnSize2D *pntrSize, NnPointerConfig *pointerConfig) { 170 | NnByte *source; 171 | NnSize2D *sourceSize; 172 | 173 | switch (pointerConfig->source) { 174 | case SRC_BUFFER: 175 | source = buffers[pointerConfig->pointerIndex]; 176 | sourceSize = &nodeConfig->buffers[pointerConfig->pointerIndex].size; 177 | break; 178 | case SRC_PIPE: 179 | source = netExecution->pipes[pointerConfig->pointerIndex]; 180 | sourceSize = &netConfig->pipes[pointerConfig->pointerIndex].size; 181 | break; 182 | default: 183 | throw std::invalid_argument("Unsupported pointer type"); 184 | } 185 | 186 | switch (pointerConfig->type) { 187 | case PNTR_RAW: { 188 | pntr[0] = source; 189 | *pntrSize = size2D(sourceSize->floatType, 1, sourceSize->length); 190 | return; 191 | } 192 | case PNTR_BATCH: 193 | case PNTR_BATCHED_SLICE: { 194 | ASSERT_EQ(sourceSize->y, netConfig->nBatches); 195 | 196 | NnSize batchBytes = getBytes(sourceSize->floatType, sourceSize->x); 197 | for (NnUint batchIndex = 0; batchIndex < netConfig->nBatches; batchIndex++) 198 | pntr[batchIndex] = &source[batchIndex * batchBytes]; 199 | *pntrSize = *sourceSize; 200 | 201 | if (pointerConfig->type == PNTR_BATCHED_SLICE) { 202 | assert(sourceSize->x % netConfig->nNodes == 0); 203 | NnUint xSlice = sourceSize->x / netConfig->nNodes; 204 | NnSize xSliceBytes = getBytes(sourceSize->floatType, xSlice); 205 | for (NnUint batchIndex = 0; batchIndex < netConfig->nBatches; batchIndex++) 206 | pntr[batchIndex] = &pntr[batchIndex][xSliceBytes * nodeConfig->nodeIndex]; 207 | *pntrSize = size2D(sourceSize->floatType, sourceSize->y, xSlice); 208 | } 209 | return; 210 | } 211 | default: 212 | throw std::invalid_argument("Unsupported pointer config"); 213 | } 214 | } 215 | 216 | void NnCpuDeviceSegment::loadWeight(NnUint opIndex, NnSize nBytes, NnByte *weight) { 217 | assert(opIndex >= 0); 218 | assert(opIndex < nOps); 219 | NnCpuOpContext *context = &opContexts[opIndex]; 220 | assert(context->weightSize.nBytes == nBytes); 221 | #if DEBUG_USE_MMAP_FOR_WEIGHTS 222 | context->weight = weight; 223 | #else 224 | std::memcpy(context->weight, weight, nBytes); 225 | #endif 226 | } 227 | 228 | void NnCpuDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) { 229 | NnCpuOpContext *context = &opContexts[opIndex]; 230 | // printf("forward: %d %s (%d/%d)\n", opIndex, context->name, threadIndex + 1, nThreads); fflush(stdout); 231 | opForward[opIndex](nThreads, threadIndex, batchSize, context); 232 | } 233 | -------------------------------------------------------------------------------- /src/nn/nn-cpu.hpp: -------------------------------------------------------------------------------- 1 | #ifndef NN_CPU_H 2 | #define NN_CPU_H 3 | 4 | #include 5 | #include "nn-executor.hpp" 6 | #include "nn-cpu-ops.hpp" 7 | 8 | #define DEBUG_USE_MMAP_FOR_WEIGHTS false 9 | 10 | class NnCpuDevice : public NnDevice { 11 | public: 12 | NnByte **buffers; 13 | private: 14 | NnNetConfig *netConfig; 15 | NnNodeConfig *nodeConfig; 16 | NnNetExecution *netExecution; 17 | NnUint nBuffers; 18 | NnByte *bufferFlags; 19 | public: 20 | NnCpuDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution); 21 | ~NnCpuDevice() override; 22 | NnUint maxNThreads() override; 23 | NnDeviceSegment *createSegment(NnUint segmentIndex) override; 24 | void resolvePointer(NnByte **pntr, NnSize2D *pntrSize, NnPointerConfig *pointerConfig); 25 | }; 26 | 27 | class NnCpuDeviceSegment : public NnDeviceSegment { 28 | public: 29 | NnUint nOps; 30 | NnCpuOpForward *opForward; 31 | NnCpuOpContext *opContexts; 32 | NnCpuDeviceSegment(NnCpuOpForward *opForward, NnCpuOpContext *opContexts, NnUint nOps) 33 | : opForward(opForward), opContexts(opContexts), nOps(nOps) {} 34 | ~NnCpuDeviceSegment() override; 35 | void loadWeight(NnUint opIndex, NnSize nBytes, NnByte *weight) override; 36 | void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) override; 37 | }; 38 | 39 | #endif -------------------------------------------------------------------------------- /src/nn/nn-executor.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "nn-executor.hpp" 5 | 6 | void NnFakeNodeSynchronizer::sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) { 7 | // Nothing 8 | } 9 | 10 | NnNetExecution::NnNetExecution(NnUint nThreads, NnNetConfig *netConfig) { 11 | this->nThreads = nThreads; 12 | this->nBatches = netConfig->nBatches; 13 | this->nPipes = netConfig->nPipes; 14 | this->batchSize = 0; // This value must be overwritten before calling forward 15 | 16 | pipes = new NnByte *[netConfig->nPipes]; 17 | for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) { 18 | NnPipeConfig *pipeConfig = &netConfig->pipes[pipeIndex]; 19 | NnByte *pipe = new NnByte[pipeConfig->size.nBytes]; 20 | std::memset(pipe, 0, pipeConfig->size.nBytes); 21 | pipes[pipeIndex] = pipe; 22 | } 23 | } 24 | 25 | NnNetExecution::~NnNetExecution() { 26 | for (NnUint pipeIndex = 0; pipeIndex < nPipes; pipeIndex++) 27 | delete[] pipes[pipeIndex]; 28 | delete[] pipes; 29 | } 30 | 31 | void NnNetExecution::setBatchSize(NnUint batchSize) { 32 | assert(batchSize <= nBatches); 33 | this->batchSize = batchSize; 34 | } 35 | 36 | NnExecutorDevice::NnExecutorDevice(NnDevice *device, int segmentFrom, int segmentTo) { 37 | this->device = std::unique_ptr(device); 38 | this->segmentFrom = segmentFrom; 39 | this->segmentTo = segmentTo; 40 | } 41 | 42 | NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, std::vector *devices, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark) 43 | : segments(nodeConfig->nSegments), steps() 44 | { 45 | NnUint maxNThreads = 0; 46 | for (NnExecutorDevice &d : *devices) { 47 | if (d.device->maxNThreads() > maxNThreads) 48 | maxNThreads = d.device->maxNThreads(); 49 | } 50 | if (netExecution->nThreads > maxNThreads) 51 | throw std::invalid_argument("This configuration supports max " + std::to_string(maxNThreads) + " threads"); 52 | 53 | this->netExecution = netExecution; 54 | this->nodeConfig = nodeConfig; 55 | 56 | bool useSynchronizer = netConfig->nNodes > 1; 57 | for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { 58 | NnDevice *device = nullptr; 59 | for (NnExecutorDevice &d : *devices) { 60 | if ( 61 | (d.segmentFrom == -1 && d.segmentTo == -1) || 62 | (segmentIndex >= d.segmentFrom && segmentIndex <= d.segmentTo) 63 | ) { 64 | device = d.device.get(); 65 | break; 66 | } 67 | } 68 | if (device == nullptr) 69 | throw std::invalid_argument("Cannot locate device for segment " + std::to_string(segmentIndex)); 70 | 71 | NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex]; 72 | if (segmentConfig->nOps > 0) { 73 | NnDeviceSegment *segment = device->createSegment(segmentIndex); 74 | segments[segmentIndex] = std::unique_ptr(segment); 75 | 76 | for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) 77 | steps.push_back(NnExecutorStep{ STEP_EXECUTE_OP, segment, opIndex, &segmentConfig->ops[opIndex] }); 78 | } 79 | if (useSynchronizer && segmentConfig->nSyncs > 0) 80 | steps.push_back(NnExecutorStep{ STEP_SYNC_NODES, nullptr, segmentIndex, nullptr }); 81 | } 82 | 83 | steps.shrink_to_fit(); 84 | 85 | context.nThreads = netExecution->nThreads; 86 | context.synchronizer = synchronizer; 87 | context.nSteps = (NnUint)steps.size(); 88 | context.steps = steps.data(); 89 | if (benchmark) 90 | context.timer = new Timer(); 91 | else 92 | context.timer = nullptr; 93 | 94 | threads = new NnExecutorThread[netExecution->nThreads]; 95 | for (NnUint threadIndex = 0; threadIndex < netExecution->nThreads; threadIndex++) { 96 | NnExecutorThread *thread = &threads[threadIndex]; 97 | thread->threadIndex = threadIndex; 98 | thread->context = &context; 99 | } 100 | } 101 | 102 | NnExecutor::~NnExecutor() { 103 | if (context.timer != nullptr) 104 | delete context.timer; 105 | delete[] threads; 106 | } 107 | 108 | void NnExecutor::loadWeight(const char *name, NnUint index, NnSize nBytes, NnByte *weight) { 109 | for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { 110 | NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex]; 111 | for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) { 112 | NnOpConfig *opConfig = &segmentConfig->ops[opIndex]; 113 | if (opConfig->index == index && std::strcmp(opConfig->name, name) == 0) { 114 | NnDeviceSegment *segment = segments[segmentIndex].get(); 115 | assert(segment != nullptr); 116 | segment->loadWeight(opIndex, nBytes, weight); 117 | return; 118 | } 119 | } 120 | } 121 | throw std::invalid_argument("Cannot locate op by name: " + std::string(name)); 122 | } 123 | 124 | inline void executeStep(NnExecutorStep *step, NnUint nThreads, NnExecutorThread *thread, NnExecutorContext *context) { 125 | if (step->type == STEP_EXECUTE_OP) { 126 | step->segment->forward(step->arg0, nThreads, thread->threadIndex, context->batchSize); 127 | } else if (step->type == STEP_SYNC_NODES) { 128 | context->synchronizer->sync(step->arg0, nThreads, thread->threadIndex); 129 | } else { 130 | throw std::invalid_argument("Unsupported step type"); 131 | } 132 | } 133 | 134 | static inline void *executorThreadHandler(void *arg) { 135 | NnExecutorThread *thread = (NnExecutorThread *)arg; 136 | NnExecutorContext *context = thread->context; 137 | NnUint nThreads = context->nThreads; 138 | NnUint doneCount = nThreads - 1; 139 | 140 | while (true) { 141 | const unsigned int currentStepIndex = context->currentStepIndex.load(); 142 | if (currentStepIndex == context->nSteps) 143 | break; 144 | 145 | NnExecutorStep *step = &context->steps[currentStepIndex]; 146 | executeStep(step, nThreads, thread, context); 147 | 148 | NnUint currentCount = context->doneThreadCount.fetch_add(1); 149 | if (currentCount == doneCount) { 150 | if (context->timer != nullptr) { 151 | NnUint time = context->timer->elapsedMicroseconds(); 152 | context->totalTime[step->type] += time; 153 | context->timer->reset(); 154 | } 155 | 156 | context->doneThreadCount.store(0); 157 | context->currentStepIndex.fetch_add(1); 158 | } else { 159 | while (context->currentStepIndex.load() == currentStepIndex); 160 | } 161 | } 162 | return nullptr; 163 | } 164 | 165 | void NnExecutor::forward() { 166 | assert(netExecution->batchSize > 0); 167 | 168 | NnUint nThreads = netExecution->nThreads; 169 | context.currentStepIndex.exchange(0); 170 | context.doneThreadCount.exchange(0); 171 | context.batchSize = netExecution->batchSize; 172 | 173 | if (context.timer != nullptr) { 174 | std::memset(context.totalTime, 0, sizeof(context.totalTime)); 175 | context.timer->reset(); 176 | } 177 | 178 | NnUint threadIndex; 179 | for (threadIndex = 1; threadIndex < nThreads; threadIndex++) { 180 | int result = pthread_create(&threads[threadIndex].handler, NULL, (PthreadFunc)executorThreadHandler, (void *)&threads[threadIndex]); 181 | if (result != 0) 182 | throw std::runtime_error("Failed to create thread"); 183 | } 184 | executorThreadHandler((void *)&threads[0]); 185 | for (threadIndex = 1; threadIndex < nThreads; threadIndex++) 186 | pthread_join(threads[threadIndex].handler, NULL); 187 | } 188 | 189 | NnUint NnExecutor::getTotalTime(NnExecutorStepType type) { 190 | assert((NnUint)type < N_STEP_TYPES); 191 | return context.totalTime[type]; 192 | } 193 | -------------------------------------------------------------------------------- /src/nn/nn-executor.hpp: -------------------------------------------------------------------------------- 1 | #ifndef NN_EXECUTOR_H 2 | #define NN_EXECUTOR_H 3 | 4 | #include "nn-core.hpp" 5 | #include 6 | #include 7 | #include "pthread.h" 8 | 9 | class NnDeviceSegment { 10 | public: 11 | virtual ~NnDeviceSegment() {}; 12 | virtual void loadWeight(NnUint opIndex, NnSize nBytes, NnByte *weight) = 0; 13 | virtual void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) = 0; 14 | }; 15 | 16 | class NnDevice { 17 | public: 18 | virtual NnUint maxNThreads() = 0; 19 | virtual ~NnDevice() {} 20 | virtual NnDeviceSegment *createSegment(NnUint segmentIndex) = 0; 21 | }; 22 | 23 | class NnNodeSynchronizer { 24 | public: 25 | virtual ~NnNodeSynchronizer() {}; 26 | virtual void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) = 0; 27 | }; 28 | 29 | class NnFakeNodeSynchronizer : public NnNodeSynchronizer { 30 | public: 31 | ~NnFakeNodeSynchronizer() override {}; 32 | void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) override; 33 | }; 34 | 35 | class NnNetExecution { 36 | public: 37 | NnUint nThreads; 38 | NnUint nPipes; 39 | NnByte **pipes; 40 | NnUint batchSize; 41 | NnUint nBatches; 42 | NnNetExecution(NnUint nThreads, NnNetConfig *netConfig); 43 | ~NnNetExecution(); 44 | void setBatchSize(NnUint batchSize); 45 | }; 46 | 47 | enum NnExecutorStepType { 48 | STEP_EXECUTE_OP, 49 | STEP_SYNC_NODES, 50 | }; 51 | 52 | #define N_STEP_TYPES STEP_SYNC_NODES + 1 53 | 54 | class NnExecutorDevice { 55 | public: 56 | std::unique_ptr device; 57 | int segmentFrom; 58 | int segmentTo; 59 | NnExecutorDevice(NnDevice *device, int segmentFrom, int segmentTo); 60 | }; 61 | 62 | typedef struct { 63 | NnExecutorStepType type; 64 | NnDeviceSegment *segment; 65 | NnUint arg0; 66 | NnOpConfig *opConfig; 67 | } NnExecutorStep; 68 | 69 | typedef struct { 70 | NnUint nThreads; 71 | NnUint nSteps; 72 | NnExecutorStep *steps; 73 | NnNodeSynchronizer *synchronizer; 74 | std::atomic_uint currentStepIndex; 75 | std::atomic_uint doneThreadCount; 76 | NnUint batchSize; 77 | Timer *timer; 78 | NnUint totalTime[N_STEP_TYPES]; 79 | } NnExecutorContext; 80 | 81 | typedef struct { 82 | NnUint threadIndex; 83 | NnExecutorContext *context; 84 | PthreadHandler handler; 85 | } NnExecutorThread; 86 | 87 | class NnExecutor { 88 | private: 89 | NnNetExecution *netExecution; 90 | NnNodeConfig *nodeConfig; 91 | std::vector> segments; 92 | std::vector steps; 93 | NnExecutorThread *threads; 94 | NnExecutorContext context; 95 | public: 96 | NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, std::vector *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark); 97 | ~NnExecutor(); 98 | void loadWeight(const char *name, NnUint index, NnSize nBytes, NnByte *weight); 99 | void forward(); 100 | NnUint getTotalTime(NnExecutorStepType type); 101 | }; 102 | 103 | #endif -------------------------------------------------------------------------------- /src/nn/nn-network.hpp: -------------------------------------------------------------------------------- 1 | #ifndef NN_NETWORK_H 2 | #define NN_NETWORK_H 3 | 4 | #include "nn-executor.hpp" 5 | 6 | #define ROOT_SOCKET_INDEX 0 7 | 8 | void initSockets(); 9 | void cleanupSockets(); 10 | int acceptSocket(int serverSocket); 11 | void setReuseAddr(int socket); 12 | void writeSocket(int socket, const void* data, NnSize size); 13 | void readSocket(int socket, void* data, NnSize size); 14 | int createServerSocket(int port); 15 | void closeServerSocket(int serverSocket); 16 | 17 | class NnReadNetworkException : public std::exception { 18 | public: 19 | int code; 20 | const char *message; 21 | NnReadNetworkException(int code, const char *message); 22 | }; 23 | 24 | class NnWriteNetworkException : public std::exception { 25 | public: 26 | int code; 27 | const char *message; 28 | NnWriteNetworkException(int code, const char *message); 29 | }; 30 | 31 | struct NnSocketIo { 32 | NnUint socketIndex; 33 | const void *data; 34 | NnSize size; 35 | }; 36 | 37 | class NnNetwork { 38 | private: 39 | int *sockets; 40 | NnSize *sentBytes; 41 | NnSize *recvBytes; 42 | 43 | public: 44 | static std::unique_ptr serve(int port); 45 | static std::unique_ptr connect(NnUint nSockets, char **hosts, NnUint *ports); 46 | 47 | NnUint nSockets; 48 | 49 | NnNetwork(NnUint nSockets, int *sockets); 50 | ~NnNetwork(); 51 | 52 | void setTurbo(bool enabled); 53 | void write(const NnUint socketIndex, const void *data, const NnSize size); 54 | void read(const NnUint socketIndex, void *data, const NnSize size); 55 | void writeAck(const NnUint socketIndex); 56 | void readAck(const NnUint socketIndex); 57 | bool tryReadWithMaxAttempts(NnUint socketIndex, void *data, NnSize size, unsigned long maxAttempts); 58 | void writeMany(NnUint n, NnSocketIo *ios); 59 | void writeAll(void *data, NnSize size); 60 | void readMany(NnUint n, NnSocketIo *ios); 61 | void getStats(NnSize *sentBytes, NnSize *recvBytes); 62 | void resetStats(); 63 | }; 64 | 65 | class NnNetworkNodeSynchronizer : public NnNodeSynchronizer { 66 | private: 67 | NnNetwork *network; 68 | NnNetExecution *execution; 69 | NnNetConfig *netConfig; 70 | NnNodeConfig *nodeConfig; 71 | public: 72 | NnNetworkNodeSynchronizer(NnNetwork *network, NnNetExecution *execution, NnNetConfig *netConfig, NnNodeConfig *nodeConfig); 73 | ~NnNetworkNodeSynchronizer() override {}; 74 | void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) override; 75 | }; 76 | 77 | class NnRootConfigWriter { 78 | private: 79 | NnNetwork *network; 80 | public: 81 | NnRootConfigWriter(NnNetwork *network); 82 | void writeNet(NnUint socketIndex, NnNetConfig *config); 83 | void writeNode(NnUint socketIndex, NnNodeConfig *config); 84 | void writeToWorkers(NnNetConfig *netConfig, NnNodeConfig *nodeConfigs); 85 | }; 86 | 87 | class NnWorkerConfigReader { 88 | private: 89 | NnNetwork *network; 90 | public: 91 | NnWorkerConfigReader(NnNetwork *network); 92 | NnNetConfig readNet(); 93 | NnNodeConfig readNode(); 94 | }; 95 | 96 | class NnRootWeightLoader { 97 | private: 98 | NnExecutor *executor; 99 | NnNetwork *network; 100 | NnUint nNodes; 101 | NnByte *temp; 102 | NnSize tempSize; 103 | public: 104 | NnRootWeightLoader(NnExecutor *executor, NnNetwork *network, NnUint nNodes); 105 | ~NnRootWeightLoader(); 106 | void writeWeight(NnUint nodeIndex, const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight); 107 | NnSize loadRoot(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight); 108 | NnSize loadAll(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight); 109 | NnSize loadRowMatmulSlices(const char *opName, NnUint opIndex, NnRowMatmulSlice *slice, NnByte *weight); 110 | NnSize loadColMatmulSlices(const char *opName, NnUint opIndex, NnColMatmulSlice *slice, NnByte *weight); 111 | void finish(); 112 | private: 113 | void allocate(NnSize size);}; 114 | 115 | class NnWorkerWeightReader { 116 | private: 117 | NnExecutor *executor; 118 | NnNetwork *network; 119 | NnByte *temp; 120 | NnUint tempSize; 121 | public: 122 | NnWorkerWeightReader(NnExecutor *executor, NnNetwork *network); 123 | ~NnWorkerWeightReader(); 124 | void read(); 125 | private: 126 | void allocate(NnUint size); 127 | }; 128 | 129 | #endif 130 | -------------------------------------------------------------------------------- /src/nn/nn-quants.cpp: -------------------------------------------------------------------------------- 1 | #include "nn-quants.hpp" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #if defined(CONVERT_F16_TO_F32_LOOKUP) 9 | float f16ToF32Lookup[65536]; 10 | #endif 11 | 12 | void initQuants() { 13 | #if defined(CONVERT_F16_TO_F32_LOOKUP) 14 | for (NnUint i = 0; i < 65536; i++) 15 | f16ToF32Lookup[i] = convertF16toF32Impl((NnFp16)i); 16 | #endif 17 | } 18 | 19 | float convertF16toF32Impl(const NnFp16 value) { 20 | union Fl32 { 21 | uint32_t u; 22 | float f; 23 | }; 24 | const Fl32 magic = { (254U - 15U) << 23 }; 25 | const Fl32 infNan = { (127U + 16U) << 23 }; 26 | Fl32 result; 27 | result.u = (value & 0x7FFFU) << 13; 28 | result.f *= magic.f; 29 | if (result.f >= infNan.f) 30 | result.u |= 255U << 23; 31 | result.u |= (value & 0x8000U) << 16; 32 | return result.f; 33 | } 34 | 35 | NnFp16 convertF32ToF16Impl(const float x) { 36 | int i = *(int *)&x; 37 | int s = (i >> 16) & 0x00008000; 38 | int e = ((i >> 23) & 0x000000ff) - (127 - 15); 39 | int m = i & 0x007fffff; 40 | if (e <= 0) { 41 | if (e < -10) { 42 | return s; 43 | } 44 | m = m | 0x00800000; 45 | int t = 14 - e; 46 | int a = (1 << (t - 1)) - 1; 47 | int b = (m >> t) & 1; 48 | m = (m + a + b) >> t; 49 | return s | m; 50 | } 51 | if (e == 0xff - (127 - 15)) { 52 | if (m == 0) { 53 | return s | 0x7c00; 54 | } 55 | m >>= 13; 56 | return s | 0x7c00 | m | (m == 0); 57 | } 58 | m = m + 0x00000fff + ((m >> 13) & 1); 59 | if (m & 0x00800000) { 60 | m = 0; 61 | e += 1; 62 | } 63 | assert(e <= 30); 64 | return s | (e << 10) | (m >> 13); 65 | } 66 | 67 | void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) { 68 | assert(n % Q80_BLOCK_SIZE == 0); 69 | const NnUint nBlocks = n / Q80_BLOCK_SIZE; 70 | SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex); 71 | 72 | #if defined(__ARM_NEON) 73 | for (NnUint i = start; i < end; i++) { 74 | const float *x = &input[i * Q80_BLOCK_SIZE]; 75 | NnBlockQ80 *y = &output[i]; 76 | 77 | float32x4_t amaxVec = vdupq_n_f32(0.0f); 78 | for (NnUint j = 0; j < Q80_BLOCK_SIZE; j += 4) { 79 | const float32x4_t vec = vld1q_f32(&x[j]); 80 | const float32x4_t abs_vec = vabsq_f32(vec); 81 | amaxVec = vmaxq_f32(amaxVec, abs_vec); 82 | } 83 | 84 | float amax = vmaxvq_f32(amaxVec); 85 | 86 | const float d = amax / 127.0f; 87 | const float id = d != 0.0f ? 1.0f / d : 0.0f; 88 | 89 | y->d = CONVERT_F32_TO_F16(d); 90 | 91 | const float32x4_t vid_vec = vdupq_n_f32(id); 92 | 93 | for (NnUint j = 0; j < Q80_BLOCK_SIZE; j += 4) { 94 | float32x4_t vec = vld1q_f32(&x[j]); 95 | vec = vmulq_f32(vec, vid_vec); 96 | 97 | const uint32x4_t sign_mask = vcgeq_f32(vec, vdupq_n_f32(0.0f)); 98 | const float32x4_t half = vbslq_f32(sign_mask, vdupq_n_f32(0.5f), vdupq_n_f32(-0.5f)); 99 | vec = vaddq_f32(vec, half); 100 | 101 | const int32x4_t vec_i32 = vcvtq_s32_f32(vec); 102 | const int16x4_t vec_i16 = vqmovn_s32(vec_i32); 103 | const int8x8_t vec_i8 = vqmovn_s16(vcombine_s16(vec_i16, vec_i16)); 104 | 105 | vst1_lane_s32((int32_t *)(y->qs + j), vreinterpret_s32_s8(vec_i8), 0); 106 | } 107 | } 108 | #elif defined(__AVX2__) 109 | for (NnUint i = start; i < end; ++i) { 110 | const float *x = input + i * Q80_BLOCK_SIZE; 111 | NnBlockQ80 *y = output + i; 112 | 113 | __m256 max_abs = _mm256_setzero_ps(); 114 | for (int j = 0; j < Q80_BLOCK_SIZE; j += 8) { 115 | __m256 vec = _mm256_loadu_ps(x + j); 116 | __m256 abs_vec = _mm256_and_ps(vec, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF))); 117 | max_abs = _mm256_max_ps(max_abs, abs_vec); 118 | } 119 | __m128 max_hi = _mm256_extractf128_ps(max_abs, 1); 120 | __m128 max_lo = _mm256_castps256_ps128(max_abs); 121 | __m128 max_128 = _mm_max_ps(max_hi, max_lo); 122 | max_128 = _mm_max_ps(max_128, _mm_movehl_ps(max_128, max_128)); 123 | max_128 = _mm_max_ss(max_128, _mm_shuffle_ps(max_128, max_128, _MM_SHUFFLE(1, 1, 1, 1))); 124 | float amax = _mm_cvtss_f32(max_128); 125 | 126 | const float d = amax / 127.0f; 127 | const float id = (d != 0.0f) ? 1.0f / d : 0.0f; 128 | y->d = CONVERT_F32_TO_F16(d); 129 | 130 | const __m256 id_vec = _mm256_set1_ps(id); 131 | const __m128i shuffle_mask = _mm_set_epi8( 132 | -1, -1, -1, -1, -1, -1, -1, -1, 133 | -1, -1, -1, -1, 12, 8, 4, 0 134 | ); 135 | 136 | for (int j = 0; j < Q80_BLOCK_SIZE; j += 8) { 137 | __m256 vec = _mm256_loadu_ps(x + j); 138 | __m256 scaled = _mm256_mul_ps(vec, id_vec); 139 | __m256 rounded = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); 140 | __m256i integers = _mm256_cvtps_epi32(rounded); 141 | 142 | __m128i low = _mm256_extracti128_si256(integers, 0); 143 | __m128i high = _mm256_extracti128_si256(integers, 1); 144 | 145 | __m128i low_bytes = _mm_shuffle_epi8(low, shuffle_mask); 146 | __m128i high_bytes = _mm_shuffle_epi8(high, shuffle_mask); 147 | 148 | uint32_t low_part = _mm_extract_epi32(low_bytes, 0); 149 | uint32_t high_part = _mm_extract_epi32(high_bytes, 0); 150 | uint64_t packed = (static_cast(high_part) << 32) | low_part; 151 | std::memcpy(y->qs + j, &packed, sizeof(packed)); 152 | } 153 | } 154 | #else 155 | for (NnUint i = start; i < end; i++) { 156 | const float *x = &input[i * Q80_BLOCK_SIZE]; 157 | NnBlockQ80 *y = &output[i]; 158 | 159 | float amax = 0.0f; 160 | for (NnUint j = 0; j < Q80_BLOCK_SIZE; j++) { 161 | const float v = fabsf(x[j]); 162 | amax = amax > v ? amax : v; 163 | } 164 | 165 | const float d = amax / ((1 << 7) - 1); 166 | const float id = d ? 1.0f / d : 0.0f; 167 | y->d = CONVERT_F32_TO_F16(d); 168 | for (NnUint j = 0; j < Q80_BLOCK_SIZE; ++j) { 169 | y->qs[j] = roundf(x[j] * id); 170 | } 171 | } 172 | #endif 173 | } 174 | 175 | void dequantizeQ80toF32(const NnBlockQ80 *input, float* output, const NnUint k, const NnUint nThreads, const NnUint threadIndex) { 176 | assert(k % Q80_BLOCK_SIZE == 0); 177 | const int nBlocks = k / Q80_BLOCK_SIZE; 178 | const int blocksPerThread = nBlocks / nThreads; 179 | const int sk = blocksPerThread * Q80_BLOCK_SIZE; 180 | const int currentThreadBlocks = blocksPerThread + (threadIndex == nThreads - 1 ? nBlocks % nThreads : 0); 181 | 182 | const NnBlockQ80 *x = &input[blocksPerThread * threadIndex]; 183 | float* y = &output[sk * threadIndex]; 184 | 185 | for (int i = 0; i < currentThreadBlocks; i++) { 186 | const float d = CONVERT_F16_TO_F32(x[i].d); 187 | for (int j = 0; j < Q80_BLOCK_SIZE; j++) { 188 | y[i * Q80_BLOCK_SIZE + j] = x[i].qs[j] * d; 189 | } 190 | } 191 | } 192 | 193 | void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) { 194 | assert(n % Q40_BLOCK_SIZE == 0); 195 | const NnUint nBlocks = n / Q40_BLOCK_SIZE; 196 | const NnUint halfSize = Q40_BLOCK_SIZE / 2; 197 | SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex); 198 | 199 | for (NnUint i = start; i < end; i++) { 200 | float amax = 0.0f; 201 | float max = 0.0f; 202 | for (NnUint j = 0; j < Q40_BLOCK_SIZE; j++) { 203 | float v = x[i * Q40_BLOCK_SIZE + j]; 204 | if (amax < fabsf(v)) { 205 | amax = fabsf(v); 206 | max = v; 207 | } 208 | } 209 | 210 | const float d = max / -8.0f; 211 | const float id = d ? 1.0f / d : 0.0f; 212 | 213 | NnBlockQ40 *o = &output[i]; 214 | o->d = CONVERT_F32_TO_F16(d); 215 | for (NnUint j = 0; j < halfSize; j++) { 216 | const float x0 = x[i * Q40_BLOCK_SIZE + j] * id; 217 | const float x1 = x[i * Q40_BLOCK_SIZE + halfSize + j] * id; 218 | 219 | uint8_t xi0 = (int8_t)(x0 + 8.5f); 220 | uint8_t xi1 = (int8_t)(x1 + 8.5f); 221 | if (xi0 > 15) xi0 = 15; 222 | if (xi1 > 15) xi1 = 15; 223 | 224 | o->qs[j] = xi0 | (xi1 << 4); 225 | } 226 | } 227 | } 228 | 229 | void dequantizeQ40toF32(const NnBlockQ40 *x, float *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) { 230 | assert(n % Q40_BLOCK_SIZE == 0); 231 | const NnUint nBlocks = n / Q40_BLOCK_SIZE; 232 | SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex); 233 | 234 | for (NnUint i = start; i < end; i++) { 235 | const NnBlockQ40 *b = &x[i]; 236 | const float d = CONVERT_F16_TO_F32(b->d); 237 | 238 | for (int j = 0; j < Q40_BLOCK_SIZE / 2; ++j) { 239 | const int x0 = (b->qs[j] & 0x0F) - 8; 240 | const int x1 = (b->qs[j] >> 4) - 8; 241 | 242 | output[i * Q40_BLOCK_SIZE + j] = x0 * d; 243 | output[i * Q40_BLOCK_SIZE + j + Q40_BLOCK_SIZE / 2] = x1 * d; 244 | } 245 | } 246 | } 247 | 248 | const char *floatTypeToString(NnFloatType type) { 249 | if (type == F_UNK) return "F_UNK"; 250 | if (type == F_32) return "F_32"; 251 | if (type == F_16) return "F_16"; 252 | if (type == F_Q40) return "F_Q40"; 253 | if (type == F_Q80) return "F_Q80"; 254 | throw std::invalid_argument("Unknown float type"); 255 | } 256 | -------------------------------------------------------------------------------- /src/nn/nn-quants.hpp: -------------------------------------------------------------------------------- 1 | #ifndef NN_QUANTS_H 2 | #define NN_QUANTS_H 3 | 4 | #include 5 | #include 6 | #if defined(__ARM_NEON) 7 | #include 8 | #elif defined(__AVX2__) 9 | #include 10 | #endif 11 | 12 | typedef std::uint8_t NnByte; 13 | typedef std::uint32_t NnUint; 14 | typedef std::size_t NnSize; 15 | typedef std::uint16_t NnFp16; 16 | 17 | float convertF16toF32Impl(const NnFp16 value); 18 | NnFp16 convertF32ToF16Impl(const float x); 19 | 20 | #if defined(__ARM_NEON) && defined(__ARM_FP16_FORMAT_IEEE) 21 | inline float convertF16ToF32Neon(const NnFp16 value) { 22 | __fp16 fp; 23 | std::memcpy(&fp, &value, sizeof(fp)); 24 | return (float)fp; 25 | } 26 | 27 | inline NnFp16 convertF32ToF16Neon(const float x) { 28 | __fp16 h = x; 29 | return *(NnFp16 *)&h; 30 | } 31 | 32 | #define CONVERT_F16_TO_F32(value) convertF16ToF32Neon(value) 33 | #define CONVERT_F32_TO_F16(value) convertF32ToF16Neon(value) 34 | #elif defined(__F16C__) 35 | #define CONVERT_F32_TO_F16(v) _cvtss_sh((v), _MM_FROUND_TO_NEAREST_INT) 36 | #endif 37 | 38 | #if !defined(CONVERT_F16_TO_F32) 39 | extern float f16ToF32Lookup[65536]; 40 | 41 | inline static float convertF16ToF32Lookup(const NnFp16 value) { 42 | return f16ToF32Lookup[value]; 43 | } 44 | 45 | #define CONVERT_F16_TO_F32_LOOKUP 46 | #define CONVERT_F16_TO_F32(value) convertF16ToF32Lookup(value) 47 | #endif 48 | 49 | #if !defined(CONVERT_F32_TO_F16) 50 | #define CONVERT_F32_TO_F16(value) convertF32ToF16Impl(value) 51 | #endif 52 | 53 | #define Q40_BLOCK_SIZE 32 54 | #define Q80_BLOCK_SIZE 32 55 | 56 | enum NnFloatType { 57 | F_UNK = -1, 58 | F_32 = 0, 59 | F_16 = 1, 60 | F_Q40 = 2, 61 | F_Q80 = 3, 62 | }; 63 | 64 | typedef struct { 65 | std::uint16_t d; 66 | std::uint8_t qs[Q40_BLOCK_SIZE / 2]; 67 | } NnBlockQ40; 68 | 69 | typedef struct { 70 | std::uint16_t d; 71 | std::int8_t qs[Q80_BLOCK_SIZE]; 72 | } NnBlockQ80; 73 | 74 | void initQuants(); 75 | void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnUint k, const NnUint nThreads, const NnUint threadIndex); 76 | void dequantizeQ80toF32(const NnBlockQ80 *input, float* output, const NnUint k, const NnUint nThreads, const NnUint threadIndex); 77 | void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex); 78 | void dequantizeQ40toF32(const NnBlockQ40 *x, float *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex); 79 | 80 | const char *floatTypeToString(NnFloatType type); 81 | 82 | #define SPLIT_THREADS(varStart, varEnd, rangeLen, nThreads, threadIndex) \ 83 | const NnUint rangeSlice = rangeLen / nThreads; \ 84 | const NnUint rangeRest = rangeLen % nThreads; \ 85 | const NnUint varStart = threadIndex * rangeSlice + (threadIndex < rangeRest ? threadIndex : rangeRest); \ 86 | const NnUint varEnd = varStart + rangeSlice + (threadIndex < rangeRest ? 1 : 0); 87 | 88 | #endif -------------------------------------------------------------------------------- /src/nn/nn-vulkan.hpp: -------------------------------------------------------------------------------- 1 | #ifndef NN_VULKAN_HPP 2 | #define NN_VULKAN_HPP 3 | 4 | #include 5 | #include 6 | #include "nn-executor.hpp" 7 | #include "nn-cpu-ops.hpp" 8 | 9 | #define DEBUG_VULKAN_TRACE false 10 | 11 | typedef struct { 12 | vk::Instance instance; 13 | vk::PhysicalDevice physicalDevice; 14 | vk::Device device; 15 | uint32_t queueFamilyIndex; 16 | vk::CommandPool commandPool; 17 | vk::Queue queue; 18 | } NnVulkanContext; 19 | 20 | enum NnStagingVulkanCopyDirection { 21 | COPY_TO_DEVICE, 22 | COPY_FROM_DEVICE 23 | }; 24 | 25 | class NnVulkanStagingCopy { 26 | private: 27 | NnStagingVulkanCopyDirection direction; 28 | const NnVulkanContext *context; 29 | vk::DeviceSize bufferSize; 30 | vk::Buffer deviceBuffer; 31 | vk::Buffer hostBuffer; 32 | vk::DeviceMemory hostMemory; 33 | void *hostPointer; 34 | public: 35 | NnVulkanStagingCopy(const NnVulkanContext *context, vk::Buffer& deviceBuffer, const vk::DeviceSize bufferSize, const NnStagingVulkanCopyDirection direction); 36 | ~NnVulkanStagingCopy(); 37 | void copy(NnByte *data); 38 | void executeCopyCommand(); 39 | void addCopyCommand(vk::CommandBuffer& commandBuffer); 40 | }; 41 | 42 | class NnVulkanBuffer { 43 | private: 44 | bool isHostVisible; 45 | NnVulkanContext *context; 46 | vk::DeviceMemory deviceMemory; 47 | void *hostPointer; 48 | public: 49 | vk::DeviceSize bufferSize; 50 | vk::Buffer deviceBuffer; 51 | vk::BufferUsageFlags usageFlags; 52 | NnVulkanBuffer(NnVulkanContext *context, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess); 53 | ~NnVulkanBuffer(); 54 | void write(const NnByte *data); 55 | void read(NnByte *data); 56 | }; 57 | 58 | typedef struct { 59 | NnUint inputOffset; 60 | NnUint inputSizeX; 61 | NnUint outputOffset; 62 | NnUint outputSizeX; 63 | } NnVulkanBatchInfo; 64 | 65 | class NnVulkanDeviceData { 66 | private: 67 | NnNetConfig *netConfig; 68 | NnNodeConfig *nodeConfig; 69 | public: 70 | std::vector> pipes; 71 | std::vector> buffers; 72 | std::vector> internalBuffers; 73 | NnVulkanDeviceData(NnVulkanContext *context, NnNetConfig *netConfig, NnNodeConfig *nodeConfig); 74 | ~NnVulkanDeviceData(); 75 | 76 | NnSize2D resolveBufferSize(NnPointerConfig *config); 77 | NnVulkanBuffer *resolvePointerVulkanBuffer(NnPointerConfig *config); 78 | NnUint resolveBufferBatchOffset(NnPointerConfig *config, NnUint batchIndex); 79 | NnUint resolveBufferBatchWidth(NnPointerConfig *config, NnUint batchIndex); 80 | }; 81 | 82 | class NnVulkanDevice : public NnDevice { 83 | private: 84 | NnVulkanContext context; 85 | NnNetConfig *netConfig; 86 | NnNodeConfig *nodeConfig; 87 | NnNetExecution *netExecution; 88 | public: 89 | NnVulkanDeviceData *data; 90 | NnVulkanDevice(NnUint gpuIndex, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution); 91 | ~NnVulkanDevice() override; 92 | NnUint maxNThreads() override; 93 | NnDeviceSegment *createSegment(NnUint segmentIndex) override; 94 | }; 95 | 96 | class NnVulkanDeviceSegmentData { 97 | private: 98 | NnVulkanDeviceData *data; 99 | std::vector batchInfoBufferIndex; 100 | std::vector weightBufferIndex; 101 | std::vector configBufferIndex; 102 | public: 103 | NnVulkanDeviceSegmentData(NnVulkanContext *context, NnVulkanDeviceData *data, NnSegmentConfig *segmentConfig, NnUint nBatches); 104 | NnVulkanBuffer *resolveOpBatchInfoVulkanBuffer(NnUint opIndex); 105 | NnVulkanBuffer *resolveOpWeightVulkanBuffer(NnUint opIndex); 106 | NnVulkanBuffer *resolveOpConfigVulkanBuffer(NnUint opIndex); 107 | }; 108 | 109 | class NnVulkanDeviceSegment : public NnDeviceSegment { 110 | private: 111 | NnVulkanContext *context; 112 | NnVulkanDeviceData *data; 113 | NnNetConfig *netConfig; 114 | NnUint segmentIndex; 115 | NnSegmentConfig *segmentConfig; 116 | NnNetExecution *netExecution; 117 | std::unique_ptr segmentData; 118 | 119 | std::vector shaderModules; 120 | std::vector descriptorSetLayouts; 121 | vk::DescriptorPool descriptorPool; 122 | std::vector descriptorSets; 123 | vk::Fence fence; 124 | std::vector pipelineLayouts; 125 | std::vector pipelines; 126 | vk::PipelineCache pipelineCache; 127 | vk::CommandBuffer commandBuffer; 128 | NnUint lastBatchSize; 129 | public: 130 | NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanDeviceData *data, NnNetConfig *netConfig, NnUint segmentIndex, NnSegmentConfig *segmentConfig, NnNetExecution *netExecution); 131 | ~NnVulkanDeviceSegment() override; 132 | void loadWeight(NnUint opIndex, NnSize nBytes, NnByte *weight) override; 133 | void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) override; 134 | }; 135 | 136 | #endif -------------------------------------------------------------------------------- /src/nn/pthread.h: -------------------------------------------------------------------------------- 1 | #ifndef PTHREAD_WRAPPER 2 | #define PTHREAD_WRAPPER 3 | 4 | #ifdef _WIN32 5 | #include 6 | 7 | typedef HANDLE PthreadHandler; 8 | typedef DWORD PthreadResult; 9 | typedef DWORD (WINAPI *PthreadFunc)(void *); 10 | 11 | static int pthread_create(PthreadHandler *out, void *unused, PthreadFunc func, void *arg) { 12 | (void) unused; 13 | PthreadHandler handle = CreateThread(NULL, 0, func, arg, 0, NULL); 14 | if (handle == NULL) { 15 | return EAGAIN; 16 | } 17 | *out = handle; 18 | return 0; 19 | } 20 | 21 | static int pthread_join(PthreadHandler thread, void *unused) { 22 | (void) unused; 23 | DWORD ret = WaitForSingleObject(thread, INFINITE); 24 | if (ret == WAIT_FAILED) { 25 | return -1; 26 | } 27 | CloseHandle(thread); 28 | return 0; 29 | } 30 | #else 31 | #include 32 | 33 | typedef pthread_t PthreadHandler; 34 | typedef void* PthreadResult; 35 | typedef void* (*PthreadFunc)(void *); 36 | 37 | #endif 38 | 39 | #endif -------------------------------------------------------------------------------- /src/nn/vulkan/cast-forward-f32-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #define N_THREADS 256 4 | 5 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 6 | 7 | struct BatchInfo { 8 | uint inputOffset; 9 | uint inputSizeX; 10 | uint outputOffset; 11 | uint outputSizeX; 12 | }; 13 | 14 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 15 | layout(binding = 1) writeonly buffer outputBuffer { float y[]; }; 16 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 17 | 18 | shared uint sharedDim; 19 | shared uint sharedXOffset; 20 | shared uint sharedYOffset; 21 | 22 | void main() { 23 | const uint threadIndex = gl_LocalInvocationID.x; 24 | 25 | if (threadIndex == 0) { 26 | const uint nWorkGroups = gl_NumWorkGroups.z; 27 | const uint batchIndex = gl_WorkGroupID.y; 28 | const uint workGroupIndex = gl_WorkGroupID.z; 29 | 30 | const BatchInfo info = infos[batchIndex]; 31 | sharedDim = info.inputSizeX / nWorkGroups; 32 | const uint dimOffset = sharedDim * workGroupIndex; 33 | sharedXOffset = info.inputOffset + dimOffset; 34 | sharedYOffset = info.outputOffset + dimOffset; 35 | } 36 | 37 | barrier(); 38 | memoryBarrierShared(); 39 | 40 | const uint dim = sharedDim; 41 | const uint xOffset = sharedXOffset; 42 | const uint yOffset = sharedYOffset; 43 | 44 | for (uint i = threadIndex; i < dim; i += N_THREADS) { 45 | y[yOffset + i] = x[xOffset + i]; 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/nn/vulkan/cast-forward-f32-q80.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #extension GL_EXT_control_flow_attributes : enable 4 | #extension GL_EXT_shader_16bit_storage : enable 5 | #extension GL_EXT_shader_explicit_arithmetic_types : enable 6 | 7 | #define Q80_BLOCK_SIZE 32 8 | #define N_THREADS 256 9 | 10 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 11 | 12 | struct BatchInfo { 13 | uint inputOffset; 14 | uint inputSizeX; 15 | uint outputOffset; // number of Q80 blocks 16 | uint outputSizeX; // number of Q80 blocks 17 | }; 18 | 19 | struct BlockQ80 { 20 | float16_t d; 21 | int8_t qs[Q80_BLOCK_SIZE]; 22 | }; 23 | 24 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 25 | layout(binding = 1) writeonly buffer outputBuffer { BlockQ80 y[]; }; 26 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 27 | 28 | shared uint sharedYStart; 29 | shared uint sharedYEnd; 30 | shared uint sharedXOffset; 31 | shared uint sharedYOffset; 32 | 33 | void main() { 34 | const uint threadIndex = gl_LocalInvocationID.x; 35 | 36 | if (threadIndex == 0) { 37 | const uint nWorkGroups = gl_NumWorkGroups.z; 38 | const uint batchIndex = gl_WorkGroupID.y; 39 | const uint workGroupIndex = gl_WorkGroupID.z; 40 | 41 | const BatchInfo info = infos[batchIndex]; 42 | 43 | const uint ySlice = info.outputSizeX / nWorkGroups; 44 | const uint yRest = info.outputSizeX % nWorkGroups; 45 | sharedYStart = workGroupIndex * ySlice + (workGroupIndex < yRest ? workGroupIndex : yRest); 46 | sharedYEnd = sharedYStart + ySlice + (workGroupIndex < yRest ? 1 : 0); 47 | sharedXOffset = info.inputOffset; 48 | sharedYOffset = info.outputOffset; 49 | } 50 | 51 | barrier(); 52 | memoryBarrierShared(); 53 | 54 | const uint yStart = sharedYStart + threadIndex; 55 | const uint yEnd = sharedYEnd; 56 | const uint xOffset = sharedXOffset; 57 | const uint yOffset = sharedYOffset; 58 | 59 | for (uint i = yStart; i < yEnd; i += N_THREADS) { 60 | const uint xiOffset = xOffset + i * Q80_BLOCK_SIZE; 61 | const uint yiOffset = yOffset + i; 62 | 63 | float amax = 0.0; 64 | [[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) { 65 | const float v = abs(x[xiOffset + j]); 66 | amax = max(amax, v); 67 | } 68 | 69 | const float d = amax / ((1 << 7) - 1); 70 | const float id = d != 0.0 ? 1.0 / d : 0.0; 71 | 72 | y[yiOffset].d = float16_t(d); 73 | 74 | [[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) { 75 | const float v = x[xiOffset + j]; 76 | y[yiOffset].qs[j] = int8_t(clamp(round(v * id), -127.0, 127.0)); 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/nn/vulkan/embedding-forward-f32-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #define N_THREADS 256 4 | 5 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 6 | 7 | struct BatchInfo { 8 | uint inputOffset; 9 | uint inputSizeX; 10 | uint outputOffset; 11 | uint outputSizeX; 12 | }; 13 | 14 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 15 | layout(binding = 1) buffer outputBuffer { float y[]; }; 16 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 17 | layout(binding = 3) readonly buffer weightBuffer { float weight[]; }; 18 | 19 | shared uint sharedPosition; 20 | shared BatchInfo sharedInfo; 21 | 22 | void main() { 23 | const uint threadIndex = gl_LocalInvocationID.x; 24 | const uint batchIndex = gl_GlobalInvocationID.y; 25 | 26 | if (threadIndex == 0) { 27 | sharedPosition = uint(x[batchIndex]); 28 | sharedInfo = infos[batchIndex]; 29 | } 30 | 31 | barrier(); 32 | memoryBarrierShared(); 33 | 34 | const uint outputSizeX = sharedInfo.outputSizeX; 35 | const uint yOffset = sharedInfo.outputOffset; 36 | const uint wOffset = sharedPosition * outputSizeX; 37 | 38 | for (uint i = threadIndex; i < outputSizeX; i += N_THREADS) { 39 | y[yOffset + i] = weight[wOffset + i]; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/nn/vulkan/inv-rms-forward-f32-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #extension GL_EXT_control_flow_attributes : enable 4 | 5 | #define N_THREADS 256 6 | 7 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 8 | 9 | struct BatchInfo { 10 | uint inputOffset; 11 | uint inputSizeX; 12 | uint outputOffset; 13 | uint outputSizeX; 14 | }; 15 | 16 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 17 | layout(binding = 1) writeonly buffer outputBuffer { float y[]; }; 18 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 19 | layout(binding = 3) readonly uniform opConfigBuffer { 20 | float epsilon; 21 | }; 22 | 23 | shared BatchInfo sharedInfo; 24 | shared float sums[N_THREADS]; 25 | 26 | void main() { 27 | const uint threadIndex = gl_LocalInvocationID.x; 28 | const uint batchIndex = gl_GlobalInvocationID.y; 29 | 30 | if (threadIndex == 0) { 31 | sharedInfo = infos[batchIndex]; 32 | } 33 | memoryBarrierShared(); 34 | barrier(); 35 | 36 | const uint inputSizeX = sharedInfo.inputSizeX; 37 | const uint offset = sharedInfo.inputOffset; 38 | const uint slice = inputSizeX / N_THREADS; 39 | const uint rest = inputSizeX % N_THREADS; 40 | const uint start = offset + threadIndex * slice + (threadIndex < rest ? threadIndex : rest); 41 | const uint end = start + slice + (threadIndex < rest ? 1 : 0); 42 | 43 | float sum = 0.0; 44 | for (uint i = start; i < end; i++) { 45 | sum += x[i] * x[i]; 46 | } 47 | sums[threadIndex] = sum; 48 | 49 | memoryBarrierShared(); 50 | barrier(); 51 | 52 | if (threadIndex == 0) { 53 | sum = 0.0; 54 | [[unroll]] for (uint i = 0; i < N_THREADS; i++) { 55 | sum += sums[i]; 56 | } 57 | y[batchIndex] = inversesqrt((sum / float(inputSizeX)) + epsilon); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/nn/vulkan/matmul-forward-f32-f32-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #define N_THREADS 128 4 | 5 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 6 | 7 | struct BatchInfo { 8 | uint inputOffset; 9 | uint inputSizeX; 10 | uint outputOffset; 11 | uint outputSizeX; 12 | }; 13 | 14 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 15 | layout(binding = 1) writeonly buffer outputBuffer { float y[]; }; 16 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 17 | layout(binding = 3) readonly buffer weightBuffer { float weight[]; }; 18 | 19 | shared BatchInfo sharedInfo; 20 | shared uint sharedDim; 21 | 22 | void main() { 23 | const uint threadIndex = gl_LocalInvocationID.x; 24 | const uint workGroupIndex = gl_WorkGroupID.z; 25 | 26 | if (threadIndex == 0) { 27 | const uint batchIndex = gl_WorkGroupID.y; 28 | const uint nWorkGroups = gl_NumWorkGroups.z; 29 | 30 | sharedInfo = infos[batchIndex]; 31 | sharedDim = sharedInfo.outputSizeX / nWorkGroups; 32 | } 33 | 34 | barrier(); 35 | memoryBarrierShared(); 36 | 37 | const uint inputSizeX = sharedInfo.inputSizeX; 38 | const uint xOffset = sharedInfo.inputOffset; 39 | const uint yOffset = sharedInfo.outputOffset; 40 | const uint dim = sharedDim; 41 | 42 | for (uint i = threadIndex; i < dim; i += N_THREADS) { 43 | const uint d = (workGroupIndex * dim) + i; 44 | const uint wOffset = d * inputSizeX; 45 | 46 | float sum = 0.0; 47 | for (uint j = 0; j < inputSizeX; j++) { 48 | sum += x[xOffset + j] * weight[wOffset + j]; 49 | } 50 | y[yOffset + d] = sum; 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/nn/vulkan/matmul-forward-q80-q40-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #extension GL_EXT_control_flow_attributes : enable 4 | #extension GL_EXT_shader_16bit_storage : enable 5 | #extension GL_EXT_shader_explicit_arithmetic_types : enable 6 | 7 | #define N_THREADS 64 8 | #define TILE_SIZE_X 2 9 | #define TILE_SIZE_D 16 10 | 11 | #define Q80_Q40_BLOCK_SIZE 32 12 | 13 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 14 | 15 | struct BatchInfo { 16 | uint inputOffset; 17 | uint inputSizeX; 18 | uint outputOffset; 19 | uint outputSizeX; 20 | }; 21 | 22 | struct BlockQ80 { 23 | float16_t d; 24 | int8_t qs[Q80_Q40_BLOCK_SIZE]; 25 | }; 26 | 27 | struct BlockQ40 { 28 | float16_t d; 29 | uint8_t qs[Q80_Q40_BLOCK_SIZE / 2]; 30 | }; 31 | 32 | layout(binding = 0) readonly buffer inputBuffer { BlockQ80 x[]; }; 33 | layout(binding = 1) writeonly buffer outputBuffer { float y[]; }; 34 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 35 | layout(binding = 3) readonly buffer weightBuffer { BlockQ40 weight[]; }; 36 | 37 | shared uint sharedXSlice; 38 | shared uint sharedXRest; 39 | shared uint sharedInputOffset; 40 | shared uint sharedInputSizeX; 41 | shared uint sharedOutputOffset; 42 | shared uint sharedD; 43 | shared float16_t sums[N_THREADS * TILE_SIZE_D]; 44 | 45 | void main() { 46 | const uint threadIndex = gl_LocalInvocationID.x; 47 | 48 | if (threadIndex == 0) { 49 | const uint batchIndex = gl_WorkGroupID.y; 50 | const uint workGroupIndex = gl_WorkGroupID.z; 51 | 52 | const BatchInfo info = infos[batchIndex]; 53 | 54 | const uint xTiles = info.inputSizeX / TILE_SIZE_X; 55 | sharedXSlice = xTiles / N_THREADS; 56 | sharedXRest = xTiles % N_THREADS; 57 | 58 | sharedInputOffset = info.inputOffset; 59 | sharedInputSizeX = info.inputSizeX; 60 | sharedOutputOffset = info.outputOffset; 61 | sharedD = TILE_SIZE_D * workGroupIndex; 62 | } 63 | 64 | barrier(); 65 | memoryBarrierShared(); 66 | 67 | const uint xSlice = sharedXSlice; 68 | const uint xRest = sharedXRest; 69 | const uint xStart = (threadIndex * xSlice + min(threadIndex, xRest)) * TILE_SIZE_X; 70 | const uint xEnd = xStart + (xSlice + (threadIndex < xRest ? 1 : 0)) * TILE_SIZE_X; 71 | 72 | const uint inputOffset = sharedInputOffset; 73 | const uint inputSizeX = sharedInputSizeX; 74 | const uint outputOffset = sharedOutputOffset; 75 | const uint d = sharedD; 76 | 77 | f16vec4 xTemp[Q80_Q40_BLOCK_SIZE / 4]; 78 | 79 | for (uint dt = 0; dt < TILE_SIZE_D; dt++) { 80 | sums[threadIndex * TILE_SIZE_D + dt] = float16_t(0.0f); 81 | } 82 | 83 | for (uint i = xStart; i < xEnd; i += TILE_SIZE_X) { 84 | [[unroll]] for (uint it = 0; it < TILE_SIZE_X; it++) { 85 | const uint xi = inputOffset + i + it; 86 | const float16_t xScale = x[xi].d; 87 | [[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) { 88 | xTemp[j] = f16vec4( 89 | x[xi].qs[j * 2], 90 | x[xi].qs[j * 2 + Q80_Q40_BLOCK_SIZE / 2], 91 | x[xi].qs[j * 2 + 1], 92 | x[xi].qs[j * 2 + 1 + Q80_Q40_BLOCK_SIZE / 2] 93 | ); 94 | } 95 | 96 | [[unroll]] for (uint dt = 0; dt < TILE_SIZE_D; dt++) { 97 | const uint wi = (d + dt) * inputSizeX + (i + it); 98 | const BlockQ40 wBlock = weight[wi]; 99 | 100 | float16_t s = float16_t(0); 101 | [[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) { 102 | uint w0 = wBlock.qs[j * 2]; 103 | uint w1 = wBlock.qs[j * 2 + 1]; 104 | ivec4 w = ivec4( 105 | w0 & 0xFu, 106 | w0 >> 4, 107 | w1 & 0xFu, 108 | w1 >> 4 109 | ) - ivec4(8); 110 | s += dot(xTemp[j], f16vec4(w)); 111 | } 112 | sums[threadIndex * TILE_SIZE_D + dt] += s * xScale * wBlock.d; 113 | } 114 | } 115 | } 116 | 117 | barrier(); 118 | memoryBarrierShared(); 119 | 120 | [[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) { 121 | for (uint dt = 0; dt < TILE_SIZE_D; dt++) { 122 | if (threadIndex < i) { 123 | sums[threadIndex * TILE_SIZE_D + dt] += sums[(threadIndex + i) * TILE_SIZE_D + dt]; 124 | } 125 | } 126 | barrier(); 127 | } 128 | for (uint dt = threadIndex; dt < TILE_SIZE_D; dt += N_THREADS) { 129 | y[outputOffset + d + dt] = float(sums[dt]); 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /src/nn/vulkan/merge-add-forward-f32-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #define N_THREADS 256 4 | 5 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 6 | 7 | struct BatchInfo { 8 | uint inputOffset; 9 | uint inputSizeX; 10 | uint outputOffset; 11 | uint outputSizeX; 12 | }; 13 | 14 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 15 | layout(binding = 1) buffer outputBuffer { float y[]; }; 16 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 17 | 18 | shared uint sharedDim; 19 | shared uint sharedOutputSizeX; 20 | shared uint sharedParts; 21 | shared uint sharedXOffset; 22 | shared uint sharedYOffset; 23 | 24 | void main() { 25 | const uint threadIndex = gl_LocalInvocationID.x; 26 | 27 | if (threadIndex == 0) { 28 | const uint nWorkGroups = gl_NumWorkGroups.z; 29 | const uint batchIndex = gl_WorkGroupID.y; 30 | const uint workGroupIndex = gl_WorkGroupID.z; 31 | 32 | const BatchInfo info = infos[batchIndex]; 33 | sharedDim = info.outputSizeX / nWorkGroups; 34 | sharedOutputSizeX = info.outputSizeX; 35 | sharedParts = info.inputSizeX / info.outputSizeX; 36 | sharedXOffset = info.inputOffset + sharedDim * workGroupIndex; 37 | sharedYOffset = info.outputOffset + sharedDim * workGroupIndex; 38 | } 39 | 40 | barrier(); 41 | memoryBarrierShared(); 42 | 43 | const uint dim = sharedDim; 44 | const uint outputSizeX = sharedOutputSizeX; 45 | const uint parts = sharedParts; 46 | const uint xOffset = sharedXOffset; 47 | const uint yOffset = sharedYOffset; 48 | 49 | for (uint i = threadIndex; i < dim; i += N_THREADS) { 50 | float sum = 0.0; 51 | const uint iOffset = xOffset + i; 52 | const uint oOffset = yOffset + i; 53 | for (uint n = 0; n < parts; n++) { 54 | sum += x[n * outputSizeX + iOffset]; 55 | } 56 | y[oOffset] += sum; 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/nn/vulkan/merge-add-forward-q80-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #extension GL_EXT_control_flow_attributes : enable 4 | #extension GL_EXT_shader_16bit_storage : enable 5 | #extension GL_EXT_shader_explicit_arithmetic_types : enable 6 | 7 | #define Q80_BLOCK_SIZE 32 8 | #define N_THREADS 256 9 | 10 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 11 | 12 | struct BatchInfo { 13 | uint inputOffset; // number of Q80 blocks 14 | uint inputSizeX; // number of Q80 blocks 15 | uint outputOffset; 16 | uint outputSizeX; 17 | }; 18 | 19 | struct BlockQ80 { 20 | float16_t d; 21 | int8_t qs[Q80_BLOCK_SIZE]; 22 | }; 23 | 24 | layout(binding = 0) readonly buffer inputBuffer { BlockQ80 x[]; }; 25 | layout(binding = 1) buffer outputBuffer { float y[]; }; 26 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 27 | 28 | shared uint sharedXStart; 29 | shared uint sharedXEnd; 30 | shared uint sharedNParts; 31 | shared uint sharedXJump; 32 | shared uint sharedXOffset; 33 | shared uint sharedYOffset; 34 | 35 | void main() { 36 | const uint threadIndex = gl_LocalInvocationID.x; 37 | 38 | if (threadIndex == 0) { 39 | const uint nWorkGroups = gl_NumWorkGroups.z; 40 | const uint batchIndex = gl_WorkGroupID.y; 41 | const uint workGroupIndex = gl_WorkGroupID.z; 42 | 43 | const BatchInfo info = infos[batchIndex]; 44 | const uint xJump = info.outputSizeX / Q80_BLOCK_SIZE; 45 | const uint nParts = info.inputSizeX / xJump; 46 | const uint xSlice = xJump / nWorkGroups; 47 | const uint xRest = xJump % nWorkGroups; 48 | 49 | sharedXStart = workGroupIndex * xSlice + (workGroupIndex < xRest ? workGroupIndex : xRest); 50 | sharedXEnd = sharedXStart + xSlice + (workGroupIndex < xRest ? 1 : 0); 51 | sharedNParts = nParts; 52 | sharedXJump = xJump; 53 | sharedXOffset = info.inputOffset; 54 | sharedYOffset = info.outputOffset; 55 | } 56 | 57 | barrier(); 58 | memoryBarrierShared(); 59 | 60 | const uint xStart = sharedXStart + threadIndex; 61 | const uint xEnd = sharedXEnd; 62 | const uint xJump = sharedXJump; 63 | const uint nParts = sharedNParts; 64 | const uint xOffset = sharedXOffset; 65 | const uint yOffset = sharedYOffset; 66 | float16_t sums[Q80_BLOCK_SIZE]; 67 | 68 | for (uint i = xStart; i < xEnd; i += N_THREADS) { 69 | const uint xiOffset = xOffset + i; 70 | const uint yiOffset = yOffset + i * Q80_BLOCK_SIZE; 71 | 72 | [[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) { 73 | sums[k] = float16_t(0.0); 74 | } 75 | for (uint n = 0; n < nParts; n++) { 76 | const BlockQ80 b = x[xiOffset + n * xJump]; 77 | const float16_t d = b.d; 78 | 79 | [[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) { 80 | sums[k] += float16_t(b.qs[k]) * d; 81 | } 82 | } 83 | 84 | [[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) { 85 | y[yiOffset + k] += float(sums[k]); 86 | } 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/nn/vulkan/mul-forward-f32-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #define N_THREADS 256 4 | 5 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 6 | 7 | struct BatchInfo { 8 | uint inputOffset; 9 | uint inputSizeX; 10 | uint outputOffset; 11 | uint outputSizeX; 12 | }; 13 | 14 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 15 | layout(binding = 1) buffer outputBuffer { float y[]; }; 16 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 17 | layout(binding = 3) readonly uniform configBuffer { 18 | uint multiplierBufferIndex; 19 | }; 20 | layout(binding = 4) readonly buffer multiplierBuffer { float m[]; }; 21 | 22 | shared uint sharedDim; 23 | shared uint sharedXyOffset; 24 | shared uint sharedMOffset; 25 | 26 | void main() { 27 | const uint threadIndex = gl_LocalInvocationID.x; 28 | 29 | if (threadIndex == 0) { 30 | const uint nWorkGroups = gl_NumWorkGroups.z; 31 | const uint batchIndex = gl_WorkGroupID.y; 32 | const uint workGroupIndex = gl_WorkGroupID.z; 33 | 34 | const BatchInfo info = infos[batchIndex]; 35 | sharedDim = info.inputSizeX / nWorkGroups; 36 | sharedXyOffset = info.inputOffset + sharedDim * workGroupIndex; 37 | sharedMOffset = info.inputSizeX * batchIndex + sharedDim * workGroupIndex; 38 | } 39 | 40 | barrier(); 41 | memoryBarrierShared(); 42 | 43 | const uint dim = sharedDim; 44 | const uint xyOffset = sharedXyOffset; 45 | const uint mOffset = sharedMOffset; 46 | 47 | for (uint i = threadIndex; i < dim; i += N_THREADS) { 48 | y[xyOffset + i] = x[xyOffset + i] * m[mOffset + i]; 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/nn/vulkan/multi-head-att-forward-f32-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #extension GL_EXT_control_flow_attributes : enable 4 | 5 | #define N_THREADS 256 6 | 7 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 8 | 9 | struct BatchInfo { 10 | uint inputOffset; 11 | uint inputSizeX; 12 | uint outputOffset; 13 | uint outputSizeX; 14 | }; 15 | 16 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 17 | layout(binding = 1) writeonly buffer outputBuffer { float y[]; }; 18 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 19 | layout(binding = 3) readonly uniform configBuffer { 20 | uint nHeads; 21 | uint nHeads0; 22 | uint nKvHeads; 23 | uint headSize; 24 | uint seqLen; 25 | uint qSliceD0; 26 | uint kvDim0; 27 | // uint positionPipeIndex; 28 | // uint queryBufferIndex; 29 | // uint keyCacheBufferIndex; 30 | // uint valueCacheBufferIndex; 31 | // uint attBufferIndex; 32 | }; 33 | layout(binding = 4) readonly buffer positionsBuffer { float positions[]; }; 34 | layout(binding = 5) readonly buffer queryBuffer { float query[]; }; 35 | layout(binding = 6) readonly buffer keyCacheBuffer { float keyCache[]; }; 36 | layout(binding = 7) readonly buffer valueCacheBuffer { float valueCache[]; }; 37 | layout(binding = 8) buffer attBufferBuffer { float att[]; }; 38 | 39 | shared BatchInfo sharedInfo; 40 | shared uint position; 41 | shared float sharedMaxScore; 42 | shared float temp[N_THREADS]; 43 | 44 | void main() { 45 | const uint threadIndex = gl_LocalInvocationID.x; 46 | const uint batchIndex = gl_WorkGroupID.y; 47 | const uint h = gl_WorkGroupID.z; 48 | 49 | const uint kvMul = nHeads / nKvHeads; 50 | const uint headIndex = h / kvMul; 51 | const float invHeadSizeRoot = 1.0 / sqrt(float(headSize)); 52 | 53 | 54 | if (threadIndex == 0) { 55 | sharedInfo = infos[batchIndex]; 56 | position = uint(positions[batchIndex]); 57 | } 58 | 59 | barrier(); 60 | memoryBarrierShared(); 61 | 62 | const uint attOffset = batchIndex * nHeads0 * seqLen + h * seqLen; 63 | const uint qOffset = batchIndex * qSliceD0 + h * headSize; 64 | const uint kvOffset = headIndex * headSize; 65 | const uint yOffset = sharedInfo.outputOffset + h * headSize; 66 | 67 | float ms = -1e10f; 68 | for (uint p = threadIndex; p <= position; p += N_THREADS) { 69 | const uint kOffset = kvOffset + p * kvDim0; 70 | 71 | float score = 0.0; 72 | for (uint i = 0; i < headSize; i++) { 73 | score += query[qOffset + i] * keyCache[kOffset + i]; 74 | } 75 | score *= invHeadSizeRoot; 76 | ms = max(ms, score); 77 | att[attOffset + p] = score; 78 | } 79 | 80 | temp[threadIndex] = ms; 81 | 82 | barrier(); 83 | memoryBarrierShared(); 84 | 85 | [[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) { 86 | if (threadIndex < i) 87 | temp[threadIndex] = max(temp[threadIndex], temp[threadIndex + i]); 88 | barrier(); 89 | } 90 | 91 | if (threadIndex == 0) { 92 | sharedMaxScore = temp[0]; 93 | } 94 | 95 | barrier(); 96 | memoryBarrierShared(); 97 | 98 | const float maxScore = sharedMaxScore; 99 | 100 | float s = 0.0; 101 | for (uint p = threadIndex; p <= position; p += N_THREADS) { 102 | float v = exp(att[attOffset + p] - maxScore); 103 | att[attOffset + p] = v; 104 | s += v; 105 | } 106 | 107 | temp[threadIndex] = s; 108 | barrier(); 109 | 110 | [[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) { 111 | if (threadIndex < i) 112 | temp[threadIndex] += temp[threadIndex + i]; 113 | barrier(); 114 | } 115 | 116 | const float yScale = 1.0 / temp[0]; 117 | 118 | for (uint i = threadIndex; i < headSize; i += N_THREADS) { 119 | float sum = 0.0; 120 | const uint vOffset = kvOffset + i; 121 | for (uint p = 0; p <= position; p += 1) { 122 | const float a = att[attOffset + p]; 123 | const float v = valueCache[vOffset + p * kvDim0]; 124 | sum += v * a; 125 | } 126 | y[yOffset + i] = sum * yScale; 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /src/nn/vulkan/rms-norm-forward-f32-f32-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #define N_THREADS 256 4 | 5 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 6 | 7 | struct BatchInfo { 8 | uint inputOffset; 9 | uint inputSizeX; 10 | uint outputOffset; 11 | uint outputSizeX; 12 | }; 13 | 14 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 15 | layout(binding = 1) writeonly buffer outputBuffer { float y[]; }; 16 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 17 | layout(binding = 3) readonly buffer weightBuffer { float weight[]; }; 18 | layout(binding = 4) readonly uniform configBuffer { 19 | uint invRmsBufferIndex; // not used 20 | }; 21 | layout(binding = 5) readonly buffer invRmsBuffer { float invRms[]; }; 22 | 23 | shared uint sharedDim; 24 | shared uint sharedDimOffset; 25 | shared uint sharedXOffset; 26 | shared uint sharedYOffset; 27 | shared float sharedS; 28 | 29 | void main() { 30 | const uint threadIndex = uint(gl_LocalInvocationID.x); 31 | 32 | if (threadIndex == 0) { 33 | const uint nWorkGroups = gl_NumWorkGroups.z; 34 | const uint batchIndex = gl_WorkGroupID.y; 35 | const uint workGroupIndex = gl_WorkGroupID.z; 36 | 37 | const BatchInfo info = infos[batchIndex]; 38 | sharedDim = info.inputSizeX / nWorkGroups; 39 | sharedDimOffset = sharedDim * workGroupIndex; 40 | sharedXOffset = info.inputOffset + sharedDimOffset; 41 | sharedYOffset = info.outputOffset + sharedDimOffset; 42 | sharedS = invRms[batchIndex]; 43 | } 44 | 45 | barrier(); 46 | memoryBarrierShared(); 47 | 48 | const uint dim = sharedDim; 49 | const uint dimOffset = sharedDimOffset; 50 | const uint xOffset = sharedXOffset; 51 | const uint yOffset = sharedYOffset; 52 | const float s = sharedS; 53 | 54 | for (uint i = threadIndex; i < dim; i += N_THREADS) { 55 | y[yOffset + i] = (x[xOffset + i] * s) * weight[i + dimOffset]; 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/nn/vulkan/rope-forward-f32-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #define N_THREADS 256 4 | 5 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 6 | 7 | struct BatchInfo { 8 | uint inputOffset; 9 | uint inputSizeX; 10 | uint outputOffset; 11 | uint outputSizeX; 12 | }; 13 | 14 | struct RopeSlice { 15 | uint qDim0; 16 | uint qDimStart; 17 | uint qDimEnd; 18 | uint qShift; 19 | uint kvDim; 20 | uint kvDim0; 21 | uint kvDimStart; 22 | uint sliceDim; 23 | uint seqLen; 24 | uint headSize; 25 | uint nKvHeads; 26 | float ropeTheta; 27 | // NnSize2D cacheSize; 28 | }; 29 | 30 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 31 | layout(binding = 1) writeonly buffer outputBuffer { float y[]; }; 32 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 33 | layout(binding = 3) readonly uniform configBuffer { 34 | bool isQ; 35 | uint positionPipeIndex; 36 | uint ropeCacheBufferIndex; 37 | float ropeScalingFactor; 38 | float ropeScalingLowFreqFactor; 39 | float ropeScalingHighFreqFactor; 40 | uint ropeScalingOrigMaxSeqLen; 41 | RopeSlice slice; 42 | }; 43 | layout(binding = 4) readonly buffer positionsBuffer { float positions[]; }; 44 | layout(binding = 5) readonly buffer ropeCacheBuffer { float ropeCache[]; }; 45 | 46 | shared uint sharedOffset; 47 | shared BatchInfo sharedInfo; 48 | 49 | void main() { 50 | const uint threadIndex = gl_LocalInvocationID.x; 51 | const uint batchIndex = gl_GlobalInvocationID.y; 52 | 53 | if (threadIndex == 0) { 54 | uint position = uint(positions[batchIndex]); 55 | 56 | sharedOffset = position * slice.sliceDim; 57 | if (isQ) { 58 | sharedOffset += slice.qShift; 59 | } 60 | sharedInfo = infos[batchIndex]; 61 | } 62 | 63 | barrier(); 64 | memoryBarrierShared(); 65 | 66 | const uint dim0Half = (isQ ? slice.qDim0 : slice.kvDim0) / 2; 67 | const uint xOffset = sharedInfo.inputOffset; 68 | const uint yOffset = sharedInfo.outputOffset; 69 | 70 | for (uint i = threadIndex; i < dim0Half; i += N_THREADS) { 71 | const uint j = i * 2; 72 | const uint c = sharedOffset + j; 73 | 74 | float fcr = ropeCache[c]; 75 | float fci = ropeCache[c + 1]; 76 | float v0 = x[xOffset + j]; 77 | float v1 = x[xOffset + j + 1]; 78 | 79 | float x0 = v0 * fcr - v1 * fci; 80 | float x1 = v0 * fci + v1 * fcr; 81 | 82 | y[yOffset + j] = x0; 83 | y[yOffset + j + 1] = x1; 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/nn/vulkan/shift-forward-f32-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #define N_THREADS 256 4 | 5 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 6 | 7 | struct BatchInfo { 8 | uint inputOffset; 9 | uint inputSizeX; 10 | uint outputOffset; 11 | uint outputSizeX; 12 | }; 13 | 14 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 15 | layout(binding = 1) writeonly buffer outputBuffer { float y[]; }; 16 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 17 | layout(binding = 3) readonly uniform configBuffer { 18 | uint indexPipeIndex; 19 | }; 20 | layout(binding = 4) readonly buffer indexBuffer { float indexes[]; }; 21 | 22 | shared uint sharedDim; 23 | shared uint sharedXOffset; 24 | shared uint sharedYOffset; 25 | 26 | void main() { 27 | const uint threadIndex = gl_LocalInvocationID.x; 28 | 29 | if (threadIndex == 0) { 30 | const uint nWorkGroups = gl_NumWorkGroups.z; 31 | const uint batchIndex = gl_WorkGroupID.y; 32 | const uint workGroupIndex = gl_WorkGroupID.z; 33 | 34 | const uint index = uint(indexes[batchIndex]); 35 | BatchInfo info = infos[batchIndex]; 36 | sharedDim = info.inputSizeX / nWorkGroups; 37 | const uint dimOffset = sharedDim * workGroupIndex; 38 | sharedXOffset = info.inputOffset + dimOffset; 39 | sharedYOffset = index * info.inputSizeX + dimOffset; 40 | } 41 | 42 | barrier(); 43 | memoryBarrierShared(); 44 | 45 | const uint dim = sharedDim; 46 | const uint xOffset = sharedXOffset; 47 | const uint yOffset = sharedYOffset; 48 | 49 | for (uint i = threadIndex; i < dim; i += N_THREADS) { 50 | y[yOffset + i] = x[xOffset + i]; 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/nn/vulkan/silu-forward-f32-f32.comp: -------------------------------------------------------------------------------- 1 | #version 450 2 | 3 | #define N_THREADS 256 4 | 5 | layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; 6 | 7 | struct BatchInfo { 8 | uint inputOffset; 9 | uint inputSizeX; 10 | uint outputOffset; 11 | uint outputSizeX; 12 | }; 13 | 14 | layout(binding = 0) readonly buffer inputBuffer { float x[]; }; 15 | layout(binding = 1) writeonly buffer outputBuffer { float y[]; }; 16 | layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; }; 17 | 18 | shared uint sharedDim; 19 | shared uint sharedXOffset; 20 | shared uint sharedYOffset; 21 | 22 | void main() { 23 | const uint threadIndex = gl_LocalInvocationID.x; 24 | 25 | if (threadIndex == 0) { 26 | const uint nWorkGroups = gl_NumWorkGroups.z; 27 | const uint batchIndex = gl_WorkGroupID.y; 28 | const uint workGroupIndex = gl_WorkGroupID.z; 29 | 30 | const BatchInfo info = infos[batchIndex]; 31 | sharedDim = info.inputSizeX / nWorkGroups; 32 | sharedXOffset = info.inputOffset + sharedDim * workGroupIndex; 33 | sharedYOffset = info.outputOffset + sharedDim * workGroupIndex; 34 | } 35 | 36 | barrier(); 37 | memoryBarrierShared(); 38 | 39 | const uint dim = sharedDim; 40 | const uint xOffset = sharedXOffset; 41 | const uint yOffset = sharedYOffset; 42 | 43 | for (uint i = threadIndex; i < dim; i += N_THREADS) { 44 | float v = x[xOffset + i]; 45 | y[yOffset + i] = v / (1.0 + exp(-v)); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/tokenizer-test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "tokenizer.hpp" 4 | 5 | #define DEV_TESTS false 6 | 7 | #define ASSERT_EQ(a, b) \ 8 | if (a != b) { \ 9 | printf("Assertion failed: %d != %d (%s:%d)\n", a, b, __FILE__, __LINE__); \ 10 | exit(-1); \ 11 | } 12 | 13 | #define TEST_EOS_ID 10000 14 | 15 | void printOk(const char *name) { 16 | printf("✅ %24s passed\n", name); 17 | } 18 | 19 | void compare(const char *name, const int *a, const int *b, const unsigned int aN, const int bN) { 20 | bool passed = true; 21 | if (aN != bN) { 22 | passed = false; 23 | } else { 24 | for (unsigned int i = 0; i < bN; i++) { 25 | if (a[i] != b[i]) { 26 | passed = false; 27 | break; 28 | } 29 | } 30 | } 31 | if (!passed) { 32 | printf("❌ %24s failed\na: ", name); 33 | for (unsigned int j = 0; j < aN; j++) 34 | printf("%5d ", a[j]); 35 | printf("\nb: "); 36 | for (unsigned int j = 0; j < bN; j++) 37 | printf("%5d ", b[j]); 38 | printf("\n"); 39 | exit(1); 40 | } 41 | printOk(name); 42 | } 43 | 44 | void dev_testEncode(Tokenizer *tokenizer) { 45 | int tokens[1024]; 46 | int nTokens; 47 | 48 | { 49 | const char *text = "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"; 50 | const int expectedTokens[] = {128000, 128006, 882, 128007, 271, 15339, 128009, 128006, 78191, 128007, 271}; 51 | 52 | tokenizer->encode((char *)text, tokens, &nTokens, true, true); 53 | compare("case0", expectedTokens, tokens, 11, nTokens); 54 | } 55 | { 56 | const char *text = "!!&&@(*x)^^!"; 57 | const int expectedTokens[] = {128000, 3001, 7827, 31, 4163, 87, 8, 22634, 0}; 58 | 59 | tokenizer->encode((char *)text, tokens, &nTokens, true, true); 60 | compare("case1", expectedTokens, tokens, 9, nTokens); 61 | } 62 | { 63 | const char *text = "😃!😇x"; 64 | const int expectedTokens[] = {128000, 76460, 225, 0, 76460, 229, 87}; 65 | 66 | tokenizer->encode((char *)text, tokens, &nTokens, true, true); 67 | compare("case2", expectedTokens, tokens, 7, nTokens); 68 | } 69 | } 70 | 71 | void dev_testDecoderEmoji(Tokenizer *tokenizer) { 72 | char *x0 = tokenizer->decode(128000); 73 | assert(x0 == nullptr); 74 | 75 | char *x1 = tokenizer->decode(76460); 76 | assert(x1 == nullptr); 77 | 78 | char *x2 = tokenizer->decode(225); 79 | assert(x2 == nullptr); 80 | 81 | char *x3 = tokenizer->decode(0); 82 | assert(strstr(x3, "😃!") != NULL); 83 | 84 | char *x4 = tokenizer->decode(56); 85 | assert(strstr(x3, "Y") != NULL); 86 | 87 | printOk("testDecoderEmoji"); 88 | } 89 | 90 | void dev_testDecoderEmojiWithEos(Tokenizer *tokenizer) { 91 | char *x0 = tokenizer->decode(128000); 92 | char *x1 = tokenizer->decode(76460); 93 | char *x2 = tokenizer->decode(225); 94 | char *x3 = tokenizer->decode(128001); 95 | 96 | assert(x0 == nullptr); 97 | assert(x1 == nullptr); 98 | assert(x2 == nullptr); 99 | assert(strstr(x3, "😃") != NULL); // piece should not contain <|end_of_text|> 100 | printOk("decoderEmojiWithEos"); 101 | } 102 | 103 | void testChatTemplateDetection() { 104 | ChatTemplateGenerator t0(TEMPLATE_UNKNOWN, "{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", ""); 105 | assert(t0.type == TEMPLATE_LLAMA3); 106 | 107 | printOk("chatTemplateDetection"); 108 | } 109 | 110 | void testEosDetectorWithPadding() { 111 | const int tokens[2] = {TEST_EOS_ID, TEST_EOS_ID + 1}; 112 | const char *pieces[2] = { "", "" }; 113 | EosDetector detector(2, tokens, pieces, 1, 1); 114 | 115 | // "" 116 | { 117 | ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS); 118 | ASSERT_EQ(detector.append(2, "eo"), MAYBE_EOS); 119 | ASSERT_EQ(detector.append(3, "s>"), EOS); 120 | assert(detector.getDelta() == nullptr); 121 | } 122 | 123 | // " " 124 | detector.reset(); 125 | { 126 | ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS); 127 | ASSERT_EQ(detector.append(2, "stop"), MAYBE_EOS); 128 | ASSERT_EQ(detector.append(3, "> "), EOS); 129 | assert(detector.getDelta() == nullptr); 130 | } 131 | 132 | // " " 133 | detector.reset(); 134 | { 135 | ASSERT_EQ(detector.append(1, " "), NOT_EOS); 136 | 137 | char *delta = detector.getDelta(); 138 | assert(delta != nullptr); 139 | assert(std::strcmp(delta, " ") == 0); 140 | } 141 | 142 | // "! " 143 | detector.reset(); 144 | { 145 | ASSERT_EQ(detector.append(1, "!<"), MAYBE_EOS); 146 | ASSERT_EQ(detector.append(2, "eos"), MAYBE_EOS); 147 | ASSERT_EQ(detector.append(3, "> "), EOS); 148 | 149 | char *delta = detector.getDelta(); 150 | assert(delta != nullptr); 151 | assert(std::strcmp(delta, "!") == 0); 152 | } 153 | 154 | // "! " 155 | detector.reset(); 156 | { 157 | ASSERT_EQ(detector.append(1, "XY"), NOT_EOS); 159 | 160 | char *delta = detector.getDelta(); 161 | assert(delta != nullptr); 162 | assert(std::strcmp(delta, "XY") == 0); 163 | } 164 | 165 | // "" }; 238 | EosDetector detector(1, tokens, pieces, 0, 0); 239 | 240 | // "" 241 | { 242 | ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS); 243 | ASSERT_EQ(detector.append(2, "eo"), MAYBE_EOS); 244 | ASSERT_EQ(detector.append(3, "s>"), EOS); 245 | assert(detector.getDelta() == nullptr); 246 | } 247 | 248 | // " <" 249 | detector.reset(); 250 | { 251 | ASSERT_EQ(detector.append(1, " <"), NOT_EOS); 252 | char *delta = detector.getDelta(); 253 | assert(delta != nullptr); 254 | assert(std::strcmp(delta, " <") == 0); 255 | } 256 | 257 | // " " 258 | detector.reset(); 259 | { 260 | ASSERT_EQ(detector.append(1, " "), NOT_EOS); 262 | char *delta = detector.getDelta(); 263 | assert(delta != nullptr); 264 | assert(std::strcmp(delta, " ") == 0); 265 | } 266 | 267 | // EOS 268 | detector.reset(); 269 | { 270 | ASSERT_EQ(detector.append(TEST_EOS_ID, nullptr), EOS); 271 | assert(detector.getDelta() == nullptr); 272 | } 273 | 274 | // emoji 275 | detector.reset(); 276 | { 277 | ASSERT_EQ(detector.append(TEST_EOS_ID, "😃"), EOS); 278 | char *delta = detector.getDelta(); 279 | assert(delta != nullptr); 280 | assert(std::strcmp(delta, "😃") == 0); 281 | } 282 | 283 | printOk("eosDetectorWithLongPadding"); 284 | } 285 | 286 | int main() { 287 | #if DEV_TESTS 288 | Tokenizer tokenizer("models/llama3_2_1b_instruct_q40/dllama_tokenizer_llama3_2_1b_instruct_q40.t"); 289 | dev_testEncode(&tokenizer); 290 | dev_testDecoderEmoji(&tokenizer); 291 | dev_testDecoderEmojiWithEos(&tokenizer); 292 | #endif 293 | 294 | testChatTemplateDetection(); 295 | testEosDetectorWithPadding(); 296 | testEosDetectorWithLongPadding(); 297 | testEosDetectorWithoutPadding(); 298 | return 0; 299 | } 300 | -------------------------------------------------------------------------------- /src/tokenizer.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TOKENIZER_HPP 2 | #define TOKENIZER_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | typedef struct { 9 | char *str; 10 | unsigned int id; 11 | } TokenIndex; 12 | 13 | struct TokenizerOldHeader { 14 | unsigned int vocabSize; 15 | unsigned int maxTokenLength; 16 | int bosId; 17 | int eosId; 18 | int padId; 19 | }; 20 | 21 | enum TokenizerHeaderKey { 22 | TOK_VERSION = 0, 23 | TOK_VOCAB_SIZE = 1, 24 | MAX_TOKEN_LENGTH = 2, 25 | BOS_ID = 3, 26 | EOS_ID = 4, 27 | PAD_ID = 5, 28 | CHAT_EOS_ID = 6, 29 | CHAT_TEMPLATE = 7, 30 | CHAT_STOP = 8, 31 | }; 32 | 33 | class Tokenizer { 34 | private: 35 | unsigned int maxTokenLength; 36 | unsigned int regularVocabSize; 37 | unsigned int specialVocabSize; 38 | float *vocabScores; 39 | unsigned int *vocabLength; 40 | TokenIndex *regularVocab; 41 | TokenIndex *specialVocab; 42 | size_t strBufferSize; 43 | char *strBuffer; 44 | size_t strBufferPos; 45 | 46 | 47 | public: 48 | std::vector eosTokenIds; 49 | unsigned int vocabSize; 50 | char **vocab; 51 | int bosId; 52 | char *chatTemplate; 53 | 54 | Tokenizer(const char *tokenizer_path); 55 | ~Tokenizer(); 56 | int findSpecialTokenStartWith(char *piece); 57 | int findRegularToken(char *piece); 58 | void encode(char *text, int *tokens, int *nTokens, bool addBos, bool addSpecialTokens); 59 | bool isEos(int token); 60 | char *decode(int token); 61 | void resetDecoder(); 62 | }; 63 | 64 | typedef struct { 65 | float prob; 66 | int index; 67 | } ProbIndex; 68 | 69 | class Sampler { 70 | private: 71 | int vocab_size; 72 | ProbIndex *probindex; 73 | float temperature; 74 | float topp; 75 | unsigned long long rngState; 76 | 77 | public: 78 | Sampler(int vocab_size, float temperature, float topp, unsigned long long rngSeed); 79 | ~Sampler(); 80 | int sample(float *logits); 81 | void setTemp(float temp); 82 | void setSeed(unsigned long long rngSeed); 83 | }; 84 | 85 | class TokenizerChatStops { 86 | public: 87 | const char **stops; 88 | size_t nStops; 89 | size_t maxStopLength; 90 | TokenizerChatStops(Tokenizer *tokenizer); 91 | ~TokenizerChatStops(); 92 | }; 93 | 94 | enum ChatTemplateType { 95 | TEMPLATE_UNKNOWN = 0, 96 | TEMPLATE_LLAMA2 = 1, 97 | TEMPLATE_LLAMA3 = 2, 98 | TEMPLATE_DEEP_SEEK3 = 3, 99 | }; 100 | 101 | struct ChatItem { 102 | std::string role; 103 | std::string message; 104 | }; 105 | 106 | struct GeneratedChat { 107 | const char *content; 108 | size_t length; 109 | const char *publicPrompt; 110 | }; 111 | 112 | class ChatTemplateGenerator { 113 | public: 114 | const char *eos; 115 | ChatTemplateType type; 116 | std::string buffer; 117 | ChatTemplateGenerator(const ChatTemplateType type, const char *chatTemplate, const char *eos); 118 | GeneratedChat generate(unsigned int nItems, ChatItem *items, bool appendGenerationPrompt); 119 | }; 120 | 121 | enum EosDetectorType { 122 | MAYBE_EOS = 0, 123 | EOS = 1, 124 | NOT_EOS = 2, 125 | }; 126 | 127 | class EosDetector { 128 | private: 129 | size_t nTokens; 130 | const int *tokens; 131 | const char **pieces; 132 | size_t *pieceSizes; 133 | size_t bufferPos; 134 | size_t bufferSize; 135 | int eosPos; 136 | int paddingLeft; 137 | int paddingRight; 138 | public: 139 | char *buffer; 140 | EosDetector(size_t nTokens, const int *tokens, const char* *pieces, int paddingLeft, int paddingRight); 141 | ~EosDetector(); 142 | 143 | EosDetectorType append(int tokenId, const char *piece); 144 | bool isEos(int tokenId); 145 | char *getDelta(); 146 | void reset(); 147 | }; 148 | 149 | #endif 150 | --------------------------------------------------------------------------------