├── mlx_esm ├── __init__.py ├── utils.py ├── cli.py ├── infer.py ├── train.py ├── tokenizer.py ├── data.py └── model.py ├── weights └── esm1-202402151405.npz ├── Justfile ├── LICENSE ├── pyproject.toml ├── notebooks ├── template.ipynb ├── 3dmol.ipynb └── train.ipynb ├── .gitignore └── README.md /mlx_esm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /weights/esm1-202402151405.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/usmanm/mlx-esm/HEAD/weights/esm1-202402151405.npz -------------------------------------------------------------------------------- /Justfile: -------------------------------------------------------------------------------- 1 | typecheck: 2 | poetry run pyright mlx_esm 3 | 4 | format: 5 | poetry run ruff format mlx_esm 6 | poetry run ruff check mlx_esm --fix 7 | 8 | vscode: 9 | poetry run code . 10 | 11 | notebook: 12 | poetry run jupyter lab 13 | -------------------------------------------------------------------------------- /mlx_esm/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Optional 3 | 4 | 5 | # Copy-pasted from tqdm.auto 6 | def is_notebook() -> bool: 7 | try: 8 | get_ipython = sys.modules["IPython"].get_ipython 9 | if "IPKernelApp" not in get_ipython().config: 10 | raise KeyError("IPKernelApp") 11 | return True 12 | except KeyError: 13 | return False 14 | 15 | 16 | # This hackery is needed because without ncols the CLI version becomes shit, but with it 17 | # the notebook version becomes shit. Such is life. 18 | def tqdm_ncols() -> Optional[int]: 19 | return None if is_notebook() else 120 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Usman Masood 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "mlx-esm" 3 | version = "0.1.0" 4 | description = "Implementation of Meta's ESM-1 in MLX" 5 | authors = ["Usman Masood "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.scripts] 9 | cli = "mlx_esm.cli:main" 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.11" 13 | mlx = "^0.2.0" 14 | requests = "^2.31.0" 15 | ruff = "^0.2.1" 16 | pyright = "^1.1.350" 17 | jupyterlab = "^4.1.0" 18 | ipython = "^8.21.0" 19 | matplotlib = "^3.8.2" 20 | numpy = "^1.26.4" 21 | click = "^8.1.7" 22 | tqdm = "^4.66.2" 23 | ipywidgets = "^8.1.2" 24 | py3dmol = "^2.0.4" 25 | 26 | [tool.poetry.group.dev.dependencies] 27 | torch = "^2.2.0" 28 | 29 | [tool.pyright] 30 | # https://github.com/microsoft/pyright/blob/main/docs/configuration.md 31 | useLibraryCodeForTypes = true 32 | exclude = [".cache"] 33 | # strict = ["mlx_esm/**"] 34 | 35 | [tool.ruff] 36 | # https://docs.astral.sh/ruff/configuration/ 37 | line-length = 100 38 | indent-width = 2 39 | 40 | [tool.ruff.lint] 41 | select = ['E', 'W', 'F', 'I', 'B', 'C4', 'ARG', 'SIM'] 42 | ignore = ['W291', 'W292', 'W293'] 43 | 44 | [tool.ruff.format] 45 | quote-style = "double" 46 | indent-style = "space" 47 | docstring-code-format = true 48 | 49 | [build-system] 50 | requires = ["poetry-core"] 51 | build-backend = "poetry.core.masonry.api" 52 | -------------------------------------------------------------------------------- /notebooks/template.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "cbe7bba0-e7ee-4111-9908-26ca7d0666dc", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# This pre-amble sets up auto-reloading of any on-disk modules we are hacking on.\n", 11 | "%load_ext autoreload\n", 12 | "%reload_ext autoreload\n", 13 | "%autoreload 2" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "57980c98-e97c-487b-9c7d-819527919601", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# This pre-amble makes mlx_esm visible to this notebook.\n", 24 | "import os\n", 25 | "import sys\n", 26 | "module_path = os.path.abspath(os.path.join(\"..\"))\n", 27 | "sys.path.insert(0, module_path)" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "652688dc-2693-489f-a4df-2fd6b0ba32ee", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# This pre-amble makes matplotlib available in this notebook.\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline" 40 | ] 41 | } 42 | ], 43 | "metadata": { 44 | "kernelspec": { 45 | "display_name": "Python 3 (ipykernel)", 46 | "language": "python", 47 | "name": "python3" 48 | }, 49 | "language_info": { 50 | "codemirror_mode": { 51 | "name": "ipython", 52 | "version": 3 53 | }, 54 | "file_extension": ".py", 55 | "mimetype": "text/x-python", 56 | "name": "python", 57 | "nbconvert_exporter": "python", 58 | "pygments_lexer": "ipython3", 59 | "version": "3.11.7" 60 | } 61 | }, 62 | "nbformat": 4, 63 | "nbformat_minor": 5 64 | } 65 | -------------------------------------------------------------------------------- /notebooks/3dmol.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "cbe7bba0-e7ee-4111-9908-26ca7d0666dc", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# This pre-amble sets up auto-reloading of any on-disk modules we are hacking on.\n", 11 | "%load_ext autoreload\n", 12 | "%reload_ext autoreload\n", 13 | "%autoreload 2" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "57980c98-e97c-487b-9c7d-819527919601", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# This pre-amble makes mlx_esm visible to this notebook.\n", 24 | "import os\n", 25 | "import sys\n", 26 | "module_path = os.path.abspath(os.path.join(\"..\"))\n", 27 | "sys.path.insert(0, module_path)" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "652688dc-2693-489f-a4df-2fd6b0ba32ee", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# This pre-amble makes matplotlib available in this notebook.\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "5c6d6b42-8f9a-4e7c-acfd-8962a83783e7", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "filename = \"SequenceFour.pdb\"" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 1, 55 | "id": "4ab721eb-d233-4d1b-aace-3100e578dbbd", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "import py3Dmol\n", 60 | "\n", 61 | "\n", 62 | "with open(filename) as f:\n", 63 | " data = \"\".join([x for x in f])\n", 64 | "\n", 65 | "view = py3Dmol.view(width=400, height=300)\n", 66 | "view.addModelsAsFrames(data)\n", 67 | "view.setStyle({\"model\": -1}, {\"cartoon\": {\"color\": \"spectrum\"}})\n", 68 | "view.zoomTo()\n", 69 | "view.show()" 70 | ] 71 | } 72 | ], 73 | "metadata": { 74 | "kernelspec": { 75 | "display_name": "Python 3 (ipykernel)", 76 | "language": "python", 77 | "name": "python3" 78 | }, 79 | "language_info": { 80 | "codemirror_mode": { 81 | "name": "ipython", 82 | "version": 3 83 | }, 84 | "file_extension": ".py", 85 | "mimetype": "text/x-python", 86 | "name": "python", 87 | "nbconvert_exporter": "python", 88 | "pygments_lexer": "ipython3", 89 | "version": "3.11.7" 90 | } 91 | }, 92 | "nbformat": 4, 93 | "nbformat_minor": 5 94 | } 95 | -------------------------------------------------------------------------------- /mlx_esm/cli.py: -------------------------------------------------------------------------------- 1 | import random 2 | from datetime import datetime 3 | from os import path 4 | from typing import Optional 5 | 6 | import click 7 | 8 | from mlx_esm.data import Tokenizer 9 | from mlx_esm.infer import generate, unmask 10 | from mlx_esm.model import ESM1 11 | from mlx_esm.train import Config, Trainer 12 | 13 | 14 | def esm1_model() -> ESM1: 15 | return ESM1(Tokenizer()) 16 | 17 | 18 | @click.command("generate") 19 | @click.option( 20 | "--weights-file", 21 | type=click.Path(exists=True, dir_okay=False, file_okay=True, readable=True), 22 | required=True, 23 | ) 24 | @click.option("--length", type=int, default=lambda: random.randint(32, 96)) 25 | @click.option("--max-prob-only", is_flag=True) 26 | def generate_cmd(weights_file: str, length: int, max_prob_only: bool = False): 27 | """ 28 | Generate a random protein sequence. 29 | """ 30 | m = esm1_model() 31 | m.load_weights(weights_file) 32 | 33 | generate(m, length=length, max_prob_only=max_prob_only) 34 | 35 | 36 | @click.command("unmask") 37 | @click.option( 38 | "--weights-file", 39 | type=click.Path(exists=True, dir_okay=False, file_okay=True, readable=True), 40 | required=True, 41 | ) 42 | @click.option("--seq", type=str, required=True) 43 | @click.option("--max-prob-only", is_flag=True) 44 | def unmask_cmd(weights_file: str, seq: str, max_prob_only: bool = False): 45 | """ 46 | Unmask a masked protein sequence. 47 | """ 48 | m = esm1_model() 49 | m.load_weights(weights_file) 50 | 51 | unmask(m, seq, max_prob_only=max_prob_only) 52 | 53 | 54 | @click.command("train") 55 | @click.option( 56 | "--weights-dir", 57 | type=click.Path(exists=True, dir_okay=True, file_okay=False), 58 | ) 59 | @click.option( 60 | "--weights-file", 61 | type=click.Path(exists=True, dir_okay=False, file_okay=True), 62 | ) 63 | @click.option("--dataset-partitions", type=int, multiple=True, default=lambda: [1, 2, 3]) 64 | @click.option("--num-iters", type=int, default=100_000) 65 | def train_cmd( 66 | weights_dir: Optional[str], 67 | weights_file: Optional[str], 68 | dataset_partitions: list[int], 69 | num_iters: int, 70 | ): 71 | """ 72 | Train a new/existing model and save/updates the weights in a file. 73 | """ 74 | if (weights_dir and weights_file) or (not weights_dir and not weights_file): 75 | raise click.BadParameter("You must provide exactly one of --weights-dir and --weights-file.") 76 | 77 | m = esm1_model() 78 | if weights_file: 79 | m.load_weights(weights_file) 80 | 81 | c = Config(dataset_partitions=dataset_partitions, num_iters=num_iters) 82 | t = Trainer(m, c) 83 | 84 | t.load() 85 | t.train() 86 | t.validate() 87 | 88 | if weights_file: 89 | file_path = weights_file 90 | # Clear the file 91 | with open(file_path, "w"): 92 | pass 93 | else: 94 | now = datetime.now() 95 | time_str = now.strftime("%Y%m%d%H%M") 96 | file_path = path.join(f"{weights_dir}/esm1-{time_str}.npz") 97 | 98 | m.save_weights(file_path) 99 | print(f"💾 weights saved to {file_path}") 100 | 101 | 102 | @click.group() 103 | def main(): 104 | pass 105 | 106 | 107 | main.add_command(generate_cmd) 108 | main.add_command(unmask_cmd) 109 | main.add_command(train_cmd) 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /mlx_esm/infer.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import numpy as np 3 | from tqdm.auto import tqdm 4 | 5 | from mlx_esm.data import Tokenizer 6 | from mlx_esm.model import ESM1 7 | from mlx_esm.utils import tqdm_ncols 8 | 9 | 10 | def generate( 11 | model: ESM1, 12 | length: int, 13 | max_iters: int = 256, 14 | max_prob_only: bool = False, 15 | ): 16 | start_seq = "^" + "*" * length + "$" 17 | return impl(model, start_seq, max_iters, max_prob_only) 18 | 19 | 20 | def unmask( 21 | model: ESM1, 22 | masked_seq: str, 23 | max_iters: int = 256, 24 | max_prob_only: bool = False, 25 | ): 26 | return impl(model, f"^{masked_seq}$", max_iters, max_prob_only) 27 | 28 | 29 | def impl(model: ESM1, input: str, max_iters: int, max_prob_only: bool): 30 | tokenizer = Tokenizer() 31 | 32 | model.eval() 33 | mx.eval(model.parameters()) 34 | 35 | toks = tokenizer.encode(input) 36 | x = mx.array([toks], dtype=mx.int32) 37 | 38 | total = (toks == tokenizer.mask_idx).sum().item() 39 | loop = tqdm( 40 | total=total, 41 | ncols=tqdm_ncols(), 42 | desc="🌱 generating", 43 | ) 44 | for _ in range(max_iters): 45 | toks = x[0] 46 | 47 | if is_sequence_legit(tokenizer, toks): 48 | break 49 | 50 | # forward the model 51 | logits = model(x) 52 | x = compute_next_x(tokenizer, x, logits, max_prob_only) 53 | loop.update() 54 | 55 | loop.close() 56 | 57 | emoji = "🌳" if is_sequence_legit(tokenizer, toks) else "🍂" 58 | s = "".join(tokenizer.decode(toks)).strip().rstrip("%").rstrip("$").lstrip("^") 59 | print(emoji + " hello world: " + s) 60 | 61 | 62 | def compute_next_x( 63 | tokenizer: Tokenizer, 64 | x: mx.array, 65 | logits: mx.array, 66 | max_prob_only: bool = False, 67 | ) -> mx.array: 68 | probs = np.array(mx.softmax(logits, axis=-1)) 69 | 70 | # This is equivalent to multinomial in PyTorch. 71 | if max_prob_only: 72 | samples = np.array([np.argmax(prob) for prob in probs.reshape(-1, probs.shape[-1])]) 73 | else: 74 | samples = np.array( 75 | [ 76 | np.random.choice(range(probs.shape[-1]), p=prob) 77 | for prob in probs.reshape(-1, probs.shape[-1]) 78 | ] 79 | ) 80 | 81 | sample = mx.array(samples.reshape(probs.shape[0], probs.shape[1])) 82 | 83 | # We only swap out the first mask token to generate proetins using 84 | # an autoregressive style. My theory is that this will lead to 85 | # more realistic sequences because the model sees it grow rather 86 | # than guessing the entire sequence in a single shot. 87 | mask_all = x == tokenizer.mask_idx 88 | mask_first = mx.zeros_like(mask_all) 89 | mask_first[mx.arange(mask_all.shape[0]), mask_all.argmax(axis=1)] = 1 90 | 91 | sample = sample * mask_first 92 | x = x * (1 - mask_first) 93 | 94 | return x + sample 95 | 96 | 97 | def is_sequence_legit(tokenizer: Tokenizer, toks: mx.array) -> bool: 98 | tokens: list[int] = toks.tolist() 99 | 100 | # Any invalid tokens in the sequence? 101 | invalid_toks = [tokenizer.mask_idx, tokenizer.unk_idx] 102 | if any(invalid_tok in tokens for invalid_tok in invalid_toks): 103 | return False 104 | 105 | # Remove any padding 106 | while tokens and tokens[-1] == tokenizer.pad_idx: 107 | tokens.pop() 108 | 109 | # Does the sequence start and end with the correct tokens? 110 | if tokens[0] != tokenizer.cls_idx or tokens[-1] != tokenizer.eos_idx: 111 | return False 112 | 113 | # Are only protein tokens in middle of the sequence? 114 | protein_toks: list[int] = tokenizer.encode("".join(tokenizer.protein_toks)).tolist() 115 | return all(tok in protein_toks for tok in tokens[1:-1]) 116 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | data/ 163 | notebooks/scratch.ipynb 164 | -------------------------------------------------------------------------------- /mlx_esm/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from dataclasses import dataclass, field 4 | from functools import partial 5 | from typing import Optional 6 | 7 | import matplotlib.pyplot as plt 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | import mlx.optimizers as optim 11 | from tqdm.auto import tqdm 12 | 13 | from mlx_esm.data import DataSplit, Loader 14 | from mlx_esm.model import ESM1 15 | from mlx_esm.tokenizer import Tokenizer 16 | from mlx_esm.utils import tqdm_ncols 17 | 18 | 19 | def set_seed(seed: int): 20 | random.seed(seed) 21 | mx.random.seed(seed) 22 | 23 | 24 | @dataclass 25 | class Config(object): 26 | # The answer to the ultimate question of life, the universe, and everything. 27 | seed: int = 42 28 | 29 | dataset_partitions: list[int] = field(default_factory=lambda: [1, 2, 3]) 30 | num_iters: int = 100_000 31 | batch_size: int = 16 32 | learning_rate: float = 0.01 33 | mask_rate: float = 0.15 34 | 35 | # The maximum sequence length for proteins to train on. This effectively 36 | # limits the "context size" of the model. Larger contexts are slower to 37 | # train on my GPU-poor MacBook Air. 38 | max_seq_len: int = 126 39 | 40 | 41 | class Trainer(object): 42 | def __init__(self, model: Optional[ESM1] = None, config: Optional[Config] = None): 43 | self.config = config or Config() 44 | self.model = model or ESM1(Tokenizer()) 45 | 46 | self.loader = Loader( 47 | self.model.tokenizer, 48 | self.config.dataset_partitions, 49 | self.config.batch_size, 50 | self.config.max_seq_len, 51 | self.config.mask_rate, 52 | ) 53 | 54 | self.losses: dict[DataSplit, list[float]] = { 55 | "train": [], 56 | "validate": [], 57 | } 58 | 59 | set_seed(self.config.seed) 60 | 61 | def load(self): 62 | print("📥 loading data") 63 | self.loader.load() 64 | 65 | def train(self, num_iters: Optional[int] = None): 66 | return self.run("train", num_iters or self.config.num_iters) 67 | 68 | def validate(self, num_iters: Optional[int] = None): 69 | return self.run("validate", num_iters or int(self.config.num_iters * 0.1)) 70 | 71 | def run(self, split: DataSplit, num_iters: int): 72 | model = self.model 73 | config = self.config 74 | loader = self.loader 75 | 76 | if split == "train": 77 | model.train() 78 | desc = "🚂 training" 79 | else: 80 | model.eval() 81 | desc = "🔍 validating" 82 | 83 | mx.eval(model.parameters()) 84 | 85 | def loss_fn(model: ESM1, x: mx.array, targets: mx.array) -> mx.array: 86 | return mx.mean(nn.losses.cross_entropy(model(x), targets)) 87 | 88 | optimizer = optim.SGD(learning_rate=config.learning_rate, momentum=0.8) 89 | 90 | # https://ml-explore.github.io/mlx/build/html/usage/compile.html#compiling-training-graphs 91 | state = [model.state, optimizer.state] 92 | 93 | @partial(mx.compile, inputs=state, outputs=state) 94 | def step(x: mx.array, y: mx.array): 95 | # forward the model 96 | loss_and_grad_fn = nn.value_and_grad(model, loss_fn) 97 | loss, grads = loss_and_grad_fn(model, x, y) 98 | 99 | # backprop and update the parameters 100 | # Update the optimizer state and model parameters 101 | # in a single call 102 | if split == "train": 103 | optimizer.update(model, grads) 104 | 105 | return loss 106 | 107 | self.iter_num = 0 108 | self.last_log_time = time.time() 109 | 110 | loop = tqdm( 111 | range(num_iters or config.num_iters), 112 | ncols=tqdm_ncols(), 113 | desc=desc, 114 | postfix={"loss": "NaN"}, 115 | ) 116 | for _ in loop: 117 | x, y = loader.next_batch("train") 118 | loss = step(x, y) 119 | mx.eval(state) 120 | 121 | avg_loss = self.avg_loss(split, loss.item()) 122 | loop.set_postfix({"loss": f"{avg_loss:.4f}"}) 123 | 124 | def avg_loss(self, split: DataSplit, loss: float, bucket_size: int = 1000): 125 | losses = self.losses[split] 126 | losses.append(loss) 127 | 128 | window = losses[-bucket_size:] 129 | avg_loss = sum(window) / len(window) 130 | return avg_loss 131 | 132 | def plot_loss( 133 | self, 134 | split: DataSplit, 135 | bucket_size: int = 1000, 136 | start_idx: int = 0, 137 | ): 138 | losses = self.losses[split][start_idx:] 139 | 140 | xs = range(len(losses)) 141 | ys = losses 142 | 143 | # We will bucket the losses to make the plot more readable. 144 | bucketed_xs = range(0, len(xs), bucket_size) 145 | bucketed_ys = [ 146 | sum(ys[i : i + bucket_size]) / len(ys[i : i + bucket_size]) 147 | for i in range(0, len(ys), bucket_size) 148 | ] 149 | 150 | plt.xlabel("iteration") 151 | plt.ylabel("loss") 152 | plt.plot(bucketed_xs, bucketed_ys) 153 | plt.title(f"{'Training' if split == 'train' else 'Validation'} Loss") 154 | plt.show() 155 | -------------------------------------------------------------------------------- /mlx_esm/tokenizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import mlx.core as mx 5 | 6 | 7 | class Tokenizer(object): 8 | def __init__(self): 9 | # Special tokens: start of sequence, masked token, end of sequence, unknown token, padding token 10 | self.special_toks: Tuple[str, ...] = ("^", "*", "$", "?", "%") 11 | self.protein_toks: Tuple[str, ...] = ( 12 | "-", 13 | ".", 14 | "A", 15 | "B", 16 | "C", 17 | "D", 18 | "E", 19 | "F", 20 | "G", 21 | "H", 22 | "I", 23 | "K", 24 | "L", 25 | "M", 26 | "N", 27 | "O", 28 | "P", 29 | "Q", 30 | "R", 31 | "S", 32 | "T", 33 | "U", 34 | "V", 35 | "W", 36 | "X", 37 | "Y", 38 | "Z", 39 | ) 40 | self.all_toks: list[str] = sorted(list(self.special_toks) + list(self.protein_toks)) 41 | 42 | # Make padding_idx always be 0. This allows us to keep positional embeddings simple. 43 | tmp_idx = self.all_toks.index("%") 44 | zero_val = self.all_toks[0] 45 | self.all_toks[0] = "%" 46 | self.all_toks[tmp_idx] = zero_val 47 | 48 | assert self.all_toks[0] == "%" 49 | assert len(self.all_toks) == len(set(self.all_toks)) 50 | 51 | self.idx_to_tok = dict(enumerate(self.all_toks)) 52 | self.tok_to_idx = {tok: idx for idx, tok in enumerate(self.all_toks)} 53 | self.vocab_size = len(self.all_toks) 54 | 55 | self.unk_idx = self.tok_to_idx["?"] 56 | self.pad_idx = self.tok_to_idx["%"] 57 | self.cls_idx = self.tok_to_idx["^"] 58 | self.mask_idx = self.tok_to_idx["*"] 59 | self.eos_idx = self.tok_to_idx["$"] 60 | 61 | assert self.pad_idx == 0 62 | 63 | def tokenize(self, sequence: str) -> list[str]: 64 | # Example: 65 | # -> split_on_token("H", "XYXHADHKJXXX") 66 | # -> ['XYX', 'H', 'AD', 'H', 'KJXXX'] 67 | def split_on_token(tok: str, text: str) -> list[str]: 68 | result: list[str] = [] 69 | split_text = text.split(tok) 70 | for i, sub_text in enumerate(split_text): 71 | sub_text = sub_text.strip() 72 | 73 | if i == len(split_text) - 1: 74 | if sub_text: 75 | result.append(sub_text) 76 | else: 77 | if sub_text: 78 | result.append(sub_text) 79 | result.append(tok) 80 | 81 | assert text == "".join(result) 82 | assert all(s == tok or tok not in s for s in result) 83 | 84 | return result 85 | 86 | # Example: 87 | # -> split_on_tokens(["H", "Y"], "XYXHADHKJXXX") 88 | # -> ['X', 'Y', 'X', 'H', 'AD', 'H', 'KJXXX'] 89 | def split_on_tokens(toks: list[str], text: str) -> list[str]: 90 | if text == "": 91 | return [] 92 | 93 | curr_tokens: list[str] = [text] 94 | next_tokens: list[str] = [] 95 | 96 | for tok in toks: 97 | for sub_text in curr_tokens: 98 | if sub_text not in toks: 99 | next_tokens.extend(split_on_token(tok, sub_text)) 100 | else: 101 | next_tokens.append(sub_text) 102 | curr_tokens = next_tokens 103 | next_tokens = [] 104 | 105 | return curr_tokens 106 | 107 | return split_on_tokens(self.all_toks, sequence.strip()) 108 | 109 | def encode(self, seq: str) -> mx.array: 110 | toks = self.tokenize(seq) 111 | # If token is not present, treat it as "unknown" token. 112 | return mx.array([self.tok_to_idx.get(tok, self.unk_idx) for tok in toks], dtype=mx.int32) 113 | 114 | def decode(self, encoded: mx.array) -> list[str]: 115 | return [self.idx_to_tok[idx] for idx in encoded.tolist()] 116 | 117 | 118 | class BatchTokenizer(object): 119 | def __init__(self, tokenizer: Tokenizer): 120 | self.tokenizer = tokenizer 121 | 122 | def encode(self, sequences: list[str]) -> mx.array: 123 | tokenizer = self.tokenizer 124 | 125 | batch_size = len(sequences) 126 | batch = [tokenizer.encode(seq) for seq in sequences] 127 | 128 | max_tok_len = max(len(toks) for toks in batch) 129 | # +2 for CLS and EOS tokens. Make it a multiple of 8 for performance. 130 | seq_len = math.ceil((max_tok_len + 2) / 8) * 8 131 | 132 | # B = size of batch, L = sequence length 133 | shape = (batch_size, seq_len) 134 | 135 | # (B, L) -> filled with padding token 136 | tokens = mx.full(shape, tokenizer.pad_idx, dtype=mx.int32) 137 | 138 | # Fill the tokens tensor with the actual protein sequence tokens, making 139 | # sure to add a "start of sequence" token at the beginning and an "end of 140 | # sequence" token at the end. We are using a "pad-right" strategy, because 141 | # BERT is a model with absolute position embeddings so it’s usually advised 142 | # to pad the inputs on the right rather than the left. 143 | # 144 | # https://huggingface.co/docs/transformers/model_doc/bert#usage-tips 145 | for idx, toks in enumerate(batch): 146 | # First token of each sequence is the "start of sequence" token. 147 | tokens[idx, 0] = tokenizer.cls_idx 148 | # Then fill in the actual protein sequence tokens. 149 | tokens[idx, 1 : len(toks) + 1] = mx.array(toks) 150 | # Finally, the last token of each sequence is the "end of sequence" token. 151 | tokens[idx, len(toks) + 1] = tokenizer.eos_idx 152 | 153 | return tokens 154 | -------------------------------------------------------------------------------- /mlx_esm/data.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import random 3 | import shutil 4 | import tempfile 5 | from os import path 6 | from typing import Literal, Optional, Tuple 7 | 8 | import mlx.core as mx 9 | import requests 10 | 11 | from mlx_esm.tokenizer import BatchTokenizer, Tokenizer 12 | 13 | DATA_DIR = path.join(path.dirname(__file__), path.pardir, "data") 14 | UNIPARC_DIR_URL = ( 15 | "https://ftp.uniprot.org/pub/databases/uniprot/current_release/uniparc/fasta/active/" 16 | ) 17 | 18 | 19 | def download_file(url: str, path: str): 20 | response = requests.get(url) 21 | response.raise_for_status() 22 | 23 | with open(path, "wb") as f: 24 | f.write(response.content) 25 | 26 | 27 | def extract_gz_file(gz_path: str, dest_path: str): 28 | with gzip.open(gz_path, "rb") as f_in, open(dest_path, "wb") as f_out: 29 | shutil.copyfileobj(f_in, f_out) 30 | 31 | 32 | def load_uniparc_dataset_partitions(ids: list[int]) -> list[str]: 33 | return [item for _id in ids for item in load_uniparc_dataset_partition(_id)] 34 | 35 | 36 | def load_uniparc_dataset_partition(_id: int) -> list[str]: 37 | assert 0 < _id <= 200 38 | 39 | filename = "uniparc_active_p%d.fasta" % _id 40 | filepath = path.join(DATA_DIR, filename) 41 | 42 | if not path.exists(filepath): 43 | url = path.join(UNIPARC_DIR_URL, f"{filename}.gz") 44 | with tempfile.NamedTemporaryFile() as tmp: 45 | download_file(url, tmp.name) 46 | extract_gz_file(tmp.name, filepath) 47 | 48 | # Example: 49 | # 50 | # >UPI0000000563 status=active 51 | # MSGHKCSYPWDLQDRYAQDKSVVNKMQQKYWETKQAFIKATGKKEDEHVVASDADLDAKL 52 | # ELFHSIQRTCLDLSKAIVLYQKRICSF 53 | # >UPI00000005DE status=active 54 | # MGAQDRPQCHFDIEINREPVGRIMFQLFSDICPKTCKNFLCLCSGEKGLGKTTGKKLCYK 55 | # GSTFHRVVKNFMIQGGDFSEGNGKGGESIYGGYFKDENFILKHDRAFLLSMANRGKHTNG 56 | # SQFFITTKPAPHLDGVHVVFGLVISGFEVIEQIENLKTDAASRPYADVRVIDCGVLATKL 57 | # TKDVFEKKRKKPTCSEGSDSSSRSSSSSESSSESEVERETIRRRRHKRRPKVRHAKKRRK 58 | # EMSSSEEPRRKRTVSPEG 59 | with open(filepath, "r") as f: 60 | sequences: list[str] = [] 61 | 62 | current_label = "" 63 | current_value_buf: list[str] = [] 64 | 65 | def _flush_sequence(): 66 | nonlocal current_label 67 | if current_label == "": 68 | return 69 | sequences.append("".join(current_value_buf)) 70 | current_label = "" 71 | current_value_buf.clear() 72 | 73 | for idx, line in enumerate(f): 74 | line = line.strip() 75 | 76 | if line.startswith(">"): 77 | _flush_sequence() 78 | label = line[1:].strip() 79 | current_label = label if label != "" else f"SEQ_{idx}" 80 | else: 81 | current_value_buf.append(line) 82 | 83 | return sequences 84 | 85 | 86 | DataSplit = Literal["train", "validate"] 87 | 88 | 89 | class Loader(object): 90 | def __init__( 91 | self, 92 | tokenizer: Tokenizer, 93 | dataset_partitions: list[int], 94 | batch_size: int, 95 | max_seq_len: int, 96 | mask_rate: float, 97 | ): 98 | self.dataset_partitions = sorted(dataset_partitions) 99 | self.batch_size = batch_size 100 | self.max_seq_len = max_seq_len 101 | self.mask_rate = mask_rate 102 | self.batch_tokenizer = BatchTokenizer(tokenizer) 103 | self.data: Optional[list[str]] = None 104 | self.train_validate_split = 0.9 105 | 106 | def load(self): 107 | if self.data is not None: 108 | return 109 | sequences = load_uniparc_dataset_partitions(self.dataset_partitions) 110 | self.data = [s for s in sequences if len(s) <= self.max_seq_len] 111 | random.shuffle(self.data) 112 | 113 | def next_batch(self, split: DataSplit) -> Tuple[mx.array, mx.array]: 114 | if self.data is None: 115 | raise Exception("data has not been loaded yet") 116 | 117 | split_idx = int(len(self.data) * self.train_validate_split) 118 | data = self.data[:split_idx] if split == "train" else self.data[split_idx:] 119 | 120 | batch: list[str] = random.sample(data, self.batch_size) 121 | 122 | encoded = self.batch_tokenizer.encode(batch) 123 | shape: list[int] = list(encoded.shape) 124 | 125 | tokenizer = self.batch_tokenizer.tokenizer 126 | pad_idx, mask_idx = tokenizer.pad_idx, tokenizer.mask_idx 127 | 128 | # We should not mask padding tokens because they do not form part of the 129 | # underlying protein sequence. We do include CLS & EOS tokens because they 130 | # are part of the sequence in so far that they do inform us about the structure 131 | # of the protein. 132 | can_mask = encoded != pad_idx 133 | 134 | # We should mask tokens with a probability of `mask_rate`. We will use a 135 | # uniform distribution to determine which tokens to mask. By multiplying 136 | # the result of the uniform distribution by `can_mask`, we ensure that we 137 | # do not mask tokens that are padding tokens. 138 | should_mask = (mx.random.uniform(0, 1, shape, dtype=mx.float32) < self.mask_rate) * can_mask 139 | should_not_mask = 1 - should_mask 140 | 141 | # BERT differs from ESM-1 in how it does masking. In ESM-1, we mask tokens 142 | # with the mask token only, while in BERT the masking strategy is a bit more 143 | # complex. We will implement the ESM-1 masking strategy here. 144 | masked = (encoded * should_not_mask) + (should_mask * mask_idx) 145 | 146 | return (masked, encoded) 147 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mlx-esm 2 | 3 | This is an implementation of Meta's [ESM-1](https://huggingface.co/docs/transformers/model_doc/esm) using Apple's [MLX](https://ml-explore.github.io/mlx/build/html/index.html) library. 4 | 5 | ## Backstory 6 | 7 | I've been learning about deep learning and neural nets over the last few months. The two best teachers in this space are [Jeremy Howard](https://twitter.com/jeremyphoward) and [Andrej Karpathy](https://twitter.com/karpathy). Both have an intuitive understanding of neural nets and an amazing capacity to simplify complex ideas for easy understanding. To get started, watch these lectures (1.5x speed recommended): 8 | - [Practical Deep Learning for Coders](https://course.fast.ai/) 9 | - [Neural Networks: Zero to Hero](https://www.youtube.com/watch?v=VMj-3S1tku0&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ) 10 | 11 | I have also been reading up on [TechBio](https://www.nfx.com/post/biotech-to-techbio), and came across Meta's research papers: 12 | - [Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences](https://www.pnas.org/doi/epdf/10.1073/pnas.2016239118) 13 | - [Evolutionary-scale prediction of atomic level protein structure 14 | with a language model](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v3.full.pdf) 15 | 16 | Like all of Meta's AI research, the architecture, source code and weights for these protein language models is [open-source](https://github.com/facebookresearch/esm). 17 | 18 | I'd been wanting to implement and train a neural net which are more than a toy example. So, I'm decided to reimplement ESM-1 from the research paper, but in MLX. ESM-1 is a fork of the [BERT](https://huggingface.co/docs/transformers/model_doc/bert) language model and uses the [masked language modeling](https://huggingface.co/docs/transformers/main/tasks/masked_language_modeling) objective. 19 | 20 | I used the ESM-1 PyTorch [implementation](https://github.com/facebookresearch/esm/blob/main/esm/model/esm1.py) and Bert MLX [implementation](https://github.com/ml-explore/mlx-examples/blob/main/bert/model.py) as a reference. I figured this will provide enough copy-pasta that I can do this quickly, but going from PyTorch to MLX will also expose me to some low-level concepts of neural nets. 21 | 22 | ## Hacking It 23 | 24 | I generally followed Karpathy's development workflow: 25 | - Data loading and tokenizing first. You should have a way to quickly get batches of input tensors to run your model on. Always set the seed to a constant so everything is reproducible. 26 | - Build the training loop with a noop model. Include any helpful logging and plotting that you'll need to make sure when you run the real model, you can diagnose bugs quickly. 27 | - Build the neural net one layer at a time. Typically, you want to start from the input embedding layer and go "deeper" into the neural net. At each layer, inspect the input and output tensor shapes to make sure the layer is doing what you expect it to do. 28 | - Use `.shape()` generously to debug dimensionality issues of tensors. Libraries like PyTorch have magical reshaping capabilities, which mostly just works out of the most. Sometimes though you'll have to test with a simple input tensor to make sure the reshaping is actually doing the right thing. 29 | 30 | Since I haven't really used notebooks much before, my development flow was in VS Code & iTerm. I also finally understood why people love Github Copilot. It is really fucking good when you're not an expert and need help with explaining code, concepts and debugging. It's knowledge of `mlx` is not great, but it knows `pytorch` really well and will provide helpful snippets in its answers. Converting from `mlx` to `pytorch` is fairly straightforward, 90% of the API matches exactly with `pytorch`, the remainder is (I think) JAX inspired. 31 | 32 | ## Trying It Out 33 | 34 | This project uses [Poetry](https://python-poetry.org/) to manage dependencies, so make sure to install it on your system first. Start by cloning the repository and installing all dependencies. 35 | 36 | ``` 37 | git clone git@github.com:usmanm/mlx-esm.git 38 | cd mlx-esm 39 | poetry install 40 | ``` 41 | 42 | ### Training 43 | 44 | You can now train your own ESM1 model. The training script will download [UniParc](https://www.uniprot.org/help/uniparc) dataset. By default, the script will train on only the first 3 partitions for 100K iterations. You can use `--num-iters` and `--dataset-partitions` CLI options to tweak these training parameters. You can also skip this step and just use the weights from my training run directly for inference. 45 | 46 | ``` 47 | ➜ poetry run cli train --weights-dir=./weights 48 | 📥 loading data 49 | 🚂 training: 100%|████████████████████████████████████████████████████████████| 100000/100000 [1:44:43<00:00, 15.91it/s, loss=0.2758] 50 | 🔍 validating: 100%|████████████████████████████████████████████████████████████| 10000/10000 [09:27<00:00, 17.63it/s, loss=0.2766] 51 | 💾 weights saved to ./weights/esm1-202402151405.npz 52 | ``` 53 | 54 | On my Macbook Air M2, training with the default parameters took about 1 hour and 41 minutes. The loss curve looks sensical, so I assume my model is working to some degree. 55 | 56 | Training Loss 57 | 58 | ### Inference 59 | 60 | There are two inference modes: 61 | - `generate`: This generates a new protein from scratch in an auto-regressive manner. You can specify `--length` to control the size of the protein. By default, a random length from the range `[32, 96)` will be picked. 62 | - `unmask`: This takes a masked proteins sequence (some amino acids hidden with `*` character) and replaces the masked tokens with amino acid predictions. 63 | 64 | ``` 65 | ➜ poetry run cli generate --weights-file=./weights/202402151405.npz 66 | 🌱 generating: 100%|████████████████████████████████████████████████████████████| 67/67 [00:00<00:00, 311.70it/s] 67 | 🌳 hello world: RTAAEVGGGHPAGPGGRAEPQVAFGAGDRPGGGRKPYGGGSVSPQAGVQVCTAIYGVTHGAWRLPDK 68 | 69 | ➜ poetry run cli unmask --weights-file=./weights/202402151405.npz --seq="MRAGRGGVPGSGGLRAPPPPLL***LAMLPAAAPRSPALAAAPAGPSVSLYLSEDEVRRLL" 70 | 🌱 generating: 100%|████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 170.82it/s] 71 | 🌳 hello world: MRAGRGGVPGSGGLRAPPPPLLAAALAMLPAAAPRSPALAAAPAGPSVSLYLSEDEVRRLL 72 | ``` 73 | 74 | Given, my GPU poor home infra, I only trained a small model (800K parameters) with ~1.5% of the UniProc dataset. 75 | 76 | I created a [FASTQ](https://knowledge.illumina.com/software/general/software-general-reference_material-list/000002211) file of 4 random proteins my model generated. 77 | 78 | ``` 79 | >SequenceOne 80 | AMDGMAGAGSTDAQAVAFVGEEAVAIALAVRAAIAARGA 81 | >SequenceTwo 82 | DMPVARGRNRSQTARGAQREIRQANSRAETGRVTIATERWAEASVDRSDEPADQEVQALRYAQQNVGWWLPSGSGAAQAGSRPAS 83 | >SequenceThree 84 | MKEVKERVPARSADDSLGVGVVEKIAAKARALEAKPRGAYHGIITVDTVTISTGLN 85 | >SequenceFour 86 | AMGIAAGLLERVAGDASYGGGVAVSQPWAIGGLAGTYERLASAVVRCTGEDEPLDVPIKRPRRRREVTEPRAAIPDIVQREREVRKRSEQQLGFRRALVTGTRVKGGTEFRLDCVGSEERIEVVGV 87 | ``` 88 | 89 | I ran these sequences through [AlphaFold](https://github.com/google-deepmind/alphafold) to predict their molecular structure. The structure comes out in `pdb` files, which I assume are named after the [Protein Data Bank](https://en.wikipedia.org/wiki/Protein_Data_Bank). Next I had to figure out how to render these 3D structures. I found [3Dmol.js](https://3dmol.csb.pitt.edu/), a free JavaScript library for visualizing molecular data which conveniently has [Python bindings](https://pypi.org/project/py3Dmol/) for notebooks. Using it is pretty straight forward, [here's](https://github.com/usmanm/mlx-esm/blob/main/notebooks/3dmol.ipynb) a Jupyter notebook with reference code I used. 90 | 91 | Lo and behold, here's how these sequences may look. 92 | 93 | SequenceOne 94 | SequenceTwo 95 | SequenceThree 96 | SequenceFour 97 | 98 | *Please note that these sequences are almost certainly not valid proteins. The model is too small and trained on very little data. Moreover, my implementation likely has some subtle bugs that I haven't discovered.* 99 | 100 | ## Takeaways 101 | 102 | Open-source ML feels like it's in a really dope spot. Thanks to Nvidia and Transformers, we now have both compute and an architecture that scales with compute. This has unlocked our ability to train really large neural nets. Meta's decision to open-source their AI work allows anyone really to start playing around with these models. 103 | 104 | Neural net architectures have a lego block type feel. They're made of "modules" wrapped behind common interface making them composable. Composing modules together in code sometimes isn't straight-forward though. I believe this is because they use a mathematical structure called [tensor](https://en.wikipedia.org/wiki/Tensor) (think of it as a N-dimensional matrix) to talk to each other. I wish I'd taken some linear algebra courses in college. It would be nice to find a more programming intuitive data structure instead of tensors. 105 | -------------------------------------------------------------------------------- /mlx_esm/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Union 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | 7 | from mlx_esm.data import Tokenizer 8 | 9 | 10 | def count_parameters(params: Union[list, dict, mx.array]) -> int: 11 | if isinstance(params, mx.array): 12 | return params.size 13 | if isinstance(params, list): 14 | return sum([count_parameters(p) for p in params]) 15 | if isinstance(params, dict): 16 | return sum([count_parameters(p) for p in params.values()]) 17 | raise ValueError(f"unknown module type: {type(params)}") 18 | 19 | 20 | class Embedding(nn.Module): 21 | def __init__( 22 | self, 23 | vocab_size: int, 24 | embed_dims: int, 25 | scale: Optional[float] = None, 26 | pad_idx: Optional[int] = None, 27 | ): 28 | super(Embedding, self).__init__() 29 | 30 | self.embed_dims = embed_dims 31 | self.vocab_size = vocab_size 32 | self.scale = scale or math.sqrt(1 / embed_dims) 33 | self.pad_idx = pad_idx 34 | 35 | self.weight = mx.random.normal([vocab_size, embed_dims]) * self.scale 36 | 37 | # The entries at pad_idx do not contribute to the gradient, so 38 | # the embedding vector at pad_idx will default to all zeros. 39 | # 40 | # TODO: Unclear how to disable updating the embedding vector at pad_idx 41 | # during training. In PyTorch, this seems to be implemented in C-level code. 42 | # See: https://github.com/pytorch/pytorch/blob/b85568a54a9c60986235ad1e0cc5dffc71b9d5b1/aten/src/ATen/native/Embedding.cpp#L108 43 | if self.pad_idx is not None: 44 | self.weight[self.pad_idx] = 0 45 | 46 | def __call__(self, x: mx.array) -> mx.array: 47 | # x: (B x L) 48 | # W: (V x C) 49 | # y: (B x L x C) 50 | y = self.weight[x] 51 | 52 | return y 53 | 54 | def __repr__(self): 55 | args = f"vocab_size={self.vocab_size}, embed_dims={self.embed_dims}" 56 | if self.pad_idx is not None: 57 | args += f", pad_idx={self.pad_idx}" 58 | return f"Embedding({args})" 59 | 60 | 61 | # Transformers ditched recurrance in favor of self-attention. This helps with 62 | # parallelization and makes it faster to train on GPUs, but the model loses 63 | # the ability to understand the order of the sequence. To fix this, positional 64 | # encodings are added to the input embeddings. 65 | # 66 | # For a deep dive into position encodings, see: 67 | # https://kazemnejad.com/blog/transformer_architecture_positional_encoding/ 68 | # https://github.com/facebookresearch/esm/blob/main/esm/modules.py#L260 69 | class SinusoidalPositionalEmbedding(nn.Module): 70 | def __init__(self, embed_dims: int, pad_idx: int): 71 | super(SinusoidalPositionalEmbedding, self).__init__() 72 | assert embed_dims % 2 == 0, "embed_dims must be even" 73 | 74 | self.embed_dims = embed_dims 75 | self.pad_idx = pad_idx 76 | self._embeddings = None 77 | 78 | def embeddings(self, max_pos: int): 79 | if self._embeddings is None or max_pos > self._embeddings.shape[0]: 80 | # Creates a series of values that represent the frequencies W_k for the sinusoidal functions 81 | # where each subsequent frequency is an exponential step smaller than the previous one. We 82 | # represent this as a row vector of size half_dim. 83 | half_dim = self.embed_dims // 2 84 | freqs = mx.exp(mx.arange(half_dim, dtype=mx.float32) * -(math.log(10000) / (half_dim - 1))) 85 | 86 | # Create a 2-D column-vector representing the position indices of shape (max_pos, 1) 87 | positions = mx.arange(max_pos, dtype=mx.float32)[..., None] 88 | 89 | # Create a 2-D matrix of shape (max_pos, half_dim). 90 | args = positions * freqs[None, ...] 91 | 92 | # Create a final 2-D matrix of shape (max_pos, embed_dim) by concatenating the 93 | # sin and cos of the scaled positions. 94 | embedding = mx.concatenate([mx.sin(args), mx.cos(args)], axis=-1) 95 | 96 | # No impact of padding token. 97 | embedding[0, :] = 0 98 | 99 | self._embeddings = embedding 100 | 101 | return self._embeddings 102 | 103 | def positions(self, x: mx.array) -> mx.array: 104 | mask = x != self.pad_idx 105 | # We add 1 because postition 0 is reserved for the padding token. 106 | positions = mx.ones(x.shape, dtype=mx.int32) * (mx.arange(x.shape[1], dtype=mx.int32) + 1) 107 | return positions * mask 108 | 109 | def __call__(self, x: mx.array) -> mx.array: 110 | seq_len = x.shape[1] 111 | max_pos = seq_len + 1 112 | 113 | # (>=L, C) 114 | emb = self.embeddings(max_pos)[:max_pos, :] 115 | # (B, L) 116 | pos = self.positions(x) 117 | 118 | # (B, L, C) 119 | y = emb[pos] 120 | 121 | return y 122 | 123 | def __repr__(self): 124 | args = f"embed_dims={self.embed_dims}" 125 | if self.pad_idx is not None: 126 | args += f", pad_idx={self.pad_idx}" 127 | return f"SinusoidalPositionalEmbedding({args})" 128 | 129 | 130 | # https://github.com/facebookresearch/esm/blob/main/esm/modules.py#L44 131 | class LayerNorm(nn.Module): 132 | def __init__(self, embed_dims: int, eps=1e-12): 133 | """ 134 | Construct a layernorm layer in the TF style (eps inside the sqrt). 135 | """ 136 | super(LayerNorm, self).__init__() 137 | 138 | self.embed_dims = embed_dims 139 | self.eps = eps 140 | self.weight = mx.ones(embed_dims) 141 | self.bias = mx.zeros(embed_dims) 142 | 143 | def __call__(self, x: mx.array) -> mx.array: 144 | means = x.mean(-1, keepdims=True) 145 | variances = x.var(-1, keepdims=True) 146 | 147 | v = (x - means) / mx.sqrt(variances + self.eps) 148 | return (self.weight * v) + self.bias 149 | 150 | def __repr__(self): 151 | return f"LayerNorm(embed_dims={self.embed_dims})" 152 | 153 | 154 | class MultiHeadAttention(nn.Module): 155 | def __init__(self, embed_dims: int, num_heads: int, bias: bool = True, add_bias_kv: bool = True): 156 | super(MultiHeadAttention, self).__init__() 157 | 158 | assert embed_dims % num_heads == 0, "embed_dims must be divisible by num_heads" 159 | 160 | self.embed_dims = embed_dims 161 | self.num_heads = num_heads 162 | self.bias = bias 163 | 164 | # TODO: implement adding bias to the key and value projections 165 | self.add_bias_kv = add_bias_kv 166 | 167 | # We use the same dimensions for queries, keys & values. 168 | qdims = embed_dims 169 | kdims = embed_dims 170 | vdims = embed_dims 171 | 172 | self.k_proj = nn.Linear(kdims, embed_dims, bias=bias) 173 | self.v_proj = nn.Linear(vdims, embed_dims, bias=bias) 174 | self.q_proj = nn.Linear(qdims, embed_dims, bias=bias) 175 | self.out_proj = nn.Linear(embed_dims, embed_dims, bias=bias) 176 | 177 | def __call__(self, queries: mx.array, keys: mx.array, values: mx.array) -> mx.array: 178 | H = self.num_heads 179 | 180 | queries = self.q_proj(queries) 181 | B, L, C = queries.shape 182 | assert self.embed_dims == C, "queries has incorrect embed_dims" 183 | 184 | scale = math.sqrt(1.0 / C) 185 | queries = queries * scale 186 | 187 | keys = self.k_proj(keys) 188 | values = self.v_proj(values) 189 | 190 | _, S, _ = keys.shape 191 | K = C // H 192 | assert K * H == C, "embed_dims must be divisible by num_heads" 193 | 194 | # Reshape the queries, keys, and values so we can compute the attention 195 | # on all heads in parallel. 196 | queries = queries.reshape(B, L, H, K).transpose([0, 2, 1, 3]) # (B, H, L, K) 197 | keys = keys.reshape(B, S, H, K).transpose([0, 2, 3, 1]) # (B, H, K, S) 198 | values = values.reshape(B, S, H, K).transpose([0, 2, 1, 3]) # (B, H, S, K) 199 | 200 | scores = queries @ keys # (B, H, L, S) 201 | scores = nn.softmax(scores, axis=-1) 202 | 203 | values_hat = scores @ values # (B, H, L, K) 204 | values_hat = values_hat.transpose([0, 2, 1, 3]) # (B, L, H, K) 205 | values_hat = values_hat.reshape(B, L, C) # (B, L, C) 206 | 207 | return self.out_proj(values_hat) 208 | 209 | def __repr__(self): 210 | args = f"embed_dims={self.embed_dims}, " 211 | args += f"num_heads={self.num_heads}, " 212 | args += f"bias={self.bias}" 213 | return f"MultiHeadAttention({args})" 214 | 215 | 216 | # The Transformer architecture is the driver of this latest wave of AI. 217 | # Its simple architecture makes it easy to parallelize and train on GPUs, 218 | # a bit upgrade from the RNNs and LSTMs of the past. This has unlocked our 219 | # to train really large-scale language models, which have emergent properties 220 | # that are useful for a wide variety of tasks. Meta trained a large LM on 221 | # protein sequences, and with a few additional layers, it was able to achieve 222 | # state-of-the-art results on protein folding and contact prediction. 223 | # 224 | # For an in-depth look at the Transformer architecture, see: 225 | # https://nlp.seas.harvard.edu/annotated-transformer/ 226 | # https://jalammar.github.io/illustrated-transformer/ 227 | # https://github.com/facebookresearch/esm/blob/main/esm/modules.py#L84 228 | class TransformerLayer(nn.Module): 229 | def __init__( 230 | self, 231 | embed_dims: int, 232 | ffn_embed_dims: int, 233 | num_attn_heads: int, 234 | ): 235 | super(TransformerLayer, self).__init__() 236 | 237 | self.embed_dims = embed_dims 238 | self.ffn_embed_dims = ffn_embed_dims 239 | self.num_attn_heads = num_attn_heads 240 | 241 | self.self_attn_layer_norm = LayerNorm(self.embed_dims) 242 | self.self_attn = nn.MultiHeadAttention( 243 | self.embed_dims, 244 | self.num_attn_heads, 245 | bias=True, 246 | ) 247 | 248 | self.final_layer_norm = LayerNorm(self.embed_dims) 249 | self.fc1 = nn.Linear(self.embed_dims, self.ffn_embed_dims) 250 | self.fc2 = nn.Linear(self.ffn_embed_dims, self.embed_dims) 251 | 252 | def __call__(self, x: mx.array) -> mx.array: 253 | residual = x 254 | x = self.self_attn_layer_norm(x) 255 | x = self.self_attn(x, x, x) 256 | x = residual + x 257 | 258 | residual = x 259 | x = self.final_layer_norm(x) 260 | x = nn.gelu(self.fc1(x)) 261 | x = self.fc2(x) 262 | x = residual + x 263 | 264 | return x 265 | 266 | def __repr__(self): 267 | args = f"embed_dims={self.embed_dims}, " 268 | args += f"ffn_embed_dims={self.ffn_embed_dims}, " 269 | args += f"num_attn_heads={self.num_attn_heads}" 270 | return f"TransformerLayer({args})" 271 | 272 | 273 | class ESM1(nn.Module): 274 | # These defaults have been scaled down here. The original defaults are: 275 | # num_layers: 33 276 | # embed_dims: 1280 277 | # ffn_embed_dims: 5120 278 | # num_attn_heads: 20 279 | # final_bias: True 280 | def __init__( 281 | self, 282 | tokenizer: Tokenizer, 283 | num_layers: int = 4, 284 | embed_dims: int = 64, 285 | ffn_embed_dims: int = 256, 286 | num_attn_heads: int = 4, 287 | final_bias: bool = True, 288 | ): 289 | super(ESM1, self).__init__() 290 | 291 | self.tokenizer = tokenizer 292 | self.pad_idx = tokenizer.pad_idx 293 | self.num_layers = num_layers 294 | self.embed_dims = embed_dims 295 | self.ffn_embed_dims = ffn_embed_dims 296 | self.num_attn_heads = num_attn_heads 297 | self.vocab_size = tokenizer.vocab_size 298 | self.final_bias = final_bias 299 | 300 | self.init_submodules() 301 | 302 | def init_submodules(self): 303 | self.embed_tokens = Embedding( 304 | self.vocab_size, 305 | self.embed_dims, 306 | pad_idx=self.pad_idx, 307 | scale=math.sqrt(self.embed_dims), 308 | ) 309 | self.embed_positions = SinusoidalPositionalEmbedding(self.embed_dims, self.pad_idx) 310 | 311 | self.transformer_layers = nn.Sequential( 312 | *[ 313 | TransformerLayer( 314 | self.embed_dims, 315 | self.ffn_embed_dims, 316 | self.num_attn_heads, 317 | ) 318 | for _ in range(self.num_layers) 319 | ] 320 | ) 321 | 322 | self.out = nn.Linear(self.embed_dims, self.vocab_size, bias=self.final_bias) 323 | 324 | def __call__(self, x: mx.array) -> mx.array: 325 | # (B, L) 326 | assert x.ndim == 2 327 | 328 | tok_embed = self.embed_tokens(x) 329 | pos_embed = self.embed_positions(x) 330 | assert tok_embed.shape == pos_embed.shape 331 | 332 | logits = tok_embed + pos_embed 333 | logits = self.transformer_layers(logits) 334 | logits = self.out(logits) 335 | 336 | assert x.shape == logits.shape[:2] and logits.shape[2] == self.vocab_size 337 | 338 | return logits 339 | 340 | def num_parameters(self): 341 | return count_parameters(self.parameters()) 342 | -------------------------------------------------------------------------------- /notebooks/train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "cbe7bba0-e7ee-4111-9908-26ca7d0666dc", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# This pre-amble sets up auto-reloading of any on-disk modules we are hacking on.\n", 11 | "%load_ext autoreload\n", 12 | "%reload_ext autoreload\n", 13 | "%autoreload 2" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "57980c98-e97c-487b-9c7d-819527919601", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# This pre-amble makes mlx_esm visible to this notebook.\n", 24 | "import os\n", 25 | "import sys\n", 26 | "module_path = os.path.abspath(os.path.join(\"..\"))\n", 27 | "sys.path.insert(0, module_path)" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "id": "652688dc-2693-489f-a4df-2fd6b0ba32ee", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# This pre-amble makes matplotlib available in this notebook.\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 9, 45 | "id": "3a4ca955-4dca-479f-86f5-2844d66761ef", 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "📥 loading data\n" 53 | ] 54 | }, 55 | { 56 | "data": { 57 | "application/vnd.jupyter.widget-view+json": { 58 | "model_id": "5acbb36761b54012a84a361ccac8728f", 59 | "version_major": 2, 60 | "version_minor": 0 61 | }, 62 | "text/plain": [ 63 | "🚂 training: 0%| | 0/100000 [00:00" 89 | ] 90 | }, 91 | "metadata": {}, 92 | "output_type": "display_data" 93 | } 94 | ], 95 | "source": [ 96 | "t.plot_loss(\"train\", start_idx=100)" 97 | ] 98 | } 99 | ], 100 | "metadata": { 101 | "kernelspec": { 102 | "display_name": "Python 3 (ipykernel)", 103 | "language": "python", 104 | "name": "python3" 105 | }, 106 | "language_info": { 107 | "codemirror_mode": { 108 | "name": "ipython", 109 | "version": 3 110 | }, 111 | "file_extension": ".py", 112 | "mimetype": "text/x-python", 113 | "name": "python", 114 | "nbconvert_exporter": "python", 115 | "pygments_lexer": "ipython3", 116 | "version": "3.11.7" 117 | } 118 | }, 119 | "nbformat": 4, 120 | "nbformat_minor": 5 121 | } 122 | --------------------------------------------------------------------------------