├── .gitignore ├── README.md ├── dataset.py ├── generate.py ├── imgs └── init-vs-no-init.png ├── llama.py ├── notebooks ├── dataset-dev.ipynb ├── llama-dev.ipynb ├── mixture-of-depths-dev.ipynb ├── tokenizer-dev.ipynb └── train-step-dev.ipynb ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | wikitext_data/ 2 | wandb/ 3 | __pycache__/ 4 | .ipynb_checkpoints/ 5 | .venv/ 6 | tokenizer.model 7 | *.safetensors 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mini-LLaMA MLX 2 | 3 | A simple implementation of LLaMA 2 that you can run experiments with on your MacBook. 4 | 5 | 6 | ## Setup 7 | 8 | - Download the LLaMA 2 tokenizer from [Meta's website](https://llama.meta.com/llama-downloads/) 9 | - Install the required packages by `pip install -r requirements.txt` 10 | - To train the model, run `python train.py`. The training configuration is in the `TrainerConfig` class. 11 | - To use the model, run `python generate.py `. 12 | 13 | 14 | ## Usage 15 | 16 | The default parameter in `train.py` is tuned for a model to train for 30 minutes on an M1 MacBook Air 16GB. 17 | Here's how to tune it to your own machine: 18 | 19 | 1. Test the time per iteration for the maximum batch size your machine can fit without increasing the time linearly. 20 | For me it's batch size 16 at ~1 sec / iteration 21 | 1. Divide 128 by the batch size to get your gradient accumulation steps `grad_acc_steps`. For me it's `128 / 16 = 8` 22 | 1. Divide 1800 by (the time per iteration * `grad_acc_steps`) to get `n_update_steps`. For me it's `1800 / (1 * 8) = 225` 23 | 1. Increase `n_epochs` when the model cannot learn well 24 | 25 | 26 | ## Why Another Implementation 27 | 28 | I decided to write another implementation for a couple reasons: 29 | 30 | - Try out MLX: MLX is a new framework with clean APIs and has the good features of JAX and PyTorch. I want to try it out. 31 | - Educational: I want to create an example with more modern architectures and techniques. I will also explain the code in detail. 32 | - Lightweight: I want to create something that enables people with limited compute to experiment with language models. 33 | - Simple and hackable: I limit support to reduce code complexity, but the code being simple makes it easy to hack. 34 | 35 | 36 | ## Dataset 37 | 38 | Inspired by [LLM-baselines](https://github.com/epfml/llm-baselines), 39 | I chose [WikiText](https://huggingface.co/datasets/wikitext) as the dataset. 40 | WikiText is small (~100M tokens) but large enough to be interesting. 41 | 42 | ### Data Preprocessing 43 | 44 | WikiText dataset is a collection of articles, and the one Hugging Face Dataset hosts splits it into sentences. 45 | The largest context-independent body of text is an article, 46 | so I decided to concatenate sentences in the same article and tokenize them with BOS and EOS token added. 47 | This turned out to be surprisingly difficult to get right. 48 | 49 | The title of each article is formatted as `= Title =`, and sections `= = section title = =` or `= = = subsection title = = =`. 50 | Initially, I applied a heuristic that classifies a sequence as a title if it starts with exactly one `=`, 51 | but turns out game stats in sports are formatted in the same way as titles are: `= Win ; D =`. 52 | As I add more and more heuristics, I found more and more edge cases. 53 | 54 | After further inspection of the data, I found that every title is prepended with 2 empty sentences. 55 | Checking if a title-like sentence has 2 empty sentences did the trick. 56 | 57 | I learned 2 lessons: 58 | - It's difficult to preprocess the dataset perfectly, especially when we scale up the dataset size. 59 | - It's important to inspect data. 60 | 61 | Data correctness cannot be overlooked because data is the upper bound of a model's performance. 62 | 63 | 64 | ### Creating a Data Batch 65 | 66 | The goal is to tokenize the list of sequences and split them into blocks (inputs + labels) for the model to ingest. 67 | Concretely, block size `blk_size = seq_len + 1`, where input length is `seq_len` and `1` is for the autoregressive label. 68 | 69 | A sequence is usually not a multiple of `blk_size`, so we have to do something to the tokens that cannot form a full block. 70 | Here I chose to pad a sequence to a multiple instead of truncation to maximize the data usage. 71 | 72 | After ensuring the sequence is a multiple of `blk_size`, we split the sequence into a batch of sequences of length `blk_size`. 73 | In the end, we will process the dataset into a batch of sequences of length `blk_size`. 74 | 75 | To iterating over the data, we index the first axis and increment by `bsz`, where `bsz` is the batch size. 76 | We count the number of steps independently from the iteration index 77 | because `n_steps` may be larger or smaller than the total number of batches in the dataset. 78 | 79 | 80 | ### Epochs 81 | 82 | I realized that it is difficult for a small model to learn well with little data. 83 | Due to limited compute, I decided to make the model overfit on small amount of data by repeating it. 84 | 85 | `n_epochs` specifies how many times a batch will be trained on the model. 86 | This hyperparameter changes the dataset size in order to maintain the same amount of compute. 87 | Specifically, `n_seqs = (bsz * n_steps) // n_epochs`, 88 | where `bsz * n_steps` is the total number of batches the models sees throughout the whole training. 89 | 90 | 91 | ## Model 92 | 93 | Model is the most flashy part of the code. 94 | The LLaMA I implemented here is unfortunately LLaMA 2 without grouped-query attention, 95 | but the architecture will be the same as LLaMA 3 if we replace attention with GQA. 96 | 97 | 98 | ### Notation 99 | 100 | I adopted the [shape suffix notation](https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd). 101 | Every tensor variable is suffixed with its dimensions. 102 | 103 | | Name | Description | 104 | | :-: | :- | 105 | | V | Vocabulary size | 106 | | D | Embedding dimension | 107 | | N | Heads dimension | 108 | | T | Sequence dimension | 109 | | H | Head embedding size | 110 | | C | Hidden dimension for feed forward nets | 111 | | B | Batch dimension | 112 | 113 | 114 | ### Self Attention 115 | 116 | There are so many tutorials on this so I'll be brief here. 117 | Self attention does the following: 118 | 119 | 1. Project an input embedding into 3 attention matrices: Q, K, and V 120 | 1. Split the embedding dimension into multiple heads 121 | 1. Apply rotary positional embeddings to query and key matrices 122 | 1. Perform scaled dot product attention 123 | 1. Flatten the head dimensions of the output embedding and project it to output 124 | 125 | A few things I would like to note about my implementation: 126 | 127 | - Projecting QKV with one large linear layer is inspired by 128 | [Andrej Karpathy's NanoGPT](https://github.com/karpathy/nanoGPT/blob/325be85d9be8c81b436728a420e85796c57dba7e/model.py#L56). 129 | It uses the property that in $y = Wx$ each row of $W$ is independent, so 3 projections can be concatenated into one. 130 | - Splitting attention heads is basically moving N (the head dimension) to outer dimensions for parallelization. 131 | 132 | #### Causal Masking 133 | 134 | An attention layer generates predictions for **every** token. 135 | Because the model is learning to predict tokens based on previous ones (causal language modeling), 136 | for every token `i`, we need to stop the model from attending (giving a relation score) to tokens `i + 1` to `seq_len`. 137 | We achieve this by masking the softmax scores of those tokens in the scaled dot production attention with `mask_TT`. 138 | 139 | `mask_TT` is passed in as an argument because `T` can change when the model is inferencing. 140 | 141 | 142 | ### Feed Forward Network 143 | 144 | A feed forward network consists of 2 linear projections and a SwiGLU activation. 145 | 146 | #### Hidden Dimension 147 | 148 | Following the LLaMA paper, I set the hidden dimension to $4d \times \frac{2}{3}$, where $d$ is the input dimension. 149 | In [Meta's implementation](https://github.com/meta-llama/llama/blob/b8348da38fde8644ef00a56596efb376f86838d1/llama/model.py#L332), 150 | they additionally rounded the number up to a multiple of 256. 151 | 152 | #### SwiGLU 153 | 154 | SwiGLU inherits from many different components. 155 | 156 | GLU ([Gated Linear Units](https://paperswithcode.com/method/glu)) proposes $GLU(a, b) = \sigma(a) \otimes b$ 157 | 158 | [Swish-1](https://paperswithcode.com/method/swish), aka SiLU (Sigmoid-weighted Linear Unit), 159 | is the activation function $Swish_1(x) = x \cdot \sigma(x)$ 160 | 161 | [SwiGLU](https://paperswithcode.com/method/swiglu) combines the two: $SwiGLU(x) = Swish_1(x W_1) \otimes (x W_2)$ 162 | 163 | 164 | ### Model 165 | 166 | One thing I would like to talk about is how Transformers make predictions when training. 167 | Unlike during inference, Transformers make `seq_len` predictions, where every prediction `i` is based on tokens `0` to `i - 1`. 168 | This is why we need `seq_len + 1` tokens for creating labels: the last token needs a label as well. 169 | 170 | Every prediction is independent, so we can calculate the loss for every prediction 171 | and average across both the batch and the sequence dimensions. 172 | 173 | We also have to remove predictions for pad tokens. 174 | We do this with a mask, zeroing out tokens predictions whose inputs are pad tokens. 175 | We can also use the mask to count the number of non-pad-token predictions, which is useful for averaging the loss. 176 | 177 | 178 | ## Training 179 | 180 | Now we get to train some models and stare at the loss plot. 181 | Here is my [WAndB project](https://wandb.ai/kimbochen/mini-llama-mlx) for reference. 182 | 183 | 184 | ### Weight Initialization 185 | 186 | Weight initialization is important. It dictates how smoothly your model starts. 187 | 188 | ![](imgs/init-vs-no-init.png) 189 | 190 | > Initialized (light blue) vs. randomly initialized (green) 191 | 192 | - [Pythia](https://arxiv.org/abs/2304.01373) and [GPT-NeoX-20B](https://arxiv.org/abs/2204.06745) both cited GPT-J-6B for the intialization method, 193 | but I cannot find it in the Mesh Transformer repo. However, GPT-NeoX repo provides implementations 194 | [here](https://github.com/EleutherAI/gpt-neox/blob/01657aa243aed07701660a2dd486434349daa72e/megatron/model/init_functions.py). 195 | - I adopted the initialization of [PaLM](https://arxiv.org/abs/2204.02311v5) because the model architecture is more similar. 196 | 197 | The weight initialization scheme is as follows: 198 | 199 | - Linear layers are initialized with a normal distribution of mean 0 std $\displaystyle \frac{1}{\sqrt{n}}$, 200 | where $n$ is the input dimension 201 | - The input embedding layer is initialized with a standard normal distribution (mean 0 std 10) 202 | - RMSNorm layers are initialized with constants 1 203 | 204 | 205 | ### Learning Rate Scheduling 206 | 207 | I implemented a standard linear warmup + cosine decay learning rate schedule. 208 | I use 10% - 15% of total update steps for warmup. 209 | 210 | 211 | ### Gradient Accumulation 212 | 213 | I decided to implement gradient accumulation to improve the gradient quality, but it makes code a little more complex. 214 | Here I clarify some terminologies: 215 | 216 | | Term | Definition | Formula | 217 | | :- | :- | :- | 218 | | `n_update_steps` | The number of steps we perform a gradient update | | 219 | | `grad_acc_steps` | The number of steps we accumulate gradients | | 220 | | `n_steps` | The total number of steps | `n_steps = n_update_steps * grad_acc_steps` | 221 | | `warmup_raio` | The ratio of warmup steps over the number of update steps | | 222 | | `warmup_steps` | The number of warmup steps | `warmup_steps = n_update_steps * warmup_raio` | 223 | 224 | Command line logs every step, while WAndB only logs every update step to make the learning rate and the loss curves easier to interpret. 225 | 226 | 227 | ## Generation 228 | 229 | This code focuses on training, so I did not do anything fancy for generation. 230 | In terms of sampling, I select the token with greatest probability (greedy sampling). 231 | 232 | Unlike training, where cross entropy loss trains the model to select the label token, 233 | we sample a token based on the probability distribution the model predicts, 234 | meaning that we might select the second or third choice tokens with some probability. 235 | 236 | One way to modulate the probability distribution is to set the temperature[^1]. 237 | Temperature is a factor that we multiply with the logits to sharpen or flatten the probability distribution. 238 | I personally find the formation indirect[^2], so here's the result: 239 | **The larger the temperature, the more chaotic the prediction (higher chance of selecting low-probability tokens).** 240 | A special case is that when temperature is 0, it reduces to greedy sampling. 241 | 242 | - [This website](https://lukesalamone.github.io/posts/what-is-temperature/) has nice interactive plots for temperature 243 | - [This blog from Allen NLP](https://blog.allenai.org/a-guide-to-language-model-sampling-in-allennlp-3b1239274bc3) explains 244 | more advanced techniques such as top-k and top-p sampling 245 | 246 | 247 | [^1]: The terminology is inspired by [Boltzmann distribution](https://en.wikipedia.org/wiki/Boltzmann_distribution), 248 | but I don't find it helpful for gaining intuition. 249 | [^2]: Temperature T is in (0.0, 1.0], we **divide** the logits by T, 250 | so higher T --> **smaller** logits --> flatter distribution --> more chaotic distribution. 251 | Since we are dividing, T cannot be zero, so we have to treat 0 as a special case when implementing. 252 | 253 | 254 | ## Future Goals 255 | 256 | Here are some things I would do if I have more time. 257 | I might do it in PyTorch in the future because my MacBook Air is still too compute-constrained to develop on. 258 | 259 | - Inference: There are many interesting inference optimizations, including KV cache, speculative decoding, Medusa 260 | - Efficiency training: I would love to test out small-scale mixture of experts and mixture of depths models 261 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | import mlx.core as mx 5 | from datasets import load_dataset 6 | from joblib import Memory 7 | from sentencepiece import SentencePieceProcessor 8 | from tqdm import tqdm 9 | 10 | 11 | logging.basicConfig(level=logging.INFO) 12 | memory = Memory('.data_cache/', verbose=0) 13 | 14 | 15 | def create_wikitext_dataset(split): 16 | ''' 17 | Preprocess the WikiText dataset. 18 | Output: 19 | text_seqs : List[str] : Processed text sequences. 20 | ''' 21 | logging.info(f'Loading WikiText {split} dataset ...') 22 | dataset = load_dataset('wikitext', 'wikitext-103-v1') 23 | corpus = dataset[split]['text'] 24 | 25 | is_title = lambda text: text[:3] == ' = ' and text[-4:] == ' = \n' and text[3].isupper() 26 | text_seqs = [] 27 | text_seq = corpus[1] # corpus[0] is an empty string 28 | 29 | for i, text in enumerate(corpus[2:], start=2): 30 | if (corpus[i-1] == corpus[i-2] == '') and is_title(text): 31 | text_seqs.append(text_seq) # Store text sequence when found a new title 32 | text_seq = text 33 | else: 34 | text_seq += text 35 | else: 36 | text_seqs.append(text_seq) # The last text sequence 37 | 38 | return text_seqs 39 | 40 | 41 | @memory.cache 42 | def prepare_dataset(split, seq_len, pad_token_id): 43 | ''' 44 | Creates the dataset, tokenize the sequences, splits token sequences to specified length. 45 | Output: 46 | token_seqs : mx.array [Number of batches, seq_len] 47 | ''' 48 | text_seqs = create_wikitext_dataset(split) 49 | 50 | logging.info('Tokenizing dataset ...') 51 | tokenizer = SentencePieceProcessor(model_file='tokenizer.model') 52 | blk_size = seq_len + 1 53 | token_seqs = [] 54 | 55 | for text_seq in tqdm(text_seqs): 56 | token_seq = tokenizer.encode(text_seq, add_bos=True, add_eos=True) 57 | token_seq = mx.array(token_seq, dtype=mx.int16) 58 | token_seq = mx.pad(token_seq, [0, blk_size - token_seq.size % blk_size], pad_token_id) 59 | token_seqs.append(token_seq.reshape(-1, blk_size)) 60 | 61 | token_seqs = mx.concatenate(token_seqs, axis=0) 62 | logging.info(f'Tokenized {token_seqs.shape[0]} batches = {token_seqs.size} tokens.') 63 | 64 | return token_seqs 65 | 66 | 67 | def config_dataloader(bsz, n_steps, n_epochs, split, seq_len, pad_token_id, **kwargs): 68 | token_seqs = prepare_dataset(split, seq_len, pad_token_id) 69 | n_seqs = (bsz * n_steps) // n_epochs 70 | logging.info(f'Training on {n_seqs} sequences for {n_epochs} epochs.') 71 | 72 | def load_data_batch(): 73 | step_idx = 0 74 | while True: 75 | for idx in range(0, n_seqs, bsz): 76 | bblk = token_seqs[idx:idx+bsz, :] 77 | yield bblk[:, :-1], bblk[:, 1:] 78 | step_idx += 1 79 | if step_idx == n_steps: 80 | return 81 | 82 | return load_data_batch 83 | 84 | 85 | if __name__ == '__main__': 86 | load_train_data = config_dataloader(256, 10, split='train', seq_len=32, pad_token_id=-1) 87 | for inputs, labels in load_train_data(): 88 | print(inputs.shape, labels.shape) 89 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import dataclass, asdict 3 | from functools import partial 4 | 5 | import mlx.core as mx 6 | from mlx import nn 7 | from mlx.utils import tree_unflatten 8 | from sentencepiece import SentencePieceProcessor 9 | from tqdm import tqdm 10 | 11 | from dataset import config_dataloader 12 | from llama import LLaMAConfig, LLaMA 13 | 14 | 15 | @dataclass 16 | class GenerationConfig: 17 | max_new_tokens: int = 256 18 | 19 | 20 | def generate(prompt): 21 | tokenizer = SentencePieceProcessor(model_file='tokenizer.model') 22 | 23 | cfg_m = LLaMAConfig(n_layers=6, d_embd=512, n_heads=8) 24 | model = LLaMA(**asdict(cfg_m)) 25 | model.update(tree_unflatten([*mx.load(sys.argv[1]).items()])) 26 | model.eval() 27 | 28 | cfg_g = GenerationConfig() 29 | tokens_BT = mx.array([tokenizer.encode(prompt, add_bos=True)], dtype=mx.uint16) 30 | new_tokens = 0 31 | new_token_BT = None 32 | 33 | while new_tokens < cfg_g.max_new_tokens and new_token_BT != tokenizer.eos_id: 34 | logits_BTC = model(tokens_BT) 35 | new_token_BT = mx.argmax(logits_BTC[:, -1, :], axis=-1, keepdims=True) # Greedy sampling 36 | tokens_BT = mx.concatenate([tokens_BT, new_token_BT], axis=-1)[:, :cfg_m.seq_len] 37 | new_tokens += 1 38 | 39 | completion = tokenizer.decode(tokens_BT.tolist())[0] 40 | print(completion) 41 | 42 | 43 | if __name__ == '__main__': 44 | generate(sys.argv[2]) 45 | -------------------------------------------------------------------------------- /imgs/init-vs-no-init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kimbochen/mini-llama-mlx/3c17dbdf721133ecb23e1f9bb9d0880241367282/imgs/init-vs-no-init.png -------------------------------------------------------------------------------- /llama.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | 3 | import mlx.core as mx 4 | from mlx import nn 5 | from mlx.nn import MultiHeadAttention as MHA 6 | from mlx.core.fast import scaled_dot_product_attention 7 | from mlx.utils import tree_flatten, tree_unflatten, tree_map 8 | 9 | 10 | 11 | @dataclass 12 | class LLaMAConfig: 13 | ''' 14 | Model configuration reference: 15 | https://github.com/epfml/llm-baselines?tab=readme-ov-file#results-on-wikitext 16 | ''' 17 | n_layers: int = 24 18 | vocab_size: int = 32000 # V , LLaMA 2 tokenizer 19 | d_embd: int = 768 # D 20 | n_heads: int = 12 # N 21 | seq_len: int = 512 # T 22 | rope_theta: float = 1e4 23 | rope_scale: float = 1.0 24 | ffn_mult: int = 256 25 | norm_eps: float = 1e-5 26 | 27 | 28 | class SelfAttention(nn.Module): 29 | def __init__(self, d_embd, n_heads, rope_theta, rope_scale, **kwargs): 30 | super().__init__() 31 | assert d_embd % n_heads == 0 32 | self.d_head = d_embd // n_heads # H 33 | 34 | self.attn_proj = nn.Linear(d_embd, 3*d_embd, bias=False) 35 | self.rope = nn.RoPE(self.d_head, base=rope_theta, scale=rope_scale) 36 | self.scale = self.d_head ** -0.5 37 | self.out_proj = nn.Linear(d_embd, d_embd, bias=False) 38 | 39 | def __call__(self, x_BTD, mask_TT): 40 | B, T, D = x_BTD.shape 41 | to_attn_heads = lambda z: z.reshape(B, T, -1, self.d_head).transpose(0, 2, 1, 3) 42 | 43 | qkv_BTD = self.attn_proj(x_BTD).split(3, axis=-1) 44 | Q_BNTH, K_BNTH, V_BNTH = map(to_attn_heads, qkv_BTD) 45 | Q_BNTH, K_BNTH = self.rope(Q_BNTH), self.rope(K_BNTH) 46 | O_BNTH = scaled_dot_product_attention(Q_BNTH, K_BNTH, V_BNTH, scale=self.scale, mask=mask_TT) 47 | out_BTD = self.out_proj(O_BNTH.transpose(0, 2, 1, 3).reshape(B, T, D)) 48 | 49 | return out_BTD 50 | 51 | 52 | class FeedForwardNet(nn.Module): 53 | def __init__(self, d_embd, ffn_mult, **kwargs): 54 | super().__init__() 55 | hidden_dim = int((4 * d_embd) * 2 / 3) # C 56 | self.hidden_dim = ffn_mult * ((hidden_dim + ffn_mult - 1) // ffn_mult) # The next multiple of ffn_mult 57 | 58 | self.gate_proj = nn.Linear(d_embd, self.hidden_dim, bias=False) 59 | self.up_proj = nn.Linear(d_embd, self.hidden_dim, bias=False) 60 | self.down_proj = nn.Linear(self.hidden_dim, d_embd, bias=False) 61 | 62 | def __call__(self, x_BTD): 63 | h_BTC = nn.silu(self.gate_proj(x_BTD)) * self.up_proj(x_BTD) # SwiGLU 64 | out_BTD = self.down_proj(h_BTC) 65 | return out_BTD 66 | 67 | 68 | class TransformerBlock(nn.Module): 69 | def __init__(self, d_embd, norm_eps, **kwargs): 70 | super().__init__() 71 | self.pre_norm = nn.RMSNorm(d_embd, norm_eps) 72 | self.attn = SelfAttention(d_embd=d_embd, **kwargs) 73 | self.ffn_norm = nn.RMSNorm(d_embd, norm_eps) 74 | self.ffn = FeedForwardNet(d_embd=d_embd, **kwargs) 75 | 76 | def __call__(self, x_BTD, mask_TT): 77 | h_BTD = x_BTD + self.attn(self.pre_norm(x_BTD), mask_TT) 78 | out_BTD = h_BTD + self.ffn(self.ffn_norm(h_BTD)) 79 | return out_BTD 80 | 81 | 82 | class LLaMA(nn.Module): 83 | def __init__(self, vocab_size, n_layers, d_embd, norm_eps, **kwargs): 84 | super().__init__() 85 | self.embd_toks = nn.Embedding(vocab_size, d_embd) 86 | self.layers = [ 87 | TransformerBlock(d_embd=d_embd, norm_eps=norm_eps,**kwargs) 88 | for _ in range(n_layers) 89 | ] 90 | self.out_norm = nn.RMSNorm(d_embd, norm_eps) 91 | self.lm_head = nn.Linear(d_embd, vocab_size, bias=False) 92 | 93 | def __call__(self, tok_idxs_BT): 94 | h_BTD = self.embd_toks(tok_idxs_BT) 95 | 96 | causal_mask_TT = MHA.create_additive_causal_mask(h_BTD.shape[1]) 97 | for layer in self.layers: 98 | h_BTD = layer(h_BTD, causal_mask_TT) 99 | h_BTD = self.out_norm(h_BTD) 100 | 101 | logits_BTV = self.lm_head(h_BTD) 102 | 103 | return logits_BTV 104 | 105 | 106 | def init_params(model): 107 | mx.random.seed(3985) 108 | names, weights = zip(*tree_flatten(model)) 109 | 110 | def init_weight(name, weight): 111 | if 'proj' in name or 'lm_head' in name: 112 | n_in = weight.shape[1] 113 | return mx.random.normal(weight.shape, loc=0.0, scale=(n_in**-0.5)) 114 | elif 'embd_toks' in name: 115 | return mx.random.normal(weight.shape, loc=0.0, scale=1.0) 116 | elif 'norm' in name: 117 | return mx.full(weight.shape, 1.0) 118 | else: 119 | raise ValueError(f'Weight {name} not initialized.') 120 | 121 | inited_weights = tree_map(init_weight, names, weights) 122 | model.update(tree_unflatten([*zip(names, inited_weights)])) 123 | 124 | return model 125 | 126 | 127 | if __name__ == '__main__': 128 | cfg_m = LLaMAConfig() 129 | model = LLaMA(**asdict(cfg_m)) 130 | 131 | mx.random.seed(3985) 132 | tok_idxs = mx.random.randint(0, cfg_m.vocab_size, shape=[2, cfg_m.seq_len]) 133 | 134 | print(model(tok_idxs).shape) 135 | -------------------------------------------------------------------------------- /notebooks/dataset-dev.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 23, 6 | "id": "484284e2-d53e-49d0-ae79-f851fcde4aa5", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "27571" 13 | ] 14 | }, 15 | "execution_count": 23, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "from pathlib import Path\n", 22 | "import mlx.core as mx\n", 23 | "\n", 24 | "train_data_dir = Path('./wikitext_data/train')\n", 25 | "train_examples = []\n", 26 | "\n", 27 | "for ex_path in sorted(train_data_dir.glob('*.npz')):\n", 28 | " train_examples.extend(mx.load(str(ex_path)).values())\n", 29 | "\n", 30 | "len(train_examples)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 37, 36 | "id": "57d669c6-a43c-4f00-a7d9-6421e82ca415", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "seq_len = 512\n", 41 | "blk_size = seq_len + 1\n", 42 | "train_examples_pad = []\n", 43 | "\n", 44 | "for example in train_examples:\n", 45 | " example_pad = mx.pad(example, pad_width=[0, blk_size-example.size%blk_size], constant_values=0)\n", 46 | " train_examples_pad.append(example_pad)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "e8db1ce2-ad6f-4190-b707-3e8d6d1bd90e", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "from pathlib import Path\n", 57 | "import mlx.core as mx\n", 58 | "\n", 59 | "class WikiTextDataLoader:\n", 60 | " def __init__(self, bsz, seq_len, pad_token_id):\n", 61 | " train_data_dir = Path('./wikitext_data/train')\n", 62 | " train_examples = []\n", 63 | " for ex_path in sorted(train_data_dir.glob('*.npz')):\n", 64 | " train_examples.extend(mx.load(str(ex_path)).values())\n", 65 | " \n", 66 | " blk_size = seq_len + 1\n", 67 | " pad_example = lambda ex: mx.pad(ex, [0, blk_size - ex.size % blk_size], pad_token_id)\n", 68 | " train_examples = [*map(pad_example, train_examples)]\n", 69 | " self.train_examples = mx.concatenate(train_examples, axis=0)\n", 70 | "\n", 71 | " self.bsz = bsz\n", 72 | " self.blk_size = blk_size\n", 73 | " self.bblk_size = bsz * blk_size # Batch block size\n", 74 | " self.total_batches = len(self.train_examples) - blk_size + 1\n", 75 | "\n", 76 | " def __len__(self):\n", 77 | " return self.total_batches\n", 78 | "\n", 79 | " def __getitem__(self, idx):\n", 80 | " batch_block = self.train_examples[i:i+self.bblk_size]\n", 81 | " batch_block = batch_block.reshape([self.bsz, self.blk_size])\n", 82 | " return batch_block[:, :-1], batch_block[:, 1:]\n", 83 | "\n", 84 | " def __iter__(self):\n", 85 | " for i in range(self.total_batches):\n", 86 | " yield self[i]\n", 87 | "\n", 88 | "dataloader = WikiTextDataLoader(4, 512, 0)\n", 89 | "xb, yb = next(dataloader)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 1, 95 | "id": "a268fccf-14f6-486a-8d30-c728953fad0e", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "from pathlib import Path\n", 100 | "import mlx.core as mx\n", 101 | "\n", 102 | "def config_dataloader(bsz, seq_len, pad_token_id):\n", 103 | " train_data_dir = Path('./wikitext_data/train')\n", 104 | " train_examples = []\n", 105 | " for ex_path in sorted(train_data_dir.glob('*.npz')):\n", 106 | " train_examples.extend(mx.load(str(ex_path)).values())\n", 107 | "\n", 108 | " blk_size = seq_len + 1\n", 109 | " pad_example = lambda ex: mx.pad(ex, [0, blk_size - ex.size % blk_size], pad_token_id)\n", 110 | " train_examples = [*map(pad_example, train_examples)]\n", 111 | " train_examples = mx.concatenate(train_examples, axis=0)\n", 112 | "\n", 113 | " bblk_size = bsz * blk_size # Batch block size\n", 114 | " n_batches = len(train_examples) - blk_size + 1\n", 115 | "\n", 116 | " def load_data_():\n", 117 | " for i in range(n_batches):\n", 118 | " bblk = train_examples[i:i+bblk_size].reshape([bsz, blk_size])\n", 119 | " yield bblk[:, :-1], bblk[:, 1:]\n", 120 | "\n", 121 | " return load_data_\n", 122 | "\n", 123 | "load_data = config_dataloader(4, 512, 0)\n", 124 | "data_iter = iter(load_data())\n", 125 | "xb, yb = next(data_iter)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 2, 131 | "id": "dc2e878f-ee8b-4104-a9c4-bb07e653350c", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "from sentencepiece import SentencePieceProcessor\n", 136 | "\n", 137 | "sp_model = SentencePieceProcessor(model_file='tokenizer.model')" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 3, 143 | "id": "65cc4a3c-4f23-43b4-ae0d-fde2a4c714a2", 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "data": { 148 | "text/plain": [ 149 | "\"\\n Mathews was decorated by several governments , receiving appointments as a Companion of the Order of St Michael and St George , Companion of the Order of the Bath and as a Knight Commander of the Order of St Michael and St George from the British government and membership in the Prussian Order of the Crown . Zanzibar also rewarded him and he was a member of the Grand Order of Hamondieh and a first class member of the Order of the Brilliant Star of Zanzibar . Mathews died of malaria in Zanzibar on 11 October 1901 . \\n = = Early life and career = = \\n Mathews was born at Funchal on Madeira on 7 March 1850 . His father , Captain William Matthews was Welsh , and his mother Jane Wallis Penfold , was the daughter of William Penfold and Sarah Gilbert . Her sister , Augusta Jane Robley née Penfold was the author of a famous book about the flora and fauna of Madeira , which is now in the Natural History Museum . Mathews became a cadet of the Royal Navy in 1863 and was appointed a midshipman on 23 September 1866 . From 1868 he was stationed in the Mediterranean but his first active service was during the Third Anglo @-@ Ashanti War of 1873 – 4 where he qualified for the campaign medal . He was promoted to lieutenant on 31 March 1874 . On 27 August 1875 Mathews was posted to HMS London , a depot ship and the Royal Navy headquarters for East Africa , to assist in the suppression of the slave trade in the area . Whilst onboard he drilled his own troops , captured several slave dhows and was commended for his actions by the Admiralty . \\n = = Commander in Chief of Zanzibar = = \\n In August 1877 , Mathews was seconded from the Navy to Sultan Barghash of Zanzibar to form a European @-@ style army which could be used to enforce Zanzibar 's control over its mainland possessions . The army had traditionally been composed entirely of Arabs and Persians but Mathews opened up recruitment to\"" 150 | ] 151 | }, 152 | "execution_count": 3, 153 | "metadata": {}, 154 | "output_type": "execute_result" 155 | } 156 | ], 157 | "source": [ 158 | "sp_model.decode(xb[1, :].tolist())" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 4, 164 | "id": "fa0b7e45-e014-41e2-96c9-a5d429c51e74", 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "data": { 169 | "text/plain": [ 170 | "'the'" 171 | ] 172 | }, 173 | "execution_count": 4, 174 | "metadata": {}, 175 | "output_type": "execute_result" 176 | } 177 | ], 178 | "source": [ 179 | "sp_model.decode(yb[1].tolist())" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 5, 185 | "id": "097645a2-00e3-4896-be66-9d8a1069d5f0", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "xb, yb = next(data_iter)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 6, 195 | "id": "fbe74dbf-a972-49f8-a48c-96bedf096494", 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "data": { 200 | "text/plain": [ 201 | "\"Mathews was decorated by several governments , receiving appointments as a Companion of the Order of St Michael and St George , Companion of the Order of the Bath and as a Knight Commander of the Order of St Michael and St George from the British government and membership in the Prussian Order of the Crown . Zanzibar also rewarded him and he was a member of the Grand Order of Hamondieh and a first class member of the Order of the Brilliant Star of Zanzibar . Mathews died of malaria in Zanzibar on 11 October 1901 . \\n = = Early life and career = = \\n Mathews was born at Funchal on Madeira on 7 March 1850 . His father , Captain William Matthews was Welsh , and his mother Jane Wallis Penfold , was the daughter of William Penfold and Sarah Gilbert . Her sister , Augusta Jane Robley née Penfold was the author of a famous book about the flora and fauna of Madeira , which is now in the Natural History Museum . Mathews became a cadet of the Royal Navy in 1863 and was appointed a midshipman on 23 September 1866 . From 1868 he was stationed in the Mediterranean but his first active service was during the Third Anglo @-@ Ashanti War of 1873 – 4 where he qualified for the campaign medal . He was promoted to lieutenant on 31 March 1874 . On 27 August 1875 Mathews was posted to HMS London , a depot ship and the Royal Navy headquarters for East Africa , to assist in the suppression of the slave trade in the area . Whilst onboard he drilled his own troops , captured several slave dhows and was commended for his actions by the Admiralty . \\n = = Commander in Chief of Zanzibar = = \\n In August 1877 , Mathews was seconded from the Navy to Sultan Barghash of Zanzibar to form a European @-@ style army which could be used to enforce Zanzibar 's control over its mainland possessions . The army had traditionally been composed entirely of Arabs and Persians but Mathews opened up recruitment to the\"" 202 | ] 203 | }, 204 | "execution_count": 6, 205 | "metadata": {}, 206 | "output_type": "execute_result" 207 | } 208 | ], 209 | "source": [ 210 | "sp_model.decode(xb[1, :].tolist())" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 7, 216 | "id": "6130ee57-0c3b-4627-8929-8ddbb12c6c36", 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "data": { 221 | "text/plain": [ 222 | "'African'" 223 | ] 224 | }, 225 | "execution_count": 7, 226 | "metadata": {}, 227 | "output_type": "execute_result" 228 | } 229 | ], 230 | "source": [ 231 | "sp_model.decode(yb[1].tolist())" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 8, 237 | "id": "78aecf30-0b01-4791-b9bf-21ea795363e9", 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "name": "stdout", 242 | "output_type": "stream", 243 | "text": [ 244 | "array([[314, 869, 910, ..., 408, 263, 937],\n", 245 | " [4509, 310, 278, ..., 17517, 869, 5345],\n", 246 | " [5652, 525, 9213, ..., 4096, 747, 279],\n", 247 | " [278, 4908, 5874, ..., 13, 2, 0]], dtype=uint16)\n", 248 | "am . This dhow had around 100 slaves on board and was transporting them between Pemba and Zanzibar . Captain Brownrigg led a boarding party to release the slaves but bin Hattam 's men then attacked the sailors , killing Brownrigg and his party before sailing away . Mathews led a force to Wete on Pemba and , after a short battle , took a mortally wounded bin Hattem prisoner before returning to Zanzibar . \n", 249 | " Mathews returned to the African mainland territories once more in 1884 when he landed with a force which intended to establish further garrisons there to dissuade German territorial claims . This attempt ultimately failed when five German warships steamed into Zanzibar Town harbour and threatened the Sultan into signing away the territories which would later form German East Africa . Further territories were ceded to the German East Africa Company in 1888 but unrest amongst the locals against them prevented them from taking control and Mathews was dispatched with 100 men to restore order . Finding around 8 @,@ 000 people gathered against the German administrators Mathews was forced to return with his men to Zanzibar . He landed once again with more troops but found himself subject to death threats and that his troops would not obey his orders and so returned again to Zanzibar . \n", 250 | " = = First Minister = = \n", 251 | " In October 1891 , upon the formation of the first constitutional government in Zanzibar , Mathews was appointed First Minister , despite some hostility from Sultan Ali bin Said . In this capacity Mathews was \" irremovable by the sultan \" and answerable only to the Sultan and the British Consul . His position was so strong that one missionary on the island is quoted as saying that his powers defied \" analytical examination \" and that Mathews really could say \" L 'état est moi \" ( I am the state ) . Mathews was also known as the \" Strong man of Zanzibar \" . The principal departments of government were mostly run by Britons or British Indians and Mathews ' approval was required before they could be removed from office . Mathews was rewarded by the Zanzibar government for his role with his appointment as a first\n" 252 | ] 253 | } 254 | ], 255 | "source": [ 256 | "for xb, _ in load_data():\n", 257 | " if mx.any(xb == 0):\n", 258 | " print(xb)\n", 259 | " print(sp_model.decode(xb[0, :].tolist()))\n", 260 | " break" 261 | ] 262 | } 263 | ], 264 | "metadata": { 265 | "kernelspec": { 266 | "display_name": "Python 3 (ipykernel)", 267 | "language": "python", 268 | "name": "python3" 269 | }, 270 | "language_info": { 271 | "codemirror_mode": { 272 | "name": "ipython", 273 | "version": 3 274 | }, 275 | "file_extension": ".py", 276 | "mimetype": "text/x-python", 277 | "name": "python", 278 | "nbconvert_exporter": "python", 279 | "pygments_lexer": "ipython3", 280 | "version": "3.11.6" 281 | } 282 | }, 283 | "nbformat": 4, 284 | "nbformat_minor": 5 285 | } 286 | -------------------------------------------------------------------------------- /notebooks/llama-dev.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "30c09be3-50a2-4fbe-ac93-fb1f5bcc5526", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from dataclasses import dataclass, asdict\n", 11 | "\n", 12 | "@dataclass\n", 13 | "class ModelConfig:\n", 14 | " '''\n", 15 | " Model configuration reference:\n", 16 | " https://github.com/epfml/llm-baselines?tab=readme-ov-file#results-on-wikitext\n", 17 | " '''\n", 18 | " n_layers: int = 24\n", 19 | " vocab_size: int = 32000 # LLaMA 2 tokenizer\n", 20 | " d_embd: int = 768\n", 21 | " n_heads: int = 12\n", 22 | " seq_len: int = 512\n", 23 | " rope_theta: float = 1e4\n", 24 | " rope_scale: float = 1.0\n", 25 | " ffn_mult: int = 256\n", 26 | " norm_eps: float = 1e-5\n", 27 | "\n", 28 | "cfg_m = ModelConfig()" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "id": "9ca9d625-839c-49c9-96c8-52ff7fa9430d", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "def x_test(B=2):\n", 39 | " mx.random.seed(3985)\n", 40 | " x_ = mx.random.uniform(shape=[B, cfg_m.seq_len, cfg_m.d_embd])\n", 41 | " return x_" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "id": "6ea0b824-41b7-4e2a-bba5-d3df45d0556b", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import mlx.core as mx\n", 52 | "from mlx import nn" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 32, 58 | "id": "1d71dd9e-203f-429e-a9df-7be6dcbdcaf6", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "((2, 512, 768), (512, 768))" 65 | ] 66 | }, 67 | "execution_count": 32, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "from mlx.core.fast import scaled_dot_product_attention\n", 74 | "\n", 75 | "class SelfAttention(nn.Module):\n", 76 | " def __init__(self, d_embd, n_heads, rope_theta, rope_scale, **kwargs):\n", 77 | " super().__init__()\n", 78 | " assert d_embd % n_heads == 0\n", 79 | " self.d_head = d_embd // n_heads\n", 80 | "\n", 81 | " self.attn_proj = nn.Linear(d_embd, 3*d_embd, bias=False)\n", 82 | " self.rope = nn.RoPE(self.d_head, base=rope_theta, scale=rope_scale)\n", 83 | " self.scale = self.d_head ** -0.5\n", 84 | " self.out_proj = nn.Linear(d_embd, d_embd, bias=False)\n", 85 | "\n", 86 | " def __call__(self, x, mask):\n", 87 | " bsz, seq_len, d_embd = x.shape\n", 88 | "\n", 89 | " # [bsz, seq_len, d_embd] * 3\n", 90 | " qkv = self.attn_proj(x).split(3, axis=-1)\n", 91 | "\n", 92 | " # bsz, n_heads, seq_len, d_head\n", 93 | " to_attn_heads = lambda z: z.reshape(bsz, seq_len, -1, self.d_head).transpose(0, 2, 1, 3)\n", 94 | " Q, K, V = map(to_attn_heads, qkv)\n", 95 | "\n", 96 | " # Apply rotary embeddings\n", 97 | " Q, K = self.rope(Q), self.rope(K)\n", 98 | "\n", 99 | " # bsz, n_head, seq_len, d_head\n", 100 | " O = scaled_dot_product_attention(Q, K, V, scale=self.scale, mask=mask)\n", 101 | "\n", 102 | " # bsz, seq_len, d_embd\n", 103 | " output = self.out_proj(O.transpose(0, 2, 1, 3).reshape(bsz, seq_len, d_embd))\n", 104 | "\n", 105 | " return output\n", 106 | "\n", 107 | "attn = SelfAttention(**asdict(cfg_m))\n", 108 | "attn(x_test(), mask=nn.MultiHeadAttention.create_additive_causal_mask(cfg_m.seq_len)).shape, (cfg_m.seq_len, cfg_m.d_embd)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 55, 114 | "id": "453e4e0d-d163-4884-86be-d83e044e48a9", 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "data": { 119 | "text/plain": [ 120 | "(2, 512, 768)" 121 | ] 122 | }, 123 | "execution_count": 55, 124 | "metadata": {}, 125 | "output_type": "execute_result" 126 | } 127 | ], 128 | "source": [ 129 | "class FeedForwardNet(nn.Module):\n", 130 | " def __init__(self, d_embd, ffn_mult, **kwargs):\n", 131 | " super().__init__()\n", 132 | " hidden_dim = int((4 * d_embd) * 2 / 3)\n", 133 | " hidden_dim = ffn_mult * ((hidden_dim + ffn_mult - 1) // ffn_mult) # The next multiple of ffn_mult\n", 134 | "\n", 135 | " self.gate_proj = nn.Linear(d_embd, hidden_dim, bias=False)\n", 136 | " self.up_proj = nn.Linear(d_embd, hidden_dim, bias=False)\n", 137 | " self.down_proj = nn.Linear(hidden_dim, d_embd, bias=False)\n", 138 | "\n", 139 | " def __call__(self, x):\n", 140 | " h = nn.silu(self.gate_proj(x)) * self.up_proj(x) # SwiGLU\n", 141 | " out = self.down_proj(h)\n", 142 | " return out\n", 143 | "\n", 144 | "ffn = FeedForwardNet(**asdict(cfg_m))\n", 145 | "ffn(x_test()).shape" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 56, 151 | "id": "f2a54139-e866-45b4-a7b3-4023b52d3d4a", 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "data": { 156 | "text/plain": [ 157 | "(2, 512, 768)" 158 | ] 159 | }, 160 | "execution_count": 56, 161 | "metadata": {}, 162 | "output_type": "execute_result" 163 | } 164 | ], 165 | "source": [ 166 | "class TransformerBlock(nn.Module):\n", 167 | " def __init__(self, d_embd, norm_eps, **kwargs):\n", 168 | " super().__init__()\n", 169 | " self.pre_norm = nn.RMSNorm(d_embd, norm_eps)\n", 170 | " self.attn = SelfAttention(d_embd=d_embd, **kwargs)\n", 171 | " self.ffn_norm = nn.RMSNorm(d_embd, norm_eps)\n", 172 | " self.ffn = FeedForwardNet(d_embd=d_embd, **kwargs)\n", 173 | "\n", 174 | " def __call__(self, x, mask):\n", 175 | " h = x + self.attn(self.pre_norm(x), mask)\n", 176 | " out = h + self.ffn(self.ffn_norm(h))\n", 177 | " return out\n", 178 | "\n", 179 | "layer = TransformerBlock(**asdict(cfg_m))\n", 180 | "layer(x_test(), mask=nn.MultiHeadAttention.create_additive_causal_mask(cfg_m.seq_len)).shape" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 57, 186 | "id": "8d7fe044-e056-4ff4-9601-f959ebd340c6", 187 | "metadata": {}, 188 | "outputs": [ 189 | { 190 | "data": { 191 | "text/plain": [ 192 | "(2, 512, 32000)" 193 | ] 194 | }, 195 | "execution_count": 57, 196 | "metadata": {}, 197 | "output_type": "execute_result" 198 | } 199 | ], 200 | "source": [ 201 | "class LLaMA(nn.Module):\n", 202 | " def __init__(self, vocab_size, n_layers, d_embd, norm_eps, **kwargs):\n", 203 | " super().__init__()\n", 204 | " self.embd_toks = nn.Embedding(vocab_size, d_embd)\n", 205 | " self.layers = [\n", 206 | " TransformerBlock(d_embd=d_embd, norm_eps=norm_eps,**kwargs)\n", 207 | " for _ in range(n_layers)\n", 208 | " ]\n", 209 | " self.out_norm = nn.RMSNorm(d_embd, norm_eps)\n", 210 | " self.lm_head = nn.Linear(d_embd, vocab_size, bias=False)\n", 211 | "\n", 212 | " def __call__(self, tok_idxs):\n", 213 | " # bsz, seq_len, d_embd\n", 214 | " h = self.embd_toks(tok_idxs)\n", 215 | "\n", 216 | " # bsz, seq_len, d_embd\n", 217 | " causal_mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])\n", 218 | " for layer in self.layers:\n", 219 | " h = layer(h, causal_mask)\n", 220 | " h = self.out_norm(h)\n", 221 | "\n", 222 | " # bsz, seq_len, vocab_size\n", 223 | " logits = self.lm_head(h)\n", 224 | "\n", 225 | " return logits\n", 226 | "\n", 227 | "model = LLaMA(**asdict(cfg_m))\n", 228 | "model(mx.random.randint(0, cfg.vocab_size, shape=[2, cfg.seq_len])).shape" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 50, 234 | "id": "d339ad76-b4e6-4241-9eac-fb747a148528", 235 | "metadata": { 236 | "scrolled": true 237 | }, 238 | "outputs": [ 239 | { 240 | "data": { 241 | "text/plain": [ 242 | "LLaMA(\n", 243 | " (embd_toks): Embedding(32000, 768)\n", 244 | " (layers.0): TransformerBlock(\n", 245 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 246 | " (attn): SelfAttention(\n", 247 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 248 | " (rope): RoPE(64, traditional=False)\n", 249 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 250 | " )\n", 251 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 252 | " (ffn): FeedForwardNet(\n", 253 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 254 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 255 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 256 | " )\n", 257 | " )\n", 258 | " (layers.1): TransformerBlock(\n", 259 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 260 | " (attn): SelfAttention(\n", 261 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 262 | " (rope): RoPE(64, traditional=False)\n", 263 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 264 | " )\n", 265 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 266 | " (ffn): FeedForwardNet(\n", 267 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 268 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 269 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 270 | " )\n", 271 | " )\n", 272 | " (layers.2): TransformerBlock(\n", 273 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 274 | " (attn): SelfAttention(\n", 275 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 276 | " (rope): RoPE(64, traditional=False)\n", 277 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 278 | " )\n", 279 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 280 | " (ffn): FeedForwardNet(\n", 281 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 282 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 283 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 284 | " )\n", 285 | " )\n", 286 | " (layers.3): TransformerBlock(\n", 287 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 288 | " (attn): SelfAttention(\n", 289 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 290 | " (rope): RoPE(64, traditional=False)\n", 291 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 292 | " )\n", 293 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 294 | " (ffn): FeedForwardNet(\n", 295 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 296 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 297 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 298 | " )\n", 299 | " )\n", 300 | " (layers.4): TransformerBlock(\n", 301 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 302 | " (attn): SelfAttention(\n", 303 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 304 | " (rope): RoPE(64, traditional=False)\n", 305 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 306 | " )\n", 307 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 308 | " (ffn): FeedForwardNet(\n", 309 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 310 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 311 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 312 | " )\n", 313 | " )\n", 314 | " (layers.5): TransformerBlock(\n", 315 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 316 | " (attn): SelfAttention(\n", 317 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 318 | " (rope): RoPE(64, traditional=False)\n", 319 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 320 | " )\n", 321 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 322 | " (ffn): FeedForwardNet(\n", 323 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 324 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 325 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 326 | " )\n", 327 | " )\n", 328 | " (layers.6): TransformerBlock(\n", 329 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 330 | " (attn): SelfAttention(\n", 331 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 332 | " (rope): RoPE(64, traditional=False)\n", 333 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 334 | " )\n", 335 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 336 | " (ffn): FeedForwardNet(\n", 337 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 338 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 339 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 340 | " )\n", 341 | " )\n", 342 | " (layers.7): TransformerBlock(\n", 343 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 344 | " (attn): SelfAttention(\n", 345 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 346 | " (rope): RoPE(64, traditional=False)\n", 347 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 348 | " )\n", 349 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 350 | " (ffn): FeedForwardNet(\n", 351 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 352 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 353 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 354 | " )\n", 355 | " )\n", 356 | " (layers.8): TransformerBlock(\n", 357 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 358 | " (attn): SelfAttention(\n", 359 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 360 | " (rope): RoPE(64, traditional=False)\n", 361 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 362 | " )\n", 363 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 364 | " (ffn): FeedForwardNet(\n", 365 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 366 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 367 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 368 | " )\n", 369 | " )\n", 370 | " (layers.9): TransformerBlock(\n", 371 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 372 | " (attn): SelfAttention(\n", 373 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 374 | " (rope): RoPE(64, traditional=False)\n", 375 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 376 | " )\n", 377 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 378 | " (ffn): FeedForwardNet(\n", 379 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 380 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 381 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 382 | " )\n", 383 | " )\n", 384 | " (layers.10): TransformerBlock(\n", 385 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 386 | " (attn): SelfAttention(\n", 387 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 388 | " (rope): RoPE(64, traditional=False)\n", 389 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 390 | " )\n", 391 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 392 | " (ffn): FeedForwardNet(\n", 393 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 394 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 395 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 396 | " )\n", 397 | " )\n", 398 | " (layers.11): TransformerBlock(\n", 399 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 400 | " (attn): SelfAttention(\n", 401 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 402 | " (rope): RoPE(64, traditional=False)\n", 403 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 404 | " )\n", 405 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 406 | " (ffn): FeedForwardNet(\n", 407 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 408 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 409 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 410 | " )\n", 411 | " )\n", 412 | " (layers.12): TransformerBlock(\n", 413 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 414 | " (attn): SelfAttention(\n", 415 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 416 | " (rope): RoPE(64, traditional=False)\n", 417 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 418 | " )\n", 419 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 420 | " (ffn): FeedForwardNet(\n", 421 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 422 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 423 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 424 | " )\n", 425 | " )\n", 426 | " (layers.13): TransformerBlock(\n", 427 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 428 | " (attn): SelfAttention(\n", 429 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 430 | " (rope): RoPE(64, traditional=False)\n", 431 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 432 | " )\n", 433 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 434 | " (ffn): FeedForwardNet(\n", 435 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 436 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 437 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 438 | " )\n", 439 | " )\n", 440 | " (layers.14): TransformerBlock(\n", 441 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 442 | " (attn): SelfAttention(\n", 443 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 444 | " (rope): RoPE(64, traditional=False)\n", 445 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 446 | " )\n", 447 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 448 | " (ffn): FeedForwardNet(\n", 449 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 450 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 451 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 452 | " )\n", 453 | " )\n", 454 | " (layers.15): TransformerBlock(\n", 455 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 456 | " (attn): SelfAttention(\n", 457 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 458 | " (rope): RoPE(64, traditional=False)\n", 459 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 460 | " )\n", 461 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 462 | " (ffn): FeedForwardNet(\n", 463 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 464 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 465 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 466 | " )\n", 467 | " )\n", 468 | " (layers.16): TransformerBlock(\n", 469 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 470 | " (attn): SelfAttention(\n", 471 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 472 | " (rope): RoPE(64, traditional=False)\n", 473 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 474 | " )\n", 475 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 476 | " (ffn): FeedForwardNet(\n", 477 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 478 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 479 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 480 | " )\n", 481 | " )\n", 482 | " (layers.17): TransformerBlock(\n", 483 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 484 | " (attn): SelfAttention(\n", 485 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 486 | " (rope): RoPE(64, traditional=False)\n", 487 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 488 | " )\n", 489 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 490 | " (ffn): FeedForwardNet(\n", 491 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 492 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 493 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 494 | " )\n", 495 | " )\n", 496 | " (layers.18): TransformerBlock(\n", 497 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 498 | " (attn): SelfAttention(\n", 499 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 500 | " (rope): RoPE(64, traditional=False)\n", 501 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 502 | " )\n", 503 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 504 | " (ffn): FeedForwardNet(\n", 505 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 506 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 507 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 508 | " )\n", 509 | " )\n", 510 | " (layers.19): TransformerBlock(\n", 511 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 512 | " (attn): SelfAttention(\n", 513 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 514 | " (rope): RoPE(64, traditional=False)\n", 515 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 516 | " )\n", 517 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 518 | " (ffn): FeedForwardNet(\n", 519 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 520 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 521 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 522 | " )\n", 523 | " )\n", 524 | " (layers.20): TransformerBlock(\n", 525 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 526 | " (attn): SelfAttention(\n", 527 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 528 | " (rope): RoPE(64, traditional=False)\n", 529 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 530 | " )\n", 531 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 532 | " (ffn): FeedForwardNet(\n", 533 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 534 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 535 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 536 | " )\n", 537 | " )\n", 538 | " (layers.21): TransformerBlock(\n", 539 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 540 | " (attn): SelfAttention(\n", 541 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 542 | " (rope): RoPE(64, traditional=False)\n", 543 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 544 | " )\n", 545 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 546 | " (ffn): FeedForwardNet(\n", 547 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 548 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 549 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 550 | " )\n", 551 | " )\n", 552 | " (layers.22): TransformerBlock(\n", 553 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 554 | " (attn): SelfAttention(\n", 555 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 556 | " (rope): RoPE(64, traditional=False)\n", 557 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 558 | " )\n", 559 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 560 | " (ffn): FeedForwardNet(\n", 561 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 562 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 563 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 564 | " )\n", 565 | " )\n", 566 | " (layers.23): TransformerBlock(\n", 567 | " (pre_norm): RMSNorm(768, eps=1e-05)\n", 568 | " (attn): SelfAttention(\n", 569 | " (attn_proj): Linear(input_dims=768, output_dims=2304, bias=False)\n", 570 | " (rope): RoPE(64, traditional=False)\n", 571 | " (out_proj): Linear(input_dims=768, output_dims=768, bias=False)\n", 572 | " )\n", 573 | " (ffn_norm): RMSNorm(768, eps=1e-05)\n", 574 | " (ffn): FeedForwardNet(\n", 575 | " (gate_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 576 | " (up_proj): Linear(input_dims=768, output_dims=2048, bias=False)\n", 577 | " (down_proj): Linear(input_dims=2048, output_dims=768, bias=False)\n", 578 | " )\n", 579 | " )\n", 580 | " (out_norm): RMSNorm(768, eps=1e-05)\n", 581 | " (lm_head): Linear(input_dims=768, output_dims=32000, bias=False)\n", 582 | ")" 583 | ] 584 | }, 585 | "execution_count": 50, 586 | "metadata": {}, 587 | "output_type": "execute_result" 588 | } 589 | ], 590 | "source": [ 591 | "model" 592 | ] 593 | } 594 | ], 595 | "metadata": { 596 | "kernelspec": { 597 | "display_name": "Python 3 (ipykernel)", 598 | "language": "python", 599 | "name": "python3" 600 | }, 601 | "language_info": { 602 | "codemirror_mode": { 603 | "name": "ipython", 604 | "version": 3 605 | }, 606 | "file_extension": ".py", 607 | "mimetype": "text/x-python", 608 | "name": "python", 609 | "nbconvert_exporter": "python", 610 | "pygments_lexer": "ipython3", 611 | "version": "3.11.6" 612 | } 613 | }, 614 | "nbformat": 4, 615 | "nbformat_minor": 5 616 | } 617 | -------------------------------------------------------------------------------- /notebooks/mixture-of-depths-dev.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "id": "0aa8c050-28eb-48f3-81c1-91bd8803e000", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import mlx.core as mx\n", 11 | "from mlx import nn" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "935906c2-d490-4286-bdd7-45aa378bb4ed", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "mx.random.seed(3985)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 11, 27 | "id": "cdc7fbf4-a1d5-48f2-875f-477a1267045c", 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "data": { 32 | "text/plain": [ 33 | "array([[[0.324384, 0.0728916, 0.692562, ..., 0.473036, 0.732917, 0.613625],\n", 34 | " [0.937988, 0.821932, 0.353415, ..., 0.699138, 0.0826687, 0.684287],\n", 35 | " [0.963293, 0.217973, 0.239089, ..., 0.435873, 0.525348, 0.027538],\n", 36 | " [0.332283, 0.22946, 0.521798, ..., 0.418693, 0.172103, 0.0305814]],\n", 37 | " [[0.380298, 0.500459, 0.931418, ..., 0.795432, 0.0933626, 0.835255],\n", 38 | " [0.328979, 0.496543, 0.868076, ..., 0.533726, 0.293846, 0.770115],\n", 39 | " [0.442788, 0.287145, 0.0224269, ..., 0.171359, 0.840324, 0.175392],\n", 40 | " [0.408249, 0.0423786, 0.482122, ..., 0.301558, 0.276174, 0.602193]]], dtype=float32)" 41 | ] 42 | }, 43 | "execution_count": 11, 44 | "metadata": {}, 45 | "output_type": "execute_result" 46 | } 47 | ], 48 | "source": [ 49 | "B, T, C = 2, 4, 8 # batch_size, seq_len, d_embd\n", 50 | "x = mx.random.uniform(shape=[B, T, C])\n", 51 | "x" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 63, 57 | "id": "9579ebcb-2d05-4e9e-87f9-e34aa272ad78", 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "(array([[[0.557108, 0.0590587, 0.603644, ..., 0.288003, 0.565407, 0.752665],\n", 64 | " [-0.286823, 0.1454, -0.0832467, ..., 0.714576, 0.233552, 0.925167],\n", 65 | " [1.15273, 1.4296, -0.337541, ..., 0.774492, 1.09024, -0.284072],\n", 66 | " [0.694623, 1.15832, 0.722264, ..., 0.151736, 0.672384, -0.699643]],\n", 67 | " [[0.947543, 2.18298, -0.407171, ..., 0.649991, 1.21402, -0.0334666],\n", 68 | " [0.0934433, -0.0856612, 0.686615, ..., 0.866722, 0.30294, 0.53751],\n", 69 | " [0.693493, 0.788395, -0.268603, ..., 0.715195, 1.10677, 0.194276],\n", 70 | " [0.0828968, 0.234539, 0.210647, ..., 0.543489, 0.562059, 0.0940933]]], dtype=float32),\n", 71 | " array(0.705366, dtype=float32))" 72 | ] 73 | }, 74 | "execution_count": 63, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | } 78 | ], 79 | "source": [ 80 | "class MixtureOfDepths(nn.Module):\n", 81 | " def __init__(self, block, capacity_factor, seq_len, n_embd):\n", 82 | " super().__init__()\n", 83 | " self.block = block\n", 84 | " self.capacity_factor = capacity_factor\n", 85 | " self.capacity = int(capacity_factor * seq_len)\n", 86 | " self.router = nn.Linear(n_embd, 1)\n", 87 | "\n", 88 | " def __call__(self, x):\n", 89 | " B, T = x.shape[:2] # batch_size, seq_len\n", 90 | "\n", 91 | " # Top k expert choice\n", 92 | " r = self.router(x).squeeze(-1)\n", 93 | " capacity = min(self.capacity, self.capacity_factor*T)\n", 94 | " chosen_idx = mx.argpartition(-r, capacity, axis=1)[:, :capacity]\n", 95 | "\n", 96 | " # Sorted top k to preserve token causality\n", 97 | " # mx.sort does not support uint32?\n", 98 | " chosen_idx = mx.sort(chosen_idx.astype(mx.float32), axis=1).astype(mx.uint32)\n", 99 | "\n", 100 | " # Process chosen tokens\n", 101 | " batch_idx = mx.arange(B)[:, None]\n", 102 | " chosen_r = r[batch_idx, chosen_idx, None]\n", 103 | " chosen_x = x[batch_idx, chosen_idx, :]\n", 104 | " process_x = self.block(chosen_x)\n", 105 | " x[batch_idx, chosen_idx, :] += chosen_r * process_x\n", 106 | "\n", 107 | " # Auxiliary loss for training the router\n", 108 | " r_nll = -nn.log_softmax(chosen_r[..., 0], axis=-1).mean()\n", 109 | "\n", 110 | " return x, r_nll\n", 111 | "\n", 112 | "mod = MixtureOfDepths(nn.Linear(C, C), 0.5, T, C)\n", 113 | "mod(x)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 47, 119 | "id": "d831ee58-ee93-4382-bc26-05c75616591e", 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "data": { 124 | "text/plain": [ 125 | "(2, 4, 128)" 126 | ] 127 | }, 128 | "execution_count": 47, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "from mlx.core.fast import scaled_dot_product_attention\n", 135 | "\n", 136 | "class CausalSelfAttention(nn.Module):\n", 137 | " def __init__(self, n_head, d_embd, p_drop, **kwargs):\n", 138 | " assert d_embd % n_head == 0\n", 139 | " super().__init__()\n", 140 | " self.n_head = n_head\n", 141 | " self.scale = (d_embd / n_head) ** 0.5\n", 142 | "\n", 143 | " self.attn_proj = nn.Linear(d_embd, 3*d_embd, bias=False)\n", 144 | " self.out_proj = nn.Linear(d_embd, d_embd, bias=False)\n", 145 | " self.resid_drop = nn.Dropout(p_drop)\n", 146 | "\n", 147 | " def __call__(self, x):\n", 148 | " B, T = x.shape[:2]\n", 149 | "\n", 150 | " qkv = self.attn_proj(x).split(3, axis=-1) # B, T, d_embd\n", 151 | " to_attn_weights = lambda z: z.reshape(B, T, self.n_head, -1).transpose(0, 2, 1, 3)\n", 152 | " Q, K, V = map(to_attn_weights, qkv) # B, n_head, T, d_head\n", 153 | "\n", 154 | " # MLX SDPA does not support dropout?\n", 155 | " causal_mask = nn.MultiHeadAttention.create_additive_causal_mask(T)\n", 156 | " O = scaled_dot_product_attention(Q, K, V, scale=self.scale, mask=causal_mask) # B, n_head, T, d_head\n", 157 | " O = O.transpose(0, 2, 1, 3).reshape(B, T, -1) # B, T, d_embd\n", 158 | "\n", 159 | " output = self.resid_drop(self.out_proj(O))\n", 160 | "\n", 161 | " return output\n", 162 | "\n", 163 | "attn = CausalSelfAttention(8, 128, 0.1)\n", 164 | "attn(mx.random.uniform(shape=[2, 4, 128])).shape" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 51, 170 | "id": "af59162d-4d13-461b-a928-2cc601cb4cdc", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "class FeedForwardNet(nn.Module):\n", 175 | " def __init__(self, d_embd, p_drop, **kwargs):\n", 176 | " super().__init__()\n", 177 | " self.up_proj = nn.Linear(d_embd, 4*d_embd, bias=False)\n", 178 | " self.down_proj = nn.Linear(4*d_embd, d_embd, bias=False)\n", 179 | " self.dropout = nn.Dropout(p_drop)\n", 180 | "\n", 181 | " def __call__(self, x):\n", 182 | " x = nn.gelu(self.up_proj(x))\n", 183 | " x = self.dropout(self.down_proj(x))\n", 184 | " return x" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 52, 190 | "id": "0057d629-043e-4f6e-ac3c-22bde2cf9ad9", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "class TransformerBlock(nn.Module):\n", 195 | " def __init__(self, d_embd, **kwargs):\n", 196 | " super().__init__()\n", 197 | " self.pre_norm = nn.LayerNorm(d_embd, bias=False)\n", 198 | " self.self_attn = CausalSelfAttention(d_embd=d_embd, **kwargs)\n", 199 | " self.post_norm = nn.LayerNorm(d_embd, bias=False)\n", 200 | " self.ffn = FeedForwardNet(d_embd=d_embd, **kwargs)\n", 201 | "\n", 202 | " def __call__(self, x):\n", 203 | " x = self.self_attn(self.pre_norm(x)) + x\n", 204 | " x = self.ffn(self.post_norm(x))\n", 205 | " return x" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 59, 211 | "id": "d7385929-e43c-477f-ad0b-1d4f0d34880f", 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "class GPT(nn.Module):\n", 216 | " def __init__(self, n_vocab, n_ctx, d_embd, p_drop, n_layers, **kwargs):\n", 217 | " super().__init__()\n", 218 | "\n", 219 | " self.tok_embd = nn.Embedding(n_vocab, d_embd)\n", 220 | " self.pos_embd = nn.Embedding(n_ctx, d_embd)\n", 221 | " self.dropout = nn.Dropout(p_drop)\n", 222 | "\n", 223 | " self.blocks = [\n", 224 | " TransformerBlock(d_embd=d_embd, p_drop=p_drop, **kwargs)\n", 225 | " for _ in range(n_layers)\n", 226 | " ]\n", 227 | "\n", 228 | " self.norm = nn.LayerNorm(d_embd, bias=False)\n", 229 | " self.lm_proj = nn.Linear(d_embd, n_vocab, bias=False)\n", 230 | "\n", 231 | " def __call__(self, tok_idx):\n", 232 | " T = tok_idx.shape[1]\n", 233 | "\n", 234 | " tok_embd = self.tok_embd(tok_idx)\n", 235 | " pos_embd = self.pos_embd(mx.arange(T))\n", 236 | " x = self.dropout(tok_embd + pos_embd)\n", 237 | "\n", 238 | " for block in self.blocks:\n", 239 | " x = block(x)\n", 240 | "\n", 241 | " logits = self.lm_proj(self.norm(x))\n", 242 | "\n", 243 | " return logits" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 57, 249 | "id": "fc40f09e-d79f-409c-9343-d998c65529a9", 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "data": { 254 | "text/plain": [ 255 | "ModelConfig(n_vocab=128, n_ctx=32, n_layers=4, d_embd=256, n_head=8, p_drop=0.1)" 256 | ] 257 | }, 258 | "execution_count": 57, 259 | "metadata": {}, 260 | "output_type": "execute_result" 261 | } 262 | ], 263 | "source": [ 264 | "from dataclasses import dataclass, asdict\n", 265 | "\n", 266 | "@dataclass\n", 267 | "class ModelConfig:\n", 268 | " n_vocab: int\n", 269 | " n_ctx: int\n", 270 | " n_layers: int\n", 271 | " d_embd: int\n", 272 | " n_head: int\n", 273 | " p_drop: float\n", 274 | "\n", 275 | "cfg = ModelConfig(128, 32, 4, 256, 8, 0.1)\n", 276 | "cfg" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 61, 282 | "id": "8dc46bdd-ec1d-4feb-9408-877955782d81", 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "data": { 287 | "text/plain": [ 288 | "(2, 32, 128)" 289 | ] 290 | }, 291 | "execution_count": 61, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "model = GPT(**asdict(cfg))\n", 298 | "model(mx.random.randint(0, cfg.n_vocab, shape=[2, cfg.n_ctx])).shape" 299 | ] 300 | } 301 | ], 302 | "metadata": { 303 | "kernelspec": { 304 | "display_name": "Python 3 (ipykernel)", 305 | "language": "python", 306 | "name": "python3" 307 | }, 308 | "language_info": { 309 | "codemirror_mode": { 310 | "name": "ipython", 311 | "version": 3 312 | }, 313 | "file_extension": ".py", 314 | "mimetype": "text/x-python", 315 | "name": "python", 316 | "nbconvert_exporter": "python", 317 | "pygments_lexer": "ipython3", 318 | "version": "3.11.6" 319 | } 320 | }, 321 | "nbformat": 4, 322 | "nbformat_minor": 5 323 | } 324 | -------------------------------------------------------------------------------- /notebooks/tokenizer-dev.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c49fc27f-7d34-411e-8853-c75ae7686ee7", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from datasets import load_dataset\n", 11 | "\n", 12 | "dataset = load_dataset('wikitext', 'wikitext-103-v1')" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "d1e55629-9cb5-4aa1-813b-2053ac22d2e3", 18 | "metadata": {}, 19 | "source": [ 20 | "- Concatenate every title with the following paragraphs until the next title\n", 21 | "- One training example: `BOS + TITLE + PARAGRAPHS + EOS`" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "id": "268c08af-a00b-408a-b439-971f42a0f78e", 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "data": { 32 | "text/plain": [ 33 | "' = Valkyria Chronicles III = \\n'" 34 | ] 35 | }, 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "output_type": "execute_result" 39 | } 40 | ], 41 | "source": [ 42 | "train_text = dataset['train']['text']\n", 43 | "train_text[1]" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 206, 49 | "id": "b79d96a2-bf86-4ec2-bc98-86f831b9f9b3", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "True" 56 | ] 57 | }, 58 | "execution_count": 206, 59 | "metadata": {}, 60 | "output_type": "execute_result" 61 | } 62 | ], 63 | "source": [ 64 | "def is_title(text):\n", 65 | " has_format = (text[:3] == ' = ' and text[-4:] == ' = \\n')\n", 66 | " return has_format and text[3].isupper()\n", 67 | "\n", 68 | "is_title(train_text[1])" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 207, 74 | "id": "1ab8d67a-69de-425f-9d3c-f76b9f240285", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "examples = []\n", 79 | "example = train_text[1] # train_text[0] is emtpy\n", 80 | "\n", 81 | "for text in train_text[2:]:\n", 82 | " if is_title(text):\n", 83 | " examples.append(example)\n", 84 | " example = text\n", 85 | " else:\n", 86 | " example += text" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 208, 92 | "id": "a04d4e82-bb41-4904-9f87-d8e4313af396", 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "27571" 99 | ] 100 | }, 101 | "execution_count": 208, 102 | "metadata": {}, 103 | "output_type": "execute_result" 104 | } 105 | ], 106 | "source": [ 107 | "len(examples)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "d45a36a4-95a1-442d-889c-220ee878875c", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "ex_idxs = sorted(range(len(examples)), key=lambda i: len(examples[i]))" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 232, 123 | "id": "edb6245d-78f2-461c-9dd9-7e22153b8dbb", 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "name": "stdout", 128 | "output_type": "stream", 129 | "text": [ 130 | " = Rank in the league ; P = \n", 131 | " Played ; W \n", 132 | " = Win ; D = \n", 133 | " Draw ; L \n", 134 | " = Loss ; F = \n", 135 | " Goals for ; A \n", 136 | "\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "i = 0\n", 142 | "print(*examples[ex_idxs[i]-1:ex_idxs[i]+2])" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 233, 148 | "id": "fee2a5e5-b634-40a9-a67a-2699cfda5527", 149 | "metadata": { 150 | "scrolled": true 151 | }, 152 | "outputs": [ 153 | { 154 | "data": { 155 | "application/vnd.jupyter.widget-view+json": { 156 | "model_id": "6fab455a10754de69826cd991f419c4e", 157 | "version_major": 2, 158 | "version_minor": 0 159 | }, 160 | "text/plain": [ 161 | " 0%| | 0/27571 [00:00 Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the \" Nameless \" , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit \" Raven \" . \n", 216 | " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more forgiving for series newcomers . Character designer Honjou and composer Hitoshi Sakimoto both returned from previous entries , along with Valkyria Chronicles II director Takeshi Ozawa . A large team of writers handled the script . The game 's opening theme was sung by May 'n . \n", 217 | " It met with positive sales in Japan , and was praised by both Japanese and western critics . After release , it received downloadable content , along with an expanded edition in November of that year . It was also adapted into manga and an original video animation series . Due to low sales of Valkyria Chronicles II , Valkyria Chronicles III was not localized , but a fan translation compatible with the game 's expanded edition was released in 2014 . Media.Vision would return to the franchise with the development of Valkyria : Azure Revolution for the PlayStation 4 . \n", 218 | " = = Gameplay = = \n", 219 | " As with previous Chronicles games , Valkyria Chronicles III is a tactical role @-@ playing game where players take control of a military unit and take part in missions against enemy forces . Stories are told through comic book @-@ like panels with animated character portraits , with characters speaking partially through voiced speech bubbles and partially through unvoiced text . The player progresses through a series of linear missions , gradually unlocked as maps that can be freely scanned through and replayed as they are unlocked . The route to each story location on the map varies depending on an individual player 's approach : when one option is selected , the other is sealed off to the player . Outside missions , the player characters rest in a camp , where units can be customized and character growth occurs . Alongside the main story missions are character @-@ specific sub missions relating to different squad members . After the game 's completion , additional episodes are unlocked , some of them having a higher difficulty than those found in the rest of the game . There are also love simulation elements related to the game 's two main heroines , although they take a very minor role . \n", 220 | " The game 's battle system , the system , is carried over directly from Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : once a character is selected , the player moves the character around the battlefield in third @-@ person . A character can only act once per @-@ turn , but characters can be granted multiple turns at the expense of other characters ' turns . Each character has a field and distance of movement limited by their Action Gauge . Up to nine characters can be assigned to a single mission . During gameplay , characters will call out if something happens to them , such as their health points ( HP ) getting low or being knocked out by enemy attacks . Each character has specific \" Potentials \" , skills unique to each character . They are divided into \" Personal Potential \" , which are innate skills that remain unaltered unless otherwise dictated by the story and can either help or impede a character , and \" Battle Potentials \" , which are grown throughout the game and always grant boons to a character . To learn Battle Potentials , each character has a unique \" Masters Table \" , a grid @-@ based skill table that can be used to acquire and link different skills . Characters also have Special Abilities that grant them temporary boosts on the battlefield : Kurt can activate \" Direct Command \" and move around the battlefield without depleting his Action Point gauge , the character can shift into her \" Valkyria Form \" and become invincible , while Imca can target multiple enemy units with her heavy weapon . \n", 221 | " Troops are divided into five classes : Scouts , , Engineers , Lancers and Armored Soldier . Troopers can switch classes by changing their assigned weapon . Changing class does not greatly affect the stats gained while in a previous class . With victory in battle , experience points are awarded to the squad , which are distributed into five different attributes shared by the entire squad , a feature differing from early games ' method of distributing to different unit types . \n", 222 | " = = Plot = = \n", 223 | " The game takes place during the Second Europan War . Gallian Army Squad 422 , also known as \" The Nameless \" , are a penal military unit composed of criminals , foreign deserters , and military offenders whose real names are erased from the records and thereon officially referred to by numbers . Ordered by the Gallian military to perform the most dangerous missions that the Regular Army and Militia will not do , they are nevertheless up to the task , exemplified by their motto , , meaning \" Always Ready . \" The three main characters are No.7 Kurt Irving , an army officer falsely accused of treason who wishes to redeem himself ; Ace No.1 Imca , a female Darcsen heavy weapons specialist who seeks revenge against the Valkyria who destroyed her home ; and No.13 Riela , a seemingly jinxed young woman who is unknowingly a descendant of the Valkyria . Together with their fellow squad members , these three are tasked to fight against a mysterious Imperial unit known as Calamity Raven , consisting of mostly Darcsen soldiers . \n", 224 | " As the Nameless officially do not exist , the upper echelons of the Gallian Army exploit the concept of plausible deniability in order to send them on missions that would otherwise make Gallia lose face in the war . While at times this works to their advantage , such as a successful incursion into Imperial territory , other orders cause certain members of the 422nd great distress . One such member , , becomes so enraged that he abandons his post and defects into the ranks of Calamity Raven , attached to the ideal of Darcsen independence proposed by their leader , Dahau . At the same time , elements within Gallian Army Command move to erase the Nameless in order to protect their own interests . Hounded by both allies and enemies , and combined with the presence of a traitor within their ranks , the 422nd desperately move to keep themselves alive while at the same time fight to help the Gallian war effort . This continues until the Nameless 's commanding officer , Ramsey Crowe , who had been kept under house arrest , is escorted to the capital city of in order to present evidence exonerating the weary soldiers and expose the real traitor , the Gallian General that had accused Kurt of Treason . \n", 225 | " Partly due to these events , and partly due to the major losses in manpower Gallia suffers towards the end of the war with the Empire , the Nameless are offered a formal position as a squad in the Gallian Army rather than serve as an anonymous shadow force . This is short @-@ lived , however , as following Maximilian 's defeat , Dahau and Calamity Raven move to activate an ancient super weapon within the Empire , kept secret by their benefactor . Without the support of Maximilian or the chance to prove themselves in the war with Gallia , it is Dahau 's last trump card in creating a new Darcsen nation . As an armed Gallian force invading the Empire just following the two nations ' cease @-@ fire would certainly wreck their newfound peace , Kurt decides to once again make his squad the Nameless , asking Crowe to list himself and all under his command as killed @-@ in @-@ action . Now owing allegiance to none other than themselves , the 422nd confronts Dahau and destroys the weapon . Each member then goes their separate ways in order to begin their lives anew . \n", 226 | " = = Development = = \n", 227 | " Concept work for Valkyria Chronicles III began after development finished on Valkyria Chronicles II in early 2010 , with full development beginning shortly after this . The director of Valkyria Chronicles II , Takeshi Ozawa , returned to that role for Valkyria Chronicles III . Development work took approximately one year . After the release of Valkyria Chronicles II , the staff took a look at both the popular response for the game and what they wanted to do next for the series . Like its predecessor , Valkyria Chronicles III was developed for PlayStation Portable : this was due to the team wanting to refine the mechanics created for Valkyria Chronicles II , and they had not come up with the \" revolutionary \" idea that would warrant a new entry for the PlayStation 3 . Speaking in an interview , it was stated that the development team considered Valkyria Chronicles III to be the series ' first true sequel : while Valkyria Chronicles II had required a large amount of trial and error during development due to the platform move , the third game gave them a chance to improve upon the best parts of Valkyria Chronicles II due to being on the same platform . In addition to Sega staff from the previous games , development work was also handled by The original scenario was written Kazuki Yamanobe , while the script was written by Hiroyuki Fujii , Koichi Majima , Miyagi , Seiki and Takayuki . Its story was darker and more somber than that of its predecessor . \n", 228 | " The majority of material created for previous games , such as the system and the design of maps , was carried over . Alongside this , improvements were made to the game 's graphics and some elements were expanded , such as map layouts , mission structure , and the number of playable units per mission . A part of this upgrade involved creating unique polygon models for each character 's body . In order to achieve this , the cooperative elements incorporated into the second game were removed , as they took up a large portion of memory space needed for the improvements . They also adjusted the difficulty settings and ease of play so they could appeal to new players while retaining the essential components of the series ' gameplay . The newer systems were decided upon early in development . The character designs were done by Honjou , who had worked on the previous Valkyria Chronicles games . When creating the Nameless Squad , Honjou was faced with the same problem he had had during the first game : the military uniforms essentially destroyed character individuality , despite him needing to create unique characters the player could identify while maintaining a sense of reality within the Valkyria Chronicles world . The main color of the Nameless was black . As with the previous Valkyria games , Valkyria Chronicles III used the graphics engine . The anime opening was produced by Production I.G. \n", 229 | " = = = Music = = = \n", 230 | " The music was composed by Hitoshi Sakimoto , who had also worked on the previous Valkyria Chronicles games . When he originally heard about the project , he thought it would be a light tone similar to other Valkyria Chronicles games , but found the themes much darker than expected . An early theme he designed around his original vision of the project was rejected . He the main theme about seven times through the music production due to this need to reassess the game . The main theme was initially recorded using orchestra , then Sakimoto removed elements such as the guitar and bass , then adjusted the theme using a synthesizer before redoing segments such as the guitar piece on their own before incorporating them into the theme . The rejected main theme was used as a hopeful tune that played during the game 's ending . The battle themes were designed around the concept of a \" modern battle \" divorced from a fantasy scenario by using modern musical instruments , constructed to create a sense of atonality . While Sakimoto was most used to working with synthesized music , he felt that he needed to incorporate live instruments such as orchestra and guitar . The guitar was played by Mitsuhiro Ohta , who also arranged several of the later tracks . The game 's opening theme song , \" If You Wish for ... \" ( , Kimi ga Nara ) , was sung by Japanese singer May 'n . Its theme was the reason soldiers fought , in particular their wish to protect what was precious to them rather than a sense of responsibility or duty . Its lyrics were written by Seiko Fujibayashi , who had worked on May 'n on previous singles . \n", 231 | " = = = Release = = = \n", 232 | " In September 2010 , a teaser website was revealed by Sega , hinting at a new Valkyria Chronicles game . In its September issue , Famitsu listed that Senjō no Valkyria 3 would be arriving on the PlayStation Portable . Its first public appearance was at the 2010 Tokyo Game Show ( TGS ) , where a demo was made available for journalists and attendees . During the publicity , story details were kept scant so as not to spoil too much for potential players , along with some of its content still being in flux at the time of its reveal . To promote the game and detail the story leading into the game 's events , an episodic Flash visual novel written by Fujii began release in January 2011 . The game was released January 27 , 2011 . During an interview , the development team said that the game had the capacity for downloadable content ( DLC ) , but that no plans were finalized . Multiple DLC maps , featuring additional missions and recruitable characters , were released between February and April 2011 . An expanded edition of the game , Valkyria Chronicles III Extra Edition , released on November 23 , 2011 . Packaged and sold at a lower price than the original , Extra Edition game with seven additional episodes : three new , three chosen by staff from the game 's DLC , and one made available as a pre @-@ order bonus . People who also owned the original game could transfer their save data between versions . \n", 233 | " Unlike its two predecessors , Valkyria Chronicles III was not released in the west . According to Sega , this was due to poor sales of Valkyria Chronicles II and the general unpopularity of the PSP in the west . An unofficial fan translation patch began development in February 2012 : players with a copy of Valkyria Chronicles III could download and apply the patch , which translated the game 's text into English . Compatible with the Extra Edition , the patch was released in January 2014 . \n", 234 | " = = Reception = = \n", 235 | " On its day of release in Japan , Valkyria Chronicles III topped both platform @-@ exclusive and multi @-@ platform sales charts . By early February , the game sold 102 @,@ 779 units , coming in second overall to The Last Story for the Wii . By the end of the year , the game had sold just over 152 @,@ 500 units . \n", 236 | " Famitsu enjoyed the story , and were particularly pleased with the improvements to gameplay . Japanese gaming site Game Watch Impress , despite negatively noting its pacing and elements recycled from previous games , was generally positive about its story and characters , and found its gameplay entertaining despite off @-@ putting difficulty spikes . 4Gamer.net writer Naohiko , in a \" Play Test \" article based on the game 's PSN demo , felt that Valkyria Chronicles III provided a \" profound feeling of closure \" for the Valkyria Chronicles series . He praised its gameplay despite annoying limitations to aspects such as special abilities , and positively noted its shift in story to a tone similar to the first game . \n", 237 | " PlayStation Official Magazine - UK praised the story 's blurring of Gallia 's moral standing , art style , and most points about its gameplay , positively noting the latter for both its continued quality and the tweaks to balance and content . Its one major criticism were multiple difficulty spikes , something that had affected the previous games . Heath Hindman of gaming website PlayStation Lifestyle praised the addition of non @-@ linear elements and improvements or removal of mechanics from Valkyria Chronicles II in addition to praising the returning gameplay style of previous games . He also positively noted the story 's serious tone . Points criticized in the review were recycled elements , awkward cutscenes that seemed to include all characters in a scene for no good reason , pacing issues , and occasional problems with the game 's AI . \n", 238 | " In a preview of the TGS demo , Ryan Geddes of IGN was left excited as to where the game would go after completing the demo , along with enjoying the improved visuals over Valkyria Chronicles II . Kotaku 's Richard Eisenbeis was highly positive about the game , citing is story as a return to form after Valkyria Chronicles II and its gameplay being the best in the series . His main criticisms were its length and gameplay repetition , along with expressing regret that it would not be localized . \n", 239 | " = = Legacy = = \n", 240 | " Kurt and Riela were featured in the Nintendo 3DS crossover Project X Zone , representing the Valkyria series . Media.Vision would return to the series to develop Valkyria : Azure Revolution , with Ozawa returning as director . Azure Revolution is a role @-@ playing video game for the PlayStation 4 that forms the beginning of a new series within the Valkyria franchise . \n", 241 | " = = = Adaptations = = = \n", 242 | " Valkyria Chronicles 3 was adapted into a two @-@ episode original video animation series in the same year of its release . Titled Senjō no Valkyria 3 : Taga Tame no ( , lit . Valkyria of the Battlefield 3 : The Wound Taken for Someone 's Sake ) , it was originally released through PlayStation Network and between April and May 2011 . The initially @-@ planned release and availability period needed to be extended due to a stoppage to PSN during the early summer of that year . It later released for DVD on June 29 and August 31 , 2011 , with separate \" Black \" and \" Blue \" editions being available for purchase . The anime is set during the latter half of Valkyria Chronicles III , detailing a mission by the Nameless against their Imperial rivals Calamity Raven . The anime was first announced in November 2010 . It was developed by A @-@ 1 Pictures , produced by Shinji , directed by Nobuhiro Kondō , and written by Hiroshi . Sakimoto 's music for the game was used in the anime . \n", 243 | " The anime 's title was inspired by the principle purpose of the Nameless : to suffer in battle for the goals of others . A subtitle attached to the project during development was \" The Road to Kubinka \" , which referenced the Kubinka Tank Museum in Moscow . The game 's main theme was how the characters regained their sense of self when stripped of their names and identities , along with general themes focused on war and its consequences . While making the anime , the production team were told by Sega to make it as realistic as possible , with the consequence that the team did extensive research into aspects such as what happened when vehicles like tanks were overturned or damaged . Due to it being along the same timeline as the original game and its television anime adaptation , the cast of Valkyria Chronicles could make appearances , which pleased the team . The opening theme , \" Akari ( Light ) \" ( @-@ ) , was sung by Japanese singer . The ending theme , \" Someday the Flowers of Light Will Bloom \" ( , Itsuka Saku Hikari no Hana ) , was sung by Minami Kuribayashi . Both songs ' lyrics were written by their respective artists . \n", 244 | " Two manga adaptations were produced , following each of the game 's main female protagonists Imca and Riela . They were Senjō no Valkyria 3 : Namo no Hana ( 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 : The Flower of the Nameless Oath ) , illustrated by Naoyuki Fujisawa and eventually released in two volumes after being serialized in Dengeki Maoh between 2011 and 2012 ; and Senjō no Valkyria 3 : Unmei no ( 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 -The Valkyrie of the Crimson Fate ) , illustrated by Mizuki Tsuge and eventually released in a single volume by Kadokawa Shoten in 2012 . \n", 245 | "\n" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "print(sp_model.decode(train_tokens[0].tolist()))" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 236, 256 | "id": "e254360e-76d3-4153-9e7b-a13a0440b1e1", 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "chunk_size = 1000\n", 261 | "for chunk_idx, idx in enumerate(range(0, len(train_tokens), chunk_size)):\n", 262 | " mx.savez(f'wikitext_data/train/example_chunk{chunk_idx:02d}', *train_tokens[idx:idx+chunk_size])" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 237, 268 | "id": "e84eee2b-cfcd-4d83-8d6b-d8dbbe5a6e45", 269 | "metadata": {}, 270 | "outputs": [ 271 | { 272 | "data": { 273 | "text/plain": [ 274 | "27571" 275 | ] 276 | }, 277 | "execution_count": 237, 278 | "metadata": {}, 279 | "output_type": "execute_result" 280 | } 281 | ], 282 | "source": [ 283 | "from pathlib import Path\n", 284 | "import mlx.core as mx\n", 285 | "\n", 286 | "train_data_dir = Path('./wikitext_data/train')\n", 287 | "train_examples = []\n", 288 | "\n", 289 | "for ex_path in sorted(train_data_dir.glob('*.npz')):\n", 290 | " train_examples.extend(mx.load(str(ex_path)).values())\n", 291 | "\n", 292 | "len(train_examples)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 238, 298 | "id": "063ffe0f-04bf-4c45-bf14-a3479effd876", 299 | "metadata": { 300 | "scrolled": true 301 | }, 302 | "outputs": [ 303 | { 304 | "data": { 305 | "text/plain": [ 306 | "' = Lloyd Mathews = \\n Sir Lloyd William Mathews , GCMG , CB ( 7 March 1850 – 11 October 1901 ) was a British naval officer , politician and abolitionist . Mathews joined the Royal Navy as a cadet at the age of 13 and progressed through the ranks to lieutenant . He was involved with the Third Anglo @-@ Ashanti War of 1873 – 4 , afterwards being stationed in East Africa for the suppression of the slave trade . In 1877 he was seconded from the navy to Sultan Barghash of Zanzibar in order to form a European @-@ style army ; he would remain in the employment of the government of Zanzibar for the rest of his life . His army quickly reached 6 @,@ 300 men and was used in several expeditions to suppress the slave trade and rebellions against the Zanzibar government . \\n Mathews retired from the Royal Navy in 1881 and was appointed Brigadier @-@ General of Zanzibar . There followed more expeditions to the African mainland , including a failed attempt to stop German expansion in East Africa . In October 1891 Mathews was appointed First Minister to the Zanzibar government , a position in which he was \" irremovable by the sultan \" . During this time Mathews was a keen abolitionist and promoted this cause to the Sultans he worked with . This resulted in the prohibiting of the slave trade in Zanzibar \\'s dominions in 1890 and the abolition of slavery in 1897 . Mathews was appointed the British Consul @-@ General for East Africa in 1891 but declined to take up the position , remaining in Zanzibar instead . Mathews and his troops also played a key role in the ending of the Anglo @-@ Zanzibar War of 1896 which erupted out of an attempt to bypass the requirement that new Sultans must be vetted by the British consul . During his time as first minister Mathews continued to be involved with the military and was part of two large campaigns , one to Witu and another to Mwele . \\n Mathews was decorated by several governments , receiving appointments as a Companion of the Order of St Michael and St George , Companion of the Order of the Bath and as a Knight Commander of the Order of St Michael and St George from the British government and membership in the Prussian Order of the Crown . Zanzibar also rewarded him and he was a member of the Grand Order of Hamondieh and a first class member of the Order of the Brilliant Star of Zanzibar . Mathews died of malaria in Zanzibar on 11 October 1901 . \\n = = Early life and career = = \\n Mathews was born at Funchal on Madeira on 7 March 1850 . His father , Captain William Matthews was Welsh , and his mother Jane Wallis Penfold , was the daughter of William Penfold and Sarah Gilbert . Her sister , Augusta Jane Robley née Penfold was the author of a famous book about the flora and fauna of Madeira , which is now in the Natural History Museum . Mathews became a cadet of the Royal Navy in 1863 and was appointed a midshipman on 23 September 1866 . From 1868 he was stationed in the Mediterranean but his first active service was during the Third Anglo @-@ Ashanti War of 1873 – 4 where he qualified for the campaign medal . He was promoted to lieutenant on 31 March 1874 . On 27 August 1875 Mathews was posted to HMS London , a depot ship and the Royal Navy headquarters for East Africa , to assist in the suppression of the slave trade in the area . Whilst onboard he drilled his own troops , captured several slave dhows and was commended for his actions by the Admiralty . \\n = = Commander in Chief of Zanzibar = = \\n In August 1877 , Mathews was seconded from the Navy to Sultan Barghash of Zanzibar to form a European @-@ style army which could be used to enforce Zanzibar \\'s control over its mainland possessions . The army had traditionally been composed entirely of Arabs and Persians but Mathews opened up recruitment to the African majority on the island and had 300 recruits in training by the end of the year . In addition , Mathews employed some unorthodox recruitment methods such as purchasing slaves from their masters , using inmates from the prison and recruiting from Africans rescued from the slavers . In June 1877 , at the instigation of John Kirk , the explorer and friend of the Sultan , the British government sent a shipment of 500 modern rifles and ammunition as a gift with which to arm the troops . Mathews introduced a new uniform for the troops consisting of a red cap , short black jackets and white trousers for the enlisted ranks and dark blue frock coats and trousers with gold and silver lace for the Arab officers . The latter was possibly modelled on the Royal Navy officers uniform with which he was familiar . The army grew quickly ; by the 1880s Mathews would command 1 @,@ 300 men , his forces eventually numbering 1 @,@ 000 regulars and 5 @,@ 000 irregulars . \\n One of the first tasks for the new army was to suppress the smuggling of slaves from Pangani on the mainland to the island of Pemba , north of Zanzibar . The troops completed this mission , capturing several slavers and hindering the trade . Mathews retired from the Royal Navy in June 1881 and was appointed Brigadier @-@ General of Zanzibar . In 1880 , the Sultan dispatched a military force under Mathews to bring his unruly African mainland territories under control . Mathews \\' expedition was initially intended to reach but his men refused to march inland and , when made to do so , deserted in large numbers . The expedition ended instead at where a 60 @-@ man garrison was established . This had been reduced to a mere handful of men by the mid @-@ 1880s but the expedition proved that the Sultan was serious about maintaining control of all of his possessions . Mathews \\' men were also involved in several expeditions to halt the land @-@ based slave trade which had developed once the seas became too heavily policed for the traders . \\n In 1881 Mathews \\' old vessel , the HMS London , was captained by Charles J Brownrigg . This vessel and her crew made several patrols aimed at hindering the slave trade using smaller steam boats for the actual pursuits and captures . On December 3 , 1881 , they caught up with a slave dhow captained by Hindi bin Hattam . This dhow had around 100 slaves on board and was transporting them between Pemba and Zanzibar . Captain Brownrigg led a boarding party to release the slaves but bin Hattam \\'s men then attacked the sailors , killing Brownrigg and his party before sailing away . Mathews led a force to Wete on Pemba and , after a short battle , took a mortally wounded bin Hattem prisoner before returning to Zanzibar . \\n Mathews returned to the African mainland territories once more in 1884 when he landed with a force which intended to establish further garrisons there to dissuade German territorial claims . This attempt ultimately failed when five German warships steamed into Zanzibar Town harbour and threatened the Sultan into signing away the territories which would later form German East Africa . Further territories were ceded to the German East Africa Company in 1888 but unrest amongst the locals against them prevented them from taking control and Mathews was dispatched with 100 men to restore order . Finding around 8 @,@ 000 people gathered against the German administrators Mathews was forced to return with his men to Zanzibar . He landed once again with more troops but found himself subject to death threats and that his troops would not obey his orders and so returned again to Zanzibar . \\n = = First Minister = = \\n In October 1891 , upon the formation of the first constitutional government in Zanzibar , Mathews was appointed First Minister , despite some hostility from Sultan Ali bin Said . In this capacity Mathews was \" irremovable by the sultan \" and answerable only to the Sultan and the British Consul . His position was so strong that one missionary on the island is quoted as saying that his powers defied \" analytical examination \" and that Mathews really could say \" L \\'état est moi \" ( I am the state ) . Mathews was also known as the \" Strong man of Zanzibar \" . The principal departments of government were mostly run by Britons or British Indians and Mathews \\' approval was required before they could be removed from office . Mathews was rewarded by the Zanzibar government for his role with his appointment as a first class member of the Order of the Brilliant Star of Zanzibar , which he was granted licence by Queen Victoria to accept and wear on 17 May 1886 . Mathews used his position to suppress slavery in the country and in 1889 convinced the Sultan to issue a decree purchasing the freedom of all slaves who had taken refuge in his dominions and , from 1890 , the prohibiting the slave trade . On 1 February 1891 Mathews was appointed Her Majesty \\'s Commissioner and Consul @-@ General to the British Sphere of Influence in East Africa . He never took up the post and instead chose to remain in Zanzibar . \\n Mathews was rewarded for his service in Zanzibar by the British government which appointed him a Companion of the Order of St Michael and St George in 1880 and a Companion of the Order of the Bath on 24 May 1889 . Despite becoming renowned in East Africa as a man who ran a fair administration and was strict with criminals , unhappiness with effective British rule and his halting of the slave trade led some Arabs to petition the Sultan for his removal in 1892 . In 1893 Mathews purchased the island of for the government . He intended it to be used as a prison but it never housed prisoners and was instead used to quarantine yellow fever cases before its present use as a conservation area for giant tortoises . Mathews was appointed a Knight Commander of the Order of St Michael and St George in 1894 . He was also awarded membership of the Order of the Crown by the German government . \\n Matters came to a head when Khalid bin Barghash attempted to take control of the palace in Zanzibar Town upon the death of his uncle in August 1896 , despite failing to gain the consent of the British consul there . Mathews opposed this succession and , with British agreement , called up 900 soldiers in an attempt to prevent it . This situation eventually led to the Anglo @-@ Zanzibar War and Mathews , with the support of Admiral Harry Rawson and five vessels of the Royal Navy , bombarded the palace and secured the end of Khalid \\'s administration . Mathews \\' helped to arrange the succession of a pro @-@ British Sultan , bin Mohammed , as Khalid \\'s successor . Mathews continued his reforms after the war , abolishing slavery in 1897 and establishing new farms to grow produce using Western techniques . He was appointed a member of the Grand Order of Hamondieh of Zanzibar and was permitted to accept and wear the decoration on 25 August 1897 . \\n = = Military expeditions = = \\n = = = Mwele = = = \\n In addition to the smaller @-@ scale expeditions described earlier , Mathews embarked on two much larger expeditions to the African mainland during his tenure as first minister , the first at Mwele . The initial rebellion in the area had been led by Mbaruk bin Rashid at Gazi , which Mathews had put down with 1 @,@ 200 men in 1882 . However , in 1895 Mbaruk \\'s nephew , Mbaruk bin Rashid , refused to acknowledge the appointment of a new leader at . This led to open rebellion at in February of that year when the younger Mbaruk attacked Zanzibari troops under Arthur Raikes , one of Mathews \\' officers . Mathews was part of an Anglo @-@ Zanzibari expedition sent to quell it , which consisted of 310 British sailors , 50 Royal Marines , 54 Sudanese and 164 Zanzibari troops . was destroyed and the leaders fled to Gazi where the older Mbaruk failed to turn them over . Another force , under Admiral Rawson , with 400 British marines and sailors , was sent after them . This further expedition failed to capture the ringleaders and a third expedition was organised by Rawson with 220 sailors , 80 marines , 60 Sudanese and 50 Zanzibaris , which destroyed Mwele . During the latter action Mathews was wounded in the shoulder . \\n = = = Witu = = = \\n Following the death of a German logger who had been operating illegally , the Sultan of Zanzibar and the British government dispatched an expedition on 20 October 1890 to bring the Sultan of Witu to justice . Nine warships and three transports carrying 800 sailors and marines , 150 Imperial British East Africa Company ( IBEA ) Indian police , 200 Zanzibari and 50 Sudanese troops were sent , defeating the Sultan and establishing a British protectorate . The IBEA was given control of the area and established a force of 250 Indian police to maintain the peace . The police were withdrawn in July 1893 following threats of violence from the new Sultan of Witu , Oman , and another expedition was dispatched to the region . This consisted of three warships : HMS Blanche , HMS Sparrow and the Zanzibari ship HHS . The latter carried Mathews with 125 Askaris and 50 Sudanese under Brigadier @-@ General Hatch of the Zanzibar army . \\n Mathews and an escort force went to Witu where , on 31 July , they removed the flag of the IBEA company and replaced it with the red flag of Zanzibar , before destroying several villages and causing Oman to retreat into the forests . The British troops then withdrew , having suffered heavily from malaria , but the Sudanese and Zanzibari troops remained . A further expedition was sent of 140 sailors and 85 other troops but Oman died soon after and a more pliable sultan , Omar bin Hamid , was appointed to govern on behalf of Zanzibar , bringing the affair to a close . In return for this action , Mathews received the British East and West Africa campaign medal . \\n = = Later life = = \\n Mathews died of malaria in Zanzibar on 11 October 1901 and was buried with full military honours in the British cemetery outside Zanzibar Town . His successor as first minister was A.S. Rogers . island , which Mathews bought for a prison , now has a restaurant named in his honour and also a church . Mathews House , at the Western end of Zanzibar Town , is also named for him . \\n'" 307 | ] 308 | }, 309 | "execution_count": 238, 310 | "metadata": {}, 311 | "output_type": "execute_result" 312 | } 313 | ], 314 | "source": [ 315 | "sp_model.decode(train_examples[0].tolist())" 316 | ] 317 | } 318 | ], 319 | "metadata": { 320 | "kernelspec": { 321 | "display_name": "Python 3 (ipykernel)", 322 | "language": "python", 323 | "name": "python3" 324 | }, 325 | "language_info": { 326 | "codemirror_mode": { 327 | "name": "ipython", 328 | "version": 3 329 | }, 330 | "file_extension": ".py", 331 | "mimetype": "text/x-python", 332 | "name": "python", 333 | "nbconvert_exporter": "python", 334 | "pygments_lexer": "ipython3", 335 | "version": "3.11.6" 336 | } 337 | }, 338 | "nbformat": 4, 339 | "nbformat_minor": 5 340 | } 341 | -------------------------------------------------------------------------------- /notebooks/train-step-dev.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "676e491e-497e-49b9-81b4-6a332c193ace", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from dataclasses import dataclass, asdict\n", 11 | "\n", 12 | "@dataclass\n", 13 | "class TrainerConfig:\n", 14 | " bsz: int = 16\n", 15 | " lr: float = 1e-4\n", 16 | " n_steps: int = 1000\n", 17 | " pad_token_id: int = 65535 # Max value of uint16\n", 18 | "\n", 19 | "cfg_t = TrainerConfig()\n", 20 | "cfg_m = LLaMAConfig()" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "id": "50982c1a-8cad-422c-9467-1bf56826c105", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "from dataset import config_dataloader\n", 31 | "load_data = config_dataloader(cfg_t.bsz, cfg_m.seq_len, cfg_t.pad_token_id)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 4, 37 | "id": "d2b7bfc2-0d8b-418c-8c55-2ad4a696d251", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import mlx.core as mx\n", 42 | "\n", 43 | "mx.set_default_device(mx.gpu)\n", 44 | "\n", 45 | "data_iter = iter(load_data())\n", 46 | "inputs_, targets_ = next(data_iter)\n", 47 | "while mx.all(inputs_ != cfg_t.pad_token_id):\n", 48 | " targets_ = next(data_iter)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 8, 54 | "id": "8d6c7905-d6da-4165-9ee4-e311c5e471f1", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "from llama import LLaMAConfig, LLaMA\n", 59 | "model = LLaMA(**asdict(cfg_m))" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 13, 65 | "id": "c76d2104-17b9-4fb9-9058-58495b351412", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "from mlx import nn\n", 70 | "\n", 71 | "def forward(model, inputs, targets):\n", 72 | " pad_mask = (inputs != cfg_t.pad_token_id)\n", 73 | " logits = model(inputs * pad_mask)\n", 74 | "\n", 75 | " logprobs = nn.losses.cross_entropy(logits, targets)\n", 76 | " logprobs_m = logprobs * pad_mask\n", 77 | " loss = logprobs_m.sum() / pad_mask.sum()\n", 78 | "\n", 79 | " return loss" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 14, 85 | "id": "05599ceb-909b-411b-b2a5-7ea4870bb0da", 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "data": { 90 | "text/plain": [ 91 | "array(10.5277, dtype=float32)" 92 | ] 93 | }, 94 | "execution_count": 14, 95 | "metadata": {}, 96 | "output_type": "execute_result" 97 | } 98 | ], 99 | "source": [ 100 | "loss_and_grad = nn.value_and_grad(model, forward)\n", 101 | "loss, grad = loss_and_grad(model, inputs_, targets_)\n", 102 | "loss" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 8, 108 | "id": "c56f3377-81f2-4c6a-8477-27bed3eb1a1d", 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "text/plain": [ 114 | "array(61.2031, dtype=float32)" 115 | ] 116 | }, 117 | "execution_count": 8, 118 | "metadata": {}, 119 | "output_type": "execute_result" 120 | } 121 | ], 122 | "source": [ 123 | "from functools import partial\n", 124 | "import mlx.optimizers as optim\n", 125 | "\n", 126 | "optimizer = optim.AdamW(learning_rate=cfg_t.lr)\n", 127 | "state = [model.state, optimizer.state]\n", 128 | "\n", 129 | "@partial(mx.compile, inputs=state, outputs=state)\n", 130 | "def train_step(inputs, targets):\n", 131 | " loss_and_grad = nn.value_and_grad(model, forward)\n", 132 | " loss, grads = loss_and_grad(model, inputs, targets)\n", 133 | " optimizer.update(model, grads)\n", 134 | " return loss\n", 135 | "\n", 136 | "loss = train_step(inputs_, targets_)\n", 137 | "loss" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 9, 143 | "id": "b69943cf-2466-42ec-ae74-80f41ff15238", 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "data": { 148 | "text/plain": [ 149 | "array(55.4943, dtype=float32)" 150 | ] 151 | }, 152 | "execution_count": 9, 153 | "metadata": {}, 154 | "output_type": "execute_result" 155 | } 156 | ], 157 | "source": [ 158 | "train_step(inputs_, targets_)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 1, 164 | "id": "28df2601-ea1e-4fd2-afd3-b596e5cf58a5", 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "from dataclasses import dataclass, asdict\n", 169 | "from mlx import nn\n", 170 | "import mlx.core as mx\n", 171 | "import mlx.optimizers as optim\n", 172 | "from dataset import config_dataloader\n", 173 | "from llama import LLaMAConfig, LLaMA\n", 174 | "\n", 175 | "@dataclass\n", 176 | "class TrainerConfig:\n", 177 | " bsz: int = 16\n", 178 | " lr: float = 1e-4\n", 179 | " n_steps: int = 1000\n", 180 | " pad_token_id: int = 65535 # Max value of uint16\n", 181 | "\n", 182 | "cfg_t = TrainerConfig()\n", 183 | "cfg_m = LLaMAConfig()\n", 184 | "\n", 185 | "model = LLaMA(**asdict(cfg_m))\n", 186 | "optimizer = optim.AdamW(learning_rate=cfg_t.lr)\n", 187 | "load_data = config_dataloader(cfg_t.bsz, cfg_m.seq_len, cfg_t.pad_token_id)\n", 188 | "\n", 189 | "inputs_BT, targets_BT = next(load_data(cfg_t.n_steps))\n", 190 | "\n", 191 | "# One train forward pass: inputs_BT, cfg_t.pad_token_id, model, targets_BT\n", 192 | "# pad_mask_BT = (inputs_BT != cfg_t.pad_token_id)\n", 193 | "# logits_BTV = model(inputs_BT * pad_mask_BT)\n", 194 | "# logprobs_BT = nn.losses.cross_entropy(logits_BTV, targets_BT)\n", 195 | "# loss = (logprobs_BT * pad_mask_BT).sum() / pad_mask_BT.sum()" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 6, 201 | "id": "6d28d790-fa62-4685-8c44-d299d432a0fc", 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "Dummy(124)\n", 209 | "Dummy(0)\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "class Dummy:\n", 215 | " def __init__(self, val):\n", 216 | " self.val = val\n", 217 | " def __add__(self, rhs):\n", 218 | " return Dummy(self.val + rhs.val)\n", 219 | " def __repr__(self):\n", 220 | " return f'Dummy({self.val})'\n", 221 | "\n", 222 | "def test_scope():\n", 223 | " def inner(x):\n", 224 | " return x + y\n", 225 | " y = Dummy(39)\n", 226 | " x_ = Dummy(85)\n", 227 | " print(inner(x_))\n", 228 | " y.val = -85\n", 229 | " print(inner(x_))\n", 230 | "test_scope()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 2, 236 | "id": "0b5df78a-e70a-435a-b709-fcaab7a8ca11", 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "def train_forward_pass(model, inputs_BT, targets_BT):\n", 241 | " pad_mask_BT = (inputs_BT != cfg_t.pad_token_id)\n", 242 | " logits_BTV = model(inputs_BT * pad_mask_BT)\n", 243 | " logprobs_BT = nn.losses.cross_entropy(logits_BTV, targets_BT)\n", 244 | " loss = (logprobs_BT * pad_mask_BT).sum() / pad_mask_BT.sum()\n", 245 | " return loss" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 3, 251 | "id": "b25cbd9d-5b14-44e5-b93b-f681aae0cecf", 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "ename": "ValueError", 256 | "evalue": "[compile] Function arguments must be trees of arrays or constants (floats, ints, or strings), but received type mlx.optimizers.optimizers.AdamW.", 257 | "output_type": "error", 258 | "traceback": [ 259 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 260 | "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", 261 | "Cell \u001b[0;32mIn[3], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mupdate(model, grads)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model\u001b[38;5;241m.\u001b[39mstate, optimizer\u001b[38;5;241m.\u001b[39mstate\n\u001b[0;32m----> 7\u001b[0m model_state, opt_state \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs_BT\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets_BT\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m model\u001b[38;5;241m.\u001b[39mupdate(model_state)\n\u001b[1;32m 9\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mupdate(opt_state)\n", 262 | "\u001b[0;31mValueError\u001b[0m: [compile] Function arguments must be trees of arrays or constants (floats, ints, or strings), but received type mlx.optimizers.optimizers.AdamW." 263 | ] 264 | } 265 | ], 266 | "source": [ 267 | "@mx.compile\n", 268 | "def train_step(model, optimizer, inputs_BT, targets_BT):\n", 269 | " loss_and_grad = nn.value_and_grad(model, train_forward_pass)\n", 270 | " loss, grads = loss_and_grad(model, inputs_BT, targets_BT)\n", 271 | " optimizer.update(model, grads)\n", 272 | " return model.state, optimizer.state\n", 273 | "model_state, opt_state = train_step(model, optimizer, inputs_BT, targets_BT)\n", 274 | "model.update(model_state)\n", 275 | "optimizer.update(opt_state)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 9, 281 | "id": "27f6dcaa-e908-443a-bc40-a766b89749a9", 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "data": { 286 | "text/plain": [ 287 | "{'step': array(0, dtype=uint64), 'learning_rate': array(0.0001, dtype=float32)}" 288 | ] 289 | }, 290 | "execution_count": 9, 291 | "metadata": {}, 292 | "output_type": "execute_result" 293 | } 294 | ], 295 | "source": [ 296 | "optimizer.state" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "id": "bd581470-2c4b-4a73-8e29-b116263c8a00", 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [] 306 | } 307 | ], 308 | "metadata": { 309 | "kernelspec": { 310 | "display_name": "Python 3 (ipykernel)", 311 | "language": "python", 312 | "name": "python3" 313 | }, 314 | "language_info": { 315 | "codemirror_mode": { 316 | "name": "ipython", 317 | "version": 3 318 | }, 319 | "file_extension": ".py", 320 | "mimetype": "text/x-python", 321 | "name": "python", 322 | "nbconvert_exporter": "python", 323 | "pygments_lexer": "ipython3", 324 | "version": "3.11.6" 325 | } 326 | }, 327 | "nbformat": 4, 328 | "nbformat_minor": 5 329 | } 330 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlx==0.9.1 2 | mlx-lm==0.7.0 3 | sentencepiece==0.2.0 4 | notebook==7.1.2 5 | ipywidgets==8.1.2 6 | datasets==2.18.0 7 | asitop==0.0.24 8 | wandb==0.16.6 9 | joblib==1.4.0 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import asdict, dataclass, field 3 | from functools import partial 4 | 5 | import mlx.core as mx 6 | import mlx.optimizers as optim 7 | from mlx import nn 8 | from mlx.utils import tree_flatten, tree_map 9 | from tqdm import tqdm 10 | import wandb 11 | 12 | from dataset import config_dataloader 13 | from llama import init_params, LLaMAConfig, LLaMA 14 | 15 | 16 | @dataclass 17 | class TrainerConfig: 18 | bsz: int = 16 19 | lr: float = 1e-3 20 | n_update_steps: int = 225 21 | grad_acc_steps: int = 8 22 | warmup_ratio: float = 0.1 23 | n_steps: int = field(init=False) 24 | warmup_steps: int = field(init=False) 25 | n_epochs: int = 100 26 | pad_token_id: int = -1 27 | ckpt_name: str = 'mini-llama-wikitext-bsl' 28 | 29 | def __post_init__(self): 30 | self.n_steps = self.n_update_steps * self.grad_acc_steps 31 | self.warmup_steps = int(self.n_update_steps * self.warmup_ratio) 32 | 33 | 34 | def train(): 35 | cfg_t = TrainerConfig() 36 | cfg_m = LLaMAConfig(n_layers=6, d_embd=256, n_heads=8) 37 | wandb.init(project='mini-llama-mlx', config={**asdict(cfg_t), **asdict(cfg_m)}) 38 | 39 | load_train_data = config_dataloader(**asdict(cfg_t), **asdict(cfg_m), split='train') 40 | 41 | model = init_params(LLaMA(**asdict(cfg_m))) 42 | 43 | lr_schedule = optim.join_schedules([ 44 | optim.schedulers.linear_schedule(0.0, cfg_t.lr, cfg_t.warmup_steps), 45 | optim.cosine_decay(cfg_t.lr, cfg_t.n_update_steps-cfg_t.warmup_steps) 46 | ], [cfg_t.warmup_steps]) 47 | optimizer = optim.AdamW(learning_rate=lr_schedule) 48 | 49 | 50 | def train_forward_pass(model_, inputs_BT, labels_BT): 51 | pad_mask_BT = (inputs_BT != cfg_t.pad_token_id) 52 | logits_BTV = model_(inputs_BT * pad_mask_BT) 53 | logprobs_BT = nn.losses.cross_entropy(logits_BTV, labels_BT) 54 | loss = (logprobs_BT * pad_mask_BT).sum() / pad_mask_BT.sum() 55 | return loss 56 | 57 | 58 | @partial(mx.compile, inputs=model.state, outputs=model.state) 59 | def train_step(inputs_BT, labels_BT, grads): 60 | loss_and_grad = nn.value_and_grad(model, train_forward_pass) 61 | loss, grads_m = loss_and_grad(model, inputs_BT, labels_BT) 62 | grads = tree_map(lambda g, gm: (g + gm / cfg_t.grad_acc_steps), grads, grads_m) 63 | return loss, grads 64 | 65 | 66 | grads = tree_map(lambda p: mx.zeros(p.shape), model) 67 | pbar = tqdm(total=cfg_t.n_steps) 68 | model.train() 69 | 70 | for step, (inputs_BT, labels_BT) in enumerate(load_train_data()): 71 | try: 72 | loss, grads = train_step(inputs_BT, labels_BT, grads) 73 | mx.eval(loss, grads) 74 | 75 | loss, lr = map(lambda x: x.item(), [loss, optimizer.learning_rate]) 76 | pbar.set_description(f'{loss=:.4f} | {lr=:.2e} | peak_mem={(mx.metal.get_peak_memory()/2**30):.2f}GB') 77 | pbar.update(1) 78 | 79 | if ((step + 1) % cfg_t.grad_acc_steps == 0) or (step == cfg_t.n_steps - 1): 80 | optimizer.update(model, grads) 81 | mx.eval(model.state, optimizer.state) 82 | grads = tree_map(lambda p: mx.zeros(p.shape), model) 83 | wandb.log({'loss': loss, 'learning_rate': lr}) 84 | except KeyboardInterrupt: 85 | break 86 | 87 | pbar.close() 88 | mx.save_safetensors(cfg_t.ckpt_name, dict(tree_flatten(model))) 89 | 90 | 91 | if __name__ == '__main__': 92 | wandb.login(key=os.environ.get('WANDB_KEY', None)) # Set WANDB_MODE="disabled" to disable 93 | train() 94 | --------------------------------------------------------------------------------