├── .envrc
├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── assets
├── llama2_in_termux.gif
└── llama_android.png
├── configurator.py
├── devshell.nix
├── export_meta_llama_bin.py
├── flake.lock
├── flake.nix
├── jni
└── Android.mk
├── libs
├── arm64-v8a
│ └── llama2
├── armeabi-v7a
│ └── llama2
├── x86
│ └── llama2
└── x86_64
│ └── llama2
├── model.py
├── requirements.txt
├── run.c
├── sample.py
├── test_all.py
├── tinystories.py
├── tokenizer.bin
├── tokenizer.model
├── tokenizer.py
└── train.py
/.envrc:
--------------------------------------------------------------------------------
1 | use flake
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .direnv/
2 | out/
3 | obj/
4 | run
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Andrej
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Change this to whereever you keep NDK
2 | NDK = /home/manuel/Android/Sdk/ndk/25.2.9519653
3 | SRCDIR = .
4 | OBJDIR = .
5 | DBG ?= 0
6 |
7 | # Debug/Release configuration
8 | ifeq ($(DBG),1)
9 | MODE_FLAGS = -DDEBUG -g -O0
10 | else
11 | MODE_FLAGS = -Os -fdata-sections -ffunction-sections
12 | endif
13 |
14 | ## NDK configuration (clang)
15 |
16 | # NDK Version
17 | NDK_TARGETVER = 27
18 |
19 | # Target arch - here aarch64 for android
20 | NDK_TARGETARCH = aarch64-linux-android
21 |
22 | # Target CPU (ARMv8)
23 | NDK_TARGETSHORTARCH = arm64
24 |
25 | # Toolchain version
26 | NDK_TOOLVER = 4.9
27 |
28 | # Architecture of a machine that does cross compilation
29 | NDK_HOSTARCH = linux-x86_64
30 |
31 | # Set needed preprocessor symbols
32 | NDK_TOOLS = $(NDK)/toolchains/llvm/prebuilt/$(NDK_HOSTARCH)/bin
33 | NDK_SYSROOT = $(NDK)/sysroot
34 | # NDK_TOOL = $(CLANG_PATH)/bin/clang
35 | NDK_TOOL = $(NDK_TOOLS)/clang-14
36 | NDK_LIBS = $(NDK)/toolchains/$(NDK_TARGETARCH)-$(NDK_TOOLVER)/prebuilt/linux-x86_64/lib/gcc/$(NDK_TARGETARCH)/4.9.x
37 | NDK_INCLUDES = -I$(NDK)/sysroot/usr/include \
38 | -I$(NDK)/sysroot/usr/include/$(NDK_TARGETARCH)
39 | NDK_SYSROOT = $(NDK)/platforms/android-$(NDK_TARGETVER)/arch-$(NDK_TARGETSHORTARCH)
40 |
41 | # Options common to compiler and linker
42 | OPT = $(MODE_FLAGS) \
43 | -std=c99 \
44 | -fPIE \
45 | -Wall \
46 | -target $(NDK_TARGETARCH)
47 |
48 | # Compiler options
49 | CFLAGS = $(OPT) \
50 | $(NDK_INCLUDES)
51 |
52 | # Linker options
53 | LDFLAGS = $(OPT) \
54 | $(MODE_FLAGS) \
55 | -pie \
56 | --sysroot=$(NDK_SYSROOT) \
57 | -B $(ANDROID_NDK)/toolchains/$(NDK_TARGETARCH)-$(NDK_TOOLVER)/prebuilt/linux-x86_64/$(NDK_TARGETARCH)/bin \
58 | -L$(NDK_LIBS)
59 |
60 | all:
61 | echo ${NDK_TOOL}
62 | $(NDK_TOOL) -c $(SRCDIR)/run.c -o $(OBJDIR)/run.o $(CFLAGS)
63 | $(NDK_TOOL) -o run $(OBJDIR)/run.o $(LDFLAGS)
64 |
65 | # the most basic way of building that is most likely to work on most systems
66 | .PHONY: run
67 | run: run.c
68 | gcc -O3 -o run run.c -lm
69 |
70 | # useful for a debug build, can then e.g. analyze with valgrind, example:
71 | # $ valgrind --leak-check=full ./run out/model.bin 1.0 3
72 | rundebug: run.c
73 | gcc -g -o run run.c -lm
74 |
75 | # https://gcc.gnu.org/onlinedocs/gcc/Optimize-Options.html
76 | # https://simonbyrne.github.io/notes/fastmath/
77 | # -Ofast enables all -O3 optimizations.
78 | # Disregards strict standards compliance.
79 | # It also enables optimizations that are not valid for all standard-compliant programs.
80 | # It turns on -ffast-math, -fallow-store-data-races and the Fortran-specific
81 | # -fstack-arrays, unless -fmax-stack-var-size is specified, and -fno-protect-parens.
82 | # It turns off -fsemantic-interposition.
83 | # In our specific application this is *probably* okay to use
84 | .PHONY: runfast
85 | runfast: run.c
86 | gcc -Ofast -o run run.c -lm
87 |
88 | # additionally compiles with OpenMP, allowing multithreaded runs
89 | # make sure to also enable multiple threads when running, e.g.:
90 | # OMP_NUM_THREADS=4 ./run out/model.bin
91 | .PHONY: runomp
92 | runomp: run.c
93 | gcc -Ofast -fopenmp -march=native run.c -lm -o run
94 |
95 | .PHONY: clean
96 | clean:
97 | rm -f run
98 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # llama2.c-android
2 |
3 |
4 |
5 | Port of Andrej Karpathy's [llama2.c](https://github.com/karpathy/llama2.c) to Android. You can run it as raw binary or use it as shared library.
6 |
7 | You can use the prebuild binaries in `libs` or compile on your own:
8 |
9 | ```bash
10 | # or wherever your ndk-build script resides
11 | cd jni && $ANDROID_HOME/ndk-bundle/ndk-build
12 | ```
13 |
14 | ## run as binary
15 |
16 | Get e.g. [termux](https://f-droid.org/en/packages/com.termux/) and install APK to run binaries.
17 |
18 | ```bash
19 | wget https://karpathy.ai/llama2c/model.bin -P out
20 | adb push libs//llama2 /storage/emulated/0/Android/data
21 | adb push model.bin /storage/emulated/0/Android/data
22 | adb push tokenizer.bin /storage/emulated/0/Android/data
23 | ```
24 |
25 | In Termux:
26 |
27 | ```bash
28 | cp /storage/emulated/0/Android/llama2 .
29 | chmod +x llama2
30 | ./llama2 model.bin
31 | ```
32 |
33 | 
34 |
35 | ## run as shared lib
36 |
37 | wip
38 |
39 |
--------------------------------------------------------------------------------
/assets/llama2_in_termux.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/assets/llama2_in_termux.gif
--------------------------------------------------------------------------------
/assets/llama_android.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/assets/llama_android.png
--------------------------------------------------------------------------------
/configurator.py:
--------------------------------------------------------------------------------
1 | """
2 | Poor Man's Configurator. Probably a terrible idea. Example usage:
3 | $ python train.py config/override_file.py --batch_size=32
4 | this will first run config/override_file.py, then override batch_size to 32
5 |
6 | The code in this file will be run as follows from e.g. train.py:
7 | >>> exec(open('configurator.py').read())
8 |
9 | So it's not a Python module, it's just shuttling this code away from train.py
10 | The code in this script then overrides the globals()
11 |
12 | I know people are not going to love this, I just really dislike configuration
13 | complexity and having to prepend config. to every single variable. If someone
14 | comes up with a better simple Python solution I am all ears.
15 | """
16 |
17 | import sys
18 | from ast import literal_eval
19 |
20 | for arg in sys.argv[1:]:
21 | if '=' not in arg:
22 | # assume it's the name of a config file
23 | assert not arg.startswith('--')
24 | config_file = arg
25 | print(f"Overriding config with {config_file}:")
26 | with open(config_file) as f:
27 | print(f.read())
28 | exec(open(config_file).read())
29 | else:
30 | # assume it's a --key=value argument
31 | assert arg.startswith('--')
32 | key, val = arg.split('=')
33 | key = key[2:]
34 | if key in globals():
35 | try:
36 | # attempt to eval it it (e.g. if bool, number, or etc)
37 | attempt = literal_eval(val)
38 | except (SyntaxError, ValueError):
39 | # if that goes wrong, just use the string
40 | attempt = val
41 | # ensure the types match ok
42 | assert type(attempt) == type(globals()[key])
43 | # cross fingers
44 | print(f"Overriding: {key} = {attempt}")
45 | globals()[key] = attempt
46 | else:
47 | raise ValueError(f"Unknown config key: {key}")
48 |
--------------------------------------------------------------------------------
/devshell.nix:
--------------------------------------------------------------------------------
1 | { pkgs }:
2 |
3 | with pkgs;
4 |
5 | devshell.mkShell {
6 | name = "android-project";
7 | motd = ''
8 | Entered the Android app development environment.
9 | '';
10 | env = [
11 | {
12 | name = "ANDROID_HOME";
13 | value = "${android-sdk}/share/android-sdk";
14 | }
15 | {
16 | name = "ANDROID_SDK_ROOT";
17 | value = "${android-sdk}/share/android-sdk";
18 | }
19 | {
20 | name = "JAVA_HOME";
21 | value = jdk11.home;
22 | }
23 | ];
24 | packages = [
25 | android-studio
26 | android-sdk
27 | gradle
28 | jdk11
29 | gcc
30 | wget
31 | ];
32 | }
33 |
--------------------------------------------------------------------------------
/export_meta_llama_bin.py:
--------------------------------------------------------------------------------
1 | """
2 | This script exports the Llama 2 weights in llama2c.bin format.
3 | """
4 | import sys
5 | import struct
6 | from pathlib import Path
7 | import json
8 |
9 | import torch
10 |
11 | from model import precompute_freqs_cis
12 |
13 |
14 | def export(p, state_dict, filepath='model.bin'):
15 | """export the model weights in fp32 into .bin file to be read from C"""
16 | f = open(filepath, 'wb')
17 |
18 | def serialize(key):
19 | print(f"writing {key}...")
20 | t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
21 | f.write(memoryview(t))
22 | del state_dict[key]
23 |
24 | # first write out the header
25 | hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0]
26 | p['vocab_size'] = 32000
27 | p['max_seq_len'] = 2048
28 |
29 | n_kv_heads = p.get('n_kv_heads') or p['n_heads']
30 | header = struct.pack(
31 | 'iiiiiii',
32 | p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
33 | n_kv_heads, -p['vocab_size'], p['max_seq_len']
34 | )
35 | # NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
36 | # in the checkpoint and should be loaded.
37 | f.write(header)
38 |
39 | # next write out the embedding weights
40 | print("writing tok_embeddings...")
41 | serialize('tok_embeddings.weight')
42 |
43 | # now all the layers
44 | # attention weights
45 | for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight')
46 | for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight')
47 | for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight')
48 | for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight')
49 | for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight')
50 | # ffn weights
51 | for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight')
52 | for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight')
53 | for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight')
54 | for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight')
55 |
56 | # final rmsnorm
57 | serialize('norm.weight')
58 | # freqs_cis
59 | freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
60 | state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']]
61 | state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']]
62 | serialize('freqs_cis.real')
63 | serialize('freqs_cis.imag')
64 |
65 | # finally write the output weights
66 | serialize('output.weight')
67 |
68 | f.close()
69 | print(f"wrote {filepath}")
70 |
71 |
72 | def concat_weights(models):
73 | state_dict = {}
74 | for name in list(models[0]):
75 | tensors = [model[name] for model in models]
76 | if len(tensors) == 1 or len(tensors[0].shape) == 1:
77 | state_dict[name] = tensors[0]
78 | continue
79 | is_axis_1 = (
80 | name.startswith('tok_embeddings.')
81 | or name.endswith('.attention.wo.weight')
82 | or name.endswith('.feed_forward.w2.weight')
83 | )
84 | axis = 1 if is_axis_1 else 0
85 | state_dict[name] = torch.cat(tensors, dim=axis)
86 | for model in models:
87 | del model[name]
88 | return state_dict
89 |
90 |
91 | def load_and_export(model_path, output_path):
92 | with open(model_path + 'params.json') as f:
93 | params = json.load(f)
94 | print(params)
95 |
96 | model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
97 | models = []
98 | for i in model_paths:
99 | print(f'Loading {i}')
100 | models.append(torch.load(i, map_location='cpu'))
101 |
102 | state_dict = concat_weights(models)
103 | del models
104 | export(params, state_dict, output_path)
105 |
106 |
107 | if __name__ == '__main__':
108 | if len(sys.argv) == 1:
109 | print('[Llama model folder path] [output path]')
110 | exit()
111 |
112 | model_path = sys.argv[1]
113 | output_path = sys.argv[2]
114 | load_and_export(model_path, output_path)
115 |
--------------------------------------------------------------------------------
/flake.lock:
--------------------------------------------------------------------------------
1 | {
2 | "nodes": {
3 | "android": {
4 | "inputs": {
5 | "devshell": "devshell",
6 | "flake-utils": "flake-utils",
7 | "nixpkgs": "nixpkgs"
8 | },
9 | "locked": {
10 | "lastModified": 1690316394,
11 | "narHash": "sha256-uiTwTicVz8fuwScghAXP/MGIJvtMEMiQdNJLQzf8o1g=",
12 | "owner": "tadfisher",
13 | "repo": "android-nixpkgs",
14 | "rev": "ed6c49798b760b644911eae3f4b008fa41983ddf",
15 | "type": "github"
16 | },
17 | "original": {
18 | "owner": "tadfisher",
19 | "repo": "android-nixpkgs",
20 | "type": "github"
21 | }
22 | },
23 | "devshell": {
24 | "inputs": {
25 | "nixpkgs": [
26 | "android",
27 | "nixpkgs"
28 | ],
29 | "systems": "systems"
30 | },
31 | "locked": {
32 | "lastModified": 1688380630,
33 | "narHash": "sha256-8ilApWVb1mAi4439zS3iFeIT0ODlbrifm/fegWwgHjA=",
34 | "owner": "numtide",
35 | "repo": "devshell",
36 | "rev": "f9238ec3d75cefbb2b42a44948c4e8fb1ae9a205",
37 | "type": "github"
38 | },
39 | "original": {
40 | "owner": "numtide",
41 | "repo": "devshell",
42 | "type": "github"
43 | }
44 | },
45 | "devshell_2": {
46 | "inputs": {
47 | "nixpkgs": "nixpkgs_2",
48 | "systems": "systems_3"
49 | },
50 | "locked": {
51 | "lastModified": 1688380630,
52 | "narHash": "sha256-8ilApWVb1mAi4439zS3iFeIT0ODlbrifm/fegWwgHjA=",
53 | "owner": "numtide",
54 | "repo": "devshell",
55 | "rev": "f9238ec3d75cefbb2b42a44948c4e8fb1ae9a205",
56 | "type": "github"
57 | },
58 | "original": {
59 | "owner": "numtide",
60 | "repo": "devshell",
61 | "type": "github"
62 | }
63 | },
64 | "flake-utils": {
65 | "inputs": {
66 | "systems": "systems_2"
67 | },
68 | "locked": {
69 | "lastModified": 1689068808,
70 | "narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=",
71 | "owner": "numtide",
72 | "repo": "flake-utils",
73 | "rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4",
74 | "type": "github"
75 | },
76 | "original": {
77 | "owner": "numtide",
78 | "repo": "flake-utils",
79 | "type": "github"
80 | }
81 | },
82 | "flake-utils_2": {
83 | "inputs": {
84 | "systems": "systems_4"
85 | },
86 | "locked": {
87 | "lastModified": 1689068808,
88 | "narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=",
89 | "owner": "numtide",
90 | "repo": "flake-utils",
91 | "rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4",
92 | "type": "github"
93 | },
94 | "original": {
95 | "owner": "numtide",
96 | "repo": "flake-utils",
97 | "type": "github"
98 | }
99 | },
100 | "nixpkgs": {
101 | "locked": {
102 | "lastModified": 1690179384,
103 | "narHash": "sha256-+arbgqFTAtoeKtepW9wCnA0njCOyoiDFyl0Q0SBSOtE=",
104 | "owner": "NixOS",
105 | "repo": "nixpkgs",
106 | "rev": "b12803b6d90e2e583429bb79b859ca53c348b39a",
107 | "type": "github"
108 | },
109 | "original": {
110 | "owner": "NixOS",
111 | "ref": "nixos-unstable",
112 | "repo": "nixpkgs",
113 | "type": "github"
114 | }
115 | },
116 | "nixpkgs_2": {
117 | "locked": {
118 | "lastModified": 1677383253,
119 | "narHash": "sha256-UfpzWfSxkfXHnb4boXZNaKsAcUrZT9Hw+tao1oZxd08=",
120 | "owner": "NixOS",
121 | "repo": "nixpkgs",
122 | "rev": "9952d6bc395f5841262b006fbace8dd7e143b634",
123 | "type": "github"
124 | },
125 | "original": {
126 | "owner": "NixOS",
127 | "ref": "nixpkgs-unstable",
128 | "repo": "nixpkgs",
129 | "type": "github"
130 | }
131 | },
132 | "nixpkgs_3": {
133 | "locked": {
134 | "lastModified": 1690387285,
135 | "narHash": "sha256-2nwKSKw48Uh8o+nZk+QDlqmeJpSf43pSwIwg2pCITc4=",
136 | "owner": "NixOS",
137 | "repo": "nixpkgs",
138 | "rev": "f323c3770063a6c6e30817d18287810b0804d3fb",
139 | "type": "github"
140 | },
141 | "original": {
142 | "owner": "NixOS",
143 | "repo": "nixpkgs",
144 | "type": "github"
145 | }
146 | },
147 | "root": {
148 | "inputs": {
149 | "android": "android",
150 | "devshell": "devshell_2",
151 | "flake-utils": "flake-utils_2",
152 | "nixpkgs": "nixpkgs_3"
153 | }
154 | },
155 | "systems": {
156 | "locked": {
157 | "lastModified": 1681028828,
158 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
159 | "owner": "nix-systems",
160 | "repo": "default",
161 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
162 | "type": "github"
163 | },
164 | "original": {
165 | "owner": "nix-systems",
166 | "repo": "default",
167 | "type": "github"
168 | }
169 | },
170 | "systems_2": {
171 | "locked": {
172 | "lastModified": 1681028828,
173 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
174 | "owner": "nix-systems",
175 | "repo": "default",
176 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
177 | "type": "github"
178 | },
179 | "original": {
180 | "owner": "nix-systems",
181 | "repo": "default",
182 | "type": "github"
183 | }
184 | },
185 | "systems_3": {
186 | "locked": {
187 | "lastModified": 1681028828,
188 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
189 | "owner": "nix-systems",
190 | "repo": "default",
191 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
192 | "type": "github"
193 | },
194 | "original": {
195 | "owner": "nix-systems",
196 | "repo": "default",
197 | "type": "github"
198 | }
199 | },
200 | "systems_4": {
201 | "locked": {
202 | "lastModified": 1681028828,
203 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
204 | "owner": "nix-systems",
205 | "repo": "default",
206 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
207 | "type": "github"
208 | },
209 | "original": {
210 | "owner": "nix-systems",
211 | "repo": "default",
212 | "type": "github"
213 | }
214 | }
215 | },
216 | "root": "root",
217 | "version": 7
218 | }
219 |
--------------------------------------------------------------------------------
/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | description = "android port llama2.c";
3 |
4 | inputs = {
5 | nixpkgs.url = "github:NixOS/nixpkgs";
6 | devshell.url = "github:numtide/devshell";
7 | flake-utils.url = "github:numtide/flake-utils";
8 | android.url = "github:tadfisher/android-nixpkgs";
9 | };
10 |
11 | outputs = { self, nixpkgs, devshell, flake-utils, android }:
12 | {
13 | overlay = final: prev: {
14 | inherit (self.packages.${final.system}) android-sdk android-studio;
15 | };
16 | }
17 | //
18 | flake-utils.lib.eachSystem [ "x86_64-linux" ] (system:
19 | let
20 | pkgs = import nixpkgs {
21 | inherit system;
22 | config.allowUnfree = true;
23 | overlays = [
24 | devshell.overlays.default
25 | self.overlay
26 | ];
27 | };
28 | in
29 | {
30 | packages = {
31 | android-sdk = android.sdk.${system} (sdkPkgs: with sdkPkgs; [
32 | build-tools-30-0-2
33 | cmdline-tools-latest
34 | emulator
35 | platform-tools
36 | platforms-android-3
37 | ndk-bundle
38 | ]);
39 |
40 | android-studio = pkgs.androidStudioPackages.stable;
41 | gcc = pkgs.gcc;
42 | wget = pkgs.wget;
43 | };
44 |
45 | devShell = import ./devshell.nix { inherit pkgs; };
46 | }
47 | );
48 | }
49 |
--------------------------------------------------------------------------------
/jni/Android.mk:
--------------------------------------------------------------------------------
1 | LOCAL_PATH := $(call my-dir)
2 | include $(CLEAR_VARS)
3 | LOCAL_MODULE := llama2
4 | LOCAL_SRC_FILES := ../run.c
5 | include $(BUILD_EXECUTABLE)
--------------------------------------------------------------------------------
/libs/arm64-v8a/llama2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/libs/arm64-v8a/llama2
--------------------------------------------------------------------------------
/libs/armeabi-v7a/llama2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/libs/armeabi-v7a/llama2
--------------------------------------------------------------------------------
/libs/x86/llama2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/libs/x86/llama2
--------------------------------------------------------------------------------
/libs/x86_64/llama2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/libs/x86_64/llama2
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import struct
3 | import inspect
4 | from dataclasses import dataclass
5 | from typing import Any, Optional, Tuple
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from torch import nn
11 |
12 | @dataclass
13 | class ModelArgs:
14 | dim: int = 4096
15 | n_layers: int = 32
16 | n_heads: int = 32
17 | n_kv_heads: Optional[int] = None
18 | vocab_size: int = -1 # defined later by tokenizer
19 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
20 | norm_eps: float = 1e-5
21 | max_seq_len: int = 2048
22 | dropout: float = 0.0
23 |
24 |
25 | class RMSNorm(torch.nn.Module):
26 | def __init__(self, dim: int, eps: float):
27 | super().__init__()
28 | self.eps = eps
29 | self.weight = nn.Parameter(torch.ones(dim))
30 |
31 | def _norm(self, x):
32 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
33 |
34 | def forward(self, x):
35 | output = self._norm(x.float()).type_as(x)
36 | return output * self.weight
37 |
38 |
39 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
40 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
41 | t = torch.arange(end, device=freqs.device) # type: ignore
42 | freqs = torch.outer(t, freqs).float() # type: ignore
43 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
44 | return freqs_cis
45 |
46 |
47 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
48 | ndim = x.ndim
49 | assert 0 <= 1 < ndim
50 | assert freqs_cis.shape == (x.shape[1], x.shape[-1])
51 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
52 | return freqs_cis.view(*shape)
53 |
54 |
55 | def apply_rotary_emb(
56 | xq: torch.Tensor,
57 | xk: torch.Tensor,
58 | freqs_cis: torch.Tensor,
59 | ) -> Tuple[torch.Tensor, torch.Tensor]:
60 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
61 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
62 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
63 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
64 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
65 | return xq_out.type_as(xq), xk_out.type_as(xk)
66 |
67 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
68 | """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
69 | bs, slen, n_kv_heads, head_dim = x.shape
70 | if n_rep == 1:
71 | return x
72 | return (
73 | x[:, :, :, None, :]
74 | .expand(bs, slen, n_kv_heads, n_rep, head_dim)
75 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
76 | )
77 |
78 | class Attention(nn.Module):
79 | def __init__(self, args: ModelArgs):
80 | super().__init__()
81 | self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
82 | model_parallel_size = 1
83 | self.n_local_heads = args.n_heads // model_parallel_size
84 | self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
85 | self.n_rep = self.n_local_heads // self.n_local_kv_heads
86 | self.head_dim = args.dim // args.n_heads
87 | self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
88 | self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
89 | self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
90 | self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
91 | self.attn_dropout = nn.Dropout(args.dropout)
92 | self.resid_dropout = nn.Dropout(args.dropout)
93 | self.dropout = args.dropout
94 |
95 | # use flash attention or a manual implementation?
96 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
97 | if not self.flash:
98 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
99 | mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
100 | mask = torch.triu(mask, diagonal=1)
101 | self.register_buffer("mask", mask)
102 |
103 | def forward(
104 | self,
105 | x: torch.Tensor,
106 | freqs_cis: torch.Tensor,
107 | ):
108 | bsz, seqlen, _ = x.shape
109 |
110 | # QKV
111 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
112 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
113 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
114 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
115 |
116 | # RoPE relative positional embeddings
117 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
118 |
119 | # grouped multiquery attention: expand out keys and values
120 | xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
121 | xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
122 |
123 | # make heads into a batch dimension
124 | xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
125 | xk = xk.transpose(1, 2)
126 | xv = xv.transpose(1, 2)
127 |
128 | # flash implementation
129 | if self.flash:
130 | output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
131 | else:
132 | # manual implementation
133 | scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
134 | scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
135 | scores = F.softmax(scores.float(), dim=-1).type_as(xq)
136 | scores = self.attn_dropout(scores)
137 | output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
138 |
139 | # restore time as batch dimension and concat heads
140 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
141 |
142 | # final projection into the residual stream
143 | output = self.wo(output)
144 | output = self.resid_dropout(output)
145 | return output
146 |
147 |
148 | class FeedForward(nn.Module):
149 | def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
150 | super().__init__()
151 | hidden_dim = int(2 * hidden_dim / 3)
152 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
153 | self.w1 = nn.Linear(dim, hidden_dim, bias=False)
154 | self.w2 = nn.Linear(hidden_dim, dim, bias=False)
155 | self.w3 = nn.Linear(dim, hidden_dim, bias=False)
156 | self.dropout = nn.Dropout(dropout)
157 |
158 | def forward(self, x):
159 | return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
160 |
161 |
162 | class TransformerBlock(nn.Module):
163 | def __init__(self, layer_id: int, args: ModelArgs):
164 | super().__init__()
165 | self.n_heads = args.n_heads
166 | self.dim = args.dim
167 | self.head_dim = args.dim // args.n_heads
168 | self.attention = Attention(args)
169 | self.feed_forward = FeedForward(
170 | dim=args.dim,
171 | hidden_dim=4 * args.dim,
172 | multiple_of=args.multiple_of,
173 | dropout=args.dropout,
174 | )
175 | self.layer_id = layer_id
176 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
177 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
178 |
179 | def forward(self, x, freqs_cis):
180 | h = x + self.attention.forward(self.attention_norm(x), freqs_cis)
181 | out = h + self.feed_forward.forward(self.ffn_norm(h))
182 | return out
183 |
184 |
185 | class Transformer(nn.Module):
186 | def __init__(self, params: ModelArgs):
187 | super().__init__()
188 | self.params = params
189 | self.vocab_size = params.vocab_size
190 | self.n_layers = params.n_layers
191 |
192 | self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
193 | self.dropout = nn.Dropout(params.dropout)
194 | self.layers = torch.nn.ModuleList()
195 | for layer_id in range(params.n_layers):
196 | self.layers.append(TransformerBlock(layer_id, params))
197 | self.norm = RMSNorm(params.dim, eps=params.norm_eps)
198 | self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
199 |
200 | # share the unembedding parameters with the embedding parameters
201 | self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
202 |
203 | # some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse
204 | freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
205 | self.register_buffer("freqs_cis", freqs_cis, persistent=False)
206 |
207 | # init all weights
208 | self.apply(self._init_weights)
209 | # apply special scaled init to the residual projections, per GPT-2 paper
210 | for pn, p in self.named_parameters():
211 | if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
212 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
213 |
214 | def _init_weights(self, module):
215 | if isinstance(module, nn.Linear):
216 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
217 | if module.bias is not None:
218 | torch.nn.init.zeros_(module.bias)
219 | elif isinstance(module, nn.Embedding):
220 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
221 |
222 | def forward(self, tokens, targets=None):
223 | _bsz, seqlen = tokens.shape
224 | h = self.tok_embeddings(tokens)
225 | h = self.dropout(h)
226 | freqs_cis = self.freqs_cis[:seqlen]
227 |
228 | for layer in self.layers:
229 | h = layer(h, freqs_cis)
230 | h = self.norm(h)
231 |
232 | if targets is not None:
233 | # if we are given some desired targets also calculate the loss
234 | logits = self.output(h)
235 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
236 | else:
237 | # inference-time mini-optimization: only forward the output on the very last position
238 | logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
239 | loss = None
240 |
241 | return logits, loss
242 |
243 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
244 | # start with all of the candidate parameters
245 | param_dict = {pn: p for pn, p in self.named_parameters()}
246 | # filter out those that do not require grad
247 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
248 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
249 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
250 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
251 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
252 | optim_groups = [
253 | {'params': decay_params, 'weight_decay': weight_decay},
254 | {'params': nodecay_params, 'weight_decay': 0.0}
255 | ]
256 | num_decay_params = sum(p.numel() for p in decay_params)
257 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
258 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
259 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
260 | # Create AdamW optimizer and use the fused version if it is available
261 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
262 | use_fused = fused_available and device_type == 'cuda'
263 | extra_args = dict(fused=True) if use_fused else dict()
264 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
265 | print(f"using fused AdamW: {use_fused}")
266 |
267 | return optimizer
268 |
269 | def estimate_mfu(self, fwdbwd_per_iter, dt):
270 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
271 | # first estimate the number of flops we do per iteration.
272 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
273 | N = sum(p.numel() for p in self.parameters())
274 | cfg = self.params
275 | L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
276 | flops_per_token = 6*N + 12*L*H*Q*T
277 | flops_per_fwdbwd = flops_per_token * T
278 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
279 | # express our flops throughput as ratio of A100 bfloat16 peak flops
280 | flops_achieved = flops_per_iter * (1.0/dt) # per second
281 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
282 | mfu = flops_achieved / flops_promised
283 | return mfu
284 |
285 | @torch.inference_mode()
286 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
287 | """
288 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
289 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
290 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
291 | Also note this is a super inefficient version of sampling with no key/value cache.
292 | """
293 | for _ in range(max_new_tokens):
294 | # if the sequence context is growing too long we must crop it at block_size
295 | idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
296 | # forward the model to get the logits for the index in the sequence
297 | logits, _ = self(idx_cond)
298 | logits = logits[:, -1, :] # crop to just the final time step
299 | if temperature == 0.0:
300 | # "sample" the single most likely index
301 | _, idx_next = torch.topk(logits, k=1, dim=-1)
302 | else:
303 | # pluck the logits at the final step and scale by desired temperature
304 | logits = logits / temperature
305 | # optionally crop the logits to only the top k options
306 | if top_k is not None:
307 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
308 | logits[logits < v[:, [-1]]] = -float('Inf')
309 | # apply softmax to convert logits to (normalized) probabilities
310 | probs = F.softmax(logits, dim=-1)
311 | idx_next = torch.multinomial(probs, num_samples=1)
312 | # append sampled index to the running sequence and continue
313 | idx = torch.cat((idx, idx_next), dim=1)
314 |
315 | return idx
316 |
317 | def export(self, filepath='model.bin'):
318 | """export the model weights in fp32 into .bin file to be read from C"""
319 | f = open(filepath, 'wb')
320 |
321 | def serialize(t):
322 | d = t.detach().cpu().view(-1).numpy().astype(np.float32)
323 | b = struct.pack(f'{len(d)}f', *d)
324 | f.write(b)
325 |
326 | # first write out the header
327 | hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
328 | p = self.params
329 | n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
330 | header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
331 | n_kv_heads, p.vocab_size, p.max_seq_len)
332 | f.write(header)
333 |
334 | # next write out the embedding weights
335 | serialize(self.tok_embeddings.weight)
336 |
337 | # now all the layers
338 | # attention weights
339 | for layer in self.layers:
340 | serialize(layer.attention_norm.weight)
341 | for layer in self.layers:
342 | serialize(layer.attention.wq.weight)
343 | for layer in self.layers:
344 | serialize(layer.attention.wk.weight)
345 | for layer in self.layers:
346 | serialize(layer.attention.wv.weight)
347 | for layer in self.layers:
348 | serialize(layer.attention.wo.weight)
349 | # ffn weights
350 | for layer in self.layers:
351 | serialize(layer.ffn_norm.weight)
352 | for layer in self.layers:
353 | serialize(layer.feed_forward.w1.weight)
354 | for layer in self.layers:
355 | serialize(layer.feed_forward.w2.weight)
356 | for layer in self.layers:
357 | serialize(layer.feed_forward.w3.weight)
358 | # final rmsnorm
359 | serialize(self.norm.weight)
360 | # note: no need to write final classifier weights due to weight sharing
361 | # freqs_cis
362 | serialize(self.freqs_cis.real[:p.max_seq_len])
363 | serialize(self.freqs_cis.imag[:p.max_seq_len])
364 |
365 | # write to binary file
366 | f.close()
367 | print(f"wrote {filepath}")
368 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.23.5
2 | pytest==7.4.0
3 | Requests==2.31.0
4 | sentencepiece==0.1.99
5 | tiktoken==0.3.3
6 | torch==2.0.1
7 | tqdm==4.64.1
8 | wandb==0.15.5
9 |
--------------------------------------------------------------------------------
/run.c:
--------------------------------------------------------------------------------
1 | /*
2 | Inference for Llama-2 Transformer model in pure C.
3 |
4 | Example compile: (see README for more details)
5 | $ gcc -O3 -o run run.c -lm
6 |
7 | Then run with:
8 | $ ./run
9 | */
10 |
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 |
20 | // ----------------------------------------------------------------------------
21 | // Transformer and RunState structs, and related memory management
22 |
23 | typedef struct {
24 | int dim; // transformer dimension
25 | int hidden_dim; // for ffn layers
26 | int n_layers; // number of layers
27 | int n_heads; // number of query heads
28 | int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
29 | int vocab_size; // vocabulary size, usually 256 (byte-level)
30 | int seq_len; // max sequence length
31 | } Config;
32 |
33 | typedef struct {
34 | // token embedding table
35 | float* token_embedding_table; // (vocab_size, dim)
36 | // weights for rmsnorms
37 | float* rms_att_weight; // (layer, dim) rmsnorm weights
38 | float* rms_ffn_weight; // (layer, dim)
39 | // weights for matmuls
40 | float* wq; // (layer, dim, dim)
41 | float* wk; // (layer, dim, dim)
42 | float* wv; // (layer, dim, dim)
43 | float* wo; // (layer, dim, dim)
44 | // weights for ffn
45 | float* w1; // (layer, hidden_dim, dim)
46 | float* w2; // (layer, dim, hidden_dim)
47 | float* w3; // (layer, hidden_dim, dim)
48 | // final rmsnorm
49 | float* rms_final_weight; // (dim,)
50 | // freq_cis for RoPE relatively positional embeddings
51 | float* freq_cis_real; // (seq_len, dim/2)
52 | float* freq_cis_imag; // (seq_len, dim/2)
53 | // (optional) classifier weights for the logits, on the last layer
54 | float* wcls;
55 | } TransformerWeights;
56 |
57 | typedef struct {
58 | // current wave of activations
59 | float *x; // activation at current time stamp (dim,)
60 | float *xb; // same, but inside a residual branch (dim,)
61 | float *xb2; // an additional buffer just for convenience (dim,)
62 | float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
63 | float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
64 | float *q; // query (dim,)
65 | float *k; // key (dim,)
66 | float *v; // value (dim,)
67 | float *att; // buffer for scores/attention values (n_heads, seq_len)
68 | float *logits; // output logits
69 | // kv cache
70 | float* key_cache; // (layer, seq_len, dim)
71 | float* value_cache; // (layer, seq_len, dim)
72 | } RunState;
73 |
74 | void malloc_run_state(RunState* s, Config* p) {
75 | // we calloc instead of malloc to keep valgrind happy
76 | s->x = calloc(p->dim, sizeof(float));
77 | s->xb = calloc(p->dim, sizeof(float));
78 | s->xb2 = calloc(p->dim, sizeof(float));
79 | s->hb = calloc(p->hidden_dim, sizeof(float));
80 | s->hb2 = calloc(p->hidden_dim, sizeof(float));
81 | s->q = calloc(p->dim, sizeof(float));
82 | s->k = calloc(p->dim, sizeof(float));
83 | s->v = calloc(p->dim, sizeof(float));
84 | s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
85 | s->logits = calloc(p->vocab_size, sizeof(float));
86 | s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
87 | s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
88 | // ensure all mallocs went fine
89 | if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
90 | || !s->k || !s->v || !s->att || !s->logits || !s->key_cache
91 | || !s->value_cache) {
92 | printf("malloc failed!\n");
93 | exit(1);
94 | }
95 | }
96 |
97 | void free_run_state(RunState* s) {
98 | free(s->x);
99 | free(s->xb);
100 | free(s->xb2);
101 | free(s->hb);
102 | free(s->hb2);
103 | free(s->q);
104 | free(s->k);
105 | free(s->v);
106 | free(s->att);
107 | free(s->logits);
108 | free(s->key_cache);
109 | free(s->value_cache);
110 | }
111 |
112 | // ----------------------------------------------------------------------------
113 | // initialization: read from checkpoint
114 |
115 | void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) {
116 | float* ptr = f;
117 | w->token_embedding_table = ptr;
118 | ptr += p->vocab_size * p->dim;
119 | w->rms_att_weight = ptr;
120 | ptr += p->n_layers * p->dim;
121 | w->wq = ptr;
122 | ptr += p->n_layers * p->dim * p->dim;
123 | w->wk = ptr;
124 | ptr += p->n_layers * p->dim * p->dim;
125 | w->wv = ptr;
126 | ptr += p->n_layers * p->dim * p->dim;
127 | w->wo = ptr;
128 | ptr += p->n_layers * p->dim * p->dim;
129 | w->rms_ffn_weight = ptr;
130 | ptr += p->n_layers * p->dim;
131 | w->w1 = ptr;
132 | ptr += p->n_layers * p->dim * p->hidden_dim;
133 | w->w2 = ptr;
134 | ptr += p->n_layers * p->hidden_dim * p->dim;
135 | w->w3 = ptr;
136 | ptr += p->n_layers * p->dim * p->hidden_dim;
137 | w->rms_final_weight = ptr;
138 | ptr += p->dim;
139 | w->freq_cis_real = ptr;
140 | int head_size = p->dim / p->n_heads;
141 | ptr += p->seq_len * head_size / 2;
142 | w->freq_cis_imag = ptr;
143 | ptr += p->seq_len * head_size / 2;
144 | w->wcls = shared_weights ? w->token_embedding_table : ptr;
145 | }
146 |
147 | // ----------------------------------------------------------------------------
148 | // neural net blocks
149 |
150 | void accum(float *a, float *b, int size) {
151 | for (int i = 0; i < size; i++) {
152 | a[i] += b[i];
153 | }
154 | }
155 |
156 | void rmsnorm(float* o, float* x, float* weight, int size) {
157 | // calculate sum of squares
158 | float ss = 0.0f;
159 | for (int j = 0; j < size; j++) {
160 | ss += x[j] * x[j];
161 | }
162 | ss /= size;
163 | ss += 1e-5f;
164 | ss = 1.0f / sqrtf(ss);
165 | // normalize and scale
166 | for (int j = 0; j < size; j++) {
167 | o[j] = weight[j] * (ss * x[j]);
168 | }
169 | }
170 |
171 | void softmax(float* x, int size) {
172 | // find max value (for numerical stability)
173 | float max_val = x[0];
174 | for (int i = 1; i < size; i++) {
175 | if (x[i] > max_val) {
176 | max_val = x[i];
177 | }
178 | }
179 | // exp and sum
180 | float sum = 0.0f;
181 | for (int i = 0; i < size; i++) {
182 | x[i] = expf(x[i] - max_val);
183 | sum += x[i];
184 | }
185 | // normalize
186 | for (int i = 0; i < size; i++) {
187 | x[i] /= sum;
188 | }
189 | }
190 |
191 | void matmul(float* xout, float* x, float* w, int n, int d) {
192 | // W (d,n) @ x (n,) -> xout (d,)
193 | #pragma omp parallel for
194 | for (int i = 0; i < d; i++) {
195 | float val = 0.0f;
196 | for (int j = 0; j < n; j++) {
197 | val += w[i * n + j] * x[j];
198 | }
199 | xout[i] = val;
200 | }
201 | }
202 |
203 | void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
204 |
205 | // a few convenience variables
206 | float *x = s->x;
207 | int dim = p->dim;
208 | int hidden_dim = p->hidden_dim;
209 | int head_size = dim / p->n_heads;
210 |
211 | // copy the token embedding into x
212 | float* content_row = &(w->token_embedding_table[token * dim]);
213 | memcpy(x, content_row, dim*sizeof(*x));
214 |
215 | // pluck out the "pos" row of freq_cis_real and freq_cis_imag
216 | float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2;
217 | float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2;
218 |
219 | // forward all the layers
220 | for(int l = 0; l < p->n_layers; l++) {
221 |
222 | // attention rmsnorm
223 | rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
224 |
225 | // qkv matmuls for this position
226 | matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
227 | matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim);
228 | matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim);
229 |
230 | // apply RoPE rotation to the q and k vectors for each head
231 | for (int h = 0; h < p->n_heads; h++) {
232 | // get the q and k vectors for this head
233 | float* q = s->q + h * head_size;
234 | float* k = s->k + h * head_size;
235 | // rotate q and k by the freq_cis_real and freq_cis_imag
236 | for (int i = 0; i < head_size; i+=2) {
237 | float q0 = q[i];
238 | float q1 = q[i+1];
239 | float k0 = k[i];
240 | float k1 = k[i+1];
241 | float fcr = freq_cis_real_row[i/2];
242 | float fci = freq_cis_imag_row[i/2];
243 | q[i] = q0 * fcr - q1 * fci;
244 | q[i+1] = q0 * fci + q1 * fcr;
245 | k[i] = k0 * fcr - k1 * fci;
246 | k[i+1] = k0 * fci + k1 * fcr;
247 | }
248 | }
249 |
250 | // save key,value at this time step (pos) to our kv cache
251 | int loff = l * p->seq_len * dim; // kv cache layer offset for convenience
252 | float* key_cache_row = s->key_cache + loff + pos * dim;
253 | float* value_cache_row = s->value_cache + loff + pos * dim;
254 | memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row));
255 | memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));
256 |
257 | // multihead attention. iterate over all heads
258 | #pragma omp parallel for
259 | for (int h = 0; h < p->n_heads; h++) {
260 | // get the query vector for this head
261 | float* q = s->q + h * head_size;
262 | // attention scores for this head
263 | float* att = s->att + h * p->seq_len;
264 | // iterate over all timesteps, including the current one
265 | for (int t = 0; t <= pos; t++) {
266 | // get the key vector for this head and at this timestep
267 | float* k = s->key_cache + loff + t * dim + h * head_size;
268 | // calculate the attention score as the dot product of q and k
269 | float score = 0.0f;
270 | for (int i = 0; i < head_size; i++) {
271 | score += q[i] * k[i];
272 | }
273 | score /= sqrtf(head_size);
274 | // save the score to the attention buffer
275 | att[t] = score;
276 | }
277 |
278 | // softmax the scores to get attention weights, from 0..pos inclusively
279 | softmax(att, pos + 1);
280 |
281 | // weighted sum of the values, store back into xb
282 | for (int i = 0; i < head_size; i++) {
283 | float val = 0.0f;
284 | for (int t = 0; t <= pos; t++) {
285 | val += att[t] * s->value_cache[loff + t * dim + h * head_size + i]; // note bad locality
286 | }
287 | s->xb[h * head_size + i] = val;
288 | }
289 | }
290 |
291 | // final matmul to get the output of the attention
292 | matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
293 |
294 | // residual connection back into x
295 | accum(x, s->xb2, dim);
296 |
297 | // ffn rmsnorm
298 | rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
299 |
300 | // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
301 | // first calculate self.w1(x) and self.w3(x)
302 | matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
303 | matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
304 |
305 | // F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
306 | for (int i = 0; i < hidden_dim; i++) {
307 | s->hb[i] = s->hb[i] * (1.0f / (1.0f + expf(-s->hb[i])));
308 | }
309 |
310 | // elementwise multiply with w3(x)
311 | for (int i = 0; i < hidden_dim; i++) {
312 | s->hb[i] = s->hb[i] * s->hb2[i];
313 | }
314 |
315 | // final matmul to get the output of the ffn
316 | matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
317 |
318 | // residual connection
319 | accum(x, s->xb, dim);
320 | }
321 |
322 | // final rmsnorm
323 | rmsnorm(x, x, w->rms_final_weight, dim);
324 |
325 | // classifier into logits
326 | matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
327 | }
328 |
329 | int sample(float* probabilities, int n) {
330 | // sample index from probabilities, they must sum to 1
331 | float r = (float)rand() / (float)RAND_MAX;
332 | float cdf = 0.0f;
333 | for (int i = 0; i < n; i++) {
334 | cdf += probabilities[i];
335 | if (r < cdf) {
336 | return i;
337 | }
338 | }
339 | return n - 1; // in case of rounding errors
340 | }
341 |
342 | int argmax(float* v, int n) {
343 | // return argmax of v in elements 0..n
344 | int max_i = 0;
345 | float max_p = v[0];
346 | for (int i = 1; i < n; i++) {
347 | if (v[i] > max_p) {
348 | max_i = i;
349 | max_p = v[i];
350 | }
351 | }
352 | return max_i;
353 | }
354 |
355 | // ----------------------------------------------------------------------------
356 |
357 | // long time_in_ms() {
358 | // struct timespec time;
359 | // // Get the current time with nanosecond precision
360 | // if (clock_gettime(CLOCK_REALTIME, &time) == 0) {
361 | // return time.tv_sec * 1000 + time.tv_nsec / 1000000;
362 | // } else {
363 | // perror("clock_gettime");
364 | // return -1; // Return -1 to indicate an error
365 | // }
366 | // }
367 |
368 | long time_in_ms() {
369 | struct timeval time;
370 | // Get the current time with microsecond precision
371 | if (gettimeofday(&time, NULL) == 0) {
372 | return time.tv_sec * 1000 + time.tv_usec / 1000;
373 | } else {
374 | perror("gettimeofday");
375 | return -1; // Return -1 to indicate an error
376 | }
377 | }
378 |
379 | int main(int argc, char *argv[]) {
380 |
381 | // poor man's C argparse
382 | char *checkpoint = NULL; // e.g. out/model.bin
383 | float temperature = 0.9f; // e.g. 1.0, or 0.0
384 | int steps = 256; // max number of steps to run for, 0: use seq_len
385 | // 'checkpoint' is necessary arg
386 | if (argc < 2) {
387 | printf("Usage: %s [temperature] [steps]\n", argv[0]);
388 | return 1;
389 | }
390 | if (argc >= 2) {
391 | checkpoint = argv[1];
392 | }
393 | if (argc >= 3) {
394 | // optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
395 | temperature = atof(argv[2]);
396 | }
397 | if (argc >= 4) {
398 | steps = atoi(argv[3]);
399 | }
400 |
401 | // seed rng with time. if you want deterministic behavior use temperature 0.0
402 | srand((unsigned int)time(NULL));
403 |
404 | // read in the model.bin file
405 | Config config;
406 | TransformerWeights weights;
407 | int fd = 0;
408 | float* data = NULL;
409 | long file_size;
410 | {
411 | FILE *file = fopen(checkpoint, "rb");
412 | if (!file) {
413 | printf("Unable to open the checkpoint file %s!\n", checkpoint);
414 | return 1;
415 | }
416 | // read in the config header
417 | if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
418 | // negative vocab size is hacky way of signaling unshared weights. bit yikes.
419 | int shared_weights = config.vocab_size > 0 ? 1 : 0;
420 | config.vocab_size = abs(config.vocab_size);
421 | // figure out the file size
422 | fseek(file, 0, SEEK_END); // move file pointer to end of file
423 | file_size = ftell(file); // get the file size, in bytes
424 | fclose(file);
425 | // memory map the Transformer weights into the data pointer
426 | fd = open(checkpoint, O_RDONLY); // open in read only mode
427 | if (fd == -1) { printf("open failed!\n"); return 1; }
428 | data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
429 | if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; }
430 | float* weights_ptr = data + sizeof(Config)/sizeof(float);
431 | checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
432 | }
433 | // right now we cannot run for more than config.seq_len steps
434 | if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
435 |
436 | // read in the tokenizer.bin file
437 | char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
438 | {
439 | FILE *file = fopen("tokenizer.bin", "rb");
440 | if (!file) {
441 | printf("Unable to open the tokenizer file tokenizer.bin! Run "
442 | "python tokenizer.py to convert tokenizer.model -> tokenizer.bin\n");
443 | return 1;
444 | }
445 | int len;
446 | for (int i = 0; i < config.vocab_size; i++) {
447 | if(fread(&len, sizeof(int), 1, file) != 1) { return 1; }
448 | vocab[i] = (char *)malloc(len + 1);
449 | if(fread(vocab[i], len, 1, file) != 1) { return 1; }
450 | vocab[i][len] = '\0'; // add the string terminating token
451 | }
452 | fclose(file);
453 | }
454 |
455 | // create and init the application RunState
456 | RunState state;
457 | malloc_run_state(&state, &config);
458 |
459 | // the current position we are in
460 | long start = time_in_ms();
461 | int next;
462 | int token = 1; // 1 = BOS token in Llama-2 sentencepiece
463 | int pos = 0;
464 | printf("\n"); // explicit print the initial BOS token (=1), stylistically symmetric
465 | while (pos < steps) {
466 |
467 | // forward the transformer to get logits for the next token
468 | transformer(token, pos, &config, &state, &weights);
469 |
470 | // sample the next token
471 | if(temperature == 0.0f) {
472 | // greedy argmax sampling
473 | next = argmax(state.logits, config.vocab_size);
474 | } else {
475 | // apply the temperature to the logits
476 | for (int q=0; q" or etc. Can also specify a file, use as: "FILE:prompt.txt"
15 | num_samples = 1 # number of samples to draw
16 | max_new_tokens = 100 # number of tokens generated in each sample
17 | temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
18 | top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability
19 | seed = 1337
20 | device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
21 | #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
22 | dtype = "float32"
23 | compile = False # use PyTorch 2.0 to compile the model to be faster
24 | exec(open('configurator.py').read()) # overrides from command line or config file
25 | # -----------------------------------------------------------------------------
26 |
27 | torch.manual_seed(seed)
28 | torch.cuda.manual_seed(seed)
29 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
30 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
31 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
32 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
33 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
34 |
35 | # init from a model saved in a specific directory
36 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
37 | checkpoint = torch.load(ckpt_path, map_location=device)
38 | gptconf = ModelArgs(**checkpoint['model_args'])
39 | model = Transformer(gptconf)
40 | state_dict = checkpoint['model']
41 | unwanted_prefix = '_orig_mod.'
42 | for k,v in list(state_dict.items()):
43 | if k.startswith(unwanted_prefix):
44 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
45 | model.load_state_dict(state_dict, strict=False)
46 |
47 | model.eval()
48 | model.to(device)
49 | if compile:
50 | print("Compiling the model...")
51 | model = torch.compile(model) # requires PyTorch 2.0 (optional)
52 |
53 | # load the tokenizer
54 | enc = Tokenizer()
55 |
56 | # encode the beginning of the prompt
57 | if start.startswith('FILE:'):
58 | with open(start[5:], 'r', encoding='utf-8') as f:
59 | start = f.read()
60 | start_ids = enc.encode(start, bos=True, eos=False)
61 | x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
62 |
63 | # run generation
64 | with torch.no_grad():
65 | with ctx:
66 | for k in range(num_samples):
67 | y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
68 | print(enc.decode(y[0].tolist()))
69 | print('---------------')
70 |
--------------------------------------------------------------------------------
/test_all.py:
--------------------------------------------------------------------------------
1 | """
2 | Run simply with
3 | $ pytest
4 | """
5 | import os
6 | import pytest # pip install pytest
7 | import subprocess
8 |
9 | import torch
10 | from model import ModelArgs, Transformer
11 |
12 | def test_argmax_inference():
13 | """
14 | Only the simplest test for now: run inference with temperature 0
15 | (for determinism) in both C and PyTorch, and see that the sampled tokens
16 | are the same.
17 | """
18 | test_ckpt_dir = "out" # TODO create a dummy test checkpoint for this?
19 |
20 | # run C version
21 | model_path = os.path.join(test_ckpt_dir, "model.bin")
22 | command = ["./run", model_path, "0.0"]
23 | proc = subprocess.Popen(command, stdout=subprocess.PIPE)
24 | c_tokens = []
25 | for line in proc.stdout:
26 | token = int(line.decode('utf-8').strip())
27 | c_tokens.append(token)
28 | proc.wait()
29 | #print(c_tokens)
30 |
31 | # run PyTorch version
32 | device = "cuda" if torch.cuda.is_available() else "cpu"
33 | ckpt_path = os.path.join(test_ckpt_dir, "ckpt.pt")
34 | checkpoint = torch.load(ckpt_path, map_location=device)
35 | gptconf = ModelArgs(**checkpoint['model_args'])
36 | model = Transformer(gptconf)
37 | state_dict = checkpoint['model']
38 | unwanted_prefix = '_orig_mod.'
39 | for k,v in list(state_dict.items()):
40 | if k.startswith(unwanted_prefix):
41 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
42 | model.load_state_dict(state_dict, strict=False)
43 | model.eval()
44 | model.to(device)
45 | x = torch.tensor([[1]], dtype=torch.long, device=device) # 1 is BOS
46 | with torch.inference_mode():
47 | y = model.generate(x, max_new_tokens=gptconf.max_seq_len, temperature=0.0)
48 | pt_tokens = y[0].tolist()
49 | pt_tokens = pt_tokens[1:] # remove BOS
50 | #print(pt_tokens)
51 |
52 | # compare
53 | assert c_tokens == pt_tokens
54 |
--------------------------------------------------------------------------------
/tinystories.py:
--------------------------------------------------------------------------------
1 | """
2 | Download, preprocess and serve the TinyStories dataset as a DataLoader.
3 | """
4 |
5 | import argparse
6 | import glob
7 | import json
8 | import os
9 | import random
10 | from typing import List
11 | from concurrent.futures import ThreadPoolExecutor, as_completed
12 |
13 | import numpy as np
14 | import requests
15 | import torch
16 | import torch.distributed as dist
17 | from tqdm import tqdm
18 |
19 | from tokenizer import Tokenizer
20 |
21 | DATA_CACHE_DIR = "data"
22 |
23 | def download_file(url: str, fname: str, chunk_size=1024):
24 | """Helper function to download a file from a given url"""
25 | resp = requests.get(url, stream=True)
26 | total = int(resp.headers.get("content-length", 0))
27 | with open(fname, "wb") as file, tqdm(
28 | desc=fname,
29 | total=total,
30 | unit="iB",
31 | unit_scale=True,
32 | unit_divisor=1024,
33 | ) as bar:
34 | for data in resp.iter_content(chunk_size=chunk_size):
35 | size = file.write(data)
36 | bar.update(size)
37 |
38 |
39 | def download():
40 | """Downloads the dataset to disk."""
41 | os.makedirs(DATA_CACHE_DIR, exist_ok=True)
42 |
43 | # download the TinyStories dataset, unless it's already downloaded
44 | data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz"
45 | data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
46 | if not os.path.exists(data_filename):
47 | print(f"Downloading {data_url} to {data_filename}...")
48 | download_file(data_url, data_filename)
49 | else:
50 | print(f"{data_filename} already exists, skipping download...")
51 |
52 | # unpack the tar.gz file into all the data shards (json files)
53 | data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
54 | if not os.path.exists(data_dir):
55 | os.makedirs(data_dir, exist_ok=True)
56 | print(f"Unpacking {data_filename}...")
57 | os.system(f"tar -xzf {data_filename} -C {data_dir}")
58 | else:
59 | print(f"{data_dir} already exists, skipping unpacking...")
60 |
61 | # print a single example just for debugging and such
62 | shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
63 | with open(shard_filenames[0], "r") as f:
64 | data = json.load(f)
65 | print("Download done.")
66 | print(f"Number of shards: {len(shard_filenames)}")
67 | print(f"Example story:\n{data[0]}")
68 |
69 | def pretokenize():
70 | enc = Tokenizer()
71 |
72 | def process_shard(shard):
73 | with open(shard, "r") as f:
74 | data = json.load(f)
75 | all_tokens = []
76 | for example in tqdm(data):
77 | text = example["story"]
78 | text = text.strip() # get rid of leading/trailing whitespace
79 | tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS
80 | all_tokens.extend(tokens)
81 | # convert to uint16 nparray
82 | all_tokens = np.array(all_tokens, dtype=np.uint16)
83 | # write to disk
84 | tokenized_filename = shard.replace(".json", ".bin")
85 | with open(tokenized_filename, "wb") as f:
86 | f.write(all_tokens.tobytes())
87 | print(f"Saved {tokenized_filename}")
88 |
89 | # iterate the shards and tokenize all of them one by one
90 | data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
91 | shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
92 |
93 | # process all the shards in a threadpool
94 | with ThreadPoolExecutor(max_workers=8) as executor:
95 | executor.map(process_shard, shard_filenames)
96 |
97 | print("Done.")
98 |
99 |
100 | class PretokDataset(torch.utils.data.IterableDataset):
101 | """Loads pretokenized examples from disk and yields them as PyTorch tensors."""
102 |
103 | def __init__(self, split, max_seq_len):
104 | super().__init__()
105 | self.split = split
106 | self.max_seq_len = max_seq_len
107 |
108 | def __iter__(self):
109 | # get worker info within a DataLoader
110 | worker_info = torch.utils.data.get_worker_info()
111 | worker_id = worker_info.id if worker_info else 0
112 | # get DDP rank info
113 | rank = dist.get_rank() if dist.is_initialized() else 0
114 | # combine the worker_id and worker_rank to create a unique seed for rng
115 | seed = 42 + worker_id + 1337 * rank
116 | rng = random.Random(seed)
117 | print(f"Created a PretokDataset with rng seed {seed}")
118 | data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
119 | shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.bin")))
120 | # train/test split. let's use only shard 0 for test split, rest train
121 | shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
122 | while True:
123 | rng.shuffle(shard_filenames)
124 | for shard in shard_filenames:
125 | # open the dataset for reading but keep it on disk with memmap
126 | m = np.memmap(shard, dtype=np.uint16, mode="r")
127 | num_batches = len(m) // self.max_seq_len
128 | num_batches -= 1 # drop the last partial batch
129 | assert num_batches > 0, "this shard is way too small? investigate."
130 | ixs = list(range(num_batches))
131 | rng.shuffle(ixs)
132 | for ix in ixs:
133 | start = ix * self.max_seq_len
134 | end = start + self.max_seq_len + 1
135 | # calling .astype will copy the data into a new numpy array, now in RAM
136 | chunk = torch.from_numpy((m[start:end]).astype(np.int64))
137 | x = chunk[:-1]
138 | y = chunk[1:]
139 | yield x, y
140 |
141 |
142 | class Task:
143 |
144 | @staticmethod
145 | def iter_batches(split, batch_size, max_seq_len, device, num_workers=0):
146 | ds = PretokDataset(split, max_seq_len)
147 | dl = torch.utils.data.DataLoader(
148 | ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
149 | )
150 | for x, y in dl:
151 | x = x.to(device, non_blocking=True)
152 | y = y.to(device, non_blocking=True)
153 | yield x, y
154 |
155 |
156 | if __name__ == "__main__":
157 | parser = argparse.ArgumentParser()
158 | parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"])
159 | args = parser.parse_args()
160 |
161 | # depending on the stage call the appropriate function
162 | fun = {
163 | "download": download,
164 | "pretokenize": pretokenize,
165 | }
166 | fun[args.stage]()
--------------------------------------------------------------------------------
/tokenizer.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Manuel030/llama2.c-android/1d6647d72191d4d21ad59e9e01c659e096c351c3/tokenizer.model
--------------------------------------------------------------------------------
/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Taken from llama code and lightly modified
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
4 |
5 | import os
6 | from logging import getLogger
7 | from typing import List
8 |
9 | from sentencepiece import SentencePieceProcessor
10 |
11 | TOKENIZER_MODEL = "tokenizer.model" # the llama sentencepiece tokenizer model
12 | TOKENIZER_BIN = "tokenizer.bin" # binary version of the tokenizer for inference in C
13 |
14 | class Tokenizer:
15 | def __init__(self):
16 | model_path = TOKENIZER_MODEL
17 | assert os.path.isfile(model_path), model_path
18 | self.sp_model = SentencePieceProcessor(model_file=model_path)
19 | #print(f"Loaded SentencePiece model from {model_path}")
20 |
21 | # BOS / EOS token IDs
22 | self.n_words: int = self.sp_model.vocab_size()
23 | self.bos_id: int = self.sp_model.bos_id()
24 | self.eos_id: int = self.sp_model.eos_id()
25 | self.pad_id: int = self.sp_model.pad_id()
26 | #print(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
27 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
28 |
29 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
30 | assert type(s) is str
31 | t = self.sp_model.encode(s)
32 | if bos:
33 | t = [self.bos_id] + t
34 | if eos:
35 | t = t + [self.eos_id]
36 | return t
37 |
38 | def decode(self, t: List[int]) -> str:
39 | return self.sp_model.decode(t)
40 |
41 | def export(self):
42 | tokens = []
43 | for i in range(self.n_words):
44 |
45 | # decode the token and light postprocessing
46 | t = self.sp_model.id_to_piece(i)
47 | if i == self.bos_id:
48 | t = '\n\n'
49 | elif i == self.eos_id:
50 | t = '\n\n'
51 | elif len(t) == 6 and t.startswith('<0x') and t.endswith('>'):
52 | t = chr(int(t[3:5], 16)) # e.g. make '<0x01>' into '\x01'
53 | t = t.replace('▁', ' ') # sentencepiece uses this as the whitespace
54 |
55 | tokens.append(t)
56 |
57 | with open(TOKENIZER_BIN, 'wb') as f:
58 | for token in tokens:
59 | bytes = token.encode('utf-8')
60 | f.write((len(bytes)).to_bytes(4, 'little')) # write length of bytes
61 | f.write(bytes) # write token bytes
62 |
63 | if __name__ == "__main__":
64 | t = Tokenizer()
65 | t.export()
66 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | This training script can be run both on a single gpu in debug mode,
3 | and also in a larger training run with distributed data parallel (ddp).
4 |
5 | To run on a single GPU small debug run, example:
6 | $ python -m train.py --compile=False --eval_iters=10 --batch_size=8
7 |
8 | To run with DDP on 4 gpus on 1 node, example:
9 | $ torchrun --standalone --nproc_per_node=4 train.py
10 |
11 | To run with DDP on 4 gpus across 2 nodes, example:
12 | - Run on the first (master) node with example IP 123.456.123.456:
13 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
14 | - Run on the worker node:
15 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
16 | (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
17 | """
18 |
19 | import math
20 | import os
21 | import time
22 | from contextlib import nullcontext
23 | from datetime import datetime
24 | from functools import partial
25 |
26 | import torch
27 | from model import Transformer, ModelArgs
28 | from torch.distributed import destroy_process_group, init_process_group
29 | from torch.nn.parallel import DistributedDataParallel as DDP
30 |
31 | from tinystories import Task
32 |
33 | # -----------------------------------------------------------------------------
34 | # I/O
35 | out_dir = "out"
36 | eval_interval = 2000
37 | log_interval = 1
38 | eval_iters = 100
39 | eval_only = False # if True, script exits right after the first eval
40 | always_save_checkpoint = False # if True, always save a checkpoint after each eval
41 | init_from = "scratch" # 'scratch' or 'resume'
42 | # wandb logging
43 | wandb_log = False # disabled by default
44 | wandb_project = "llamac"
45 | wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
46 | # data
47 | batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
48 | max_seq_len = 256
49 | # model
50 | dim = 288
51 | n_layers = 6
52 | n_heads = 6
53 | multiple_of = 32
54 | dropout = 0.0
55 | # adamw optimizer
56 | gradient_accumulation_steps = 4 # used to simulate larger batch sizes
57 | learning_rate = 5e-4 # max learning rate
58 | max_iters = 100000 # total number of training iterations
59 | weight_decay = 1e-1
60 | beta1 = 0.9
61 | beta2 = 0.95
62 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
63 | # learning rate decay settings
64 | decay_lr = True # whether to decay the learning rate
65 | warmup_iters = 1000 # how many steps to warm up for
66 | # system
67 | device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
68 | dtype = "bfloat16" # float32|bfloat16|float16
69 | compile = True # use PyTorch 2.0 to compile the model to be faster
70 | # -----------------------------------------------------------------------------
71 | config_keys = [
72 | k
73 | for k, v in globals().items()
74 | if not k.startswith("_") and isinstance(v, (int, float, bool, str))
75 | ]
76 | exec(open("configurator.py").read()) # overrides from command line or config file
77 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
78 | # -----------------------------------------------------------------------------
79 |
80 | # fixing some hyperparams to sensible defaults
81 | lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
82 | min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
83 |
84 | # various inits, derived attributes, I/O setup
85 | ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
86 | if ddp:
87 | init_process_group(backend="nccl")
88 | ddp_rank = int(os.environ["RANK"])
89 | ddp_local_rank = int(os.environ["LOCAL_RANK"])
90 | ddp_world_size = int(os.environ["WORLD_SIZE"])
91 | device = f"cuda:{ddp_local_rank}"
92 | torch.cuda.set_device(device)
93 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
94 | seed_offset = ddp_rank # each process gets a different seed
95 | # world_size number of processes will be training simultaneously, so we can scale
96 | # down the desired gradient accumulation iterations per process proportionally
97 | assert gradient_accumulation_steps % ddp_world_size == 0
98 | gradient_accumulation_steps //= ddp_world_size
99 | else:
100 | # if not ddp, we are running on a single gpu, and one process
101 | master_process = True
102 | seed_offset = 0
103 | ddp_world_size = 1
104 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len
105 | if master_process:
106 | print(f"tokens per iteration will be: {tokens_per_iter:,}")
107 | print(f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len")
108 |
109 | if master_process:
110 | os.makedirs(out_dir, exist_ok=True)
111 | torch.manual_seed(1337 + seed_offset)
112 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
113 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
114 | device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
115 | # note: float16 data type will automatically use a GradScaler
116 | ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
117 | ctx = (
118 | nullcontext()
119 | if device_type == "cpu"
120 | else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
121 | )
122 |
123 | # task-specific setup
124 | iter_batches = partial(
125 | Task.iter_batches,
126 | batch_size=batch_size,
127 | max_seq_len=max_seq_len,
128 | device=device,
129 | num_workers=0,
130 | )
131 |
132 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
133 | iter_num = 0
134 | best_val_loss = 1e9
135 |
136 | # model init
137 | model_args = dict(
138 | dim=dim,
139 | n_layers=n_layers,
140 | n_heads=n_heads,
141 | n_kv_heads=n_heads,
142 | vocab_size=32000,
143 | multiple_of=multiple_of,
144 | max_seq_len=max_seq_len,
145 | #dropout=dropout,
146 | ) # start with model_args from command line
147 | if init_from == "scratch":
148 | # init a new model from scratch
149 | print("Initializing a new model from scratch")
150 | gptconf = ModelArgs(**model_args)
151 | model = Transformer(gptconf)
152 | elif init_from == "resume":
153 | print(f"Resuming training from {out_dir}")
154 | # resume training from a checkpoint.
155 | ckpt_path = os.path.join(out_dir, "ckpt.pt")
156 | checkpoint = torch.load(ckpt_path, map_location=device)
157 | checkpoint_model_args = checkpoint["model_args"]
158 | # force these config attributes to be equal otherwise we can't even resume training
159 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
160 | for k in ["dim", "n_layers", "n_heads", "n_kv_heads", "vocab_size", "multiple_of", "max_seq_len"]:
161 | model_args[k] = checkpoint_model_args[k]
162 | # create the model
163 | gptconf = ModelArgs(**model_args)
164 | model = Transformer(gptconf)
165 | state_dict = checkpoint["model"]
166 | # fix the keys of the state dictionary :(
167 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
168 | unwanted_prefix = "_orig_mod."
169 | for k, v in list(state_dict.items()):
170 | if k.startswith(unwanted_prefix):
171 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
172 | model.load_state_dict(state_dict)
173 | iter_num = checkpoint["iter_num"]
174 | best_val_loss = checkpoint["best_val_loss"]
175 | model.to(device)
176 |
177 | # initialize a GradScaler. If enabled=False scaler is a no-op
178 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))
179 |
180 | # optimizer
181 | optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
182 | if init_from == "resume":
183 | optimizer.load_state_dict(checkpoint["optimizer"])
184 | checkpoint = None # free up memory
185 |
186 | # compile the model
187 | if compile:
188 | print("compiling the model... (takes a ~minute)")
189 | unoptimized_model = model
190 | model = torch.compile(model) # requires PyTorch 2.0
191 |
192 | # wrap model into DDP container
193 | if ddp:
194 | # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
195 | # construction time since NCCL does not support `ComplexFloat`
196 | prefix = "_orig_mod." if compile else ""
197 | model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
198 | model = DDP(model, device_ids=[ddp_local_rank])
199 |
200 | # helps estimate an arbitrarily accurate loss over either split using many batches
201 | @torch.no_grad()
202 | def estimate_loss():
203 | out = {}
204 | model.eval()
205 | for split in ["train", "val"]:
206 | batch_iter = iter_batches(split)
207 | losses = torch.zeros(eval_iters) # keep on CPU
208 | for k in range(eval_iters):
209 | X, Y = next(batch_iter)
210 | with ctx:
211 | logits, loss = model(X, Y)
212 | losses[k] = loss.item()
213 | out[split] = losses.mean()
214 | model.train()
215 | return out
216 |
217 | # learning rate decay scheduler (cosine with warmup)
218 | def get_lr(it):
219 | # 1) linear warmup for warmup_iters steps
220 | if it < warmup_iters:
221 | return learning_rate * it / warmup_iters
222 | # 2) if it > lr_decay_iters, return min learning rate
223 | if it > lr_decay_iters:
224 | return min_lr
225 | # 3) in between, use cosine decay down to min learning rate
226 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
227 | assert 0 <= decay_ratio <= 1
228 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
229 | return min_lr + coeff * (learning_rate - min_lr)
230 |
231 | # logging
232 | if wandb_log and master_process:
233 | import wandb
234 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
235 |
236 | # training loop
237 | train_batch_iter = iter_batches("train")
238 | X, Y = next(train_batch_iter) # fetch the very first batch
239 | t0 = time.time()
240 | local_iter_num = 0 # number of iterations in the lifetime of this process
241 | raw_model = model.module if ddp else model # unwrap DDP container if needed
242 | running_mfu = -1.0
243 | while True:
244 | # determine and set the learning rate for this iteration
245 | lr = get_lr(iter_num) if decay_lr else learning_rate
246 | for param_group in optimizer.param_groups:
247 | param_group["lr"] = lr
248 |
249 | # evaluate the loss on train/val sets and write checkpoints
250 | if iter_num % eval_interval == 0 and master_process:
251 | losses = estimate_loss()
252 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
253 | if wandb_log:
254 | try:
255 | wandb.log(
256 | {
257 | "iter": iter_num,
258 | "tokens": iter_num * tokens_per_iter,
259 | "loss/train": losses["train"],
260 | "loss/val": losses["val"],
261 | "lr": lr,
262 | "mfu": running_mfu * 100, # convert to percentage
263 | }
264 | )
265 | except Exception as e:
266 | print(f"logging to wandb failed: {e}")
267 | if losses["val"] < best_val_loss or always_save_checkpoint:
268 | best_val_loss = losses["val"]
269 | if iter_num > 0:
270 | checkpoint = {
271 | "model": raw_model.state_dict(),
272 | "optimizer": optimizer.state_dict(),
273 | "model_args": model_args,
274 | "iter_num": iter_num,
275 | "best_val_loss": best_val_loss,
276 | "config": config,
277 | }
278 | print(f"saving checkpoint to {out_dir}")
279 | torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt"))
280 | raw_model.export(os.path.join(out_dir, "model.bin"))
281 | if iter_num == 0 and eval_only:
282 | break
283 |
284 | # forward backward update, with optional gradient accumulation to simulate larger batch size
285 | # and using the GradScaler if data type is float16
286 | for micro_step in range(gradient_accumulation_steps):
287 | if ddp:
288 | # in DDP training we only need to sync gradients at the last micro step.
289 | # the official way to do this is with model.no_sync() context manager, but
290 | # I really dislike that this bloats the code and forces us to repeat code
291 | # looking at the source of that context manager, it just toggles this variable
292 | model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
293 | with ctx:
294 | logits, loss = model(X, Y)
295 | loss = loss / gradient_accumulation_steps
296 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
297 | X, Y = next(train_batch_iter)
298 | # backward pass, with gradient scaling if training in fp16
299 | scaler.scale(loss).backward()
300 | # clip the gradient
301 | if grad_clip != 0.0:
302 | scaler.unscale_(optimizer)
303 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
304 | # step the optimizer and scaler if training in fp16
305 | scaler.step(optimizer)
306 | scaler.update()
307 | # flush the gradients as soon as we can, no need for this memory anymore
308 | optimizer.zero_grad(set_to_none=True)
309 |
310 | # timing and logging
311 | t1 = time.time()
312 | dt = t1 - t0
313 | t0 = t1
314 | if iter_num % log_interval == 0 and master_process:
315 | # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point
316 | lossf = loss.item() * gradient_accumulation_steps
317 | if local_iter_num >= 5: # let the training loop settle a bit
318 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
319 | running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
320 | print(
321 | f"{iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%"
322 | )
323 | iter_num += 1
324 | local_iter_num += 1
325 |
326 | # termination conditions
327 | if iter_num > max_iters:
328 | break
329 |
330 | if ddp:
331 | destroy_process_group()
332 |
--------------------------------------------------------------------------------