├── .gitignore ├── LICENSE ├── README.md ├── bert_fine_tune.py ├── dataset ├── BERT.py ├── __init__.py └── simple_transformers.py ├── llm_inference.py ├── lm_train.py ├── mlx_models ├── BasicLM.py ├── __init__.py ├── configure.sh ├── switch.py ├── tiny_bert.py ├── tiny_llama.py └── whisper.py ├── pytorch_models ├── BasicLM.py ├── __init__.py ├── configure.sh ├── switch.py ├── tiny_bert.py ├── tiny_llama.py └── whisper.py ├── raw_results.txt ├── requirements.txt ├── switch_test.py ├── utils ├── __init__.py └── initializer.py └── whisper_inference.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | tmp/ 3 | .venv/ 4 | TinyLlama-1.1B-Chat-v1.0/ 5 | tiny_llama/ 6 | whisper_tiny_fp32/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lucas Steuernagel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLX-vs-PyTorch 2 | 3 | This repository contains benchmarks for comparing two popular artificial 4 | intelligence frameworks that work on Apple Silicon devices: MLX and PyTorch. 5 | 6 | The idea behind this simple project is to enable a wise choice when starting an 7 | AI project on an Apple computer. 8 | 9 | We ran five benchmarks several times to emulate a day-to-day usage. For more information about 10 | them, please refer to section [Details about each benchmark](#details-about-each-benchmark). 11 | 12 | 1. Training a transformers language model (`lm_train.py`). 13 | 2. Training/fine-tuning BERT (`bert_fine_tune.py`). 14 | 3. Inference using OpenAI's whisper model (`whisper_inference.py`). 15 | 4. Language model inference using TinyLLama (`llm_inference.py`). 16 | 5. A synthetic benchmark that moves data between CPU and GPU for 17 | matrix multiplication (`switch_test.py`). 18 | 19 | 20 | ## Results 21 | 22 | We executed the tests for ten iterations each, except the language model training 23 | and the BERT training ones, for which we ran only three iterations due to the 24 | extra time they took. 25 | 26 | The results on the tables below show the average time for the iterations we ran. 27 | For information about the median of the execution times for each benchmark, refer 28 | to [raw_results.txt](raw_results.txt). 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 |
M1 Pro (10 CPU core, 16 GPU core, 32 GB RAM)
BenchmarkPyTorch time (s)MLX time (s)
Training a transformer
language model
1806.63 1157.00
Training BERT 751.02 718.35
Whisper inference 31.99 8.50
TinyLLama inference 59.27 33.38
CPU/GPU switch 349.72 270.15
71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 |
M1 Max (10 CPU core, 32 GPU core, 64 GB RAM)
BenchmarkPyTorch time (s)MLX time (s)
Training a transformer
language model
1106.75 752.25
Training BERT 793.67 499.34
Whisper inference 21.28 6.95
TinyLLama inference 50.98 20.61
CPU/GPU switch 251.71 214.57
113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 |
M3 Max (16 CPU core, 40 GPU core, 48 GB RAM)
BenchmarkPyTorch time (s)MLX time (s)
Training a transformer
language model
912.52 426.00
Training BERT 550.29 408.45
Whisper inference 17.90 4.85
TinyLLama inference 36.18 15.41
CPU/GPU switch 146.35 140.51
155 | 156 | 157 | ## How to run the benchmarks 158 | 159 | First, make sure you have git LFS installed so that you can configure your repository: 160 | 161 | ``` 162 | pip3 install -r requirements.txt 163 | cd pytorch_models 164 | ./configure.sh 165 | cd .. 166 | cd mlx_models 167 | ./configure.sh 168 | ``` 169 | 170 | Every Python file in the root folder represents a different benchmark. All of them require two arguments: the number 171 | of times to run the benchmark and the framework. If you'd like to run, for example, the TinyLLama inference benchmark 172 | ten times using PyTorch, execute: 173 | 174 | ``` 175 | python3 llm_inference.py --framework pytorch --iter 10 176 | ``` 177 | 178 | When the command finishes, it will print on the terminal the average and median times of the ten iterations. 179 | 180 | ### Additional settings 181 | 182 | The `lm_train.py` benchmark needs the `PYTORCH_MPS_HIGH_WATERMARK_RATIO` environment variable set to zero when used with 183 | PyTorch. 184 | 185 | The `whisper_inference` benchmark only works with the latest commit from the PyTorch repository, so build it from 186 | sources to run this benchmark. 187 | 188 | ## Details about each benchmark 189 | 190 | ### Training a transformers langauge model 191 | 192 | For this benchmark, we copied the model from MLX's [TransformerLM example](https://github.com/ml-explore/mlx-examples/blob/a7598e9456c6455a07ff4905712c2ea3cfcd52db/transformer_lm/main.py#L15). 193 | For the PyTorch version, we utilized the closest functions available to properly replicate the model in another framework. 194 | The dataset utilized is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). For more information about 195 | the model size, epochs and other hyperparameters, refer to [lm_train.py](lm_train.py). 196 | 197 | ### Training/fine-tuning BERT 198 | 199 | We utilized the model presented in [Conneau et al](https://arxiv.org/pdf/1705.02364), using the 200 | [BERT-tiny model](https://huggingface.co/prajjwal1/bert-tiny) for the respective BERT blocks. It classifies pairs of 201 | sentences as having a contradiction, entailment or neutral relation. It was implemented in pure PyTorch and pure 202 | MLX respectively. We do not initialize it with any pre-trained weights, so the benchmark can be seen as pure training. 203 | The dataset for training was the [NLI dataset](https://sbert.net/datasets/AllNLI.tsv.gz). 204 | 205 | The only adaptation in this case was that we used PyTorch dataloader for the MLX model too, as it was compatible with 206 | the tokenizer library. Even though the data loader creates a PyTorch tensor for each input, we can transform it to a 207 | numpy array without extra copies, so this setting did not harm the MLX results. 208 | 209 | ### Whisper inference 210 | 211 | For the PyTorch setting, we used HuggingFace transformers library to download and execute the tiny whisper model. For 212 | the MLX benchmark, we used the [MLX examples tools](https://github.com/ml-explore/mlx-examples/tree/main/whisper) to 213 | download tiny whisper and convert it to the MLX format, using `float32` as the inner data type to match that of PyTorch 214 | (see [mlx_models/configure.sh](mlx_models/configure.sh)). The inference code for MLX leverages the `mlx_whisper` 215 | library. 216 | 217 | ### TinyLLama inference 218 | 219 | For PyTorch, we downloaded the `TinyLlama-1.1B-Chat-v1.0` model from the HuggingFace repository 220 | (see [pytorch_models/configure.sh](pytorch_models/configure.sh)), and use the transformers library to load and execute the model. 221 | 222 | For MLX, we convert the model to the MLX format using the [MLX examples tools](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama), 223 | and use `float32` as the data type to match that of PyTorch. We utilize the execution script from the 224 | [MLX examples repository](https://github.com/ml-explore/mlx-examples/blob/main/llms/llama/llama.py) with several adaptations 225 | to account for the proper prompt formatting and execution constraints. 226 | 227 | 228 | ### CPU/GPU switch 229 | 230 | In this benchmark, we perform matrix multiplications in a loop. First, we multiply matrices in the CPU, then we 231 | multiply the resulting matrices in the GPU. Lastly, we reuse the results from the latter as the input 232 | for the next iteration's CPU multiplication. 233 | 234 | The idea behind this benchmark is to assess how effective each framework's mechanisms are to move data between 235 | execution units. 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /bert_fine_tune.py: -------------------------------------------------------------------------------- 1 | from dataset.BERT import load_nli 2 | from utils.initializer import initialize 3 | from pytorch_models.tiny_bert import train as pytorch_train 4 | from mlx_models.tiny_bert import train as mlx_train 5 | import numpy as np 6 | import time 7 | 8 | num_epochs = 5 9 | batch_size = 8 10 | num_labels = 3 11 | lr = 5e-5 12 | bert_config = { 13 | "hidden_size": 128, 14 | "num_attention_heads": 2, 15 | "num_hidden_layers": 2, 16 | "intermediate_size": 512, 17 | "vocab_size": 30522, 18 | } 19 | 20 | if __name__ == "__main__": 21 | args, times = initialize() 22 | 23 | dataset = load_nli() 24 | 25 | for i in range(0, args.iter): 26 | if args.framework == "mlx": 27 | start = time.time() 28 | mlx_train(num_epochs, batch_size, num_labels, bert_config, lr, dataset) 29 | end = time.time() 30 | 31 | elapsed = end - start 32 | times[i] = elapsed 33 | print(f"MLX time: {elapsed}s") 34 | else: 35 | start = time.time() 36 | pytorch_train(num_epochs, batch_size, num_labels, bert_config, lr, dataset) 37 | end = time.time() 38 | 39 | elapsed = end - start 40 | times[i] = elapsed 41 | print(f"Pytorch time: {elapsed}s") 42 | 43 | print(f"\nBERT fine tune test: ran {args.iter} times") 44 | print( 45 | f"Framework: {args.framework}\n\tAverage: {np.mean(times)}s - Median: {np.median(times)}s" 46 | ) 47 | -------------------------------------------------------------------------------- /dataset/BERT.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import pathlib 3 | import os 4 | import requests 5 | 6 | 7 | def load_nli(): 8 | current_dir = pathlib.Path(__file__).parent.resolve() 9 | parent_save_dir = os.path.join(current_dir, "tmp") 10 | save_dir = os.path.join(parent_save_dir, "nli") 11 | 12 | if not os.path.exists(parent_save_dir): 13 | os.mkdir(parent_save_dir) 14 | 15 | if not os.path.exists(save_dir): 16 | os.mkdir(save_dir) 17 | 18 | url = "https://sbert.net/datasets/AllNLI.tsv.gz" 19 | file_name = os.path.join(save_dir, "AllNLI.tsv.gz") 20 | if not os.path.isfile(file_name): 21 | r = requests.get(url) 22 | open(file_name, "wb").write(r.content) 23 | 24 | test = [[], [], []] 25 | dev = [[], [], []] 26 | train = [[], [], []] 27 | 28 | # TODO: Should this be a torch tensor or an mx array? 29 | mp = {"contradiction": [1, 0, 0], "neutral": [0, 1, 0], "entailment": [0, 0, 1]} 30 | 31 | with gzip.open(file_name, "rb") as file: 32 | train_items = 0 33 | for line in file.readlines(): 34 | line_as_string = line.decode("utf-8") 35 | items = line_as_string.strip("\n\r").split("\t") 36 | # As we do not need more than 50.000 items, we can ignore the rest 37 | if items[0] == "train" and train_items <= 50000: 38 | train[0].append(items[3]) 39 | train[1].append(items[4]) 40 | train[2].append(mp[items[5]]) 41 | train_items += 1 42 | # I am ignoring the test set to save RAM 43 | # elif items[0] == 'test': 44 | # test[0].append(items[3]) 45 | # test[1].append(items[4]) 46 | # test[2].append(mp[items[5]]) 47 | elif items[0] == "dev": 48 | dev[0].append(items[3]) 49 | dev[1].append(items[4]) 50 | dev[2].append(mp[items[5]]) 51 | 52 | samples = {"test": test, "dev": dev, "train": train} 53 | 54 | return samples 55 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSte/MLX-vs-Pytorch/b1c8b97dd313455e3cce6e21ea17e8d08d7045a7/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/simple_transformers.py: -------------------------------------------------------------------------------- 1 | # This file will download the PTB corpus dataset https://paperswithcode.com/dataset/penn-treebank 2 | import pathlib 3 | import os 4 | from urllib import request 5 | import numpy as np 6 | import itertools 7 | 8 | 9 | # Loading function copied from: https://github.com/ml-explore/mlx-examples/blob/main/transformer_lm/datasets.py 10 | def load_ptb(): 11 | current_dir = pathlib.Path(__file__).parent.resolve() 12 | parent_save_dir = os.path.join(current_dir, "tmp") 13 | save_dir = os.path.join(parent_save_dir, "ptb") 14 | 15 | contents = [ 16 | "ptb.train.txt", 17 | "ptb.valid.txt", 18 | "ptb.test.txt", 19 | ] 20 | 21 | if not os.path.exists(parent_save_dir): 22 | os.mkdir(parent_save_dir) 23 | 24 | if not os.path.exists(save_dir): 25 | os.mkdir(save_dir) 26 | 27 | base_url = "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/" 28 | for file_name in contents: 29 | save_path = os.path.join(save_dir, file_name) 30 | if not os.path.exists(save_path): 31 | request.urlretrieve(base_url + file_name, save_path) 32 | 33 | # Loading 34 | with open(os.path.join(save_dir, contents[0]), "r") as f: 35 | vocab = set(t for l in f.readlines() for t in l.strip().split(" ")) 36 | eos = "" 37 | vocab.add(eos) 38 | vocab = {v: i for i, v in enumerate(vocab)} 39 | 40 | def to_array(dataset): 41 | with open(os.path.join(save_dir, dataset), "r") as f: 42 | lines = (l.strip().split(" ") for l in f.readlines()) 43 | return np.array( 44 | [vocab[w] for line in lines for w in itertools.chain(line, [eos])], 45 | dtype=np.uint32, 46 | ) 47 | 48 | datasets = [to_array(fn) for fn in contents] 49 | return vocab, *datasets 50 | -------------------------------------------------------------------------------- /llm_inference.py: -------------------------------------------------------------------------------- 1 | from mlx_models.tiny_llama import MLXLlama 2 | from pytorch_models.tiny_llama import TorchLLama 3 | from utils.initializer import initialize 4 | import numpy as np 5 | import time 6 | 7 | out_filename = "_llm_out.txt" 8 | 9 | prompts = [ 10 | "How to get in a good university?", 11 | "What is artificial intelligence?", 12 | "What is a computer?", 13 | "How to land in an awesome job?", 14 | "How to change the world?", 15 | ] 16 | 17 | max_tokens = 1024 18 | 19 | 20 | def run_model(model) -> float: 21 | start = time.time() 22 | for item in prompts: 23 | model.generate_and_save(item) 24 | end = time.time() 25 | 26 | return end - start 27 | 28 | 29 | if __name__ == "__main__": 30 | args, times = initialize() 31 | 32 | for i in range(0, args.iter): 33 | if args.framework == "mlx": 34 | mlx_model = MLXLlama(max_tokens, "mlx" + out_filename) 35 | elapsed = run_model(mlx_model) 36 | mlx_model.finish() 37 | times[i] = elapsed 38 | print(f"MLX time: {elapsed}s") 39 | else: 40 | torch_model = TorchLLama(max_tokens, "pytorch" + out_filename) 41 | elapsed = run_model(torch_model) 42 | torch_model.finish() 43 | 44 | times[i] = elapsed 45 | print(f"Pytorch time: {elapsed}s") 46 | 47 | print(f"\nLLM inference test: ran {args.iter} times") 48 | print( 49 | f"Framework: {args.framework}\n\tAverage: {np.mean(times)}s - Median: {np.median(times)}s" 50 | ) 51 | -------------------------------------------------------------------------------- /lm_train.py: -------------------------------------------------------------------------------- 1 | from dataset.simple_transformers import load_ptb 2 | from mlx_models.BasicLM import train as mlx_train 3 | from pytorch_models.BasicLM import train as pytorch_train 4 | from utils.initializer import initialize 5 | import numpy as np 6 | import time 7 | 8 | context_size = 1024 9 | num_blocks = 12 10 | dim = 1024 11 | num_heads = 16 12 | epochs = 5 13 | learning_rate = 3e-4 14 | weight_decay = 1e-5 15 | lr_warmup = 200 16 | batch_size = 32 17 | 18 | if __name__ == "__main__": 19 | args, times = initialize() 20 | 21 | data = load_ptb() 22 | 23 | for i in range(0, args.iter): 24 | if args.framework == "mlx": 25 | start = time.time() 26 | mlx_train( 27 | num_blocks, 28 | batch_size, 29 | context_size, 30 | dim, 31 | num_heads, 32 | False, 33 | learning_rate, 34 | weight_decay, 35 | epochs, 36 | lr_warmup, 37 | data, 38 | ) 39 | end = time.time() 40 | elapsed = end - start 41 | print(f"MLX time: {elapsed}s") 42 | times[i] = elapsed 43 | else: 44 | start = time.time() 45 | pytorch_train( 46 | num_blocks, 47 | batch_size, 48 | context_size, 49 | dim, 50 | num_heads, 51 | False, 52 | learning_rate, 53 | weight_decay, 54 | epochs, 55 | lr_warmup, 56 | data, 57 | ) 58 | end = time.time() 59 | elapsed = end - start 60 | print(f"Pytorch time: {elapsed}s") 61 | times[i] = elapsed 62 | 63 | print(f"\nLLM train test: ran {args.iter} times") 64 | print( 65 | f"Framework: {args.framework}\n\tAverage: {np.mean(times)}s - Median: {np.median(times)}s" 66 | ) 67 | -------------------------------------------------------------------------------- /mlx_models/BasicLM.py: -------------------------------------------------------------------------------- 1 | # File copied (and slightly modified) from: https://github.com/ml-explore/mlx-examples/blob/main/transformer_lm/main.py 2 | 3 | import mlx.nn as nn 4 | import mlx.core as mx 5 | import mlx.optimizers as optim 6 | import numpy as np 7 | from functools import partial 8 | import math 9 | 10 | 11 | class TransformerLM(nn.Module): 12 | def __init__( 13 | self, 14 | vocab_size: int, 15 | num_layers: int, 16 | dims: int, 17 | num_heads: int, 18 | checkpoint: bool, 19 | ): 20 | super().__init__() 21 | 22 | self.embedding = nn.Embedding(vocab_size, dims) 23 | self.pe = nn.SinusoidalPositionalEncoding(dims) 24 | self.transformer = nn.TransformerEncoder( 25 | num_layers, dims, num_heads, norm_first=True, checkpoint=checkpoint 26 | ) 27 | self.out_proj = nn.Linear(dims, vocab_size) 28 | 29 | def __call__(self, x): 30 | l_shape = x.shape[1] 31 | mask = nn.MultiHeadAttention.create_additive_causal_mask(l_shape) 32 | x = self.embedding(x) 33 | x = x + self.pe(mx.arange(l_shape)) 34 | x = self.transformer(x, mask) 35 | return self.out_proj(x) 36 | 37 | 38 | def to_samples(context_size, dataset): 39 | tokens = dataset.size 40 | window_size = context_size + 1 41 | samples = tokens - window_size + 1 42 | X = np.lib.stride_tricks.as_strided( 43 | dataset, 44 | shape=(samples, window_size), 45 | strides=(dataset.itemsize, dataset.itemsize), 46 | ) 47 | return X[:, :-1], X[:, 1:] 48 | 49 | 50 | def iterate_batches(batch_size, context_size, dataset): 51 | inputs, targets = to_samples(context_size, dataset) 52 | s = 0 53 | while True: 54 | if s == 0: 55 | # Reset permutation: 56 | perm = np.random.permutation(inputs.shape[0]) 57 | ids = perm[s: s + batch_size] 58 | yield inputs[ids], targets[ids] 59 | s += batch_size 60 | if s >= inputs.shape[0]: 61 | s = 0 62 | 63 | 64 | def train( 65 | num_blocks, 66 | batch_size, 67 | context_size, 68 | dim, 69 | num_heads, 70 | checkpoint, 71 | learning_rate, 72 | weight_decay, 73 | num_iters, 74 | lr_warmup, 75 | ptb_data 76 | ): 77 | mx.set_default_device(mx.gpu) 78 | vocab, train, valid, test = ptb_data 79 | 80 | model = TransformerLM(len(vocab), num_blocks, dim, num_heads, checkpoint) 81 | mx.eval(model.parameters()) 82 | 83 | def loss_fn(model, x, y): 84 | logits = model(x) 85 | losses = nn.losses.cross_entropy(logits, y) 86 | return mx.mean(losses) 87 | 88 | optimizer = optim.AdamW(learning_rate=learning_rate, weight_decay=weight_decay) 89 | 90 | def eval_fn(dataset): 91 | inputs, targets = map(mx.array, to_samples(context_size, dataset)) 92 | loss = 0 93 | for s in range(0, min(targets.shape[0], 1024), batch_size): 94 | bx, by = inputs[s: s + batch_size], targets[s: s + batch_size] 95 | bx, by = map(mx.array, (bx, by)) 96 | losses = loss_fn(model, bx, by) 97 | loss += mx.sum(losses).item() 98 | return loss / len(targets) 99 | 100 | state = [model.state, optimizer.state] 101 | 102 | # @partial(mx.compile, inputs=state, outputs=state) 103 | def step(inputs, targets): 104 | loss_and_grad_fn = nn.value_and_grad(model, loss_fn) 105 | loss, grads = loss_and_grad_fn(model, inputs, targets) 106 | optimizer.update(model, grads) 107 | return loss 108 | 109 | train_iterator = iterate_batches(batch_size, context_size, train) 110 | losses = [] 111 | for it, (inputs, targets) in zip(range(num_iters), train_iterator): 112 | inputs, targets = map(mx.array, (inputs, targets)) 113 | optimizer.learning_rate = min(1, it / lr_warmup) * learning_rate 114 | loss = step(inputs, targets) 115 | mx.eval(state) 116 | losses.append(loss.item()) 117 | train_loss = np.mean(losses) 118 | print(f"Iter {it + 1}: Train loss {train_loss:.3f}, ") 119 | val_loss = eval_fn(valid) 120 | print( 121 | f"Iter {it + 1}: " 122 | f"Val loss {val_loss:.3f}, " 123 | f"Val ppl {math.exp(val_loss):.3f}, " 124 | ) 125 | 126 | test_loss = eval_fn(test) 127 | test_ppl = math.exp(test_loss) 128 | print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") 129 | -------------------------------------------------------------------------------- /mlx_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSte/MLX-vs-Pytorch/b1c8b97dd313455e3cce6e21ea17e8d08d7045a7/mlx_models/__init__.py -------------------------------------------------------------------------------- /mlx_models/configure.sh: -------------------------------------------------------------------------------- 1 | #! bash 2 | 3 | set -e 4 | 5 | DIRECTORY="$PWD"/tmp 6 | 7 | echo "$DIRECTORY" 8 | if [ ! -d "$DIRECTORY" ]; then 9 | mkdir tmp 10 | fi 11 | 12 | pushd tmp 13 | 14 | 15 | MLX_EX="$PWD"/mlx-examples 16 | if [ ! -d "$MLX_EX" ]; then 17 | git clone https://github.com/ml-explore/mlx-examples.git 18 | fi 19 | 20 | pushd mlx-examples 21 | git checkout 09aaeac72caf0547aeacf2f2cac86195aa999cc9 22 | 23 | python3 llms/llama/convert.py --torch-path ../../../pytorch_models/TinyLlama-1.1B-Chat-v1.0 --model-name tiny_llama --dtype float32 --mlx-path ../../tiny_llama 24 | python3 whisper/convert.py --torch-name-or-path tiny --dtype float32 --mlx-path ../../whisper_tiny_fp32 25 | 26 | popd 27 | popd 28 | rm -rf tmp -------------------------------------------------------------------------------- /mlx_models/switch.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import time 3 | 4 | 5 | def multiply_and_divide(op1: mx.array, op2: mx.array, stream) -> mx.array: 6 | mult = mx.matmul(op1, op2, stream=stream) 7 | div = mx.divide(mult, mult.max(0, stream=stream), stream=stream) 8 | return div 9 | 10 | 11 | def multiply_items(op1: mx.array, op2: mx.array, op3: mx.array, stream): 12 | res_1 = multiply_and_divide(op1, op2, stream) 13 | res_2 = multiply_and_divide(op2, op3, stream) 14 | res_3 = multiply_and_divide(op3, op1, stream) 15 | return res_1, res_2, res_3 16 | 17 | 18 | def run_test(iterations: int, size: int, filename: str) -> float: 19 | a = mx.random.uniform(shape=(size, size), stream=mx.cpu, dtype=mx.float32) 20 | b = mx.random.uniform(shape=(size, size), stream=mx.cpu, dtype=mx.float32) 21 | c = mx.random.uniform(shape=(size, size), stream=mx.cpu, dtype=mx.float32) 22 | 23 | start = time.time() 24 | 25 | for _ in range(0, iterations): 26 | mps_1, mps_2, mps_3 = multiply_items(a, b, c, mx.gpu) 27 | mx.eval(mps_1, mps_2, mps_3) 28 | a, b, c = multiply_items(mps_1, mps_2, mps_3, mx.cpu) 29 | mx.eval(a, b, c) 30 | 31 | end = time.time() 32 | 33 | with open(filename, "w") as file: 34 | print(a, file=file, flush=True) 35 | print(b, file=file, flush=True) 36 | print(c, file=file, flush=True) 37 | 38 | duration = end - start 39 | print(f"MLX time: {duration}") 40 | return duration 41 | -------------------------------------------------------------------------------- /mlx_models/tiny_bert.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import mlx.optimizers 4 | import numpy as np 5 | from transformers import AutoTokenizer 6 | from torch.utils.data import DataLoader 7 | 8 | import mlx.nn as nn 9 | import mlx.core as mx 10 | 11 | from pytorch_models.tiny_bert import Config 12 | 13 | 14 | class LayerNorm(nn.Module): 15 | def __init__(self, hidden_size, variance_epsilon=1e-12): 16 | super(LayerNorm, self).__init__() 17 | self.gamma = mx.zeros(hidden_size) 18 | self.beta = mx.zeros(hidden_size) 19 | self.variance_epsilon = variance_epsilon 20 | 21 | def __call__(self, x): 22 | u = mx.mean(x, -1, keepdims=True) 23 | s = mx.power((x - u), 2).mean(-1, keepdims=True) 24 | x = (x - u) / mx.sqrt(s + self.variance_epsilon) 25 | return self.gamma * x + self.beta 26 | 27 | 28 | class PositionWiseMLP(nn.Module): 29 | def __init__(self, hidden_size, intermediate_size): 30 | super(PositionWiseMLP, self).__init__() 31 | self.expansion = nn.Linear(hidden_size, intermediate_size) 32 | self.contraction = nn.Linear(intermediate_size, hidden_size) 33 | self.gelu = nn.GELU() 34 | 35 | def __call__(self, x): 36 | x = self.expansion(x) 37 | x = self.contraction(self.gelu(x)) 38 | return x 39 | 40 | 41 | class Transformer(nn.Module): 42 | def __init__(self, config): 43 | super(Transformer, self).__init__() 44 | 45 | self.hidden_size = config.hidden_size 46 | self.num_attention_heads = config.num_attention_heads 47 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 48 | self.all_head_size = self.num_attention_heads * self.attention_head_size 49 | 50 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 51 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 52 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 53 | 54 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 55 | 56 | self.attn_out = nn.Linear(config.hidden_size, config.hidden_size) 57 | self.ln1 = LayerNorm(config.hidden_size) 58 | 59 | self.mlp = PositionWiseMLP(config.hidden_size, config.intermediate_size) 60 | self.ln2 = LayerNorm(config.hidden_size) 61 | 62 | @staticmethod 63 | def split_heads(tensor, num_heads, attention_head_size): 64 | new_shape = tensor.shape[:-1] + (num_heads, attention_head_size) 65 | tensor = tensor.reshape(new_shape) 66 | return tensor.transpose(0, 2, 1, 3) 67 | 68 | @staticmethod 69 | def merge_heads(tensor, num_heads, attention_head_size): 70 | tensor = tensor.transpose(0, 2, 1, 3) 71 | new_shape = tensor.shape[:-2] + (num_heads * attention_head_size,) 72 | return tensor.reshape(new_shape) 73 | 74 | def attention(self, q, k, v, attention_mask): 75 | mask = attention_mask == 1 76 | mask = mx.expand_dims(mask, 1) 77 | mask = mx.expand_dims(mask, 2) 78 | 79 | weights = mx.matmul(q, k.transpose(0, 1, 3, 2)) 80 | weights = weights / math.sqrt(self.attention_head_size) 81 | weights = mx.where(mask, weights, float(-1e9)) 82 | 83 | logit = nn.softmax(weights, axis=-1) 84 | logit = self.dropout(logit) 85 | 86 | scores = mx.matmul(logit, v) 87 | return scores 88 | 89 | def __call__(self, x, attention_mask): 90 | q, k, v = self.query(x), self.key(x), self.value(x) 91 | 92 | q = self.split_heads(q, self.num_attention_heads, self.attention_head_size) 93 | k = self.split_heads(k, self.num_attention_heads, self.attention_head_size) 94 | v = self.split_heads(v, self.num_attention_heads, self.attention_head_size) 95 | 96 | a = self.attention(q, k, v, attention_mask) 97 | a = self.merge_heads(a, self.num_attention_heads, self.attention_head_size) 98 | a = self.attn_out(a) 99 | a = self.dropout(a) 100 | a = self.ln1(a + x) 101 | 102 | m = self.mlp(a) 103 | m = self.dropout(m) 104 | m = self.ln2(m + a) 105 | 106 | return m 107 | 108 | 109 | class MiniBert(nn.Module): 110 | def __init__(self, config_dict): 111 | super(MiniBert, self).__init__() 112 | self.config = Config.from_dict(config_dict) 113 | 114 | self.token_embedding = nn.Embedding( 115 | self.config.vocab_size, self.config.hidden_size 116 | ) 117 | self.position_embedding = nn.Embedding( 118 | self.config.max_position_embeddings, self.config.hidden_size 119 | ) 120 | self.token_type_embedding = nn.Embedding( 121 | self.config.type_vocab_size, self.config.hidden_size 122 | ) 123 | 124 | self.ln = LayerNorm(self.config.hidden_size) 125 | self.dropout = nn.Dropout(self.config.hidden_dropout_prob) 126 | 127 | self.layers = [ 128 | Transformer(self.config) for _ in range(self.config.num_hidden_layers) 129 | ] 130 | 131 | self.pooler = nn.Sequential( 132 | nn.Linear(self.config.hidden_size, self.config.hidden_size), nn.Tanh() 133 | ) 134 | 135 | def __call__(self, input_ids, attention_mask=None, token_type_ids=None): 136 | position_ids = mx.arange( 137 | input_ids.shape[1], 138 | dtype=mx.int32, 139 | ) 140 | position_ids = mx.expand_dims(position_ids, 0) 141 | position_ids = mx.repeat(position_ids, input_ids.shape[0], 0) 142 | 143 | if token_type_ids is None: 144 | token_type_ids = mx.zeros_like(input_ids) 145 | 146 | x = ( 147 | self.token_embedding(input_ids) 148 | + self.position_embedding(position_ids) 149 | + self.token_type_embedding(token_type_ids) 150 | ) 151 | x = self.dropout(self.ln(x)) 152 | 153 | for layer in self.layers: 154 | x = layer(x, attention_mask) 155 | 156 | o = self.pooler(x[:, 0]) 157 | return x, o 158 | 159 | 160 | class BertFineTuneTask(nn.Module): 161 | def __init__(self, num_labels, bert_config): 162 | super(BertFineTuneTask, self).__init__() 163 | self.bert = MiniBert(bert_config) 164 | self.linear1 = nn.Linear(4 * self.bert.config.hidden_size, num_labels) 165 | 166 | def __call__(self, input_ids, attention_mask): 167 | input_ids = input_ids.reshape(input_ids.shape[0] * 2, input_ids.shape[2]) 168 | attention_mask = attention_mask.reshape( 169 | attention_mask.shape[0] * 2, attention_mask.shape[2] 170 | ) 171 | bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) 172 | 173 | embeds = bert_output[1].reshape( 174 | bert_output[1].shape[0] // 2, 2, bert_output[1].shape[1] 175 | ) 176 | 177 | set_1_embeds = embeds[:, 0] 178 | set_2_embeds = embeds[:, 1] 179 | concat = mx.concatenate( 180 | ( 181 | set_1_embeds, 182 | set_2_embeds, 183 | mx.abs(set_1_embeds - set_2_embeds), 184 | mx.multiply(set_1_embeds, set_2_embeds), 185 | ), 186 | axis=-1, 187 | ) 188 | 189 | lin = self.linear1(concat) 190 | return lin 191 | 192 | 193 | def tokenize_sentence_pair_dataset(dataset, tokenizer, max_length=512): 194 | tokenized_dataset = [] 195 | for i in range(0, len(dataset[0])): 196 | tokenized_dataset.append( 197 | ( 198 | tokenizer( 199 | [dataset[0][i], dataset[1][i]], 200 | return_tensors="np", 201 | padding="max_length", 202 | max_length=max_length, 203 | truncation=True, 204 | ), 205 | np.array(dataset[2][i]), 206 | ) 207 | ) 208 | return tokenized_dataset 209 | 210 | 211 | def train_loop(model, optimizer, train_dataloader, num_epochs, val_dataloader): 212 | mx.eval(model.parameters()) 213 | state = [model.state, optimizer.state] 214 | 215 | def loss_fn(model, x, y): 216 | input_ids, attention_mask = x 217 | embeds = model(input_ids, attention_mask) 218 | losses = nn.losses.cross_entropy(embeds, y, reduction="mean") 219 | return losses 220 | 221 | def step(inputs, targets): 222 | loss_and_grad_fn = nn.value_and_grad(model, loss_fn) 223 | loss, grads = loss_and_grad_fn(model, inputs, targets) 224 | optimizer.update(model, grads) 225 | return loss 226 | 227 | for epoch in range(num_epochs): 228 | print(f"Epoch {epoch + 1} out of {num_epochs}") 229 | epoch_loss = 0 230 | for item in iter(train_dataloader): 231 | # From Pytorch tensor to numpy, there is no copy. 232 | ids = mx.array(item[0]["input_ids"].numpy()) 233 | mask = mx.array(item[0]["attention_mask"].numpy()) 234 | truth = mx.array(item[1].numpy()) 235 | 236 | loss = step((ids, mask), truth) 237 | mx.eval(state) 238 | epoch_loss += loss.item() 239 | 240 | print(f"epoch loss: {epoch_loss}") 241 | dev_loss = 0 242 | 243 | for item in iter(val_dataloader): 244 | ids = mx.array(item[0]["input_ids"].numpy()) 245 | mask = mx.array(item[0]["attention_mask"].numpy()) 246 | truth = mx.array(item[1].numpy()) 247 | 248 | loss = loss_fn(model, (ids, mask), truth) 249 | dev_loss += loss.item() 250 | print(f"validation loss: {dev_loss}") 251 | print() 252 | 253 | 254 | def train(num_epochs, batch_size, num_labels, bert_config, lr, dataset): 255 | mx.set_default_device(mx.gpu) 256 | model_name = "prajjwal1/bert-tiny" 257 | tokenizer = AutoTokenizer.from_pretrained(model_name) 258 | tokenized_train = tokenize_sentence_pair_dataset( 259 | dataset["train"][:50000], tokenizer, max_length=128 260 | ) 261 | tokenized_val = tokenize_sentence_pair_dataset( 262 | dataset["dev"], tokenizer, max_length=128 263 | ) 264 | 265 | train_dataloader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=False) 266 | val_dataloader = DataLoader(tokenized_val, batch_size=batch_size, shuffle=False) 267 | 268 | bert_model = BertFineTuneTask(num_labels, bert_config) 269 | optimizer = mlx.optimizers.AdamW(learning_rate=lr) 270 | train_loop(bert_model, optimizer, train_dataloader, num_epochs, val_dataloader) 271 | -------------------------------------------------------------------------------- /mlx_models/tiny_llama.py: -------------------------------------------------------------------------------- 1 | # File copied from https://github.com/ml-explore/mlx-examples/blob/main/llms/llama/llama.py 2 | # and modified for the benchmark 3 | import pathlib 4 | from dataclasses import dataclass 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | import json 8 | from typing import Optional, Tuple 9 | from pathlib import Path 10 | import glob 11 | from sentencepiece import SentencePieceProcessor 12 | from mlx.utils import tree_unflatten 13 | import os 14 | 15 | 16 | @dataclass 17 | class ModelArgs: 18 | dim: int 19 | n_layers: int 20 | head_dim: int 21 | hidden_dim: int 22 | n_heads: int 23 | n_kv_heads: int 24 | norm_eps: float 25 | vocab_size: int 26 | rope_theta: float 27 | rope_traditional: bool = True 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, args: ModelArgs): 32 | super().__init__() 33 | self.args = args 34 | 35 | self.n_heads: int = args.n_heads 36 | self.n_kv_heads: int = args.n_kv_heads 37 | 38 | self.repeats = self.n_heads // self.n_kv_heads 39 | 40 | self.scale = self.args.head_dim**-0.5 41 | 42 | self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) 43 | self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) 44 | self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) 45 | self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) 46 | self.rope = nn.RoPE( 47 | args.head_dim, traditional=args.rope_traditional, base=args.rope_theta 48 | ) 49 | 50 | def __call__( 51 | self, 52 | x: mx.array, 53 | mask: Optional[mx.array] = None, 54 | cache: Optional[Tuple[mx.array, mx.array]] = None, 55 | ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: 56 | B, L, D = x.shape 57 | 58 | queries, keys, values = self.wq(x), self.wk(x), self.wv(x) 59 | 60 | # Prepare the queries, keys and values for the attention computation 61 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 62 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 63 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 64 | 65 | def repeat(a): 66 | a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) 67 | return a.reshape([B, self.n_heads, L, -1]) 68 | 69 | keys, values = map(repeat, (keys, values)) 70 | 71 | if cache is not None: 72 | key_cache, value_cache = cache 73 | queries = self.rope(queries, offset=key_cache.shape[2]) 74 | keys = self.rope(keys, offset=key_cache.shape[2]) 75 | keys = mx.concatenate([key_cache, keys], axis=2) 76 | values = mx.concatenate([value_cache, values], axis=2) 77 | else: 78 | queries = self.rope(queries) 79 | keys = self.rope(keys) 80 | 81 | scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) 82 | if mask is not None: 83 | scores += mask 84 | scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) 85 | output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) 86 | return self.wo(output), (keys, values) 87 | 88 | 89 | class FeedForward(nn.Module): 90 | def __init__(self, args: ModelArgs): 91 | super().__init__() 92 | 93 | self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) 94 | self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) 95 | self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) 96 | 97 | def __call__(self, x) -> mx.array: 98 | return self.w2(nn.silu(self.w1(x)) * self.w3(x)) 99 | 100 | 101 | class TransformerBlock(nn.Module): 102 | def __init__(self, args: ModelArgs): 103 | super().__init__() 104 | self.n_heads = args.n_heads 105 | self.dim = args.dim 106 | self.attention = Attention(args) 107 | self.feed_forward = FeedForward(args=args) 108 | self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) 109 | self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) 110 | self.args = args 111 | 112 | def __call__( 113 | self, 114 | x: mx.array, 115 | mask: Optional[mx.array] = None, 116 | cache: Optional[Tuple[mx.array, mx.array]] = None, 117 | ) -> mx.array: 118 | r, cache = self.attention(self.attention_norm(x), mask, cache) 119 | h = x + r 120 | r = self.feed_forward(self.ffn_norm(h)) 121 | out = h + r 122 | return out, cache 123 | 124 | 125 | class Llama(nn.Module): 126 | def __init__(self, args: ModelArgs): 127 | super().__init__() 128 | self.args = args 129 | self.vocab_size = args.vocab_size 130 | self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) 131 | self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] 132 | self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps) 133 | self.output = nn.Linear(args.dim, args.vocab_size, bias=False) 134 | 135 | def __call__(self, x): 136 | mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) 137 | mask = mask.astype(self.tok_embeddings.weight.dtype) 138 | 139 | x = self.tok_embeddings(x) 140 | for l in self.layers: 141 | x, _ = l(x, mask) 142 | x = self.norm(x) 143 | return self.output(x) 144 | 145 | def generate(self, x, temp=1.0): 146 | def sample(logits): 147 | if temp == 0: 148 | return mx.argmax(logits, axis=-1) 149 | else: 150 | return mx.random.categorical(logits * (1 / temp)) 151 | 152 | cache = [] 153 | 154 | # Make an additive causal mask. We will need that to process the prompt. 155 | mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) 156 | mask = mask.astype(self.tok_embeddings.weight.dtype) 157 | 158 | # First we process the prompt x the same was as in __call__ but 159 | # save the caches in cache 160 | x = self.tok_embeddings(x) 161 | for l in self.layers: 162 | x, c = l(x, mask=mask) 163 | # We store the per layer cache in a simple python list 164 | cache.append(c) 165 | x = self.norm(x) 166 | # We only care about the last logits that generate the next token 167 | y = self.output(x[:, -1]) 168 | y = sample(y) 169 | 170 | # y now has size [1] 171 | # Since MLX is lazily evaluated nothing is computed yet. 172 | # Calling y.item() would force the computation to happen at 173 | # this point but we can also choose not to do that and let the 174 | # user choose when to start the computation. 175 | yield y 176 | 177 | # Now we parsed the prompt and generated the first token we 178 | # need to feed it back into the model and loop to generate the 179 | # rest. 180 | while True: 181 | # Unsqueezing the last dimension to add a sequence length 182 | # dimension of 1 183 | x = y[:, None] 184 | 185 | x = self.tok_embeddings(x) 186 | for i in range(len(cache)): 187 | # We are overwriting the arrays in the cache list. When 188 | # the computation will happen, MLX will be discarding the 189 | # old cache the moment it is not needed anymore. 190 | x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) 191 | x = self.norm(x) 192 | y = sample(self.output(x[:, -1])) 193 | 194 | yield y 195 | 196 | 197 | def generate(args, model, tokenizer) -> str: 198 | x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) 199 | tokens = [] 200 | for token in model.generate(x, args.temp): 201 | tokens.append(token) 202 | if ( 203 | token.item() == 29958 204 | and tokenizer.decode([t.item() for t in tokens[-5:]]) == "<|user|>" 205 | ): 206 | tokens = tokens[:-5] 207 | break 208 | 209 | if len(tokens) >= args.max_tokens: 210 | break 211 | 212 | mx.eval(tokens) 213 | s = tokenizer.decode([t.item() for t in tokens]) 214 | return s 215 | 216 | 217 | def sanitize_config(config, weights): 218 | config.pop("model_type", None) 219 | n_heads = config["n_heads"] 220 | if "n_kv_heads" not in config: 221 | config["n_kv_heads"] = n_heads 222 | if "head_dim" not in config: 223 | config["head_dim"] = config["dim"] // n_heads 224 | if "hidden_dim" not in config: 225 | config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] 226 | if config.get("vocab_size", -1) < 0: 227 | config["vocab_size"] = weights["output.weight"].shape[-1] 228 | if "rope_theta" not in config: 229 | config["rope_theta"] = 10000 230 | unused = ["multiple_of", "ffn_dim_multiplier"] 231 | for k in unused: 232 | config.pop(k, None) 233 | return config 234 | 235 | 236 | def load_model(model_path): 237 | model_path = Path(model_path) 238 | 239 | unsharded_weights_path = Path(model_path / "weights.npz") 240 | if unsharded_weights_path.is_file(): 241 | weights = mx.load(str(unsharded_weights_path)) 242 | else: 243 | sharded_weights_glob = str(model_path / "weights.*.npz") 244 | weight_files = glob.glob(sharded_weights_glob) 245 | 246 | if len(weight_files) == 0: 247 | raise FileNotFoundError("No weights found in {}".format(model_path)) 248 | 249 | weights = {} 250 | for wf in weight_files: 251 | weights.update(mx.load(wf).items()) 252 | 253 | with open(model_path / "config.json", "r") as f: 254 | config = sanitize_config(json.loads(f.read()), weights) 255 | quantization = config.pop("quantization", None) 256 | model = Llama(ModelArgs(**config)) 257 | if quantization is not None: 258 | nn.quantize(model, **quantization) 259 | model.update(tree_unflatten(list(weights.items()))) 260 | tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model")) 261 | return model, tokenizer 262 | 263 | 264 | @dataclass 265 | class Args: 266 | prompt: str 267 | max_tokens: int 268 | temp: float 269 | 270 | 271 | class MLXLlama: 272 | def __init__(self, max_tokens: int, file: str): 273 | self.max_tokens = max_tokens 274 | 275 | current_dir = pathlib.Path(__file__).parent.resolve() 276 | save_dir = os.path.join(current_dir, "tiny_llama") 277 | 278 | mx.set_default_device(mx.gpu) 279 | self.model, self.tokenizer = load_model(save_dir) 280 | self.out_file = open(file, "w") 281 | 282 | def generate_and_save(self, prompt: str): 283 | formatted_prompt = ( 284 | f"<|system|>\nYou are a friendly chatbot who always responds wisely\n" 285 | f"<|user|>\n{prompt}\n" 286 | "<|assistant|>" 287 | ) 288 | args = Args(temp=0.0, max_tokens=1024, prompt=formatted_prompt) 289 | result = generate(args, self.model, self.tokenizer) 290 | print(result, file=self.out_file, flush=True) 291 | 292 | def finish(self): 293 | self.out_file.close() 294 | -------------------------------------------------------------------------------- /mlx_models/whisper.py: -------------------------------------------------------------------------------- 1 | import mlx_whisper 2 | from datasets import load_dataset 3 | import mlx.core as mx 4 | import pathlib 5 | import os 6 | 7 | 8 | class MLXWhisper: 9 | def __init__(self, num_examples: int, filename: str): 10 | self.dataset = load_dataset( 11 | "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" 12 | ) 13 | self.file = open(filename, "w") 14 | model = "whisper_tiny_fp32" 15 | current_dir = pathlib.Path(__file__).parent.resolve() 16 | self.model_dir = os.path.join(current_dir, model) 17 | 18 | self.num_examples = 1 19 | # This serves to load the model 20 | self.generate() 21 | self.file.truncate(0) 22 | 23 | self.num_examples = num_examples 24 | 25 | def generate(self): 26 | mx.set_default_device(mx.gpu) 27 | for i in range(0, self.num_examples): 28 | audio_sample = self.dataset[i]["audio"] 29 | waveform = audio_sample["array"] 30 | text = mlx_whisper.transcribe( 31 | waveform, path_or_hf_repo=self.model_dir, fp16=False 32 | )["text"] 33 | print(text, file=self.file, flush=True) 34 | 35 | def finish(self): 36 | self.file.close() 37 | -------------------------------------------------------------------------------- /pytorch_models/BasicLM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Optional 4 | import math 5 | from dataset.simple_transformers import load_ptb 6 | from mlx_models.BasicLM import to_samples, iterate_batches 7 | import numpy as np 8 | 9 | mps_dev = torch.device('mps') 10 | 11 | 12 | # Adapted from 13 | # https://github.com/ml-explore/mlx/blob/c4a471c99d0c6e6b085ff944ffef149905296a14/python/mlx/nn/layers/positional_encoding.py#L57 14 | class SinusoidalPositionalEncoding(nn.Module): 15 | def __init__( 16 | self, 17 | dims: int, 18 | min_freq: float = 0.0001, 19 | max_freq: float = 1, 20 | scale: Optional[float] = None, 21 | cos_first: bool = False, 22 | full_turns: bool = False, 23 | ): 24 | super().__init__() 25 | one_zero = 1 - torch.arange(0, dims // 2, device=mps_dev) / (dims // 2 - 1) 26 | min_freq = torch.log(torch.tensor(min_freq)) 27 | max_freq = torch.log(torch.tensor(max_freq)) 28 | 29 | self._sigmas = torch.exp(one_zero * (max_freq - min_freq) + min_freq) 30 | if full_turns: 31 | self._sigmas = self._sigmas * (2 * math.pi) 32 | 33 | self.scale = scale or (2 / dims) ** 0.5 34 | self.cos_first = cos_first 35 | 36 | def forward(self, x): 37 | y = x[..., None] * self._sigmas 38 | cosy = torch.cos(y) 39 | siny = torch.sin(y) 40 | 41 | if self.cos_first: 42 | y = torch.cat((cosy, siny), -1) 43 | else: 44 | y = torch.cat((siny, cosy), -1) 45 | 46 | if self.scale != 1: 47 | y = y * self.scale 48 | 49 | return y 50 | 51 | 52 | class TransformerLM(nn.Module): 53 | def __init__( 54 | self, 55 | vocab_size: int, 56 | num_layers: int, 57 | dims: int, 58 | num_heads: int, 59 | checkpoint: bool, 60 | ): 61 | super().__init__() 62 | 63 | self.embedding = nn.Embedding(vocab_size, dims) 64 | self.pe = SinusoidalPositionalEncoding(dims) 65 | 66 | encoder_layer = nn.TransformerEncoderLayer( 67 | dims, nhead=num_heads, dim_feedforward=dims * 4, norm_first=True, batch_first=True 68 | ) 69 | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) 70 | self.out_proj = nn.Linear(dims, vocab_size) 71 | 72 | @staticmethod 73 | def create_additive_causal_mask(N: int, dtype: torch.dtype = torch.float32): 74 | # Adapted from 75 | # https://github.com/ml-explore/mlx/blob/c4a471c99d0c6e6b085ff944ffef149905296a14/python/mlx/nn/layers/transformer.py#L102 76 | indices = torch.arange(N, device=mps_dev) 77 | mask = indices[:, None] < indices[None] 78 | mask = mask.to(dtype) * -1e9 79 | return mask 80 | 81 | def forward(self, x): 82 | l_shape = x.shape[1] 83 | mask = self.create_additive_causal_mask(l_shape) 84 | x = self.embedding(x) 85 | x = x + self.pe(torch.arange(l_shape, device=mps_dev)) 86 | x = self.transformer(x, mask) 87 | return self.out_proj(x) 88 | 89 | 90 | def mps_tensor(x): 91 | return torch.tensor(x, device=mps_dev) 92 | 93 | 94 | def train( 95 | num_blocks, 96 | batch_size, 97 | context_size, 98 | dim, 99 | num_heads, 100 | checkpoint, 101 | learning_rate, 102 | weight_decay, 103 | num_iters, 104 | lr_warmup, 105 | ptb_data 106 | ): 107 | vocab, train, valid, test = ptb_data 108 | model = TransformerLM(len(vocab), num_blocks, dim, num_heads, checkpoint).to(mps_dev) 109 | 110 | def loss_fn(model, x, y): 111 | logits = model(x) 112 | losses = nn.functional.cross_entropy(logits.permute(0, 2, 1), y) 113 | return losses 114 | 115 | optimizer = torch.optim.AdamW( 116 | model.parameters(), lr=learning_rate, weight_decay=weight_decay 117 | ) 118 | 119 | def eval_fn(dataset): 120 | inputs, targets = to_samples(context_size, dataset) 121 | inputs = torch.tensor(inputs, dtype=torch.int32, device=mps_dev) 122 | targets = torch.tensor(targets, dtype=torch.int32, device=mps_dev) 123 | loss = 0 124 | model.train(False) 125 | with torch.no_grad(): 126 | for s in range(0, min(targets.shape[0], 1024), batch_size): 127 | bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] 128 | losses = loss_fn(model, bx, by) 129 | loss += torch.sum(losses).item() 130 | return loss / len(targets) 131 | 132 | lr_lambda = lambda epoch: min(1, epoch / lr_warmup) * learning_rate 133 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 134 | 135 | # @torch.compile 136 | def step(inputs, targets, it): 137 | optimizer.zero_grad() 138 | loss = loss_fn(model, inputs, targets) 139 | loss.backward() 140 | scheduler.step(it) 141 | return loss 142 | 143 | train_iterator = iterate_batches(batch_size, context_size, train) 144 | losses = [] 145 | for it, (inputs, targets) in zip(range(num_iters), train_iterator): 146 | model.train(True) 147 | inputs = torch.tensor(inputs, dtype=torch.int32, device=mps_dev) 148 | targets = torch.tensor(targets, device=mps_dev, dtype=torch.int32) 149 | loss = step(inputs, targets, it) 150 | losses.append(loss.item()) 151 | train_loss = np.mean(losses) 152 | print(f"Iter {it + 1}: Train loss {train_loss:.3f}, ") 153 | val_loss = eval_fn(valid) 154 | print( 155 | f"Iter {it + 1}: " 156 | f"Val loss {val_loss:.3f}, " 157 | f"Val ppl {math.exp(val_loss):.3f}, " 158 | ) 159 | 160 | test_loss = eval_fn(test) 161 | test_ppl = math.exp(test_loss) 162 | print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") 163 | -------------------------------------------------------------------------------- /pytorch_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSte/MLX-vs-Pytorch/b1c8b97dd313455e3cce6e21ea17e8d08d7045a7/pytorch_models/__init__.py -------------------------------------------------------------------------------- /pytorch_models/configure.sh: -------------------------------------------------------------------------------- 1 | #!bash 2 | 3 | git lfs install 4 | git clone https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0 -------------------------------------------------------------------------------- /pytorch_models/switch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | 5 | def multiply_and_divide(op1: torch.Tensor, op2: torch.Tensor) -> torch.Tensor: 6 | mult = torch.matmul(op1, op2) 7 | div = torch.divide(mult, torch.max(mult, 0).values) 8 | return div 9 | 10 | 11 | def multiply_items(op1: torch.Tensor, op2: torch.Tensor, op3: torch.Tensor): 12 | r_1 = multiply_and_divide(op1, op2) 13 | r_2 = multiply_and_divide(op2, op3) 14 | r_3 = multiply_and_divide(op1, op3) 15 | return r_1, r_2, r_3 16 | 17 | 18 | mps_device = torch.device("mps") 19 | cpu_device = torch.device("cpu") 20 | 21 | 22 | def run_test(iterations: int, size: int, filename: str): 23 | res_1 = torch.rand(size, size, device=cpu_device, dtype=torch.float32) 24 | res_2 = torch.rand(size, size, device=cpu_device, dtype=torch.float32) 25 | res_3 = torch.rand(size, size, device=cpu_device, dtype=torch.float32) 26 | 27 | start = time.time() 28 | 29 | for _ in range(0, iterations): 30 | a = res_1.to(mps_device) 31 | b = res_2.to(mps_device) 32 | c = res_3.to(mps_device) 33 | 34 | mps_1, mps_2, mps_3 = multiply_items(a, b, c) 35 | 36 | cpu_1 = mps_1.to(cpu_device) 37 | cpu_2 = mps_2.to(cpu_device) 38 | cpu_3 = mps_3.to(cpu_device) 39 | 40 | res_1, res_2, res_3 = multiply_items(cpu_1, cpu_2, cpu_3) 41 | 42 | end = time.time() 43 | 44 | with open(filename, "w") as file: 45 | print(res_1, file=file, flush=True) 46 | print(res_2, file=file, flush=True) 47 | print(res_3, file=file, flush=True) 48 | 49 | duration = end - start 50 | print(f"Pytorch time: {duration}") 51 | return duration 52 | -------------------------------------------------------------------------------- /pytorch_models/tiny_bert.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from transformers import AutoTokenizer 7 | from torch.utils.data import DataLoader 8 | 9 | device = torch.device("mps") 10 | 11 | 12 | class Config(object): 13 | def __init__( 14 | self, 15 | vocab_size, 16 | hidden_size=768, 17 | num_hidden_layers=12, 18 | num_attention_heads=12, 19 | intermediate_size=3072, 20 | dropout_prob=0.1, 21 | max_position_embeddings=512, 22 | type_vocab_size=2, 23 | initializer_range=0.02, 24 | ): 25 | self.vocab_size = vocab_size 26 | self.hidden_size = hidden_size 27 | self.num_hidden_layers = num_hidden_layers 28 | self.num_attention_heads = num_attention_heads 29 | self.intermediate_size = intermediate_size 30 | self.hidden_dropout_prob = dropout_prob 31 | self.attention_probs_dropout_prob = dropout_prob 32 | self.max_position_embeddings = max_position_embeddings 33 | self.type_vocab_size = type_vocab_size 34 | self.initializer_range = initializer_range 35 | 36 | @classmethod 37 | def from_dict(cls, dict_object): 38 | config = Config(vocab_size=None) 39 | for key, value in dict_object.items(): 40 | config.__dict__[key] = value 41 | return config 42 | 43 | 44 | class LayerNorm(nn.Module): 45 | def __init__(self, hidden_size, variance_epsilon=1e-12): 46 | super(LayerNorm, self).__init__() 47 | self.gamma = nn.Parameter(torch.ones(hidden_size)) 48 | self.beta = nn.Parameter(torch.zeros(hidden_size)) 49 | self.variance_epsilon = variance_epsilon 50 | 51 | def forward(self, x): 52 | u = x.mean(-1, keepdim=True) 53 | s = (x - u).pow(2).mean(-1, keepdim=True) 54 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 55 | return self.gamma * x + self.beta 56 | 57 | 58 | class PositionWiseMLP(nn.Module): 59 | def __init__(self, hidden_size, intermediate_size): 60 | super(PositionWiseMLP, self).__init__() 61 | self.expansion = nn.Linear(hidden_size, intermediate_size) 62 | self.contraction = nn.Linear(intermediate_size, hidden_size) 63 | 64 | def forward(self, x): 65 | x = self.expansion(x) 66 | x = self.contraction(torch.nn.functional.gelu(x)) 67 | return x 68 | 69 | 70 | class Transformer(nn.Module): 71 | def __init__(self, config): 72 | super(Transformer, self).__init__() 73 | 74 | self.hidden_size = config.hidden_size 75 | self.num_attention_heads = config.num_attention_heads 76 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 77 | self.all_head_size = self.num_attention_heads * self.attention_head_size 78 | 79 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 80 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 81 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 82 | 83 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 84 | 85 | self.attn_out = nn.Linear(config.hidden_size, config.hidden_size) 86 | self.ln1 = LayerNorm(config.hidden_size) 87 | 88 | self.mlp = PositionWiseMLP(config.hidden_size, config.intermediate_size) 89 | self.ln2 = LayerNorm(config.hidden_size) 90 | 91 | @staticmethod 92 | def split_heads(tensor, num_heads, attention_head_size): 93 | new_shape = tensor.size()[:-1] + (num_heads, attention_head_size) 94 | tensor = tensor.view(*new_shape) 95 | return tensor.permute(0, 2, 1, 3) 96 | 97 | @staticmethod 98 | def merge_heads(tensor, num_heads, attention_head_size): 99 | tensor = tensor.permute(0, 2, 1, 3).contiguous() 100 | new_shape = tensor.size()[:-2] + (num_heads * attention_head_size,) 101 | return tensor.view(new_shape) 102 | 103 | def attention(self, q, k, v, attention_mask): 104 | mask = attention_mask == 1 105 | mask = mask.unsqueeze(1).unsqueeze(2) 106 | 107 | weights = torch.matmul(q, k.transpose(-2, -1)) 108 | weights = weights / math.sqrt(self.attention_head_size) 109 | weights = torch.where( 110 | mask, weights, torch.tensor(float(-1e9), device=attention_mask.device) 111 | ) 112 | 113 | logit = torch.nn.functional.softmax(weights, dim=-1) 114 | logit = self.dropout(logit) 115 | 116 | scores = torch.matmul(logit, v) 117 | return scores 118 | 119 | def forward(self, x, attention_mask): 120 | q, k, v = self.query(x), self.key(x), self.value(x) 121 | 122 | q = self.split_heads(q, self.num_attention_heads, self.attention_head_size) 123 | k = self.split_heads(k, self.num_attention_heads, self.attention_head_size) 124 | v = self.split_heads(v, self.num_attention_heads, self.attention_head_size) 125 | 126 | a = self.attention(q, k, v, attention_mask) 127 | a = self.merge_heads(a, self.num_attention_heads, self.attention_head_size) 128 | a = self.attn_out(a) 129 | a = self.dropout(a) 130 | a = self.ln1(a + x) 131 | 132 | m = self.mlp(a) 133 | m = self.dropout(m) 134 | m = self.ln2(m + a) 135 | 136 | return m 137 | 138 | 139 | class MiniBert(nn.Module): 140 | def __init__(self, config_dict): 141 | super(MiniBert, self).__init__() 142 | self.config = Config.from_dict(config_dict) 143 | self.embeddings = nn.ModuleDict( 144 | { 145 | "token": nn.Embedding( 146 | self.config.vocab_size, self.config.hidden_size, padding_idx=0 147 | ), 148 | "position": nn.Embedding( 149 | self.config.max_position_embeddings, self.config.hidden_size 150 | ), 151 | "token_type": nn.Embedding( 152 | self.config.type_vocab_size, self.config.hidden_size 153 | ), 154 | } 155 | ) 156 | 157 | self.ln = LayerNorm(self.config.hidden_size) 158 | self.dropout = nn.Dropout(self.config.hidden_dropout_prob) 159 | 160 | self.layers = nn.ModuleList( 161 | [Transformer(self.config) for _ in range(self.config.num_hidden_layers)] 162 | ) 163 | 164 | # This is a pooling layer for Bert's last layer. 165 | self.pooler = nn.Sequential( 166 | OrderedDict( 167 | [ 168 | ( 169 | "dense", 170 | nn.Linear(self.config.hidden_size, self.config.hidden_size), 171 | ), 172 | ("activation", nn.Tanh()), 173 | ] 174 | ) 175 | ) 176 | 177 | def forward(self, input_ids, attention_mask=None, token_type_ids=None): 178 | position_ids = torch.arange( 179 | input_ids.size(1), dtype=torch.long, device=input_ids.device 180 | ) 181 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 182 | if token_type_ids is None: 183 | token_type_ids = torch.zeros_like(input_ids) 184 | 185 | x = ( 186 | self.embeddings.token(input_ids) 187 | + self.embeddings.position(position_ids) 188 | + self.embeddings.token_type(token_type_ids) 189 | ) 190 | x = self.dropout(self.ln(x)) 191 | 192 | for layer in self.layers: 193 | x = layer(x, attention_mask) 194 | 195 | o = self.pooler(x[:, 0]) 196 | return x, o 197 | 198 | 199 | # BERT fine-tune task: https://arxiv.org/pdf/1705.02364 200 | class BertFineTuneTask(nn.Module): 201 | def __init__(self, num_labels, bert_config): 202 | super(BertFineTuneTask, self).__init__() 203 | self.bert = MiniBert(bert_config) 204 | self.linear1 = torch.nn.Linear(4 * self.bert.config.hidden_size, num_labels) 205 | self.loss_func = torch.nn.CrossEntropyLoss() 206 | 207 | def forward(self, input_ids, attention_mask, ground_truth): 208 | input_ids = input_ids.view(input_ids.size(0) * 2, input_ids.size(2)) 209 | attention_mask = attention_mask.view( 210 | attention_mask.size(0) * 2, attention_mask.size(2) 211 | ) 212 | bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) 213 | 214 | embeds = bert_output[1].view( 215 | bert_output[1].size(0) // 2, 2, bert_output[1].size(1) 216 | ) 217 | 218 | set_1_embeds = embeds[:, 0] 219 | set_2_embeds = embeds[:, 1] 220 | concat = torch.cat( 221 | ( 222 | set_1_embeds, 223 | set_2_embeds, 224 | torch.abs(set_1_embeds - set_2_embeds), 225 | torch.mul(set_1_embeds, set_2_embeds), 226 | ), 227 | dim=-1, 228 | ) 229 | lin = self.linear1(concat) 230 | loss = self.loss_func(lin, ground_truth) 231 | return loss 232 | 233 | 234 | def tokenize_sentence_pair_dataset(dataset, tokenizer, max_length=512): 235 | tokenized_dataset = [] 236 | for i in range(0, len(dataset[0])): 237 | tokenized_dataset.append( 238 | ( 239 | tokenizer( 240 | [dataset[0][i], dataset[1][i]], 241 | return_tensors="pt", 242 | padding="max_length", 243 | max_length=max_length, 244 | truncation=True, 245 | ), 246 | torch.tensor(dataset[2][i], dtype=torch.float32), 247 | ) 248 | ) 249 | 250 | return tokenized_dataset 251 | 252 | 253 | def train_loop(model, optimizer, train_dataloader, num_epochs, val_dataloader): 254 | model.to(device) 255 | model.train(True) 256 | 257 | for epoch in range(num_epochs): 258 | print(f"Epoch {epoch + 1} out of {num_epochs}") 259 | epoch_loss = 0 260 | for item in iter(train_dataloader): 261 | optimizer.zero_grad() 262 | 263 | ids = item[0]["input_ids"].to(device) 264 | mask = item[0]["attention_mask"].to(device) 265 | truth = item[1].to(device) 266 | 267 | loss = model(ids, mask, truth) 268 | epoch_loss += loss 269 | loss.backward() 270 | optimizer.step() 271 | 272 | print(f"epoch loss: {epoch_loss}") 273 | dev_loss = 0 274 | with torch.no_grad(): 275 | for item in iter(val_dataloader): 276 | ids = item[0]["input_ids"].to(device) 277 | mask = item[0]["attention_mask"].to(device) 278 | truth = item[1].to(device) 279 | 280 | loss = model(ids, mask, truth) 281 | dev_loss += loss 282 | print(f"validation loss: {dev_loss}") 283 | print() 284 | 285 | 286 | def train(num_epochs, batch_size, num_labels, bert_config, lr, dataset): 287 | model_name = "prajjwal1/bert-tiny" 288 | tokenizer = AutoTokenizer.from_pretrained(model_name) 289 | tokenized_train = tokenize_sentence_pair_dataset( 290 | dataset["train"][:50000], tokenizer, max_length=128 291 | ) 292 | tokenized_val = tokenize_sentence_pair_dataset( 293 | dataset["dev"], tokenizer, max_length=128 294 | ) 295 | 296 | train_dataloader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=False) 297 | val_dataloader = DataLoader(tokenized_val, batch_size=batch_size, shuffle=False) 298 | 299 | bert_model = BertFineTuneTask(num_labels, bert_config) 300 | optimizer = torch.optim.AdamW(bert_model.parameters(), lr=lr) 301 | train_loop(bert_model, optimizer, train_dataloader, num_epochs, val_dataloader) 302 | -------------------------------------------------------------------------------- /pytorch_models/tiny_llama.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | import transformers 3 | import torch 4 | import pathlib 5 | import os 6 | 7 | 8 | class TorchLLama: 9 | def __init__( 10 | self, 11 | max_tokens: int, 12 | filename: str, 13 | ): 14 | self.max_tokens = max_tokens 15 | self.filename = filename 16 | 17 | model = "TinyLlama-1.1B-Chat-v1.0" 18 | current_dir = pathlib.Path(__file__).parent.resolve() 19 | save_dir = os.path.join(current_dir, model) 20 | self.pipeline = transformers.pipeline( 21 | "text-generation", 22 | model=save_dir, 23 | torch_dtype=torch.float32, 24 | device_map=torch.device("mps"), 25 | ) 26 | self.file = open(filename, "w") 27 | 28 | def generate_and_save(self, prompt: str): 29 | formatted_prompt = ( 30 | f"<|system|>\nYou are a friendly chatbot who always responds wisely\n" 31 | f"<|user|>\n{prompt}\n" 32 | "<|assistant|>" 33 | ) 34 | seqs = self.pipeline( 35 | formatted_prompt, 36 | max_new_tokens=self.max_tokens, 37 | do_sample=False, 38 | ) 39 | print(seqs[0]["generated_text"], file=self.file, flush=True) 40 | 41 | def finish(self): 42 | self.file.close() 43 | -------------------------------------------------------------------------------- /pytorch_models/whisper.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import WhisperProcessor, WhisperForConditionalGeneration 3 | import torch 4 | 5 | device = torch.device("mps") 6 | 7 | 8 | class TorchWhisper: 9 | def __init__(self, num_examples: int, filename: str): 10 | self.num_examples = num_examples 11 | self.dataset = load_dataset( 12 | "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" 13 | ) 14 | self.processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") 15 | self.model = WhisperForConditionalGeneration.from_pretrained( 16 | "openai/whisper-tiny.en" 17 | ).to(device) 18 | self.file = open(filename, "w") 19 | 20 | def generate(self): 21 | for i in range(0, self.num_examples): 22 | audio_sample = self.dataset[i]["audio"] 23 | waveform = audio_sample["array"] 24 | sampling_rate = audio_sample["sampling_rate"] 25 | input_features = self.processor( 26 | waveform, sampling_rate=sampling_rate, return_tensors="pt" 27 | ).input_features 28 | predicted_ids = self.model.generate(input_features.to(device)) 29 | transcription = self.processor.batch_decode( 30 | predicted_ids, skip_special_tokens=True 31 | ) 32 | print(transcription[0], file=self.file, flush=True) 33 | 34 | def finish(self): 35 | self.file.close() 36 | -------------------------------------------------------------------------------- /raw_results.txt: -------------------------------------------------------------------------------- 1 | 2 | ============================= M1 Pro ============================= 3 | 4 | ---- Training a transformers language model ---- 5 | => Values from three executions 6 | 7 | Framework: pytorch 8 | Average: 1806.639595190684s - Median: 1818.6110489368439s 9 | 10 | Framework: mlx 11 | Average: 1157.0066788196564s - Median: 1154.6633532047272s 12 | 13 | 14 | ---- Training BERT ---- 15 | => Values from three executions 16 | 17 | Framework: pytorch 18 | Average: 751.028749704361s - Median: 753.6784768104553s 19 | 20 | Framework: mlx 21 | Average: 718.3554696242014s - Median: 718.211642742157s 22 | 23 | 24 | ---- Whisper inference ---- 25 | => Values from ten executions 26 | 27 | Framework: pytorch 28 | Average: 31.998124384880064s - Median: 31.96485936641693s 29 | 30 | Framework: mlx 31 | Average: 8.509361457824706s - Median: 8.509169936180115s 32 | 33 | 34 | ---- TinyLLama inference ---- 35 | => Values from ten executions 36 | 37 | Framework: pytorch 38 | Average: 59.274635887146s - Median: 55.8025221824646s 39 | 40 | Framework: mlx 41 | Average: 33.38054447174072s - Median: 33.322925329208374s 42 | 43 | 44 | ---- CPU/GPU switch ---- 45 | => Values from ten executions 46 | 47 | Framework: pytorch 48 | Average: 349.7299320459366s - Median: 349.9100536108017s 49 | 50 | Framework: mlx 51 | Average: 270.1572776556015s - Median: 271.8326184749603s 52 | 53 | ================================================================== 54 | 55 | ============================= M1 Max ============================= 56 | 57 | ---- Training a transformers language model ---- 58 | => Values from three executions 59 | 60 | Framework: pytorch 61 | Average: 1106.7549295464829s - Median: 1101.3749282211793s 62 | 63 | Framework: mlx 64 | Average: 752.2592077255249s - Median: 752.2592812234227s 65 | 66 | 67 | ---- Training BERT ---- 68 | => Values from three executions 69 | 70 | Framework: pytorch 71 | Average: 793.6759963925679s - Median: 793.610111951828s 72 | 73 | Framework: mlx 74 | Average: 499.343544960022s - Median: 498.0613958835602s 75 | 76 | 77 | ---- Whisper inference ---- 78 | => Values from ten executions 79 | 80 | Framework: pytorch 81 | Average: 21.27653947197935s - Median: 21.204530119895964s 82 | 83 | Framework: mlx 84 | Average: 6.946336317062378s - Median: 6.937196493148804s 85 | 86 | 87 | ---- TinyLLama inference ---- 88 | => Values from ten executions 89 | 90 | Framework: pytorch 91 | Average: 50.98743650913239s - Median: 47.73779845237732s 92 | 93 | Framework: mlx 94 | Average: 20.61291913986206s - Median: 20.613165140151978s 95 | 96 | 97 | ---- CPU/GPU switch ---- 98 | => Values from ten executions 99 | 100 | Framework: pytorch 101 | Average: 251.71098244190216s - Median: 252.021404504776s 102 | 103 | Framework: mlx 104 | Average: 214.5735635280609s - Median: 214.57463908195496s 105 | 106 | ================================================================== 107 | 108 | ============================= M3 Max ============================= 109 | 110 | ---- Training a transformers language model ---- 111 | => Values from three executions 112 | 113 | Framework: pytorch 114 | Average: 912.5205717086792s - Median: 924.3736660480499s 115 | 116 | Framework: mlx 117 | Average: 426.00355768203735s - Median: 426.1944200992584s 118 | 119 | 120 | ---- Training BERT ---- 121 | => Values from three executions 122 | 123 | Framework: pytorch 124 | Average: 550.2911033630371s - Median: 544.6988077163696s 125 | 126 | Framework: mlx 127 | Average: 408.45258100827533s - Median: 408.62774896621704s 128 | 129 | 130 | ---- Whisper inference ---- 131 | => Values from ten executions 132 | 133 | Framework: pytorch 134 | Average: 17.909158730506896s - Median: 17.877812027931213s 135 | 136 | Framework: mlx 137 | Average: 4.8507798433303835s - Median: 4.839159846305847s 138 | 139 | 140 | ---- TinyLLama inference ---- 141 | => Values from ten executions 142 | 143 | Framework: pytorch 144 | Average: 36.182030129432675s - Median: 34.26037609577179s 145 | 146 | Framework: mlx 147 | Average: 15.41469841003418s - Median: 15.389396786689758s 148 | 149 | 150 | ---- CPU/GPU switch ---- 151 | => Values from ten executions 152 | 153 | Framework: pytorch 154 | Average: 146.35703275203704s - Median: 146.41792500019073s 155 | 156 | Framework: mlx 157 | Average: 140.5102721452713s - Median: 140.51127195358276s 158 | 159 | ================================================================== 160 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.1 2 | aiohttp==3.9.5 3 | aiosignal==1.3.1 4 | attrs==23.2.0 5 | audioread==3.0.1 6 | black==24.4.2 7 | certifi==2024.2.2 8 | cffi==1.16.0 9 | charset-normalizer==3.3.2 10 | click==8.1.7 11 | datasets==2.19.1 12 | decorator==5.1.1 13 | dill==0.3.8 14 | filelock==3.13.4 15 | frozenlist==1.4.1 16 | fsspec==2024.3.1 17 | huggingface-hub==0.23.0 18 | idna==3.7 19 | Jinja2==3.1.3 20 | joblib==1.4.2 21 | lazy_loader==0.4 22 | librosa==0.10.2 23 | llvmlite==0.42.0 24 | MarkupSafe==2.1.5 25 | mlx==0.14.1 26 | mlx-whisper==0.1.0 27 | more-itertools==10.2.0 28 | mpmath==1.3.0 29 | msgpack==1.0.8 30 | multidict==6.0.5 31 | multiprocess==0.70.16 32 | mypy-extensions==1.0.0 33 | networkx==3.3 34 | numba==0.59.1 35 | numpy==1.26.4 36 | packaging==24.0 37 | pandas==2.2.2 38 | pathspec==0.12.1 39 | platformdirs==4.2.1 40 | pooch==1.8.1 41 | psutil==5.9.8 42 | pyarrow==16.0.0 43 | pyarrow-hotfix==0.6 44 | pycparser==2.22 45 | python-dateutil==2.9.0.post0 46 | pytz==2024.1 47 | PyYAML==6.0.1 48 | regex==2024.5.10 49 | requests==2.31.0 50 | safetensors==0.4.3 51 | scikit-learn==1.4.2 52 | scipy==1.13.0 53 | sentencepiece==0.2.0 54 | six==1.16.0 55 | soundfile==0.12.1 56 | soxr==0.3.7 57 | sympy==1.12 58 | threadpoolctl==3.5.0 59 | tiktoken==0.7.0 60 | tokenizers==0.19.1 61 | torch==2.3.0 62 | tqdm==4.66.4 63 | transformers==4.40.2 64 | typing_extensions==4.11.0 65 | tzdata==2024.1 66 | urllib3==2.2.1 67 | xxhash==3.4.1 68 | yarl==1.9.4 69 | -------------------------------------------------------------------------------- /switch_test.py: -------------------------------------------------------------------------------- 1 | from utils.initializer import initialize 2 | from mlx_models.switch import run_test as mlx_run 3 | from pytorch_models.switch import run_test as pytorch_run 4 | import numpy as np 5 | 6 | 7 | loop_times = 50 8 | size = 10000 9 | filename = "_switch.txt" 10 | 11 | if __name__ == "__main__": 12 | args, times = initialize() 13 | 14 | for i in range(args.iter): 15 | if args.framework == "mlx": 16 | times[i] = mlx_run(loop_times, size, "mlx" + filename) 17 | else: 18 | times[i] = pytorch_run(loop_times, size, "pytorch" + filename) 19 | 20 | print(f"\nSwitch test: ran {args.iter} times") 21 | print( 22 | f"Framework: {args.framework}\n\tAverage: {np.mean(times)}s - Median: {np.median(times)}s" 23 | ) 24 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSte/MLX-vs-Pytorch/b1c8b97dd313455e3cce6e21ea17e8d08d7045a7/utils/__init__.py -------------------------------------------------------------------------------- /utils/initializer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | 5 | def initialize(): 6 | parser = argparse.ArgumentParser(description="Language model training benchmark") 7 | parser.add_argument( 8 | "--framework", 9 | help="Which framework to use: pytorch or mlx", 10 | type=str, 11 | ) 12 | parser.add_argument( 13 | "--iter", help="How many times to run the test", default=1, type=int 14 | ) 15 | 16 | args = parser.parse_args() 17 | times = np.zeros(args.iter) 18 | 19 | return args, times 20 | -------------------------------------------------------------------------------- /whisper_inference.py: -------------------------------------------------------------------------------- 1 | from utils.initializer import initialize 2 | import time 3 | from pytorch_models.whisper import TorchWhisper 4 | from mlx_models.whisper import MLXWhisper 5 | import numpy as np 6 | 7 | 8 | num_examples = 70 9 | out_filename = "_whisper.txt" 10 | 11 | if __name__ == "__main__": 12 | args, times = initialize() 13 | 14 | for i in range(args.iter): 15 | if args.framework == "mlx": 16 | mlx_model = MLXWhisper(num_examples, "mlx" + out_filename) 17 | start = time.time() 18 | mlx_model.generate() 19 | end = time.time() 20 | mlx_model.finish() 21 | elapsed = end - start 22 | print(f"MLX time: {elapsed}s") 23 | times[i] = elapsed 24 | else: 25 | torch_model = TorchWhisper(num_examples, "pytorch" + out_filename) 26 | start = time.time() 27 | torch_model.generate() 28 | end = time.time() 29 | torch_model.finish() 30 | elapsed = end - start 31 | print(f"Pytorch time: {elapsed}s") 32 | times[i] = elapsed 33 | 34 | print(f"\nWhisper inference test: ran {args.iter} times") 35 | print( 36 | f"Framework: {args.framework}\n\tAverage: {np.mean(times)}s - Median: {np.median(times)}s" 37 | ) 38 | --------------------------------------------------------------------------------