├── .github ├── 8raspi.jpg ├── 8raspi2.jpg ├── cover.png └── workflows │ └── main.yml ├── .gitignore ├── .vscode └── launch.json ├── LICENSE ├── Makefile ├── README.md ├── converter ├── .gitignore ├── convert-hf.py ├── convert-llama.py ├── convert-tokenizer-hf.py ├── convert-tokenizer-llama2.py ├── convert-tokenizer-llama3.py ├── requirements.txt ├── tokenizer-writer.py ├── writer-test.py └── writer.py ├── docs ├── HUGGINGFACE.md └── LLAMA.md ├── examples ├── chat-api-client.js ├── macbeth.sh └── n-workers.sh ├── launch.py ├── report └── report.pdf └── src ├── app.cpp ├── app.hpp ├── apps ├── dllama-api │ ├── README.md │ ├── dllama-api.cpp │ └── types.hpp ├── dllama │ └── dllama.cpp └── socket-benchmark │ └── socket-benchmark.cpp ├── commands-test.cpp ├── commands.cpp ├── commands.hpp ├── common ├── json.hpp └── pthread.h ├── funcs-test.cpp ├── funcs.cpp ├── funcs.hpp ├── grok1-tasks-test.cpp ├── grok1-tasks.cpp ├── grok1-tasks.hpp ├── llama2-tasks-test.cpp ├── llama2-tasks.cpp ├── llama2-tasks.hpp ├── mixtral-tasks.cpp ├── mixtral-tasks.hpp ├── quants-test.cpp ├── quants.cpp ├── quants.hpp ├── socket.cpp ├── socket.hpp ├── tasks.cpp ├── tasks.hpp ├── tokenizer-test.cpp ├── tokenizer.cpp ├── tokenizer.hpp ├── transformer.cpp ├── transformer.hpp ├── utils.cpp └── utils.hpp /.github/8raspi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fairydreaming/distributed-llama/424b63fdd5343fdf51352c5d55842290ca4c9b5d/.github/8raspi.jpg -------------------------------------------------------------------------------- /.github/8raspi2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fairydreaming/distributed-llama/424b63fdd5343fdf51352c5d55842290ca4c9b5d/.github/8raspi2.jpg -------------------------------------------------------------------------------- /.github/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fairydreaming/distributed-llama/424b63fdd5343fdf51352c5d55842290ca4c9b5d/.github/cover.png -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: main 2 | on: 3 | pull_request: 4 | branches: 5 | - main 6 | push: 7 | branches: 8 | - main 9 | jobs: 10 | build-linux: 11 | name: Linux 12 | runs-on: ${{matrix.os}} 13 | strategy: 14 | matrix: 15 | os: 16 | - ubuntu-latest 17 | platforms: 18 | - linux/arm64 19 | - linux/amd64 20 | steps: 21 | - name: Checkout Repo 22 | uses: actions/checkout@v3 23 | - name: Dependencies 24 | id: dependencies 25 | run: sudo apt-get update && sudo apt-get install build-essential 26 | - name: Build 27 | id: build 28 | run: | 29 | make dllama 30 | make dllama-api 31 | make funcs-test 32 | make quants-test 33 | make tokenizer-test 34 | make commands-test 35 | make llama2-tasks-test 36 | make grok1-tasks-test 37 | - name: funcs-test 38 | run: ./funcs-test 39 | - name: quants-test 40 | run: ./quants-test 41 | - name: tokenizer-test 42 | run: ./tokenizer-test 43 | - name: commands-test 44 | run: ./commands-test 45 | - name: llama2-tasks-test 46 | run: ./llama2-tasks-test 47 | - name: grok1-tasks-test 48 | run: ./grok1-tasks-test 49 | 50 | build-windows: 51 | name: Windows 52 | runs-on: windows-latest 53 | steps: 54 | - name: Checkout Repo 55 | uses: actions/checkout@v3 56 | - name: Dependencies 57 | id: dependencies 58 | run: choco install make 59 | - name: Build 60 | id: build 61 | run: | 62 | make dllama 63 | make dllama-api 64 | make funcs-test 65 | make quants-test 66 | make tokenizer-test 67 | make commands-test 68 | make llama2-tasks-test 69 | make grok1-tasks-test 70 | - name: funcs-test 71 | run: ./funcs-test 72 | - name: quants-test 73 | run: ./quants-test 74 | - name: tokenizer-test 75 | run: ./tokenizer-test 76 | - name: commands-test 77 | run: ./commands-test 78 | - name: llama2-tasks-test 79 | run: ./llama2-tasks-test 80 | - name: grok1-tasks-test 81 | run: ./grok1-tasks-test 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/settings.json 2 | 3 | *.o 4 | *.0 5 | *.dSYM 6 | *.data 7 | *.temp 8 | __pycache__ 9 | 10 | *-test 11 | /socket-benchmark 12 | main 13 | run*.sh 14 | server 15 | /dllama 16 | /dllama-* 17 | *.exe -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "main", 6 | "type": "cppdbg", 7 | "request": "launch", 8 | "program": "${workspaceFolder}/main", 9 | "args": [], 10 | "stopAtEntry": false, 11 | "cwd": "${workspaceFolder}", 12 | "environment": [], 13 | "externalConsole": false, 14 | "MIMode": "lldb" 15 | } 16 | ] 17 | } 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2024 Bartłomiej Tadych (b4rtaz) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CXX = g++ 2 | CXXFLAGS = -g -std=c++11 -Werror -O3 -march=native -mtune=native 3 | 4 | # Conditional settings for Windows 5 | ifeq ($(OS),Windows_NT) 6 | LIBS = -lws2_32 # or -lpthreadGC2 if needed 7 | else 8 | LIBS = -lpthread 9 | endif 10 | 11 | ifdef DLLAMA_USE_NUMA 12 | LIBS += -lnuma 13 | CXXFLAGS += -DDLLAMA_USE_NUMA=1 14 | endif 15 | 16 | utils: src/utils.cpp 17 | $(CXX) $(CXXFLAGS) -c src/utils.cpp -o utils.o 18 | quants: src/quants.cpp 19 | $(CXX) $(CXXFLAGS) -c src/quants.cpp -o quants.o 20 | funcs: src/funcs.cpp 21 | $(CXX) $(CXXFLAGS) -c src/funcs.cpp -o funcs.o 22 | funcs-test: src/funcs-test.cpp funcs 23 | $(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o 24 | commands: src/commands.cpp 25 | $(CXX) $(CXXFLAGS) -c src/commands.cpp -o commands.o 26 | socket: src/socket.cpp 27 | $(CXX) $(CXXFLAGS) -c src/socket.cpp -o socket.o 28 | transformer: src/utils.cpp 29 | $(CXX) $(CXXFLAGS) -c src/transformer.cpp -o transformer.o 30 | tasks: src/tasks.cpp 31 | $(CXX) $(CXXFLAGS) -c src/tasks.cpp -o tasks.o 32 | llama2-tasks: src/llama2-tasks.cpp 33 | $(CXX) $(CXXFLAGS) -c src/llama2-tasks.cpp -o llama2-tasks.o 34 | grok1-tasks: src/grok1-tasks.cpp 35 | $(CXX) $(CXXFLAGS) -c src/grok1-tasks.cpp -o grok1-tasks.o 36 | mixtral-tasks: src/mixtral-tasks.cpp 37 | $(CXX) $(CXXFLAGS) -c src/mixtral-tasks.cpp -o mixtral-tasks.o 38 | tokenizer: src/tokenizer.cpp 39 | $(CXX) $(CXXFLAGS) -c src/tokenizer.cpp -o tokenizer.o 40 | app: src/app.cpp 41 | $(CXX) $(CXXFLAGS) -c src/app.cpp -o app.o 42 | 43 | dllama: src/apps/dllama/dllama.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app 44 | $(CXX) $(CXXFLAGS) src/apps/dllama/dllama.cpp -o dllama utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS) 45 | dllama-api: src/apps/dllama-api/dllama-api.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app 46 | $(CXX) $(CXXFLAGS) src/apps/dllama-api/dllama-api.cpp -o dllama-api utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS) 47 | socket-benchmark: src/apps/socket-benchmark/socket-benchmark.cpp socket 48 | $(CXX) $(CXXFLAGS) src/apps/socket-benchmark/socket-benchmark.cpp -o socket-benchmark socket.o $(LIBS) 49 | 50 | funcs-test: src/funcs-test.cpp funcs utils quants 51 | $(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o $(LIBS) 52 | quants-test: src/quants.cpp utils quants 53 | $(CXX) $(CXXFLAGS) src/quants-test.cpp -o quants-test utils.o quants.o $(LIBS) 54 | tokenizer-test: src/tokenizer-test.cpp tokenizer funcs commands utils quants 55 | $(CXX) $(CXXFLAGS) src/tokenizer-test.cpp -o tokenizer-test tokenizer.o funcs.o commands.o utils.o quants.o $(LIBS) 56 | commands-test: src/commands-test.cpp funcs commands utils quants transformer socket 57 | $(CXX) $(CXXFLAGS) src/commands-test.cpp -o commands-test funcs.o commands.o utils.o quants.o transformer.o socket.o $(LIBS) 58 | llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs commands socket transformer tasks llama2-tasks tokenizer 59 | $(CXX) $(CXXFLAGS) src/llama2-tasks-test.cpp -o llama2-tasks-test utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o tokenizer.o $(LIBS) 60 | grok1-tasks-test: src/grok1-tasks-test.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks tokenizer 61 | $(CXX) $(CXXFLAGS) src/grok1-tasks-test.cpp -o grok1-tasks-test utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o tokenizer.o $(LIBS) 62 | -------------------------------------------------------------------------------- /converter/.gitignore: -------------------------------------------------------------------------------- 1 | *.t 2 | *.m 3 | *.bin 4 | */ 5 | -------------------------------------------------------------------------------- /converter/convert-hf.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import sys 4 | import os 5 | from writer import parseFloatType, writeTensor, writeHeader, FloatType 6 | from safetensors import safe_open 7 | 8 | class ArchType: 9 | LLAMA = 0xABCD00 10 | MIXTRAL = 0xABCD02 11 | 12 | def permute(tensor, nHeads: int, nKvHeads: int): 13 | if nHeads != nKvHeads: 14 | nHeads = nKvHeads 15 | return (tensor.reshape(nHeads, 2, tensor.shape[0] // nHeads // 2, *tensor.shape[1:]).swapaxes(1, 2).reshape(tensor.shape)) 16 | 17 | class Processor: 18 | def __init__(self, config): 19 | self.config = config 20 | self.currentModelIndex = None 21 | self.currentModel = None 22 | self.currentModelKeys = None 23 | self.layerMap = {} 24 | self.plan = [] 25 | 26 | def __unloadModel(self): 27 | if self.currentModel: 28 | del self.currentModel 29 | self.currentModel = None 30 | gc.collect() 31 | 32 | def __loadModel(self, index: int): 33 | if (self.currentModelIndex == index): 34 | return 35 | self.__unloadModel() 36 | filePath = self.config['files'][index] 37 | fileName = os.path.basename(filePath) 38 | print(f'💿 Loading file {fileName}...') 39 | self.currentModel = safe_open(filePath, framework='pt', device='cpu') 40 | self.currentModelKeys = list(self.currentModel.keys()) 41 | for key in self.currentModelKeys: 42 | self.layerMap[key] = index 43 | print(f'Found {len(self.currentModelKeys)} layers') 44 | self.currentModelIndex = index 45 | 46 | def __permuteQ(self, tensor): 47 | return permute(tensor, self.config['n_heads'], self.config['n_heads']) 48 | 49 | def __permuteK(self, tensor): 50 | return permute(tensor, self.config['n_heads'], self.config['n_kv_heads']) 51 | 52 | def __preparePlan(self): 53 | wt = self.config['weights_float_type'] 54 | p = self.plan 55 | p.append([FloatType.F32, 56 | 'model.embed_tokens.weight']) 57 | for l in range(0, self.config['n_layers']): 58 | p.append([wt, self.__permuteQ, 59 | f'model.layers.{l}.self_attn.q_proj.weight']) 60 | p.append([wt, self.__permuteK, 61 | f'model.layers.{l}.self_attn.k_proj.weight']) 62 | p.append([wt, 63 | f'model.layers.{l}.self_attn.v_proj.weight']) 64 | p.append([wt, 65 | f'model.layers.{l}.self_attn.o_proj.weight']) 66 | 67 | if (self.config['n_experts'] > 0): 68 | for e in range(self.config['n_experts']): 69 | p.append([wt, 70 | f'model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight']) # up 71 | p.append([wt, 72 | f'model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight']) # gate 73 | p.append([wt, 74 | f'model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight']) # down 75 | else: 76 | p.append([wt, 77 | f'model.layers.{l}.mlp.gate_proj.weight']) # gate 78 | p.append([wt, 79 | f'model.layers.{l}.mlp.down_proj.weight']) # down 80 | p.append([wt, 81 | f'model.layers.{l}.mlp.up_proj.weight']) # up 82 | 83 | p.append([FloatType.F32, 84 | f'model.layers.{l}.input_layernorm.weight']) 85 | p.append([FloatType.F32, 86 | f'model.layers.{l}.post_attention_layernorm.weight']) 87 | p.append([FloatType.F32, 88 | 'model.norm.weight']) 89 | p.append([wt, 90 | 'lm_head.weight']) 91 | 92 | def write(self, outputFile: str): 93 | self.__preparePlan() 94 | for planItem in self.plan: 95 | lookup = planItem[1:] 96 | transform = None 97 | if (callable(lookup[0])): 98 | transform = lookup[0] 99 | lookup = lookup[1:] 100 | 101 | if (self.currentModelIndex == None): 102 | modelIndex = 0 103 | else: 104 | modelIndex = None 105 | for layerName in lookup: 106 | if (layerName in self.layerMap): 107 | modelIndex = self.layerMap[layerName] 108 | break 109 | if (modelIndex is None): 110 | modelIndex = self.currentModelIndex + 1 111 | self.__loadModel(modelIndex) 112 | 113 | tensor = None 114 | for layerName in lookup: 115 | if (layerName in self.currentModelKeys): 116 | tensor = self.currentModel.get_tensor(layerName) 117 | break 118 | if tensor is None: 119 | raise Exception(f'Layer {lookup[0]} not found') 120 | print(f'🔶 Writing tensor {layerName} {tensor.shape}...') 121 | 122 | floatType = planItem[0] 123 | if (transform): 124 | tensor = transform(tensor) 125 | writeTensor(outputFile, tensor, floatType) 126 | 127 | def parseArchType(type: str): 128 | archType = { 129 | 'llama': ArchType.LLAMA, 130 | 'mistral': ArchType.LLAMA, 131 | 'mixtral': ArchType.MIXTRAL, 132 | }.get(type) 133 | if (archType is None): 134 | raise Exception(f'Unsupported arch type: {type}') 135 | return archType 136 | 137 | def parseHiddenAct(act: str): 138 | hiddenAct = { 139 | 'gelu': 0, 140 | 'silu': 1 141 | }.get(act) 142 | if (hiddenAct is None): 143 | raise Exception(f'Unsupported hidden act: {act}') 144 | return hiddenAct 145 | 146 | def parseRopeType(rt: str): 147 | ropeType = { 148 | 'llama3': 2, # LLAMA3_1 149 | }.get(rt) 150 | if (ropeType is None): 151 | raise Exception(f'Unsupported rope type: {ropeType}') 152 | return ropeType 153 | 154 | def loadConfig(folderPath: str, weightsFloatType: int): 155 | allFiles = os.listdir(folderPath) 156 | allFiles.sort() 157 | with open(os.path.join(folderPath, 'config.json')) as fc: 158 | config = json.load(fc) 159 | files = [] 160 | for fileName in allFiles: 161 | if fileName.endswith('.safetensors') and not fileName.startswith('.'): 162 | files.append(os.path.join(folderPath, fileName)) 163 | if (len(files) == 0): 164 | raise Exception('Not found any model file') 165 | 166 | result = { 167 | 'version': 0, 168 | 'arch_type': parseArchType(config['model_type']), 169 | 'hidden_act': parseHiddenAct(config['hidden_act']), 170 | 'dim': config['hidden_size'], 171 | 'hidden_dim': config['intermediate_size'], 172 | 'n_layers': config['num_hidden_layers'], 173 | 'n_heads': config['num_attention_heads'], 174 | 'n_kv_heads': config['num_key_value_heads'], 175 | 'weights_float_type': weightsFloatType, 176 | 'max_seq_len': config['max_position_embeddings'], 177 | 'vocab_size': config['vocab_size'], 178 | 'files': files, 179 | } 180 | 181 | nExperts = config.get('num_local_experts') 182 | nActiveExperts = config.get('num_active_local_experts') or config.get('num_experts_per_tok') 183 | result['n_experts'] = int(nExperts) if nExperts is not None else 0 184 | result['n_active_experts'] = int(nActiveExperts) if nActiveExperts is not None else 0 185 | 186 | ropeTheta = config.get('rope_theta') 187 | if (ropeTheta is not None): 188 | result['rope_theta'] = int(ropeTheta) 189 | 190 | ropeScaling = config.get('rope_scaling') 191 | if (ropeScaling is not None): 192 | result['rope_scaling_factor'] = int(ropeScaling['factor']) 193 | result['rope_scaling_low_freq_factor'] = int(ropeScaling['low_freq_factor']) 194 | result['rope_scaling_high_freq_factory'] = int(ropeScaling['high_freq_factor']) 195 | result['rope_scaling_orig_max_seq_len'] = int(ropeScaling['original_max_position_embeddings']) 196 | result['rope_type'] = parseRopeType(ropeScaling['rope_type']) 197 | return result 198 | 199 | def printUsage(): 200 | print('Usage: python convert-hf.py ') 201 | print() 202 | print('Options:') 203 | print(' The path to the folder containing the model files') 204 | print(' The float type of the weights (e.g. "q40")') 205 | print(' The name of the model (e.g. "llama3")') 206 | 207 | if __name__ == '__main__': 208 | if (len(sys.argv) < 4): 209 | printUsage() 210 | exit(1) 211 | 212 | sourceFolderPath = sys.argv[1] 213 | weightsFloatType = parseFloatType(sys.argv[2]) 214 | name = sys.argv[3] 215 | outputFileName = f'dllama_model_{name}_{sys.argv[2]}.m' 216 | 217 | print(f'Output file: {outputFileName}') 218 | 219 | config = loadConfig(sourceFolderPath, weightsFloatType) 220 | 221 | with open(outputFileName, 'wb') as outputFile: 222 | writeHeader(outputFile, config) 223 | processor = Processor(config) 224 | processor.write(outputFile) 225 | 226 | print(f'✅ {outputFileName} created successfully') -------------------------------------------------------------------------------- /converter/convert-llama.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import math 6 | import numpy as np 7 | from writer import writeTensor, writeHeader, parseFloatType, strFloatType, FloatType 8 | from pathlib import Path 9 | 10 | LAYER_CHUNK_SIZE = 48 11 | 12 | def convert(modelPath, outputPath, targetFloatType): 13 | paramsPath = os.path.join(modelPath, 'params.json') 14 | with open(paramsPath) as f: 15 | params = json.load(f) 16 | if (params['vocab_size'] < 1): 17 | raise Exception('vocab_size is invalid, please update params.json file') 18 | if (params.get('max_seq_len') is None): 19 | raise Exception('max_seq_len is required, please update params.json file') 20 | params['n_kv_heads'] = params.get('n_kv_heads') or params['n_heads'] 21 | params['head_size'] = params['dim'] / params['n_heads'] 22 | params['arch_type'] = 0xABCD00 23 | params['n_experts'] = 0 24 | params['n_active_experts'] = 0 25 | params['weights_float_type'] = targetFloatType 26 | if ('rope_theta' in params): 27 | params['rope_theta'] = int(params['rope_theta']) 28 | 29 | modelPaths = sorted(list(Path(modelPath).glob('consolidated.*.pth'))) 30 | nSlices = len(modelPaths) 31 | 32 | layers = [] 33 | layers.append('tok_embeddings.weight') 34 | for layerIndex in range(0, params['n_layers']): 35 | layers.append(f'layers.{layerIndex}.attention.wq.weight') 36 | layers.append(f'layers.{layerIndex}.attention.wk.weight') 37 | layers.append(f'layers.{layerIndex}.attention.wv.weight') 38 | layers.append(f'layers.{layerIndex}.attention.wo.weight') 39 | layers.append(f'layers.{layerIndex}.feed_forward.w1.weight') 40 | layers.append(f'layers.{layerIndex}.feed_forward.w2.weight') 41 | layers.append(f'layers.{layerIndex}.feed_forward.w3.weight') 42 | layers.append(f'layers.{layerIndex}.attention_norm.weight') 43 | layers.append(f'layers.{layerIndex}.ffn_norm.weight') 44 | layers.append('norm.weight') 45 | layers.append('output.weight') 46 | 47 | isHeaderWrote = False 48 | outFile = open(outputPath, 'wb') 49 | 50 | nChunks = math.ceil(len(layers) / LAYER_CHUNK_SIZE) 51 | for chunkIndex in range(0, nChunks): 52 | chunkLayerNames = layers[LAYER_CHUNK_SIZE * chunkIndex:LAYER_CHUNK_SIZE * (chunkIndex + 1)] 53 | models = {} 54 | for layerName in chunkLayerNames: 55 | models[layerName] = [] 56 | 57 | print(f'💿 Chunking model {chunkIndex + 1}/{nChunks}...') 58 | 59 | for modelPath in modelPaths: 60 | model = torch.load(modelPath, map_location='cpu') 61 | for modelKey in model: 62 | if (modelKey in chunkLayerNames): 63 | models[modelKey].append(model[modelKey]) 64 | if not isHeaderWrote: 65 | params['hidden_dim'] = model['layers.0.feed_forward.w1.weight'].shape[0] * nSlices 66 | writeHeader(outFile, params) 67 | isHeaderWrote = True 68 | del model 69 | 70 | for layerName in chunkLayerNames: 71 | if layerName == 'rope.freqs': 72 | continue 73 | 74 | isAxis1 = ( 75 | layerName == 'tok_embeddings.weight' or 76 | layerName.endswith('.attention.wo.weight') or 77 | layerName.endswith('.feed_forward.w2.weight') 78 | ) 79 | isAlwaysF32 = ( 80 | layerName == 'tok_embeddings.weight' or 81 | layerName.endswith('.attention_norm.weight') or 82 | layerName.endswith('.ffn_norm.weight') or 83 | layerName == 'norm.weight' 84 | ) 85 | floatType = FloatType.F32 if isAlwaysF32 else targetFloatType 86 | 87 | tensors = models[layerName] 88 | if len(tensors) == 1 or len(tensors[0].shape) == 1: 89 | tensor = tensors[0] 90 | else: 91 | tensor = torch.cat(tensors, dim=(1 if isAxis1 else 0)) 92 | 93 | print(f'🔶 Exporting {layerName} {tensor.shape}...') 94 | writeTensor(outFile, tensor, floatType) 95 | 96 | del models 97 | 98 | outFile.close() 99 | 100 | def usage(): 101 | print('Usage: python convert-llama.py ') 102 | exit(1) 103 | 104 | if __name__ == '__main__': 105 | if (len(sys.argv) < 3): 106 | usage() 107 | 108 | modelPath = sys.argv[1] 109 | targetFloatType = parseFloatType(sys.argv[2]) 110 | targetFloatTypeStr = strFloatType(targetFloatType) 111 | 112 | modelName = os.path.basename(modelPath) 113 | outputFileName = f'dllama_model_{modelName.lower()}_{targetFloatTypeStr}.m' 114 | 115 | print(f'Model name: {modelName}') 116 | print(f'Target float type: {targetFloatTypeStr}') 117 | print(f'Target file: {outputFileName}') 118 | 119 | convert(modelPath, outputFileName, targetFloatType) 120 | 121 | print('Done!') 122 | -------------------------------------------------------------------------------- /converter/convert-tokenizer-hf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import os 4 | from sentencepiece import SentencePieceProcessor 5 | writer = __import__('tokenizer-writer') 6 | 7 | def openJson(path): 8 | with open(path, 'r', encoding='utf-8') as file: 9 | return json.load(file) 10 | 11 | class TokensResolver: 12 | def __init__(self, dirPath, tokenizerConfig): 13 | self.dirPath = dirPath 14 | self.tokenizerConfig = tokenizerConfig 15 | self.bosId = None 16 | self.eosId = None 17 | self.tokens = [] 18 | self.scores = [] 19 | 20 | def resolvePreTrainedTokenizerFast(self): 21 | tokenizer = openJson(os.path.join(self.dirPath, 'tokenizer.json')) 22 | assert(tokenizer['model']['type'] == 'BPE') 23 | 24 | i = 0 25 | for token in tokenizer['model']['vocab'].keys(): 26 | assert(tokenizer['model']['vocab'][token] == i) 27 | self.tokens.append(token.encode('utf8')) 28 | self.scores.append(-float(i)) 29 | i += 1 30 | if ('added_tokens' in tokenizer): 31 | for at in tokenizer['added_tokens']: 32 | assert(at['id'] == i) 33 | self.tokens.append(at['content'].encode('utf8')) 34 | self.scores.append(-float(i)) 35 | if (at['content'] == self.tokenizerConfig['bos_token']): 36 | self.bosId = i 37 | if (at['content'] == self.tokenizerConfig['eos_token']): 38 | self.eosId = i 39 | i += 1 40 | 41 | def resolveLlamaTokenizer(self): 42 | modelPath = os.path.join(self.dirPath, 'tokenizer.model') 43 | processor = SentencePieceProcessor(model_file=modelPath) 44 | 45 | assert processor.vocab_size() == processor.get_piece_size() 46 | self.bosId = processor.bos_id() 47 | self.eosId = processor.eos_id() 48 | 49 | vocabSize = processor.vocab_size() 50 | for i in range(vocabSize): 51 | t = processor.id_to_piece(i) 52 | s = processor.get_score(i) 53 | t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace 54 | b = t.encode('utf-8') 55 | self.tokens.append(b) 56 | self.scores.append(s) 57 | 58 | def resolve(self): 59 | cls = self.tokenizerConfig['tokenizer_class'] 60 | if (cls == 'PreTrainedTokenizerFast'): 61 | return self.resolvePreTrainedTokenizerFast() 62 | if (cls == 'LlamaTokenizer'): 63 | return self.resolveLlamaTokenizer() 64 | raise Exception(f'Tokenizer {cls} is not supported') 65 | 66 | def printUsage(): 67 | print('Usage: python convert-tokenizer-hf.py ') 68 | print() 69 | print('Options:') 70 | print(' The path to the folder with tokenizer_config.json') 71 | print(' The name of the tokenizer (e.g. "llama3")') 72 | 73 | if __name__ == '__main__': 74 | if (len(sys.argv) < 2): 75 | printUsage() 76 | exit(1) 77 | 78 | dirPath = sys.argv[1] 79 | name = sys.argv[2] 80 | tokenizerConfig = openJson(os.path.join(dirPath, 'tokenizer_config.json')) 81 | 82 | resolver = TokensResolver(dirPath, tokenizerConfig) 83 | resolver.resolve() 84 | 85 | print(f'bosId: {resolver.bosId} ({resolver.tokens[resolver.bosId]})') 86 | print(f'eosId: {resolver.eosId} ({resolver.tokens[resolver.eosId]})') 87 | 88 | chatTemplate = None 89 | chatExtraStop = None 90 | if ('chat_template' in tokenizerConfig): 91 | chatTemplate = tokenizerConfig['chat_template'].encode('utf-8') 92 | input = input('⏩ Enter value for chat extra stop (enter to skip): ') 93 | if (input != ''): 94 | chatExtraStop = input.encode('utf-8') 95 | 96 | outputFileName = f'dllama_tokenizer_{name}.t' 97 | with open(outputFileName, 'wb') as outputFile: 98 | writer.writeTokenizer(outputFile, { 99 | 'bos_id': resolver.bosId, 100 | 'eos_id': resolver.eosId, 101 | 'chat_eos_id': resolver.eosId, 102 | }, resolver.tokens, resolver.scores, chatTemplate, chatExtraStop) 103 | print(f'✅ Created {outputFileName}') 104 | -------------------------------------------------------------------------------- /converter/convert-tokenizer-llama2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from sentencepiece import SentencePieceProcessor 4 | writer = __import__('tokenizer-writer') 5 | 6 | chatTemplate = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" 7 | 8 | def printUsage(): 9 | print('Usage: python convert-tokenizer-llama2.py ') 10 | print() 11 | print('Options:') 12 | print(' The path to the folder with llama2 folder path') 13 | 14 | if __name__ == '__main__': 15 | if (len(sys.argv) < 2): 16 | printUsage() 17 | exit(1) 18 | 19 | dirPath = sys.argv[1] 20 | modelPath = os.path.join(dirPath, 'tokenizer.model') 21 | processor = SentencePieceProcessor(model_file=modelPath) 22 | 23 | vocabSize = processor.vocab_size() 24 | tokens = [] 25 | scores = [] 26 | for i in range(vocabSize): 27 | t = processor.id_to_piece(i) 28 | s = processor.get_score(i) 29 | t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace 30 | b = t.encode('utf-8') 31 | tokens.append(b) 32 | scores.append(s) 33 | 34 | outputFileName = 'dllama_tokenizer_llama2.t' 35 | with open(outputFileName, 'wb') as outputFile: 36 | writer.writeTokenizer(outputFile, { 37 | 'bos_id': processor.bos_id(), 38 | 'eos_id': processor.eos_id(), 39 | 'chat_eos_id': processor.eos_id(), 40 | }, tokens, scores, chatTemplate.encode('utf-8'), None) 41 | 42 | print(f'✅ Created {outputFileName}') 43 | -------------------------------------------------------------------------------- /converter/convert-tokenizer-llama3.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import base64 3 | writer = __import__('tokenizer-writer') 4 | 5 | # Format of input file: 6 | # ``` 7 | # IQ== 0 8 | # Ig== 1 9 | # Iw== 2 10 | # ... 11 | # ``` 12 | 13 | nSpecialTokens = 256 14 | specialTokens = [ 15 | '<|begin_of_text|>', 16 | '<|end_of_text|>', 17 | '<|reserved_special_token_0|>', 18 | '<|reserved_special_token_1|>', 19 | '<|reserved_special_token_2|>', 20 | '<|reserved_special_token_3|>', 21 | '<|start_header_id|>', 22 | '<|end_header_id|>', 23 | '<|reserved_special_token_4|>', 24 | '<|eot_id|>', 25 | ] + [ 26 | f'<|reserved_special_token_{i}|>' 27 | for i in range(5, nSpecialTokens - 5) 28 | ] 29 | bosId = 128000 30 | eosId = 128001 31 | chatEosId = 128009 32 | chatTemplate = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" 33 | 34 | def printUsage(): 35 | print('Usage: python convert-tokenizer-llama3.py ') 36 | print() 37 | print('Options:') 38 | print(' The path to the Llama 3 tokenizer model (tokenizer.model)') 39 | 40 | if __name__ == '__main__': 41 | if (len(sys.argv) < 2): 42 | printUsage() 43 | exit(1) 44 | 45 | modelPath = sys.argv[1] 46 | outputFileName = 'dllama_tokenizer_llama3.t' 47 | 48 | with open(modelPath, 'r') as inputFile: 49 | with open(outputFileName, 'wb') as outputFile: 50 | inputLines = inputFile.readlines() 51 | nLines = len(inputLines) 52 | 53 | tokens = [] 54 | scores = [] 55 | for line in inputLines: 56 | s = line.split(' ') 57 | bytes = base64.b64decode(s[0]) 58 | score = -float(s[1]) 59 | tokens.append(bytes) 60 | scores.append(score) 61 | 62 | specialTokenIndex = nLines 63 | for token in specialTokens: 64 | bytes = token.encode('utf-8') 65 | score = -float(specialTokenIndex) 66 | tokens.append(bytes) 67 | scores.append(score) 68 | specialTokenIndex += 1 69 | 70 | writer.writeTokenizer(outputFile, { 71 | 'bos_id': bosId, 72 | 'eos_id': eosId, 73 | 'chat_eos_id': chatEosId, 74 | }, tokens, scores, chatTemplate.encode('utf-8'), None) 75 | 76 | print(f'✅ Created {outputFileName}') 77 | -------------------------------------------------------------------------------- /converter/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.5 2 | torch==2.0.1 3 | safetensors==0.4.2 -------------------------------------------------------------------------------- /converter/tokenizer-writer.py: -------------------------------------------------------------------------------- 1 | import struct 2 | 3 | def writeTokenizer(file, params, tokens, scores, chatTemplate, chatExtraStop): 4 | assert(params['eos_id'] is not None) 5 | assert(params['bos_id'] is not None) 6 | 7 | headerKeys = { 8 | 'version': 0, 9 | 'vocab_size': 1, 10 | 'max_token_length': 2, 11 | 'bos_id': 3, 12 | 'eos_id': 4, 13 | 'pad_id': 5, 14 | 'chat_eos_id': 6, 15 | 'chat_template': 7, 16 | 'chat_stop': 8 17 | } 18 | header = struct.pack('i', 0x567124) 19 | 20 | nTokens = len(tokens) 21 | maxTokenLength = max(len(t) for t in tokens) 22 | 23 | params['version'] = 1 24 | params['vocab_size'] = nTokens 25 | params['max_token_length'] = maxTokenLength 26 | if (chatTemplate): 27 | params['chat_template'] = len(chatTemplate) 28 | if (chatExtraStop): 29 | params['chat_stop'] = len(chatExtraStop) 30 | 31 | data = b'' 32 | for key in params: 33 | value = params[key] 34 | if value is None: 35 | continue 36 | if key in headerKeys: 37 | data += struct.pack('ii', headerKeys[key], params[key]) 38 | else: 39 | print(f'Unknown header key: {key}') 40 | 41 | print('⭐ Params:') 42 | print(params) 43 | if (chatTemplate): 44 | print('⭐ Chat template:') 45 | print(chatTemplate) 46 | 47 | header += struct.pack('i', len(header) * 2 + len(data)) 48 | file.write(header) 49 | file.write(data) 50 | if chatTemplate: 51 | file.write(chatTemplate) 52 | if chatExtraStop: 53 | file.write(chatExtraStop) 54 | 55 | for i in range(0, nTokens): 56 | size = len(tokens[i]) 57 | assert(size > 0) 58 | file.write(struct.pack('fI', scores[i], size)) 59 | file.write(tokens[i]) 60 | -------------------------------------------------------------------------------- /converter/writer-test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import torch 4 | from writer import writeQuantizedQ40Tensor 5 | 6 | TEMP_FILE_NAME = 'writer-test.temp' 7 | 8 | def readBase64FromFile(path): 9 | with open(path, 'rb') as file: 10 | return file.read().hex() 11 | 12 | def testWriteQuantizedQ40Tensor(): 13 | EXPECTED_OUTPUT = '7e346345a692b89665b2c5790537876e598aaa366d988876a898b8d788a98868ce660c66f6b3a88cba5ce9a871987ba9cc5bcaaa760c1eb556a4455b747b6b9504968828ef2a8d7c1db5c6be3764799e66db6d8e76463126a30e4333cad7a4f645947c6cf97f9de086d468c8d535a6ba7dc799d3d0c657bab6799468cad8bb349eb7d7635c7c798998696bb38e4085a9eb34444ba96a7f8ba7b2b42d746a96cf9660aeb4499d8708ad5c7b9a7558947645f3bbb6b0346a656887ad9a86059baac5c596ab781c703569bb8a4356a4bd58cb78736ba09759bb0e34a6274e827b957d7a67dfa86846955660d234b6d9d78a378094a8a8708a7a774ae92f8a36b8c999a9b77a7d958a69747c807963941235379886d69a7a8767b3a6a4ac71999760' 14 | 15 | torch.manual_seed(seed=1) 16 | tensor = torch.randn(32, 16) 17 | 18 | with open(TEMP_FILE_NAME, 'wb') as file: 19 | writeQuantizedQ40Tensor(file, tensor) 20 | 21 | contentBase64 = readBase64FromFile(TEMP_FILE_NAME) 22 | assert contentBase64 == EXPECTED_OUTPUT, f'Received: {contentBase64}' 23 | print('✅ writeQuantizedQ40Tensor') 24 | 25 | def runWriteQuantizedQ40TensorBenchmark(): 26 | tensor = torch.randn(8192, 4096) 27 | t0 = time.time() 28 | with open(TEMP_FILE_NAME, 'wb') as file: 29 | writeQuantizedQ40Tensor(file, tensor) 30 | t1 = time.time() 31 | print(f'🕐 writeQuantizedQ40Tensor: {t1 - t0:.4f}s') 32 | 33 | if __name__ == '__main__': 34 | testWriteQuantizedQ40Tensor() 35 | runWriteQuantizedQ40TensorBenchmark() 36 | -------------------------------------------------------------------------------- /converter/writer.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import torch 3 | import time 4 | import numpy as np 5 | 6 | class FloatType: 7 | F32 = 0 8 | F16 = 1 9 | Q40 = 2 10 | Q80 = 3 11 | 12 | floatTypeMap = { 13 | 'f32': FloatType.F32, 14 | 'f16': FloatType.F16, 15 | 'q40': FloatType.Q40, 16 | 'q80': FloatType.Q80, 17 | } 18 | floatTypeNames = list(floatTypeMap.keys()) 19 | 20 | def parseFloatType(type): 21 | floatType = floatTypeMap.get(type) 22 | if floatType is not None: 23 | return floatType 24 | raise Exception(f'{type} is not supported') 25 | 26 | def strFloatType(type): 27 | return floatTypeNames[type] 28 | 29 | def writeQuantizedQ40Tensor(file, x): 30 | x = x.to(torch.float32).numpy().astype(np.float32) 31 | blockSize = 32 32 | blockHalfSize = blockSize // 2 33 | assert(x.shape[0] % blockSize == 0) 34 | groups = x.reshape(-1, blockSize) 35 | gmax = np.max(groups, axis=1) 36 | gmin = np.min(groups, axis=1) 37 | deltas = np.divide(np.where(-gmin > gmax, gmin, gmax), -8) 38 | deltas16 = deltas.astype(np.float16) 39 | ids = np.where(deltas != 0, 1.0 / deltas, 0) 40 | groups = np.add(groups * ids[:, np.newaxis], 8.5) 41 | groups = np.clip(groups, 0, 15).astype(int) 42 | 43 | gLow = groups[:, :blockHalfSize] & 0xF 44 | gHigh = (groups[:, blockHalfSize:] & 0xF) << 4 45 | gCombined = gLow | gHigh 46 | 47 | nBytes = 0 48 | for groupIndex in range(0, len(groups)): 49 | delta16 = deltas16[groupIndex] 50 | buffer = struct.pack(f'e{blockHalfSize}B', delta16, *gCombined[groupIndex]) 51 | file.write(buffer) 52 | nBytes += len(buffer) 53 | return nBytes 54 | 55 | def writeQuantizedQ80Tensor(file, x): 56 | x = x.to(torch.float32).numpy().astype(np.float32) 57 | blockSize = 32 58 | assert(x.shape[0] % blockSize == 0) 59 | groups = x.reshape(-1, blockSize) 60 | gmax = np.max(groups, axis=1) 61 | gmin = np.min(groups, axis=1) 62 | gabsMax = np.where(-gmin > gmax, -gmin, gmax) 63 | deltas = gabsMax / ((1 << 7) - 1) 64 | deltas16 = deltas.astype(np.float16) 65 | ids = np.where(deltas != 0, 1.0 / deltas, 0) 66 | groups = groups * ids[:, np.newaxis] 67 | groups8 = np.round(groups).astype(np.int8) 68 | 69 | nBytes = 0 70 | for groupIndex in range(0, len(groups)): 71 | buffer = struct.pack(f'e{blockSize}b', deltas16[groupIndex], *groups8[groupIndex]) 72 | file.write(buffer) 73 | nBytes += len(buffer) 74 | return nBytes 75 | 76 | def writeF32Tensor(file, d): 77 | chunkSize = 10000 78 | nBytes = 0 79 | for i in range(0, len(d), chunkSize): 80 | chunk = d[i:i+chunkSize].to(torch.float32).numpy().astype(np.float32) 81 | b = struct.pack(f'{len(chunk)}f', *chunk) 82 | nBytes += len(b) 83 | file.write(b) 84 | return nBytes 85 | 86 | def writeF16Tensor(file, d): 87 | d = d.to(torch.float16).numpy().astype(np.float16) 88 | b = struct.pack(f'{len(d)}e', *d) 89 | file.write(b) 90 | return len(b) 91 | 92 | def writeTensor(file, tensor, floatType): 93 | d = tensor.detach().cpu().view(-1) 94 | t0 = time.time() 95 | nBytes = 0 96 | if (floatType == FloatType.F16): 97 | nBytes = writeF16Tensor(file, d) 98 | elif (floatType == FloatType.F32): 99 | nBytes = writeF32Tensor(file, d) 100 | elif (floatType == FloatType.Q40): 101 | nBytes = writeQuantizedQ40Tensor(file, d) 102 | elif (floatType == FloatType.Q80): 103 | nBytes = writeQuantizedQ80Tensor(file, d) 104 | else: 105 | raise Exception(f'Unknown float type') 106 | t1 = time.time() 107 | print(f'Saved {strFloatType(floatType)} tensor in {t1 - t0:.2f}s, {nBytes} bytes') 108 | 109 | def writeHeader(file, params): 110 | headerKeys = { 111 | 'version': 0, 112 | 'arch_type': 1, 113 | 'dim': 2, 114 | 'hidden_dim': 3, 115 | 'n_layers': 4, 116 | 'n_heads': 5, 117 | 'n_kv_heads': 6, 118 | 'n_experts': 7, 119 | 'n_active_experts': 8, 120 | 'vocab_size': 9, 121 | 'max_seq_len': 10, 122 | 'hidden_act': 11, 123 | 'rope_theta': 12, 124 | 'weights_float_type': 13, 125 | 'rope_scaling_factor': 14, 126 | 'rope_scaling_low_freq_factor': 15, 127 | 'rope_scaling_high_freq_factory': 16, 128 | 'rope_scaling_orig_max_seq_len': 17, 129 | 'rope_type': 18, 130 | } 131 | header = struct.pack('i', 0xA00ABCD) 132 | 133 | data = b'' 134 | for key in params: 135 | if key in headerKeys: 136 | data += struct.pack('ii', headerKeys[key], params[key]) 137 | else: 138 | print(f'Unknown header key: {key}') 139 | 140 | header += struct.pack('i', len(header) * 2 + len(data)) 141 | file.write(header) 142 | file.write(data) 143 | print(params) 144 | -------------------------------------------------------------------------------- /docs/HUGGINGFACE.md: -------------------------------------------------------------------------------- 1 | # How to Run Hugging Face 🤗 Model 2 | 3 | Currently, Distributed Llama supports three types of Hugging Face models: `llama`, `mistral`, and `mixtral`. You can try to convert any compatible Hugging Face model and run it with Distributed Llama. 4 | 5 | > [!IMPORTANT] 6 | > All converters are in the early stages of development. After conversion, the model may not work correctly. 7 | 8 | 1. Download a model, for example: [Mistral-7B-v0.3](https://huggingface.co/mistralai/Mistral-7B-v0.3/tree/main). 9 | 2. The downloaded model should contain `config.json`, `tokenizer.json`, `tokenizer_config.json` and `tokenizer.model` and safetensor files. 10 | 3. Run the converter of the model: 11 | ```sh 12 | cd converter 13 | python convert-hf.py path/to/hf/model q40 mistral-7b-0.3 14 | ``` 15 | 4. Run the converter of the tokenizer: 16 | ```sh 17 | python convert-tokenizer-hf.py path/to/hf/model mistral-7b-0.3 18 | ``` 19 | 5. That's it! Now you can run the Distributed Llama. 20 | ``` 21 | ./dllama inference --model dllama_model_mistral-7b-0.3_q40.m --tokenizer dllama_tokenizer_mistral-7b-0.3.t --buffer-float-type q80 --prompt "Hello world" 22 | ``` 23 | -------------------------------------------------------------------------------- /docs/LLAMA.md: -------------------------------------------------------------------------------- 1 | # How to Run Llama 2 | 3 | ## How to Run Llama 2 4 | 5 | 1. Download [Llama 2](https://github.com/facebookresearch/llama) weights from Meta. This project supports 7B, 7B-chat, 13B, 13B-chat, 70B and 70B-chat models. 6 | 2. Open the `llama-2-7b/params.json` file: 7 | * replace `"vocab_size": -1` to `"vocab_size": 32000`, 8 | * add a new property: `"max_seq_len": 2048`. 9 | 3. Install dependencies of the converter: 10 | ```sh 11 | cd converter && pip install -r requirements.txt 12 | ``` 13 | 4. Convert weights to Distributed Llama format. This will take a bit of time. The script requires Python 3. 14 | ```sh 15 | python convert-llama.py /path/to/meta/llama-2-7b q40 16 | ``` 17 | 5. Download the tokenizer for Llama 2: 18 | ``` 19 | wget https://huggingface.co/b4rtaz/Llama-2-Tokenizer-Distributed-Llama/resolve/main/dllama_tokenizer_llama2.t 20 | ``` 21 | 6. Build the project: 22 | ```bash 23 | make dllama 24 | make dllama-api 25 | ``` 26 | 7. Run: 27 | ```bash 28 | ./dllama inference --model dllama_llama-2-7b_q40.bin --tokenizer dllama-llama2-tokenizer.t --weights-float-type q40 --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4 29 | ``` 30 | 31 | In the table below, you can find the expected size of the converted weights with different floating-point types. 32 | 33 | | Model | Original size | Float32 | Float16 | Q40 | 34 | |-------------|---------------|----------|----------|----------| 35 | | Llama 2 7B | 13.48 GB | 25.10GB | | 3.95 GB | 36 | | Llama 2 13B | 26.03 GB | | | 7.35 GB | 37 | | Llama 2 70B | 137.97 GB | | | 36.98 GB | 38 | 39 | ## How to Run Llama 3 40 | 41 | 1. Get an access to the model on [Llama 3 website](https://llama.meta.com/llama-downloads). 42 | 2. Clone the `https://github.com/meta-llama/llama3` repository. 43 | 3. Run the `download.sh` script to download the model. 44 | 4. For Llama 3 8B model you should have the following files: 45 | - `Meta-Llama-3-8B/consolidated.00.pth` 46 | - `Meta-Llama-3-8B/params.json` 47 | - `Meta-Llama-3-8B/tokenizer.model` 48 | 5. Open `params.json` and add a new property: `"max_seq_len": 8192`. 49 | 6. Clone the `https://github.com/b4rtaz/distributed-llama.git` repository. 50 | 7. Install dependencies of the converter: 51 | ```sh 52 | cd converter && pip install -r requirements.txt 53 | ``` 54 | 8. Convert the model to the Distributed Llama format: 55 | ```bash 56 | python converter/convert-llama.py path/to/Meta-Llama-3-8B q40 57 | ``` 58 | 9. Convert the tokenizer to the Distributed Llama format: 59 | ```bash 60 | python converter/convert-tokenizer-llama3.py path/to/tokenizer.model 61 | ``` 62 | 10. Build the project: 63 | ```bash 64 | make dllama 65 | make dllama-api 66 | ``` 67 | 11. Run the Distributed Llama: 68 | ```bash 69 | ./dllama inference --weights-float-type q40 --buffer-float-type q80 --prompt "My name is" --steps 128 --nthreads 8 --model dllama_meta-llama-3-8b_q40.bin --tokenizer llama3-tokenizer.t 70 | ``` 71 | -------------------------------------------------------------------------------- /examples/chat-api-client.js: -------------------------------------------------------------------------------- 1 | // This is a simple client for dllama-api. 2 | // 3 | // Usage: 4 | // 5 | // 1. Start the server, how to do it is described in the `src/apps/dllama-api/README.md` file. 6 | // 2. Run this script: `node examples/chat-api-client.js` 7 | 8 | const HOST = process.env.HOST ? process.env.HOST : '127.0.0.1'; 9 | const PORT = process.env.PORT ? Number(process.env.PORT) : 9990; 10 | 11 | async function chat(messages, maxTokens) { 12 | const response = await fetch(`http://${HOST}:${PORT}/v1/chat/completions`, { 13 | method: 'POST', 14 | headers: { 15 | 'Content-Type': 'application/json', 16 | }, 17 | body: JSON.stringify({ 18 | messages, 19 | temperature: 0.7, 20 | stop: ['<|eot_id|>'], 21 | max_tokens: maxTokens 22 | }), 23 | }); 24 | return await response.json(); 25 | } 26 | 27 | async function ask(system, user, maxTokens) { 28 | console.log(`> system: ${system}`); 29 | console.log(`> user: ${user}`); 30 | const response = await chat([ 31 | { 32 | role: 'system', 33 | content: system 34 | }, 35 | { 36 | role: 'user', 37 | content: user 38 | } 39 | ], maxTokens); 40 | console.log(response.usage); 41 | console.log(response.choices[0].message.content); 42 | } 43 | 44 | async function main() { 45 | await ask('You are an excellent math teacher.', 'What is 1 + 2?', 128); 46 | await ask('You are a romantic.', 'Where is Europe?', 128); 47 | } 48 | 49 | main(); 50 | -------------------------------------------------------------------------------- /examples/macbeth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This is a simple test of generating a sequence that fulfills the KV cache. 4 | # 5 | # Used model & tokenizer: https://huggingface.co/b4rtaz/llama-3-8b-distributed-llama 6 | # Probably, this test will be working correctly only on MacBook Pro M1, due to differences in float multiplication on different CPUs. 7 | 8 | cd "$(dirname "$0")" 9 | cd .. 10 | 11 | # Source: https://www.opensourceshakespeare.org/views/plays/play_view.php?WorkID=macbeth&Scope=entire 12 | PROMPT="Duncan. What bloody man is that? He can report, 13 | As seemeth by his plight, of the revolt 14 | The newest state. 20 15 | 16 | Malcolm. This is the sergeant 17 | Who like a good and hardy soldier fought 18 | 'Gainst my captivity. Hail, brave friend! 19 | Say to the king the knowledge of the broil 20 | As thou didst leave it. 25 21 | 22 | Sergeant. Doubtful it stood; 23 | As two spent swimmers, that do cling together 24 | And choke their art. The merciless Macdonwald— 25 | Worthy to be a rebel, for to that 26 | The multiplying villanies of nature 30 27 | Do swarm upon him—from the western isles 28 | Of kerns and gallowglasses is supplied; 29 | And fortune, on his damned quarrel smiling, 30 | Show'd like a rebel's whore: but all's too weak: 31 | For brave Macbeth—well he deserves that name— 35 32 | Disdaining fortune, with his brandish'd steel, 33 | Which smoked with bloody execution, 34 | Like valour's minion carved out his passage 35 | Till he faced the slave; 36 | Which ne'er shook hands, nor bade farewell to him, 40 37 | Till he unseam'd him from the nave to the chaps, 38 | And fix'd his head upon our battlements. 39 | 40 | Duncan. O valiant cousin! worthy gentleman! 41 | 42 | Sergeant. As whence the sun 'gins his reflection 43 | Shipwrecking storms and direful thunders break, 45 44 | So from that spring whence comfort seem'd to come 45 | Discomfort swells. Mark, king of Scotland, mark: 46 | No sooner justice had with valour arm'd 47 | Compell'd these skipping kerns to trust their heels, 48 | But the Norweyan lord surveying vantage, 50 49 | With furbish'd arms and new supplies of men 50 | Began a fresh assault. 51 | 52 | Duncan. Dismay'd not this 53 | Our captains, Macbeth and Banquo? 54 | 55 | Sergeant. Yes; 55 56 | As sparrows eagles, or the hare the lion. 57 | If I say sooth, I must report they were 58 | As cannons overcharged with double cracks, so they 59 | Doubly redoubled strokes upon the foe: 60 | Except they meant to bathe in reeking wounds, 60 61 | Or memorise another Golgotha, 62 | I cannot tell. 63 | But I am faint, my gashes cry for help. 64 | 65 | Duncan. So well thy words become thee as thy wounds; 66 | They smack of honour both. Go get him surgeons. 65 67 | [Exit Sergeant, attended] 68 | Who comes here?" 69 | 70 | GENERATED="Malcolm. The worthy Thane of Ross. 71 | Duncan. What a haste looks through a duel's wounds! 70 72 | Some must be pac'd. 73 | [Exit Ross] 74 | See this encounter is like to the poring 75 | On of a beggar's story, told by one 76 | That means to pluck upon the heart the strings 77 | And draw the tears thriftily. 75 78 | [Enter Lennox] 79 | How goes the night, boy? 80 | 81 | Lennox. The night is long that none should wake. 82 | 83 | Duncan. You do not need to stare. The Moor 84 | To know the man. 'Tis the Moors devices. 80 85 | [Exit Lennox] 86 | By the happy right of mine own hands, 87 | Strike all that live in this poor thing of mine. 88 | 'Tis calld the Eyrie, and I am sick at heart. 89 | As hellish-devils do the damned souls 90 | O'their bad lives, thus ill-breveted, linger 91 | O'er lamps and forks and other instruments 92 | That prove the stages of the night. 90 93 | Good sir, take note; I bid you farewell: 94 | Come sleep, and cut short this nitty romance. 95 | [He sleeps.] 96 | If cravens, I bear them like the Minion of the moon, 97 | With tiptoe foot he sneaks and starts to be a man. 95 98 | And when he is found asleep, awake him with this armed' s address: 99 | That sleep which th'assassin hallowed, 100 | Scotland, awake; your king is murder'd, sleep no more. 100 101 | *Furbish'd. Weapons polished for battle. 102 | *Thriftily. Fastidiously, thoughtfully. 103 | *Eyrie. Fortress; the lair of birds of prey. 104 | *Minion. A braggart, a coward. 105 | 106 | 1.5 107 | 108 | Macbeth. So foul and fair a day I have not seen. 5 109 | Ross. Good morning, noble Macbeth. I come from Inverness, 110 | And find our throne void, the arm'd rest you; 10 111 | My Lord of Cassil has resigned his life. 112 | Macbeth. Whate'er you owe, in time repay, fair friends. 113 | Note you the words; I pray you do. 114 | Ross. I am your faithful servant, and will keep 115 | My sworn reward upon your life; my lord. 116 | Macbeth. You shall be well rewarded; stay the press, 20 117 | And I'll not fail. How now, good fellow? 118 | Servant. Sir, his schoolmaster. 25 119 | Macbeth. Well, good, though, old. 120 | Tell me, good fellow, how goes the night? 30 121 | Servant. There's marrygold and fire in your veins, my lord. 122 | Macbeth. He does commend you; the weight of this old night's embargoes 35 123 | Did one hour's waste of time lay upon him. 124 | I know when we are too safe, 'tis dangerous to be secure; 125 | Therefore our fearful parts do brave the danger 40 126 | Which knows it not. I see you are a gentleman. 127 | And a laudable one too; I am most off obliged. 128 | Servant. I should be sorry, my good lord, to have had the labour 45 129 | To outlive this damned hour. 50 130 | Macbeth. What's done cannot be undone. To bed, to bed, to bed. 131 | Servant. Will it please you to lie still? 55 132 | Macbeth. Lord, lord, my heart is in my mouth. All's true that ends well. 133 | Servant. I thank you, fair, and leave you to the content. 60 134 | Macbeth. You see, my lord, it smokes, and shows no cause 135 | Why the drone dies. 65 136 | Servant. Grief fills the room up of one vast stair, 137 | And downs our vaults to the inconstant man above. 70 138 | Macbeth. Go bid thy masters and thy mistress say, 75 139 | I have power in earth to do so much. 140 | There's comfort yet. They are assailable. Then say I, 141 | Thus ye may answer. 142 | Servant. He cannot be wronged; or being wronged, 80 143 | I cannot help him. 85 144 | Macbeth. You know but by this; as this, 90 145 | The Jew foole is hang'd. 95 146 | Servant. No more today, my lord. 100 147 | Macbeth. He does shame to tell him he loves him, but not remove him 105 148 | From his true place; no. 149 | Servant. That's true, and now I remember the story 110 150 | Of that sign in Leo four diurnal courses 151 | Returning in a constant motion were within 115 152 | A boare that had on Taurus' back tetracted; 120 153 | Or neuer, or but once in modulated accidence. 125 154 | Macbeth. Thou climd'st alone, ty'd to the stag's horn. 155 | Servant. I was a bull, for this the goodly year. 130 156 | Come, put me in my place. 157 | Macbeth. Now go to sleep. 135 158 | Servant. The west neuer sett before the equinox 140 159 | Till now; and sunnes look'd not theyr frequencie 145 160 | Upon our lappe till now, my lord. 150 161 | Macbeth. This game of chance you term a gong. 162 | Servant. A gong is a scotch word for an egg. 155 163 | Macbeth. Peace, be still. 160 164 | Servant. I coniecture I smell the blood of an Englishman. 165 165 | Macbeth. The faith is murthered. 166 | Servant. That murder'd in his sleep. 170 167 | Macbeth. And sleeping murdered. 175 168 | Servant. In the fair queen heere in his royal court. 180 169 | Macbeth. So great a mercy that it may last eternally. 170 | Servant. The earth hath bubbles as the water hath, 185 171 | And these are of them. Whate'er we will do 190 172 | To mend the trespasses of the comming time 195 173 | Shall be the seedes of new mischefe, and shall beget 200 174 | The formes of the extinctnese, which we are now. 205 175 | Macbeth. We have scorch'd the snake, not kill'd it. 210 176 | Servant. They hunt it in the morn. Good gally, good lord! 215 177 | It weares a gilded snout. 220 178 | Macbeth. It is the very painting of your fear. 225 179 | Servant. This is the worst. 230 180 | Macbeth. A fair quater of a mile is yet to go. 235 181 | Servant. A mile and half. 240 182 | Macbeth. I have run fifteen miles to-day. 183 | Servant. A calender's date. 184 | Macbeth. A bigger patch, a bigger patch. 245 185 | Servant. Thirteen of more. 250 186 | Macbeth. Wast thou with him? 255 187 | Servant. No, nor he to night. 260 188 | Macbeth. Thou seest the moon" 189 | 190 | echo "Generating, it can take a while..." 191 | 192 | OUTPUT=$(( ./dllama generate --seed 12345 --temperature 0.9 --topp 0.9 --prompt "$PROMPT" --weights-float-type q40 --buffer-float-type f32 --nthreads 2 --steps 2048 --model models/llama3_8b_q40/dllama_model_llama3_8b_q40.m --tokenizer models/llama3_8b_q40/dllama_tokenizer_llama3_8b_q40.t --workers 127.0.0.1:9999 127.0.0.1:9998 127.0.0.1:9997 ) 2>&1) 193 | 194 | echo "$OUTPUT" 195 | 196 | if [[ $OUTPUT == *"$GENERATED"* ]]; then 197 | echo "✅ Output is same" 198 | else 199 | echo "❌ Output is different" 200 | fi 201 | -------------------------------------------------------------------------------- /examples/n-workers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script starts N workers from a single command. Mainly useful for testing and debugging. 4 | # Usage: 5 | # 6 | # W=7 T=2 bash n-workers.sh start 7 | # W=7 bash n-workers.sh stop 8 | # 9 | # Env vars: 10 | # W - n workers 11 | # T - n threads per worker 12 | 13 | cd "$(dirname "$0")" 14 | 15 | if [ -z "$W" ]; then 16 | W=3 17 | fi 18 | if [ -z "$T" ]; then 19 | T=1 20 | fi 21 | 22 | if [ "$1" == "start" ]; then 23 | for (( w = 0; w < $W ; w += 1 )); 24 | do 25 | PORT=$(expr 9999 - $w) 26 | PROC_ID=$(lsof -ti:$PORT) 27 | if [ -n "$PROC_ID" ]; then 28 | kill -9 $PROC_ID 29 | echo "Killed process $PROC_ID" 30 | fi 31 | 32 | mkdir -p dllama_worker_$w # macOs does not support -Logfile argument, so we place logs inside different directories 33 | cd dllama_worker_$w 34 | screen -d -L -S dllama_worker_$w -m ../../dllama worker --port $PORT --nthreads $T 35 | cd .. 36 | echo "Started worker $w on port $PORT" 37 | done 38 | 39 | sleep 2 40 | elif [ "$1" == "stop" ]; then 41 | for (( w = 0; w < $W ; w += 1 )); 42 | do 43 | screen -S dllama_worker_$w -X quit 44 | done 45 | 46 | echo "Stopped $W workers" 47 | else 48 | echo "Usage: $0 [start|stop]" 49 | fi 50 | 51 | echo "> screen -ls" 52 | screen -ls 53 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import requests 4 | 5 | def parts(length): 6 | result = [] 7 | for i in range(length): 8 | a = chr(97 + (i // 26)) 9 | b = chr(97 + (i % 26)) 10 | result.append(a + b) 11 | return result 12 | 13 | # [['model-url-0', 'model-url-1', ...], 'tokenizer-url', 'weights-float-type', 'buffer-float-type', 'model-type'] 14 | MODELS = { 15 | 'tinyllama_1_1b_3t_q40': [ 16 | ['https://huggingface.co/b4rtaz/TinyLlama-1.1B-3T-Distributed-Llama/resolve/main/dllama_model_tinylama_1.1b_3t_q40.m?download=true'], 17 | 'https://huggingface.co/b4rtaz/TinyLlama-1.1B-3T-Distributed-Llama/resolve/main/dllama_tokenizer_tinylama_1.1b_3t.t?download=true', 18 | 'q40', 'q80', 'base' 19 | ], 20 | 'llama3_8b_q40': [ 21 | ['https://huggingface.co/b4rtaz/Llama-3-8B-Q40-Distributed-Llama/resolve/main/dllama_model_meta-llama-3-8b_q40.m?download=true'], 22 | 'https://huggingface.co/b4rtaz/Llama-3-8B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_llama3.t?download=true', 23 | 'q40', 'q80', 'base' 24 | ], 25 | 'llama3_8b_instruct_q40': [ 26 | ['https://huggingface.co/b4rtaz/Llama-3-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_lama3_instruct_q40.m?download=true'], 27 | 'https://huggingface.co/b4rtaz/Llama-3-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3.t?download=true', 28 | 'q40', 'q80', 'chat' 29 | ], 30 | 'llama3_1_8b_instruct_q40': [ 31 | ['https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.1_instruct_q40.m?download=true'], 32 | 'https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama_3_1.t?download=true', 33 | 'q40', 'q80', 'chat' 34 | ], 35 | 'llama3_1_405b_instruct_q40': [ 36 | list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Llama-3_1-405B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama31_405b_q40_{suffix}?download=true', parts(56))), 37 | 'https://huggingface.co/b4rtaz/Llama-3_1-405B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama_3_1.t?download=true', 38 | 'q40', 'q80', 'chat' 39 | ], 40 | 'llama3_2_1b_instruct_q40': [ 41 | ['https://huggingface.co/b4rtaz/Llama-3_2-1B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.2-1b-instruct_q40.m?download=true'], 42 | 'https://huggingface.co/b4rtaz/Llama-3_2-1B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3_2.t?download=true', 43 | 'q40', 'q80', 'chat', '--max-seq-len 8192' 44 | ], 45 | 'llama3_2_3b_instruct_q40': [ 46 | ['https://huggingface.co/b4rtaz/Llama-3_2-3B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.2-3b-instruct_q40.m?download=true'], 47 | 'https://huggingface.co/b4rtaz/Llama-3_2-3B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3_2.t?download=true', 48 | 'q40', 'q80', 'chat', '--max-seq-len 8192' 49 | ], 50 | } 51 | 52 | def downloadFile(urls: str, path: str): 53 | if (os.path.isfile(path)): 54 | fileName = os.path.basename(path) 55 | result = input(f'❓ {fileName} already exists, do you want to download again? ("Y" if yes): ') 56 | if (result.upper() != 'Y'): 57 | return 58 | 59 | lastSizeMb = 0 60 | with open(path, 'wb') as file: 61 | for url in urls: 62 | print(f'📄 {url}') 63 | response = requests.get(url, stream=True) 64 | response.raise_for_status() 65 | for chunk in response.iter_content(chunk_size=4096): 66 | file.write(chunk) 67 | sizeMb = file.tell() // (1024 * 1024) 68 | if (sizeMb != lastSizeMb): 69 | sys.stdout.write("\rDownloaded %i MB" % sizeMb) 70 | lastSizeMb = sizeMb 71 | sys.stdout.write('\n') 72 | sys.stdout.write(' ✅\n') 73 | 74 | def download(modelName: str, model: list): 75 | dirPath = os.path.join('models', modelName) 76 | print(f'📀 Downloading {modelName} to {dirPath}...') 77 | os.makedirs(dirPath, exist_ok=True) 78 | modelUrls = model[0] 79 | tokenizerUrl = model[1] 80 | modelPath = os.path.join(dirPath, f'dllama_model_{modelName}.m') 81 | tokenizerPath = os.path.join(dirPath, f'dllama_tokenizer_{modelName}.t') 82 | downloadFile(modelUrls, modelPath) 83 | downloadFile([tokenizerUrl], tokenizerPath) 84 | print('📀 All files are downloaded') 85 | return (modelPath, tokenizerPath) 86 | 87 | def writeRunFile(modelName: str, command: str): 88 | filePath = f'run_{modelName}.sh' 89 | with open(filePath, 'w') as file: 90 | file.write('#!/bin/sh\n') 91 | file.write('\n') 92 | file.write(f'{command}\n') 93 | return filePath 94 | 95 | def printUsage(): 96 | print('Usage: python download-model.py ') 97 | print() 98 | print('Options:') 99 | print(' The name of the model to download') 100 | print(' --run Run the model after download') 101 | print() 102 | print('Available models:') 103 | for model in MODELS: 104 | print(f' {model}') 105 | 106 | if __name__ == '__main__': 107 | if (len(sys.argv) < 2): 108 | printUsage() 109 | exit(1) 110 | 111 | os.chdir(os.path.dirname(__file__)) 112 | 113 | modelName = sys.argv[1].replace('-', '_') 114 | if modelName not in MODELS: 115 | print(f'Model is not supported: {modelName}') 116 | exit(1) 117 | runAfterDownload = sys.argv.count('--run') > 0 118 | 119 | model = MODELS[modelName] 120 | (modelPath, tokenizerPath) = download(modelName, model) 121 | if (model[4] == 'chat'): 122 | command = './dllama chat' 123 | else: 124 | command = './dllama inference --steps 64 --prompt "Hello world"' 125 | command += f' --model {modelPath} --tokenizer {tokenizerPath} --buffer-float-type {model[3]} --nthreads 4' 126 | if (len(model) > 5): 127 | command += f' {model[5]}' 128 | 129 | print('To run Distributed Llama you need to execute:') 130 | print('--- copy start ---') 131 | print() 132 | print('\033[96m' + command + '\033[0m') 133 | print() 134 | print('--- copy end -----') 135 | 136 | runFilePath = writeRunFile(modelName, command) 137 | print(f'🌻 Created {runFilePath} script to easy run') 138 | 139 | if (not runAfterDownload): 140 | runAfterDownload = input('❓ Do you want to run Distributed Llama? ("Y" if yes): ').lower() == 'Y' 141 | if (runAfterDownload): 142 | if (not os.path.isfile('dllama')): 143 | os.system('make dllama') 144 | os.system(command) 145 | -------------------------------------------------------------------------------- /report/report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fairydreaming/distributed-llama/424b63fdd5343fdf51352c5d55842290ca4c9b5d/report/report.pdf -------------------------------------------------------------------------------- /src/app.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "app.hpp" 9 | 10 | FloatType parseFloatType(char* val) { 11 | if (strcmp(val, "f32") == 0) return F32; 12 | if (strcmp(val, "f16") == 0) return F16; 13 | if (strcmp(val, "q40") == 0) return Q40; 14 | if (strcmp(val, "q80") == 0) return Q80; 15 | printf("Invalid float type %s\n", val); 16 | exit(EXIT_FAILURE); 17 | } 18 | 19 | ChatTemplateType parseChatTemplateType(char* val) { 20 | if (strcmp(val, "llama2") == 0) return TEMPLATE_LLAMA2; 21 | if (strcmp(val, "llama3") == 0) return TEMPLATE_LLAMA3; 22 | if (strcmp(val, "zephyr") == 0) return TEMPLATE_ZEPHYR; 23 | if (strcmp(val, "chatml") == 0) return TEMPLATE_CHATML; 24 | throw std::runtime_error("Invalid chat template type"); 25 | 26 | } 27 | 28 | AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) { 29 | AppArgs args; 30 | args.mode = NULL; 31 | args.nThreads = 4; 32 | args.modelPath = NULL; 33 | args.tokenizerPath = NULL; 34 | args.prompt = NULL; 35 | args.weightsFloatType = FUNK; 36 | args.bufferFloatType = F32; 37 | args.nWorkers = 0; 38 | args.port = 9990; 39 | args.temperature = 0.8f; 40 | args.topp = 0.9f; 41 | args.steps = 0; 42 | args.seed = (unsigned long long)time(NULL); 43 | args.chatTemplateType = TEMPLATE_UNKNOWN; 44 | args.maxSeqLen = 0; 45 | args.useDiscForKvCache = false; 46 | args.numa = -1; 47 | 48 | int i = 1; 49 | if (hasMode && argc > 1) { 50 | args.mode = argv[1]; 51 | i++; 52 | } 53 | for (; i + 1 < argc; i += 2) { 54 | char* name = argv[i]; 55 | char* value = argv[i + 1]; 56 | if (strcmp(name, "--model") == 0) { 57 | args.modelPath = value; 58 | } else if (strcmp(name, "--tokenizer") == 0) { 59 | args.tokenizerPath = value; 60 | } else if (strcmp(name, "--prompt") == 0) { 61 | args.prompt = value; 62 | } else if (strcmp(name, "--weights-float-type") == 0) { 63 | args.weightsFloatType = parseFloatType(value); 64 | } else if (strcmp(name, "--buffer-float-type") == 0) { 65 | args.bufferFloatType = parseFloatType(value); 66 | } else if (strcmp(name, "--workers") == 0) { 67 | int j = i + 1; 68 | for (; j < argc && argv[j][0] != '-'; j++); 69 | int count = j - i - 1; 70 | 71 | args.nWorkers = count; 72 | args.workerHosts = new char*[count]; 73 | args.workerPorts = new int[count]; 74 | 75 | for (int s = 0; s < count; s++) { 76 | char* v = argv[i + 1 + s]; 77 | char* sep = strstr(v, ":"); 78 | if (sep == NULL) { 79 | printf("Invalid address %s\n", v); 80 | exit(EXIT_FAILURE); 81 | } 82 | int hostLen = sep - v; 83 | args.workerHosts[s] = new char[hostLen + 1]; 84 | memcpy(args.workerHosts[s], v, hostLen); 85 | args.workerHosts[s][hostLen] = '\0'; 86 | args.workerPorts[s] = atoi(sep + 1); 87 | } 88 | 89 | i += count - 1; 90 | } else if (strcmp(name, "--port") == 0) { 91 | args.port = atoi(value); 92 | } else if (strcmp(name, "--nthreads") == 0) { 93 | args.nThreads = atoi(value); 94 | } else if (strcmp(name, "--steps") == 0) { 95 | args.steps = atoi(value); 96 | } else if (strcmp(name, "--temperature") == 0) { 97 | args.temperature = atof(value); 98 | } else if (strcmp(name, "--topp") == 0) { 99 | args.topp = atof(value); 100 | } else if (strcmp(name, "--seed") == 0) { 101 | args.seed = atoll(value); 102 | } else if (strcmp(name, "--chat-template") == 0) { 103 | args.chatTemplateType = parseChatTemplateType(value); 104 | } else if (strcmp(name, "--max-seq-len") == 0) { 105 | args.maxSeqLen = (unsigned int)atoi(value); 106 | } else if (strcmp(name, "--kv-cache-storage") == 0) { 107 | args.useDiscForKvCache = strcmp(value, "disc") == 0; 108 | } else if (strcmp(name, "--numa") == 0) { 109 | args.numa = atoi(value); 110 | } else { 111 | printf("Unknown option %s\n", name); 112 | exit(EXIT_FAILURE); 113 | } 114 | } 115 | return args; 116 | } 117 | 118 | TransformerArch TransformerArchFactory::create(TransformerSpec* spec) { 119 | if (spec->archType == LLAMA) return buildLlamaArch(spec); 120 | if (spec->archType == GROK1) return buildGrok1Arch(spec); 121 | if (spec->archType == MIXTRAL) return buildMixtralArch(spec); 122 | printf("Unsupported arch type: %d\n", spec->archType); 123 | exit(EXIT_FAILURE); 124 | } 125 | 126 | void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec)) { 127 | if (args->modelPath == NULL) { 128 | throw std::runtime_error("Model is required"); 129 | } 130 | if (args->tokenizerPath == NULL) { 131 | throw std::runtime_error("Tokenizer is required"); 132 | } 133 | 134 | SocketPool* socketPool = SocketPool::connect(args->nWorkers, args->workerHosts, args->workerPorts); 135 | unsigned int nSlices = args->nWorkers + 1; 136 | 137 | TransformerSpec spec = Transformer::loadSpecFromFile(args->modelPath, nSlices, args->maxSeqLen, args->weightsFloatType, args->bufferFloatType); 138 | TransformerArch arch = TransformerArchFactory::create(&spec); 139 | Tokenizer tokenizer(args->tokenizerPath, spec.vocabSize); 140 | 141 | if (args->steps == 0 || args->steps > spec.seqLen) { 142 | args->steps = spec.seqLen; 143 | } 144 | 145 | TransformerConfig config; 146 | config.useDiscForKvCache = args->useDiscForKvCache; 147 | 148 | Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, &config, socketPool); 149 | socketPool->setTurbo(true); 150 | 151 | Inference inference = Inference(&arch, args->nThreads, &transformer, socketPool); 152 | 153 | Sampler sampler(spec.vocabSize, args->temperature, args->topp, args->seed); 154 | 155 | program(&inference, socketPool, &tokenizer, &sampler, args, &spec); 156 | 157 | delete socketPool; 158 | } 159 | -------------------------------------------------------------------------------- /src/app.hpp: -------------------------------------------------------------------------------- 1 | #ifndef APP_HPP 2 | #define APP_HPP 3 | 4 | #include "quants.hpp" 5 | #include "transformer.hpp" 6 | #include "utils.hpp" 7 | #include "utils.hpp" 8 | #include "app.hpp" 9 | #include "transformer.hpp" 10 | #include "tasks.hpp" 11 | #include "llama2-tasks.hpp" 12 | #include "grok1-tasks.hpp" 13 | #include "mixtral-tasks.hpp" 14 | #include "tokenizer.hpp" 15 | 16 | class AppArgs { 17 | public: 18 | char* mode; 19 | int nThreads; 20 | bool useDiscForKvCache; 21 | int numa; 22 | 23 | // inference 24 | char* modelPath; 25 | char* tokenizerPath; 26 | char* prompt; 27 | FloatType weightsFloatType; 28 | FloatType bufferFloatType; 29 | int nWorkers; 30 | char** workerHosts; 31 | int* workerPorts; 32 | float temperature; 33 | float topp; 34 | pos_t steps; 35 | bool benchmark; 36 | unsigned long long seed; 37 | ChatTemplateType chatTemplateType; 38 | unsigned int maxSeqLen; 39 | 40 | // worker 41 | int port; 42 | 43 | static AppArgs parse(int argc, char** argv, bool hasMode); 44 | }; 45 | 46 | class TransformerArchFactory { 47 | public: 48 | static TransformerArch create(TransformerSpec* spec); 49 | }; 50 | 51 | class App { 52 | public: 53 | static void run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec)); 54 | }; 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /src/apps/dllama-api/README.md: -------------------------------------------------------------------------------- 1 | # Distributed Llama API 2 | 3 | This is an early version of the server that is compatible with the OpenAi API. It supports only the `/v1/chat/completions` endpoint. To run this server you need a chat model and a tokenizer with the chat support. 4 | 5 | How to run? 6 | 7 | 1. Download the model and the tokenizer from [here](https://huggingface.co/b4rtaz/Llama-3-8B-Q40-Instruct-Distributed-Llama). 8 | 2. Run the server with the following command: 9 | ```bash 10 | ./dllama-api --model converter/dllama_model_lama3_instruct_q40.m --tokenizer converter/dllama_tokenizer_llama3.t --weights-float-type q40 --buffer-float-type q80 --nthreads 4 11 | ``` 12 | 13 | Check the [chat-api-client.js](../../../examples/chat-api-client.js) file to see how to use the API from NodeJS application. 14 | -------------------------------------------------------------------------------- /src/apps/dllama-api/types.hpp: -------------------------------------------------------------------------------- 1 | #ifndef DLLAMA_API_TYPES_HPP 2 | #define DLLAMA_API_TYPES_HPP 3 | 4 | #include 5 | 6 | #include "../../common/json.hpp" 7 | 8 | using json = nlohmann::json; 9 | 10 | struct ChatMessageDelta { 11 | std::string role; 12 | std::string content; 13 | 14 | ChatMessageDelta() : role(""), content("") {} 15 | ChatMessageDelta(const std::string& role_, const std::string& content_) : role(role_), content(content_) {} 16 | }; 17 | 18 | struct ChatMessage { 19 | std::string role; 20 | std::string content; 21 | 22 | ChatMessage() : role(""), content("") {} 23 | ChatMessage(const std::string& role_, const std::string& content_) : role(role_), content(content_) {} 24 | }; 25 | 26 | struct ChunkChoice { 27 | int index; 28 | ChatMessageDelta delta; 29 | std::string finish_reason; 30 | 31 | ChunkChoice() : index(0) {} 32 | }; 33 | 34 | 35 | struct Choice { 36 | int index; 37 | ChatMessage message; 38 | std::string finish_reason; 39 | 40 | Choice() : finish_reason("") {} 41 | Choice(ChatMessage &message_) : message(message_), finish_reason("") {} 42 | Choice(const std::string &reason_) : finish_reason(reason_) {} 43 | }; 44 | 45 | struct ChatCompletionChunk { 46 | std::string id; 47 | std::string object; 48 | long long created; 49 | std::string model; 50 | std::vector choices; 51 | 52 | ChatCompletionChunk(ChunkChoice &choice_) 53 | : id("cmpl-c0"), object("chat.completion"), model("Distributed Model") { 54 | created = std::time(nullptr); // Set created to current Unix timestamp 55 | choices.push_back(choice_); 56 | } 57 | }; 58 | 59 | // Struct to represent the usage object 60 | struct ChatUsage { 61 | int prompt_tokens; 62 | int completion_tokens; 63 | int total_tokens; 64 | 65 | ChatUsage() : prompt_tokens(0), completion_tokens(0), total_tokens(0) {} 66 | ChatUsage(int pt, int ct, int tt) : prompt_tokens(pt), completion_tokens(ct), total_tokens(tt) {} 67 | }; 68 | 69 | struct ChatCompletion { 70 | std::string id; 71 | std::string object; 72 | long long created; // Unix timestamp 73 | std::string model; 74 | std::vector choices; 75 | ChatUsage usage; 76 | 77 | ChatCompletion() : id(), object(), model() {} 78 | ChatCompletion(const Choice &choice_, const ChatUsage& usage_) 79 | : id("cmpl-j0"), object("chat.completion"), model("Distributed Model"), usage(usage_) { 80 | created = std::time(nullptr); // Set created to current Unix timestamp 81 | choices.push_back(choice_); 82 | } 83 | }; 84 | 85 | struct InferenceParams { 86 | std::vector messages; 87 | int max_tokens; 88 | float temperature; 89 | float top_p; 90 | std::vector stop; 91 | bool stream; 92 | unsigned long long seed; 93 | }; 94 | 95 | // Define to_json for Delta struct 96 | void to_json(json& j, const ChatMessageDelta& msg) { 97 | j = json{{"role", msg.role}, {"content", msg.content}}; 98 | } 99 | 100 | void to_json(json& j, const ChatMessage& msg) { 101 | j = json{{"role", msg.role}, {"content", msg.content}}; 102 | } 103 | 104 | void to_json(json& j, const ChunkChoice& choice) { 105 | j = json{{"index", choice.index}, {"delta", choice.delta}, {"finish_reason", choice.finish_reason}}; 106 | } 107 | 108 | void to_json(json& j, const Choice& choice) { 109 | j = json{{"index", choice.index}, {"message", choice.message}, {"finish_reason", choice.finish_reason}}; 110 | } 111 | 112 | void to_json(json& j, const ChatCompletionChunk& completion) { 113 | j = json{{"id", completion.id}, 114 | {"object", completion.object}, 115 | {"created", completion.created}, 116 | {"model", completion.model}, 117 | {"choices", completion.choices}}; 118 | } 119 | 120 | void to_json(json& j, const ChatUsage& usage) { 121 | j = json{{"completion_tokens", usage.completion_tokens}, 122 | {"prompt_tokens", usage.prompt_tokens}, 123 | {"total_tokens", usage.total_tokens}}; 124 | } 125 | 126 | void to_json(json& j, const ChatCompletion& completion) { 127 | j = json{{"id", completion.id}, 128 | {"object", completion.object}, 129 | {"created", completion.created}, 130 | {"model", completion.model}, 131 | {"usage", completion.usage}, 132 | {"choices", completion.choices}}; 133 | } 134 | 135 | std::vector parseChatMessages(json &json){ 136 | std::vector messages; 137 | messages.reserve(json.size()); 138 | 139 | for (const auto& item : json) { 140 | messages.emplace_back( 141 | item["role"].template get(), 142 | item["content"].template get() 143 | ); 144 | } 145 | return messages; 146 | } 147 | 148 | #endif -------------------------------------------------------------------------------- /src/apps/dllama/dllama.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #if DLLAMA_USE_NUMA 11 | #include 12 | #endif 13 | 14 | #include "../../utils.hpp" 15 | #include "../../socket.hpp" 16 | #include "../../transformer.hpp" 17 | #include "../../tasks.hpp" 18 | #include "../../tokenizer.hpp" 19 | #include "../../app.hpp" 20 | 21 | void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) { 22 | if (args->prompt == NULL) 23 | throw std::runtime_error("Prompt is required"); 24 | 25 | // encode the (string) prompt into tokens sequence 26 | int numPromptTokens = 0; 27 | int* promptTokens = new int[strlen(args->prompt) + 3]; // +3 for '\0', ?BOS, ?EOS 28 | 29 | // TODO: this is a hack for Grok1. We should have a more general way to handle this 30 | bool addBos = spec->archType != GROK1; 31 | 32 | tokenizer->encode(args->prompt, promptTokens, &numPromptTokens, addBos, false); 33 | if (numPromptTokens < 1) 34 | throw std::runtime_error("Expected at least 1 prompt token"); 35 | 36 | // start the main loop 37 | long start = 0; // used to time our code, only initialized after first iteration 38 | int next; // will store the next token in the sequence 39 | int token = promptTokens[0]; // kick off with the first token in the prompt 40 | pos_t pos = 0; // position in the sequence 41 | 42 | unsigned long inferenceTime; 43 | unsigned long transferTime; 44 | size_t sentBytes; 45 | size_t recvBytes; 46 | unsigned long totalGenerationTime = 0; 47 | unsigned long totalInferenceTime = 0; 48 | unsigned long totalTransferTime = 0; 49 | while (pos < args->steps) { 50 | unsigned long startTime = timeMs(); 51 | float* logits = inference->infer(token, pos); 52 | 53 | inference->getStats(&inferenceTime, &transferTime); 54 | socketPool->getStats(&sentBytes, &recvBytes); 55 | 56 | // advance the state machine 57 | if (pos < numPromptTokens - 1) { 58 | // if we are still processing the input prompt, force the next prompt token 59 | next = promptTokens[pos + 1]; 60 | } else { 61 | // otherwise sample the next token from the logits 62 | next = sampler->sample(logits); 63 | } 64 | pos++; 65 | 66 | unsigned long generationTime = timeMs() - startTime; 67 | 68 | totalGenerationTime += generationTime; 69 | totalInferenceTime += inferenceTime; 70 | totalTransferTime += transferTime; 71 | 72 | // data-dependent terminating condition: the BOS token delimits sequences 73 | if (next == tokenizer->bosId) { 74 | break; 75 | } 76 | 77 | // print the token as string, decode it with the Tokenizer object 78 | char* piece = tokenizer->decode(token, next); 79 | 80 | if (args->benchmark) 81 | printf("🔶 G %4ld ms I %4ld ms T %4ld ms S %6ld kB R %6ld kB ", generationTime, inferenceTime, transferTime, sentBytes / 1024, recvBytes / 1024); 82 | safePrintf(piece); 83 | if (args->benchmark) 84 | printf("\n"); 85 | fflush(stdout); 86 | token = next; 87 | } 88 | 89 | delete[] promptTokens; 90 | 91 | if (!args->benchmark) printf("\n"); 92 | double avgGenerationTime = totalGenerationTime / (double)pos; 93 | printf("Generated tokens: %d\n", pos); 94 | printf("Avg tokens / second: %.2f\n", 1000.0 / avgGenerationTime); 95 | printf("Avg generation time: %.2f ms\n", avgGenerationTime); 96 | printf("Avg inference time: %.2f ms\n", totalInferenceTime / (double)pos); 97 | printf("Avg transfer time: %.2f ms\n", totalTransferTime / (double)pos); 98 | } 99 | 100 | size_t readStdin(const char* guide, char* buffer, size_t bufsize) { 101 | fflush(stdin); 102 | // read a line from stdin, up to but not including \n 103 | printf("%s", guide); 104 | if (fgets(buffer, bufsize, stdin) != NULL) { 105 | size_t len = strlen(buffer); 106 | if (len > 0 && buffer[len - 1] == '\n') { 107 | buffer[len - 1] = '\0'; // strip newline 108 | len--; 109 | } 110 | return len; 111 | } 112 | return 0; 113 | } 114 | 115 | class Chat { 116 | private: 117 | Inference* inference; 118 | Tokenizer* tokenizer; 119 | Sampler* sampler; 120 | AppArgs* args; 121 | TransformerSpec* spec; 122 | ChatTemplate* chatTemplate; 123 | EosDetector* eosDetector; 124 | 125 | public: 126 | Chat(Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, EosDetector* eosDetector, ChatTemplate* chatTemplate) { 127 | this->inference = inference; 128 | this->tokenizer = tokenizer; 129 | this->sampler = sampler; 130 | this->args = args; 131 | this->spec = spec; 132 | this->eosDetector = eosDetector; 133 | this->chatTemplate = chatTemplate; 134 | } 135 | 136 | void chat() { 137 | char inputBuffer[2048]; 138 | 139 | size_t sysPromptLength = readStdin("💻 System prompt (optional): ", inputBuffer, sizeof(inputBuffer)); 140 | std::vector deltaItems; 141 | if (sysPromptLength > 0) { 142 | deltaItems.push_back(ChatItem{"system", inputBuffer}); 143 | } 144 | 145 | pos_t pos = 0; 146 | int token; 147 | do { 148 | size_t userPromptLength; 149 | do { 150 | userPromptLength = readStdin("\n👱 User\n> ", inputBuffer, sizeof(inputBuffer)); 151 | } while (userPromptLength == 0); 152 | 153 | deltaItems.push_back(ChatItem{"user", inputBuffer}); 154 | 155 | size_t nChatItems = deltaItems.size(); 156 | ChatItem chatItems[nChatItems]; 157 | for (size_t j = 0; j < nChatItems; j++) { 158 | chatItems[j].role = deltaItems[j].role; 159 | chatItems[j].message = deltaItems[j].message; 160 | } 161 | std::string inputPrompt = chatTemplate->generate(deltaItems.size(), chatItems, true); 162 | 163 | int* inputTokens = new int[inputPrompt.size() + 3]; 164 | int nInputTokens; 165 | tokenizer->encode((char*)inputPrompt.c_str(), inputTokens, &nInputTokens, true, false); 166 | 167 | pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, pos + nInputTokens - 1); 168 | for (pos_t i = 0; pos < userPromptEndPos; pos++, i++) { 169 | inference->infer(inputTokens[i], pos); 170 | token = inputTokens[i + 1]; 171 | } 172 | 173 | printf("\n🤖 Assistant\n"); 174 | 175 | for (; pos < spec->seqLen; pos++) { 176 | int prevToken = token; 177 | float* logits = inference->infer(token, pos); 178 | token = sampler->sample(logits); 179 | char* piece = tokenizer->decode(prevToken, token); 180 | bool isSafe = isSafePiece(piece); 181 | EosDetectorType eosType = eosDetector->append(token, isSafe ? piece : ""); 182 | if (eosType == NOT_EOS || eosType == EOS) { 183 | char* delta = eosDetector->getDelta(); 184 | if (delta != NULL) { 185 | printf("%s", delta); 186 | fflush(stdout); 187 | } 188 | eosDetector->clear(); 189 | } 190 | if (eosType == EOS) break; 191 | } 192 | 193 | inputPrompt.clear(); 194 | } while (pos < spec->seqLen); 195 | 196 | printf("(end of context)\n"); 197 | } 198 | }; 199 | 200 | void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) { 201 | TokenizerChatStops stops(tokenizer); 202 | ChatTemplate chatTemplate(args->chatTemplateType, tokenizer->chatTemplate, stops.stops[0]); 203 | EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength); 204 | 205 | Chat chat(inference, tokenizer, sampler, args, spec, &eosDetector, &chatTemplate); 206 | chat.chat(); 207 | } 208 | 209 | void worker(AppArgs* args) { 210 | if (args->port < 1024) { 211 | throw std::runtime_error("Invalid port number"); 212 | } 213 | 214 | TransformerConfig config; 215 | config.useDiscForKvCache = args->useDiscForKvCache; 216 | 217 | SocketServer server(args->port); 218 | Socket socket = server.accept(); 219 | TransformerSpec spec; 220 | Transformer transformer = Transformer::loadSlice(&spec, &config, &socket); 221 | TransformerArch arch = TransformerArchFactory::create(&spec); 222 | 223 | Worker worker = Worker(&arch, args->nThreads, &transformer, &socket); 224 | worker.work(); 225 | } 226 | 227 | int main(int argc, char *argv[]) { 228 | initQuants(); 229 | initSockets(); 230 | 231 | AppArgs args = AppArgs::parse(argc, argv, true); 232 | bool success = false; 233 | 234 | if (args.numa >= 0) { 235 | #if DLLAMA_USE_NUMA 236 | numa_run_on_node(args.numa); 237 | #else 238 | fprintf(stderr, "Application was compiled without DLLAMA_USE_NUMA option, ignoring NUMA settings!\n"); 239 | #endif 240 | } 241 | 242 | if (args.mode != NULL) { 243 | if (strcmp(args.mode, "inference") == 0) { 244 | args.benchmark = true; 245 | App::run(&args, generate); 246 | success = true; 247 | } else if (strcmp(args.mode, "generate") == 0) { 248 | args.benchmark = false; 249 | App::run(&args, generate); 250 | success = true; 251 | } else if (strcmp(args.mode, "chat") == 0) { 252 | App::run(&args, chat); 253 | success = true; 254 | } else if (strcmp(args.mode, "worker") == 0) { 255 | worker(&args); 256 | success = true; 257 | } 258 | } 259 | 260 | cleanupSockets(); 261 | 262 | if (success) 263 | return EXIT_SUCCESS; 264 | fprintf(stderr, "Invalid usage\n"); 265 | return EXIT_FAILURE; 266 | } 267 | -------------------------------------------------------------------------------- /src/apps/socket-benchmark/socket-benchmark.cpp: -------------------------------------------------------------------------------- 1 | #include "../../socket.hpp" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace std::chrono; 10 | 11 | unsigned int packageSizes[] = { 128, 256, 512, 768, 1024, 1280, 1518, 2048, 4096, 8192, 16384, 32768, 65536 }; 12 | unsigned int packageSizesCount = sizeof(packageSizes) / sizeof(unsigned int); 13 | unsigned int maPackageSize = packageSizes[packageSizesCount - 1]; 14 | unsigned int nAttempts = 5000; 15 | int port = 7721; 16 | bool testTcp = true; 17 | 18 | void setNonBlocking(int socket) { 19 | //int flags = fcntl(socket, F_GETFL, 0); 20 | //if (fcntl(socket, F_SETFL, flags |= O_NONBLOCK) < 0) 21 | // throw std::runtime_error("Cannot set socket flags"); 22 | } 23 | 24 | #define MAX_PACKAGE_SIZE 1280 25 | 26 | char pktinfo[4096] = {0}; 27 | 28 | void readUdpSocket(int socket, char* buffer, unsigned int size, struct sockaddr_in* clientAddr, socklen_t* clientAddrLen) { 29 | struct msghdr msg; 30 | struct iovec iov; 31 | int received_ttl = 0; 32 | char buf[CMSG_SPACE(sizeof(received_ttl))]; 33 | iov.iov_base = buffer; 34 | iov.iov_len = size; 35 | msg.msg_name = clientAddr; 36 | msg.msg_namelen = *clientAddrLen; 37 | msg.msg_iov = &iov; 38 | msg.msg_iovlen = 1; 39 | msg.msg_control = 0; 40 | msg.msg_controllen = 0; 41 | for (;;) { 42 | ssize_t s0 = recvmsg(socket, &msg, MSG_DONTWAIT); 43 | if (s0 == size) { 44 | //printf("read\n"); 45 | return; 46 | } 47 | if (s0 <= 0) { 48 | if (errno == EAGAIN || errno == EWOULDBLOCK) 49 | continue; 50 | printf("error read: %s\n", strerror(errno)); 51 | throw std::runtime_error("Cannot read from socket"); 52 | } 53 | }; 54 | } 55 | 56 | void writeUdpSocket(int socket, char* buffer, unsigned int size, struct sockaddr_in* clientAddr, socklen_t clientAddrLen) { 57 | struct msghdr msg; 58 | struct iovec iov; 59 | int received_ttl = 0; 60 | char buf[CMSG_SPACE(sizeof(received_ttl))]; 61 | iov.iov_base = buffer; 62 | iov.iov_len = size; 63 | msg.msg_name = clientAddr; 64 | msg.msg_namelen = clientAddrLen; 65 | msg.msg_iov = &iov; 66 | msg.msg_iovlen = 1; 67 | msg.msg_control = 0; 68 | msg.msg_controllen = 0; 69 | for (;;) { 70 | ssize_t s0 = sendmsg(socket, &msg, 0); 71 | if (s0 == size) { 72 | //printf("sent\n"); 73 | return; 74 | } 75 | if (s0 <= 0) { 76 | if (errno == EAGAIN || errno == EWOULDBLOCK) 77 | continue; 78 | printf("error write: %s\n", strerror(errno)); 79 | throw std::runtime_error("Cannot write to socket"); 80 | } 81 | } 82 | } 83 | 84 | void server() { 85 | printf("nAttempts: %d\n", nAttempts); 86 | char buffer[maPackageSize]; 87 | 88 | if (testTcp) { 89 | printf("TCP test\n"); 90 | 91 | SocketServer server(port); 92 | Socket socket = server.accept(); 93 | for (long i = 0; i < packageSizesCount; i++) { 94 | unsigned int currentPackageSize = packageSizes[i]; 95 | 96 | long long totalReadTime = 0; 97 | long long totalWriteTime = 0; 98 | long long totalTime = 0; // [us] 99 | for (long a = 0; a < nAttempts; a++) { 100 | auto t0 = high_resolution_clock::now(); 101 | socket.read(buffer, currentPackageSize); 102 | auto t1 = high_resolution_clock::now(); 103 | socket.write(buffer, currentPackageSize); 104 | auto t2 = high_resolution_clock::now(); 105 | 106 | totalReadTime += duration_cast(t1 - t0).count(); 107 | totalWriteTime += duration_cast(t2 - t1).count(); 108 | totalTime += duration_cast(t2 - t0).count(); 109 | } 110 | 111 | double nPingPongs = (1.0 / (totalTime / 1000000.0)) * (double)nAttempts; 112 | printf("[%6d bytes] write: %5lld us, read: %5lld us, total: %5lld us, nPingPongs: %.2f\n", 113 | currentPackageSize, totalWriteTime, totalReadTime, totalTime, nPingPongs); 114 | } 115 | } 116 | 117 | printf("UDP test\n"); 118 | 119 | { 120 | int serverSocket = ::socket(AF_INET, SOCK_DGRAM, 0); 121 | 122 | struct sockaddr_in serverAddr; 123 | struct sockaddr_in clientAddr; 124 | socklen_t clientAddrLen = sizeof(clientAddr); 125 | memset(&serverAddr, 0, sizeof(serverAddr)); 126 | memset(&clientAddr, 0, sizeof(clientAddr)); 127 | serverAddr.sin_family = AF_INET; 128 | serverAddr.sin_addr.s_addr = INADDR_ANY; 129 | serverAddr.sin_port = htons(port); 130 | setNonBlocking(serverSocket); 131 | 132 | if (bind(serverSocket, (struct sockaddr *)&serverAddr, sizeof(serverAddr)) < 0) 133 | throw std::runtime_error("Cannot bind socket"); 134 | 135 | for (long i = 0; i < packageSizesCount; i++) { 136 | unsigned int currentPackageSize = packageSizes[i]; 137 | 138 | long long totalReadTime = 0; 139 | long long totalWriteTime = 0; 140 | long long totalTime = 0; // [us] 141 | 142 | //setsockopt(serverSocket, SOL_SOCKET, SO_RCVBUF, ¤tPackageSize, sizeof(currentPackageSize)); 143 | //setsockopt(serverSocket, SOL_SOCKET, SO_SNDBUF, ¤tPackageSize, sizeof(currentPackageSize)); 144 | 145 | for (long a = 0; a < nAttempts; a++) { 146 | auto t0 = high_resolution_clock::now(); 147 | 148 | readUdpSocket(serverSocket, buffer, currentPackageSize, &clientAddr, &clientAddrLen); 149 | 150 | auto t1 = high_resolution_clock::now(); 151 | 152 | writeUdpSocket(serverSocket, buffer, currentPackageSize, &clientAddr, clientAddrLen); 153 | 154 | auto t2 = high_resolution_clock::now(); 155 | 156 | totalReadTime += duration_cast(t1 - t0).count(); 157 | totalWriteTime += duration_cast(t2 - t1).count(); 158 | totalTime += duration_cast(t2 - t0).count(); 159 | } 160 | 161 | double nPingPongs = (1.0 / (totalTime / 1000000.0)) * (double)nAttempts; 162 | printf("[%6d bytes] write: %5lld us, read: %5lld us, total: %5lld us, nPingPongs: %.2f\n", 163 | currentPackageSize, totalWriteTime, totalReadTime, totalTime, nPingPongs); 164 | } 165 | } 166 | } 167 | 168 | void client(char* host) { 169 | char buffer[maPackageSize]; 170 | 171 | if (testTcp) { 172 | printf("TCP test\n"); 173 | 174 | char** hosts = new char*[1]; 175 | hosts[0] = host; 176 | int* ports = new int[1]; 177 | ports[0] = port; 178 | 179 | SocketPool* pool = SocketPool::connect(1, hosts, ports); 180 | pool->setTurbo(true); 181 | 182 | for (long i = 0; i < packageSizesCount; i++) { 183 | unsigned int currentPackageSize = packageSizes[i]; 184 | 185 | long long totalReadTime = 0; 186 | long long totalWriteTime = 0; 187 | long long totalTime = 0; // [us] 188 | for (long a = 0; a < nAttempts; a++) { 189 | auto t0 = high_resolution_clock::now(); 190 | pool->write(0, buffer, currentPackageSize); 191 | auto t1 = high_resolution_clock::now(); 192 | pool->read(0, buffer, currentPackageSize); 193 | auto t2 = high_resolution_clock::now(); 194 | 195 | totalWriteTime += duration_cast(t1 - t0).count(); 196 | totalReadTime += duration_cast(t2 - t1).count(); 197 | totalTime += duration_cast(t2 - t0).count(); 198 | } 199 | 200 | printf("[%6d bytes] write: %5lld us, read: %5lld us, total: %5lld us\n", 201 | currentPackageSize, totalWriteTime, totalReadTime, totalTime); 202 | } 203 | 204 | delete pool; 205 | delete[] hosts; 206 | delete[] ports; 207 | } 208 | 209 | printf("UDP test\n"); 210 | 211 | { 212 | int clientSocket = ::socket(AF_INET, SOCK_DGRAM, 0); 213 | struct sockaddr_in serverAddr; 214 | socklen_t serverAddrLen = sizeof(serverAddr); 215 | memset(&serverAddr, 0, sizeof(serverAddr)); 216 | serverAddr.sin_family = AF_INET; 217 | serverAddr.sin_port = htons(port); 218 | serverAddr.sin_addr.s_addr = inet_addr(host); 219 | setNonBlocking(clientSocket); 220 | 221 | for (long i = 0; i < packageSizesCount; i++) { 222 | unsigned int currentPackageSize = packageSizes[i]; 223 | 224 | //setsockopt(clientSocket, SOL_SOCKET, SO_RCVBUF, ¤tPackageSize, sizeof(currentPackageSize)); 225 | //setsockopt(clientSocket, SOL_SOCKET, SO_SNDBUF, ¤tPackageSize, sizeof(currentPackageSize)); 226 | 227 | long long totalReadTime = 0; 228 | long long totalWriteTime = 0; 229 | long long totalTime = 0; // [us] 230 | for (long a = 0; a < nAttempts; a++) { 231 | auto t0 = high_resolution_clock::now(); 232 | 233 | writeUdpSocket(clientSocket, buffer, currentPackageSize, &serverAddr, sizeof(serverAddr)); 234 | 235 | auto t1 = high_resolution_clock::now(); 236 | 237 | readUdpSocket(clientSocket, buffer, currentPackageSize, &serverAddr, &serverAddrLen); 238 | 239 | auto t2 = high_resolution_clock::now(); 240 | 241 | totalWriteTime += duration_cast(t1 - t0).count(); 242 | totalReadTime += duration_cast(t2 - t1).count(); 243 | totalTime += duration_cast(t2 - t0).count(); 244 | } 245 | 246 | printf("[%6d bytes] write: %5lld us, read: %5lld us, total: %5lld us\n", 247 | currentPackageSize, totalWriteTime, totalReadTime, totalTime); 248 | } 249 | } 250 | } 251 | 252 | int main(int argc, char *argv[]) { 253 | initSockets(); 254 | if (argc > 1 && strcmp(argv[1], "server") == 0) { 255 | server(); 256 | } else if (argc > 2 && strcmp(argv[1], "client") == 0) { 257 | client(argv[2]); 258 | } else { 259 | printf("Invalid arguments\n"); 260 | } 261 | } 262 | -------------------------------------------------------------------------------- /src/commands-test.cpp: -------------------------------------------------------------------------------- 1 | #include "commands.hpp" 2 | #include 3 | #include 4 | #include 5 | 6 | void testRopeSlice(int arch, const int nSliceTests, const int nPosTests, const int nThreadTests) { 7 | int dim = 4096; 8 | int headSize = 128; 9 | int nKvHeads = 8; 10 | int seqLen = 2048; 11 | int nHeads = dim / headSize; 12 | int kvDim = (dim * nKvHeads) / nHeads; 13 | int ropeTheta = 10000.0f; 14 | 15 | float* q = new float[dim]; 16 | float* k = new float[kvDim]; 17 | float* correctQ = new float[dim]; 18 | float* correctK = new float[kvDim]; 19 | 20 | for (int pos = 0; pos < seqLen; pos += seqLen / nPosTests) { 21 | for (int si = 0; si < nSliceTests; si++) { 22 | int nSlices = pow(2, si); 23 | 24 | for (int nThreads = 1; nThreads <= nThreadTests; nThreads++) { 25 | printf("pos=%d nSlices=%d threads=%d\n", pos, nSlices, nThreads); 26 | 27 | for (int j = 0; j < dim; j++) q[j] = 1.0; 28 | for (int j = 0; j < kvDim; j++) k[j] = 1.0; 29 | 30 | for (slice_index_t sliceIndex = 0; sliceIndex < nSlices; sliceIndex++) { 31 | RopeSlice slice(dim, kvDim, nKvHeads, nSlices, seqLen, headSize, ropeTheta, sliceIndex); 32 | RopeCommand* rope; 33 | if (arch == 1) { 34 | rope = new LlamaRopeCommand(&slice); 35 | } else if (arch == 2) { 36 | rope = new FalconRopeCommand(&slice); 37 | } 38 | 39 | for (int threadIndex = 0; threadIndex < nThreads; threadIndex++) { 40 | rope->forward( 41 | true, 42 | &q[(sliceIndex * dim) / nSlices], 43 | pos, nThreads, threadIndex); 44 | rope->forward( 45 | false, 46 | &k[(sliceIndex * kvDim) / nSlices], 47 | pos, nThreads, threadIndex); 48 | } 49 | 50 | delete rope; 51 | } 52 | 53 | if (si == 0 && nThreads == 1) { 54 | memcpy(correctQ, q, dim * sizeof(float)); 55 | memcpy(correctK, k, kvDim * sizeof(float)); 56 | } else { 57 | for (int j = 0; j < dim; j++) { 58 | if (fabs(q[j] - correctQ[j]) > 1e-6) { 59 | printf("q[%d] mismatch: %f != %f (arch=%d)\n", j, q[j], correctQ[j], arch); 60 | exit(EXIT_FAILURE); 61 | } 62 | } 63 | for (int j = 0; j < kvDim; j++) { 64 | if (fabs(k[j] - correctK[j]) > 1e-6) { 65 | printf("k[%d] mismatch: %f != %f (arch=%d)\n", j, k[j], correctK[j], arch); 66 | exit(EXIT_FAILURE); 67 | } 68 | } 69 | } 70 | } 71 | } 72 | } 73 | 74 | delete[] q; 75 | delete[] k; 76 | delete[] correctQ; 77 | delete[] correctK; 78 | printf("✅ ropeSlice (arch=%d)\n", arch); 79 | } 80 | 81 | int main() { 82 | testRopeSlice(2, 4, 6, 3); 83 | testRopeSlice(1, 6, 4, 3); 84 | return 0; 85 | } 86 | -------------------------------------------------------------------------------- /src/commands.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #ifdef _WIN32 4 | #define _USE_MATH_DEFINES 5 | #endif 6 | #include 7 | #include "utils.hpp" 8 | #include "funcs.hpp" 9 | #include "commands.hpp" 10 | 11 | RowMatmulSlice::RowMatmulSlice(FloatType type, int nSlices, int n, int d) { 12 | assert(d % nSlices == 0); 13 | 14 | this->type = type; 15 | this->nSlices = nSlices; 16 | this->d0 = d / nSlices; 17 | this->n = n; 18 | this->bytes = getBatchBytes(type, this->n, d); 19 | this->sliceBytes = getBatchBytes(type, this->n, this->d0); 20 | } 21 | 22 | size_t RowMatmulSlice::splitWeights(slice_index_t sliceIndex, char* weights, char* weights0) { 23 | int numbersPerBatch = getNumbersPerBatch(this->type); 24 | int batchBytes = getBatchBytes(this->type, numbersPerBatch, 1); 25 | 26 | int n = this->n / numbersPerBatch; 27 | size_t offset = this->d0 * sliceIndex * n * batchBytes; 28 | size_t copiedBytes = 0; 29 | 30 | for (int d = 0; d < this->d0; d++) { 31 | for (int j = 0; j < n; j++) { 32 | long o = (d * n + j) * batchBytes; 33 | 34 | memcpy(weights0 + o, weights + offset + o, batchBytes); 35 | copiedBytes += batchBytes; 36 | } 37 | } 38 | return copiedBytes; 39 | } 40 | 41 | unsigned int RowMatmulSlice::dOffset(slice_index_t sliceIndex) { 42 | return this->d0 * sliceIndex; 43 | } 44 | 45 | ColMatmulSlice::ColMatmulSlice(FloatType type, int nSlices, int n, int d) { 46 | assert(n % nSlices == 0); 47 | 48 | this->type = type; 49 | this->nSlices = nSlices; 50 | this->n = n; 51 | this->n0 = n / nSlices; 52 | this->d = d; 53 | this->bytes = getBatchBytes(type, n, d); 54 | this->sliceBytes = getBatchBytes(type, this->n0, d); 55 | } 56 | 57 | size_t ColMatmulSlice::splitWeights(slice_index_t sliceIndex, char* weights, char* weights0) { 58 | int numbersPerBatch = getNumbersPerBatch(this->type); 59 | int batchBytes = getBatchBytes(this->type, numbersPerBatch, 1); 60 | assert(n0 % numbersPerBatch == 0); 61 | 62 | int n = this->n / numbersPerBatch; 63 | int rowBytes = n * batchBytes; 64 | int row0Bytes = (n0 / numbersPerBatch) * batchBytes; 65 | int rowOffsetBytes = sliceIndex * row0Bytes; 66 | 67 | size_t copiedBytes = 0; 68 | for (int d = 0; d < this->d; d++) { 69 | memcpy(&weights0[row0Bytes * d], &weights[rowBytes * d + rowOffsetBytes], row0Bytes); 70 | copiedBytes += row0Bytes; 71 | } 72 | return copiedBytes; 73 | } 74 | 75 | RopeSlice::RopeSlice(unsigned int dim, unsigned int kvDim, unsigned int nKvHeads, unsigned int nSlices, unsigned int seqLen, unsigned int headSize, float ropeTheta, slice_index_t sliceIndex) { 76 | assert(dim >= kvDim); 77 | assert(dim % nSlices == 0); 78 | assert(kvDim % nSlices == 0); 79 | 80 | qDim0 = dim / nSlices; 81 | kvDim0 = kvDim / nSlices; 82 | assert(qDim0 % 2 == 0); 83 | assert(kvDim0 % 2 == 0); 84 | kvDimStart = kvDim0 * sliceIndex; 85 | qDimStart = qDim0 * sliceIndex; 86 | qDimEnd = qDimStart + qDim0; 87 | qShift = qDimStart - kvDimStart; 88 | sliceDim = qDimEnd - kvDimStart; 89 | this->kvDim = kvDim; 90 | this->nKvHeads = nKvHeads; 91 | this->seqLen = seqLen; 92 | this->headSize = headSize; 93 | this->ropeTheta = ropeTheta; 94 | assert(sliceDim % 2 == 0); 95 | } 96 | 97 | KvCacheSlice::KvCacheSlice(unsigned int kvDim, unsigned int seqLen, unsigned int nSlices) { 98 | assert(kvDim % nSlices == 0); 99 | kvDim0 = kvDim / nSlices; 100 | keyCacheSize = seqLen * kvDim0 * sizeof(float); 101 | valueCacheSize = seqLen * kvDim0 * sizeof(float); 102 | } 103 | 104 | MultiHeadAttSlice::MultiHeadAttSlice(unsigned int nHeads, unsigned int seqLen, unsigned int nSlices, slice_index_t sliceIndex) { 105 | assert(nHeads % nSlices == 0); 106 | nHeads0 = nHeads / nSlices; 107 | attSize = seqLen * nHeads0 * sizeof(float); 108 | } 109 | 110 | MatmulCommand::MatmulCommand(const unsigned int n, const unsigned int d, const FloatType inputFloatType, const FloatType weightsFloatType) { 111 | this->n = n; 112 | this->d = d; 113 | this->inputFloatType = inputFloatType; 114 | this->weightsFloatType = weightsFloatType; 115 | this->cpuSize = getBatchBytes(weightsFloatType, n, d); 116 | #if ALLOC_MEMORY 117 | this->cpuWeights = newBuffer(this->cpuSize); 118 | #endif 119 | }; 120 | 121 | MatmulCommand::~MatmulCommand() { 122 | #if ALLOC_MEMORY 123 | freeBuffer(cpuWeights); 124 | #endif 125 | } 126 | 127 | size_t MatmulCommand::loadWeights(const void* source) { 128 | #if ALLOC_MEMORY 129 | memcpy(cpuWeights, source, cpuSize); 130 | #else 131 | cpuWeights = (void*)source; 132 | #endif 133 | return cpuSize; 134 | } 135 | 136 | void MatmulCommand::forward(const void* input, float* output, const unsigned int nThreads, const unsigned int threadIndex) { 137 | matmul(weightsFloatType, inputFloatType, output, input, cpuWeights, n, d, nThreads, threadIndex); 138 | } 139 | 140 | LlamaRopeCommand::LlamaRopeCommand(RopeSlice *slice) { 141 | this->slice = slice; 142 | 143 | size_t cacheBytes = slice->seqLen * slice->sliceDim * sizeof(float); 144 | cache = (float*)newBuffer(cacheBytes); 145 | printf("🕒 ropeCacheSize: %ld kB\n", cacheBytes / 1024); 146 | 147 | for (pos_t pos = 0; pos < slice->seqLen; pos++) { 148 | for (unsigned int i = slice->kvDimStart; i < slice->qDimEnd; i += 2) { 149 | const unsigned int headDim = i % slice->headSize; 150 | const float freq = 1.0f / powf(slice->ropeTheta, headDim / (float)slice->headSize); 151 | const float val = pos * freq; 152 | const float fcr = cosf(val); 153 | const float fci = sinf(val); 154 | cache[pos * slice->sliceDim + (i - slice->kvDimStart)] = fcr; 155 | cache[pos * slice->sliceDim + (i - slice->kvDimStart) + 1] = fci; 156 | } 157 | } 158 | }; 159 | 160 | LlamaRopeCommand::~LlamaRopeCommand() { 161 | freeBuffer(cache); 162 | } 163 | 164 | void LlamaRopeCommand::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex) { 165 | const unsigned int dim0Half = (isQ ? slice->qDim0 : slice->kvDim0) / 2; 166 | const unsigned int shift = isQ ? slice->qShift : 0; 167 | SPLIT_RANGE_TO_THREADS(s, e, 0, dim0Half, nThreads, threadIndex); 168 | const unsigned int iStart = s * 2; 169 | const unsigned int iEnd = e * 2; 170 | 171 | for (unsigned int i = iStart; i < iEnd; i += 2) { 172 | float fcr = cache[pos * slice->sliceDim + shift + i]; 173 | float fci = cache[pos * slice->sliceDim + shift + i + 1]; 174 | float v0 = qOrK[i]; 175 | float v1 = qOrK[i + 1]; 176 | qOrK[i] = v0 * fcr - v1 * fci; 177 | qOrK[i + 1] = v0 * fci + v1 * fcr; 178 | } 179 | } 180 | 181 | Llama3_1RopeCommand::Llama3_1RopeCommand(RopeSlice *slice, float ropeScalingFactor, float ropeScalingLowFreqFactor, float ropeScalingHighFreqFactory, int ropeScalingOrigMaxSeqLen) { 182 | this->slice = slice; 183 | this->ropeScalingFactor = ropeScalingFactor; 184 | this->ropeScalingLowFreqFactor = ropeScalingLowFreqFactor; 185 | this->ropeScalingHighFreqFactory = ropeScalingHighFreqFactory; 186 | this->ropeScalingOrigMaxSeqLen = ropeScalingOrigMaxSeqLen; 187 | printf("🕒 ropeScalingFactor: %f\n", ropeScalingFactor); 188 | printf("🕒 ropeScalingLowFreqFactor: %f\n", ropeScalingLowFreqFactor); 189 | printf("🕒 ropeScalingHighFreqFactory: %f\n", ropeScalingHighFreqFactory); 190 | printf("🕒 ropeScalingOrigMaxSeqLen: %d\n", ropeScalingOrigMaxSeqLen); 191 | } 192 | 193 | float Llama3_1RopeCommand::scale(float freq) { 194 | float waveLen = 2.0f * M_PI * freq; 195 | float lowFreqWavelen = ropeScalingOrigMaxSeqLen / ropeScalingLowFreqFactor; 196 | float highFreqWavelen = ropeScalingOrigMaxSeqLen / ropeScalingHighFreqFactory; 197 | if (waveLen < highFreqWavelen) { 198 | return freq; 199 | } else if (waveLen > lowFreqWavelen) { 200 | return freq / ropeScalingFactor; 201 | } else { 202 | float smooth = (ropeScalingOrigMaxSeqLen / waveLen - ropeScalingLowFreqFactor) / (ropeScalingHighFreqFactory - ropeScalingLowFreqFactor); 203 | return (1 - smooth) * freq / ropeScalingFactor + smooth * freq; 204 | } 205 | } 206 | 207 | void Llama3_1RopeCommand::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex) { 208 | const unsigned int dim0Half = (isQ ? slice->qDim0 : slice->kvDim0) / 2; 209 | const unsigned int shift = isQ ? slice->qShift : 0; 210 | SPLIT_RANGE_TO_THREADS(s, e, 0, dim0Half, nThreads, threadIndex); 211 | const unsigned int iStart = s * 2; 212 | const unsigned int iEnd = e * 2; 213 | 214 | for (unsigned int i = iStart; i < iEnd; i += 2) { 215 | const unsigned int headDim = i % slice->headSize; 216 | const float freq = 1.0f / powf(slice->ropeTheta, headDim / (float)slice->headSize); 217 | const float val = pos * freq; 218 | const float fcr = cosf(val); 219 | const float fci = sinf(val); 220 | 221 | float v0 = qOrK[i]; 222 | float v1 = qOrK[i + 1]; 223 | 224 | qOrK[i] = scale(v0 * fcr - v1 * fci); 225 | qOrK[i + 1] = scale(v0 * fci + v1 * fcr); 226 | } 227 | } 228 | 229 | FalconRopeCommand::FalconRopeCommand(RopeSlice *slice) { 230 | this->slice = slice; 231 | } 232 | 233 | FalconRopeCommand::~FalconRopeCommand() {} 234 | 235 | void FalconRopeCommand::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex) { 236 | // TODO: this implementation allows only a small number of slices (because it requires dim0 % headSize == 0). This could be improved. 237 | unsigned int dimStart = isQ ? slice->qDimStart : slice->kvDimStart; 238 | unsigned int dim0 = isQ ? slice->qDim0 : slice->kvDim0; 239 | unsigned int headSize = isQ ? slice->headSize : slice->kvDim / slice->nKvHeads; 240 | assert(dimStart % headSize == 0); 241 | assert(dim0 % headSize == 0); 242 | unsigned int nHeads0 = dim0 / headSize; 243 | SPLIT_RANGE_TO_THREADS(h0s, h0e, 0, nHeads0, nThreads, threadIndex); 244 | 245 | for (unsigned int h = h0s; h < h0e; h++) { 246 | for (unsigned int j = 0; j < headSize / 2; j++) { 247 | float freq = 1.0f / powf(slice->ropeTheta, 2.0f * (float)j / (float)headSize); 248 | float val = pos * freq; 249 | float fcr = cosf(val); 250 | float fci = sinf(val); 251 | float q0 = qOrK[h * headSize + j]; 252 | float q1 = qOrK[h * headSize + j + headSize / 2]; 253 | qOrK[h * headSize + j] = q0 * fcr - q1 * fci; 254 | qOrK[h * headSize + j + headSize / 2] = q0 * fci + q1 * fcr; 255 | } 256 | } 257 | } -------------------------------------------------------------------------------- /src/commands.hpp: -------------------------------------------------------------------------------- 1 | #ifndef COMMANDS_HPP 2 | #define COMMANDS_HPP 3 | 4 | #include 5 | #include "quants.hpp" 6 | 7 | // RESPONSIBILITIES 8 | // 9 | // *Slice - calculates sizes, offsets, slice sizes etc. It is not responsible for memory allocation. It may help in the loading of data. 10 | // *Command - allocates memory for weights, performs calculations. 11 | 12 | typedef unsigned int pos_t; 13 | typedef uint8_t slice_index_t; 14 | 15 | class MatmulSlice { 16 | public: 17 | size_t bytes; 18 | size_t sliceBytes; 19 | virtual size_t splitWeights(slice_index_t sliceIndex, char* weights, char* weights0) = 0; 20 | }; 21 | 22 | class RowMatmulSlice : public MatmulSlice { 23 | public: 24 | FloatType type; 25 | int nSlices; 26 | int n; 27 | int d0; 28 | 29 | RowMatmulSlice(FloatType type, int nSlices, int n, int d); 30 | size_t splitWeights(slice_index_t sliceIndex, char* weights, char* weights0); 31 | unsigned int dOffset(slice_index_t sliceIndex); 32 | }; 33 | 34 | class ColMatmulSlice : public MatmulSlice { 35 | public: 36 | FloatType type; 37 | int nSlices; 38 | int n; 39 | int n0; 40 | int d; 41 | 42 | ColMatmulSlice(FloatType type, int nSlices, int n, int d); 43 | size_t splitWeights(slice_index_t sliceIndex, char* weights, char* weights0); 44 | }; 45 | 46 | class RopeSlice { 47 | public: 48 | unsigned int qDim0; 49 | unsigned int qDimStart; 50 | unsigned int qDimEnd; 51 | unsigned int qShift; 52 | unsigned int kvDim; 53 | unsigned int kvDim0; 54 | unsigned int kvDimStart; 55 | unsigned int sliceDim; 56 | unsigned int seqLen; 57 | unsigned int headSize; 58 | unsigned int nKvHeads; 59 | float ropeTheta; 60 | RopeSlice(unsigned int dim, unsigned int kvDim, unsigned int nKvHeads, unsigned int nSlices, unsigned int seqLen, unsigned int headSize, float ropeTheta, slice_index_t sliceIndex); 61 | }; 62 | 63 | class KvCacheSlice { 64 | public: 65 | unsigned int kvDim0; 66 | size_t keyCacheSize; 67 | size_t valueCacheSize; 68 | KvCacheSlice(unsigned int kvDim, unsigned int seqLen, unsigned int nSlices); 69 | }; 70 | 71 | class MultiHeadAttSlice { 72 | public: 73 | unsigned int nHeads0; 74 | size_t attSize; 75 | MultiHeadAttSlice(unsigned int nHeads, unsigned int seqLen, unsigned int nSlices, slice_index_t sliceIndex); 76 | }; 77 | 78 | class MatmulCommand { 79 | private: 80 | FloatType inputFloatType; 81 | FloatType weightsFloatType; 82 | unsigned int n; 83 | unsigned int d; 84 | size_t cpuSize; 85 | void* cpuWeights; 86 | public: 87 | MatmulCommand(const unsigned int n, const unsigned int d, const FloatType inputFloatType, const FloatType weightsFloatType); 88 | ~MatmulCommand(); 89 | size_t loadWeights(const void* source); 90 | void forward(const void* input, float* output, const unsigned int nThreads, const unsigned int threadIndex); 91 | }; 92 | 93 | class RopeCommand { 94 | public: 95 | virtual ~RopeCommand() {}; 96 | virtual void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex) = 0; 97 | }; 98 | 99 | class LlamaRopeCommand : public RopeCommand { 100 | private: 101 | RopeSlice* slice; 102 | float* cache; 103 | public: 104 | LlamaRopeCommand(RopeSlice *slice); 105 | ~LlamaRopeCommand(); 106 | void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex); 107 | }; 108 | 109 | class Llama3_1RopeCommand : public RopeCommand { 110 | private: 111 | RopeSlice* slice; 112 | float ropeScalingFactor; 113 | float ropeScalingLowFreqFactor; 114 | float ropeScalingHighFreqFactory; 115 | int ropeScalingOrigMaxSeqLen; 116 | public: 117 | Llama3_1RopeCommand(RopeSlice *slice, float ropeScalingFactor, float ropeScalingLowFreqFactor, float ropeScalingHighFreqFactory, int ropeScalingOrigMaxSeqLen); 118 | void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex); 119 | float scale(float freq); 120 | }; 121 | 122 | class FalconRopeCommand : public RopeCommand { 123 | private: 124 | RopeSlice* slice; 125 | public: 126 | FalconRopeCommand(RopeSlice *slice); 127 | ~FalconRopeCommand(); 128 | void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex); 129 | }; 130 | 131 | #endif 132 | -------------------------------------------------------------------------------- /src/common/pthread.h: -------------------------------------------------------------------------------- 1 | #ifndef PTHREAD_WRAPPER 2 | #define PTHREAD_WRAPPER 3 | 4 | #ifdef _WIN32 5 | #include 6 | 7 | typedef HANDLE dl_thread; 8 | typedef DWORD thread_ret_t; 9 | typedef DWORD (WINAPI *thread_func_t)(void *); 10 | 11 | static int pthread_create(dl_thread * out, void * unused, thread_func_t func, void * arg) { 12 | (void) unused; 13 | dl_thread handle = CreateThread(NULL, 0, func, arg, 0, NULL); 14 | if (handle == NULL) { 15 | return EAGAIN; 16 | } 17 | 18 | *out = handle; 19 | return 0; 20 | } 21 | 22 | static int pthread_join(dl_thread thread, void * unused) { 23 | (void) unused; 24 | DWORD ret = WaitForSingleObject(thread, INFINITE); 25 | if (ret == WAIT_FAILED) { 26 | return -1; 27 | } 28 | CloseHandle(thread); 29 | return 0; 30 | } 31 | #else 32 | #include 33 | 34 | typedef pthread_t dl_thread; 35 | typedef void* thread_ret_t; 36 | typedef void* (*thread_func_t)(void *); 37 | 38 | #endif 39 | 40 | #endif // PTHREAD_WRAPPER 41 | -------------------------------------------------------------------------------- /src/funcs-test.cpp: -------------------------------------------------------------------------------- 1 | #include "funcs.hpp" 2 | #include "utils.hpp" 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | void testRms() { 9 | float x[] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f}; 10 | float r = rms(x, 8); 11 | if (fabs(r - 1.980256) > 0.001) { 12 | printf("❌ rms() = %f\n", r); 13 | exit(EXIT_FAILURE); 14 | } 15 | printf("✅ rms\n"); 16 | } 17 | 18 | void testMatmulQ80() { 19 | const int n = 512; 20 | const int d = 256; 21 | unsigned long long state = 88888888L; 22 | float x[n]; 23 | float w[n * d]; 24 | float y[d]; 25 | float yQ0[d]; 26 | float yQ1[d]; 27 | int i; 28 | for (i = 0; i < n; i++) x[i] = randomF32(&state) / 127.0f; 29 | for (i = 0; i < n * d; i++) w[i] = randomF32(&state) / 127.0f; 30 | 31 | char* xQ = new char[getBatchBytes(Q80, n, 1)]; 32 | char* wQ = new char[getBatchBytes(Q80, n, d)]; 33 | quantizeQ80Row(x, (BlockQ80*)xQ, n, 1, 0); 34 | quantizeQ80Row(w, (BlockQ80*)wQ, n * d, 1, 0); 35 | 36 | matmul(F32, F32, y, x, w, n, d, 1, 0); 37 | matmul(Q80, F32, yQ0, x, wQ, n, d, 1, 0); 38 | matmul(Q80, Q80, yQ1, xQ, wQ, n, d, 1, 0); 39 | 40 | for (i = 0; i < d; i++) { 41 | float diff = fabs(y[i] - yQ0[i]); 42 | if (diff > 0.001) { 43 | printf("❌ matmulQ80() ix=%d %f != %f diff=%f\n", i, y[i], yQ0[i], diff); 44 | exit(EXIT_FAILURE); 45 | } 46 | } 47 | printf("✅ matmulQ80\n"); 48 | 49 | for (i = 0; i < d; i++) { 50 | float diff = fabs(y[i] - yQ1[i]); 51 | if (diff > 0.001) { 52 | printf("❌ matmulQ80vQ80() ix=%d %f != %f diff=%f\n", i, y[i], yQ1[i], diff); 53 | exit(EXIT_FAILURE); 54 | } 55 | } 56 | printf("✅ matmulQ80vQ80\n"); 57 | 58 | delete[] xQ; 59 | delete[] wQ; 60 | } 61 | 62 | void testAdd() { 63 | const int n = 16; 64 | float a[n]; 65 | float b[n]; 66 | 67 | for (int nThreads = 1; nThreads < 8; nThreads++) { 68 | for (int i = 0; i < n; i++) { 69 | a[i] = (float)-i; 70 | b[i] = (float)i; 71 | } 72 | 73 | for (int threadIndex = 0; threadIndex < nThreads; threadIndex++) { 74 | add(a, b, n, nThreads, threadIndex); 75 | } 76 | 77 | for (int i = 0; i < n; i++) { 78 | if (fabs(a[i]) > 0.001) { 79 | printf("❌ add() = %f (nThreads=%d)\n", a[i], nThreads); 80 | exit(EXIT_FAILURE); 81 | } 82 | } 83 | } 84 | 85 | printf("✅ add\n"); 86 | } 87 | 88 | void assertInt(int a, int b) { 89 | if (a != b) { 90 | printf("❌ %d != %d\n", a, b); 91 | exit(EXIT_FAILURE); 92 | } 93 | } 94 | 95 | void testSplitRangeToThreads() { 96 | // <0; 32> across 3 threads 97 | { 98 | SPLIT_RANGE_TO_THREADS(a0Start, a0End, 0, 32, 3, 0); // thread 0 99 | assertInt(a0Start, 0); 100 | assertInt(a0End, 11); 101 | } 102 | { 103 | SPLIT_RANGE_TO_THREADS(a1Start, a1End, 0, 32, 3, 1); // thread 1 104 | assertInt(a1Start, 11); 105 | assertInt(a1End, 22); 106 | } 107 | { 108 | SPLIT_RANGE_TO_THREADS(a2Start, a2End, 0, 32, 3, 2); // thread 2 109 | assertInt(a2Start, 22); 110 | assertInt(a2End, 32); 111 | } 112 | 113 | // <0; 4> across 8 threads 114 | { 115 | SPLIT_RANGE_TO_THREADS(b0Start, b0End, 0, 4, 8, 0); // thread 0 116 | assertInt(b0Start, 0); 117 | assertInt(b0End, 1); 118 | } 119 | { 120 | SPLIT_RANGE_TO_THREADS(b0Start, b0End, 0, 4, 8, 3); // thread 3 121 | assertInt(b0Start, 3); 122 | assertInt(b0End, 4); 123 | } 124 | { 125 | SPLIT_RANGE_TO_THREADS(b0Start, b0End, 0, 4, 8, 4); // thread 4 126 | assertInt(b0Start, 4); 127 | assertInt(b0End, 4); 128 | } 129 | { 130 | SPLIT_RANGE_TO_THREADS(b0Start, b0End, 0, 4, 8, 7); // thread 7 131 | assertInt(b0Start, 4); 132 | assertInt(b0End, 4); 133 | } 134 | 135 | printf("✅ SPLIT_RANGE_TO_THREADS\n"); 136 | } 137 | 138 | int main() { 139 | initQuants(); 140 | 141 | testRms(); 142 | testMatmulQ80(); 143 | testAdd(); 144 | testSplitRangeToThreads(); 145 | return EXIT_SUCCESS; 146 | } -------------------------------------------------------------------------------- /src/funcs.hpp: -------------------------------------------------------------------------------- 1 | #ifndef FUNCS_HPP 2 | #define FUNCS_HPP 3 | 4 | #include "quants.hpp" 5 | 6 | void softmax(float* x, const unsigned int size); 7 | float rms(const float* x, const unsigned int size); 8 | void rmsnorm(float* o, const float* x, const float ms, const float* weight, const unsigned int size, const unsigned int nThreads, const unsigned int threadIndex); 9 | void matmul(const FloatType weightsFloatType, const FloatType inputFloatType, float* output, const void* input, const void* weights, const unsigned int n, const unsigned int d, const unsigned int nThreads, const unsigned int threadIndex); 10 | float dotProduct(const float* a, const float* b, const unsigned int size); 11 | void gelu(float* t, const unsigned int n, const unsigned int nThreads, const unsigned int threadIndex); 12 | void silu(float* t, const unsigned int n, const unsigned int nThreads, const unsigned int threadIndex); 13 | void mul(float* output, const float* input, const unsigned int n, const unsigned int nThreads, const unsigned int threadIndex); 14 | void mulScalar(float* output, const float c, const unsigned int n, const unsigned int nThreads, const unsigned int threadIndex); 15 | void add(float* output, const float* input, const unsigned int n, const unsigned int nThreads, const unsigned int threadIndex); 16 | 17 | #endif 18 | -------------------------------------------------------------------------------- /src/grok1-tasks-test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "utils.hpp" 7 | #include "funcs.hpp" 8 | #include "transformer.hpp" 9 | #include "tasks.hpp" 10 | #include "llama2-tasks.hpp" 11 | #include "grok1-tasks.hpp" 12 | 13 | float expectedOutput_0_4[] = { 0.00940248929, 0.0191232786, 0.0147766126, 0.0102868658 }; 14 | float expectedOutput_256_260[] = { 0.0191071425, 0.0134582901, 0.0146755828, 0.019181719 }; 15 | float expectedOutput_5012_5016[] = { 0.0126675405, 0.0169415697, 0.0183475353, 0.0182626117 }; 16 | 17 | void compare(float* a, float* b, int n) { 18 | for (int i = 0; i < n; i++) { 19 | if (std::isnan(a[i]) || fabs(a[i] - b[i]) > 0.000035) { // Optimization may cause some differences 20 | printf("%.9g != %.9g\n", a[i], b[i]); i++; 21 | printf("%.9g != %.9g\n", a[i], b[i]); i++; 22 | printf("%.9g != %.9g\n", a[i], b[i]); i++; 23 | printf("%.9g != %.9g\n", a[i], b[i]); i++; 24 | exit(EXIT_FAILURE); 25 | } 26 | } 27 | } 28 | 29 | int main() { 30 | TransformerSpec spec; 31 | spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int); 32 | spec.archType = GROK1; 33 | spec.ropeType = ROPE_FALCON; 34 | spec.dim = 6144; 35 | spec.nLayers = 1; 36 | spec.nHeads = 48; 37 | spec.headSize = spec.dim / spec.nHeads; 38 | spec.nKvHeads = 8; 39 | spec.seqLen = 8192; 40 | spec.hiddenDim = 1024; 41 | spec.kvDim = (spec.dim * spec.nKvHeads) / spec.nHeads; 42 | spec.vocabSize = 1024; 43 | spec.nExperts = 8; 44 | spec.nActiveExperts = 2; 45 | spec.weightsFloatType = F32; 46 | spec.bufferFloatType = F32; 47 | spec.nSlices = 1; 48 | spec.hiddenAct = GELU; 49 | spec.ropeTheta = 10000.0f; 50 | 51 | TransformerConfig config; 52 | config.useDiscForKvCache = false; 53 | 54 | size_t beforeBlockBytes = spec.dim * spec.vocabSize * sizeof(float); 55 | size_t blockBytes = 956596224; 56 | size_t afterBlockBytes = (spec.dim + spec.dim * spec.vocabSize) * sizeof(float); 57 | spec.fileSize = spec.headerSize + beforeBlockBytes + blockBytes + afterBlockBytes; 58 | 59 | char* weights = (char*)newBuffer(beforeBlockBytes + blockBytes + afterBlockBytes); 60 | long nFloats = blockBytes / sizeof(float); 61 | float* block = (float*)&weights[beforeBlockBytes]; 62 | 63 | unsigned long long state = 123456789L; 64 | for (int f = 0; f < nFloats; f++) block[f] = randomF32(&state) / 100.0; 65 | 66 | SocketPool socketPool(0, NULL); 67 | Transformer transformer = Transformer::loadRoot(weights, &spec, &config, &socketPool); 68 | transformer.pos = 0; 69 | 70 | float* x = transformer.x; 71 | for (int i = 0; i < spec.dim; i++) x[i] = (randomF32(&state) / 100.0) / 78.38367176906169f; 72 | 73 | TransformerArch arch = buildGrok1Arch(&spec); 74 | 75 | int nThreads = 4; 76 | TransformerContext context; 77 | context.transformer = &transformer; 78 | context.currentBlockIndex = 0; 79 | context.socket = NULL; 80 | context.socketPool = &socketPool; 81 | 82 | int skipLastNTasks = 4; 83 | TaskLoop loop(nThreads, arch.inference.nTasks - skipLastNTasks, TASK_N_TYPES, arch.inference.tasks, &context); 84 | long t0 = timeMs(); 85 | loop.run(); 86 | long t1 = timeMs(); 87 | 88 | freeBuffer(weights); 89 | 90 | compare(&x[0], expectedOutput_0_4, 4); 91 | compare(&x[256], expectedOutput_256_260, 4); 92 | compare(&x[5012], expectedOutput_5012_5016, 4); 93 | 94 | printf("✅ Block forwarded correctly in %ldms\n", t1 - t0); 95 | } 96 | -------------------------------------------------------------------------------- /src/grok1-tasks.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "utils.hpp" 5 | #include "funcs.hpp" 6 | #include "socket.hpp" 7 | #include "tasks.hpp" 8 | #include "llama2-tasks.hpp" 9 | #include "grok1-tasks.hpp" 10 | 11 | void grokMulInput(TASK_ARGS) { 12 | TASK_VARIABLES; 13 | mulScalar(transformer->x, 78.38367176906169f, transformer->spec->dim, nThreads, threadIndex); 14 | } 15 | 16 | void grokRmfFfn(TASK_ARGS) { 17 | TASK_VARIABLES; 18 | if (threadIndex == 0) { 19 | float* xb2 = (float*)transformer->buffer->getUnit(TB_SLICED_XB2); 20 | memset(xb2, 0, spec->dim * sizeof(float)); 21 | for (uint8_t s = 0; s < spec->nSlices; s++) { 22 | float* xbv = (float*)transformer->buffer->getSliced(TB_SLICED_XBV, s); 23 | add(xb2, xbv, spec->dim, 1, 0); 24 | } 25 | transformer->rms = rms(xb2, spec->dim); 26 | } 27 | } 28 | 29 | void grokRmfFfnNorm(TASK_ARGS) { 30 | TASK_VARIABLES; 31 | float* xb2 = (float*)transformer->buffer->getUnit(TB_SLICED_XB2); 32 | 33 | rmsnorm(xb2, xb2, transformer->rms, block->rmsFfn, spec->dim, nThreads, threadIndex); 34 | } 35 | 36 | void grokRmfFfnNormJoin(TASK_ARGS) { 37 | TASK_VARIABLES; 38 | 39 | float* xb2 = (float*)transformer->buffer->getUnit(TB_SLICED_XB2); 40 | add(transformer->x, xb2, spec->dim, nThreads, threadIndex); 41 | } 42 | 43 | void grokMoeRms(TASK_ARGS) { 44 | TASK_VARIABLES; 45 | if (threadIndex == 0) { 46 | transformer->rms = rms(transformer->x, spec->dim); 47 | } 48 | } 49 | 50 | void grokMoeRmsNorm(TASK_ARGS) { 51 | TASK_VARIABLES; 52 | float* xb = (float*)transformer->buffer->getUnit(TB_UNIT_XB); 53 | rmsnorm(xb, transformer->x, transformer->rms, block->rmsMoe, spec->dim, nThreads, threadIndex); 54 | } 55 | 56 | void grokMoeRouter(TASK_ARGS) { 57 | TASK_VARIABLES; 58 | float* xb = (float*)transformer->buffer->getUnit(TB_UNIT_XB); 59 | block->moeRouterMm->forward(xb, block->moeRouterProbs, nThreads, threadIndex); 60 | } 61 | 62 | void grokMoeRouterSoftmax(TASK_ARGS) { 63 | TASK_VARIABLES; 64 | if (threadIndex == 0) { 65 | softmax(block->moeRouterProbs, spec->nExperts); 66 | } 67 | } 68 | 69 | void grokMoeTopk(TASK_ARGS) { 70 | TASK_VARIABLES; 71 | if (threadIndex == 0) { 72 | assert(spec->nActiveExperts == 2); // TODO 73 | uint8_t* indexes = (uint8_t*)transformer->buffer->getUnit(TB_UNIT_MOE_INDEXES); 74 | 75 | int best0i = -1; 76 | int best1i = -1; 77 | float best0v; 78 | float best1v; 79 | for (int i = 0; i < spec->nExperts; i++) { 80 | float prob = block->moeRouterProbs[i]; 81 | if (best0i == -1 || best0v < prob) { 82 | if ((best0i != -1 && best1i == -1) || best1v < best0v) { 83 | best1v = best0v; 84 | best1i = best0i; 85 | } 86 | best0i = i; 87 | best0v = prob; 88 | } else if (best1i == -1 || best1v < prob) { 89 | best1i = i; 90 | best1v = prob; 91 | } 92 | } 93 | 94 | indexes[0] = (uint8_t)best0i; 95 | indexes[1] = (uint8_t)best1i; 96 | } 97 | } 98 | 99 | void grokMoeNormWeights(TASK_ARGS) { 100 | TASK_VARIABLES; 101 | if (threadIndex == 0) { 102 | uint8_t* indexes = (uint8_t*)transformer->buffer->getUnit(TB_UNIT_MOE_INDEXES); 103 | float* weights = (float*)transformer->buffer->getUnit(TB_UNIT_MOE_WEIGHTS); 104 | 105 | float sum = 0.0; 106 | int i; 107 | for (i = 0; i < spec->nActiveExperts; i++) { 108 | sum += block->moeRouterProbs[indexes[i]]; 109 | } 110 | for (i = 0; i < spec->nActiveExperts; i++) { 111 | weights[i] = block->moeRouterProbs[indexes[i]] / sum; 112 | } 113 | } 114 | } 115 | 116 | void grokQuantizeMoeInput(TASK_ARGS) { 117 | TASK_VARIABLES; 118 | quantizeUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB, TB_UNIT_XB_QUANTIZED); 119 | } 120 | 121 | void grokSyncMoeInput(TASK_ARGS) { 122 | TASK_VARIABLES; 123 | syncUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB_QUANTIZED); 124 | syncUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_MOE_INDEXES); 125 | syncUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_MOE_WEIGHTS); 126 | } 127 | 128 | void grokMoeBlock0(TASK_ARGS) { 129 | TASK_VARIABLES; 130 | 131 | uint8_t* indexes = (uint8_t*)transformer->buffer->getUnit(TB_UNIT_MOE_INDEXES); 132 | float* xb = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED); 133 | float* hb = (float*)transformer->buffer->getSliced(TB_SLICED_HB, transformer->sliceIndex); 134 | 135 | for (int ae = 0; ae < spec->nActiveExperts; ae++) { 136 | uint8_t e = indexes[ae]; 137 | 138 | float* expertUp = &hb[block->moeUpAndGate0Slice->d0 * ae]; 139 | float* expertGate = &block->expertGate[block->moeUpAndGate0Slice->d0 * ae]; 140 | 141 | block->moeUpMm[e]->forward(xb, expertUp, nThreads, threadIndex); 142 | block->moeGateMm[e]->forward(xb, expertGate, nThreads, threadIndex); 143 | } 144 | } 145 | 146 | void grokMoeBlock1(TASK_ARGS) { 147 | TASK_VARIABLES; 148 | float* hb = (float*)transformer->buffer->getSliced(TB_SLICED_HB, transformer->sliceIndex); 149 | 150 | for (int ae = 0; ae < spec->nActiveExperts; ae++) { 151 | float* expertUp = &hb[block->moeUpAndGate0Slice->d0 * ae]; 152 | float* expertGate = &block->expertGate[block->moeUpAndGate0Slice->d0 * ae]; 153 | 154 | if (spec->hiddenAct == SILU) { 155 | silu(expertGate, block->moeUpAndGate0Slice->d0, nThreads, threadIndex); 156 | } else if (spec->hiddenAct == GELU) { 157 | gelu(expertGate, block->moeUpAndGate0Slice->d0, nThreads, threadIndex); 158 | } else { 159 | assert(false); 160 | } 161 | mul(expertUp, expertGate, block->moeUpAndGate0Slice->d0, nThreads, threadIndex); 162 | } 163 | } 164 | 165 | void grokQuantizeMoeMul(TASK_ARGS) { 166 | TASK_VARIABLES; 167 | quantizeSlicedBuffer(nThreads, threadIndex, ctx, true, TB_SLICED_HB, TB_SLICED_HB_QUANTIZED); 168 | } 169 | 170 | void grokSyncMoeMulA(TASK_ARGS) { 171 | TASK_VARIABLES; 172 | syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_HB_QUANTIZED); 173 | } 174 | 175 | void grokSyncMoeMulRearrange(TASK_ARGS) { 176 | TASK_VARIABLES; 177 | 178 | if (threadIndex == 0 && spec->nSlices > 1) { 179 | char* hbq = (char*)transformer->buffer->getUnit(TB_SLICED_HB_QUANTIZED); 180 | size_t bufferBytes = transformer->buffer->getUnitBytes(TB_SLICED_HB_QUANTIZED); 181 | size_t bufferSliceBytes = transformer->buffer->getSlicedBytes(TB_SLICED_HB_QUANTIZED); 182 | 183 | size_t moeUpBytes = bufferBytes / spec->nActiveExperts; 184 | size_t moeUp0SliceBytes = getBatchBytes(spec->bufferFloatType, block->moeUpAndGate0Slice->d0, 1); 185 | 186 | char* buffer = new char[bufferBytes]; 187 | 188 | for (int s = 0; s < spec->nSlices; s++) { 189 | for (int ae = 0; ae < spec->nActiveExperts; ae++) { 190 | memcpy(&buffer[ae * moeUpBytes + s * moeUp0SliceBytes], &hbq[s * bufferSliceBytes + ae * moeUp0SliceBytes], moeUp0SliceBytes); 191 | } 192 | } 193 | 194 | memcpy(hbq, buffer, bufferBytes); 195 | delete[] buffer; 196 | } 197 | } 198 | 199 | void grokSyncMoeMulB(TASK_ARGS) { 200 | TASK_VARIABLES; 201 | syncUnitBuffer(nThreads, threadIndex, ctx, TB_SLICED_HB_QUANTIZED); 202 | } 203 | 204 | void grokMoeBlock2(TASK_ARGS) { 205 | TASK_VARIABLES; 206 | 207 | float* xb2 = (float*)transformer->buffer->getSliced(TB_SLICED_XB2, transformer->sliceIndex); 208 | char* hbq = (char*)transformer->buffer->getUnit(TB_SLICED_HB_QUANTIZED); 209 | size_t rowBytes = getBatchBytes(spec->bufferFloatType, spec->hiddenDim, 1); 210 | 211 | uint8_t* indexes = (uint8_t*)transformer->buffer->getUnit(TB_UNIT_MOE_INDEXES); 212 | float* weights = (float*)transformer->buffer->getUnit(TB_UNIT_MOE_WEIGHTS); 213 | 214 | for (int ae = 0; ae < spec->nActiveExperts; ae++) { 215 | uint8_t e = indexes[ae]; 216 | float weight = weights[ae]; 217 | 218 | char* expertUp = &hbq[rowBytes * ae]; 219 | float* expertDown = ae == 0 ? xb2 : &block->expertDown[block->moeDown0Slice->d0 * (ae - 1)]; 220 | 221 | block->moeDownMm[e]->forward(expertUp, expertDown, nThreads, threadIndex); 222 | 223 | mulScalar(expertDown, weight, block->moeDown0Slice->d0, nThreads, threadIndex); 224 | if (ae > 0) { 225 | add(xb2, expertDown, block->moeDown0Slice->d0, nThreads, threadIndex); 226 | } 227 | } 228 | } 229 | 230 | void grokQuantizeMoeOutput(TASK_ARGS) { 231 | TASK_VARIABLES; 232 | quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XB2, TB_SLICED_XB2_QUANTIZED); 233 | } 234 | 235 | void grokSyncMoeOutput(TASK_ARGS) { 236 | TASK_VARIABLES; 237 | syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_XB2_QUANTIZED); 238 | } 239 | 240 | void grokDequantizeMoeOutput(TASK_ARGS) { 241 | TASK_VARIABLES; 242 | dequantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XB2_QUANTIZED, TB_SLICED_XB2); 243 | } 244 | 245 | void grokMoeRmsFinal(TASK_ARGS) { 246 | TASK_VARIABLES; 247 | if (threadIndex == 0) { 248 | float* xb2 = (float*)transformer->buffer->getUnit(TB_SLICED_XB2); 249 | transformer->rms = rms(xb2, spec->dim); 250 | } 251 | } 252 | 253 | void grokMoeRmsNormFinal(TASK_ARGS) { 254 | TASK_VARIABLES; 255 | float* xb2 = (float*)transformer->buffer->getUnit(TB_SLICED_XB2); 256 | rmsnorm(xb2, xb2, transformer->rms, block->rmsFfn2, spec->dim, nThreads, threadIndex); 257 | } 258 | 259 | void grokMoeAdd(TASK_ARGS) { 260 | TASK_VARIABLES; 261 | float* xb2 = (float*)transformer->buffer->getUnit(TB_SLICED_XB2); 262 | add(transformer->x, xb2, spec->dim, nThreads, threadIndex); 263 | } 264 | 265 | void grokFinalize(TASK_ARGS) { 266 | TASK_VARIABLES; 267 | transformer->wclsMm->forward(transformer->x, transformer->logits, nThreads, threadIndex); 268 | } 269 | 270 | void grokFinalize2(TASK_ARGS) { 271 | TASK_VARIABLES; 272 | mulScalar(transformer->logits, 0.5773502691896257f, spec->vocabSize, nThreads, threadIndex); 273 | } 274 | 275 | TransformerArch buildGrok1Arch(TransformerSpec* spec) { 276 | TransformerArch a; 277 | 278 | // inference 279 | 280 | a.I(TASK_WITH_NAME(sendPos), TASK_TYPE_TRANSFER); 281 | a.I(TASK_WITH_NAME(grokMulInput), TASK_TYPE_INFERENCE); 282 | for (int i = 0; i < spec->nLayers; i++) { 283 | a.I(TASK_WITH_NAME(llamaRmsAtt), TASK_TYPE_INFERENCE); 284 | a.I(TASK_WITH_NAME(llamaRmsAttNorm), TASK_TYPE_INFERENCE); 285 | a.I(TASK_WITH_NAME(llamaQuantizeRmsAtt), TASK_TYPE_INFERENCE); 286 | a.I(TASK_WITH_NAME(llamaSyncRmsAtt), TASK_TYPE_TRANSFER); 287 | a.I(TASK_WITH_NAME(llamaQkv), TASK_TYPE_INFERENCE); 288 | a.I(TASK_WITH_NAME(llamaRope), TASK_TYPE_INFERENCE); 289 | a.I(TASK_WITH_NAME(llamaMultiheadAtt), TASK_TYPE_INFERENCE); 290 | a.I(TASK_WITH_NAME(llamaQuantizeMultiheadAtt), TASK_TYPE_INFERENCE); 291 | a.I(TASK_WITH_NAME(llamaAtt), TASK_TYPE_INFERENCE); 292 | a.I(TASK_WITH_NAME(llamaQuantizeAtt), TASK_TYPE_INFERENCE); 293 | a.I(TASK_WITH_NAME(llamaSyncAtt), TASK_TYPE_TRANSFER); 294 | a.I(TASK_WITH_NAME(llamaDequantizeAtt), TASK_TYPE_INFERENCE); 295 | a.I(TASK_WITH_NAME(grokRmfFfn), TASK_TYPE_INFERENCE); 296 | a.I(TASK_WITH_NAME(grokRmfFfnNorm), TASK_TYPE_INFERENCE); 297 | a.I(TASK_WITH_NAME(grokRmfFfnNormJoin), TASK_TYPE_INFERENCE); 298 | 299 | a.I(TASK_WITH_NAME(grokMoeRms), TASK_TYPE_INFERENCE); 300 | a.I(TASK_WITH_NAME(grokMoeRmsNorm), TASK_TYPE_INFERENCE); 301 | a.I(TASK_WITH_NAME(grokMoeRouter), TASK_TYPE_INFERENCE); 302 | a.I(TASK_WITH_NAME(grokMoeRouterSoftmax), TASK_TYPE_INFERENCE); 303 | a.I(TASK_WITH_NAME(grokMoeTopk), TASK_TYPE_INFERENCE); 304 | a.I(TASK_WITH_NAME(grokMoeNormWeights), TASK_TYPE_INFERENCE); 305 | a.I(TASK_WITH_NAME(grokQuantizeMoeInput), TASK_TYPE_INFERENCE); 306 | a.I(TASK_WITH_NAME(grokSyncMoeInput), TASK_TYPE_TRANSFER); 307 | a.I(TASK_WITH_NAME(grokMoeBlock0), TASK_TYPE_INFERENCE); 308 | a.I(TASK_WITH_NAME(grokMoeBlock1), TASK_TYPE_INFERENCE); 309 | a.I(TASK_WITH_NAME(grokQuantizeMoeMul), TASK_TYPE_INFERENCE); 310 | a.I(TASK_WITH_NAME(grokSyncMoeMulA), TASK_TYPE_INFERENCE); 311 | a.I(TASK_WITH_NAME(grokSyncMoeMulRearrange), TASK_TYPE_INFERENCE); 312 | a.I(TASK_WITH_NAME(grokSyncMoeMulB), TASK_TYPE_INFERENCE); 313 | a.I(TASK_WITH_NAME(grokMoeBlock2), TASK_TYPE_INFERENCE); 314 | a.I(TASK_WITH_NAME(grokQuantizeMoeOutput), TASK_TYPE_INFERENCE); 315 | a.I(TASK_WITH_NAME(grokSyncMoeOutput), TASK_TYPE_TRANSFER); 316 | a.I(TASK_WITH_NAME(grokDequantizeMoeOutput), TASK_TYPE_INFERENCE); 317 | a.I(TASK_WITH_NAME(grokMoeRmsFinal), TASK_TYPE_INFERENCE); 318 | a.I(TASK_WITH_NAME(grokMoeRmsNormFinal), TASK_TYPE_INFERENCE); 319 | a.I(TASK_WITH_NAME(grokMoeAdd), TASK_TYPE_INFERENCE); 320 | a.I(TASK_WITH_NAME(llamaNextBlock), TASK_TYPE_INFERENCE); 321 | } 322 | 323 | a.I(TASK_WITH_NAME(llamaRmsFinal), TASK_TYPE_INFERENCE); 324 | a.I(TASK_WITH_NAME(llamaRmsFinalNorm), TASK_TYPE_INFERENCE); 325 | a.I(TASK_WITH_NAME(grokFinalize), TASK_TYPE_INFERENCE); 326 | a.I(TASK_WITH_NAME(grokFinalize2), TASK_TYPE_INFERENCE); 327 | 328 | // worker 329 | 330 | for (int i = 0; i < spec->nLayers; i++) { 331 | a.W(TASK_WITH_NAME(llamaSyncRmsAtt), TASK_TYPE_TRANSFER); 332 | a.W(TASK_WITH_NAME(llamaQkv), TASK_TYPE_INFERENCE); 333 | a.W(TASK_WITH_NAME(llamaRope), TASK_TYPE_INFERENCE); 334 | a.W(TASK_WITH_NAME(llamaMultiheadAtt), TASK_TYPE_INFERENCE); 335 | a.W(TASK_WITH_NAME(llamaQuantizeMultiheadAtt), TASK_TYPE_INFERENCE); 336 | a.W(TASK_WITH_NAME(llamaAtt), TASK_TYPE_INFERENCE); 337 | a.W(TASK_WITH_NAME(llamaQuantizeAtt), TASK_TYPE_INFERENCE); 338 | a.W(TASK_WITH_NAME(llamaSyncAtt), TASK_TYPE_TRANSFER); 339 | 340 | a.W(TASK_WITH_NAME(grokSyncMoeInput), TASK_TYPE_TRANSFER); 341 | a.W(TASK_WITH_NAME(grokMoeBlock0), TASK_TYPE_INFERENCE); 342 | a.W(TASK_WITH_NAME(grokMoeBlock1), TASK_TYPE_INFERENCE); 343 | a.W(TASK_WITH_NAME(grokQuantizeMoeMul), TASK_TYPE_INFERENCE); 344 | a.W(TASK_WITH_NAME(grokSyncMoeMulA), TASK_TYPE_INFERENCE); 345 | a.W(TASK_WITH_NAME(grokSyncMoeMulB), TASK_TYPE_INFERENCE); 346 | a.W(TASK_WITH_NAME(grokMoeBlock2), TASK_TYPE_INFERENCE); 347 | a.W(TASK_WITH_NAME(grokQuantizeMoeOutput), TASK_TYPE_INFERENCE); 348 | a.W(TASK_WITH_NAME(grokSyncMoeOutput), TASK_TYPE_TRANSFER); 349 | 350 | a.W(TASK_WITH_NAME(llamaNextBlock), TASK_TYPE_INFERENCE); 351 | } 352 | 353 | return a; 354 | } 355 | -------------------------------------------------------------------------------- /src/grok1-tasks.hpp: -------------------------------------------------------------------------------- 1 | #ifndef GROK_TASKS_HPP 2 | #define GROK_TASKS_HPP 3 | 4 | #include "tasks.hpp" 5 | 6 | void grokRmfFfn(TASK_ARGS); 7 | void grokRmfFfnNorm(TASK_ARGS); 8 | void grokRmfFfnNormJoin(TASK_ARGS); 9 | void grokMoeRms(TASK_ARGS); 10 | void grokMoeRmsNorm(TASK_ARGS); 11 | void grokMoeRouter(TASK_ARGS); 12 | void grokMoeRouterSoftmax(TASK_ARGS); 13 | void grokMoeTopk(TASK_ARGS); 14 | void grokMoeNormWeights(TASK_ARGS); 15 | void grokQuantizeMoeInput(TASK_ARGS); 16 | void grokSyncMoeInput(TASK_ARGS); 17 | void grokMoeBlock0(TASK_ARGS); 18 | void grokMoeBlock1(TASK_ARGS); 19 | void grokQuantizeMoeMul(TASK_ARGS); 20 | void grokSyncMoeMulA(TASK_ARGS); 21 | void grokSyncMoeMulRearrange(TASK_ARGS); 22 | void grokSyncMoeMulB(TASK_ARGS); 23 | void grokMoeBlock2(TASK_ARGS); 24 | void grokQuantizeMoeOutput(TASK_ARGS); 25 | void grokSyncMoeOutput(TASK_ARGS); 26 | void grokDequantizeMoeOutput(TASK_ARGS); 27 | void grokMoeRmsFinal(TASK_ARGS); 28 | void grokMoeRmsNormFinal(TASK_ARGS); 29 | void grokMoeAdd(TASK_ARGS); 30 | 31 | TransformerArch buildGrok1Arch(TransformerSpec* spec); 32 | 33 | #endif -------------------------------------------------------------------------------- /src/llama2-tasks.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "utils.hpp" 5 | #include "funcs.hpp" 6 | #include "socket.hpp" 7 | #include "tasks.hpp" 8 | #include "llama2-tasks.hpp" 9 | 10 | void llamaRmsAtt(TASK_ARGS) { 11 | TASK_VARIABLES; 12 | if (threadIndex == 0) { 13 | transformer->rms = rms(transformer->x, spec->dim); 14 | } 15 | } 16 | 17 | void llamaRmsAttNorm(TASK_ARGS) { 18 | TASK_VARIABLES; 19 | float* xb = (float*)transformer->buffer->getUnit(TB_UNIT_XB); 20 | rmsnorm(xb, transformer->x, transformer->rms, block->rmsAtt, spec->dim, nThreads, threadIndex); 21 | } 22 | 23 | void llamaQuantizeRmsAtt(TASK_ARGS) { 24 | TASK_VARIABLES; 25 | quantizeUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB, TB_UNIT_XB_QUANTIZED); 26 | } 27 | 28 | void llamaSyncRmsAtt(TASK_ARGS) { 29 | TASK_VARIABLES; 30 | syncUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB_QUANTIZED); 31 | } 32 | 33 | void llamaQkv(TASK_ARGS) { 34 | TASK_VARIABLES; 35 | assert(block->kvCacheSlice->kvDim0 == block->k0Slice->d0); 36 | assert(block->kvCacheSlice->kvDim0 == block->v0Slice->d0); 37 | 38 | float *xbq = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED); 39 | float *k0 = &block->keyCache[transformer->pos * block->kvCacheSlice->kvDim0]; 40 | float* v0 = &block->valueCache[transformer->pos * block->kvCacheSlice->kvDim0]; 41 | 42 | block->q0mm->forward(xbq, block->qo0, nThreads, threadIndex); 43 | block->k0mm->forward(xbq, k0, nThreads, threadIndex); 44 | block->v0mm->forward(xbq, v0, nThreads, threadIndex); 45 | } 46 | 47 | void llamaRope(TASK_ARGS) { 48 | TASK_VARIABLES; 49 | float* k0 = &block->keyCache[transformer->pos * block->kvCacheSlice->kvDim0]; 50 | transformer->rope->forward(true, block->qo0, transformer->pos, nThreads, threadIndex); 51 | transformer->rope->forward(false, k0, transformer->pos, nThreads, threadIndex); 52 | } 53 | 54 | void llamaMultiheadAtt(TASK_ARGS) { 55 | TASK_VARIABLES; 56 | SPLIT_RANGE_TO_THREADS(h0Start, h0End, 0, block->multiHeadAttSlice->nHeads0, nThreads, threadIndex); 57 | 58 | float* xb = (float*)transformer->buffer->getSliced(TB_UNIT_XB, transformer->sliceIndex); 59 | 60 | int kvMul = spec->nHeads / spec->nKvHeads; // integer multiplier of the kv sharing in multiquery 61 | 62 | for (int h0 = h0Start; h0 < h0End; h0++) { 63 | // get the query vector for this head 64 | float* _q = block->qo0 + h0 * spec->headSize; 65 | // attention scores for this head 66 | float* _att = block->att + h0 * spec->seqLen; 67 | // iterate over all timesteps, including the current one 68 | for (int t = 0; t <= transformer->pos; t++) { 69 | // get the key vector for this head and at this timestep 70 | float* k = block->keyCache + t * block->kvCacheSlice->kvDim0 + (h0 / kvMul) * spec->headSize; 71 | // calculate the attention score as the dot product of q and k 72 | float score = dotProduct(_q, k, spec->headSize) / sqrtf(spec->headSize); 73 | _att[t] = score; 74 | } 75 | 76 | // softmax the scores to get attention weights, from 0..pos inclusively 77 | softmax(_att, transformer->pos + 1); 78 | 79 | // weighted sum of the values, store back into xb 80 | float* hxb = xb + h0 * spec->headSize; 81 | memset(hxb, 0, spec->headSize * sizeof(float)); 82 | for (int t = 0; t <= transformer->pos; t++) { 83 | // get the value vector for this head and at this timestep 84 | float* _v = block->valueCache + t * block->kvCacheSlice->kvDim0 + (h0 / kvMul) * spec->headSize; 85 | // get the attention weight for this timestep 86 | float a = _att[t]; 87 | 88 | // accumulate the weighted value into xb 89 | for (int i = 0; i < spec->headSize; i++) { 90 | hxb[i] += a * _v[i]; 91 | } 92 | } 93 | } 94 | } 95 | 96 | void llamaQuantizeMultiheadAtt(TASK_ARGS) { 97 | TASK_VARIABLES; 98 | quantizeSlicedBuffer(nThreads, threadIndex, ctx, true, TB_UNIT_XB, TB_UNIT_XB_QUANTIZED); 99 | }; 100 | 101 | void llamaAtt(TASK_ARGS) { 102 | TASK_VARIABLES; 103 | 104 | void* xbq0 = transformer->buffer->getSliced(TB_UNIT_XB_QUANTIZED, transformer->sliceIndex); 105 | float* xbv0 = (float*)transformer->buffer->getSliced(TB_SLICED_XBV, transformer->sliceIndex); 106 | 107 | block->wo0mm->forward(xbq0, xbv0, nThreads, threadIndex); 108 | } 109 | 110 | void llamaQuantizeAtt(TASK_ARGS) { 111 | TASK_VARIABLES; 112 | quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XBV, TB_SLICED_XBV_QUANTIZED); 113 | } 114 | 115 | void llamaSyncAtt(TASK_ARGS) { 116 | TASK_VARIABLES; 117 | syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_XBV_QUANTIZED); 118 | } 119 | 120 | void llamaDequantizeAtt(TASK_ARGS) { 121 | TASK_VARIABLES; 122 | dequantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XBV_QUANTIZED, TB_SLICED_XBV); 123 | } 124 | 125 | void llamaMergeAtt(TASK_ARGS) { 126 | TASK_VARIABLES; 127 | for (slice_index_t sliceIndex = 0; sliceIndex < spec->nSlices; sliceIndex++) { 128 | float* xbv = (float*)transformer->buffer->getSliced(TB_SLICED_XBV, sliceIndex); 129 | add(transformer->x, xbv, spec->dim, nThreads, threadIndex); 130 | } 131 | } 132 | 133 | void llamaRmfFfn(TASK_ARGS) { 134 | TASK_VARIABLES; 135 | if (threadIndex == 0) { 136 | transformer->rms = rms(transformer->x, spec->dim); 137 | } 138 | } 139 | 140 | void llamaRmfFfnNorm(TASK_ARGS) { 141 | TASK_VARIABLES; 142 | float* xb = (float*)transformer->buffer->getUnit(TB_UNIT_XB); 143 | float* x = (float*)transformer->x; 144 | 145 | rmsnorm(xb, x, transformer->rms, block->rmsFfn, spec->dim, nThreads, threadIndex); 146 | } 147 | 148 | void llamaQuantizeRmfFfn(TASK_ARGS) { 149 | TASK_VARIABLES; 150 | quantizeUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB, TB_UNIT_XB_QUANTIZED); 151 | } 152 | 153 | void llamaSyncFfn(TASK_ARGS) { 154 | TASK_VARIABLES; 155 | syncUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB_QUANTIZED); 156 | } 157 | 158 | void llamaFfn0(TASK_ARGS) { 159 | TASK_VARIABLES; 160 | 161 | float* xb = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED); 162 | float* hb0 = (float*)transformer->buffer->getSliced(TB_SLICED_HB, transformer->sliceIndex); 163 | 164 | block->w10mm->forward(xb, hb0, nThreads, threadIndex); 165 | block->w30mm->forward(xb, block->hb20, nThreads, threadIndex); 166 | 167 | if (spec->hiddenAct == SILU) { 168 | silu(hb0, block->w10Slice->d0, nThreads, threadIndex); 169 | } else if (spec->hiddenDim == GELU) { 170 | gelu(hb0, block->w10Slice->d0, nThreads, threadIndex); 171 | } else { 172 | assert(false); 173 | } 174 | mul(hb0, block->hb20, block->w10Slice->d0, nThreads, threadIndex); 175 | } 176 | 177 | void llamaFfn1(TASK_ARGS) { 178 | TASK_VARIABLES; 179 | quantizeSlicedBuffer(nThreads, threadIndex, ctx, true, TB_SLICED_HB, TB_SLICED_HB_QUANTIZED); 180 | } 181 | 182 | void llamaFfn2(TASK_ARGS) { 183 | TASK_VARIABLES; 184 | 185 | float *hb = (float*)transformer->buffer->getSliced(TB_SLICED_HB_QUANTIZED, transformer->sliceIndex); 186 | float *xbv = (float*)transformer->buffer->getSliced(TB_SLICED_XBV, transformer->sliceIndex); 187 | 188 | block->w20mm->forward(hb, xbv, nThreads, threadIndex); 189 | } 190 | 191 | void llamaQuantizeFfn2(TASK_ARGS) { 192 | TASK_VARIABLES; 193 | quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XBV, TB_SLICED_XBV_QUANTIZED); 194 | } 195 | 196 | void llamaSyncFfn2(TASK_ARGS) { 197 | TASK_VARIABLES; 198 | syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_XBV_QUANTIZED); 199 | } 200 | 201 | void llamaDequantizeFfn2(TASK_ARGS) { 202 | TASK_VARIABLES; 203 | dequantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XBV_QUANTIZED, TB_SLICED_XBV); 204 | } 205 | 206 | void llamaMergeFfn2(TASK_ARGS) { 207 | TASK_VARIABLES; 208 | for (slice_index_t sliceIndex = 0; sliceIndex < spec->nSlices; sliceIndex++) { 209 | float* xbv = (float*)transformer->buffer->getSliced(TB_SLICED_XBV, sliceIndex); 210 | add(transformer->x, xbv, spec->dim, nThreads, threadIndex); 211 | } 212 | } 213 | 214 | void llamaNextBlock(TASK_ARGS) { 215 | TASK_VARIABLES; 216 | 217 | if (threadIndex == 0) { 218 | ctx->currentBlockIndex++; 219 | } 220 | } 221 | 222 | void llamaRmsFinal(TASK_ARGS) { 223 | TASK_VARIABLES; 224 | if (threadIndex == 0) { 225 | float* x = transformer->x; 226 | transformer->rms = rms(x, spec->dim); 227 | } 228 | } 229 | 230 | void llamaRmsFinalNorm(TASK_ARGS) { 231 | TASK_VARIABLES; 232 | float* x = transformer->x; 233 | rmsnorm(x, x, transformer->rms, (float*)transformer->rmsFinal, spec->dim, nThreads, threadIndex); 234 | } 235 | 236 | void llamaFinalize(TASK_ARGS) { 237 | TASK_VARIABLES; 238 | float* lb = (float*)transformer->buffer->getSliced(TB_SLICED_LOGITS, transformer->sliceIndex); 239 | transformer->wclsMm->forward(transformer->x, lb, nThreads, threadIndex); 240 | } 241 | 242 | void llamaSyncFinalize(TASK_ARGS) { 243 | TASK_VARIABLES; 244 | syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_LOGITS); 245 | } 246 | 247 | void llamaCopyFinalize(TASK_ARGS) { 248 | TASK_VARIABLES; 249 | if (threadIndex == 0) { 250 | float* lb = (float*)transformer->buffer->getUnit(TB_SLICED_LOGITS); 251 | memcpy(transformer->logits, lb, transformer->buffer->getUnitBytes(TB_SLICED_LOGITS)); 252 | } 253 | } 254 | 255 | TransformerArch buildLlamaArch(TransformerSpec* spec) { 256 | TransformerArch a; 257 | 258 | // inference 259 | 260 | a.I(sendPos, "sendPos", TASK_TYPE_TRANSFER); 261 | for (int i = 0; i < spec->nLayers; i++) { 262 | a.I(TASK_WITH_NAME(llamaRmsAtt), TASK_TYPE_INFERENCE); 263 | a.I(TASK_WITH_NAME(llamaRmsAttNorm), TASK_TYPE_INFERENCE); 264 | a.I(TASK_WITH_NAME(llamaQuantizeRmsAtt), TASK_TYPE_INFERENCE); 265 | a.I(TASK_WITH_NAME(llamaSyncRmsAtt), TASK_TYPE_TRANSFER); 266 | a.I(TASK_WITH_NAME(llamaQkv), TASK_TYPE_INFERENCE); 267 | a.I(TASK_WITH_NAME(llamaRope), TASK_TYPE_INFERENCE); 268 | a.I(TASK_WITH_NAME(llamaMultiheadAtt), TASK_TYPE_INFERENCE); 269 | a.I(TASK_WITH_NAME(llamaQuantizeMultiheadAtt), TASK_TYPE_INFERENCE); 270 | a.I(TASK_WITH_NAME(llamaAtt), TASK_TYPE_INFERENCE); 271 | a.I(TASK_WITH_NAME(llamaQuantizeAtt), TASK_TYPE_INFERENCE); 272 | a.I(TASK_WITH_NAME(llamaSyncAtt), TASK_TYPE_TRANSFER); 273 | a.I(TASK_WITH_NAME(llamaDequantizeAtt), TASK_TYPE_INFERENCE); 274 | a.I(TASK_WITH_NAME(llamaMergeAtt), TASK_TYPE_INFERENCE); 275 | a.I(TASK_WITH_NAME(llamaRmfFfn), TASK_TYPE_INFERENCE); 276 | a.I(TASK_WITH_NAME(llamaRmfFfnNorm), TASK_TYPE_INFERENCE); 277 | a.I(TASK_WITH_NAME(llamaQuantizeRmfFfn), TASK_TYPE_INFERENCE); 278 | a.I(TASK_WITH_NAME(llamaSyncFfn), TASK_TYPE_TRANSFER); 279 | a.I(TASK_WITH_NAME(llamaFfn0), TASK_TYPE_INFERENCE); 280 | a.I(TASK_WITH_NAME(llamaFfn1), TASK_TYPE_INFERENCE); 281 | a.I(TASK_WITH_NAME(llamaFfn2), TASK_TYPE_INFERENCE); 282 | a.I(TASK_WITH_NAME(llamaQuantizeFfn2), TASK_TYPE_INFERENCE); 283 | a.I(TASK_WITH_NAME(llamaSyncFfn2), TASK_TYPE_TRANSFER); 284 | a.I(TASK_WITH_NAME(llamaDequantizeFfn2), TASK_TYPE_INFERENCE); 285 | a.I(TASK_WITH_NAME(llamaMergeFfn2), TASK_TYPE_INFERENCE); 286 | a.I(TASK_WITH_NAME(llamaNextBlock), TASK_TYPE_INFERENCE); 287 | } 288 | a.I(TASK_WITH_NAME(llamaRmsFinal), TASK_TYPE_INFERENCE); 289 | a.I(TASK_WITH_NAME(llamaRmsFinalNorm), TASK_TYPE_INFERENCE); 290 | a.I(TASK_WITH_NAME(llamaFinalize), TASK_TYPE_INFERENCE); 291 | a.I(TASK_WITH_NAME(llamaSyncFinalize), TASK_TYPE_TRANSFER); 292 | a.I(TASK_WITH_NAME(llamaCopyFinalize), TASK_TYPE_INFERENCE); 293 | 294 | // worker 295 | 296 | for (int i = 0; i < spec->nLayers; i++) { 297 | a.W(TASK_WITH_NAME(llamaSyncRmsAtt), TASK_TYPE_TRANSFER); 298 | a.W(TASK_WITH_NAME(llamaQkv), TASK_TYPE_INFERENCE); 299 | a.W(TASK_WITH_NAME(llamaRope), TASK_TYPE_INFERENCE); 300 | a.W(TASK_WITH_NAME(llamaMultiheadAtt), TASK_TYPE_INFERENCE); 301 | a.W(TASK_WITH_NAME(llamaQuantizeMultiheadAtt), TASK_TYPE_INFERENCE); 302 | a.W(TASK_WITH_NAME(llamaAtt), TASK_TYPE_INFERENCE); 303 | a.W(TASK_WITH_NAME(llamaQuantizeAtt), TASK_TYPE_INFERENCE); 304 | a.W(TASK_WITH_NAME(llamaSyncAtt), TASK_TYPE_TRANSFER); 305 | a.W(TASK_WITH_NAME(llamaSyncFfn), TASK_TYPE_TRANSFER); 306 | a.W(TASK_WITH_NAME(llamaFfn0), TASK_TYPE_INFERENCE); 307 | a.W(TASK_WITH_NAME(llamaFfn1), TASK_TYPE_INFERENCE); 308 | a.W(TASK_WITH_NAME(llamaFfn2), TASK_TYPE_INFERENCE); 309 | a.W(TASK_WITH_NAME(llamaQuantizeFfn2), TASK_TYPE_INFERENCE); 310 | a.W(TASK_WITH_NAME(llamaSyncFfn2), TASK_TYPE_TRANSFER); 311 | a.W(TASK_WITH_NAME(llamaNextBlock), TASK_TYPE_INFERENCE); 312 | } 313 | a.W(TASK_WITH_NAME(llamaFinalize), TASK_TYPE_INFERENCE); 314 | a.W(TASK_WITH_NAME(llamaSyncFinalize), TASK_TYPE_TRANSFER); 315 | return a; 316 | } 317 | -------------------------------------------------------------------------------- /src/llama2-tasks.hpp: -------------------------------------------------------------------------------- 1 | #ifndef LLAMA2_TASKS_HPP 2 | #define LLAMA2_TASKS_HPP 3 | 4 | #include "tasks.hpp" 5 | 6 | void llamaRmsAtt(TASK_ARGS); 7 | void llamaRmsAttNorm(TASK_ARGS); 8 | void llamaQuantizeRmsAtt(TASK_ARGS); 9 | void llamaSyncRmsAtt(TASK_ARGS); 10 | void llamaQkv(TASK_ARGS); 11 | void llamaRope(TASK_ARGS); 12 | void llamaMultiheadAtt(TASK_ARGS); 13 | void llamaQuantizeMultiheadAtt(TASK_ARGS); 14 | void llamaAtt(TASK_ARGS); 15 | void llamaQuantizeAtt(TASK_ARGS); 16 | void llamaSyncAtt(TASK_ARGS); 17 | void llamaDequantizeAtt(TASK_ARGS); 18 | void llamaMergeAtt(TASK_ARGS); 19 | void llamaRmfFfn(TASK_ARGS); 20 | void llamaRmfFfnNorm(TASK_ARGS); 21 | void llamaNextBlock(TASK_ARGS); 22 | void llamaRmsFinal(TASK_ARGS); 23 | void llamaRmsFinalNorm(TASK_ARGS); 24 | void llamaFinalize(TASK_ARGS); 25 | 26 | TransformerArch buildLlamaArch(TransformerSpec* spec); 27 | 28 | #endif -------------------------------------------------------------------------------- /src/mixtral-tasks.cpp: -------------------------------------------------------------------------------- 1 | #include "llama2-tasks.hpp" 2 | #include "grok1-tasks.hpp" 3 | #include "mixtral-tasks.hpp" 4 | 5 | TransformerArch buildMixtralArch(TransformerSpec* spec) { 6 | TransformerArch a; 7 | 8 | // inference 9 | 10 | a.I(sendPos, "sendPos", TASK_TYPE_TRANSFER); 11 | for (int i = 0; i < spec->nLayers; i++) { 12 | a.I(TASK_WITH_NAME(llamaRmsAtt), TASK_TYPE_INFERENCE); 13 | a.I(TASK_WITH_NAME(llamaRmsAttNorm), TASK_TYPE_INFERENCE); 14 | a.I(TASK_WITH_NAME(llamaQuantizeRmsAtt), TASK_TYPE_INFERENCE); 15 | a.I(TASK_WITH_NAME(llamaSyncRmsAtt), TASK_TYPE_TRANSFER); 16 | a.I(TASK_WITH_NAME(llamaQkv), TASK_TYPE_INFERENCE); 17 | a.I(TASK_WITH_NAME(llamaRope), TASK_TYPE_INFERENCE); 18 | a.I(TASK_WITH_NAME(llamaMultiheadAtt), TASK_TYPE_INFERENCE); 19 | a.I(TASK_WITH_NAME(llamaQuantizeMultiheadAtt), TASK_TYPE_INFERENCE); 20 | a.I(TASK_WITH_NAME(llamaAtt), TASK_TYPE_INFERENCE); 21 | a.I(TASK_WITH_NAME(llamaQuantizeAtt), TASK_TYPE_INFERENCE); 22 | a.I(TASK_WITH_NAME(llamaSyncAtt), TASK_TYPE_TRANSFER); 23 | a.I(TASK_WITH_NAME(llamaDequantizeAtt), TASK_TYPE_INFERENCE); 24 | a.I(TASK_WITH_NAME(llamaMergeAtt), TASK_TYPE_INFERENCE); 25 | a.I(TASK_WITH_NAME(llamaRmfFfn), TASK_TYPE_INFERENCE); 26 | a.I(TASK_WITH_NAME(llamaRmfFfnNorm), TASK_TYPE_INFERENCE); 27 | 28 | a.I(TASK_WITH_NAME(grokMoeRouter), TASK_TYPE_INFERENCE); 29 | a.I(TASK_WITH_NAME(grokMoeRouterSoftmax), TASK_TYPE_INFERENCE); 30 | a.I(TASK_WITH_NAME(grokMoeTopk), TASK_TYPE_INFERENCE); 31 | a.I(TASK_WITH_NAME(grokMoeNormWeights), TASK_TYPE_INFERENCE); 32 | a.I(TASK_WITH_NAME(grokQuantizeMoeInput), TASK_TYPE_INFERENCE); 33 | a.I(TASK_WITH_NAME(grokSyncMoeInput), TASK_TYPE_TRANSFER); 34 | a.I(TASK_WITH_NAME(grokMoeBlock0), TASK_TYPE_INFERENCE); 35 | a.I(TASK_WITH_NAME(grokMoeBlock1), TASK_TYPE_INFERENCE); 36 | a.I(TASK_WITH_NAME(grokQuantizeMoeMul), TASK_TYPE_INFERENCE); 37 | a.I(TASK_WITH_NAME(grokSyncMoeMulA), TASK_TYPE_INFERENCE); 38 | a.I(TASK_WITH_NAME(grokSyncMoeMulRearrange), TASK_TYPE_INFERENCE); 39 | a.I(TASK_WITH_NAME(grokSyncMoeMulB), TASK_TYPE_INFERENCE); 40 | a.I(TASK_WITH_NAME(grokMoeBlock2), TASK_TYPE_INFERENCE); 41 | a.I(TASK_WITH_NAME(grokQuantizeMoeOutput), TASK_TYPE_INFERENCE); 42 | a.I(TASK_WITH_NAME(grokSyncMoeOutput), TASK_TYPE_TRANSFER); 43 | a.I(TASK_WITH_NAME(grokDequantizeMoeOutput), TASK_TYPE_INFERENCE); 44 | a.I(TASK_WITH_NAME(grokMoeAdd), TASK_TYPE_INFERENCE); 45 | 46 | a.I(TASK_WITH_NAME(llamaNextBlock), TASK_TYPE_INFERENCE); 47 | } 48 | a.I(TASK_WITH_NAME(llamaRmsFinal), TASK_TYPE_INFERENCE); 49 | a.I(TASK_WITH_NAME(llamaRmsFinalNorm), TASK_TYPE_INFERENCE); 50 | a.I(TASK_WITH_NAME(llamaFinalize), TASK_TYPE_INFERENCE); 51 | 52 | // worker 53 | 54 | for (int i = 0; i < spec->nLayers; i++) { 55 | a.W(TASK_WITH_NAME(llamaSyncRmsAtt), TASK_TYPE_TRANSFER); 56 | a.W(TASK_WITH_NAME(llamaQkv), TASK_TYPE_INFERENCE); 57 | a.W(TASK_WITH_NAME(llamaRope), TASK_TYPE_INFERENCE); 58 | a.W(TASK_WITH_NAME(llamaMultiheadAtt), TASK_TYPE_INFERENCE); 59 | a.W(TASK_WITH_NAME(llamaQuantizeMultiheadAtt), TASK_TYPE_INFERENCE); 60 | a.W(TASK_WITH_NAME(llamaAtt), TASK_TYPE_INFERENCE); 61 | a.W(TASK_WITH_NAME(llamaQuantizeAtt), TASK_TYPE_INFERENCE); 62 | a.W(TASK_WITH_NAME(llamaSyncAtt), TASK_TYPE_TRANSFER); 63 | 64 | a.W(TASK_WITH_NAME(grokSyncMoeInput), TASK_TYPE_TRANSFER); 65 | a.W(TASK_WITH_NAME(grokMoeBlock0), TASK_TYPE_INFERENCE); 66 | a.W(TASK_WITH_NAME(grokMoeBlock1), TASK_TYPE_INFERENCE); 67 | a.W(TASK_WITH_NAME(grokQuantizeMoeMul), TASK_TYPE_INFERENCE); 68 | a.W(TASK_WITH_NAME(grokSyncMoeMulA), TASK_TYPE_INFERENCE); 69 | a.W(TASK_WITH_NAME(grokSyncMoeMulB), TASK_TYPE_INFERENCE); 70 | a.W(TASK_WITH_NAME(grokMoeBlock2), TASK_TYPE_INFERENCE); 71 | a.W(TASK_WITH_NAME(grokQuantizeMoeOutput), TASK_TYPE_INFERENCE); 72 | a.W(TASK_WITH_NAME(grokSyncMoeOutput), TASK_TYPE_TRANSFER); 73 | 74 | a.W(TASK_WITH_NAME(llamaNextBlock), TASK_TYPE_INFERENCE); 75 | } 76 | 77 | return a; 78 | } 79 | -------------------------------------------------------------------------------- /src/mixtral-tasks.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MIXTRAL_TASKS_HPP 2 | #define MIXTRAL_TASKS_HPP 3 | 4 | #include "tasks.hpp" 5 | 6 | TransformerArch buildMixtralArch(TransformerSpec* spec); 7 | 8 | #endif -------------------------------------------------------------------------------- /src/quants-test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "utils.hpp" 5 | #include "quants.hpp" 6 | 7 | void testQ80(const int len, int nThreads) { 8 | unsigned long long state = 800000010L; 9 | float input[len]; 10 | float* output = new float[len]; 11 | BlockQ80* q80s = new BlockQ80[len / QK80]; 12 | 13 | for (int i = 0; i < len; i++) { 14 | input[i] = randomF32(&state); 15 | output[i] = 0; 16 | } 17 | 18 | for (int threadIndex = 0; threadIndex < nThreads; threadIndex++) { 19 | quantizeQ80Row((float*)&input, (BlockQ80*)q80s, len, nThreads, threadIndex); 20 | } 21 | for (int threadIndex = 0; threadIndex < nThreads; threadIndex++) { 22 | dequantizeQ80Row((BlockQ80*)q80s, (float*)output, len, nThreads, threadIndex); 23 | } 24 | 25 | for (int i = 0; i < len; i++) { 26 | float diff = fabs(output[i] - input[i]); 27 | if (diff > 0.0043) { 28 | printf("❌ (%d, %d) ix=%d %f != %f diff=%f nThreads=%d\n", len, nThreads, i, output[i], input[i], diff, nThreads); 29 | exit(EXIT_FAILURE); 30 | } 31 | } 32 | 33 | delete[] output; 34 | delete[] q80s; 35 | } 36 | 37 | int main() { 38 | initQuants(); 39 | 40 | testQ80(1024, 1); 41 | testQ80(1024, 2); 42 | testQ80(1024, 4); 43 | testQ80(1024, 8); 44 | testQ80(1024, 16); 45 | testQ80(768, 1); 46 | testQ80(768, 2); 47 | testQ80(768, 4); 48 | testQ80(768, 8); 49 | testQ80(2752, 1); 50 | testQ80(2752, 2); 51 | testQ80(2752, 4); 52 | 53 | printf("✅ Q80 quantized correctly\n"); 54 | return EXIT_SUCCESS; 55 | } 56 | -------------------------------------------------------------------------------- /src/quants.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "quants.hpp" 6 | 7 | #if defined(__ARM_NEON) 8 | #include 9 | #endif 10 | 11 | int getNumbersPerBatch(FloatType type) { 12 | switch (type) { 13 | case F32: 14 | return 1; 15 | case F16: 16 | return 1; 17 | case Q40: 18 | return QK40; 19 | case Q80: 20 | return QK80; 21 | case FUNK: 22 | break; 23 | } 24 | fprintf(stderr, "Unsupported float type %d\n", type); 25 | exit(EXIT_FAILURE); 26 | } 27 | 28 | long getBatchBytes(FloatType type, int n, int d) { 29 | switch (type) { 30 | case F32: 31 | return n * d * sizeof(float); 32 | case F16: 33 | return n * d * sizeof(uint16_t); 34 | case Q40: 35 | { 36 | assert(n % QK40 == 0); 37 | int blocks = n / QK40 * d; 38 | return blocks * sizeof(BlockQ40); 39 | } 40 | case Q80: 41 | { 42 | assert(n % QK80 == 0); 43 | int blocks = n / QK80 * d; 44 | return blocks * sizeof(BlockQ80); 45 | } 46 | case FUNK: 47 | break; 48 | } 49 | fprintf(stderr, "Unsupported float type %d\n", type); 50 | exit(EXIT_FAILURE); 51 | } 52 | 53 | float F16ToF32[65536]; 54 | 55 | // https://gist.github.com/rygorous/2144712 56 | float _convertF16ToF32(uint16_t value) { 57 | union F32 58 | { 59 | uint32_t u; 60 | float f; 61 | }; 62 | 63 | const F32 magic = { (254U - 15U) << 23 }; 64 | const F32 was_infnan = { (127U + 16U) << 23 }; 65 | F32 out; 66 | 67 | out.u = (value & 0x7FFFU) << 13; 68 | out.f *= magic.f; 69 | if (out.f >= was_infnan.f) { 70 | out.u |= 255U << 23; 71 | } 72 | out.u |= (value & 0x8000U) << 16; 73 | return out.f; 74 | } 75 | 76 | uint16_t _convertF32ToF16(float value) { 77 | unsigned int fltInt32 = *(unsigned int*)&value; 78 | unsigned short fltInt16; 79 | 80 | fltInt16 = (fltInt32 >> 31) << 5; 81 | unsigned short tmp = (fltInt32 >> 23) & 0xff; 82 | tmp = (tmp - 0x70) & ((unsigned int)((int)(0x70 - tmp) >> 4) >> 27); 83 | fltInt16 = (fltInt16 | tmp) << 10; 84 | fltInt16 |= (fltInt32 >> 13) & 0x3ff; 85 | return fltInt16; 86 | } 87 | 88 | void initF16ToF32() { 89 | for (int i = 0; i < 65536; i++) { 90 | F16ToF32[i] = _convertF16ToF32(i); 91 | } 92 | } 93 | 94 | float convertF16ToF32(uint16_t value) { 95 | return F16ToF32[value]; 96 | } 97 | 98 | // https://github.com/mitsuba-renderer/openexr/blob/dbabb6f9500ee628c1faba21bb8add2649cc32a6/IlmBase/Half/half.cpp#L85 99 | uint16_t convertF32ToF16(const float x) { 100 | int i = *(int*)&x; 101 | int s = (i >> 16) & 0x00008000; 102 | int e = ((i >> 23) & 0x000000ff) - (127 - 15); 103 | int m = i & 0x007fffff; 104 | 105 | if (e <= 0) { 106 | if (e < -10) { 107 | return s; 108 | } 109 | m = m | 0x00800000; 110 | int t = 14 - e; 111 | int a = (1 << (t - 1)) - 1; 112 | int b = (m >> t) & 1; 113 | m = (m + a + b) >> t; 114 | return s | m; 115 | } else if (e == 0xff - (127 - 15)) { 116 | if (m == 0) { 117 | return s | 0x7c00; 118 | } else { 119 | m >>= 13; 120 | return s | 0x7c00 | m | (m == 0); 121 | } 122 | } else { 123 | m = m + 0x00000fff + ((m >> 13) & 1); 124 | 125 | if (m & 0x00800000) { 126 | m = 0; 127 | e += 1; 128 | } 129 | if (e > 30) { 130 | // overflow (); // TODO: this should not be commented out 131 | return s | 0x7c00; 132 | } 133 | return s | (e << 10) | (m >> 13); 134 | } 135 | } 136 | 137 | void dequantizeQ40Row(const BlockQ40* x, float* y, int k) { 138 | static const int qk = QK40; 139 | assert(k % qk == 0); 140 | const int nb = k / qk; 141 | 142 | #if defined(__ARM_NEON) 143 | const uint8x16_t m4b = vdupq_n_u8(0x0F); 144 | const int8x16_t s8b = vdupq_n_s8(0x8); 145 | 146 | for (int i = 0; i < nb; i++) { 147 | const BlockQ40* b = &x[i]; 148 | const float d = convertF16ToF32(b->d); 149 | 150 | const uint8x16_t v0_0 = vld1q_u8(b->qs); 151 | 152 | const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); 153 | const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); 154 | 155 | const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); 156 | const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); 157 | 158 | int8x8_t r1 = vget_low_s8(v0_0ls); 159 | int8x8_t r2 = vget_high_s8(v0_0ls); 160 | int8x8_t r3 = vget_low_s8(v0_0hs); 161 | int8x8_t r4 = vget_high_s8(v0_0hs); 162 | 163 | for (int j = 0; j < 8; j++) { 164 | y[i * qk + j + 0] = r1[j] * d; 165 | y[i * qk + j + 8] = r2[j] * d; 166 | y[i * qk + j + 16] = r3[j] * d; 167 | y[i * qk + j + 24] = r4[j] * d; 168 | } 169 | } 170 | #else 171 | for (int i = 0; i < nb; i++) { 172 | const BlockQ40* b = &x[i]; 173 | const float d = convertF16ToF32(b->d); 174 | 175 | for (int j = 0; j < qk / 2; ++j) { 176 | const int x0 = (b->qs[j] & 0x0F) - 8; 177 | const int x1 = (b->qs[j] >> 4) - 8; 178 | 179 | y[i * qk + j] = x0 * d; 180 | y[i * qk + j + qk / 2] = x1 * d; 181 | } 182 | } 183 | #endif 184 | } 185 | 186 | void quantizeQ80Row(float* input, BlockQ80* output, int k, unsigned int nThreads, unsigned int threadIndex) { 187 | assert(k % QK80 == 0); 188 | 189 | const int nBlocks = k / QK80; 190 | const int blocksPerThread = nBlocks / nThreads; 191 | const int sk = blocksPerThread * QK80; 192 | const int currentThreadBlocks = blocksPerThread + (threadIndex == nThreads - 1 ? nBlocks % nThreads : 0); 193 | 194 | const float* x = &input[sk * threadIndex]; 195 | BlockQ80* y = &output[blocksPerThread * threadIndex]; 196 | 197 | #if defined(__ARM_NEON) 198 | float dBuf[4]; 199 | 200 | for (int i = 0; i < currentThreadBlocks; i++) { 201 | float32x4_t srcv [8]; 202 | float32x4_t asrcv[8]; 203 | float32x4_t amaxv[8]; 204 | 205 | for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); 206 | for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); 207 | 208 | for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); 209 | for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); 210 | for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); 211 | 212 | const float amax = vmaxvq_f32(amaxv[0]); 213 | 214 | const float d = amax / ((1 << 7) - 1); 215 | const float id = d ? 1.0f/d : 0.0f; 216 | 217 | int dbi = i % 4; 218 | dBuf[dbi] = d; 219 | if (dbi == 3) { 220 | float32x4_t dBuf32 = vld1q_f32(dBuf); 221 | int16x4_t dBuf16 = (int16x4_t)vcvt_f16_f32(dBuf32); 222 | 223 | y[i - 3].d = dBuf16[0]; 224 | y[i - 2].d = dBuf16[1]; 225 | y[i - 1].d = dBuf16[2]; 226 | y[i - 0].d = dBuf16[3]; 227 | } 228 | 229 | for (int j = 0; j < 8; j++) { 230 | const float32x4_t v = vmulq_n_f32(srcv[j], id); 231 | const int32x4_t vi = vcvtnq_s32_f32(v); 232 | 233 | y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); 234 | y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); 235 | y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); 236 | y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); 237 | } 238 | } 239 | 240 | int rest = currentThreadBlocks % 4; 241 | if (rest != 0) { 242 | float32x4_t dBuf32 = vld1q_f32(dBuf); 243 | int16x4_t dBuf16 = (int16x4_t)vcvt_f16_f32(dBuf32); 244 | for (int i = 0; i < rest; i++) { 245 | y[currentThreadBlocks - rest + i].d = dBuf16[i]; 246 | } 247 | } 248 | #else 249 | for (int i = 0; i < currentThreadBlocks; i++) { 250 | float amax = 0.0f; 251 | 252 | for (int j = 0; j < QK80; j++) { 253 | const float v = fabsf(x[i*QK80 + j]); 254 | amax = amax > v ? amax : v; 255 | } 256 | 257 | const float d = amax / ((1 << 7) - 1); 258 | const float id = d ? 1.0f/d : 0.0f; 259 | 260 | y[i].d = convertF32ToF16(d); 261 | 262 | for (int j = 0; j < QK80; ++j) { 263 | const float x0 = x[i*QK80 + j]*id; 264 | y[i].qs[j] = roundf(x0); 265 | } 266 | } 267 | #endif 268 | } 269 | 270 | void dequantizeQ80Row(const BlockQ80* input, float* output, int k, unsigned int nThreads, unsigned int threadIndex) { 271 | assert(k % QK80 == 0); 272 | 273 | const int nBlocks = k / QK80; 274 | const int blocksPerThread = nBlocks / nThreads; 275 | const int sk = blocksPerThread * QK80; 276 | const int currentThreadBlocks = blocksPerThread + (threadIndex == nThreads - 1 ? nBlocks % nThreads : 0); 277 | 278 | const BlockQ80* x = &input[blocksPerThread * threadIndex]; 279 | float* y = &output[sk * threadIndex]; 280 | 281 | for (int i = 0; i < currentThreadBlocks; i++) { 282 | const float d = convertF16ToF32(x[i].d); 283 | 284 | for (int j = 0; j < QK80; ++j) { 285 | y[i*QK80 + j] = x[i].qs[j]*d; 286 | } 287 | } 288 | } 289 | 290 | void initQuants() { 291 | initF16ToF32(); 292 | } 293 | -------------------------------------------------------------------------------- /src/quants.hpp: -------------------------------------------------------------------------------- 1 | #ifndef QUANTS_HPP 2 | #define QUANTS_HPP 3 | 4 | #include 5 | 6 | enum FloatType { 7 | FUNK = -1, 8 | F32 = 0, 9 | F16 = 1, 10 | Q40 = 2, 11 | Q80 = 3 12 | }; 13 | 14 | #define QK40 32 15 | #define QK80 32 16 | 17 | typedef struct { 18 | uint16_t d; // delta 19 | uint8_t qs[QK40 / 2]; // nibbles / quants 20 | } BlockQ40; 21 | 22 | typedef struct { 23 | uint16_t d; // delta 24 | int8_t qs[QK80]; // quants 25 | } BlockQ80; 26 | 27 | void initQuants(); 28 | 29 | int getNumbersPerBatch(FloatType type); 30 | long getBatchBytes(FloatType type, int n, int d); 31 | float convertF16ToF32(uint16_t value); 32 | 33 | void dequantizeQ40Row(const BlockQ40* x, float* y, int k); 34 | void quantizeQ80Row(float* input, BlockQ80* output, int k, unsigned int nThreads, unsigned int threadIndex); 35 | void dequantizeQ80Row(const BlockQ80* input, float* output, int k, unsigned int nThreads, unsigned int threadIndex); 36 | 37 | #endif 38 | -------------------------------------------------------------------------------- /src/socket.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "socket.hpp" 13 | 14 | #ifdef _WIN32 15 | #include 16 | #include // For inet_addr and other functions 17 | #include // For SSIZE_T 18 | typedef SSIZE_T ssize_t; 19 | #define close closesocket 20 | #else 21 | #include 22 | #include 23 | #include 24 | #include 25 | #endif 26 | 27 | #define SOCKET_LAST_ERRCODE errno 28 | #define SOCKET_LAST_ERROR strerror(errno) 29 | 30 | static inline bool isEagainError() { 31 | #ifdef _WIN32 32 | return WSAGetLastError() == WSAEWOULDBLOCK; 33 | #else 34 | return SOCKET_LAST_ERRCODE == EAGAIN; 35 | #endif 36 | } 37 | 38 | static inline void setNonBlocking(int socket, bool enabled) { 39 | #ifdef _WIN32 40 | u_long mode = enabled ? 1 : 0; 41 | if (ioctlsocket(socket, FIONBIO, &mode) != 0) { 42 | throw std::runtime_error("Error setting socket to non-blocking"); 43 | } 44 | #else 45 | int flags = fcntl(socket, F_GETFL, 0); 46 | if (enabled) { 47 | flags |= O_NONBLOCK; 48 | } else { 49 | flags = flags & (~O_NONBLOCK); 50 | } 51 | if (fcntl(socket, F_SETFL, flags) < 0) 52 | throw std::runtime_error("Error setting socket to non-blocking"); 53 | #endif 54 | } 55 | 56 | static inline void setNoDelay(int socket) { 57 | int flag = 1; 58 | if (setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, sizeof(int)) < 0) 59 | throw std::runtime_error("Error setting socket to no-delay"); 60 | } 61 | 62 | static inline void setQuickAck(int socket) { 63 | #ifndef _WIN32 64 | #ifdef TCP_QUICKACK 65 | int value = 1; 66 | if (setsockopt(socket, IPPROTO_TCP, TCP_QUICKACK, (char*)&value, sizeof(int)) < 0) 67 | throw std::runtime_error("Error setting quick ack"); 68 | #endif 69 | #endif 70 | } 71 | 72 | static inline void setReuseAddr(int socket) { 73 | int opt = 1; 74 | #ifdef _WIN32 75 | int iresult = setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (char*)&opt, sizeof(opt)); 76 | if (iresult == SOCKET_ERROR) { 77 | closesocket(socket); 78 | throw std::runtime_error("setsockopt failed: " + std::to_string(WSAGetLastError())); 79 | } 80 | #else 81 | if (setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { 82 | close(socket); 83 | throw std::runtime_error("setsockopt failed: " + std::string(strerror(errno))); 84 | } 85 | #endif 86 | } 87 | 88 | static inline void writeSocket(int socket, const void* data, size_t size) { 89 | while (size > 0) { 90 | int s = send(socket, (const char*)data, size, 0); 91 | if (s < 0) { 92 | if (isEagainError()) { 93 | continue; 94 | } 95 | throw WriteSocketException(0, "Error writing to socket"); 96 | } else if (s == 0) { 97 | throw WriteSocketException(0, "Socket closed"); 98 | } 99 | size -= s; 100 | data = (const char*)data + s; 101 | } 102 | } 103 | 104 | static inline bool tryReadSocket(int socket, void* data, size_t size, unsigned long maxAttempts) { 105 | // maxAttempts = 0 means infinite attempts 106 | size_t s = size; 107 | while (s > 0) { 108 | int r = recv(socket, (char*)data, s, 0); 109 | if (r < 0) { 110 | if (isEagainError()) { 111 | if (s == size && maxAttempts > 0) { 112 | maxAttempts--; 113 | if (maxAttempts == 0) { 114 | return false; 115 | } 116 | } 117 | continue; 118 | } 119 | throw ReadSocketException(0, "Error reading from socket"); 120 | } else if (r == 0) { 121 | throw ReadSocketException(0, "Socket closed"); 122 | } 123 | data = (char*)data + r; 124 | s -= r; 125 | } 126 | return true; 127 | } 128 | 129 | static inline void readSocket(int socket, void* data, size_t size) { 130 | if (!tryReadSocket(socket, data, size, 0)) { 131 | throw std::runtime_error("Error reading from socket"); 132 | } 133 | } 134 | 135 | void initSockets() { 136 | #ifdef _WIN32 137 | WSADATA wsaData; 138 | if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { 139 | throw std::runtime_error("WSAStartup failed: " + std::to_string(WSAGetLastError())); 140 | } 141 | #endif 142 | } 143 | 144 | void cleanupSockets() { 145 | #ifdef _WIN32 146 | WSACleanup(); 147 | #endif 148 | } 149 | 150 | ReadSocketException::ReadSocketException(int code, const char* message) { 151 | this->code = code; 152 | this->message = message; 153 | } 154 | 155 | WriteSocketException::WriteSocketException(int code, const char* message) { 156 | this->code = code; 157 | this->message = message; 158 | } 159 | 160 | SocketPool* SocketPool::connect(unsigned int nSockets, char** hosts, int* ports) { 161 | int* sockets = new int[nSockets]; 162 | struct sockaddr_in addr; 163 | 164 | for (unsigned int i = 0; i < nSockets; i++) { 165 | memset(&addr, 0, sizeof(addr)); 166 | addr.sin_family = AF_INET; 167 | addr.sin_addr.s_addr = inet_addr(hosts[i]); 168 | addr.sin_port = htons(ports[i]); 169 | 170 | int clientSocket = socket(AF_INET, SOCK_STREAM, 0); 171 | if (clientSocket < 0) 172 | throw std::runtime_error("Cannot create socket"); 173 | 174 | int connectResult = ::connect(clientSocket, (struct sockaddr*)&addr, sizeof(addr)); 175 | if (connectResult != 0) { 176 | printf("Cannot connect to %s:%d (%s)\n", hosts[i], ports[i], SOCKET_LAST_ERROR); 177 | throw std::runtime_error("Cannot connect"); 178 | } 179 | 180 | setNoDelay(clientSocket); 181 | setQuickAck(clientSocket); 182 | sockets[i] = clientSocket; 183 | } 184 | return new SocketPool(nSockets, sockets); 185 | } 186 | 187 | SocketPool::SocketPool(unsigned int nSockets, int* sockets) { 188 | this->nSockets = nSockets; 189 | this->sockets = sockets; 190 | this->sentBytes.exchange(0); 191 | this->recvBytes.exchange(0); 192 | } 193 | 194 | SocketPool::~SocketPool() { 195 | for (unsigned int i = 0; i < nSockets; i++) { 196 | shutdown(sockets[i], 2); 197 | close(sockets[i]); 198 | } 199 | delete[] sockets; 200 | } 201 | 202 | void SocketPool::setTurbo(bool enabled) { 203 | for (unsigned int i = 0; i < nSockets; i++) { 204 | ::setNonBlocking(sockets[i], enabled); 205 | } 206 | } 207 | 208 | void SocketPool::write(unsigned int socketIndex, const void* data, size_t size) { 209 | assert(socketIndex >= 0 && socketIndex < nSockets); 210 | sentBytes += size; 211 | writeSocket(sockets[socketIndex], data, size); 212 | } 213 | 214 | void SocketPool::read(unsigned int socketIndex, void* data, size_t size) { 215 | assert(socketIndex >= 0 && socketIndex < nSockets); 216 | recvBytes += size; 217 | readSocket(sockets[socketIndex], data, size); 218 | } 219 | 220 | void SocketPool::writeMany(unsigned int n, SocketIo* ios) { 221 | bool isWriting; 222 | for (unsigned int i = 0; i < n; i++) { 223 | SocketIo* io = &ios[i]; 224 | assert(io->socketIndex >= 0 && io->socketIndex < nSockets); 225 | sentBytes += io->size; 226 | } 227 | do { 228 | isWriting = false; 229 | for (unsigned int i = 0; i < n; i++) { 230 | SocketIo* io = &ios[i]; 231 | if (io->size > 0) { 232 | isWriting = true; 233 | int socket = sockets[io->socketIndex]; 234 | ssize_t s = send(socket, (const char*)io->data, io->size, 0); 235 | if (s < 0) { 236 | if (isEagainError()) { 237 | continue; 238 | } 239 | throw WriteSocketException(SOCKET_LAST_ERRCODE, SOCKET_LAST_ERROR); 240 | } else if (s == 0) { 241 | throw WriteSocketException(0, "Socket closed"); 242 | } 243 | io->size -= s; 244 | io->data = (char*)io->data + s; 245 | } 246 | } 247 | } while (isWriting); 248 | } 249 | 250 | void SocketPool::readMany(unsigned int n, SocketIo* ios) { 251 | bool isReading; 252 | for (unsigned int i = 0; i < n; i++) { 253 | SocketIo* io = &ios[i]; 254 | assert(io->socketIndex >= 0 && io->socketIndex < nSockets); 255 | recvBytes += io->size; 256 | } 257 | do { 258 | isReading = false; 259 | for (unsigned int i = 0; i < n; i++) { 260 | SocketIo* io = &ios[i]; 261 | if (io->size > 0) { 262 | isReading = true; 263 | int socket = sockets[io->socketIndex]; 264 | ssize_t r = recv(socket, (char*)io->data, io->size, 0); 265 | if (r < 0) { 266 | if (isEagainError()) { 267 | continue; 268 | } 269 | throw ReadSocketException(SOCKET_LAST_ERRCODE, SOCKET_LAST_ERROR); 270 | } else if (r == 0) { 271 | throw ReadSocketException(0, "Socket closed"); 272 | } 273 | io->size -= r; 274 | io->data = (char*)io->data + r; 275 | } 276 | } 277 | } while (isReading); 278 | } 279 | 280 | void SocketPool::getStats(size_t* sentBytes, size_t* recvBytes) { 281 | *sentBytes = this->sentBytes; 282 | *recvBytes = this->recvBytes; 283 | this->sentBytes.exchange(0); 284 | this->recvBytes.exchange(0); 285 | } 286 | 287 | Socket SocketServer::accept() { 288 | struct sockaddr_in clientAddr; 289 | socklen_t clientAddrSize = sizeof(clientAddr); 290 | int clientSocket = ::accept(socket, (struct sockaddr*)&clientAddr, &clientAddrSize); 291 | if (clientSocket < 0) 292 | throw std::runtime_error("Error accepting connection"); 293 | setNoDelay(clientSocket); 294 | setQuickAck(clientSocket); 295 | return Socket(clientSocket); 296 | } 297 | 298 | Socket::Socket(int socket) { 299 | this->socket = socket; 300 | } 301 | 302 | Socket::~Socket() { 303 | shutdown(socket, 2); 304 | close(socket); 305 | } 306 | 307 | void Socket::setTurbo(bool enabled) { 308 | ::setNonBlocking(socket, enabled); 309 | } 310 | 311 | void Socket::write(const void* data, size_t size) { 312 | writeSocket(socket, data, size); 313 | } 314 | 315 | void Socket::read(void* data, size_t size) { 316 | readSocket(socket, data, size); 317 | } 318 | 319 | bool Socket::tryRead(void* data, size_t size, unsigned long maxAttempts) { 320 | return tryReadSocket(socket, data, size, maxAttempts); 321 | } 322 | 323 | std::vector Socket::readHttpRequest() { 324 | std::vector httpRequest; 325 | char buffer[1024 * 1024]; // TODO: this should be refactored asap 326 | ssize_t bytesRead; 327 | 328 | // Peek into the socket buffer to check available data 329 | bytesRead = recv(socket, buffer, sizeof(buffer), MSG_PEEK); 330 | if (bytesRead <= 0) { 331 | // No data available or error occurred 332 | if (bytesRead == 0) { 333 | // No more data to read 334 | return httpRequest; 335 | } else { 336 | // Error while peeking 337 | throw std::runtime_error("Error while peeking into socket"); 338 | } 339 | } 340 | 341 | // Resize buffer according to the amount of data available 342 | std::vector peekBuffer(bytesRead); 343 | bytesRead = recv(socket, peekBuffer.data(), bytesRead, 0); 344 | if (bytesRead <= 0) { 345 | // Error while reading 346 | throw std::runtime_error("Error while reading from socket"); 347 | } 348 | 349 | // Append data to httpRequest 350 | httpRequest.insert(httpRequest.end(), peekBuffer.begin(), peekBuffer.end()); 351 | 352 | return httpRequest; 353 | } 354 | 355 | SocketServer::SocketServer(int port) { 356 | const char* host = "0.0.0.0"; 357 | struct sockaddr_in serverAddr; 358 | 359 | socket = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); 360 | if (socket < 0) 361 | throw std::runtime_error("Cannot create socket"); 362 | setReuseAddr(socket); 363 | 364 | memset(&serverAddr, 0, sizeof(serverAddr)); 365 | serverAddr.sin_family = AF_INET; 366 | serverAddr.sin_port = htons(port); 367 | serverAddr.sin_addr.s_addr = inet_addr(host); 368 | 369 | int bindResult; 370 | #ifdef _WIN32 371 | bindResult = bind(socket, (SOCKADDR*)&serverAddr, sizeof(serverAddr)); 372 | if (bindResult == SOCKET_ERROR) { 373 | int error = WSAGetLastError(); 374 | closesocket(socket); 375 | throw std::runtime_error("Cannot bind port: " + std::to_string(error)); 376 | } 377 | #else 378 | bindResult = bind(socket, (struct sockaddr*)&serverAddr, sizeof(serverAddr)); 379 | if (bindResult < 0) { 380 | close(socket); 381 | throw std::runtime_error("Cannot bind port: " + std::string(strerror(errno))); 382 | } 383 | #endif 384 | 385 | int listenResult = listen(socket, SOMAXCONN); 386 | if (listenResult != 0) { 387 | #ifdef _WIN32 388 | closesocket(socket); 389 | throw std::runtime_error("Cannot listen on port: " + std::to_string(WSAGetLastError())); 390 | #else 391 | close(socket); 392 | throw std::runtime_error("Cannot listen on port: " + std::string(strerror(errno))); 393 | #endif 394 | } 395 | 396 | printf("Listening on %s:%d...\n", host, port); 397 | } 398 | 399 | SocketServer::~SocketServer() { 400 | shutdown(socket, 2); 401 | close(socket); 402 | } 403 | -------------------------------------------------------------------------------- /src/socket.hpp: -------------------------------------------------------------------------------- 1 | #ifndef SOCKET_HPP 2 | #define SOCKET_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | void initSockets(); 10 | void cleanupSockets(); 11 | 12 | class ReadSocketException : public std::exception { 13 | public: 14 | int code; 15 | const char* message; 16 | ReadSocketException(int code, const char* message); 17 | }; 18 | 19 | class WriteSocketException : public std::exception { 20 | public: 21 | int code; 22 | const char* message; 23 | WriteSocketException(int code, const char* message); 24 | }; 25 | 26 | struct SocketIo { 27 | unsigned int socketIndex; 28 | const void* data; 29 | size_t size; 30 | }; 31 | 32 | class SocketPool { 33 | private: 34 | int* sockets; 35 | std::atomic_uint sentBytes; 36 | std::atomic_uint recvBytes; 37 | 38 | public: 39 | static SocketPool* connect(unsigned int nSockets, char** hosts, int* ports); 40 | 41 | unsigned int nSockets; 42 | 43 | SocketPool(unsigned int nSockets, int* sockets); 44 | ~SocketPool(); 45 | 46 | void setTurbo(bool enabled); 47 | void write(unsigned int socketIndex, const void* data, size_t size); 48 | void read(unsigned int socketIndex, void* data, size_t size); 49 | void writeMany(unsigned int n, SocketIo* ios); 50 | void readMany(unsigned int n, SocketIo* ios); 51 | void getStats(size_t* sentBytes, size_t* recvBytes); 52 | }; 53 | 54 | class Socket { 55 | private: 56 | int socket; 57 | 58 | public: 59 | Socket(int socket); 60 | ~Socket(); 61 | 62 | void setTurbo(bool enabled); 63 | void write(const void* data, size_t size); 64 | void read(void* data, size_t size); 65 | bool tryRead(void* data, size_t size, unsigned long maxAttempts); 66 | std::vector readHttpRequest(); 67 | }; 68 | 69 | class SocketServer { 70 | private: 71 | int socket; 72 | public: 73 | SocketServer(int port); 74 | ~SocketServer(); 75 | Socket accept(); 76 | }; 77 | 78 | #endif 79 | -------------------------------------------------------------------------------- /src/tasks.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "tasks.hpp" 6 | 7 | TransformerArch::TransformerArch() { 8 | inference.nTasks = 0; 9 | worker.nTasks = 0; 10 | } 11 | 12 | TransformerArch::~TransformerArch() { 13 | if (inference.nTasks > 0) { 14 | delete[] inference.tasks; 15 | } 16 | if (worker.nTasks > 0) { 17 | delete[] worker.tasks; 18 | } 19 | } 20 | 21 | void addTask(TaskLoopHandler* handler, const char* taskName, unsigned int taskType, TransformerTasks* tasks) { 22 | const int alloc = 32; 23 | if (tasks->nTasks % alloc == 0) { 24 | TaskLoopTask* newTasks = new TaskLoopTask[tasks->nTasks + alloc]; 25 | if (tasks->nTasks > 0) { 26 | memcpy(newTasks, tasks->tasks, tasks->nTasks * sizeof(TaskLoopTask)); 27 | delete[] tasks->tasks; 28 | } 29 | tasks->tasks = newTasks; 30 | } 31 | tasks->tasks[tasks->nTasks].handler = handler; 32 | tasks->tasks[tasks->nTasks].taskType = taskType; 33 | tasks->tasks[tasks->nTasks].executionCount = 0; 34 | tasks->tasks[tasks->nTasks].executionTime = 0; 35 | tasks->tasks[tasks->nTasks].taskName = taskName; 36 | tasks->nTasks++; 37 | } 38 | 39 | void TransformerArch::I(TaskLoopHandler* handler, const char* taskName, unsigned int taskType) { 40 | addTask(handler, taskName, taskType, &inference); 41 | } 42 | 43 | void TransformerArch::W(TaskLoopHandler* handler, const char* taskName, unsigned int taskType) { 44 | addTask(handler, taskName, taskType, &worker); 45 | } 46 | 47 | void syncUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex) { 48 | void* buffer = ctx->transformer->buffer->getUnit(bufferIndex); 49 | size_t bufferBytes = ctx->transformer->buffer->getUnitBytes(bufferIndex); 50 | 51 | if (ctx->socketPool != NULL) { 52 | // root 53 | 54 | unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0); 55 | if (nSockets > 0) { 56 | SocketIo ios[nSockets]; 57 | for (int i = 0; i < nSockets; i++) { 58 | ios[i].socketIndex = threadIndex + i * nThreads; 59 | ios[i].data = buffer; 60 | ios[i].size = bufferBytes; 61 | } 62 | ctx->socketPool->writeMany(nSockets, ios); 63 | } 64 | } else if (ctx->socket != NULL) { 65 | if (threadIndex != 0) return; 66 | 67 | // worker 68 | ctx->socket->read(buffer, bufferBytes); 69 | } 70 | } 71 | 72 | void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex) { 73 | size_t bufferBytes = ctx->transformer->buffer->getSlicedBytes(bufferIndex); 74 | if (ctx->socketPool != NULL) { 75 | // root 76 | 77 | unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0); 78 | if (nSockets > 0) { 79 | SocketIo ios[nSockets]; 80 | for (int i = 0; i < nSockets; i++) { 81 | int socketIndex = threadIndex + i * nThreads; 82 | uint8_t workerSliceIndex = socketIndex + 1; 83 | ios[i].socketIndex = socketIndex; 84 | ios[i].data = ctx->transformer->buffer->getSliced(bufferIndex, workerSliceIndex); 85 | ios[i].size = bufferBytes; 86 | } 87 | 88 | ctx->socketPool->readMany(nSockets, ios); 89 | } 90 | } else if (ctx->socket != NULL) { 91 | if (threadIndex != 0) return; 92 | 93 | // worker 94 | void* buffer = ctx->transformer->buffer->getSliced(bufferIndex, ctx->transformer->sliceIndex); 95 | ctx->socket->write(buffer, bufferBytes); 96 | } 97 | } 98 | 99 | void quantizeUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t sourceBufferIndex, uint8_t targetBufferIndex) { 100 | if (ctx->transformer->spec->bufferFloatType == F32) return; 101 | assert(ctx->transformer->spec->bufferFloatType == Q80); 102 | 103 | quantizeQ80Row( 104 | (float*)ctx->transformer->buffer->getUnit(sourceBufferIndex), 105 | (BlockQ80*)ctx->transformer->buffer->getUnit(targetBufferIndex), 106 | ctx->transformer->buffer->getUnitBytes(sourceBufferIndex) / sizeof(float), 107 | nThreads, 108 | threadIndex); 109 | } 110 | 111 | void quantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool quantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex) { 112 | if (ctx->transformer->spec->bufferFloatType == F32) return; 113 | if (ctx->transformer->sliceIndex == 0 && !quantizeRootSlice) return; 114 | assert(ctx->transformer->spec->bufferFloatType == Q80); 115 | 116 | quantizeQ80Row( 117 | (float*)ctx->transformer->buffer->getSliced(sourceBufferIndex, ctx->transformer->sliceIndex), 118 | (BlockQ80*)ctx->transformer->buffer->getSliced(targetBufferIndex, ctx->transformer->sliceIndex), 119 | ctx->transformer->buffer->getSlicedBytes(sourceBufferIndex) / sizeof(float), 120 | nThreads, 121 | threadIndex); 122 | } 123 | 124 | void dequantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool dequantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex) { 125 | if (ctx->transformer->spec->bufferFloatType == F32) return; 126 | assert(ctx->transformer->spec->bufferFloatType == Q80); 127 | assert(ctx->socketPool != NULL); // This function may be called only by root. 128 | 129 | unsigned int sliceIndex = dequantizeRootSlice ? 0 : 1; 130 | for (; sliceIndex < ctx->transformer->spec->nSlices; sliceIndex++) { 131 | dequantizeQ80Row( 132 | (BlockQ80*)ctx->transformer->buffer->getSliced(sourceBufferIndex, sliceIndex), 133 | (float*)ctx->transformer->buffer->getSliced(targetBufferIndex, sliceIndex), 134 | (ctx->transformer->buffer->getSlicedBytes(sourceBufferIndex) / sizeof(BlockQ80)) * QK80, 135 | nThreads, 136 | threadIndex); 137 | } 138 | } 139 | 140 | void sendPos(TASK_ARGS) { 141 | TASK_VARIABLES; 142 | 143 | if (ctx->socketPool != NULL) { 144 | unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0); 145 | if (nSockets > 0) { 146 | SocketIo ios[nSockets]; 147 | for (int i = 0; i < nSockets; i++) { 148 | ios[i].socketIndex = threadIndex + i * nThreads; 149 | ios[i].data = &transformer->pos; 150 | ios[i].size = sizeof(pos_t); 151 | } 152 | ctx->socketPool->writeMany(nSockets, ios); 153 | } 154 | } 155 | } 156 | 157 | bool tryWaitForPos(Transformer* transformer, Socket* socket, unsigned int maxAttempts) { 158 | return socket->tryRead(&transformer->pos, sizeof(pos_t), maxAttempts); 159 | } 160 | 161 | Inference::Inference(TransformerArch* arch, unsigned int nThreads, Transformer* transformer, SocketPool* socketPool) { 162 | this->transformer = transformer; 163 | this->socketPool = socketPool; 164 | this->arch = arch; 165 | context.transformer = transformer; 166 | context.socket = NULL; 167 | context.socketPool = socketPool; 168 | assert(arch->inference.tasks[0].handler == sendPos); 169 | taskLoop = new TaskLoop(nThreads, arch->inference.nTasks, TASK_N_TYPES, arch->inference.tasks, (void*)&context); 170 | } 171 | 172 | Inference::~Inference() { 173 | delete taskLoop; 174 | } 175 | 176 | float* Inference::infer(int token, pos_t pos) { 177 | transformer->pos = pos; 178 | 179 | float* contentRow = ((float*)transformer->tokenEmbeddingTable) + token * transformer->spec->dim; 180 | memcpy(transformer->x, contentRow, transformer->spec->dim * sizeof(float)); 181 | 182 | context.currentBlockIndex = 0; 183 | 184 | taskLoop->run(); 185 | 186 | return transformer->logits; 187 | } 188 | 189 | void Inference::getStats(unsigned long* inferenceTime, unsigned long* transferTime) { 190 | *inferenceTime = taskLoop->executionTime[TASK_TYPE_INFERENCE]; 191 | *transferTime = taskLoop->executionTime[TASK_TYPE_TRANSFER]; 192 | } 193 | 194 | Worker::Worker(TransformerArch* arch, unsigned int nThreads, Transformer* transformer, Socket* socket) { 195 | this->transformer = transformer; 196 | this->socket = socket; 197 | context.transformer = transformer; 198 | context.socket = socket; 199 | context.socketPool = NULL; 200 | taskLoop = new TaskLoop(nThreads, arch->worker.nTasks, TASK_N_TYPES, arch->worker.tasks, (void*)&context); 201 | } 202 | 203 | Worker::~Worker() { 204 | delete taskLoop; 205 | } 206 | 207 | void Worker::work() { 208 | const unsigned long maxAttempts = 10000; 209 | 210 | bool turbo = false; 211 | while (true) { 212 | const clock_t start = clock(); 213 | 214 | while (!tryWaitForPos(transformer, socket, maxAttempts)) { 215 | if (turbo) { 216 | // After one second of waiting with non-blocking read, we switch to blocking mode to not burn CPU. 217 | if (clock() - start > CLOCKS_PER_SEC) { 218 | socket->setTurbo(false); 219 | turbo = false; 220 | printf("🚁 Socket is in blocking mode\n"); 221 | } 222 | } 223 | } 224 | if (!turbo) { 225 | socket->setTurbo(true); 226 | turbo = true; 227 | printf("🚁 Socket is in non-blocking mode\n"); 228 | } 229 | 230 | context.currentBlockIndex = 0; 231 | taskLoop->run(); 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /src/tasks.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TASKS_HPP 2 | #define TASKS_HPP 3 | 4 | #include "transformer.hpp" 5 | #include "utils.hpp" 6 | 7 | #define TASK_ARGS unsigned int nThreads, unsigned int threadIndex, void* userData 8 | #define TASK_WITH_NAME(x) x, #x 9 | 10 | #define TASK_N_TYPES 2 11 | #define TASK_TYPE_INFERENCE 0 12 | #define TASK_TYPE_TRANSFER 1 13 | 14 | struct TransformerContext { 15 | Transformer* transformer; 16 | Socket* socket; 17 | SocketPool* socketPool; 18 | unsigned int currentBlockIndex; 19 | }; 20 | 21 | typedef void (InferenceInitializer)(TransformerContext* context); 22 | 23 | struct TransformerTasks { 24 | unsigned int nTasks; 25 | TaskLoopTask* tasks; 26 | }; 27 | 28 | class TransformerArch { 29 | public: 30 | TransformerTasks inference; 31 | TransformerTasks worker; 32 | 33 | TransformerArch(); 34 | ~TransformerArch(); 35 | 36 | void I(TaskLoopHandler* handler, const char* taskName, unsigned int taskType); 37 | void W(TaskLoopHandler* handler, const char* taskName, unsigned int taskType); 38 | }; 39 | 40 | #define TASK_VARIABLES \ 41 | TransformerContext* ctx = (TransformerContext*)userData; \ 42 | Transformer* transformer = ctx->transformer; \ 43 | TransformerBlock* block = transformer->blocks[ctx->currentBlockIndex]; \ 44 | TransformerSpec* spec = transformer->spec; // printf("%s:%d\n", __FUNCTION__, ctx->currentBlockIndex); fflush(stdout); 45 | 46 | void syncUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex); 47 | void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex); 48 | void quantizeUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t sourceBufferIndex, uint8_t targetBufferIndex); 49 | void quantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool quantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex); 50 | void dequantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool dequantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex); 51 | void sendPos(TASK_ARGS); 52 | 53 | class Inference { 54 | private: 55 | Transformer* transformer; 56 | SocketPool* socketPool; 57 | TransformerContext context; 58 | TaskLoop *taskLoop; 59 | TransformerArch *arch; 60 | public: 61 | Inference(TransformerArch* arch, unsigned int nThreads, Transformer* transformer, SocketPool* socketPool); 62 | ~Inference(); 63 | float* infer(int token, pos_t pos); 64 | void getStats(unsigned long* inferenceTime, unsigned long* transferTime); 65 | }; 66 | 67 | class Worker { 68 | private: 69 | Transformer* transformer; 70 | Socket* socket; 71 | TransformerContext context; 72 | TaskLoop *taskLoop; 73 | public: 74 | Worker(TransformerArch* arch, unsigned int nThreads, Transformer* transformer, Socket* socket); 75 | ~Worker(); 76 | void work(); 77 | }; 78 | 79 | #endif 80 | -------------------------------------------------------------------------------- /src/tokenizer-test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "tokenizer.hpp" 5 | 6 | #define ASSERT_EOS_TYPE(type, expected) \ 7 | if (type != expected) { \ 8 | printf("Expected %d, got %d (line: %d)\n", expected, type, __LINE__); \ 9 | exit(1); \ 10 | } 11 | 12 | #define EOS_ID 10000 13 | 14 | void testChatTemplate() { 15 | ChatTemplate t0(TEMPLATE_UNKNOWN, "{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", ""); 16 | assert(t0.type == TEMPLATE_LLAMA3); 17 | 18 | ChatTemplate t1(TEMPLATE_UNKNOWN, "{{bos_token}}{\% for message in messages \%}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|im_start|>assistant\n' }}{\% endif \%}", ""); 19 | assert(t1.type == TEMPLATE_CHATML); 20 | 21 | ChatTemplate t2(TEMPLATE_UNKNOWN, "{\% for message in messages \%}\n{\% if message['role'] == 'user' \%}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'system' \%}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'assistant' \%}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{\% endif \%}\n{\% if loop.last and add_generation_prompt \%}\n{{ '<|assistant|>' }}\n{\% endif \%}\n{\% endfor \%}", ""); 22 | assert(t2.type == TEMPLATE_ZEPHYR); 23 | 24 | printf("✅ ChatTemplate\n"); 25 | } 26 | 27 | void testEosDetectorWithPadding() { 28 | const char* stops[2] = { "", "" }; 29 | EosDetector detector(EOS_ID, 2, stops, 1, 1); 30 | 31 | // "" 32 | { 33 | ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS); 34 | ASSERT_EOS_TYPE(detector.append(2, "eo"), MAYBE_EOS); 35 | ASSERT_EOS_TYPE(detector.append(3, "s>"), EOS); 36 | assert(detector.getDelta() == NULL); 37 | } 38 | 39 | // " " 40 | detector.clear(); 41 | { 42 | ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS); 43 | ASSERT_EOS_TYPE(detector.append(2, "stop"), MAYBE_EOS); 44 | ASSERT_EOS_TYPE(detector.append(3, "> "), EOS); 45 | assert(detector.getDelta() == NULL); 46 | } 47 | 48 | // " " 49 | detector.clear(); 50 | { 51 | ASSERT_EOS_TYPE(detector.append(1, " "), NOT_EOS); 52 | 53 | char* delta = detector.getDelta(); 54 | assert(delta != NULL); 55 | assert(strcmp(delta, " ") == 0); 56 | } 57 | 58 | // "! " 59 | detector.clear(); 60 | { 61 | ASSERT_EOS_TYPE(detector.append(1, "!<"), MAYBE_EOS); 62 | ASSERT_EOS_TYPE(detector.append(2, "eos"), MAYBE_EOS); 63 | ASSERT_EOS_TYPE(detector.append(3, "> "), EOS); 64 | 65 | char* delta = detector.getDelta(); 66 | assert(delta != NULL); 67 | assert(strcmp(delta, "!") == 0); 68 | } 69 | 70 | // "! " 71 | detector.clear(); 72 | { 73 | ASSERT_EOS_TYPE(detector.append(1, "XY"), NOT_EOS); 75 | 76 | char* delta = detector.getDelta(); 77 | assert(delta != NULL); 78 | assert(strcmp(delta, "XY") == 0); 79 | } 80 | 81 | // ""), EOS); 86 | 87 | char* delta = detector.getDelta(); 88 | assert(delta != NULL); 89 | assert(strcmp(delta, ""), EOS); 96 | assert(detector.getDelta() == NULL); 97 | } 98 | 99 | printf("✅ EosDetector with padding\n"); 100 | } 101 | 102 | 103 | void testEosDetectorWithLongPadding() { 104 | const char* stops[1] = { "|end|" }; 105 | EosDetector detector(EOS_ID, 1, stops, 5, 5); 106 | 107 | // "lipsum" 108 | { 109 | ASSERT_EOS_TYPE(detector.append(1, "lipsum"), NOT_EOS); 110 | char* delta = detector.getDelta(); 111 | assert(delta != NULL); 112 | assert(strcmp(delta, "lipsum") == 0); 113 | } 114 | 115 | // "lorem" 116 | detector.clear(); 117 | { 118 | ASSERT_EOS_TYPE(detector.append(1, "lorem"), NOT_EOS); 119 | char* delta = detector.getDelta(); 120 | assert(delta != NULL); 121 | assert(strcmp(delta, "lorem") == 0); 122 | } 123 | 124 | // "lorem|enQ" 125 | detector.clear(); 126 | { 127 | ASSERT_EOS_TYPE(detector.append(1, "lorem|"), MAYBE_EOS); 128 | ASSERT_EOS_TYPE(detector.append(2, "enQ"), NOT_EOS); 129 | char* delta = detector.getDelta(); 130 | assert(delta != NULL); 131 | assert(strcmp(delta, "lorem|enQ") == 0); 132 | } 133 | 134 | printf("✅ EosDetector with long padding\n"); 135 | } 136 | 137 | void testEosDetectorWithoutPadding() { 138 | const char* stops[1] = { "" }; 139 | EosDetector detector(EOS_ID, 1, stops, 0, 0); 140 | 141 | // "" 142 | { 143 | ASSERT_EOS_TYPE(detector.append(1, "<"), MAYBE_EOS); 144 | ASSERT_EOS_TYPE(detector.append(2, "eo"), MAYBE_EOS); 145 | ASSERT_EOS_TYPE(detector.append(3, "s>"), EOS); 146 | assert(detector.getDelta() == NULL); 147 | } 148 | 149 | // " <" 150 | detector.clear(); 151 | { 152 | ASSERT_EOS_TYPE(detector.append(1, " <"), NOT_EOS); 153 | char* delta = detector.getDelta(); 154 | assert(delta != NULL); 155 | assert(strcmp(delta, " <") == 0); 156 | } 157 | 158 | // " " 159 | detector.clear(); 160 | { 161 | ASSERT_EOS_TYPE(detector.append(1, " "), NOT_EOS); 163 | char* delta = detector.getDelta(); 164 | assert(delta != NULL); 165 | assert(strcmp(delta, " ") == 0); 166 | } 167 | 168 | // EOS 169 | detector.clear(); 170 | { 171 | ASSERT_EOS_TYPE(detector.append(EOS_ID, ""), EOS); 172 | assert(detector.getDelta() == NULL); 173 | } 174 | 175 | printf("✅ EosDetector without padding\n"); 176 | } 177 | 178 | int main() { 179 | testChatTemplate(); 180 | testEosDetectorWithPadding(); 181 | testEosDetectorWithLongPadding(); 182 | testEosDetectorWithoutPadding(); 183 | return EXIT_SUCCESS; 184 | } 185 | -------------------------------------------------------------------------------- /src/tokenizer.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TOKENIZER_HPP 2 | #define TOKENIZER_HPP 3 | 4 | #include 5 | #include 6 | #include "tasks.hpp" 7 | 8 | bool isSafePiece(char *piece); 9 | void safePrintf(char *piece); 10 | 11 | typedef struct { 12 | char *str; 13 | int id; 14 | } TokenIndex; 15 | 16 | struct TokenizerOldHeader { 17 | unsigned int vocabSize; 18 | unsigned int maxTokenLength; 19 | int bosId; 20 | int eosId; 21 | int padId; 22 | }; 23 | 24 | enum TokenizerHeaderKey { 25 | TOK_VERSION = 0, 26 | TOK_VOCAB_SIZE = 1, 27 | MAX_TOKEN_LENGTH = 2, 28 | BOS_ID = 3, 29 | EOS_ID = 4, 30 | PAD_ID = 5, 31 | CHAT_EOS_ID = 6, 32 | CHAT_TEMPLATE = 7, 33 | CHAT_STOP = 8, 34 | }; 35 | 36 | class Tokenizer { 37 | private: 38 | unsigned int maxTokenLength; 39 | float* vocabScores; 40 | TokenIndex *sortedVocab; 41 | int vocabSize; 42 | unsigned char bytePieces[512]; // stores all single-byte strings 43 | 44 | public: 45 | char** vocab; 46 | int bosId; 47 | int eosId; 48 | int chatEosId; 49 | char* chatTemplate; 50 | char* chatStop; 51 | 52 | Tokenizer(char* tokenizer_path, int vocab_size); 53 | ~Tokenizer(); 54 | void encode(char *text, int *tokens, int *nTokens, bool addBos, bool addEos); 55 | char* decode(int prev_token, int token); 56 | }; 57 | 58 | // struct used when sorting probabilities during top-p sampling 59 | typedef struct { 60 | float prob; 61 | int index; 62 | } ProbIndex; 63 | 64 | // The Sampler, which takes logits and returns a sampled token 65 | // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling 66 | class Sampler { 67 | private: 68 | int vocab_size; 69 | ProbIndex* probindex; // buffer used in top-p sampling 70 | float temperature; 71 | float topp; 72 | unsigned long long rngState; 73 | 74 | public: 75 | Sampler(int vocab_size, float temperature, float topp, unsigned long long rngSeed); 76 | ~Sampler(); 77 | int sample(float* logits); 78 | void setTemp(float temp); 79 | void setSeed(unsigned long long rngSeed); 80 | }; 81 | 82 | class TokenizerChatStops { 83 | public: 84 | const char** stops; 85 | size_t nStops; 86 | size_t maxStopLength; 87 | TokenizerChatStops(Tokenizer* tokenizer); 88 | ~TokenizerChatStops(); 89 | }; 90 | 91 | enum ChatTemplateType { 92 | TEMPLATE_UNKNOWN = 0, 93 | TEMPLATE_LLAMA2 = 1, 94 | TEMPLATE_LLAMA3 = 2, 95 | TEMPLATE_ZEPHYR = 3, 96 | TEMPLATE_CHATML = 4, 97 | }; 98 | 99 | struct ChatItem { 100 | std::string role; 101 | std::string message; 102 | }; 103 | 104 | class ChatTemplate { 105 | public: 106 | const char* eos; 107 | ChatTemplateType type; 108 | ChatTemplate(const ChatTemplateType type, const char* chatTemplate, const char* eos); 109 | std::string generate(unsigned int nMessages, ChatItem* items, bool appendGenerationPrompt); 110 | }; 111 | 112 | enum EosDetectorType { 113 | MAYBE_EOS = 0, 114 | EOS = 1, 115 | NOT_EOS = 2, 116 | }; 117 | 118 | class EosDetector { 119 | private: 120 | int eosId; 121 | size_t nStops; 122 | const char** stops; 123 | size_t* stopSizes; 124 | size_t bufferPos; 125 | size_t bufferSize; 126 | int eosPos; 127 | int paddingLeft; 128 | int paddingRight; 129 | public: 130 | char* buffer; 131 | EosDetector(int eosId, size_t nStops, const char** stops, int paddingLeft, int paddingRight); 132 | ~EosDetector(); 133 | 134 | EosDetectorType append(int tokenId, const char* piece); 135 | char* getDelta(); 136 | void clear(); 137 | }; 138 | 139 | #endif 140 | -------------------------------------------------------------------------------- /src/transformer.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TRANSFORMER_HPP 2 | #define TRANSFORMER_HPP 3 | 4 | #include 5 | #include 6 | #include "quants.hpp" 7 | #include "commands.hpp" 8 | #include "socket.hpp" 9 | 10 | enum TransformerHeaderKey { 11 | VERSION = 0, 12 | ARCH_TYPE = 1, 13 | DIM = 2, 14 | HIDDEN_DIM = 3, 15 | N_LAYERS = 4, 16 | N_HEADS = 5, 17 | N_KV_HEADS = 6, 18 | N_EXPERTS = 7, 19 | N_ACTIVE_EXPERTS = 8, 20 | VOCAB_SIZE = 9, 21 | SEQ_LEN = 10, 22 | HIDDEN_ACT = 11, 23 | ROPE_THETA = 12, 24 | WEIGHTS_FLOAT_TYPE = 13, 25 | ROPE_SCALING_FACTOR = 14, 26 | ROPE_SCALING_LOW_FREQ_FACTOR = 15, 27 | ROPE_SCALING_HIGH_FREQ_FACTORY = 16, 28 | ROPE_SCALING_ORIG_MAX_SEQ_LEN = 17, 29 | ROPE_TYPE = 18, 30 | }; 31 | 32 | struct TransformerFileOldHeader { 33 | int dim; 34 | int hiddenDim; 35 | int nLayers; 36 | int nHeads; 37 | int nKvHeads; 38 | int nExperts; 39 | int nActiveExperts; 40 | int vocabSize; 41 | int seqLen; 42 | }; 43 | 44 | enum TransformerArchType { 45 | LLAMA = 0xABCD00, 46 | GROK1 = 0xABCD01, 47 | MIXTRAL = 0xABCD02 48 | }; 49 | 50 | enum TransformerHiddenAct { 51 | GELU = 0, 52 | SILU = 1, 53 | }; 54 | 55 | enum TransformerRopeType { 56 | ROPE_UNKNOWN = -1, 57 | ROPE_LLAMA = 0, 58 | ROPE_FALCON = 1, 59 | ROPE_LLAMA3_1 = 2, 60 | }; 61 | 62 | struct TransformerSpec { 63 | size_t headerSize; 64 | size_t fileSize; 65 | int version; 66 | TransformerArchType archType; 67 | int dim; 68 | int nLayers; 69 | int nHeads; 70 | int headSize; 71 | int nKvHeads; 72 | int nExperts; 73 | int nActiveExperts; 74 | unsigned int origSeqLen; // Original model context length 75 | unsigned int seqLen; // Limited context length by the `--max-seq-len` argument 76 | int hiddenDim; 77 | TransformerHiddenAct hiddenAct; 78 | int kvDim; 79 | int vocabSize; 80 | float ropeTheta; 81 | TransformerRopeType ropeType; 82 | float ropeScalingFactor; 83 | float ropeScalingLowFreqFactor; 84 | float ropeScalingHighFreqFactory; 85 | int ropeScalingOrigMaxSeqLen; 86 | 87 | FloatType weightsFloatType; 88 | FloatType bufferFloatType; 89 | uint8_t nSlices; 90 | }; 91 | 92 | struct TransformerConfig { 93 | bool useDiscForKvCache; 94 | }; 95 | 96 | class TransformerBlock { 97 | public: 98 | slice_index_t sliceIndex; 99 | TransformerSpec *spec; 100 | TransformerConfig* config; 101 | 102 | size_t rmsAttBytes; 103 | float* rmsAtt; 104 | size_t rmsFfnBytes; 105 | float* rmsFfn; 106 | size_t rmsMoeBytes; 107 | float* rmsMoe; 108 | size_t rmsFfn2Bytes; 109 | float* rmsFfn2; 110 | 111 | MatmulCommand *q0mm; 112 | MatmulCommand *k0mm; 113 | MatmulCommand *v0mm; 114 | MatmulCommand *wo0mm; 115 | RowMatmulSlice* q0Slice; 116 | RowMatmulSlice* k0Slice; 117 | RowMatmulSlice* v0Slice; 118 | ColMatmulSlice* wo0Slice; 119 | 120 | MatmulCommand *w10mm; 121 | MatmulCommand *w20mm; 122 | MatmulCommand *w30mm; 123 | RowMatmulSlice* w10Slice; 124 | ColMatmulSlice* w20Slice; 125 | RowMatmulSlice* w30Slice; 126 | 127 | MatmulCommand* moeRouterMm; 128 | RowMatmulSlice* moeUpAndGate0Slice; 129 | RowMatmulSlice* moeDown0Slice; 130 | MatmulCommand** moeUpMm; 131 | MatmulCommand** moeGateMm; 132 | MatmulCommand** moeDownMm; 133 | 134 | float* moeRouterProbs; 135 | float* expertGate; 136 | float* expertDown; 137 | float* hb20; 138 | 139 | KvCacheSlice* kvCacheSlice; 140 | float* keyCache; 141 | float* valueCache; 142 | MultiHeadAttSlice* multiHeadAttSlice; 143 | float* att; 144 | float* qo0; 145 | 146 | TransformerBlock(TransformerSpec* spec, TransformerConfig* config, slice_index_t sliceIndex); 147 | ~TransformerBlock(); 148 | }; 149 | 150 | #define TB_LENGTH 11 151 | #define TB_NO_PAIRS 3 152 | 153 | #define TB_UNIT_XB 0 154 | #define TB_UNIT_XB_QUANTIZED 1 155 | #define TB_SLICED_XB2 2 156 | #define TB_SLICED_XB2_QUANTIZED 3 157 | #define TB_SLICED_XBV 4 158 | #define TB_SLICED_XBV_QUANTIZED 5 159 | #define TB_SLICED_HB 6 160 | #define TB_SLICED_HB_QUANTIZED 7 161 | #define TB_UNIT_MOE_INDEXES 8 162 | #define TB_UNIT_MOE_WEIGHTS 9 163 | #define TB_SLICED_LOGITS 10 164 | 165 | class TransformerBuffer { 166 | public: 167 | uint8_t nSlices; 168 | void** buffers; 169 | size_t* bufferBytes; 170 | 171 | TransformerBuffer(TransformerSpec* spec); 172 | ~TransformerBuffer(); 173 | void* getUnit(uint8_t bufferIndex); 174 | size_t getUnitBytes(uint8_t bufferIndex); 175 | void* getSliced(uint8_t bufferIndex, slice_index_t sliceIndex); 176 | size_t getSlicedBytes(uint8_t bufferIndex); 177 | }; 178 | 179 | class Transformer { 180 | public: 181 | TransformerSpec* spec; 182 | TransformerConfig* config; 183 | TransformerBlock** blocks; 184 | TransformerBuffer* buffer; 185 | slice_index_t sliceIndex; 186 | 187 | size_t tokenEmbeddingTableBytes; 188 | float* tokenEmbeddingTable; 189 | size_t rmsFinalBytes; 190 | float* rmsFinal; 191 | RowMatmulSlice* wclsSlice; 192 | MatmulCommand* wclsMm; 193 | 194 | pos_t pos; 195 | float rms; 196 | float* x; 197 | float* logits; 198 | RopeSlice* ropeSlice; 199 | RopeCommand* rope; 200 | 201 | ~Transformer(); 202 | 203 | static TransformerSpec loadSpecFromFile(const char* path, const unsigned int nSlices, const unsigned int maxSeqLen, FloatType weightsFloatType, FloatType bufferFloatType); 204 | static Transformer loadRootFromFile(const char* path, TransformerSpec* spec, TransformerConfig* config, SocketPool* socketPool); 205 | static Transformer loadRoot(char* data, TransformerSpec* spec, TransformerConfig* config, SocketPool* socketPool); 206 | static Transformer loadSlice(TransformerSpec* spec, TransformerConfig* config, Socket* socket); 207 | 208 | private: 209 | Transformer(TransformerSpec* spec, TransformerConfig* config, slice_index_t sliceIndex); 210 | }; 211 | 212 | #endif 213 | -------------------------------------------------------------------------------- /src/utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "utils.hpp" 10 | 11 | #define BUFFER_ALIGNMENT 16 12 | 13 | #ifdef _WIN32 14 | #include 15 | #else 16 | #include 17 | #include 18 | #include 19 | #endif 20 | 21 | void* newBuffer(size_t size) { 22 | void* buffer; 23 | #ifdef _WIN32 24 | buffer = _aligned_malloc(size, BUFFER_ALIGNMENT); 25 | if (buffer == NULL) { 26 | fprintf(stderr, "error: _aligned_malloc failed\n"); 27 | exit(EXIT_FAILURE); 28 | } 29 | #else 30 | if (posix_memalign((void**)&buffer, BUFFER_ALIGNMENT, size) != 0) { 31 | fprintf(stderr, "error: posix_memalign failed\n"); 32 | exit(EXIT_FAILURE); 33 | } 34 | if (mlock(buffer, size) != 0) { 35 | fprintf(stderr, "🚧 Cannot allocate %zu bytes directly in RAM\n", size); 36 | } 37 | #endif 38 | return buffer; 39 | } 40 | 41 | void freeBuffer(void* buffer) { 42 | #ifdef _WIN32 43 | _aligned_free(buffer); 44 | #else 45 | free(buffer); 46 | #endif 47 | } 48 | 49 | unsigned int lastMmapFileBufferIndex = 0; 50 | 51 | void* newMmapFileBuffer(unsigned int appInstanceId, size_t size) { 52 | #ifdef _WIN32 53 | throw new std::runtime_error("Mmap file buffer is not supported on Windows yet"); 54 | #else 55 | char path[256]; 56 | snprintf(path, 256, "mmap-buffer-%d-%d.temp", appInstanceId, lastMmapFileBufferIndex++); 57 | int fd = open(path, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); 58 | if (fd == -1) 59 | throw new std::runtime_error("Cannot create mmap buffer file"); 60 | if (ftruncate(fd, size) == -1) 61 | throw new std::runtime_error("Cannot truncate mmap buffer file. Not enough disk space?"); 62 | void *addr = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); 63 | if (addr == MAP_FAILED) 64 | throw new std::runtime_error("Cannot mmap buffer file"); 65 | close(fd); 66 | return addr; 67 | #endif 68 | } 69 | 70 | void freeMmapFileBuffer(void* addr) { 71 | // TODO 72 | } 73 | 74 | unsigned long timeMs() { 75 | struct timeval te; 76 | gettimeofday(&te, NULL); 77 | return te.tv_sec * 1000LL + te.tv_usec / 1000; 78 | } 79 | 80 | unsigned long timeUs() { 81 | struct timeval te; 82 | gettimeofday(&te, NULL); 83 | return te.tv_sec * 1000000LL + te.tv_usec; 84 | } 85 | 86 | unsigned int randomU32(unsigned long long *state) { 87 | // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A 88 | *state ^= *state >> 12; 89 | *state ^= *state << 25; 90 | *state ^= *state >> 27; 91 | return (*state * 0x2545F4914F6CDD1Dull) >> 32; 92 | } 93 | 94 | float randomF32(unsigned long long *state) { 95 | // random float32 in <0,1) 96 | return (randomU32(state) >> 8) / 16777216.0f; 97 | } 98 | 99 | long seekToEnd(FILE* file) { 100 | #ifdef _WIN32 101 | _fseeki64(file, 0, SEEK_END); 102 | return _ftelli64(file); 103 | #else 104 | fseek(file, 0, SEEK_END); 105 | return ftell(file); 106 | #endif 107 | } 108 | 109 | void openMmapFile(MmapFile* file, const char* path, size_t size) { 110 | file->size = size; 111 | #ifdef _WIN32 112 | file->hFile = CreateFileA(path, GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); 113 | if (file->hFile == INVALID_HANDLE_VALUE) { 114 | printf("Cannot open file %s\n", path); 115 | exit(EXIT_FAILURE); 116 | } 117 | 118 | file->hMapping = CreateFileMappingA(file->hFile, NULL, PAGE_READONLY, 0, 0, NULL); 119 | if (file->hMapping == NULL) { 120 | printf("CreateFileMappingA failed, error: %lu\n", GetLastError()); 121 | CloseHandle(file->hFile); 122 | exit(EXIT_FAILURE); 123 | } 124 | 125 | file->data = (char*)MapViewOfFile(file->hMapping, FILE_MAP_READ, 0, 0, 0); 126 | if (file->data == NULL) { 127 | printf("MapViewOfFile failed!\n"); 128 | CloseHandle(file->hMapping); 129 | CloseHandle(file->hFile); 130 | exit(EXIT_FAILURE); 131 | } 132 | #else 133 | file->fd = open(path, O_RDONLY); 134 | if (file->fd == -1) { 135 | printf("Cannot open file %s\n", path); 136 | exit(EXIT_FAILURE); 137 | } 138 | 139 | file->data = mmap(NULL, size, PROT_READ, MAP_PRIVATE, file->fd, 0); 140 | if (file->data == MAP_FAILED) { 141 | printf("Mmap failed!\n"); 142 | close(file->fd); 143 | exit(EXIT_FAILURE); 144 | } 145 | #endif 146 | } 147 | 148 | void closeMmapFile(MmapFile* file) { 149 | #ifdef _WIN32 150 | UnmapViewOfFile(file->data); 151 | CloseHandle(file->hMapping); 152 | CloseHandle(file->hFile); 153 | #else 154 | munmap(file->data, file->size); 155 | close(file->fd); 156 | #endif 157 | } 158 | 159 | TaskLoop::TaskLoop(unsigned int nThreads, unsigned int nTasks, unsigned int nTypes, TaskLoopTask* tasks, void* userData) { 160 | this->nThreads = nThreads; 161 | this->nTasks = nTasks; 162 | this->nTypes = nTypes; 163 | this->tasks = tasks; 164 | this->userData = userData; 165 | executionTime = new unsigned long[nTypes]; 166 | 167 | threads = new TaskLoopThread[nThreads]; 168 | for (unsigned int i = 0; i < nThreads; i++) { 169 | threads[i].threadIndex = i; 170 | threads[i].nTasks = nTasks; 171 | threads[i].loop = this; 172 | } 173 | } 174 | 175 | TaskLoop::~TaskLoop() { 176 | std::map > executionProfile; 177 | for (unsigned int i = 0; i < this->nTasks; i++) { 178 | auto it = executionProfile.find(this->tasks[i].taskName); 179 | if (it != executionProfile.end()) { 180 | it->second.first += this->tasks[i].executionCount; 181 | it->second.second += this->tasks[i].executionTime; 182 | } else { 183 | executionProfile.insert(std::make_pair(this->tasks[i].taskName, std::make_pair(this->tasks[i].executionCount, this->tasks[i].executionTime))); 184 | } 185 | } 186 | for (const auto& it : executionProfile) { 187 | printf("%s,%ld,%ld\n", it.first, it.second.first, it.second.second); 188 | } 189 | delete[] executionTime; 190 | delete[] threads; 191 | } 192 | 193 | void TaskLoop::run() { 194 | currentTaskIndex.exchange(0); 195 | doneThreadCount.exchange(0); 196 | 197 | unsigned int i; 198 | lastTime = timeUs(); 199 | for (i = 0; i < nTypes; i++) { 200 | executionTime[i] = 0; 201 | } 202 | 203 | for (i = 1; i < nThreads; i++) { 204 | int result = pthread_create(&threads[i].handler, NULL, (thread_func_t)threadHandler, (void*)&threads[i]); 205 | if (result != 0) { 206 | printf("Cannot created thread\n"); 207 | exit(EXIT_FAILURE); 208 | } 209 | } 210 | 211 | threadHandler((void*)&threads[0]); 212 | 213 | for (i = 1; i < nThreads; i++) { 214 | pthread_join(threads[i].handler, NULL); 215 | } 216 | } 217 | 218 | void* TaskLoop::threadHandler(void* arg) { 219 | TaskLoopThread* context = (TaskLoopThread*)arg; 220 | TaskLoop* loop = context->loop; 221 | unsigned int threadIndex = context->threadIndex; 222 | 223 | while (true) { 224 | const unsigned int currentTaskIndex = loop->currentTaskIndex.load(); 225 | if (currentTaskIndex == context->nTasks) { 226 | break; 227 | } 228 | 229 | TaskLoopTask* task = &loop->tasks[currentTaskIndex % loop->nTasks]; 230 | 231 | task->handler(loop->nThreads, threadIndex, loop->userData); 232 | 233 | int currentCount = loop->doneThreadCount.fetch_add(1); 234 | 235 | if (currentCount == loop->nThreads - 1) { 236 | unsigned long currentTime = timeUs(); 237 | unsigned long lastTime = loop->lastTime; 238 | loop->executionTime[task->taskType] += currentTime - lastTime; 239 | loop->lastTime = currentTime; 240 | 241 | loop->doneThreadCount.store(0); 242 | loop->currentTaskIndex.fetch_add(1); 243 | 244 | task->executionTime += currentTime - lastTime; 245 | task->executionCount += 1; 246 | } else { 247 | while (loop->currentTaskIndex.load() == currentTaskIndex) { 248 | // NOP 249 | } 250 | } 251 | } 252 | 253 | // printf("@ Thread %d stopped at step %d\n", threadIndex, unsigned(loop->currentTaskIndex)); 254 | return 0; 255 | } 256 | -------------------------------------------------------------------------------- /src/utils.hpp: -------------------------------------------------------------------------------- 1 | #ifndef UTILS_HPP 2 | #define UTILS_HPP 3 | 4 | #include 5 | #include 6 | #include "common/pthread.h" 7 | 8 | #define ALLOC_MEMORY true 9 | 10 | #ifdef _WIN32 11 | #include 12 | #endif 13 | 14 | #define SPLIT_RANGE_TO_THREADS(varStart, varEnd, rangeStart, rangeEnd, nThreads, threadIndex) \ 15 | const unsigned int rangeLen = (rangeEnd - rangeStart); \ 16 | const unsigned int rangeSlice = rangeLen / nThreads; \ 17 | const unsigned int rangeRest = rangeLen % nThreads; \ 18 | const unsigned int varStart = threadIndex * rangeSlice + (threadIndex < rangeRest ? threadIndex : rangeRest); \ 19 | const unsigned int varEnd = varStart + rangeSlice + (threadIndex < rangeRest ? 1 : 0); 20 | 21 | #define DEBUG_FLOATS(name, v, n) printf("⭕ %s ", name); for (int i = 0; i < n; i++) printf("%f ", v[i]); printf("\n"); 22 | 23 | void* newBuffer(size_t size); 24 | void freeBuffer(void* buffer); 25 | 26 | void* newMmapFileBuffer(unsigned int appInstanceId, size_t size); 27 | void freeMmapFileBuffer(void* addr); 28 | 29 | unsigned long timeMs(); 30 | unsigned int randomU32(unsigned long long *state); 31 | float randomF32(unsigned long long *state); 32 | long seekToEnd(FILE* file); 33 | 34 | struct MmapFile { 35 | void* data; 36 | size_t size; 37 | #ifdef _WIN32 38 | HANDLE hFile; 39 | HANDLE hMapping; 40 | #else 41 | int fd; 42 | #endif 43 | }; 44 | 45 | void openMmapFile(MmapFile* file, const char* path, size_t size); 46 | void closeMmapFile(MmapFile* file); 47 | 48 | typedef void (TaskLoopHandler)(unsigned int nThreads, unsigned int threadIndex, void* userData); 49 | typedef struct { 50 | TaskLoopHandler* handler; 51 | unsigned int taskType; 52 | const char* taskName; 53 | unsigned long executionCount; 54 | unsigned long executionTime; 55 | } TaskLoopTask; 56 | 57 | class TaskLoop; 58 | 59 | struct TaskLoopThread { 60 | unsigned int threadIndex; 61 | unsigned int nTasks; 62 | dl_thread handler; 63 | TaskLoop* loop; 64 | }; 65 | 66 | class TaskLoop { 67 | public: 68 | unsigned int nThreads; 69 | unsigned int nTasks; 70 | unsigned int nTypes; 71 | TaskLoopTask* tasks; 72 | void* userData; 73 | std::atomic_uint currentTaskIndex; 74 | std::atomic_uint doneThreadCount; 75 | unsigned long lastTime; 76 | unsigned long* executionTime; 77 | TaskLoopThread* threads; 78 | 79 | TaskLoop(unsigned int nThreads, unsigned int nTasks, unsigned int nTypes, TaskLoopTask* tasks, void* userData); 80 | ~TaskLoop(); 81 | void run(); 82 | static void* threadHandler(void* args); 83 | }; 84 | 85 | #endif 86 | --------------------------------------------------------------------------------