├── .gitignore
├── LICENSE
├── README.md
├── bert_fine_tune.py
├── dataset
├── BERT.py
├── __init__.py
└── simple_transformers.py
├── llm_inference.py
├── lm_train.py
├── mlx_models
├── BasicLM.py
├── __init__.py
├── configure.sh
├── switch.py
├── tiny_bert.py
├── tiny_llama.py
└── whisper.py
├── pytorch_models
├── BasicLM.py
├── __init__.py
├── configure.sh
├── switch.py
├── tiny_bert.py
├── tiny_llama.py
└── whisper.py
├── raw_results.txt
├── requirements.txt
├── switch_test.py
├── utils
├── __init__.py
└── initializer.py
└── whisper_inference.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | tmp/
3 | .venv/
4 | TinyLlama-1.1B-Chat-v1.0/
5 | tiny_llama/
6 | whisper_tiny_fp32/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Lucas Steuernagel
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MLX-vs-PyTorch
2 |
3 | This repository contains benchmarks for comparing two popular artificial
4 | intelligence frameworks that work on Apple Silicon devices: MLX and PyTorch.
5 |
6 | The idea behind this simple project is to enable a wise choice when starting an
7 | AI project on an Apple computer.
8 |
9 | We ran five benchmarks several times to emulate a day-to-day usage. For more information about
10 | them, please refer to section [Details about each benchmark](#details-about-each-benchmark).
11 |
12 | 1. Training a transformers language model (`lm_train.py`).
13 | 2. Training/fine-tuning BERT (`bert_fine_tune.py`).
14 | 3. Inference using OpenAI's whisper model (`whisper_inference.py`).
15 | 4. Language model inference using TinyLLama (`llm_inference.py`).
16 | 5. A synthetic benchmark that moves data between CPU and GPU for
17 | matrix multiplication (`switch_test.py`).
18 |
19 |
20 | ## Results
21 |
22 | We executed the tests for ten iterations each, except the language model training
23 | and the BERT training ones, for which we ran only three iterations due to the
24 | extra time they took.
25 |
26 | The results on the tables below show the average time for the iterations we ran.
27 | For information about the median of the execution times for each benchmark, refer
28 | to [raw_results.txt](raw_results.txt).
29 |
30 |
31 |
32 |
33 | M1 Pro (10 CPU core, 16 GPU core, 32 GB RAM) |
34 |
35 |
36 |
37 |
38 | Benchmark |
39 | PyTorch time (s) |
40 | MLX time (s) |
41 |
42 |
43 |
44 |
45 | Training a transformer language model |
46 | 1806.63 |
47 | 1157.00 |
48 |
49 |
50 | Training BERT |
51 | 751.02 |
52 | 718.35 |
53 |
54 |
55 | Whisper inference |
56 | 31.99 |
57 | 8.50 |
58 |
59 |
60 | TinyLLama inference |
61 | 59.27 |
62 | 33.38 |
63 |
64 |
65 | CPU/GPU switch |
66 | 349.72 |
67 | 270.15 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 | M1 Max (10 CPU core, 32 GPU core, 64 GB RAM) |
76 |
77 |
78 |
79 |
80 | Benchmark |
81 | PyTorch time (s) |
82 | MLX time (s) |
83 |
84 |
85 |
86 |
87 | Training a transformer language model |
88 | 1106.75 |
89 | 752.25 |
90 |
91 |
92 | Training BERT |
93 | 793.67 |
94 | 499.34 |
95 |
96 |
97 | Whisper inference |
98 | 21.28 |
99 | 6.95 |
100 |
101 |
102 | TinyLLama inference |
103 | 50.98 |
104 | 20.61 |
105 |
106 |
107 | CPU/GPU switch |
108 | 251.71 |
109 | 214.57 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 | M3 Max (16 CPU core, 40 GPU core, 48 GB RAM) |
118 |
119 |
120 |
121 |
122 | Benchmark |
123 | PyTorch time (s) |
124 | MLX time (s) |
125 |
126 |
127 |
128 |
129 | Training a transformer language model |
130 | 912.52 |
131 | 426.00 |
132 |
133 |
134 | Training BERT |
135 | 550.29 |
136 | 408.45 |
137 |
138 |
139 | Whisper inference |
140 | 17.90 |
141 | 4.85 |
142 |
143 |
144 | TinyLLama inference |
145 | 36.18 |
146 | 15.41 |
147 |
148 |
149 | CPU/GPU switch |
150 | 146.35 |
151 | 140.51 |
152 |
153 |
154 |
155 |
156 |
157 | ## How to run the benchmarks
158 |
159 | First, make sure you have git LFS installed so that you can configure your repository:
160 |
161 | ```
162 | pip3 install -r requirements.txt
163 | cd pytorch_models
164 | ./configure.sh
165 | cd ..
166 | cd mlx_models
167 | ./configure.sh
168 | ```
169 |
170 | Every Python file in the root folder represents a different benchmark. All of them require two arguments: the number
171 | of times to run the benchmark and the framework. If you'd like to run, for example, the TinyLLama inference benchmark
172 | ten times using PyTorch, execute:
173 |
174 | ```
175 | python3 llm_inference.py --framework pytorch --iter 10
176 | ```
177 |
178 | When the command finishes, it will print on the terminal the average and median times of the ten iterations.
179 |
180 | ### Additional settings
181 |
182 | The `lm_train.py` benchmark needs the `PYTORCH_MPS_HIGH_WATERMARK_RATIO` environment variable set to zero when used with
183 | PyTorch.
184 |
185 | The `whisper_inference` benchmark only works with the latest commit from the PyTorch repository, so build it from
186 | sources to run this benchmark.
187 |
188 | ## Details about each benchmark
189 |
190 | ### Training a transformers langauge model
191 |
192 | For this benchmark, we copied the model from MLX's [TransformerLM example](https://github.com/ml-explore/mlx-examples/blob/a7598e9456c6455a07ff4905712c2ea3cfcd52db/transformer_lm/main.py#L15).
193 | For the PyTorch version, we utilized the closest functions available to properly replicate the model in another framework.
194 | The dataset utilized is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). For more information about
195 | the model size, epochs and other hyperparameters, refer to [lm_train.py](lm_train.py).
196 |
197 | ### Training/fine-tuning BERT
198 |
199 | We utilized the model presented in [Conneau et al](https://arxiv.org/pdf/1705.02364), using the
200 | [BERT-tiny model](https://huggingface.co/prajjwal1/bert-tiny) for the respective BERT blocks. It classifies pairs of
201 | sentences as having a contradiction, entailment or neutral relation. It was implemented in pure PyTorch and pure
202 | MLX respectively. We do not initialize it with any pre-trained weights, so the benchmark can be seen as pure training.
203 | The dataset for training was the [NLI dataset](https://sbert.net/datasets/AllNLI.tsv.gz).
204 |
205 | The only adaptation in this case was that we used PyTorch dataloader for the MLX model too, as it was compatible with
206 | the tokenizer library. Even though the data loader creates a PyTorch tensor for each input, we can transform it to a
207 | numpy array without extra copies, so this setting did not harm the MLX results.
208 |
209 | ### Whisper inference
210 |
211 | For the PyTorch setting, we used HuggingFace transformers library to download and execute the tiny whisper model. For
212 | the MLX benchmark, we used the [MLX examples tools](https://github.com/ml-explore/mlx-examples/tree/main/whisper) to
213 | download tiny whisper and convert it to the MLX format, using `float32` as the inner data type to match that of PyTorch
214 | (see [mlx_models/configure.sh](mlx_models/configure.sh)). The inference code for MLX leverages the `mlx_whisper`
215 | library.
216 |
217 | ### TinyLLama inference
218 |
219 | For PyTorch, we downloaded the `TinyLlama-1.1B-Chat-v1.0` model from the HuggingFace repository
220 | (see [pytorch_models/configure.sh](pytorch_models/configure.sh)), and use the transformers library to load and execute the model.
221 |
222 | For MLX, we convert the model to the MLX format using the [MLX examples tools](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama),
223 | and use `float32` as the data type to match that of PyTorch. We utilize the execution script from the
224 | [MLX examples repository](https://github.com/ml-explore/mlx-examples/blob/main/llms/llama/llama.py) with several adaptations
225 | to account for the proper prompt formatting and execution constraints.
226 |
227 |
228 | ### CPU/GPU switch
229 |
230 | In this benchmark, we perform matrix multiplications in a loop. First, we multiply matrices in the CPU, then we
231 | multiply the resulting matrices in the GPU. Lastly, we reuse the results from the latter as the input
232 | for the next iteration's CPU multiplication.
233 |
234 | The idea behind this benchmark is to assess how effective each framework's mechanisms are to move data between
235 | execution units.
236 |
237 |
238 |
239 |
--------------------------------------------------------------------------------
/bert_fine_tune.py:
--------------------------------------------------------------------------------
1 | from dataset.BERT import load_nli
2 | from utils.initializer import initialize
3 | from pytorch_models.tiny_bert import train as pytorch_train
4 | from mlx_models.tiny_bert import train as mlx_train
5 | import numpy as np
6 | import time
7 |
8 | num_epochs = 5
9 | batch_size = 8
10 | num_labels = 3
11 | lr = 5e-5
12 | bert_config = {
13 | "hidden_size": 128,
14 | "num_attention_heads": 2,
15 | "num_hidden_layers": 2,
16 | "intermediate_size": 512,
17 | "vocab_size": 30522,
18 | }
19 |
20 | if __name__ == "__main__":
21 | args, times = initialize()
22 |
23 | dataset = load_nli()
24 |
25 | for i in range(0, args.iter):
26 | if args.framework == "mlx":
27 | start = time.time()
28 | mlx_train(num_epochs, batch_size, num_labels, bert_config, lr, dataset)
29 | end = time.time()
30 |
31 | elapsed = end - start
32 | times[i] = elapsed
33 | print(f"MLX time: {elapsed}s")
34 | else:
35 | start = time.time()
36 | pytorch_train(num_epochs, batch_size, num_labels, bert_config, lr, dataset)
37 | end = time.time()
38 |
39 | elapsed = end - start
40 | times[i] = elapsed
41 | print(f"Pytorch time: {elapsed}s")
42 |
43 | print(f"\nBERT fine tune test: ran {args.iter} times")
44 | print(
45 | f"Framework: {args.framework}\n\tAverage: {np.mean(times)}s - Median: {np.median(times)}s"
46 | )
47 |
--------------------------------------------------------------------------------
/dataset/BERT.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import pathlib
3 | import os
4 | import requests
5 |
6 |
7 | def load_nli():
8 | current_dir = pathlib.Path(__file__).parent.resolve()
9 | parent_save_dir = os.path.join(current_dir, "tmp")
10 | save_dir = os.path.join(parent_save_dir, "nli")
11 |
12 | if not os.path.exists(parent_save_dir):
13 | os.mkdir(parent_save_dir)
14 |
15 | if not os.path.exists(save_dir):
16 | os.mkdir(save_dir)
17 |
18 | url = "https://sbert.net/datasets/AllNLI.tsv.gz"
19 | file_name = os.path.join(save_dir, "AllNLI.tsv.gz")
20 | if not os.path.isfile(file_name):
21 | r = requests.get(url)
22 | open(file_name, "wb").write(r.content)
23 |
24 | test = [[], [], []]
25 | dev = [[], [], []]
26 | train = [[], [], []]
27 |
28 | # TODO: Should this be a torch tensor or an mx array?
29 | mp = {"contradiction": [1, 0, 0], "neutral": [0, 1, 0], "entailment": [0, 0, 1]}
30 |
31 | with gzip.open(file_name, "rb") as file:
32 | train_items = 0
33 | for line in file.readlines():
34 | line_as_string = line.decode("utf-8")
35 | items = line_as_string.strip("\n\r").split("\t")
36 | # As we do not need more than 50.000 items, we can ignore the rest
37 | if items[0] == "train" and train_items <= 50000:
38 | train[0].append(items[3])
39 | train[1].append(items[4])
40 | train[2].append(mp[items[5]])
41 | train_items += 1
42 | # I am ignoring the test set to save RAM
43 | # elif items[0] == 'test':
44 | # test[0].append(items[3])
45 | # test[1].append(items[4])
46 | # test[2].append(mp[items[5]])
47 | elif items[0] == "dev":
48 | dev[0].append(items[3])
49 | dev[1].append(items[4])
50 | dev[2].append(mp[items[5]])
51 |
52 | samples = {"test": test, "dev": dev, "train": train}
53 |
54 | return samples
55 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSte/MLX-vs-Pytorch/b1c8b97dd313455e3cce6e21ea17e8d08d7045a7/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/simple_transformers.py:
--------------------------------------------------------------------------------
1 | # This file will download the PTB corpus dataset https://paperswithcode.com/dataset/penn-treebank
2 | import pathlib
3 | import os
4 | from urllib import request
5 | import numpy as np
6 | import itertools
7 |
8 |
9 | # Loading function copied from: https://github.com/ml-explore/mlx-examples/blob/main/transformer_lm/datasets.py
10 | def load_ptb():
11 | current_dir = pathlib.Path(__file__).parent.resolve()
12 | parent_save_dir = os.path.join(current_dir, "tmp")
13 | save_dir = os.path.join(parent_save_dir, "ptb")
14 |
15 | contents = [
16 | "ptb.train.txt",
17 | "ptb.valid.txt",
18 | "ptb.test.txt",
19 | ]
20 |
21 | if not os.path.exists(parent_save_dir):
22 | os.mkdir(parent_save_dir)
23 |
24 | if not os.path.exists(save_dir):
25 | os.mkdir(save_dir)
26 |
27 | base_url = "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/"
28 | for file_name in contents:
29 | save_path = os.path.join(save_dir, file_name)
30 | if not os.path.exists(save_path):
31 | request.urlretrieve(base_url + file_name, save_path)
32 |
33 | # Loading
34 | with open(os.path.join(save_dir, contents[0]), "r") as f:
35 | vocab = set(t for l in f.readlines() for t in l.strip().split(" "))
36 | eos = ""
37 | vocab.add(eos)
38 | vocab = {v: i for i, v in enumerate(vocab)}
39 |
40 | def to_array(dataset):
41 | with open(os.path.join(save_dir, dataset), "r") as f:
42 | lines = (l.strip().split(" ") for l in f.readlines())
43 | return np.array(
44 | [vocab[w] for line in lines for w in itertools.chain(line, [eos])],
45 | dtype=np.uint32,
46 | )
47 |
48 | datasets = [to_array(fn) for fn in contents]
49 | return vocab, *datasets
50 |
--------------------------------------------------------------------------------
/llm_inference.py:
--------------------------------------------------------------------------------
1 | from mlx_models.tiny_llama import MLXLlama
2 | from pytorch_models.tiny_llama import TorchLLama
3 | from utils.initializer import initialize
4 | import numpy as np
5 | import time
6 |
7 | out_filename = "_llm_out.txt"
8 |
9 | prompts = [
10 | "How to get in a good university?",
11 | "What is artificial intelligence?",
12 | "What is a computer?",
13 | "How to land in an awesome job?",
14 | "How to change the world?",
15 | ]
16 |
17 | max_tokens = 1024
18 |
19 |
20 | def run_model(model) -> float:
21 | start = time.time()
22 | for item in prompts:
23 | model.generate_and_save(item)
24 | end = time.time()
25 |
26 | return end - start
27 |
28 |
29 | if __name__ == "__main__":
30 | args, times = initialize()
31 |
32 | for i in range(0, args.iter):
33 | if args.framework == "mlx":
34 | mlx_model = MLXLlama(max_tokens, "mlx" + out_filename)
35 | elapsed = run_model(mlx_model)
36 | mlx_model.finish()
37 | times[i] = elapsed
38 | print(f"MLX time: {elapsed}s")
39 | else:
40 | torch_model = TorchLLama(max_tokens, "pytorch" + out_filename)
41 | elapsed = run_model(torch_model)
42 | torch_model.finish()
43 |
44 | times[i] = elapsed
45 | print(f"Pytorch time: {elapsed}s")
46 |
47 | print(f"\nLLM inference test: ran {args.iter} times")
48 | print(
49 | f"Framework: {args.framework}\n\tAverage: {np.mean(times)}s - Median: {np.median(times)}s"
50 | )
51 |
--------------------------------------------------------------------------------
/lm_train.py:
--------------------------------------------------------------------------------
1 | from dataset.simple_transformers import load_ptb
2 | from mlx_models.BasicLM import train as mlx_train
3 | from pytorch_models.BasicLM import train as pytorch_train
4 | from utils.initializer import initialize
5 | import numpy as np
6 | import time
7 |
8 | context_size = 1024
9 | num_blocks = 12
10 | dim = 1024
11 | num_heads = 16
12 | epochs = 5
13 | learning_rate = 3e-4
14 | weight_decay = 1e-5
15 | lr_warmup = 200
16 | batch_size = 32
17 |
18 | if __name__ == "__main__":
19 | args, times = initialize()
20 |
21 | data = load_ptb()
22 |
23 | for i in range(0, args.iter):
24 | if args.framework == "mlx":
25 | start = time.time()
26 | mlx_train(
27 | num_blocks,
28 | batch_size,
29 | context_size,
30 | dim,
31 | num_heads,
32 | False,
33 | learning_rate,
34 | weight_decay,
35 | epochs,
36 | lr_warmup,
37 | data,
38 | )
39 | end = time.time()
40 | elapsed = end - start
41 | print(f"MLX time: {elapsed}s")
42 | times[i] = elapsed
43 | else:
44 | start = time.time()
45 | pytorch_train(
46 | num_blocks,
47 | batch_size,
48 | context_size,
49 | dim,
50 | num_heads,
51 | False,
52 | learning_rate,
53 | weight_decay,
54 | epochs,
55 | lr_warmup,
56 | data,
57 | )
58 | end = time.time()
59 | elapsed = end - start
60 | print(f"Pytorch time: {elapsed}s")
61 | times[i] = elapsed
62 |
63 | print(f"\nLLM train test: ran {args.iter} times")
64 | print(
65 | f"Framework: {args.framework}\n\tAverage: {np.mean(times)}s - Median: {np.median(times)}s"
66 | )
67 |
--------------------------------------------------------------------------------
/mlx_models/BasicLM.py:
--------------------------------------------------------------------------------
1 | # File copied (and slightly modified) from: https://github.com/ml-explore/mlx-examples/blob/main/transformer_lm/main.py
2 |
3 | import mlx.nn as nn
4 | import mlx.core as mx
5 | import mlx.optimizers as optim
6 | import numpy as np
7 | from functools import partial
8 | import math
9 |
10 |
11 | class TransformerLM(nn.Module):
12 | def __init__(
13 | self,
14 | vocab_size: int,
15 | num_layers: int,
16 | dims: int,
17 | num_heads: int,
18 | checkpoint: bool,
19 | ):
20 | super().__init__()
21 |
22 | self.embedding = nn.Embedding(vocab_size, dims)
23 | self.pe = nn.SinusoidalPositionalEncoding(dims)
24 | self.transformer = nn.TransformerEncoder(
25 | num_layers, dims, num_heads, norm_first=True, checkpoint=checkpoint
26 | )
27 | self.out_proj = nn.Linear(dims, vocab_size)
28 |
29 | def __call__(self, x):
30 | l_shape = x.shape[1]
31 | mask = nn.MultiHeadAttention.create_additive_causal_mask(l_shape)
32 | x = self.embedding(x)
33 | x = x + self.pe(mx.arange(l_shape))
34 | x = self.transformer(x, mask)
35 | return self.out_proj(x)
36 |
37 |
38 | def to_samples(context_size, dataset):
39 | tokens = dataset.size
40 | window_size = context_size + 1
41 | samples = tokens - window_size + 1
42 | X = np.lib.stride_tricks.as_strided(
43 | dataset,
44 | shape=(samples, window_size),
45 | strides=(dataset.itemsize, dataset.itemsize),
46 | )
47 | return X[:, :-1], X[:, 1:]
48 |
49 |
50 | def iterate_batches(batch_size, context_size, dataset):
51 | inputs, targets = to_samples(context_size, dataset)
52 | s = 0
53 | while True:
54 | if s == 0:
55 | # Reset permutation:
56 | perm = np.random.permutation(inputs.shape[0])
57 | ids = perm[s: s + batch_size]
58 | yield inputs[ids], targets[ids]
59 | s += batch_size
60 | if s >= inputs.shape[0]:
61 | s = 0
62 |
63 |
64 | def train(
65 | num_blocks,
66 | batch_size,
67 | context_size,
68 | dim,
69 | num_heads,
70 | checkpoint,
71 | learning_rate,
72 | weight_decay,
73 | num_iters,
74 | lr_warmup,
75 | ptb_data
76 | ):
77 | mx.set_default_device(mx.gpu)
78 | vocab, train, valid, test = ptb_data
79 |
80 | model = TransformerLM(len(vocab), num_blocks, dim, num_heads, checkpoint)
81 | mx.eval(model.parameters())
82 |
83 | def loss_fn(model, x, y):
84 | logits = model(x)
85 | losses = nn.losses.cross_entropy(logits, y)
86 | return mx.mean(losses)
87 |
88 | optimizer = optim.AdamW(learning_rate=learning_rate, weight_decay=weight_decay)
89 |
90 | def eval_fn(dataset):
91 | inputs, targets = map(mx.array, to_samples(context_size, dataset))
92 | loss = 0
93 | for s in range(0, min(targets.shape[0], 1024), batch_size):
94 | bx, by = inputs[s: s + batch_size], targets[s: s + batch_size]
95 | bx, by = map(mx.array, (bx, by))
96 | losses = loss_fn(model, bx, by)
97 | loss += mx.sum(losses).item()
98 | return loss / len(targets)
99 |
100 | state = [model.state, optimizer.state]
101 |
102 | # @partial(mx.compile, inputs=state, outputs=state)
103 | def step(inputs, targets):
104 | loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
105 | loss, grads = loss_and_grad_fn(model, inputs, targets)
106 | optimizer.update(model, grads)
107 | return loss
108 |
109 | train_iterator = iterate_batches(batch_size, context_size, train)
110 | losses = []
111 | for it, (inputs, targets) in zip(range(num_iters), train_iterator):
112 | inputs, targets = map(mx.array, (inputs, targets))
113 | optimizer.learning_rate = min(1, it / lr_warmup) * learning_rate
114 | loss = step(inputs, targets)
115 | mx.eval(state)
116 | losses.append(loss.item())
117 | train_loss = np.mean(losses)
118 | print(f"Iter {it + 1}: Train loss {train_loss:.3f}, ")
119 | val_loss = eval_fn(valid)
120 | print(
121 | f"Iter {it + 1}: "
122 | f"Val loss {val_loss:.3f}, "
123 | f"Val ppl {math.exp(val_loss):.3f}, "
124 | )
125 |
126 | test_loss = eval_fn(test)
127 | test_ppl = math.exp(test_loss)
128 | print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
129 |
--------------------------------------------------------------------------------
/mlx_models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSte/MLX-vs-Pytorch/b1c8b97dd313455e3cce6e21ea17e8d08d7045a7/mlx_models/__init__.py
--------------------------------------------------------------------------------
/mlx_models/configure.sh:
--------------------------------------------------------------------------------
1 | #! bash
2 |
3 | set -e
4 |
5 | DIRECTORY="$PWD"/tmp
6 |
7 | echo "$DIRECTORY"
8 | if [ ! -d "$DIRECTORY" ]; then
9 | mkdir tmp
10 | fi
11 |
12 | pushd tmp
13 |
14 |
15 | MLX_EX="$PWD"/mlx-examples
16 | if [ ! -d "$MLX_EX" ]; then
17 | git clone https://github.com/ml-explore/mlx-examples.git
18 | fi
19 |
20 | pushd mlx-examples
21 | git checkout 09aaeac72caf0547aeacf2f2cac86195aa999cc9
22 |
23 | python3 llms/llama/convert.py --torch-path ../../../pytorch_models/TinyLlama-1.1B-Chat-v1.0 --model-name tiny_llama --dtype float32 --mlx-path ../../tiny_llama
24 | python3 whisper/convert.py --torch-name-or-path tiny --dtype float32 --mlx-path ../../whisper_tiny_fp32
25 |
26 | popd
27 | popd
28 | rm -rf tmp
--------------------------------------------------------------------------------
/mlx_models/switch.py:
--------------------------------------------------------------------------------
1 | import mlx.core as mx
2 | import time
3 |
4 |
5 | def multiply_and_divide(op1: mx.array, op2: mx.array, stream) -> mx.array:
6 | mult = mx.matmul(op1, op2, stream=stream)
7 | div = mx.divide(mult, mult.max(0, stream=stream), stream=stream)
8 | return div
9 |
10 |
11 | def multiply_items(op1: mx.array, op2: mx.array, op3: mx.array, stream):
12 | res_1 = multiply_and_divide(op1, op2, stream)
13 | res_2 = multiply_and_divide(op2, op3, stream)
14 | res_3 = multiply_and_divide(op3, op1, stream)
15 | return res_1, res_2, res_3
16 |
17 |
18 | def run_test(iterations: int, size: int, filename: str) -> float:
19 | a = mx.random.uniform(shape=(size, size), stream=mx.cpu, dtype=mx.float32)
20 | b = mx.random.uniform(shape=(size, size), stream=mx.cpu, dtype=mx.float32)
21 | c = mx.random.uniform(shape=(size, size), stream=mx.cpu, dtype=mx.float32)
22 |
23 | start = time.time()
24 |
25 | for _ in range(0, iterations):
26 | mps_1, mps_2, mps_3 = multiply_items(a, b, c, mx.gpu)
27 | mx.eval(mps_1, mps_2, mps_3)
28 | a, b, c = multiply_items(mps_1, mps_2, mps_3, mx.cpu)
29 | mx.eval(a, b, c)
30 |
31 | end = time.time()
32 |
33 | with open(filename, "w") as file:
34 | print(a, file=file, flush=True)
35 | print(b, file=file, flush=True)
36 | print(c, file=file, flush=True)
37 |
38 | duration = end - start
39 | print(f"MLX time: {duration}")
40 | return duration
41 |
--------------------------------------------------------------------------------
/mlx_models/tiny_bert.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import mlx.optimizers
4 | import numpy as np
5 | from transformers import AutoTokenizer
6 | from torch.utils.data import DataLoader
7 |
8 | import mlx.nn as nn
9 | import mlx.core as mx
10 |
11 | from pytorch_models.tiny_bert import Config
12 |
13 |
14 | class LayerNorm(nn.Module):
15 | def __init__(self, hidden_size, variance_epsilon=1e-12):
16 | super(LayerNorm, self).__init__()
17 | self.gamma = mx.zeros(hidden_size)
18 | self.beta = mx.zeros(hidden_size)
19 | self.variance_epsilon = variance_epsilon
20 |
21 | def __call__(self, x):
22 | u = mx.mean(x, -1, keepdims=True)
23 | s = mx.power((x - u), 2).mean(-1, keepdims=True)
24 | x = (x - u) / mx.sqrt(s + self.variance_epsilon)
25 | return self.gamma * x + self.beta
26 |
27 |
28 | class PositionWiseMLP(nn.Module):
29 | def __init__(self, hidden_size, intermediate_size):
30 | super(PositionWiseMLP, self).__init__()
31 | self.expansion = nn.Linear(hidden_size, intermediate_size)
32 | self.contraction = nn.Linear(intermediate_size, hidden_size)
33 | self.gelu = nn.GELU()
34 |
35 | def __call__(self, x):
36 | x = self.expansion(x)
37 | x = self.contraction(self.gelu(x))
38 | return x
39 |
40 |
41 | class Transformer(nn.Module):
42 | def __init__(self, config):
43 | super(Transformer, self).__init__()
44 |
45 | self.hidden_size = config.hidden_size
46 | self.num_attention_heads = config.num_attention_heads
47 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
48 | self.all_head_size = self.num_attention_heads * self.attention_head_size
49 |
50 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
51 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
52 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
53 |
54 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
55 |
56 | self.attn_out = nn.Linear(config.hidden_size, config.hidden_size)
57 | self.ln1 = LayerNorm(config.hidden_size)
58 |
59 | self.mlp = PositionWiseMLP(config.hidden_size, config.intermediate_size)
60 | self.ln2 = LayerNorm(config.hidden_size)
61 |
62 | @staticmethod
63 | def split_heads(tensor, num_heads, attention_head_size):
64 | new_shape = tensor.shape[:-1] + (num_heads, attention_head_size)
65 | tensor = tensor.reshape(new_shape)
66 | return tensor.transpose(0, 2, 1, 3)
67 |
68 | @staticmethod
69 | def merge_heads(tensor, num_heads, attention_head_size):
70 | tensor = tensor.transpose(0, 2, 1, 3)
71 | new_shape = tensor.shape[:-2] + (num_heads * attention_head_size,)
72 | return tensor.reshape(new_shape)
73 |
74 | def attention(self, q, k, v, attention_mask):
75 | mask = attention_mask == 1
76 | mask = mx.expand_dims(mask, 1)
77 | mask = mx.expand_dims(mask, 2)
78 |
79 | weights = mx.matmul(q, k.transpose(0, 1, 3, 2))
80 | weights = weights / math.sqrt(self.attention_head_size)
81 | weights = mx.where(mask, weights, float(-1e9))
82 |
83 | logit = nn.softmax(weights, axis=-1)
84 | logit = self.dropout(logit)
85 |
86 | scores = mx.matmul(logit, v)
87 | return scores
88 |
89 | def __call__(self, x, attention_mask):
90 | q, k, v = self.query(x), self.key(x), self.value(x)
91 |
92 | q = self.split_heads(q, self.num_attention_heads, self.attention_head_size)
93 | k = self.split_heads(k, self.num_attention_heads, self.attention_head_size)
94 | v = self.split_heads(v, self.num_attention_heads, self.attention_head_size)
95 |
96 | a = self.attention(q, k, v, attention_mask)
97 | a = self.merge_heads(a, self.num_attention_heads, self.attention_head_size)
98 | a = self.attn_out(a)
99 | a = self.dropout(a)
100 | a = self.ln1(a + x)
101 |
102 | m = self.mlp(a)
103 | m = self.dropout(m)
104 | m = self.ln2(m + a)
105 |
106 | return m
107 |
108 |
109 | class MiniBert(nn.Module):
110 | def __init__(self, config_dict):
111 | super(MiniBert, self).__init__()
112 | self.config = Config.from_dict(config_dict)
113 |
114 | self.token_embedding = nn.Embedding(
115 | self.config.vocab_size, self.config.hidden_size
116 | )
117 | self.position_embedding = nn.Embedding(
118 | self.config.max_position_embeddings, self.config.hidden_size
119 | )
120 | self.token_type_embedding = nn.Embedding(
121 | self.config.type_vocab_size, self.config.hidden_size
122 | )
123 |
124 | self.ln = LayerNorm(self.config.hidden_size)
125 | self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
126 |
127 | self.layers = [
128 | Transformer(self.config) for _ in range(self.config.num_hidden_layers)
129 | ]
130 |
131 | self.pooler = nn.Sequential(
132 | nn.Linear(self.config.hidden_size, self.config.hidden_size), nn.Tanh()
133 | )
134 |
135 | def __call__(self, input_ids, attention_mask=None, token_type_ids=None):
136 | position_ids = mx.arange(
137 | input_ids.shape[1],
138 | dtype=mx.int32,
139 | )
140 | position_ids = mx.expand_dims(position_ids, 0)
141 | position_ids = mx.repeat(position_ids, input_ids.shape[0], 0)
142 |
143 | if token_type_ids is None:
144 | token_type_ids = mx.zeros_like(input_ids)
145 |
146 | x = (
147 | self.token_embedding(input_ids)
148 | + self.position_embedding(position_ids)
149 | + self.token_type_embedding(token_type_ids)
150 | )
151 | x = self.dropout(self.ln(x))
152 |
153 | for layer in self.layers:
154 | x = layer(x, attention_mask)
155 |
156 | o = self.pooler(x[:, 0])
157 | return x, o
158 |
159 |
160 | class BertFineTuneTask(nn.Module):
161 | def __init__(self, num_labels, bert_config):
162 | super(BertFineTuneTask, self).__init__()
163 | self.bert = MiniBert(bert_config)
164 | self.linear1 = nn.Linear(4 * self.bert.config.hidden_size, num_labels)
165 |
166 | def __call__(self, input_ids, attention_mask):
167 | input_ids = input_ids.reshape(input_ids.shape[0] * 2, input_ids.shape[2])
168 | attention_mask = attention_mask.reshape(
169 | attention_mask.shape[0] * 2, attention_mask.shape[2]
170 | )
171 | bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
172 |
173 | embeds = bert_output[1].reshape(
174 | bert_output[1].shape[0] // 2, 2, bert_output[1].shape[1]
175 | )
176 |
177 | set_1_embeds = embeds[:, 0]
178 | set_2_embeds = embeds[:, 1]
179 | concat = mx.concatenate(
180 | (
181 | set_1_embeds,
182 | set_2_embeds,
183 | mx.abs(set_1_embeds - set_2_embeds),
184 | mx.multiply(set_1_embeds, set_2_embeds),
185 | ),
186 | axis=-1,
187 | )
188 |
189 | lin = self.linear1(concat)
190 | return lin
191 |
192 |
193 | def tokenize_sentence_pair_dataset(dataset, tokenizer, max_length=512):
194 | tokenized_dataset = []
195 | for i in range(0, len(dataset[0])):
196 | tokenized_dataset.append(
197 | (
198 | tokenizer(
199 | [dataset[0][i], dataset[1][i]],
200 | return_tensors="np",
201 | padding="max_length",
202 | max_length=max_length,
203 | truncation=True,
204 | ),
205 | np.array(dataset[2][i]),
206 | )
207 | )
208 | return tokenized_dataset
209 |
210 |
211 | def train_loop(model, optimizer, train_dataloader, num_epochs, val_dataloader):
212 | mx.eval(model.parameters())
213 | state = [model.state, optimizer.state]
214 |
215 | def loss_fn(model, x, y):
216 | input_ids, attention_mask = x
217 | embeds = model(input_ids, attention_mask)
218 | losses = nn.losses.cross_entropy(embeds, y, reduction="mean")
219 | return losses
220 |
221 | def step(inputs, targets):
222 | loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
223 | loss, grads = loss_and_grad_fn(model, inputs, targets)
224 | optimizer.update(model, grads)
225 | return loss
226 |
227 | for epoch in range(num_epochs):
228 | print(f"Epoch {epoch + 1} out of {num_epochs}")
229 | epoch_loss = 0
230 | for item in iter(train_dataloader):
231 | # From Pytorch tensor to numpy, there is no copy.
232 | ids = mx.array(item[0]["input_ids"].numpy())
233 | mask = mx.array(item[0]["attention_mask"].numpy())
234 | truth = mx.array(item[1].numpy())
235 |
236 | loss = step((ids, mask), truth)
237 | mx.eval(state)
238 | epoch_loss += loss.item()
239 |
240 | print(f"epoch loss: {epoch_loss}")
241 | dev_loss = 0
242 |
243 | for item in iter(val_dataloader):
244 | ids = mx.array(item[0]["input_ids"].numpy())
245 | mask = mx.array(item[0]["attention_mask"].numpy())
246 | truth = mx.array(item[1].numpy())
247 |
248 | loss = loss_fn(model, (ids, mask), truth)
249 | dev_loss += loss.item()
250 | print(f"validation loss: {dev_loss}")
251 | print()
252 |
253 |
254 | def train(num_epochs, batch_size, num_labels, bert_config, lr, dataset):
255 | mx.set_default_device(mx.gpu)
256 | model_name = "prajjwal1/bert-tiny"
257 | tokenizer = AutoTokenizer.from_pretrained(model_name)
258 | tokenized_train = tokenize_sentence_pair_dataset(
259 | dataset["train"][:50000], tokenizer, max_length=128
260 | )
261 | tokenized_val = tokenize_sentence_pair_dataset(
262 | dataset["dev"], tokenizer, max_length=128
263 | )
264 |
265 | train_dataloader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=False)
266 | val_dataloader = DataLoader(tokenized_val, batch_size=batch_size, shuffle=False)
267 |
268 | bert_model = BertFineTuneTask(num_labels, bert_config)
269 | optimizer = mlx.optimizers.AdamW(learning_rate=lr)
270 | train_loop(bert_model, optimizer, train_dataloader, num_epochs, val_dataloader)
271 |
--------------------------------------------------------------------------------
/mlx_models/tiny_llama.py:
--------------------------------------------------------------------------------
1 | # File copied from https://github.com/ml-explore/mlx-examples/blob/main/llms/llama/llama.py
2 | # and modified for the benchmark
3 | import pathlib
4 | from dataclasses import dataclass
5 | import mlx.core as mx
6 | import mlx.nn as nn
7 | import json
8 | from typing import Optional, Tuple
9 | from pathlib import Path
10 | import glob
11 | from sentencepiece import SentencePieceProcessor
12 | from mlx.utils import tree_unflatten
13 | import os
14 |
15 |
16 | @dataclass
17 | class ModelArgs:
18 | dim: int
19 | n_layers: int
20 | head_dim: int
21 | hidden_dim: int
22 | n_heads: int
23 | n_kv_heads: int
24 | norm_eps: float
25 | vocab_size: int
26 | rope_theta: float
27 | rope_traditional: bool = True
28 |
29 |
30 | class Attention(nn.Module):
31 | def __init__(self, args: ModelArgs):
32 | super().__init__()
33 | self.args = args
34 |
35 | self.n_heads: int = args.n_heads
36 | self.n_kv_heads: int = args.n_kv_heads
37 |
38 | self.repeats = self.n_heads // self.n_kv_heads
39 |
40 | self.scale = self.args.head_dim**-0.5
41 |
42 | self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
43 | self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
44 | self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
45 | self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
46 | self.rope = nn.RoPE(
47 | args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
48 | )
49 |
50 | def __call__(
51 | self,
52 | x: mx.array,
53 | mask: Optional[mx.array] = None,
54 | cache: Optional[Tuple[mx.array, mx.array]] = None,
55 | ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
56 | B, L, D = x.shape
57 |
58 | queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
59 |
60 | # Prepare the queries, keys and values for the attention computation
61 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
62 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
63 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
64 |
65 | def repeat(a):
66 | a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
67 | return a.reshape([B, self.n_heads, L, -1])
68 |
69 | keys, values = map(repeat, (keys, values))
70 |
71 | if cache is not None:
72 | key_cache, value_cache = cache
73 | queries = self.rope(queries, offset=key_cache.shape[2])
74 | keys = self.rope(keys, offset=key_cache.shape[2])
75 | keys = mx.concatenate([key_cache, keys], axis=2)
76 | values = mx.concatenate([value_cache, values], axis=2)
77 | else:
78 | queries = self.rope(queries)
79 | keys = self.rope(keys)
80 |
81 | scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
82 | if mask is not None:
83 | scores += mask
84 | scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
85 | output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
86 | return self.wo(output), (keys, values)
87 |
88 |
89 | class FeedForward(nn.Module):
90 | def __init__(self, args: ModelArgs):
91 | super().__init__()
92 |
93 | self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
94 | self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
95 | self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
96 |
97 | def __call__(self, x) -> mx.array:
98 | return self.w2(nn.silu(self.w1(x)) * self.w3(x))
99 |
100 |
101 | class TransformerBlock(nn.Module):
102 | def __init__(self, args: ModelArgs):
103 | super().__init__()
104 | self.n_heads = args.n_heads
105 | self.dim = args.dim
106 | self.attention = Attention(args)
107 | self.feed_forward = FeedForward(args=args)
108 | self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
109 | self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
110 | self.args = args
111 |
112 | def __call__(
113 | self,
114 | x: mx.array,
115 | mask: Optional[mx.array] = None,
116 | cache: Optional[Tuple[mx.array, mx.array]] = None,
117 | ) -> mx.array:
118 | r, cache = self.attention(self.attention_norm(x), mask, cache)
119 | h = x + r
120 | r = self.feed_forward(self.ffn_norm(h))
121 | out = h + r
122 | return out, cache
123 |
124 |
125 | class Llama(nn.Module):
126 | def __init__(self, args: ModelArgs):
127 | super().__init__()
128 | self.args = args
129 | self.vocab_size = args.vocab_size
130 | self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
131 | self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
132 | self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps)
133 | self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
134 |
135 | def __call__(self, x):
136 | mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
137 | mask = mask.astype(self.tok_embeddings.weight.dtype)
138 |
139 | x = self.tok_embeddings(x)
140 | for l in self.layers:
141 | x, _ = l(x, mask)
142 | x = self.norm(x)
143 | return self.output(x)
144 |
145 | def generate(self, x, temp=1.0):
146 | def sample(logits):
147 | if temp == 0:
148 | return mx.argmax(logits, axis=-1)
149 | else:
150 | return mx.random.categorical(logits * (1 / temp))
151 |
152 | cache = []
153 |
154 | # Make an additive causal mask. We will need that to process the prompt.
155 | mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
156 | mask = mask.astype(self.tok_embeddings.weight.dtype)
157 |
158 | # First we process the prompt x the same was as in __call__ but
159 | # save the caches in cache
160 | x = self.tok_embeddings(x)
161 | for l in self.layers:
162 | x, c = l(x, mask=mask)
163 | # We store the per layer cache in a simple python list
164 | cache.append(c)
165 | x = self.norm(x)
166 | # We only care about the last logits that generate the next token
167 | y = self.output(x[:, -1])
168 | y = sample(y)
169 |
170 | # y now has size [1]
171 | # Since MLX is lazily evaluated nothing is computed yet.
172 | # Calling y.item() would force the computation to happen at
173 | # this point but we can also choose not to do that and let the
174 | # user choose when to start the computation.
175 | yield y
176 |
177 | # Now we parsed the prompt and generated the first token we
178 | # need to feed it back into the model and loop to generate the
179 | # rest.
180 | while True:
181 | # Unsqueezing the last dimension to add a sequence length
182 | # dimension of 1
183 | x = y[:, None]
184 |
185 | x = self.tok_embeddings(x)
186 | for i in range(len(cache)):
187 | # We are overwriting the arrays in the cache list. When
188 | # the computation will happen, MLX will be discarding the
189 | # old cache the moment it is not needed anymore.
190 | x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
191 | x = self.norm(x)
192 | y = sample(self.output(x[:, -1]))
193 |
194 | yield y
195 |
196 |
197 | def generate(args, model, tokenizer) -> str:
198 | x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
199 | tokens = []
200 | for token in model.generate(x, args.temp):
201 | tokens.append(token)
202 | if (
203 | token.item() == 29958
204 | and tokenizer.decode([t.item() for t in tokens[-5:]]) == "<|user|>"
205 | ):
206 | tokens = tokens[:-5]
207 | break
208 |
209 | if len(tokens) >= args.max_tokens:
210 | break
211 |
212 | mx.eval(tokens)
213 | s = tokenizer.decode([t.item() for t in tokens])
214 | return s
215 |
216 |
217 | def sanitize_config(config, weights):
218 | config.pop("model_type", None)
219 | n_heads = config["n_heads"]
220 | if "n_kv_heads" not in config:
221 | config["n_kv_heads"] = n_heads
222 | if "head_dim" not in config:
223 | config["head_dim"] = config["dim"] // n_heads
224 | if "hidden_dim" not in config:
225 | config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
226 | if config.get("vocab_size", -1) < 0:
227 | config["vocab_size"] = weights["output.weight"].shape[-1]
228 | if "rope_theta" not in config:
229 | config["rope_theta"] = 10000
230 | unused = ["multiple_of", "ffn_dim_multiplier"]
231 | for k in unused:
232 | config.pop(k, None)
233 | return config
234 |
235 |
236 | def load_model(model_path):
237 | model_path = Path(model_path)
238 |
239 | unsharded_weights_path = Path(model_path / "weights.npz")
240 | if unsharded_weights_path.is_file():
241 | weights = mx.load(str(unsharded_weights_path))
242 | else:
243 | sharded_weights_glob = str(model_path / "weights.*.npz")
244 | weight_files = glob.glob(sharded_weights_glob)
245 |
246 | if len(weight_files) == 0:
247 | raise FileNotFoundError("No weights found in {}".format(model_path))
248 |
249 | weights = {}
250 | for wf in weight_files:
251 | weights.update(mx.load(wf).items())
252 |
253 | with open(model_path / "config.json", "r") as f:
254 | config = sanitize_config(json.loads(f.read()), weights)
255 | quantization = config.pop("quantization", None)
256 | model = Llama(ModelArgs(**config))
257 | if quantization is not None:
258 | nn.quantize(model, **quantization)
259 | model.update(tree_unflatten(list(weights.items())))
260 | tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
261 | return model, tokenizer
262 |
263 |
264 | @dataclass
265 | class Args:
266 | prompt: str
267 | max_tokens: int
268 | temp: float
269 |
270 |
271 | class MLXLlama:
272 | def __init__(self, max_tokens: int, file: str):
273 | self.max_tokens = max_tokens
274 |
275 | current_dir = pathlib.Path(__file__).parent.resolve()
276 | save_dir = os.path.join(current_dir, "tiny_llama")
277 |
278 | mx.set_default_device(mx.gpu)
279 | self.model, self.tokenizer = load_model(save_dir)
280 | self.out_file = open(file, "w")
281 |
282 | def generate_and_save(self, prompt: str):
283 | formatted_prompt = (
284 | f"<|system|>\nYou are a friendly chatbot who always responds wisely\n"
285 | f"<|user|>\n{prompt}\n"
286 | "<|assistant|>"
287 | )
288 | args = Args(temp=0.0, max_tokens=1024, prompt=formatted_prompt)
289 | result = generate(args, self.model, self.tokenizer)
290 | print(result, file=self.out_file, flush=True)
291 |
292 | def finish(self):
293 | self.out_file.close()
294 |
--------------------------------------------------------------------------------
/mlx_models/whisper.py:
--------------------------------------------------------------------------------
1 | import mlx_whisper
2 | from datasets import load_dataset
3 | import mlx.core as mx
4 | import pathlib
5 | import os
6 |
7 |
8 | class MLXWhisper:
9 | def __init__(self, num_examples: int, filename: str):
10 | self.dataset = load_dataset(
11 | "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
12 | )
13 | self.file = open(filename, "w")
14 | model = "whisper_tiny_fp32"
15 | current_dir = pathlib.Path(__file__).parent.resolve()
16 | self.model_dir = os.path.join(current_dir, model)
17 |
18 | self.num_examples = 1
19 | # This serves to load the model
20 | self.generate()
21 | self.file.truncate(0)
22 |
23 | self.num_examples = num_examples
24 |
25 | def generate(self):
26 | mx.set_default_device(mx.gpu)
27 | for i in range(0, self.num_examples):
28 | audio_sample = self.dataset[i]["audio"]
29 | waveform = audio_sample["array"]
30 | text = mlx_whisper.transcribe(
31 | waveform, path_or_hf_repo=self.model_dir, fp16=False
32 | )["text"]
33 | print(text, file=self.file, flush=True)
34 |
35 | def finish(self):
36 | self.file.close()
37 |
--------------------------------------------------------------------------------
/pytorch_models/BasicLM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Optional
4 | import math
5 | from dataset.simple_transformers import load_ptb
6 | from mlx_models.BasicLM import to_samples, iterate_batches
7 | import numpy as np
8 |
9 | mps_dev = torch.device('mps')
10 |
11 |
12 | # Adapted from
13 | # https://github.com/ml-explore/mlx/blob/c4a471c99d0c6e6b085ff944ffef149905296a14/python/mlx/nn/layers/positional_encoding.py#L57
14 | class SinusoidalPositionalEncoding(nn.Module):
15 | def __init__(
16 | self,
17 | dims: int,
18 | min_freq: float = 0.0001,
19 | max_freq: float = 1,
20 | scale: Optional[float] = None,
21 | cos_first: bool = False,
22 | full_turns: bool = False,
23 | ):
24 | super().__init__()
25 | one_zero = 1 - torch.arange(0, dims // 2, device=mps_dev) / (dims // 2 - 1)
26 | min_freq = torch.log(torch.tensor(min_freq))
27 | max_freq = torch.log(torch.tensor(max_freq))
28 |
29 | self._sigmas = torch.exp(one_zero * (max_freq - min_freq) + min_freq)
30 | if full_turns:
31 | self._sigmas = self._sigmas * (2 * math.pi)
32 |
33 | self.scale = scale or (2 / dims) ** 0.5
34 | self.cos_first = cos_first
35 |
36 | def forward(self, x):
37 | y = x[..., None] * self._sigmas
38 | cosy = torch.cos(y)
39 | siny = torch.sin(y)
40 |
41 | if self.cos_first:
42 | y = torch.cat((cosy, siny), -1)
43 | else:
44 | y = torch.cat((siny, cosy), -1)
45 |
46 | if self.scale != 1:
47 | y = y * self.scale
48 |
49 | return y
50 |
51 |
52 | class TransformerLM(nn.Module):
53 | def __init__(
54 | self,
55 | vocab_size: int,
56 | num_layers: int,
57 | dims: int,
58 | num_heads: int,
59 | checkpoint: bool,
60 | ):
61 | super().__init__()
62 |
63 | self.embedding = nn.Embedding(vocab_size, dims)
64 | self.pe = SinusoidalPositionalEncoding(dims)
65 |
66 | encoder_layer = nn.TransformerEncoderLayer(
67 | dims, nhead=num_heads, dim_feedforward=dims * 4, norm_first=True, batch_first=True
68 | )
69 | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
70 | self.out_proj = nn.Linear(dims, vocab_size)
71 |
72 | @staticmethod
73 | def create_additive_causal_mask(N: int, dtype: torch.dtype = torch.float32):
74 | # Adapted from
75 | # https://github.com/ml-explore/mlx/blob/c4a471c99d0c6e6b085ff944ffef149905296a14/python/mlx/nn/layers/transformer.py#L102
76 | indices = torch.arange(N, device=mps_dev)
77 | mask = indices[:, None] < indices[None]
78 | mask = mask.to(dtype) * -1e9
79 | return mask
80 |
81 | def forward(self, x):
82 | l_shape = x.shape[1]
83 | mask = self.create_additive_causal_mask(l_shape)
84 | x = self.embedding(x)
85 | x = x + self.pe(torch.arange(l_shape, device=mps_dev))
86 | x = self.transformer(x, mask)
87 | return self.out_proj(x)
88 |
89 |
90 | def mps_tensor(x):
91 | return torch.tensor(x, device=mps_dev)
92 |
93 |
94 | def train(
95 | num_blocks,
96 | batch_size,
97 | context_size,
98 | dim,
99 | num_heads,
100 | checkpoint,
101 | learning_rate,
102 | weight_decay,
103 | num_iters,
104 | lr_warmup,
105 | ptb_data
106 | ):
107 | vocab, train, valid, test = ptb_data
108 | model = TransformerLM(len(vocab), num_blocks, dim, num_heads, checkpoint).to(mps_dev)
109 |
110 | def loss_fn(model, x, y):
111 | logits = model(x)
112 | losses = nn.functional.cross_entropy(logits.permute(0, 2, 1), y)
113 | return losses
114 |
115 | optimizer = torch.optim.AdamW(
116 | model.parameters(), lr=learning_rate, weight_decay=weight_decay
117 | )
118 |
119 | def eval_fn(dataset):
120 | inputs, targets = to_samples(context_size, dataset)
121 | inputs = torch.tensor(inputs, dtype=torch.int32, device=mps_dev)
122 | targets = torch.tensor(targets, dtype=torch.int32, device=mps_dev)
123 | loss = 0
124 | model.train(False)
125 | with torch.no_grad():
126 | for s in range(0, min(targets.shape[0], 1024), batch_size):
127 | bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
128 | losses = loss_fn(model, bx, by)
129 | loss += torch.sum(losses).item()
130 | return loss / len(targets)
131 |
132 | lr_lambda = lambda epoch: min(1, epoch / lr_warmup) * learning_rate
133 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
134 |
135 | # @torch.compile
136 | def step(inputs, targets, it):
137 | optimizer.zero_grad()
138 | loss = loss_fn(model, inputs, targets)
139 | loss.backward()
140 | scheduler.step(it)
141 | return loss
142 |
143 | train_iterator = iterate_batches(batch_size, context_size, train)
144 | losses = []
145 | for it, (inputs, targets) in zip(range(num_iters), train_iterator):
146 | model.train(True)
147 | inputs = torch.tensor(inputs, dtype=torch.int32, device=mps_dev)
148 | targets = torch.tensor(targets, device=mps_dev, dtype=torch.int32)
149 | loss = step(inputs, targets, it)
150 | losses.append(loss.item())
151 | train_loss = np.mean(losses)
152 | print(f"Iter {it + 1}: Train loss {train_loss:.3f}, ")
153 | val_loss = eval_fn(valid)
154 | print(
155 | f"Iter {it + 1}: "
156 | f"Val loss {val_loss:.3f}, "
157 | f"Val ppl {math.exp(val_loss):.3f}, "
158 | )
159 |
160 | test_loss = eval_fn(test)
161 | test_ppl = math.exp(test_loss)
162 | print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
163 |
--------------------------------------------------------------------------------
/pytorch_models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSte/MLX-vs-Pytorch/b1c8b97dd313455e3cce6e21ea17e8d08d7045a7/pytorch_models/__init__.py
--------------------------------------------------------------------------------
/pytorch_models/configure.sh:
--------------------------------------------------------------------------------
1 | #!bash
2 |
3 | git lfs install
4 | git clone https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0
--------------------------------------------------------------------------------
/pytorch_models/switch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 |
4 |
5 | def multiply_and_divide(op1: torch.Tensor, op2: torch.Tensor) -> torch.Tensor:
6 | mult = torch.matmul(op1, op2)
7 | div = torch.divide(mult, torch.max(mult, 0).values)
8 | return div
9 |
10 |
11 | def multiply_items(op1: torch.Tensor, op2: torch.Tensor, op3: torch.Tensor):
12 | r_1 = multiply_and_divide(op1, op2)
13 | r_2 = multiply_and_divide(op2, op3)
14 | r_3 = multiply_and_divide(op1, op3)
15 | return r_1, r_2, r_3
16 |
17 |
18 | mps_device = torch.device("mps")
19 | cpu_device = torch.device("cpu")
20 |
21 |
22 | def run_test(iterations: int, size: int, filename: str):
23 | res_1 = torch.rand(size, size, device=cpu_device, dtype=torch.float32)
24 | res_2 = torch.rand(size, size, device=cpu_device, dtype=torch.float32)
25 | res_3 = torch.rand(size, size, device=cpu_device, dtype=torch.float32)
26 |
27 | start = time.time()
28 |
29 | for _ in range(0, iterations):
30 | a = res_1.to(mps_device)
31 | b = res_2.to(mps_device)
32 | c = res_3.to(mps_device)
33 |
34 | mps_1, mps_2, mps_3 = multiply_items(a, b, c)
35 |
36 | cpu_1 = mps_1.to(cpu_device)
37 | cpu_2 = mps_2.to(cpu_device)
38 | cpu_3 = mps_3.to(cpu_device)
39 |
40 | res_1, res_2, res_3 = multiply_items(cpu_1, cpu_2, cpu_3)
41 |
42 | end = time.time()
43 |
44 | with open(filename, "w") as file:
45 | print(res_1, file=file, flush=True)
46 | print(res_2, file=file, flush=True)
47 | print(res_3, file=file, flush=True)
48 |
49 | duration = end - start
50 | print(f"Pytorch time: {duration}")
51 | return duration
52 |
--------------------------------------------------------------------------------
/pytorch_models/tiny_bert.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import math
6 | from transformers import AutoTokenizer
7 | from torch.utils.data import DataLoader
8 |
9 | device = torch.device("mps")
10 |
11 |
12 | class Config(object):
13 | def __init__(
14 | self,
15 | vocab_size,
16 | hidden_size=768,
17 | num_hidden_layers=12,
18 | num_attention_heads=12,
19 | intermediate_size=3072,
20 | dropout_prob=0.1,
21 | max_position_embeddings=512,
22 | type_vocab_size=2,
23 | initializer_range=0.02,
24 | ):
25 | self.vocab_size = vocab_size
26 | self.hidden_size = hidden_size
27 | self.num_hidden_layers = num_hidden_layers
28 | self.num_attention_heads = num_attention_heads
29 | self.intermediate_size = intermediate_size
30 | self.hidden_dropout_prob = dropout_prob
31 | self.attention_probs_dropout_prob = dropout_prob
32 | self.max_position_embeddings = max_position_embeddings
33 | self.type_vocab_size = type_vocab_size
34 | self.initializer_range = initializer_range
35 |
36 | @classmethod
37 | def from_dict(cls, dict_object):
38 | config = Config(vocab_size=None)
39 | for key, value in dict_object.items():
40 | config.__dict__[key] = value
41 | return config
42 |
43 |
44 | class LayerNorm(nn.Module):
45 | def __init__(self, hidden_size, variance_epsilon=1e-12):
46 | super(LayerNorm, self).__init__()
47 | self.gamma = nn.Parameter(torch.ones(hidden_size))
48 | self.beta = nn.Parameter(torch.zeros(hidden_size))
49 | self.variance_epsilon = variance_epsilon
50 |
51 | def forward(self, x):
52 | u = x.mean(-1, keepdim=True)
53 | s = (x - u).pow(2).mean(-1, keepdim=True)
54 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
55 | return self.gamma * x + self.beta
56 |
57 |
58 | class PositionWiseMLP(nn.Module):
59 | def __init__(self, hidden_size, intermediate_size):
60 | super(PositionWiseMLP, self).__init__()
61 | self.expansion = nn.Linear(hidden_size, intermediate_size)
62 | self.contraction = nn.Linear(intermediate_size, hidden_size)
63 |
64 | def forward(self, x):
65 | x = self.expansion(x)
66 | x = self.contraction(torch.nn.functional.gelu(x))
67 | return x
68 |
69 |
70 | class Transformer(nn.Module):
71 | def __init__(self, config):
72 | super(Transformer, self).__init__()
73 |
74 | self.hidden_size = config.hidden_size
75 | self.num_attention_heads = config.num_attention_heads
76 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
77 | self.all_head_size = self.num_attention_heads * self.attention_head_size
78 |
79 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
80 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
81 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
82 |
83 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
84 |
85 | self.attn_out = nn.Linear(config.hidden_size, config.hidden_size)
86 | self.ln1 = LayerNorm(config.hidden_size)
87 |
88 | self.mlp = PositionWiseMLP(config.hidden_size, config.intermediate_size)
89 | self.ln2 = LayerNorm(config.hidden_size)
90 |
91 | @staticmethod
92 | def split_heads(tensor, num_heads, attention_head_size):
93 | new_shape = tensor.size()[:-1] + (num_heads, attention_head_size)
94 | tensor = tensor.view(*new_shape)
95 | return tensor.permute(0, 2, 1, 3)
96 |
97 | @staticmethod
98 | def merge_heads(tensor, num_heads, attention_head_size):
99 | tensor = tensor.permute(0, 2, 1, 3).contiguous()
100 | new_shape = tensor.size()[:-2] + (num_heads * attention_head_size,)
101 | return tensor.view(new_shape)
102 |
103 | def attention(self, q, k, v, attention_mask):
104 | mask = attention_mask == 1
105 | mask = mask.unsqueeze(1).unsqueeze(2)
106 |
107 | weights = torch.matmul(q, k.transpose(-2, -1))
108 | weights = weights / math.sqrt(self.attention_head_size)
109 | weights = torch.where(
110 | mask, weights, torch.tensor(float(-1e9), device=attention_mask.device)
111 | )
112 |
113 | logit = torch.nn.functional.softmax(weights, dim=-1)
114 | logit = self.dropout(logit)
115 |
116 | scores = torch.matmul(logit, v)
117 | return scores
118 |
119 | def forward(self, x, attention_mask):
120 | q, k, v = self.query(x), self.key(x), self.value(x)
121 |
122 | q = self.split_heads(q, self.num_attention_heads, self.attention_head_size)
123 | k = self.split_heads(k, self.num_attention_heads, self.attention_head_size)
124 | v = self.split_heads(v, self.num_attention_heads, self.attention_head_size)
125 |
126 | a = self.attention(q, k, v, attention_mask)
127 | a = self.merge_heads(a, self.num_attention_heads, self.attention_head_size)
128 | a = self.attn_out(a)
129 | a = self.dropout(a)
130 | a = self.ln1(a + x)
131 |
132 | m = self.mlp(a)
133 | m = self.dropout(m)
134 | m = self.ln2(m + a)
135 |
136 | return m
137 |
138 |
139 | class MiniBert(nn.Module):
140 | def __init__(self, config_dict):
141 | super(MiniBert, self).__init__()
142 | self.config = Config.from_dict(config_dict)
143 | self.embeddings = nn.ModuleDict(
144 | {
145 | "token": nn.Embedding(
146 | self.config.vocab_size, self.config.hidden_size, padding_idx=0
147 | ),
148 | "position": nn.Embedding(
149 | self.config.max_position_embeddings, self.config.hidden_size
150 | ),
151 | "token_type": nn.Embedding(
152 | self.config.type_vocab_size, self.config.hidden_size
153 | ),
154 | }
155 | )
156 |
157 | self.ln = LayerNorm(self.config.hidden_size)
158 | self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
159 |
160 | self.layers = nn.ModuleList(
161 | [Transformer(self.config) for _ in range(self.config.num_hidden_layers)]
162 | )
163 |
164 | # This is a pooling layer for Bert's last layer.
165 | self.pooler = nn.Sequential(
166 | OrderedDict(
167 | [
168 | (
169 | "dense",
170 | nn.Linear(self.config.hidden_size, self.config.hidden_size),
171 | ),
172 | ("activation", nn.Tanh()),
173 | ]
174 | )
175 | )
176 |
177 | def forward(self, input_ids, attention_mask=None, token_type_ids=None):
178 | position_ids = torch.arange(
179 | input_ids.size(1), dtype=torch.long, device=input_ids.device
180 | )
181 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
182 | if token_type_ids is None:
183 | token_type_ids = torch.zeros_like(input_ids)
184 |
185 | x = (
186 | self.embeddings.token(input_ids)
187 | + self.embeddings.position(position_ids)
188 | + self.embeddings.token_type(token_type_ids)
189 | )
190 | x = self.dropout(self.ln(x))
191 |
192 | for layer in self.layers:
193 | x = layer(x, attention_mask)
194 |
195 | o = self.pooler(x[:, 0])
196 | return x, o
197 |
198 |
199 | # BERT fine-tune task: https://arxiv.org/pdf/1705.02364
200 | class BertFineTuneTask(nn.Module):
201 | def __init__(self, num_labels, bert_config):
202 | super(BertFineTuneTask, self).__init__()
203 | self.bert = MiniBert(bert_config)
204 | self.linear1 = torch.nn.Linear(4 * self.bert.config.hidden_size, num_labels)
205 | self.loss_func = torch.nn.CrossEntropyLoss()
206 |
207 | def forward(self, input_ids, attention_mask, ground_truth):
208 | input_ids = input_ids.view(input_ids.size(0) * 2, input_ids.size(2))
209 | attention_mask = attention_mask.view(
210 | attention_mask.size(0) * 2, attention_mask.size(2)
211 | )
212 | bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
213 |
214 | embeds = bert_output[1].view(
215 | bert_output[1].size(0) // 2, 2, bert_output[1].size(1)
216 | )
217 |
218 | set_1_embeds = embeds[:, 0]
219 | set_2_embeds = embeds[:, 1]
220 | concat = torch.cat(
221 | (
222 | set_1_embeds,
223 | set_2_embeds,
224 | torch.abs(set_1_embeds - set_2_embeds),
225 | torch.mul(set_1_embeds, set_2_embeds),
226 | ),
227 | dim=-1,
228 | )
229 | lin = self.linear1(concat)
230 | loss = self.loss_func(lin, ground_truth)
231 | return loss
232 |
233 |
234 | def tokenize_sentence_pair_dataset(dataset, tokenizer, max_length=512):
235 | tokenized_dataset = []
236 | for i in range(0, len(dataset[0])):
237 | tokenized_dataset.append(
238 | (
239 | tokenizer(
240 | [dataset[0][i], dataset[1][i]],
241 | return_tensors="pt",
242 | padding="max_length",
243 | max_length=max_length,
244 | truncation=True,
245 | ),
246 | torch.tensor(dataset[2][i], dtype=torch.float32),
247 | )
248 | )
249 |
250 | return tokenized_dataset
251 |
252 |
253 | def train_loop(model, optimizer, train_dataloader, num_epochs, val_dataloader):
254 | model.to(device)
255 | model.train(True)
256 |
257 | for epoch in range(num_epochs):
258 | print(f"Epoch {epoch + 1} out of {num_epochs}")
259 | epoch_loss = 0
260 | for item in iter(train_dataloader):
261 | optimizer.zero_grad()
262 |
263 | ids = item[0]["input_ids"].to(device)
264 | mask = item[0]["attention_mask"].to(device)
265 | truth = item[1].to(device)
266 |
267 | loss = model(ids, mask, truth)
268 | epoch_loss += loss
269 | loss.backward()
270 | optimizer.step()
271 |
272 | print(f"epoch loss: {epoch_loss}")
273 | dev_loss = 0
274 | with torch.no_grad():
275 | for item in iter(val_dataloader):
276 | ids = item[0]["input_ids"].to(device)
277 | mask = item[0]["attention_mask"].to(device)
278 | truth = item[1].to(device)
279 |
280 | loss = model(ids, mask, truth)
281 | dev_loss += loss
282 | print(f"validation loss: {dev_loss}")
283 | print()
284 |
285 |
286 | def train(num_epochs, batch_size, num_labels, bert_config, lr, dataset):
287 | model_name = "prajjwal1/bert-tiny"
288 | tokenizer = AutoTokenizer.from_pretrained(model_name)
289 | tokenized_train = tokenize_sentence_pair_dataset(
290 | dataset["train"][:50000], tokenizer, max_length=128
291 | )
292 | tokenized_val = tokenize_sentence_pair_dataset(
293 | dataset["dev"], tokenizer, max_length=128
294 | )
295 |
296 | train_dataloader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=False)
297 | val_dataloader = DataLoader(tokenized_val, batch_size=batch_size, shuffle=False)
298 |
299 | bert_model = BertFineTuneTask(num_labels, bert_config)
300 | optimizer = torch.optim.AdamW(bert_model.parameters(), lr=lr)
301 | train_loop(bert_model, optimizer, train_dataloader, num_epochs, val_dataloader)
302 |
--------------------------------------------------------------------------------
/pytorch_models/tiny_llama.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | import transformers
3 | import torch
4 | import pathlib
5 | import os
6 |
7 |
8 | class TorchLLama:
9 | def __init__(
10 | self,
11 | max_tokens: int,
12 | filename: str,
13 | ):
14 | self.max_tokens = max_tokens
15 | self.filename = filename
16 |
17 | model = "TinyLlama-1.1B-Chat-v1.0"
18 | current_dir = pathlib.Path(__file__).parent.resolve()
19 | save_dir = os.path.join(current_dir, model)
20 | self.pipeline = transformers.pipeline(
21 | "text-generation",
22 | model=save_dir,
23 | torch_dtype=torch.float32,
24 | device_map=torch.device("mps"),
25 | )
26 | self.file = open(filename, "w")
27 |
28 | def generate_and_save(self, prompt: str):
29 | formatted_prompt = (
30 | f"<|system|>\nYou are a friendly chatbot who always responds wisely\n"
31 | f"<|user|>\n{prompt}\n"
32 | "<|assistant|>"
33 | )
34 | seqs = self.pipeline(
35 | formatted_prompt,
36 | max_new_tokens=self.max_tokens,
37 | do_sample=False,
38 | )
39 | print(seqs[0]["generated_text"], file=self.file, flush=True)
40 |
41 | def finish(self):
42 | self.file.close()
43 |
--------------------------------------------------------------------------------
/pytorch_models/whisper.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from transformers import WhisperProcessor, WhisperForConditionalGeneration
3 | import torch
4 |
5 | device = torch.device("mps")
6 |
7 |
8 | class TorchWhisper:
9 | def __init__(self, num_examples: int, filename: str):
10 | self.num_examples = num_examples
11 | self.dataset = load_dataset(
12 | "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
13 | )
14 | self.processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
15 | self.model = WhisperForConditionalGeneration.from_pretrained(
16 | "openai/whisper-tiny.en"
17 | ).to(device)
18 | self.file = open(filename, "w")
19 |
20 | def generate(self):
21 | for i in range(0, self.num_examples):
22 | audio_sample = self.dataset[i]["audio"]
23 | waveform = audio_sample["array"]
24 | sampling_rate = audio_sample["sampling_rate"]
25 | input_features = self.processor(
26 | waveform, sampling_rate=sampling_rate, return_tensors="pt"
27 | ).input_features
28 | predicted_ids = self.model.generate(input_features.to(device))
29 | transcription = self.processor.batch_decode(
30 | predicted_ids, skip_special_tokens=True
31 | )
32 | print(transcription[0], file=self.file, flush=True)
33 |
34 | def finish(self):
35 | self.file.close()
36 |
--------------------------------------------------------------------------------
/raw_results.txt:
--------------------------------------------------------------------------------
1 |
2 | ============================= M1 Pro =============================
3 |
4 | ---- Training a transformers language model ----
5 | => Values from three executions
6 |
7 | Framework: pytorch
8 | Average: 1806.639595190684s - Median: 1818.6110489368439s
9 |
10 | Framework: mlx
11 | Average: 1157.0066788196564s - Median: 1154.6633532047272s
12 |
13 |
14 | ---- Training BERT ----
15 | => Values from three executions
16 |
17 | Framework: pytorch
18 | Average: 751.028749704361s - Median: 753.6784768104553s
19 |
20 | Framework: mlx
21 | Average: 718.3554696242014s - Median: 718.211642742157s
22 |
23 |
24 | ---- Whisper inference ----
25 | => Values from ten executions
26 |
27 | Framework: pytorch
28 | Average: 31.998124384880064s - Median: 31.96485936641693s
29 |
30 | Framework: mlx
31 | Average: 8.509361457824706s - Median: 8.509169936180115s
32 |
33 |
34 | ---- TinyLLama inference ----
35 | => Values from ten executions
36 |
37 | Framework: pytorch
38 | Average: 59.274635887146s - Median: 55.8025221824646s
39 |
40 | Framework: mlx
41 | Average: 33.38054447174072s - Median: 33.322925329208374s
42 |
43 |
44 | ---- CPU/GPU switch ----
45 | => Values from ten executions
46 |
47 | Framework: pytorch
48 | Average: 349.7299320459366s - Median: 349.9100536108017s
49 |
50 | Framework: mlx
51 | Average: 270.1572776556015s - Median: 271.8326184749603s
52 |
53 | ==================================================================
54 |
55 | ============================= M1 Max =============================
56 |
57 | ---- Training a transformers language model ----
58 | => Values from three executions
59 |
60 | Framework: pytorch
61 | Average: 1106.7549295464829s - Median: 1101.3749282211793s
62 |
63 | Framework: mlx
64 | Average: 752.2592077255249s - Median: 752.2592812234227s
65 |
66 |
67 | ---- Training BERT ----
68 | => Values from three executions
69 |
70 | Framework: pytorch
71 | Average: 793.6759963925679s - Median: 793.610111951828s
72 |
73 | Framework: mlx
74 | Average: 499.343544960022s - Median: 498.0613958835602s
75 |
76 |
77 | ---- Whisper inference ----
78 | => Values from ten executions
79 |
80 | Framework: pytorch
81 | Average: 21.27653947197935s - Median: 21.204530119895964s
82 |
83 | Framework: mlx
84 | Average: 6.946336317062378s - Median: 6.937196493148804s
85 |
86 |
87 | ---- TinyLLama inference ----
88 | => Values from ten executions
89 |
90 | Framework: pytorch
91 | Average: 50.98743650913239s - Median: 47.73779845237732s
92 |
93 | Framework: mlx
94 | Average: 20.61291913986206s - Median: 20.613165140151978s
95 |
96 |
97 | ---- CPU/GPU switch ----
98 | => Values from ten executions
99 |
100 | Framework: pytorch
101 | Average: 251.71098244190216s - Median: 252.021404504776s
102 |
103 | Framework: mlx
104 | Average: 214.5735635280609s - Median: 214.57463908195496s
105 |
106 | ==================================================================
107 |
108 | ============================= M3 Max =============================
109 |
110 | ---- Training a transformers language model ----
111 | => Values from three executions
112 |
113 | Framework: pytorch
114 | Average: 912.5205717086792s - Median: 924.3736660480499s
115 |
116 | Framework: mlx
117 | Average: 426.00355768203735s - Median: 426.1944200992584s
118 |
119 |
120 | ---- Training BERT ----
121 | => Values from three executions
122 |
123 | Framework: pytorch
124 | Average: 550.2911033630371s - Median: 544.6988077163696s
125 |
126 | Framework: mlx
127 | Average: 408.45258100827533s - Median: 408.62774896621704s
128 |
129 |
130 | ---- Whisper inference ----
131 | => Values from ten executions
132 |
133 | Framework: pytorch
134 | Average: 17.909158730506896s - Median: 17.877812027931213s
135 |
136 | Framework: mlx
137 | Average: 4.8507798433303835s - Median: 4.839159846305847s
138 |
139 |
140 | ---- TinyLLama inference ----
141 | => Values from ten executions
142 |
143 | Framework: pytorch
144 | Average: 36.182030129432675s - Median: 34.26037609577179s
145 |
146 | Framework: mlx
147 | Average: 15.41469841003418s - Median: 15.389396786689758s
148 |
149 |
150 | ---- CPU/GPU switch ----
151 | => Values from ten executions
152 |
153 | Framework: pytorch
154 | Average: 146.35703275203704s - Median: 146.41792500019073s
155 |
156 | Framework: mlx
157 | Average: 140.5102721452713s - Median: 140.51127195358276s
158 |
159 | ==================================================================
160 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.30.1
2 | aiohttp==3.9.5
3 | aiosignal==1.3.1
4 | attrs==23.2.0
5 | audioread==3.0.1
6 | black==24.4.2
7 | certifi==2024.2.2
8 | cffi==1.16.0
9 | charset-normalizer==3.3.2
10 | click==8.1.7
11 | datasets==2.19.1
12 | decorator==5.1.1
13 | dill==0.3.8
14 | filelock==3.13.4
15 | frozenlist==1.4.1
16 | fsspec==2024.3.1
17 | huggingface-hub==0.23.0
18 | idna==3.7
19 | Jinja2==3.1.3
20 | joblib==1.4.2
21 | lazy_loader==0.4
22 | librosa==0.10.2
23 | llvmlite==0.42.0
24 | MarkupSafe==2.1.5
25 | mlx==0.14.1
26 | mlx-whisper==0.1.0
27 | more-itertools==10.2.0
28 | mpmath==1.3.0
29 | msgpack==1.0.8
30 | multidict==6.0.5
31 | multiprocess==0.70.16
32 | mypy-extensions==1.0.0
33 | networkx==3.3
34 | numba==0.59.1
35 | numpy==1.26.4
36 | packaging==24.0
37 | pandas==2.2.2
38 | pathspec==0.12.1
39 | platformdirs==4.2.1
40 | pooch==1.8.1
41 | psutil==5.9.8
42 | pyarrow==16.0.0
43 | pyarrow-hotfix==0.6
44 | pycparser==2.22
45 | python-dateutil==2.9.0.post0
46 | pytz==2024.1
47 | PyYAML==6.0.1
48 | regex==2024.5.10
49 | requests==2.31.0
50 | safetensors==0.4.3
51 | scikit-learn==1.4.2
52 | scipy==1.13.0
53 | sentencepiece==0.2.0
54 | six==1.16.0
55 | soundfile==0.12.1
56 | soxr==0.3.7
57 | sympy==1.12
58 | threadpoolctl==3.5.0
59 | tiktoken==0.7.0
60 | tokenizers==0.19.1
61 | torch==2.3.0
62 | tqdm==4.66.4
63 | transformers==4.40.2
64 | typing_extensions==4.11.0
65 | tzdata==2024.1
66 | urllib3==2.2.1
67 | xxhash==3.4.1
68 | yarl==1.9.4
69 |
--------------------------------------------------------------------------------
/switch_test.py:
--------------------------------------------------------------------------------
1 | from utils.initializer import initialize
2 | from mlx_models.switch import run_test as mlx_run
3 | from pytorch_models.switch import run_test as pytorch_run
4 | import numpy as np
5 |
6 |
7 | loop_times = 50
8 | size = 10000
9 | filename = "_switch.txt"
10 |
11 | if __name__ == "__main__":
12 | args, times = initialize()
13 |
14 | for i in range(args.iter):
15 | if args.framework == "mlx":
16 | times[i] = mlx_run(loop_times, size, "mlx" + filename)
17 | else:
18 | times[i] = pytorch_run(loop_times, size, "pytorch" + filename)
19 |
20 | print(f"\nSwitch test: ran {args.iter} times")
21 | print(
22 | f"Framework: {args.framework}\n\tAverage: {np.mean(times)}s - Median: {np.median(times)}s"
23 | )
24 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSte/MLX-vs-Pytorch/b1c8b97dd313455e3cce6e21ea17e8d08d7045a7/utils/__init__.py
--------------------------------------------------------------------------------
/utils/initializer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 |
4 |
5 | def initialize():
6 | parser = argparse.ArgumentParser(description="Language model training benchmark")
7 | parser.add_argument(
8 | "--framework",
9 | help="Which framework to use: pytorch or mlx",
10 | type=str,
11 | )
12 | parser.add_argument(
13 | "--iter", help="How many times to run the test", default=1, type=int
14 | )
15 |
16 | args = parser.parse_args()
17 | times = np.zeros(args.iter)
18 |
19 | return args, times
20 |
--------------------------------------------------------------------------------
/whisper_inference.py:
--------------------------------------------------------------------------------
1 | from utils.initializer import initialize
2 | import time
3 | from pytorch_models.whisper import TorchWhisper
4 | from mlx_models.whisper import MLXWhisper
5 | import numpy as np
6 |
7 |
8 | num_examples = 70
9 | out_filename = "_whisper.txt"
10 |
11 | if __name__ == "__main__":
12 | args, times = initialize()
13 |
14 | for i in range(args.iter):
15 | if args.framework == "mlx":
16 | mlx_model = MLXWhisper(num_examples, "mlx" + out_filename)
17 | start = time.time()
18 | mlx_model.generate()
19 | end = time.time()
20 | mlx_model.finish()
21 | elapsed = end - start
22 | print(f"MLX time: {elapsed}s")
23 | times[i] = elapsed
24 | else:
25 | torch_model = TorchWhisper(num_examples, "pytorch" + out_filename)
26 | start = time.time()
27 | torch_model.generate()
28 | end = time.time()
29 | torch_model.finish()
30 | elapsed = end - start
31 | print(f"Pytorch time: {elapsed}s")
32 | times[i] = elapsed
33 |
34 | print(f"\nWhisper inference test: ran {args.iter} times")
35 | print(
36 | f"Framework: {args.framework}\n\tAverage: {np.mean(times)}s - Median: {np.median(times)}s"
37 | )
38 |
--------------------------------------------------------------------------------