├── .envrc ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── assets ├── llama2_in_termux.gif └── llama_android.png ├── configurator.py ├── devshell.nix ├── export_meta_llama_bin.py ├── flake.lock ├── flake.nix ├── jni └── Android.mk ├── libs ├── arm64-v8a │ └── llama2 ├── armeabi-v7a │ └── llama2 ├── x86 │ └── llama2 └── x86_64 │ └── llama2 ├── model.py ├── requirements.txt ├── run.c ├── sample.py ├── test_all.py ├── tinystories.py ├── tokenizer.bin ├── tokenizer.model ├── tokenizer.py └── train.py /.envrc: -------------------------------------------------------------------------------- 1 | use flake -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .direnv/ 2 | out/ 3 | obj/ 4 | run -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Andrej 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Change this to whereever you keep NDK 2 | NDK = /home/manuel/Android/Sdk/ndk/25.2.9519653 3 | SRCDIR = . 4 | OBJDIR = . 5 | DBG ?= 0 6 | 7 | # Debug/Release configuration 8 | ifeq ($(DBG),1) 9 | MODE_FLAGS = -DDEBUG -g -O0 10 | else 11 | MODE_FLAGS = -Os -fdata-sections -ffunction-sections 12 | endif 13 | 14 | ## NDK configuration (clang) 15 | 16 | # NDK Version 17 | NDK_TARGETVER = 27 18 | 19 | # Target arch - here aarch64 for android 20 | NDK_TARGETARCH = aarch64-linux-android 21 | 22 | # Target CPU (ARMv8) 23 | NDK_TARGETSHORTARCH = arm64 24 | 25 | # Toolchain version 26 | NDK_TOOLVER = 4.9 27 | 28 | # Architecture of a machine that does cross compilation 29 | NDK_HOSTARCH = linux-x86_64 30 | 31 | # Set needed preprocessor symbols 32 | NDK_TOOLS = $(NDK)/toolchains/llvm/prebuilt/$(NDK_HOSTARCH)/bin 33 | NDK_SYSROOT = $(NDK)/sysroot 34 | # NDK_TOOL = $(CLANG_PATH)/bin/clang 35 | NDK_TOOL = $(NDK_TOOLS)/clang-14 36 | NDK_LIBS = $(NDK)/toolchains/$(NDK_TARGETARCH)-$(NDK_TOOLVER)/prebuilt/linux-x86_64/lib/gcc/$(NDK_TARGETARCH)/4.9.x 37 | NDK_INCLUDES = -I$(NDK)/sysroot/usr/include \ 38 | -I$(NDK)/sysroot/usr/include/$(NDK_TARGETARCH) 39 | NDK_SYSROOT = $(NDK)/platforms/android-$(NDK_TARGETVER)/arch-$(NDK_TARGETSHORTARCH) 40 | 41 | # Options common to compiler and linker 42 | OPT = $(MODE_FLAGS) \ 43 | -std=c99 \ 44 | -fPIE \ 45 | -Wall \ 46 | -target $(NDK_TARGETARCH) 47 | 48 | # Compiler options 49 | CFLAGS = $(OPT) \ 50 | $(NDK_INCLUDES) 51 | 52 | # Linker options 53 | LDFLAGS = $(OPT) \ 54 | $(MODE_FLAGS) \ 55 | -pie \ 56 | --sysroot=$(NDK_SYSROOT) \ 57 | -B $(ANDROID_NDK)/toolchains/$(NDK_TARGETARCH)-$(NDK_TOOLVER)/prebuilt/linux-x86_64/$(NDK_TARGETARCH)/bin \ 58 | -L$(NDK_LIBS) 59 | 60 | all: 61 | echo ${NDK_TOOL} 62 | $(NDK_TOOL) -c $(SRCDIR)/run.c -o $(OBJDIR)/run.o $(CFLAGS) 63 | $(NDK_TOOL) -o run $(OBJDIR)/run.o $(LDFLAGS) 64 | 65 | # the most basic way of building that is most likely to work on most systems 66 | .PHONY: run 67 | run: run.c 68 | gcc -O3 -o run run.c -lm 69 | 70 | # useful for a debug build, can then e.g. analyze with valgrind, example: 71 | # $ valgrind --leak-check=full ./run out/model.bin 1.0 3 72 | rundebug: run.c 73 | gcc -g -o run run.c -lm 74 | 75 | # https://gcc.gnu.org/onlinedocs/gcc/Optimize-Options.html 76 | # https://simonbyrne.github.io/notes/fastmath/ 77 | # -Ofast enables all -O3 optimizations. 78 | # Disregards strict standards compliance. 79 | # It also enables optimizations that are not valid for all standard-compliant programs. 80 | # It turns on -ffast-math, -fallow-store-data-races and the Fortran-specific 81 | # -fstack-arrays, unless -fmax-stack-var-size is specified, and -fno-protect-parens. 82 | # It turns off -fsemantic-interposition. 83 | # In our specific application this is *probably* okay to use 84 | .PHONY: runfast 85 | runfast: run.c 86 | gcc -Ofast -o run run.c -lm 87 | 88 | # additionally compiles with OpenMP, allowing multithreaded runs 89 | # make sure to also enable multiple threads when running, e.g.: 90 | # OMP_NUM_THREADS=4 ./run out/model.bin 91 | .PHONY: runomp 92 | runomp: run.c 93 | gcc -Ofast -fopenmp -march=native run.c -lm -o run 94 | 95 | .PHONY: clean 96 | clean: 97 | rm -f run 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llama2.c-android 2 | 3 | 4 | 5 | Port of Andrej Karpathy's [llama2.c](https://github.com/karpathy/llama2.c) to Android. You can run it as raw binary or use it as shared library. 6 | 7 | You can use the prebuild binaries in `libs` or compile on your own: 8 | 9 | ```bash 10 | # or wherever your ndk-build script resides 11 | cd jni && $ANDROID_HOME/ndk-bundle/ndk-build 12 | ``` 13 | 14 | ## run as binary 15 | 16 | Get e.g. [termux](https://f-droid.org/en/packages/com.termux/) and install APK to run binaries. 17 | 18 | ```bash 19 | wget https://karpathy.ai/llama2c/model.bin -P out 20 | adb push libs//llama2 /storage/emulated/0/Android/data 21 | adb push model.bin /storage/emulated/0/Android/data 22 | adb push tokenizer.bin /storage/emulated/0/Android/data 23 | ``` 24 | 25 | In Termux: 26 | 27 | ```bash 28 | cp /storage/emulated/0/Android/llama2 . 29 | chmod +x llama2 30 | ./llama2 model.bin 31 | ``` 32 | 33 | ![Llama2 in Termux](assets/llama2_in_termux.gif) 34 | 35 | ## run as shared lib 36 | 37 | wip 38 | 39 | -------------------------------------------------------------------------------- /assets/llama2_in_termux.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/assets/llama2_in_termux.gif -------------------------------------------------------------------------------- /assets/llama_android.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/assets/llama_android.png -------------------------------------------------------------------------------- /configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /devshell.nix: -------------------------------------------------------------------------------- 1 | { pkgs }: 2 | 3 | with pkgs; 4 | 5 | devshell.mkShell { 6 | name = "android-project"; 7 | motd = '' 8 | Entered the Android app development environment. 9 | ''; 10 | env = [ 11 | { 12 | name = "ANDROID_HOME"; 13 | value = "${android-sdk}/share/android-sdk"; 14 | } 15 | { 16 | name = "ANDROID_SDK_ROOT"; 17 | value = "${android-sdk}/share/android-sdk"; 18 | } 19 | { 20 | name = "JAVA_HOME"; 21 | value = jdk11.home; 22 | } 23 | ]; 24 | packages = [ 25 | android-studio 26 | android-sdk 27 | gradle 28 | jdk11 29 | gcc 30 | wget 31 | ]; 32 | } 33 | -------------------------------------------------------------------------------- /export_meta_llama_bin.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script exports the Llama 2 weights in llama2c.bin format. 3 | """ 4 | import sys 5 | import struct 6 | from pathlib import Path 7 | import json 8 | 9 | import torch 10 | 11 | from model import precompute_freqs_cis 12 | 13 | 14 | def export(p, state_dict, filepath='model.bin'): 15 | """export the model weights in fp32 into .bin file to be read from C""" 16 | f = open(filepath, 'wb') 17 | 18 | def serialize(key): 19 | print(f"writing {key}...") 20 | t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy() 21 | f.write(memoryview(t)) 22 | del state_dict[key] 23 | 24 | # first write out the header 25 | hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0] 26 | p['vocab_size'] = 32000 27 | p['max_seq_len'] = 2048 28 | 29 | n_kv_heads = p.get('n_kv_heads') or p['n_heads'] 30 | header = struct.pack( 31 | 'iiiiiii', 32 | p['dim'], hidden_dim, p['n_layers'], p['n_heads'], 33 | n_kv_heads, -p['vocab_size'], p['max_seq_len'] 34 | ) 35 | # NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present 36 | # in the checkpoint and should be loaded. 37 | f.write(header) 38 | 39 | # next write out the embedding weights 40 | print("writing tok_embeddings...") 41 | serialize('tok_embeddings.weight') 42 | 43 | # now all the layers 44 | # attention weights 45 | for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight') 46 | for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight') 47 | for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight') 48 | for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight') 49 | for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight') 50 | # ffn weights 51 | for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight') 52 | for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight') 53 | for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight') 54 | for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight') 55 | 56 | # final rmsnorm 57 | serialize('norm.weight') 58 | # freqs_cis 59 | freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2) 60 | state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']] 61 | state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']] 62 | serialize('freqs_cis.real') 63 | serialize('freqs_cis.imag') 64 | 65 | # finally write the output weights 66 | serialize('output.weight') 67 | 68 | f.close() 69 | print(f"wrote {filepath}") 70 | 71 | 72 | def concat_weights(models): 73 | state_dict = {} 74 | for name in list(models[0]): 75 | tensors = [model[name] for model in models] 76 | if len(tensors) == 1 or len(tensors[0].shape) == 1: 77 | state_dict[name] = tensors[0] 78 | continue 79 | is_axis_1 = ( 80 | name.startswith('tok_embeddings.') 81 | or name.endswith('.attention.wo.weight') 82 | or name.endswith('.feed_forward.w2.weight') 83 | ) 84 | axis = 1 if is_axis_1 else 0 85 | state_dict[name] = torch.cat(tensors, dim=axis) 86 | for model in models: 87 | del model[name] 88 | return state_dict 89 | 90 | 91 | def load_and_export(model_path, output_path): 92 | with open(model_path + 'params.json') as f: 93 | params = json.load(f) 94 | print(params) 95 | 96 | model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth'))) 97 | models = [] 98 | for i in model_paths: 99 | print(f'Loading {i}') 100 | models.append(torch.load(i, map_location='cpu')) 101 | 102 | state_dict = concat_weights(models) 103 | del models 104 | export(params, state_dict, output_path) 105 | 106 | 107 | if __name__ == '__main__': 108 | if len(sys.argv) == 1: 109 | print('[Llama model folder path] [output path]') 110 | exit() 111 | 112 | model_path = sys.argv[1] 113 | output_path = sys.argv[2] 114 | load_and_export(model_path, output_path) 115 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "android": { 4 | "inputs": { 5 | "devshell": "devshell", 6 | "flake-utils": "flake-utils", 7 | "nixpkgs": "nixpkgs" 8 | }, 9 | "locked": { 10 | "lastModified": 1690316394, 11 | "narHash": "sha256-uiTwTicVz8fuwScghAXP/MGIJvtMEMiQdNJLQzf8o1g=", 12 | "owner": "tadfisher", 13 | "repo": "android-nixpkgs", 14 | "rev": "ed6c49798b760b644911eae3f4b008fa41983ddf", 15 | "type": "github" 16 | }, 17 | "original": { 18 | "owner": "tadfisher", 19 | "repo": "android-nixpkgs", 20 | "type": "github" 21 | } 22 | }, 23 | "devshell": { 24 | "inputs": { 25 | "nixpkgs": [ 26 | "android", 27 | "nixpkgs" 28 | ], 29 | "systems": "systems" 30 | }, 31 | "locked": { 32 | "lastModified": 1688380630, 33 | "narHash": "sha256-8ilApWVb1mAi4439zS3iFeIT0ODlbrifm/fegWwgHjA=", 34 | "owner": "numtide", 35 | "repo": "devshell", 36 | "rev": "f9238ec3d75cefbb2b42a44948c4e8fb1ae9a205", 37 | "type": "github" 38 | }, 39 | "original": { 40 | "owner": "numtide", 41 | "repo": "devshell", 42 | "type": "github" 43 | } 44 | }, 45 | "devshell_2": { 46 | "inputs": { 47 | "nixpkgs": "nixpkgs_2", 48 | "systems": "systems_3" 49 | }, 50 | "locked": { 51 | "lastModified": 1688380630, 52 | "narHash": "sha256-8ilApWVb1mAi4439zS3iFeIT0ODlbrifm/fegWwgHjA=", 53 | "owner": "numtide", 54 | "repo": "devshell", 55 | "rev": "f9238ec3d75cefbb2b42a44948c4e8fb1ae9a205", 56 | "type": "github" 57 | }, 58 | "original": { 59 | "owner": "numtide", 60 | "repo": "devshell", 61 | "type": "github" 62 | } 63 | }, 64 | "flake-utils": { 65 | "inputs": { 66 | "systems": "systems_2" 67 | }, 68 | "locked": { 69 | "lastModified": 1689068808, 70 | "narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=", 71 | "owner": "numtide", 72 | "repo": "flake-utils", 73 | "rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4", 74 | "type": "github" 75 | }, 76 | "original": { 77 | "owner": "numtide", 78 | "repo": "flake-utils", 79 | "type": "github" 80 | } 81 | }, 82 | "flake-utils_2": { 83 | "inputs": { 84 | "systems": "systems_4" 85 | }, 86 | "locked": { 87 | "lastModified": 1689068808, 88 | "narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=", 89 | "owner": "numtide", 90 | "repo": "flake-utils", 91 | "rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4", 92 | "type": "github" 93 | }, 94 | "original": { 95 | "owner": "numtide", 96 | "repo": "flake-utils", 97 | "type": "github" 98 | } 99 | }, 100 | "nixpkgs": { 101 | "locked": { 102 | "lastModified": 1690179384, 103 | "narHash": "sha256-+arbgqFTAtoeKtepW9wCnA0njCOyoiDFyl0Q0SBSOtE=", 104 | "owner": "NixOS", 105 | "repo": "nixpkgs", 106 | "rev": "b12803b6d90e2e583429bb79b859ca53c348b39a", 107 | "type": "github" 108 | }, 109 | "original": { 110 | "owner": "NixOS", 111 | "ref": "nixos-unstable", 112 | "repo": "nixpkgs", 113 | "type": "github" 114 | } 115 | }, 116 | "nixpkgs_2": { 117 | "locked": { 118 | "lastModified": 1677383253, 119 | "narHash": "sha256-UfpzWfSxkfXHnb4boXZNaKsAcUrZT9Hw+tao1oZxd08=", 120 | "owner": "NixOS", 121 | "repo": "nixpkgs", 122 | "rev": "9952d6bc395f5841262b006fbace8dd7e143b634", 123 | "type": "github" 124 | }, 125 | "original": { 126 | "owner": "NixOS", 127 | "ref": "nixpkgs-unstable", 128 | "repo": "nixpkgs", 129 | "type": "github" 130 | } 131 | }, 132 | "nixpkgs_3": { 133 | "locked": { 134 | "lastModified": 1690387285, 135 | "narHash": "sha256-2nwKSKw48Uh8o+nZk+QDlqmeJpSf43pSwIwg2pCITc4=", 136 | "owner": "NixOS", 137 | "repo": "nixpkgs", 138 | "rev": "f323c3770063a6c6e30817d18287810b0804d3fb", 139 | "type": "github" 140 | }, 141 | "original": { 142 | "owner": "NixOS", 143 | "repo": "nixpkgs", 144 | "type": "github" 145 | } 146 | }, 147 | "root": { 148 | "inputs": { 149 | "android": "android", 150 | "devshell": "devshell_2", 151 | "flake-utils": "flake-utils_2", 152 | "nixpkgs": "nixpkgs_3" 153 | } 154 | }, 155 | "systems": { 156 | "locked": { 157 | "lastModified": 1681028828, 158 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 159 | "owner": "nix-systems", 160 | "repo": "default", 161 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 162 | "type": "github" 163 | }, 164 | "original": { 165 | "owner": "nix-systems", 166 | "repo": "default", 167 | "type": "github" 168 | } 169 | }, 170 | "systems_2": { 171 | "locked": { 172 | "lastModified": 1681028828, 173 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 174 | "owner": "nix-systems", 175 | "repo": "default", 176 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 177 | "type": "github" 178 | }, 179 | "original": { 180 | "owner": "nix-systems", 181 | "repo": "default", 182 | "type": "github" 183 | } 184 | }, 185 | "systems_3": { 186 | "locked": { 187 | "lastModified": 1681028828, 188 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 189 | "owner": "nix-systems", 190 | "repo": "default", 191 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 192 | "type": "github" 193 | }, 194 | "original": { 195 | "owner": "nix-systems", 196 | "repo": "default", 197 | "type": "github" 198 | } 199 | }, 200 | "systems_4": { 201 | "locked": { 202 | "lastModified": 1681028828, 203 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 204 | "owner": "nix-systems", 205 | "repo": "default", 206 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 207 | "type": "github" 208 | }, 209 | "original": { 210 | "owner": "nix-systems", 211 | "repo": "default", 212 | "type": "github" 213 | } 214 | } 215 | }, 216 | "root": "root", 217 | "version": 7 218 | } 219 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "android port llama2.c"; 3 | 4 | inputs = { 5 | nixpkgs.url = "github:NixOS/nixpkgs"; 6 | devshell.url = "github:numtide/devshell"; 7 | flake-utils.url = "github:numtide/flake-utils"; 8 | android.url = "github:tadfisher/android-nixpkgs"; 9 | }; 10 | 11 | outputs = { self, nixpkgs, devshell, flake-utils, android }: 12 | { 13 | overlay = final: prev: { 14 | inherit (self.packages.${final.system}) android-sdk android-studio; 15 | }; 16 | } 17 | // 18 | flake-utils.lib.eachSystem [ "x86_64-linux" ] (system: 19 | let 20 | pkgs = import nixpkgs { 21 | inherit system; 22 | config.allowUnfree = true; 23 | overlays = [ 24 | devshell.overlays.default 25 | self.overlay 26 | ]; 27 | }; 28 | in 29 | { 30 | packages = { 31 | android-sdk = android.sdk.${system} (sdkPkgs: with sdkPkgs; [ 32 | build-tools-30-0-2 33 | cmdline-tools-latest 34 | emulator 35 | platform-tools 36 | platforms-android-3 37 | ndk-bundle 38 | ]); 39 | 40 | android-studio = pkgs.androidStudioPackages.stable; 41 | gcc = pkgs.gcc; 42 | wget = pkgs.wget; 43 | }; 44 | 45 | devShell = import ./devshell.nix { inherit pkgs; }; 46 | } 47 | ); 48 | } 49 | -------------------------------------------------------------------------------- /jni/Android.mk: -------------------------------------------------------------------------------- 1 | LOCAL_PATH := $(call my-dir) 2 | include $(CLEAR_VARS) 3 | LOCAL_MODULE := llama2 4 | LOCAL_SRC_FILES := ../run.c 5 | include $(BUILD_EXECUTABLE) -------------------------------------------------------------------------------- /libs/arm64-v8a/llama2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/libs/arm64-v8a/llama2 -------------------------------------------------------------------------------- /libs/armeabi-v7a/llama2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/libs/armeabi-v7a/llama2 -------------------------------------------------------------------------------- /libs/x86/llama2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/libs/x86/llama2 -------------------------------------------------------------------------------- /libs/x86_64/llama2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/libs/x86_64/llama2 -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import struct 3 | import inspect 4 | from dataclasses import dataclass 5 | from typing import Any, Optional, Tuple 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | @dataclass 13 | class ModelArgs: 14 | dim: int = 4096 15 | n_layers: int = 32 16 | n_heads: int = 32 17 | n_kv_heads: Optional[int] = None 18 | vocab_size: int = -1 # defined later by tokenizer 19 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 20 | norm_eps: float = 1e-5 21 | max_seq_len: int = 2048 22 | dropout: float = 0.0 23 | 24 | 25 | class RMSNorm(torch.nn.Module): 26 | def __init__(self, dim: int, eps: float): 27 | super().__init__() 28 | self.eps = eps 29 | self.weight = nn.Parameter(torch.ones(dim)) 30 | 31 | def _norm(self, x): 32 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 33 | 34 | def forward(self, x): 35 | output = self._norm(x.float()).type_as(x) 36 | return output * self.weight 37 | 38 | 39 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 40 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 41 | t = torch.arange(end, device=freqs.device) # type: ignore 42 | freqs = torch.outer(t, freqs).float() # type: ignore 43 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 44 | return freqs_cis 45 | 46 | 47 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 48 | ndim = x.ndim 49 | assert 0 <= 1 < ndim 50 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 51 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 52 | return freqs_cis.view(*shape) 53 | 54 | 55 | def apply_rotary_emb( 56 | xq: torch.Tensor, 57 | xk: torch.Tensor, 58 | freqs_cis: torch.Tensor, 59 | ) -> Tuple[torch.Tensor, torch.Tensor]: 60 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 61 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 62 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 63 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 64 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 65 | return xq_out.type_as(xq), xk_out.type_as(xk) 66 | 67 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 68 | """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" 69 | bs, slen, n_kv_heads, head_dim = x.shape 70 | if n_rep == 1: 71 | return x 72 | return ( 73 | x[:, :, :, None, :] 74 | .expand(bs, slen, n_kv_heads, n_rep, head_dim) 75 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim) 76 | ) 77 | 78 | class Attention(nn.Module): 79 | def __init__(self, args: ModelArgs): 80 | super().__init__() 81 | self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads 82 | model_parallel_size = 1 83 | self.n_local_heads = args.n_heads // model_parallel_size 84 | self.n_local_kv_heads = self.n_kv_heads // model_parallel_size 85 | self.n_rep = self.n_local_heads // self.n_local_kv_heads 86 | self.head_dim = args.dim // args.n_heads 87 | self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) 88 | self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) 89 | self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) 90 | self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) 91 | self.attn_dropout = nn.Dropout(args.dropout) 92 | self.resid_dropout = nn.Dropout(args.dropout) 93 | self.dropout = args.dropout 94 | 95 | # use flash attention or a manual implementation? 96 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 97 | if not self.flash: 98 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 99 | mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) 100 | mask = torch.triu(mask, diagonal=1) 101 | self.register_buffer("mask", mask) 102 | 103 | def forward( 104 | self, 105 | x: torch.Tensor, 106 | freqs_cis: torch.Tensor, 107 | ): 108 | bsz, seqlen, _ = x.shape 109 | 110 | # QKV 111 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 112 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 113 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 114 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 115 | 116 | # RoPE relative positional embeddings 117 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis) 118 | 119 | # grouped multiquery attention: expand out keys and values 120 | xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) 121 | xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) 122 | 123 | # make heads into a batch dimension 124 | xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 125 | xk = xk.transpose(1, 2) 126 | xv = xv.transpose(1, 2) 127 | 128 | # flash implementation 129 | if self.flash: 130 | output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) 131 | else: 132 | # manual implementation 133 | scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) 134 | scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) 135 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 136 | scores = self.attn_dropout(scores) 137 | output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) 138 | 139 | # restore time as batch dimension and concat heads 140 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 141 | 142 | # final projection into the residual stream 143 | output = self.wo(output) 144 | output = self.resid_dropout(output) 145 | return output 146 | 147 | 148 | class FeedForward(nn.Module): 149 | def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float): 150 | super().__init__() 151 | hidden_dim = int(2 * hidden_dim / 3) 152 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 153 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 154 | self.w2 = nn.Linear(hidden_dim, dim, bias=False) 155 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 156 | self.dropout = nn.Dropout(dropout) 157 | 158 | def forward(self, x): 159 | return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) 160 | 161 | 162 | class TransformerBlock(nn.Module): 163 | def __init__(self, layer_id: int, args: ModelArgs): 164 | super().__init__() 165 | self.n_heads = args.n_heads 166 | self.dim = args.dim 167 | self.head_dim = args.dim // args.n_heads 168 | self.attention = Attention(args) 169 | self.feed_forward = FeedForward( 170 | dim=args.dim, 171 | hidden_dim=4 * args.dim, 172 | multiple_of=args.multiple_of, 173 | dropout=args.dropout, 174 | ) 175 | self.layer_id = layer_id 176 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 177 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 178 | 179 | def forward(self, x, freqs_cis): 180 | h = x + self.attention.forward(self.attention_norm(x), freqs_cis) 181 | out = h + self.feed_forward.forward(self.ffn_norm(h)) 182 | return out 183 | 184 | 185 | class Transformer(nn.Module): 186 | def __init__(self, params: ModelArgs): 187 | super().__init__() 188 | self.params = params 189 | self.vocab_size = params.vocab_size 190 | self.n_layers = params.n_layers 191 | 192 | self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) 193 | self.dropout = nn.Dropout(params.dropout) 194 | self.layers = torch.nn.ModuleList() 195 | for layer_id in range(params.n_layers): 196 | self.layers.append(TransformerBlock(layer_id, params)) 197 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 198 | self.output = nn.Linear(params.dim, params.vocab_size, bias=False) 199 | 200 | # share the unembedding parameters with the embedding parameters 201 | self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying 202 | 203 | # some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse 204 | freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2) 205 | self.register_buffer("freqs_cis", freqs_cis, persistent=False) 206 | 207 | # init all weights 208 | self.apply(self._init_weights) 209 | # apply special scaled init to the residual projections, per GPT-2 paper 210 | for pn, p in self.named_parameters(): 211 | if pn.endswith('w3.weight') or pn.endswith('wo.weight'): 212 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers)) 213 | 214 | def _init_weights(self, module): 215 | if isinstance(module, nn.Linear): 216 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 217 | if module.bias is not None: 218 | torch.nn.init.zeros_(module.bias) 219 | elif isinstance(module, nn.Embedding): 220 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 221 | 222 | def forward(self, tokens, targets=None): 223 | _bsz, seqlen = tokens.shape 224 | h = self.tok_embeddings(tokens) 225 | h = self.dropout(h) 226 | freqs_cis = self.freqs_cis[:seqlen] 227 | 228 | for layer in self.layers: 229 | h = layer(h, freqs_cis) 230 | h = self.norm(h) 231 | 232 | if targets is not None: 233 | # if we are given some desired targets also calculate the loss 234 | logits = self.output(h) 235 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 236 | else: 237 | # inference-time mini-optimization: only forward the output on the very last position 238 | logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim 239 | loss = None 240 | 241 | return logits, loss 242 | 243 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 244 | # start with all of the candidate parameters 245 | param_dict = {pn: p for pn, p in self.named_parameters()} 246 | # filter out those that do not require grad 247 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 248 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 249 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 250 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 251 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 252 | optim_groups = [ 253 | {'params': decay_params, 'weight_decay': weight_decay}, 254 | {'params': nodecay_params, 'weight_decay': 0.0} 255 | ] 256 | num_decay_params = sum(p.numel() for p in decay_params) 257 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 258 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 259 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 260 | # Create AdamW optimizer and use the fused version if it is available 261 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 262 | use_fused = fused_available and device_type == 'cuda' 263 | extra_args = dict(fused=True) if use_fused else dict() 264 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 265 | print(f"using fused AdamW: {use_fused}") 266 | 267 | return optimizer 268 | 269 | def estimate_mfu(self, fwdbwd_per_iter, dt): 270 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 271 | # first estimate the number of flops we do per iteration. 272 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 273 | N = sum(p.numel() for p in self.parameters()) 274 | cfg = self.params 275 | L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len 276 | flops_per_token = 6*N + 12*L*H*Q*T 277 | flops_per_fwdbwd = flops_per_token * T 278 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 279 | # express our flops throughput as ratio of A100 bfloat16 peak flops 280 | flops_achieved = flops_per_iter * (1.0/dt) # per second 281 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 282 | mfu = flops_achieved / flops_promised 283 | return mfu 284 | 285 | @torch.inference_mode() 286 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 287 | """ 288 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 289 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 290 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 291 | Also note this is a super inefficient version of sampling with no key/value cache. 292 | """ 293 | for _ in range(max_new_tokens): 294 | # if the sequence context is growing too long we must crop it at block_size 295 | idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:] 296 | # forward the model to get the logits for the index in the sequence 297 | logits, _ = self(idx_cond) 298 | logits = logits[:, -1, :] # crop to just the final time step 299 | if temperature == 0.0: 300 | # "sample" the single most likely index 301 | _, idx_next = torch.topk(logits, k=1, dim=-1) 302 | else: 303 | # pluck the logits at the final step and scale by desired temperature 304 | logits = logits / temperature 305 | # optionally crop the logits to only the top k options 306 | if top_k is not None: 307 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 308 | logits[logits < v[:, [-1]]] = -float('Inf') 309 | # apply softmax to convert logits to (normalized) probabilities 310 | probs = F.softmax(logits, dim=-1) 311 | idx_next = torch.multinomial(probs, num_samples=1) 312 | # append sampled index to the running sequence and continue 313 | idx = torch.cat((idx, idx_next), dim=1) 314 | 315 | return idx 316 | 317 | def export(self, filepath='model.bin'): 318 | """export the model weights in fp32 into .bin file to be read from C""" 319 | f = open(filepath, 'wb') 320 | 321 | def serialize(t): 322 | d = t.detach().cpu().view(-1).numpy().astype(np.float32) 323 | b = struct.pack(f'{len(d)}f', *d) 324 | f.write(b) 325 | 326 | # first write out the header 327 | hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0] 328 | p = self.params 329 | n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads 330 | header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, 331 | n_kv_heads, p.vocab_size, p.max_seq_len) 332 | f.write(header) 333 | 334 | # next write out the embedding weights 335 | serialize(self.tok_embeddings.weight) 336 | 337 | # now all the layers 338 | # attention weights 339 | for layer in self.layers: 340 | serialize(layer.attention_norm.weight) 341 | for layer in self.layers: 342 | serialize(layer.attention.wq.weight) 343 | for layer in self.layers: 344 | serialize(layer.attention.wk.weight) 345 | for layer in self.layers: 346 | serialize(layer.attention.wv.weight) 347 | for layer in self.layers: 348 | serialize(layer.attention.wo.weight) 349 | # ffn weights 350 | for layer in self.layers: 351 | serialize(layer.ffn_norm.weight) 352 | for layer in self.layers: 353 | serialize(layer.feed_forward.w1.weight) 354 | for layer in self.layers: 355 | serialize(layer.feed_forward.w2.weight) 356 | for layer in self.layers: 357 | serialize(layer.feed_forward.w3.weight) 358 | # final rmsnorm 359 | serialize(self.norm.weight) 360 | # note: no need to write final classifier weights due to weight sharing 361 | # freqs_cis 362 | serialize(self.freqs_cis.real[:p.max_seq_len]) 363 | serialize(self.freqs_cis.imag[:p.max_seq_len]) 364 | 365 | # write to binary file 366 | f.close() 367 | print(f"wrote {filepath}") 368 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.5 2 | pytest==7.4.0 3 | Requests==2.31.0 4 | sentencepiece==0.1.99 5 | tiktoken==0.3.3 6 | torch==2.0.1 7 | tqdm==4.64.1 8 | wandb==0.15.5 9 | -------------------------------------------------------------------------------- /run.c: -------------------------------------------------------------------------------- 1 | /* 2 | Inference for Llama-2 Transformer model in pure C. 3 | 4 | Example compile: (see README for more details) 5 | $ gcc -O3 -o run run.c -lm 6 | 7 | Then run with: 8 | $ ./run 9 | */ 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | // ---------------------------------------------------------------------------- 21 | // Transformer and RunState structs, and related memory management 22 | 23 | typedef struct { 24 | int dim; // transformer dimension 25 | int hidden_dim; // for ffn layers 26 | int n_layers; // number of layers 27 | int n_heads; // number of query heads 28 | int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery) 29 | int vocab_size; // vocabulary size, usually 256 (byte-level) 30 | int seq_len; // max sequence length 31 | } Config; 32 | 33 | typedef struct { 34 | // token embedding table 35 | float* token_embedding_table; // (vocab_size, dim) 36 | // weights for rmsnorms 37 | float* rms_att_weight; // (layer, dim) rmsnorm weights 38 | float* rms_ffn_weight; // (layer, dim) 39 | // weights for matmuls 40 | float* wq; // (layer, dim, dim) 41 | float* wk; // (layer, dim, dim) 42 | float* wv; // (layer, dim, dim) 43 | float* wo; // (layer, dim, dim) 44 | // weights for ffn 45 | float* w1; // (layer, hidden_dim, dim) 46 | float* w2; // (layer, dim, hidden_dim) 47 | float* w3; // (layer, hidden_dim, dim) 48 | // final rmsnorm 49 | float* rms_final_weight; // (dim,) 50 | // freq_cis for RoPE relatively positional embeddings 51 | float* freq_cis_real; // (seq_len, dim/2) 52 | float* freq_cis_imag; // (seq_len, dim/2) 53 | // (optional) classifier weights for the logits, on the last layer 54 | float* wcls; 55 | } TransformerWeights; 56 | 57 | typedef struct { 58 | // current wave of activations 59 | float *x; // activation at current time stamp (dim,) 60 | float *xb; // same, but inside a residual branch (dim,) 61 | float *xb2; // an additional buffer just for convenience (dim,) 62 | float *hb; // buffer for hidden dimension in the ffn (hidden_dim,) 63 | float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,) 64 | float *q; // query (dim,) 65 | float *k; // key (dim,) 66 | float *v; // value (dim,) 67 | float *att; // buffer for scores/attention values (n_heads, seq_len) 68 | float *logits; // output logits 69 | // kv cache 70 | float* key_cache; // (layer, seq_len, dim) 71 | float* value_cache; // (layer, seq_len, dim) 72 | } RunState; 73 | 74 | void malloc_run_state(RunState* s, Config* p) { 75 | // we calloc instead of malloc to keep valgrind happy 76 | s->x = calloc(p->dim, sizeof(float)); 77 | s->xb = calloc(p->dim, sizeof(float)); 78 | s->xb2 = calloc(p->dim, sizeof(float)); 79 | s->hb = calloc(p->hidden_dim, sizeof(float)); 80 | s->hb2 = calloc(p->hidden_dim, sizeof(float)); 81 | s->q = calloc(p->dim, sizeof(float)); 82 | s->k = calloc(p->dim, sizeof(float)); 83 | s->v = calloc(p->dim, sizeof(float)); 84 | s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); 85 | s->logits = calloc(p->vocab_size, sizeof(float)); 86 | s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); 87 | s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); 88 | // ensure all mallocs went fine 89 | if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q 90 | || !s->k || !s->v || !s->att || !s->logits || !s->key_cache 91 | || !s->value_cache) { 92 | printf("malloc failed!\n"); 93 | exit(1); 94 | } 95 | } 96 | 97 | void free_run_state(RunState* s) { 98 | free(s->x); 99 | free(s->xb); 100 | free(s->xb2); 101 | free(s->hb); 102 | free(s->hb2); 103 | free(s->q); 104 | free(s->k); 105 | free(s->v); 106 | free(s->att); 107 | free(s->logits); 108 | free(s->key_cache); 109 | free(s->value_cache); 110 | } 111 | 112 | // ---------------------------------------------------------------------------- 113 | // initialization: read from checkpoint 114 | 115 | void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) { 116 | float* ptr = f; 117 | w->token_embedding_table = ptr; 118 | ptr += p->vocab_size * p->dim; 119 | w->rms_att_weight = ptr; 120 | ptr += p->n_layers * p->dim; 121 | w->wq = ptr; 122 | ptr += p->n_layers * p->dim * p->dim; 123 | w->wk = ptr; 124 | ptr += p->n_layers * p->dim * p->dim; 125 | w->wv = ptr; 126 | ptr += p->n_layers * p->dim * p->dim; 127 | w->wo = ptr; 128 | ptr += p->n_layers * p->dim * p->dim; 129 | w->rms_ffn_weight = ptr; 130 | ptr += p->n_layers * p->dim; 131 | w->w1 = ptr; 132 | ptr += p->n_layers * p->dim * p->hidden_dim; 133 | w->w2 = ptr; 134 | ptr += p->n_layers * p->hidden_dim * p->dim; 135 | w->w3 = ptr; 136 | ptr += p->n_layers * p->dim * p->hidden_dim; 137 | w->rms_final_weight = ptr; 138 | ptr += p->dim; 139 | w->freq_cis_real = ptr; 140 | int head_size = p->dim / p->n_heads; 141 | ptr += p->seq_len * head_size / 2; 142 | w->freq_cis_imag = ptr; 143 | ptr += p->seq_len * head_size / 2; 144 | w->wcls = shared_weights ? w->token_embedding_table : ptr; 145 | } 146 | 147 | // ---------------------------------------------------------------------------- 148 | // neural net blocks 149 | 150 | void accum(float *a, float *b, int size) { 151 | for (int i = 0; i < size; i++) { 152 | a[i] += b[i]; 153 | } 154 | } 155 | 156 | void rmsnorm(float* o, float* x, float* weight, int size) { 157 | // calculate sum of squares 158 | float ss = 0.0f; 159 | for (int j = 0; j < size; j++) { 160 | ss += x[j] * x[j]; 161 | } 162 | ss /= size; 163 | ss += 1e-5f; 164 | ss = 1.0f / sqrtf(ss); 165 | // normalize and scale 166 | for (int j = 0; j < size; j++) { 167 | o[j] = weight[j] * (ss * x[j]); 168 | } 169 | } 170 | 171 | void softmax(float* x, int size) { 172 | // find max value (for numerical stability) 173 | float max_val = x[0]; 174 | for (int i = 1; i < size; i++) { 175 | if (x[i] > max_val) { 176 | max_val = x[i]; 177 | } 178 | } 179 | // exp and sum 180 | float sum = 0.0f; 181 | for (int i = 0; i < size; i++) { 182 | x[i] = expf(x[i] - max_val); 183 | sum += x[i]; 184 | } 185 | // normalize 186 | for (int i = 0; i < size; i++) { 187 | x[i] /= sum; 188 | } 189 | } 190 | 191 | void matmul(float* xout, float* x, float* w, int n, int d) { 192 | // W (d,n) @ x (n,) -> xout (d,) 193 | #pragma omp parallel for 194 | for (int i = 0; i < d; i++) { 195 | float val = 0.0f; 196 | for (int j = 0; j < n; j++) { 197 | val += w[i * n + j] * x[j]; 198 | } 199 | xout[i] = val; 200 | } 201 | } 202 | 203 | void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) { 204 | 205 | // a few convenience variables 206 | float *x = s->x; 207 | int dim = p->dim; 208 | int hidden_dim = p->hidden_dim; 209 | int head_size = dim / p->n_heads; 210 | 211 | // copy the token embedding into x 212 | float* content_row = &(w->token_embedding_table[token * dim]); 213 | memcpy(x, content_row, dim*sizeof(*x)); 214 | 215 | // pluck out the "pos" row of freq_cis_real and freq_cis_imag 216 | float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2; 217 | float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2; 218 | 219 | // forward all the layers 220 | for(int l = 0; l < p->n_layers; l++) { 221 | 222 | // attention rmsnorm 223 | rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); 224 | 225 | // qkv matmuls for this position 226 | matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim); 227 | matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim); 228 | matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim); 229 | 230 | // apply RoPE rotation to the q and k vectors for each head 231 | for (int h = 0; h < p->n_heads; h++) { 232 | // get the q and k vectors for this head 233 | float* q = s->q + h * head_size; 234 | float* k = s->k + h * head_size; 235 | // rotate q and k by the freq_cis_real and freq_cis_imag 236 | for (int i = 0; i < head_size; i+=2) { 237 | float q0 = q[i]; 238 | float q1 = q[i+1]; 239 | float k0 = k[i]; 240 | float k1 = k[i+1]; 241 | float fcr = freq_cis_real_row[i/2]; 242 | float fci = freq_cis_imag_row[i/2]; 243 | q[i] = q0 * fcr - q1 * fci; 244 | q[i+1] = q0 * fci + q1 * fcr; 245 | k[i] = k0 * fcr - k1 * fci; 246 | k[i+1] = k0 * fci + k1 * fcr; 247 | } 248 | } 249 | 250 | // save key,value at this time step (pos) to our kv cache 251 | int loff = l * p->seq_len * dim; // kv cache layer offset for convenience 252 | float* key_cache_row = s->key_cache + loff + pos * dim; 253 | float* value_cache_row = s->value_cache + loff + pos * dim; 254 | memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row)); 255 | memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row)); 256 | 257 | // multihead attention. iterate over all heads 258 | #pragma omp parallel for 259 | for (int h = 0; h < p->n_heads; h++) { 260 | // get the query vector for this head 261 | float* q = s->q + h * head_size; 262 | // attention scores for this head 263 | float* att = s->att + h * p->seq_len; 264 | // iterate over all timesteps, including the current one 265 | for (int t = 0; t <= pos; t++) { 266 | // get the key vector for this head and at this timestep 267 | float* k = s->key_cache + loff + t * dim + h * head_size; 268 | // calculate the attention score as the dot product of q and k 269 | float score = 0.0f; 270 | for (int i = 0; i < head_size; i++) { 271 | score += q[i] * k[i]; 272 | } 273 | score /= sqrtf(head_size); 274 | // save the score to the attention buffer 275 | att[t] = score; 276 | } 277 | 278 | // softmax the scores to get attention weights, from 0..pos inclusively 279 | softmax(att, pos + 1); 280 | 281 | // weighted sum of the values, store back into xb 282 | for (int i = 0; i < head_size; i++) { 283 | float val = 0.0f; 284 | for (int t = 0; t <= pos; t++) { 285 | val += att[t] * s->value_cache[loff + t * dim + h * head_size + i]; // note bad locality 286 | } 287 | s->xb[h * head_size + i] = val; 288 | } 289 | } 290 | 291 | // final matmul to get the output of the attention 292 | matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim); 293 | 294 | // residual connection back into x 295 | accum(x, s->xb2, dim); 296 | 297 | // ffn rmsnorm 298 | rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim); 299 | 300 | // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) 301 | // first calculate self.w1(x) and self.w3(x) 302 | matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim); 303 | matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim); 304 | 305 | // F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid 306 | for (int i = 0; i < hidden_dim; i++) { 307 | s->hb[i] = s->hb[i] * (1.0f / (1.0f + expf(-s->hb[i]))); 308 | } 309 | 310 | // elementwise multiply with w3(x) 311 | for (int i = 0; i < hidden_dim; i++) { 312 | s->hb[i] = s->hb[i] * s->hb2[i]; 313 | } 314 | 315 | // final matmul to get the output of the ffn 316 | matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim); 317 | 318 | // residual connection 319 | accum(x, s->xb, dim); 320 | } 321 | 322 | // final rmsnorm 323 | rmsnorm(x, x, w->rms_final_weight, dim); 324 | 325 | // classifier into logits 326 | matmul(s->logits, x, w->wcls, p->dim, p->vocab_size); 327 | } 328 | 329 | int sample(float* probabilities, int n) { 330 | // sample index from probabilities, they must sum to 1 331 | float r = (float)rand() / (float)RAND_MAX; 332 | float cdf = 0.0f; 333 | for (int i = 0; i < n; i++) { 334 | cdf += probabilities[i]; 335 | if (r < cdf) { 336 | return i; 337 | } 338 | } 339 | return n - 1; // in case of rounding errors 340 | } 341 | 342 | int argmax(float* v, int n) { 343 | // return argmax of v in elements 0..n 344 | int max_i = 0; 345 | float max_p = v[0]; 346 | for (int i = 1; i < n; i++) { 347 | if (v[i] > max_p) { 348 | max_i = i; 349 | max_p = v[i]; 350 | } 351 | } 352 | return max_i; 353 | } 354 | 355 | // ---------------------------------------------------------------------------- 356 | 357 | // long time_in_ms() { 358 | // struct timespec time; 359 | // // Get the current time with nanosecond precision 360 | // if (clock_gettime(CLOCK_REALTIME, &time) == 0) { 361 | // return time.tv_sec * 1000 + time.tv_nsec / 1000000; 362 | // } else { 363 | // perror("clock_gettime"); 364 | // return -1; // Return -1 to indicate an error 365 | // } 366 | // } 367 | 368 | long time_in_ms() { 369 | struct timeval time; 370 | // Get the current time with microsecond precision 371 | if (gettimeofday(&time, NULL) == 0) { 372 | return time.tv_sec * 1000 + time.tv_usec / 1000; 373 | } else { 374 | perror("gettimeofday"); 375 | return -1; // Return -1 to indicate an error 376 | } 377 | } 378 | 379 | int main(int argc, char *argv[]) { 380 | 381 | // poor man's C argparse 382 | char *checkpoint = NULL; // e.g. out/model.bin 383 | float temperature = 0.9f; // e.g. 1.0, or 0.0 384 | int steps = 256; // max number of steps to run for, 0: use seq_len 385 | // 'checkpoint' is necessary arg 386 | if (argc < 2) { 387 | printf("Usage: %s [temperature] [steps]\n", argv[0]); 388 | return 1; 389 | } 390 | if (argc >= 2) { 391 | checkpoint = argv[1]; 392 | } 393 | if (argc >= 3) { 394 | // optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline 395 | temperature = atof(argv[2]); 396 | } 397 | if (argc >= 4) { 398 | steps = atoi(argv[3]); 399 | } 400 | 401 | // seed rng with time. if you want deterministic behavior use temperature 0.0 402 | srand((unsigned int)time(NULL)); 403 | 404 | // read in the model.bin file 405 | Config config; 406 | TransformerWeights weights; 407 | int fd = 0; 408 | float* data = NULL; 409 | long file_size; 410 | { 411 | FILE *file = fopen(checkpoint, "rb"); 412 | if (!file) { 413 | printf("Unable to open the checkpoint file %s!\n", checkpoint); 414 | return 1; 415 | } 416 | // read in the config header 417 | if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; } 418 | // negative vocab size is hacky way of signaling unshared weights. bit yikes. 419 | int shared_weights = config.vocab_size > 0 ? 1 : 0; 420 | config.vocab_size = abs(config.vocab_size); 421 | // figure out the file size 422 | fseek(file, 0, SEEK_END); // move file pointer to end of file 423 | file_size = ftell(file); // get the file size, in bytes 424 | fclose(file); 425 | // memory map the Transformer weights into the data pointer 426 | fd = open(checkpoint, O_RDONLY); // open in read only mode 427 | if (fd == -1) { printf("open failed!\n"); return 1; } 428 | data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0); 429 | if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; } 430 | float* weights_ptr = data + sizeof(Config)/sizeof(float); 431 | checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights); 432 | } 433 | // right now we cannot run for more than config.seq_len steps 434 | if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } 435 | 436 | // read in the tokenizer.bin file 437 | char** vocab = (char**)malloc(config.vocab_size * sizeof(char*)); 438 | { 439 | FILE *file = fopen("tokenizer.bin", "rb"); 440 | if (!file) { 441 | printf("Unable to open the tokenizer file tokenizer.bin! Run " 442 | "python tokenizer.py to convert tokenizer.model -> tokenizer.bin\n"); 443 | return 1; 444 | } 445 | int len; 446 | for (int i = 0; i < config.vocab_size; i++) { 447 | if(fread(&len, sizeof(int), 1, file) != 1) { return 1; } 448 | vocab[i] = (char *)malloc(len + 1); 449 | if(fread(vocab[i], len, 1, file) != 1) { return 1; } 450 | vocab[i][len] = '\0'; // add the string terminating token 451 | } 452 | fclose(file); 453 | } 454 | 455 | // create and init the application RunState 456 | RunState state; 457 | malloc_run_state(&state, &config); 458 | 459 | // the current position we are in 460 | long start = time_in_ms(); 461 | int next; 462 | int token = 1; // 1 = BOS token in Llama-2 sentencepiece 463 | int pos = 0; 464 | printf("\n"); // explicit print the initial BOS token (=1), stylistically symmetric 465 | while (pos < steps) { 466 | 467 | // forward the transformer to get logits for the next token 468 | transformer(token, pos, &config, &state, &weights); 469 | 470 | // sample the next token 471 | if(temperature == 0.0f) { 472 | // greedy argmax sampling 473 | next = argmax(state.logits, config.vocab_size); 474 | } else { 475 | // apply the temperature to the logits 476 | for (int q=0; q" or etc. Can also specify a file, use as: "FILE:prompt.txt" 15 | num_samples = 1 # number of samples to draw 16 | max_new_tokens = 100 # number of tokens generated in each sample 17 | temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions 18 | top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability 19 | seed = 1337 20 | device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 21 | #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' 22 | dtype = "float32" 23 | compile = False # use PyTorch 2.0 to compile the model to be faster 24 | exec(open('configurator.py').read()) # overrides from command line or config file 25 | # ----------------------------------------------------------------------------- 26 | 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 30 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 31 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 32 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 33 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 34 | 35 | # init from a model saved in a specific directory 36 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 37 | checkpoint = torch.load(ckpt_path, map_location=device) 38 | gptconf = ModelArgs(**checkpoint['model_args']) 39 | model = Transformer(gptconf) 40 | state_dict = checkpoint['model'] 41 | unwanted_prefix = '_orig_mod.' 42 | for k,v in list(state_dict.items()): 43 | if k.startswith(unwanted_prefix): 44 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 45 | model.load_state_dict(state_dict, strict=False) 46 | 47 | model.eval() 48 | model.to(device) 49 | if compile: 50 | print("Compiling the model...") 51 | model = torch.compile(model) # requires PyTorch 2.0 (optional) 52 | 53 | # load the tokenizer 54 | enc = Tokenizer() 55 | 56 | # encode the beginning of the prompt 57 | if start.startswith('FILE:'): 58 | with open(start[5:], 'r', encoding='utf-8') as f: 59 | start = f.read() 60 | start_ids = enc.encode(start, bos=True, eos=False) 61 | x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) 62 | 63 | # run generation 64 | with torch.no_grad(): 65 | with ctx: 66 | for k in range(num_samples): 67 | y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) 68 | print(enc.decode(y[0].tolist())) 69 | print('---------------') 70 | -------------------------------------------------------------------------------- /test_all.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run simply with 3 | $ pytest 4 | """ 5 | import os 6 | import pytest # pip install pytest 7 | import subprocess 8 | 9 | import torch 10 | from model import ModelArgs, Transformer 11 | 12 | def test_argmax_inference(): 13 | """ 14 | Only the simplest test for now: run inference with temperature 0 15 | (for determinism) in both C and PyTorch, and see that the sampled tokens 16 | are the same. 17 | """ 18 | test_ckpt_dir = "out" # TODO create a dummy test checkpoint for this? 19 | 20 | # run C version 21 | model_path = os.path.join(test_ckpt_dir, "model.bin") 22 | command = ["./run", model_path, "0.0"] 23 | proc = subprocess.Popen(command, stdout=subprocess.PIPE) 24 | c_tokens = [] 25 | for line in proc.stdout: 26 | token = int(line.decode('utf-8').strip()) 27 | c_tokens.append(token) 28 | proc.wait() 29 | #print(c_tokens) 30 | 31 | # run PyTorch version 32 | device = "cuda" if torch.cuda.is_available() else "cpu" 33 | ckpt_path = os.path.join(test_ckpt_dir, "ckpt.pt") 34 | checkpoint = torch.load(ckpt_path, map_location=device) 35 | gptconf = ModelArgs(**checkpoint['model_args']) 36 | model = Transformer(gptconf) 37 | state_dict = checkpoint['model'] 38 | unwanted_prefix = '_orig_mod.' 39 | for k,v in list(state_dict.items()): 40 | if k.startswith(unwanted_prefix): 41 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 42 | model.load_state_dict(state_dict, strict=False) 43 | model.eval() 44 | model.to(device) 45 | x = torch.tensor([[1]], dtype=torch.long, device=device) # 1 is BOS 46 | with torch.inference_mode(): 47 | y = model.generate(x, max_new_tokens=gptconf.max_seq_len, temperature=0.0) 48 | pt_tokens = y[0].tolist() 49 | pt_tokens = pt_tokens[1:] # remove BOS 50 | #print(pt_tokens) 51 | 52 | # compare 53 | assert c_tokens == pt_tokens 54 | -------------------------------------------------------------------------------- /tinystories.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download, preprocess and serve the TinyStories dataset as a DataLoader. 3 | """ 4 | 5 | import argparse 6 | import glob 7 | import json 8 | import os 9 | import random 10 | from typing import List 11 | from concurrent.futures import ThreadPoolExecutor, as_completed 12 | 13 | import numpy as np 14 | import requests 15 | import torch 16 | import torch.distributed as dist 17 | from tqdm import tqdm 18 | 19 | from tokenizer import Tokenizer 20 | 21 | DATA_CACHE_DIR = "data" 22 | 23 | def download_file(url: str, fname: str, chunk_size=1024): 24 | """Helper function to download a file from a given url""" 25 | resp = requests.get(url, stream=True) 26 | total = int(resp.headers.get("content-length", 0)) 27 | with open(fname, "wb") as file, tqdm( 28 | desc=fname, 29 | total=total, 30 | unit="iB", 31 | unit_scale=True, 32 | unit_divisor=1024, 33 | ) as bar: 34 | for data in resp.iter_content(chunk_size=chunk_size): 35 | size = file.write(data) 36 | bar.update(size) 37 | 38 | 39 | def download(): 40 | """Downloads the dataset to disk.""" 41 | os.makedirs(DATA_CACHE_DIR, exist_ok=True) 42 | 43 | # download the TinyStories dataset, unless it's already downloaded 44 | data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz" 45 | data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz") 46 | if not os.path.exists(data_filename): 47 | print(f"Downloading {data_url} to {data_filename}...") 48 | download_file(data_url, data_filename) 49 | else: 50 | print(f"{data_filename} already exists, skipping download...") 51 | 52 | # unpack the tar.gz file into all the data shards (json files) 53 | data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") 54 | if not os.path.exists(data_dir): 55 | os.makedirs(data_dir, exist_ok=True) 56 | print(f"Unpacking {data_filename}...") 57 | os.system(f"tar -xzf {data_filename} -C {data_dir}") 58 | else: 59 | print(f"{data_dir} already exists, skipping unpacking...") 60 | 61 | # print a single example just for debugging and such 62 | shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) 63 | with open(shard_filenames[0], "r") as f: 64 | data = json.load(f) 65 | print("Download done.") 66 | print(f"Number of shards: {len(shard_filenames)}") 67 | print(f"Example story:\n{data[0]}") 68 | 69 | def pretokenize(): 70 | enc = Tokenizer() 71 | 72 | def process_shard(shard): 73 | with open(shard, "r") as f: 74 | data = json.load(f) 75 | all_tokens = [] 76 | for example in tqdm(data): 77 | text = example["story"] 78 | text = text.strip() # get rid of leading/trailing whitespace 79 | tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS 80 | all_tokens.extend(tokens) 81 | # convert to uint16 nparray 82 | all_tokens = np.array(all_tokens, dtype=np.uint16) 83 | # write to disk 84 | tokenized_filename = shard.replace(".json", ".bin") 85 | with open(tokenized_filename, "wb") as f: 86 | f.write(all_tokens.tobytes()) 87 | print(f"Saved {tokenized_filename}") 88 | 89 | # iterate the shards and tokenize all of them one by one 90 | data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") 91 | shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) 92 | 93 | # process all the shards in a threadpool 94 | with ThreadPoolExecutor(max_workers=8) as executor: 95 | executor.map(process_shard, shard_filenames) 96 | 97 | print("Done.") 98 | 99 | 100 | class PretokDataset(torch.utils.data.IterableDataset): 101 | """Loads pretokenized examples from disk and yields them as PyTorch tensors.""" 102 | 103 | def __init__(self, split, max_seq_len): 104 | super().__init__() 105 | self.split = split 106 | self.max_seq_len = max_seq_len 107 | 108 | def __iter__(self): 109 | # get worker info within a DataLoader 110 | worker_info = torch.utils.data.get_worker_info() 111 | worker_id = worker_info.id if worker_info else 0 112 | # get DDP rank info 113 | rank = dist.get_rank() if dist.is_initialized() else 0 114 | # combine the worker_id and worker_rank to create a unique seed for rng 115 | seed = 42 + worker_id + 1337 * rank 116 | rng = random.Random(seed) 117 | print(f"Created a PretokDataset with rng seed {seed}") 118 | data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") 119 | shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.bin"))) 120 | # train/test split. let's use only shard 0 for test split, rest train 121 | shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1] 122 | while True: 123 | rng.shuffle(shard_filenames) 124 | for shard in shard_filenames: 125 | # open the dataset for reading but keep it on disk with memmap 126 | m = np.memmap(shard, dtype=np.uint16, mode="r") 127 | num_batches = len(m) // self.max_seq_len 128 | num_batches -= 1 # drop the last partial batch 129 | assert num_batches > 0, "this shard is way too small? investigate." 130 | ixs = list(range(num_batches)) 131 | rng.shuffle(ixs) 132 | for ix in ixs: 133 | start = ix * self.max_seq_len 134 | end = start + self.max_seq_len + 1 135 | # calling .astype will copy the data into a new numpy array, now in RAM 136 | chunk = torch.from_numpy((m[start:end]).astype(np.int64)) 137 | x = chunk[:-1] 138 | y = chunk[1:] 139 | yield x, y 140 | 141 | 142 | class Task: 143 | 144 | @staticmethod 145 | def iter_batches(split, batch_size, max_seq_len, device, num_workers=0): 146 | ds = PretokDataset(split, max_seq_len) 147 | dl = torch.utils.data.DataLoader( 148 | ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers 149 | ) 150 | for x, y in dl: 151 | x = x.to(device, non_blocking=True) 152 | y = y.to(device, non_blocking=True) 153 | yield x, y 154 | 155 | 156 | if __name__ == "__main__": 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"]) 159 | args = parser.parse_args() 160 | 161 | # depending on the stage call the appropriate function 162 | fun = { 163 | "download": download, 164 | "pretokenize": pretokenize, 165 | } 166 | fun[args.stage]() -------------------------------------------------------------------------------- /tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/tokenizer.model -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | # Taken from llama code and lightly modified 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 4 | 5 | import os 6 | from logging import getLogger 7 | from typing import List 8 | 9 | from sentencepiece import SentencePieceProcessor 10 | 11 | TOKENIZER_MODEL = "tokenizer.model" # the llama sentencepiece tokenizer model 12 | TOKENIZER_BIN = "tokenizer.bin" # binary version of the tokenizer for inference in C 13 | 14 | class Tokenizer: 15 | def __init__(self): 16 | model_path = TOKENIZER_MODEL 17 | assert os.path.isfile(model_path), model_path 18 | self.sp_model = SentencePieceProcessor(model_file=model_path) 19 | #print(f"Loaded SentencePiece model from {model_path}") 20 | 21 | # BOS / EOS token IDs 22 | self.n_words: int = self.sp_model.vocab_size() 23 | self.bos_id: int = self.sp_model.bos_id() 24 | self.eos_id: int = self.sp_model.eos_id() 25 | self.pad_id: int = self.sp_model.pad_id() 26 | #print(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") 27 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 28 | 29 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 30 | assert type(s) is str 31 | t = self.sp_model.encode(s) 32 | if bos: 33 | t = [self.bos_id] + t 34 | if eos: 35 | t = t + [self.eos_id] 36 | return t 37 | 38 | def decode(self, t: List[int]) -> str: 39 | return self.sp_model.decode(t) 40 | 41 | def export(self): 42 | tokens = [] 43 | for i in range(self.n_words): 44 | 45 | # decode the token and light postprocessing 46 | t = self.sp_model.id_to_piece(i) 47 | if i == self.bos_id: 48 | t = '\n\n' 49 | elif i == self.eos_id: 50 | t = '\n\n' 51 | elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'): 52 | t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01' 53 | t = t.replace('▁', ' ') # sentencepiece uses this as the whitespace 54 | 55 | tokens.append(t) 56 | 57 | with open(TOKENIZER_BIN, 'wb') as f: 58 | for token in tokens: 59 | bytes = token.encode('utf-8') 60 | f.write((len(bytes)).to_bytes(4, 'little')) # write length of bytes 61 | f.write(bytes) # write token bytes 62 | 63 | if __name__ == "__main__": 64 | t = Tokenizer() 65 | t.export() 66 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This training script can be run both on a single gpu in debug mode, 3 | and also in a larger training run with distributed data parallel (ddp). 4 | 5 | To run on a single GPU small debug run, example: 6 | $ python -m train.py --compile=False --eval_iters=10 --batch_size=8 7 | 8 | To run with DDP on 4 gpus on 1 node, example: 9 | $ torchrun --standalone --nproc_per_node=4 train.py 10 | 11 | To run with DDP on 4 gpus across 2 nodes, example: 12 | - Run on the first (master) node with example IP 123.456.123.456: 13 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 14 | - Run on the worker node: 15 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 16 | (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) 17 | """ 18 | 19 | import math 20 | import os 21 | import time 22 | from contextlib import nullcontext 23 | from datetime import datetime 24 | from functools import partial 25 | 26 | import torch 27 | from model import Transformer, ModelArgs 28 | from torch.distributed import destroy_process_group, init_process_group 29 | from torch.nn.parallel import DistributedDataParallel as DDP 30 | 31 | from tinystories import Task 32 | 33 | # ----------------------------------------------------------------------------- 34 | # I/O 35 | out_dir = "out" 36 | eval_interval = 2000 37 | log_interval = 1 38 | eval_iters = 100 39 | eval_only = False # if True, script exits right after the first eval 40 | always_save_checkpoint = False # if True, always save a checkpoint after each eval 41 | init_from = "scratch" # 'scratch' or 'resume' 42 | # wandb logging 43 | wandb_log = False # disabled by default 44 | wandb_project = "llamac" 45 | wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 46 | # data 47 | batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size 48 | max_seq_len = 256 49 | # model 50 | dim = 288 51 | n_layers = 6 52 | n_heads = 6 53 | multiple_of = 32 54 | dropout = 0.0 55 | # adamw optimizer 56 | gradient_accumulation_steps = 4 # used to simulate larger batch sizes 57 | learning_rate = 5e-4 # max learning rate 58 | max_iters = 100000 # total number of training iterations 59 | weight_decay = 1e-1 60 | beta1 = 0.9 61 | beta2 = 0.95 62 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 63 | # learning rate decay settings 64 | decay_lr = True # whether to decay the learning rate 65 | warmup_iters = 1000 # how many steps to warm up for 66 | # system 67 | device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 68 | dtype = "bfloat16" # float32|bfloat16|float16 69 | compile = True # use PyTorch 2.0 to compile the model to be faster 70 | # ----------------------------------------------------------------------------- 71 | config_keys = [ 72 | k 73 | for k, v in globals().items() 74 | if not k.startswith("_") and isinstance(v, (int, float, bool, str)) 75 | ] 76 | exec(open("configurator.py").read()) # overrides from command line or config file 77 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 78 | # ----------------------------------------------------------------------------- 79 | 80 | # fixing some hyperparams to sensible defaults 81 | lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla 82 | min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 83 | 84 | # various inits, derived attributes, I/O setup 85 | ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? 86 | if ddp: 87 | init_process_group(backend="nccl") 88 | ddp_rank = int(os.environ["RANK"]) 89 | ddp_local_rank = int(os.environ["LOCAL_RANK"]) 90 | ddp_world_size = int(os.environ["WORLD_SIZE"]) 91 | device = f"cuda:{ddp_local_rank}" 92 | torch.cuda.set_device(device) 93 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 94 | seed_offset = ddp_rank # each process gets a different seed 95 | # world_size number of processes will be training simultaneously, so we can scale 96 | # down the desired gradient accumulation iterations per process proportionally 97 | assert gradient_accumulation_steps % ddp_world_size == 0 98 | gradient_accumulation_steps //= ddp_world_size 99 | else: 100 | # if not ddp, we are running on a single gpu, and one process 101 | master_process = True 102 | seed_offset = 0 103 | ddp_world_size = 1 104 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len 105 | if master_process: 106 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 107 | print(f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len") 108 | 109 | if master_process: 110 | os.makedirs(out_dir, exist_ok=True) 111 | torch.manual_seed(1337 + seed_offset) 112 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 113 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 114 | device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast 115 | # note: float16 data type will automatically use a GradScaler 116 | ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype] 117 | ctx = ( 118 | nullcontext() 119 | if device_type == "cpu" 120 | else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 121 | ) 122 | 123 | # task-specific setup 124 | iter_batches = partial( 125 | Task.iter_batches, 126 | batch_size=batch_size, 127 | max_seq_len=max_seq_len, 128 | device=device, 129 | num_workers=0, 130 | ) 131 | 132 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 133 | iter_num = 0 134 | best_val_loss = 1e9 135 | 136 | # model init 137 | model_args = dict( 138 | dim=dim, 139 | n_layers=n_layers, 140 | n_heads=n_heads, 141 | n_kv_heads=n_heads, 142 | vocab_size=32000, 143 | multiple_of=multiple_of, 144 | max_seq_len=max_seq_len, 145 | #dropout=dropout, 146 | ) # start with model_args from command line 147 | if init_from == "scratch": 148 | # init a new model from scratch 149 | print("Initializing a new model from scratch") 150 | gptconf = ModelArgs(**model_args) 151 | model = Transformer(gptconf) 152 | elif init_from == "resume": 153 | print(f"Resuming training from {out_dir}") 154 | # resume training from a checkpoint. 155 | ckpt_path = os.path.join(out_dir, "ckpt.pt") 156 | checkpoint = torch.load(ckpt_path, map_location=device) 157 | checkpoint_model_args = checkpoint["model_args"] 158 | # force these config attributes to be equal otherwise we can't even resume training 159 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 160 | for k in ["dim", "n_layers", "n_heads", "n_kv_heads", "vocab_size", "multiple_of", "max_seq_len"]: 161 | model_args[k] = checkpoint_model_args[k] 162 | # create the model 163 | gptconf = ModelArgs(**model_args) 164 | model = Transformer(gptconf) 165 | state_dict = checkpoint["model"] 166 | # fix the keys of the state dictionary :( 167 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 168 | unwanted_prefix = "_orig_mod." 169 | for k, v in list(state_dict.items()): 170 | if k.startswith(unwanted_prefix): 171 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) 172 | model.load_state_dict(state_dict) 173 | iter_num = checkpoint["iter_num"] 174 | best_val_loss = checkpoint["best_val_loss"] 175 | model.to(device) 176 | 177 | # initialize a GradScaler. If enabled=False scaler is a no-op 178 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) 179 | 180 | # optimizer 181 | optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) 182 | if init_from == "resume": 183 | optimizer.load_state_dict(checkpoint["optimizer"]) 184 | checkpoint = None # free up memory 185 | 186 | # compile the model 187 | if compile: 188 | print("compiling the model... (takes a ~minute)") 189 | unoptimized_model = model 190 | model = torch.compile(model) # requires PyTorch 2.0 191 | 192 | # wrap model into DDP container 193 | if ddp: 194 | # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at 195 | # construction time since NCCL does not support `ComplexFloat` 196 | prefix = "_orig_mod." if compile else "" 197 | model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"} 198 | model = DDP(model, device_ids=[ddp_local_rank]) 199 | 200 | # helps estimate an arbitrarily accurate loss over either split using many batches 201 | @torch.no_grad() 202 | def estimate_loss(): 203 | out = {} 204 | model.eval() 205 | for split in ["train", "val"]: 206 | batch_iter = iter_batches(split) 207 | losses = torch.zeros(eval_iters) # keep on CPU 208 | for k in range(eval_iters): 209 | X, Y = next(batch_iter) 210 | with ctx: 211 | logits, loss = model(X, Y) 212 | losses[k] = loss.item() 213 | out[split] = losses.mean() 214 | model.train() 215 | return out 216 | 217 | # learning rate decay scheduler (cosine with warmup) 218 | def get_lr(it): 219 | # 1) linear warmup for warmup_iters steps 220 | if it < warmup_iters: 221 | return learning_rate * it / warmup_iters 222 | # 2) if it > lr_decay_iters, return min learning rate 223 | if it > lr_decay_iters: 224 | return min_lr 225 | # 3) in between, use cosine decay down to min learning rate 226 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 227 | assert 0 <= decay_ratio <= 1 228 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 229 | return min_lr + coeff * (learning_rate - min_lr) 230 | 231 | # logging 232 | if wandb_log and master_process: 233 | import wandb 234 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 235 | 236 | # training loop 237 | train_batch_iter = iter_batches("train") 238 | X, Y = next(train_batch_iter) # fetch the very first batch 239 | t0 = time.time() 240 | local_iter_num = 0 # number of iterations in the lifetime of this process 241 | raw_model = model.module if ddp else model # unwrap DDP container if needed 242 | running_mfu = -1.0 243 | while True: 244 | # determine and set the learning rate for this iteration 245 | lr = get_lr(iter_num) if decay_lr else learning_rate 246 | for param_group in optimizer.param_groups: 247 | param_group["lr"] = lr 248 | 249 | # evaluate the loss on train/val sets and write checkpoints 250 | if iter_num % eval_interval == 0 and master_process: 251 | losses = estimate_loss() 252 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 253 | if wandb_log: 254 | try: 255 | wandb.log( 256 | { 257 | "iter": iter_num, 258 | "tokens": iter_num * tokens_per_iter, 259 | "loss/train": losses["train"], 260 | "loss/val": losses["val"], 261 | "lr": lr, 262 | "mfu": running_mfu * 100, # convert to percentage 263 | } 264 | ) 265 | except Exception as e: 266 | print(f"logging to wandb failed: {e}") 267 | if losses["val"] < best_val_loss or always_save_checkpoint: 268 | best_val_loss = losses["val"] 269 | if iter_num > 0: 270 | checkpoint = { 271 | "model": raw_model.state_dict(), 272 | "optimizer": optimizer.state_dict(), 273 | "model_args": model_args, 274 | "iter_num": iter_num, 275 | "best_val_loss": best_val_loss, 276 | "config": config, 277 | } 278 | print(f"saving checkpoint to {out_dir}") 279 | torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt")) 280 | raw_model.export(os.path.join(out_dir, "model.bin")) 281 | if iter_num == 0 and eval_only: 282 | break 283 | 284 | # forward backward update, with optional gradient accumulation to simulate larger batch size 285 | # and using the GradScaler if data type is float16 286 | for micro_step in range(gradient_accumulation_steps): 287 | if ddp: 288 | # in DDP training we only need to sync gradients at the last micro step. 289 | # the official way to do this is with model.no_sync() context manager, but 290 | # I really dislike that this bloats the code and forces us to repeat code 291 | # looking at the source of that context manager, it just toggles this variable 292 | model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 293 | with ctx: 294 | logits, loss = model(X, Y) 295 | loss = loss / gradient_accumulation_steps 296 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 297 | X, Y = next(train_batch_iter) 298 | # backward pass, with gradient scaling if training in fp16 299 | scaler.scale(loss).backward() 300 | # clip the gradient 301 | if grad_clip != 0.0: 302 | scaler.unscale_(optimizer) 303 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 304 | # step the optimizer and scaler if training in fp16 305 | scaler.step(optimizer) 306 | scaler.update() 307 | # flush the gradients as soon as we can, no need for this memory anymore 308 | optimizer.zero_grad(set_to_none=True) 309 | 310 | # timing and logging 311 | t1 = time.time() 312 | dt = t1 - t0 313 | t0 = t1 314 | if iter_num % log_interval == 0 and master_process: 315 | # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point 316 | lossf = loss.item() * gradient_accumulation_steps 317 | if local_iter_num >= 5: # let the training loop settle a bit 318 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 319 | running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu 320 | print( 321 | f"{iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%" 322 | ) 323 | iter_num += 1 324 | local_iter_num += 1 325 | 326 | # termination conditions 327 | if iter_num > max_iters: 328 | break 329 | 330 | if ddp: 331 | destroy_process_group() 332 | --------------------------------------------------------------------------------