├── .gitignore ├── LICENSE ├── README.md ├── convert.py ├── data └── example.txt ├── generate.py ├── gpt ├── README.py ├── convert.py ├── hf_model.py └── model.py ├── llama ├── __init__.py ├── model.py └── optim.py ├── pyproject.toml └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | poetry.lock 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Joe Barrow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🍏🤖 `mlx-playground` 2 | 3 | Run fast transformer decoders on your Macbooks' GPU! 4 | Working towards a fast reimplementation of GPT-2 and Llama-like models in [mlx](https://ml-explore.github.io/mlx/build/html/index.html). 5 | 6 | The aim is that the only dependencies are: 7 | - `mlx` 8 | - `sentencepiece` 9 | - `tqdm` 10 | - `numpy` 11 | 12 | With an optional dev dependency of: 13 | - `transformers` for downloading and converting weights 14 | 15 | ## Accomplishments 16 | 17 | - [x] ~~makemore llama reimplementation~~ (train your own w/ `python train.py`!) 18 | - [x] [BERT merged into `mlx-examples`](https://github.com/ml-explore/mlx-examples/pull/43) 19 | - [x] [Phi-2 merged into `mlx-examples`](https://github.com/ml-explore/mlx-examples/pull/97) 20 | - [x] [AdamW merged into `mlx`](https://github.com/ml-explore/mlx/pull/72) 21 | 22 | ## Remaining Goals 23 | 24 | This project will be considered complete once these goals are achieved. 25 | 26 | - [ ] finetune BERT 27 | - [ ] GPT-2 reimplementation and loading in MLX 28 | - [ ] speculative decoding 29 | - [ ] learning rate scheduling 30 | 31 | ## Installation 32 | 33 | ``` 34 | poetry install --no-root 35 | ``` 36 | 37 | ## Phi-2 38 | 39 | To download and convert the model: 40 | 41 | ```sh 42 | python phi2/convert.py 43 | ``` 44 | 45 | That will fill in `weights/phi-2.npz`. 46 | 47 | 🚧 (Not yet done) To run the model: 48 | 49 | ```sh 50 | python phi2/generate.py 51 | ``` 52 | 53 | ## Acknowledgements 54 | 55 | Some great resources: 56 | 57 | - [Brian Kitano's LLaMa from Scratch](https://blog.briankitano.com/llama-from-scratch/) 58 | - [PyTorch lab's `gpt-fast`](https://github.com/pytorch-labs/gpt-fast) 59 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbarrow/mlx-playground/d897be4dd948a5bc33981d01b126b847143f0e15/convert.py -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbarrow/mlx-playground/d897be4dd948a5bc33981d01b126b847143f0e15/generate.py -------------------------------------------------------------------------------- /gpt/README.py: -------------------------------------------------------------------------------- 1 | # GPT-2 2 | 3 | https://huggingface.co/gpt2 4 | http://jalammar.github.io/illustrated-gpt2/ 5 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 6 | -------------------------------------------------------------------------------- /gpt/convert.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from transformers import AutoModelForCausalLM 3 | from dataclasses import dataclass 4 | from pprint import pprint 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | 10 | @dataclass 11 | class ModelArgs: 12 | max_sequence_length: int = 1024 13 | num_vocab: int = 50257 14 | model_dim: int = 768 15 | mlp_dim: int = 3072 16 | num_heads: int = 12 17 | 18 | 19 | class Gpt2(nn.Module): 20 | def __init__(self, config: ModelArgs): 21 | self.wte = nn.Embedding(config.num_vocab, config.model_dim) 22 | self.wpe = nn.Embedding(config.max_sequence_length, config.model_dim) 23 | self.decoder = nn.TransformerDecoder( 24 | num_layers=config.num_layers, 25 | dims=config.model_dim, 26 | num_heads=config.num_heads, 27 | mlp_dims=config.mlp_dim 28 | ) 29 | 30 | self.lm_head = nn.Linear(config.model_dim, config.num_vocab) 31 | 32 | def __call__( 33 | self, 34 | input_ids: mx.array, 35 | positions: mx.array, 36 | attention_mask: mx.array = None, 37 | ) -> tuple[mx.array, mx.array]: 38 | text = self.wte(input_ids) 39 | position = self.wpe(positions) 40 | 41 | x = text + position 42 | 43 | if attention_mask is not None: 44 | # convert 0's to -infs, 1's to 0's, and make it broadcastable 45 | attention_mask = mx.log(attention_mask) 46 | attention_mask = mx.expand_dims(attention_mask, (1, 2)) 47 | 48 | y = self.decoder(x, attention_mask) 49 | return self.lm_head(y) 50 | 51 | 52 | def replace_key(key: str) -> str: 53 | 54 | return key 55 | 56 | 57 | if __name__ == "__main__": 58 | model = AutoModelForCausalLM.from_pretrained("gpt2") 59 | 60 | new_model = Gpt2(ModelArgs()) 61 | 62 | pprint([(k, v.shape) for k, v in model.state_dict().items()]) 63 | -------------------------------------------------------------------------------- /gpt/hf_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM 2 | from pprint import pprint 3 | 4 | 5 | if __name__ == "__main__": 6 | model = AutoModelForCausalLM.from_pretrained("gpt2") 7 | 8 | pprint([(k, v.shape) for k, v in model.state_dict().items()]) 9 | -------------------------------------------------------------------------------- /gpt/model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbarrow/mlx-playground/d897be4dd948a5bc33981d01b126b847143f0e15/gpt/model.py -------------------------------------------------------------------------------- /llama/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbarrow/mlx-playground/d897be4dd948a5bc33981d01b126b847143f0e15/llama/__init__.py -------------------------------------------------------------------------------- /llama/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | 4 | import mlx.nn as nn 5 | import mlx.core as mx 6 | 7 | 8 | @dataclass 9 | class ModelArgs: 10 | block_size: int = 16 11 | vocab_size: int = 65 12 | n_layers: int = 4 13 | n_heads: int = 8 14 | dims: int = 256 15 | intermediate_size: int = 512 16 | n_local_heads: int = -1 17 | head_dim: int = 64 18 | rope_base: float = 10_000 19 | norm_eps: float = 1e-5 20 | n_kv_heads: int = 4 21 | 22 | def __post_init__(self): 23 | if self.n_local_heads == -1: 24 | self.n_local_heads = self.n_heads 25 | 26 | # if self.intermediate_size is None: 27 | # hidden_dim = 4 * self.dims 28 | # n_hidden = int(2 * hidden_dim / 3) 29 | # self.intermediate_size = find_multiple(n_hidden, 256) 30 | 31 | self.head_dim = self.dims // self.n_heads 32 | 33 | 34 | class FeedForward(nn.Module): 35 | def __init__(self, config: ModelArgs) -> None: 36 | super().__init__() 37 | self.w1 = nn.Linear(config.dims, config.intermediate_size, bias=False) 38 | self.w2 = nn.Linear(config.intermediate_size, config.dims, bias=False) 39 | self.w3 = nn.Linear(config.dims, config.intermediate_size, bias=False) 40 | 41 | def __call__(self, x: mx.array) -> mx.array: 42 | return self.w2(nn.silu(self.w1(x)) * self.w3(x)) 43 | 44 | 45 | class Attention(nn.Module): 46 | def __init__(self, config: ModelArgs): 47 | super().__init__() 48 | self.args = config 49 | 50 | self.n_heads: int = config.n_heads 51 | self.n_kv_heads: int = config.n_kv_heads 52 | 53 | self.repeats = self.n_heads // self.n_kv_heads 54 | 55 | self.scale = self.args.head_dim**-0.5 56 | 57 | self.wq = nn.Linear(config.dims, config.n_heads * config.head_dim, bias=False) 58 | self.wk = nn.Linear( 59 | config.dims, config.n_kv_heads * config.head_dim, bias=False 60 | ) 61 | self.wv = nn.Linear( 62 | config.dims, config.n_kv_heads * config.head_dim, bias=False 63 | ) 64 | self.wo = nn.Linear(config.n_heads * config.head_dim, config.dims, bias=False) 65 | self.rope = nn.RoPE(config.head_dim, traditional=True) 66 | 67 | def __call__( 68 | self, 69 | x: mx.array, 70 | mask: Optional[mx.array] = None, 71 | cache: Optional[Tuple[mx.array, mx.array]] = None, 72 | ) -> mx.array: 73 | B, L, D = x.shape 74 | 75 | queries, keys, values = self.wq(x), self.wk(x), self.wv(x) 76 | 77 | # Prepare the queries, keys and values for the attention computation 78 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 79 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 80 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 81 | 82 | def repeat(a): 83 | a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) 84 | return a.reshape([B, self.n_heads, L, -1]) 85 | 86 | keys, values = map(repeat, (keys, values)) 87 | 88 | if cache is not None: 89 | key_cache, value_cache = cache 90 | queries = self.rope(queries, offset=key_cache.shape[2]) 91 | keys = self.rope(keys, offset=key_cache.shape[2]) 92 | keys = mx.concatenate([key_cache, keys], axis=2) 93 | values = mx.concatenate([value_cache, values], axis=2) 94 | else: 95 | queries = self.rope(queries) 96 | keys = self.rope(keys) 97 | 98 | scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) 99 | if mask is not None: 100 | scores += mask 101 | scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) 102 | output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) 103 | return self.wo(output), (keys, values) 104 | 105 | 106 | class TransformerBlock(nn.Module): 107 | def __init__(self, config: ModelArgs) -> None: 108 | super().__init__() 109 | 110 | self.attention = Attention(config) 111 | self.feed_forward = FeedForward(config) 112 | self.ffn_norm = nn.RMSNorm(config.dims, config.norm_eps) 113 | self.attention_norm = nn.RMSNorm(config.dims, config.norm_eps) 114 | 115 | def __call__(self, x, mask=None): 116 | y = self.attention_norm(x) 117 | y, _ = self.attention(y, mask) 118 | x = x + y 119 | 120 | y = self.ffn_norm(x) 121 | y = self.feed_forward(y) 122 | x = x + y 123 | 124 | return x 125 | 126 | 127 | class Llama(nn.Module): 128 | def __init__(self, config: ModelArgs) -> None: 129 | super().__init__() 130 | 131 | self.embedding = nn.Embedding(config.vocab_size, config.dims) 132 | self.attention = [TransformerBlock(config) for _ in range(config.n_layers)] 133 | self.norm = nn.RMSNorm(config.dims) 134 | self.out_proj = nn.Linear(config.dims, config.vocab_size, bias=False) 135 | 136 | def __call__(self, idx: mx.array): 137 | mask = nn.MultiHeadAttention.create_additive_causal_mask(idx.shape[1]) 138 | mask = mask.astype(self.embedding.weight.dtype) 139 | 140 | x = self.embedding(idx) 141 | for encoding_layer in self.attention: 142 | x = encoding_layer(x, mask) 143 | x = self.norm(x) 144 | 145 | return self.out_proj(x) 146 | 147 | def loss(self, x, y): 148 | logits = self(x) 149 | losses = nn.losses.cross_entropy(logits, y) 150 | mx.simplify(losses) 151 | 152 | return mx.mean(losses) 153 | -------------------------------------------------------------------------------- /llama/optim.py: -------------------------------------------------------------------------------- 1 | from mlx.optimizers import Optimizer, OptimizerState 2 | from typing import List 3 | 4 | import mlx.core as mx 5 | 6 | 7 | class AdamW(Optimizer): 8 | r"""Implementation of the AdamW optimizer [1]. 9 | 10 | Following the above convention, in contrast with [1], we do not use bias 11 | correction in the first and second moments for AdamW. We update the weights 12 | with a weight_decay (λ) value: 13 | 14 | .. math:: 15 | 16 | m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\ 17 | v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\ 18 | \hat{m}_{t+1} &= \frac{m_t}{(1 - \beta_1^t)} 19 | \hat{v}_{t+1} &= \frac{v_t}{(1 - \beta_1^t)} 20 | w_{t+1} &= w_t - \alpha (\frac{\hat{m}_{t+1}}{\sqrt{\hat{v}_{t+1} + \epsilon}} + \lambda w_t) 21 | 22 | [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay 23 | regularization. ICLR 2019. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | learning_rate: float, 29 | betas: List[float] = [0.9, 0.999], 30 | eps: float = 1e-8, 31 | weight_decay: float = 0.01, 32 | ): 33 | super().__init__() 34 | 35 | self.learning_rate = learning_rate 36 | self.betas = betas 37 | self.eps = eps 38 | self.weight_decay = weight_decay 39 | 40 | def apply_single( 41 | self, gradient: mx.array, parameter: mx.array, state: OptimizerState 42 | ): 43 | """Performs the AdamW parameter update and stores :math:`v` and 44 | :math:`m` in the optimizer state.""" 45 | lr = self.learning_rate 46 | b1, b2 = self.betas 47 | eps = self.eps 48 | wd = self.weight_decay 49 | 50 | m = state.get("m", gradient) 51 | v = state.get("v", mx.square(gradient)) 52 | t = state.get("t", 1) 53 | m = b1 * m + (1 - b1) * gradient 54 | v = b2 * v + (1 - b2) * mx.square(gradient) 55 | state["m"] = m 56 | state["v"] = v 57 | state["t"] = t + 1 58 | 59 | m_hat = m / (1. - b1 ** t) 60 | v_hat = v / (1. - b2 ** t) 61 | 62 | return parameter - lr * (m_hat / (mx.sqrt(v_hat) + eps) + wd * parameter) 63 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "mlx-playground" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Joe Barrow "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.11" 10 | mlx = "0.0.5" 11 | sentencepiece = "^0.1.99" 12 | tqdm = "^4.66.1" 13 | einops = "^0.7.0" 14 | torch = "2.1.1" 15 | 16 | 17 | [tool.poetry.group.dev.dependencies] 18 | transformers = "^4.35.2" 19 | torch = "^2.1.1" 20 | 21 | [build-system] 22 | requires = ["poetry-core"] 23 | build-backend = "poetry.core.masonry.api" 24 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Super simple train.py, getting started without any tokenizers, 3 | and with a very simple training loop. 4 | """ 5 | from llama.model import Llama, ModelArgs 6 | from llama.optim import AdamW 7 | from mlx.utils import tree_flatten 8 | from tqdm import tqdm 9 | 10 | import mlx.optimizers as optim 11 | import mlx.core as mx 12 | import mlx.nn as nn 13 | 14 | 15 | lines = open("./data/example.txt", "r").read() 16 | 17 | vocab = sorted(list(set(lines))) 18 | itos = {i: ch for i, ch in enumerate(vocab)} 19 | stoi = {ch: i for i, ch in enumerate(vocab)} 20 | 21 | CONFIG = { 22 | "context_length": 16, 23 | "batch_size": 32, 24 | "steps": 1000, 25 | "learning_rate": 0.001, 26 | } 27 | 28 | 29 | def encode(s): 30 | return [stoi[ch] for ch in s] 31 | 32 | 33 | def decode(l): 34 | return "".join([itos[i] for i in l]) 35 | 36 | 37 | def get_batches( 38 | data: mx.array, split: str, batch_size: int, context_window: int, config=CONFIG 39 | ) -> tuple[mx.array, mx.array]: 40 | train = data[: int(0.8 * len(data))] 41 | val = data[int(0.8 * len(data)) : int(0.9 * len(data))] 42 | test = data[int(0.9 * len(data)) :] 43 | 44 | batch_data = train 45 | if split == "val": 46 | batch_data = val 47 | 48 | if split == "test": 49 | batch_data = test 50 | 51 | ixs = mx.random.randint( 52 | 0, batch_data.shape[0] - context_window - 1, shape=(batch_size,) 53 | ).tolist() 54 | 55 | # create B x C tensors of x and y 56 | x = mx.concatenate( 57 | [mx.expand_dims(batch_data[ix : ix + context_window], 0) for ix in ixs], axis=0 58 | ) 59 | y = mx.concatenate( 60 | [mx.expand_dims(batch_data[ix + 1 : ix + context_window + 1], 0) for ix in ixs], 61 | axis=0, 62 | ) 63 | 64 | return x, y 65 | 66 | 67 | def evaluate_loss(model, config=CONFIG) -> dict[str, mx.array]: 68 | out = {} 69 | 70 | mx.eval(model.parameters()) 71 | for split in ["train", "val"]: 72 | losses = [] 73 | for _ in range(10): 74 | xb, yb = get_batches( 75 | dataset, split, config["batch_size"], config["context_length"], config 76 | ) 77 | loss = model.loss(xb, yb) 78 | losses.append(loss.item()) 79 | out[split] = mx.mean(mx.array(losses)).item() 80 | return out 81 | 82 | 83 | def train(model: nn.Module, optimizer, config=CONFIG): 84 | losses = [] 85 | 86 | loss_and_grad_fn = nn.value_and_grad(model, model.loss) 87 | pbar = tqdm(range(config["steps"])) 88 | 89 | for step in pbar: 90 | xs, ys = get_batches( 91 | dataset, "train", config["batch_size"], config["context_length"] 92 | ) 93 | 94 | loss, grads = loss_and_grad_fn(xs, ys) 95 | model.update(optimizer.apply_gradients(grads, model)) 96 | 97 | mx.simplify(loss, model.parameters()) 98 | # mx.eval(loss, model.parameters()) 99 | losses.append(loss.item()) 100 | 101 | pbar.set_description(f"loss: ({loss.item():.2f})") 102 | 103 | print(evaluate_loss(model)) 104 | 105 | 106 | if __name__ == "__main__": 107 | dataset = mx.array(encode(lines)) 108 | 109 | args = ModelArgs() 110 | model = Llama(args) 111 | 112 | nparams = sum(x.size for k, x in tree_flatten(model.parameters())) 113 | print(f"training a model with {nparams} trainable params") 114 | 115 | optimizer = AdamW( 116 | learning_rate=CONFIG["learning_rate"], weight_decay=0.1, betas=[0.9, 0.95] 117 | ) 118 | # optimizer = optim.Adam(learning_rate=CONFIG["learning_rate"]) 119 | 120 | train(model, optimizer) 121 | --------------------------------------------------------------------------------