├── .gitignore ├── .gitmodules ├── .vscode ├── launch.json └── settings.json ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── main.py ├── pyproject.toml ├── requirements.txt └── transformer.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .mypy_cache 3 | wandb 4 | xla-dumps 5 | 6 | *.pyc 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "jaxutils"] 2 | path = jaxutils 3 | url = https://github.com/awf/awf-jaxutils 4 | [submodule "awfutils"] 5 | path = awfutils 6 | url = https://github.com/awf/awfutils 7 | [submodule "timer"] 8 | path = timer 9 | url = https://github.com/LucienShui/timer 10 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: main.py", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "main.py", 12 | "console": "integratedTerminal", 13 | "justMyCode": false 14 | } 15 | ] 16 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": true, 3 | "python.formatting.provider": "black", 4 | "jupyter.debugJustMyCode": false, 5 | "debugpy.debugJustMyCode": false, 6 | } -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 2 | Running 3 | ```sh 4 | $ export JAX_PLATFORM_NAME=gpu # or cpu 5 | $ export JAX_LOG_COMPILES=1 # or 0 6 | $ export XLA_FLAGS=--xla_dump_to=./xla-dumps/ # Also dumps jaxprs to this folder 7 | $ python main.py -help 8 | $ python main.py -layers 3 -dmodel 512 -heads 8 -dk 64 -dff 2048 9 | ``` 10 | 11 | Results at https://wandb.ai/awfidius/pure-transformer 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrew Fitzgibbon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | A fully functional (pun intended) implementation of a machine learning transformer model in Python/JAX. I do realize that 'pure functional' and 'Python' are not necessarily [mots quit vont très bien ensemble](https://forum.wordreference.com/threads/sont-les-mots-qui-vont-tr%C3%A8s-bien-ensemble.1832510/), but I'm sure you'll agree on reading the code that it has [una anima di pura programmazione funzionale](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). And a little [macaronica](https://en.wikipedia.org/wiki/Macaronic_language) appeals to the peasant soul. In other words, don't worry about the language... 2 | 3 | Given only a few simple BLAS-like functions: 4 | ```python 5 | def linear(params, x: jnp.ndarray): 6 | return x @ params.weight + params.bias[None,:] 7 | 8 | def elementwise_linear(params, x: jnp.ndarray): 9 | return params.gain[None,:] * x + params.bias[None,:] 10 | 11 | def standardize(x, eps = 1e-5): 12 | return (x - x.mean())/(x.std() + eps) 13 | ``` 14 | then the entire transformer forward computation is 25 lines of code (excerpt from `transformer.py`): 15 | ```python 16 | def transformer(cfg, params, x: Int[Array, "L"]): 17 | """ 18 | cfg: Config, from transformer_init, holds hyperparameters 19 | params: Current transformer parameters, initialized in init 20 | x: 1D array of L integers, representing the input sequence 21 | output: L x n_vocab logits 22 | """ 23 | L, = x.shape # x is just 1D. Vmap/pmap will handle batching 24 | 25 | # Make shape checkers (https://github.com/awf/awfutils?tab=readme-ov-file#typecheck) 26 | LxL = lambda x: x.shape == (L, L) 27 | LxDk = lambda x: x.shape == (L, cfg.d_k) 28 | LxDff = lambda x: x.shape == (L, cfg.d_ff) 29 | LxDm = lambda x: x.shape == (L, cfg.d_model) 30 | 31 | # Create mask: 0 to attend, -Inf to ignore 32 | mask : LxL = jnp.log(jnp.tril(jnp.ones((L, L)))) 33 | 34 | # Start with token embeddings 35 | embeddings : LxDm = cfg.lambda_e * params.embeddings[x, :] 36 | 37 | # Add (learned) positional encodings 38 | embeddings += cfg.lambda_pe * params.positional_encodings[:L, :] 39 | 40 | # Apply the transformer layers 41 | for layer in params.layers: 42 | 43 | # Layer-normalize embeddings 44 | t1 : LxDm = vmap(standardize)(embeddings) 45 | t1 : LxDm = t1 @ jnp.diag(layer.norm_self_attn) 46 | 47 | # Multi-head self-attention 48 | self_attns = [] 49 | for head in layer.heads: 50 | 51 | # Project into this head's query/key space 52 | query : LxDk = t1 @ head.query 53 | key : LxDk = t1 @ head.key 54 | 55 | # Compute L x L attention matrix 56 | score : LxL = query @ key.T + mask 57 | attn : LxL = jax.nn.softmax(cfg.tau * score, axis=1) 58 | 59 | value : LxDk = t1 @ head.value 60 | self_attn : LxDk = attn @ value 61 | 62 | # Add this head's contribution to the list 63 | self_attns += [self_attn] # [LxDk for #heads] 64 | 65 | embeddings += jnp.hstack(self_attns) 66 | 67 | # Layer-normalize embeddings 68 | t2 : LxDm = vmap(standardize)(embeddings) 69 | t2 : LxDm = t2 @ jnp.diag(layer.norm_ff) 70 | 71 | # Feedforward fully connected 72 | t2 : LxDff = t2 @ layer.ffn1 73 | t2 = jax.nn.relu(t2) 74 | t2 : LxDm = t2 @ layer.ffn2 75 | 76 | # Add this layer's contribution into embeddings 77 | embeddings += t2 78 | 79 | # Layer-normalize embeddings 80 | embeddings : LxDm = vmap(standardize)(embeddings) 81 | embeddings = embeddings @ jnp.diag(params.pre_output_norm) 82 | 83 | # And linearly project to output dimension 84 | return embeddings @ params.output # L x n_vocab 85 | ``` 86 | 87 | The loss and its gradient needs a few more lines: 88 | ```python 89 | def crossentropy(output: jnp.ndarray, target: int): 90 | return -jax.nn.log_softmax(output)[target] 91 | 92 | def seq_crossentropy(output: jnp.ndarray, targets: jnp.ndarray): 93 | return vmap(crossentropy)(output, targets).mean() 94 | 95 | def transformer_loss(cfg, params, x): 96 | output = transformer(cfg, params, x) 97 | 98 | return seq_crossentropy(output[:-1], x[1:]) 99 | 100 | # Gradient wrt 'params' 101 | grad_loss = jax.grad(transformer_loss, argnums=1) 102 | ``` 103 | 104 | The random initialization is also short: 105 | ```python 106 | params = ParamsDict() 107 | 108 | # Create embedding layer 109 | rng,params.embeddings = rand(rng, jax.random.normal, (n_vocab, d_model)) 110 | 111 | # Positional encodings initialized to zeros 112 | params.positional_encodings = jnp.zeros((max_len, d_model)) 113 | 114 | # For transformer layers 115 | params.layers = [] 116 | for _ in range(n_layers): 117 | layer = ParamsDict() 118 | layer.norm_self_attn = jnp.ones(d_model) 119 | 120 | layer.heads = [] 121 | for _ in range(n_heads): 122 | head = ParamsDict() 123 | rng, head.query = matrix_init_uniform(rng, d_model, d_k) 124 | rng, head.key = matrix_init_uniform(rng, d_model, d_k) 125 | rng, head.value = matrix_init_uniform(rng, d_model, d_k) 126 | 127 | layer.heads.append(head) 128 | 129 | layer.norm_ff = jnp.ones(d_model) 130 | 131 | rng, layer.ffn1 = matrix_init_uniform(rng, d_model, d_ff) 132 | rng, layer.ffn2 = matrix_init_uniform(rng, d_ff, d_model) 133 | 134 | params.layers.append(layer) 135 | 136 | # Final normalization and output layer 137 | params.pre_output_norm = layernorm_init_identity(d_model) 138 | rng,params.output = linear_init_uniform(rng, d_model, n_vocab) 139 | ``` 140 | 141 | Add an optimizer, and we are pronto a romblare. 142 | 143 | ## Running 144 | ```sh 145 | $ export JAX_PLATFORM_NAME=gpu # or cpu 146 | $ export JAX_LOG_COMPILES=1 # or 0 147 | $ export XLA_FLAGS=--xla_dump_to=./xla-dumps/ # Also dumps jaxprs to this folder 148 | $ python main.py -help 149 | $ python main.py -layers 3 -dmodel 512 -heads 8 -dk 64 -dff 2048 150 | ``` 151 | 152 | Results at https://wandb.ai/awfidius/pure-transformer 153 | 154 | ## Acknowledgements 155 | 156 | The model is based on https://github.com/vpj/jax_transformer/blob/master/transformer.py, and the Adam and Dataset 157 | classes in jaxutils are almost direct copies from https://github.com/vpj/jax_transformer 158 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pure-from-the-ground-up transformer, based on https://github.com/vpj/jax_transformer 3 | """ 4 | 5 | from transformer import * 6 | 7 | import time 8 | import re 9 | import sys 10 | import os 11 | import logging 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | import numpy as np 16 | 17 | from functools import partial 18 | from itertools import islice 19 | 20 | 21 | import wandb 22 | 23 | from awfutils import Arg 24 | from jaxutils.datasets import TinyShakespeare 25 | from jaxutils.Adam import Adam 26 | from jaxutils.show_jaxpr import show_jaxpr_and_xla, show_xla, show_jaxpr 27 | 28 | jnp.set_printoptions(threshold=20, edgeitems=3, linewidth=2048, precision=3) 29 | np.set_printoptions(threshold=20, edgeitems=3, linewidth=2048, precision=3) 30 | 31 | # Noisily fail when arrays are the wrong size 32 | jax.config.update("jax_numpy_rank_promotion", "raise") 33 | 34 | jax.config.update( 35 | "jax_log_compiles", Arg("log-compiles", False, "Log JAX recompilations").peek() 36 | ) 37 | jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") 38 | 39 | LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper() 40 | logger = logging.getLogger("pure-transformer") 41 | logger.setLevel(level=LOGLEVEL) 42 | timer = timer.get_timer(logging.WARNING) 43 | db = logger.debug 44 | 45 | 46 | def tree_axpy(a, x, y): 47 | return jax.tree.map(lambda x, y: a * x + y, x, y) 48 | 49 | 50 | def main(): 51 | 52 | lr = Arg(flag="lr", doc="Learning rate", default=0.001) 53 | beta1 = Arg(flag="beta1", doc="Adam beta1", default=0.9) 54 | beta2 = Arg(flag="beta2", doc="Adam beta2", default=0.99) 55 | seq_len = Arg(flag="seq-len", doc="Sequence length", default=32) 56 | batch_size = Arg(flag="batch-size", doc="Batch size", default=128) 57 | epochs = Arg("epochs", 32) 58 | batches = Arg("batches", sys.maxsize, "Max batches") 59 | 60 | # Init the model params 61 | heads = Arg("heads", 8, "Number of attention heads") 62 | d_model = Arg("dmodel", 512, "Embedding dimension") 63 | d_k = Arg("dk", 64, "Attention head dimension") 64 | d_ff = Arg("dff", 512, "Feedforward layer dimension") 65 | n_layers = Arg("layers", 3, "Number of layers") 66 | 67 | save = Arg("save", "", "Save mode. Log run to wandb, lengthen epochs and batches") 68 | 69 | if save(): 70 | wandb.init( 71 | project="pure-transformer", 72 | entity="awfidius", 73 | name=save() if len(save()) else None, 74 | config=Arg.config(), 75 | ) 76 | else: 77 | print("Quick mode, disabling wandb, using small prime sizes") 78 | wandb.init(mode="disabled") 79 | epochs.default = 5 80 | batches.default = 10 81 | # Sizes are prime numbers, to catch any mismatches 82 | d_model.default = 13 * 7 83 | d_k.default = 13 84 | heads.default = 7 85 | d_ff.default = 111 86 | 87 | start = time.time() 88 | 89 | # Create PRNG key 90 | rnd_key = jax.random.PRNGKey(42) 91 | 92 | # Create dataset 93 | dataset = TinyShakespeare(rnd_key, seq_len=seq_len(), batch_size=batch_size()) 94 | tostr = lambda x: "".join([dataset.itos[i] for i in x]).replace("\n", "\\n") 95 | 96 | rnd_key, cfg, params = transformer_init( 97 | rnd_key, 98 | dataset.n_tokens, 99 | d_model=d_model(), 100 | n_layers=n_layers(), 101 | n_heads=heads(), 102 | d_k=d_k(), 103 | d_ff=d_ff(), 104 | ) 105 | 106 | names = [k for (k, _) in params.items()] 107 | print(names) 108 | assert len(names) == len(jax.tree.flatten(params)[0]) 109 | 110 | # gnorms_table = wandb.Table(columns=names) 111 | # wandb.log({"gnorms_table": gnorms_table}) 112 | 113 | sizes = jax.tree.map(lambda v: np.prod(v.shape), params) 114 | sizes.print("sizes:") 115 | print("Total parameter count:", np.sum(jax.tree.flatten(sizes)[0])) 116 | # sizes_table = wandb.Table(columns=['param','size']) 117 | 118 | @partial(jax.jit, static_argnums=0) 119 | def loss_batch(cfg, params, seq): 120 | batched = vmap(transformer_loss, in_axes=(None, None, 0), out_axes=0) 121 | return jnp.mean(batched(cfg, params, seq)) 122 | 123 | # show_jaxpr(get_loss_batch, (params, *islice(dataset,1))) 124 | grad_loss_batch_unjit = jax.grad(loss_batch, argnums=1) 125 | grad_loss_batch = jax.jit(grad_loss_batch_unjit, static_argnums=0) 126 | 127 | value_and_grad_loss_batch_unjit = jax.value_and_grad(loss_batch, argnums=1) 128 | value_and_grad_loss_batch = jax.jit( 129 | value_and_grad_loss_batch_unjit, static_argnums=0 130 | ) 131 | 132 | matches = re.search("--xla_dump_to=([^ ]+)", os.environ.get("XLA_FLAGS") or "") 133 | if matches: 134 | fn = matches[1] + "/grad_loss_batch.jaxpr.py" 135 | with open(fn, "w") as file: 136 | # xla = jax.xla_computation(loss_batch, static_argnums=0)(cfg, params, *islice(dataset,1)) 137 | # print("XLA=", xla.as_hlo_text()) 138 | show_jaxpr( 139 | grad_loss_batch, 140 | (cfg, params, *islice(dataset, 1)), 141 | file=file, 142 | static_argnums=0, 143 | ) 144 | print("Saved jaxpr to", fn) 145 | 146 | sgd = Arg("sgd", False, "Pure sgd") 147 | 148 | test_data = jnp.hstack([*islice(dataset, 4)]) 149 | # grad_loss_batch = jax.pjit(grad_loss_batch_unjit, static_argnums=0) 150 | 151 | optimizer = Adam(params, lr=lr(), betas=(beta1(), beta2())) 152 | 153 | gnorms_all = np.zeros((len(names), 0)) 154 | for epoch in range(epochs()): 155 | 156 | if epoch == 0: 157 | # epoch zero is straight through to sample 158 | pass 159 | else: 160 | # Iterate through batches 161 | for i, data in enumerate(islice(dataset, batches())): 162 | # Get loss and gradients 163 | loss, grads = value_and_grad_loss_batch(cfg, params, data) 164 | 165 | # print(f"{wandb.run.name} loss {loss.item()} data {tostr(data[0])}") 166 | total_time = time.time() - start 167 | 168 | wandb.log({"time": total_time, "batch": i, "loss": loss}) 169 | 170 | # Update parameters 171 | if sgd(): 172 | params = tree_axpy(-lr(), grads, params) 173 | else: 174 | params = optimizer.step(params, grads) 175 | 176 | # Log a sample after each epoch 177 | test_loss = loss_batch(cfg, params, test_data) 178 | 179 | print(f"E{epoch} test={test_loss:.3f}") 180 | if epoch % 25 == 0: 181 | with timer("sample"): 182 | prompt = [dataset.stoi[c] for c in "Au"] 183 | sampled = transformer_sample( 184 | cfg, params, jnp.array(prompt), length=20 + epoch 185 | ) 186 | print(f"Sample [{tostr(prompt)}|{tostr(sampled[len(prompt):])}]") 187 | 188 | # Grab Current Time After Running the Code 189 | end = time.time() 190 | total_time = end - start 191 | print("TIME: " + str(total_time)) 192 | 193 | 194 | if __name__ == "__main__": 195 | main() 196 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | include = '\.pyi?$' 4 | extend-exclude = ''' 5 | ^/(timer|jaxutils)/ 6 | ''' 7 | 8 | # pyproject.toml 9 | [tool.pytest.ini_options] 10 | minversion = "6.0" 11 | # addopts = "-ra -q" 12 | testpaths = [ 13 | ".", 14 | "jaxutils", 15 | ] 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wandb 2 | jax[cuda]==0.5 3 | jaxtyping 4 | torch # until awfutils moves to optree 5 | -e timer/. 6 | -e awfutils/. 7 | -e awf-jaxutils/. 8 | -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pure-from-the-ground-up transformer, based on https://github.com/vpj/jax_transformer/blob/master/transformer.py 3 | 4 | """ 5 | 6 | from timer import timer 7 | 8 | import jax 9 | from jax import vmap 10 | import jax.numpy as jnp 11 | 12 | from jaxtyping import Array, Int 13 | 14 | from functools import partial 15 | 16 | import jax.experimental.host_callback 17 | 18 | from awfutils import Arg, typecheck 19 | from jaxutils.ParamsDict import ParamsDict 20 | 21 | 22 | def rand(rng, f, shape, **kwargs): 23 | """ 24 | Wrap jax.random.foo function to split the incoming rng, and return the new rng beside the payload 25 | 26 | rng = ... from previous code ... 27 | 28 | rng, vals1 = rand(rng, jax.random.uniform, (9,3), minval=-2.0, maxval=2.0) 29 | # ^-- rng is now newly split 30 | rng, vals2 = rand(rng, jax.random.normal, (3,9)) 31 | # ^-- rng is split again 32 | """ 33 | rng, rng1 = jax.random.split(rng) 34 | return rng, f(rng1, shape, **kwargs) 35 | 36 | 37 | def matrix_init_uniform(rng: jax.random.PRNGKey, in_features: int, out_features: int): 38 | """ 39 | Initialize a matrix with uniform weights, scaled by 1/sqrt(in_features) 40 | """ 41 | params = ParamsDict() 42 | rnd_range = 1 / in_features**0.5 43 | return rand( 44 | rng, 45 | jax.random.uniform, 46 | (in_features, out_features), 47 | minval=-rnd_range, 48 | maxval=rnd_range, 49 | ) 50 | 51 | 52 | # Layer norm 53 | def elementwise_linear_init_identity(shape): 54 | """ 55 | Initialize an elementwise_linear layer with unit gain, zero bias 56 | """ 57 | return ParamsDict(gain=jnp.ones(shape), bias=jnp.zeros(shape)) 58 | 59 | 60 | def linear(params, x: jnp.ndarray): 61 | return x @ params.weight + params.bias[None, :] 62 | 63 | 64 | def elementwise_linear(params, x: jnp.ndarray): 65 | return params.gain[None, :] * x + params.bias[None, :] 66 | 67 | 68 | def standardize(x, eps=1e-5): 69 | return (x - x.mean()) / (x.std() + eps) 70 | 71 | 72 | flip_pe_coef = Arg("flip-pe", False, "Scale token embedding, not position embedding") 73 | 74 | 75 | def transformer_init( 76 | rng: jax.random.PRNGKey, 77 | n_vocab: int, 78 | d_model: int, 79 | n_layers: int, 80 | n_heads: int, 81 | d_k: int, 82 | d_ff: int, 83 | max_len=4096, 84 | ): 85 | assert d_k * n_heads == d_model 86 | 87 | # Build config struct for call 88 | config = ParamsDict() 89 | config.d_model = d_model 90 | config.d_ff = d_ff 91 | config.d_k = d_k 92 | config.heads = n_heads 93 | if flip_pe_coef(): 94 | config.lambda_e = d_model**-0.5 95 | config.lambda_pe = 1.0 96 | else: 97 | config.lambda_e = d_model**-0.5 98 | config.lambda_pe = 1.0 99 | config.tau = 1 / d_k**0.5 100 | 101 | # Build initializers for params 102 | params = ParamsDict() 103 | 104 | # Create embedding layer 105 | rng, params.embeddings = rand(rng, jax.random.normal, (n_vocab, d_model)) 106 | 107 | # Positional encodings initialized to zeros 108 | params.positional_encodings = jnp.zeros((max_len, d_model)) 109 | 110 | # For transformer layers 111 | params.layers = [] 112 | for _ in range(n_layers): 113 | layer = ParamsDict() 114 | layer.norm_self_attn = jnp.ones(d_model) 115 | 116 | layer.heads = [] 117 | for _ in range(n_heads): 118 | head = ParamsDict() 119 | rng, head.query = matrix_init_uniform(rng, d_model, d_k) 120 | rng, head.key = matrix_init_uniform(rng, d_model, d_k) 121 | rng, head.value = matrix_init_uniform(rng, d_model, d_k) 122 | 123 | layer.heads.append(head) 124 | 125 | layer.norm_ff = jnp.ones(d_model) 126 | 127 | rng, layer.ffn1 = matrix_init_uniform(rng, d_model, d_ff) 128 | rng, layer.ffn2 = matrix_init_uniform(rng, d_ff, d_model) 129 | 130 | params.layers.append(layer) 131 | 132 | # Final normalization and output layer 133 | params.pre_output_norm = jnp.ones(d_model) 134 | rng, params.output = matrix_init_uniform(rng, d_model, n_vocab) 135 | 136 | return rng, config, params 137 | 138 | 139 | # Format off for the size annotations 140 | # fmt: off 141 | @partial(jax.jit, static_argnums=0) 142 | @typecheck 143 | def transformer(cfg, params, x: Int[Array, "L"]): 144 | """ 145 | cfg: Config, from transformer_init, holds hyperparameters 146 | params: Current transformer parameters, initialized in init 147 | x: 1D array of L integers, representing the input sequence 148 | output: L x n_vocab logits 149 | 150 | Obviously, this is just one example of a transformer. There 151 | are many variations, depending where normalizations go, 152 | whether or not there is bias, what kinds of position 153 | encodings, etc. 154 | """ 155 | print("Compiling for L=", x.shape) 156 | 157 | L, = x.shape # x is just 1D. Vmap/pmap will handle batching 158 | 159 | # Make shape checkers for awfutils.typecheck 160 | LxL = lambda x: x.shape == (L, L) 161 | LxDk = lambda x: x.shape == (L, cfg.d_k) 162 | LxDff = lambda x: x.shape == (L, cfg.d_ff) 163 | LxDm = lambda x: x.shape == (L, cfg.d_model) 164 | 165 | # Create mask: 0 to attend, -Inf to ignore 166 | mask : LxL = jnp.log(jnp.tril(jnp.ones((L, L)))) 167 | 168 | # Start with token embeddings 169 | embeddings : LxDm = cfg.lambda_e * params.embeddings[x, :] 170 | 171 | # Add (learned) positional encodings 172 | embeddings += cfg.lambda_pe * params.positional_encodings[:L, :] 173 | 174 | # Apply the transformer layers 175 | for layer in params.layers: 176 | 177 | # Layer-normalize embeddings 178 | t1 : LxDm = vmap(standardize)(embeddings) 179 | t1 : LxDm = t1 @ jnp.diag(layer.norm_self_attn) 180 | 181 | # Multi-head self-attention 182 | self_attns = [] 183 | for head in layer.heads: 184 | 185 | # Project into this head's query/key space 186 | query : LxDk = t1 @ head.query 187 | key : LxDk = t1 @ head.key 188 | 189 | # Compute L x L attention matrix 190 | score : LxL = query @ key.T + mask 191 | attn : LxL = jax.nn.softmax(cfg.tau * score, axis=1) 192 | 193 | value : LxDk = t1 @ head.value 194 | self_attn : LxDk = attn @ value 195 | 196 | # Add this head's contribution to the list 197 | self_attns += [self_attn] # [LxDk for #heads] 198 | 199 | embeddings += jnp.hstack(self_attns) 200 | 201 | # Layer-normalize embeddings 202 | t2 : LxDm = vmap(standardize)(embeddings) 203 | t2 : LxDm = t2 @ jnp.diag(layer.norm_ff) 204 | 205 | # Feedforward fully connected 206 | t2 : LxDff = t2 @ layer.ffn1 207 | t2 = jax.nn.relu(t2) 208 | t2 : LxDm = t2 @ layer.ffn2 209 | 210 | # Add this layer's contribution into embeddings 211 | embeddings += t2 212 | 213 | # Layer-normalize embeddings 214 | embeddings : LxDm = vmap(standardize)(embeddings) 215 | embeddings = embeddings @ jnp.diag(params.pre_output_norm) 216 | 217 | # And linearly project to output dimension 218 | return embeddings @ params.output # L x n_vocab 219 | # fmt: on 220 | 221 | 222 | def crossentropy(output: jnp.ndarray, target: int): 223 | return -jax.nn.log_softmax(output)[target] 224 | 225 | 226 | def seq_crossentropy(output: jnp.ndarray, targets: jnp.ndarray): 227 | return vmap(crossentropy)(output, targets).mean() 228 | 229 | 230 | def transformer_loss(cfg, params, x): 231 | """ 232 | # Transformer loss for one example 233 | 234 | cfg: Config, from init 235 | params: Current transformer parameters, initialized in init 236 | x: 1D array of integers, representing the input sequence 237 | """ 238 | output = transformer(cfg, params, x) 239 | 240 | return seq_crossentropy(output[:-1], x[1:]) 241 | 242 | 243 | # We don't jit this, as the loop will unroll, and take a long time to compile 244 | def transformer_sample(cfg, params, seq: jnp.ndarray, length: int = 20): 245 | 246 | for _i in range(length): 247 | output = transformer(cfg, params, seq) 248 | 249 | idx = jnp.argmax(output[-1]) 250 | 251 | seq = jnp.concatenate((seq, idx[None])) 252 | 253 | return seq 254 | --------------------------------------------------------------------------------