├── .circleci └── config.yml ├── .gitignore ├── .pre-commit-config.yaml ├── ACKNOWLEDGMENTS.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── bert ├── README.md ├── convert.py ├── model.py ├── requirements.txt ├── test.py └── weights │ └── .gitignore ├── cifar ├── README.md ├── dataset.py ├── main.py ├── requirements.txt └── resnet.py ├── clip ├── .gitignore ├── README.md ├── assets │ ├── cat.jpeg │ └── dog.jpeg ├── clip.py ├── convert.py ├── hf_preproc.py ├── image_processor.py ├── linear_probe.py ├── model.py ├── requirements.txt ├── test.py └── tokenizer.py ├── cvae ├── .gitignore ├── README.md ├── assets │ ├── rec_mnist.png │ └── samples_mnist.png ├── dataset.py ├── main.py ├── requirements.txt └── vae.py ├── encodec ├── README.md ├── benchmarks │ ├── bench_mx.py │ └── bench_pt.py ├── convert.py ├── encodec.py ├── example.py ├── requirements.txt ├── test.py └── utils.py ├── flux ├── README.md ├── dreambooth.py ├── flux │ ├── __init__.py │ ├── autoencoder.py │ ├── clip.py │ ├── datasets.py │ ├── flux.py │ ├── layers.py │ ├── lora.py │ ├── model.py │ ├── sampler.py │ ├── t5.py │ ├── tokenizers.py │ ├── trainer.py │ └── utils.py ├── generate_interactive.py ├── requirements.txt ├── static │ ├── dog-r4-g8-1200-512x1024.png │ ├── dog-r4-g8-1200.png │ ├── dog6.png │ └── generated-mlx.png └── txt2image.py ├── gcn ├── .gitignore ├── README.md ├── datasets.py ├── gcn.py ├── main.py └── requirements.txt ├── llava ├── .gitignore ├── README.md ├── generate.py ├── language.py ├── llava.py ├── requirements.txt ├── test.py └── vision.py ├── llms ├── README.md ├── gguf_llm │ ├── README.md │ ├── generate.py │ ├── models.py │ ├── requirements.txt │ └── utils.py ├── llama │ ├── README.md │ ├── convert.py │ ├── llama.py │ ├── requirements.txt │ └── sample_prompt.txt ├── mistral │ ├── .gitignore │ ├── README.md │ ├── convert.py │ ├── mistral.py │ ├── requirements.txt │ └── test.py ├── mixtral │ ├── README.md │ ├── convert.py │ ├── mixtral.py │ ├── params.json │ └── requirements.txt └── speculative_decoding │ ├── README.md │ ├── convert.py │ ├── decoder.py │ ├── main.py │ ├── model.py │ └── requirements.txt ├── lora ├── .gitignore ├── README.md ├── convert.py ├── data │ ├── test.jsonl │ ├── train.jsonl │ ├── valid.jsonl │ └── wikisql.py ├── fuse.py ├── lora.py ├── models.py ├── requirements.txt └── utils.py ├── mnist ├── README.md ├── main.py ├── mnist.py └── requirements.txt ├── musicgen ├── README.md ├── benchmarks │ ├── bench_mx.py │ └── bench_pt.py ├── encodec.py ├── generate.py ├── musicgen.py ├── requirements.txt ├── t5.py └── utils.py ├── normalizing_flow ├── README.md ├── bijectors.py ├── distributions.py ├── flows.py ├── main.py ├── requirements.txt └── samples.png ├── segment_anything ├── README.md ├── convert.py ├── main.py ├── notebooks │ ├── automatic_mask_generator_example.ipynb │ ├── images │ │ ├── dog.jpg │ │ ├── groceries.jpg │ │ └── truck.jpg │ └── predictor_example.ipynb ├── requirements.txt └── segment_anything │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── predictor.py │ ├── prompt_encoder.py │ ├── sam.py │ ├── transformer.py │ └── utils │ ├── __init__.py │ ├── amg.py │ └── transforms.py ├── speechcommands ├── README.md ├── kwt.py ├── main.py └── requirements.txt ├── stable_diffusion ├── README.md ├── generated-mlx.png ├── im2im.png ├── image2image.py ├── requirements.txt ├── stable_diffusion │ ├── __init__.py │ ├── clip.py │ ├── config.py │ ├── model_io.py │ ├── sampler.py │ ├── tokenizer.py │ ├── unet.py │ └── vae.py ├── still-life.png └── txt2image.py ├── t5 ├── .gitignore ├── README.md ├── hf_t5.py ├── requirements.txt └── t5.py ├── transformer_lm ├── README.md ├── datasets.py ├── main.py └── requirements.txt └── whisper ├── MANIFEST.in ├── README.md ├── benchmark.py ├── convert.py ├── mlx_whisper ├── __init__.py ├── _version.py ├── assets │ ├── download_alice.sh │ ├── gpt2.tiktoken │ ├── ls_test.flac │ ├── mel_filters.npz │ └── multilingual.tiktoken ├── audio.py ├── cli.py ├── decoding.py ├── load_models.py ├── requirements.txt ├── timing.py ├── tokenizer.py ├── torch_whisper.py ├── transcribe.py ├── whisper.py └── writers.py ├── setup.py └── test.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | apple: ml-explore/pr-approval@0.1.0 5 | 6 | jobs: 7 | linux_build_and_test: 8 | docker: 9 | - image: cimg/python:3.9 10 | 11 | steps: 12 | - checkout 13 | - run: 14 | name: Run style checks 15 | command: | 16 | pip install pre-commit 17 | pre-commit run --all 18 | if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi 19 | 20 | workflows: 21 | build_and_test: 22 | when: 23 | matches: 24 | pattern: "^(?!pull/)[-\\w]+$" 25 | value: << pipeline.git.branch >> 26 | jobs: 27 | - linux_build_and_test 28 | 29 | prb: 30 | when: 31 | matches: 32 | pattern: "^pull/\\d+(/head)?$" 33 | value: << pipeline.git.branch >> 34 | jobs: 35 | - hold: 36 | type: approval 37 | - apple/authenticate: 38 | context: pr-approval 39 | - linux_build_and_test: 40 | requires: [ hold ] 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Vim 10 | *.swp 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # IDE files 135 | .idea/ 136 | .vscode/ 137 | 138 | # .DS_Store files 139 | .DS_Store 140 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black-pre-commit-mirror 3 | rev: 25.1.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/pycqa/isort 7 | rev: 6.0.0 8 | hooks: 9 | - id: isort 10 | args: 11 | - --profile=black 12 | -------------------------------------------------------------------------------- /ACKNOWLEDGMENTS.md: -------------------------------------------------------------------------------- 1 | # Individual Contributors 2 | 3 | If you wish to be acknowledged for your contributions, please list your name 4 | with a short description of your contribution(s) below. For example: 5 | 6 | - Jane Smith: Added the `foo` example. 7 | 8 | MLX Examples was developed with contributions from the following individuals: 9 | 10 | - Juarez Bochi: Added support for T5 models. 11 | - Sarthak Yadav: Added the `cifar` and `speechcommands` examples. 12 | - Shunta Saito: Added support for PLaMo models. 13 | - Gabrijel Boduljak: Implemented `CLIP`. 14 | - Markus Enzweiler: Added the `cvae` examples. 15 | - Prince Canuma: Helped add support for `Starcoder2` models. 16 | - Shiyu Li: Added the `Segment Anything Model`. 17 | - Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `OLMoE` archtectures and support for `full-fine-tuning`. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to mlx-examples 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | 1. Fork and submit pull requests to the repo. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. Every PR should have passing tests and at least one review. 11 | 4. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 12 | This should install hooks for running `black` and `clang-format` to ensure 13 | consistent style for C++ and python code. 14 | 15 | You can also run the formatters manually as follows on individual files: 16 | 17 | ```bash 18 | clang-format -i file.cpp 19 | ``` 20 | 21 | ```bash 22 | black file.py 23 | ``` 24 | 25 | or, 26 | 27 | ```bash 28 | # single file 29 | pre-commit run --files file1.py 30 | 31 | # specific files 32 | pre-commit run --files file1.py file2.py 33 | ``` 34 | 35 | or run `pre-commit run --all-files` to check all files in the repo. 36 | 37 | ## Issues 38 | 39 | We use GitHub issues to track public bugs. Please ensure your description is 40 | clear and has sufficient instructions to be able to reproduce the issue. 41 | 42 | ## License 43 | 44 | By contributing to mlx-examples, you agree that your contributions will be licensed 45 | under the LICENSE file in the root directory of this source tree. 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLX Examples 2 | 3 | This repo contains a variety of standalone examples using the [MLX 4 | framework](https://github.com/ml-explore/mlx). 5 | 6 | The [MNIST](mnist) example is a good starting point to learn how to use MLX. 7 | Some more useful examples are listed below. Check-out [MLX 8 | LM](https://github.com/ml-explore/mlx-lm) for a more fully featured Python 9 | package for LLMs with MLX. 10 | 11 | ### Text Models 12 | 13 | - [Transformer language model](transformer_lm) training. 14 | - Minimal examples of large scale text generation with [LLaMA](llms/llama), 15 | [Mistral](llms/mistral), and more in the [LLMs](llms) directory. 16 | - A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral). 17 | - Parameter efficient fine-tuning with [LoRA or QLoRA](lora). 18 | - Text-to-text multi-task Transformers with [T5](t5). 19 | - Bidirectional language understanding with [BERT](bert). 20 | 21 | ### Image Models 22 | 23 | - Generating images 24 | - [FLUX](flux) 25 | - [Stable Diffusion or SDXL](stable_diffusion) 26 | - Image classification using [ResNets on CIFAR-10](cifar). 27 | - Convolutional variational autoencoder [(CVAE) on MNIST](cvae). 28 | 29 | ### Audio Models 30 | 31 | - Speech recognition with [OpenAI's Whisper](whisper). 32 | - Audio compression and generation with [Meta's EnCodec](encodec). 33 | - Music generation with [Meta's MusicGen](musicgen). 34 | 35 | ### Multimodal models 36 | 37 | - Joint text and image embeddings with [CLIP](clip). 38 | - Text generation from image and text inputs with [LLaVA](llava). 39 | - Image segmentation with [Segment Anything (SAM)](segment_anything). 40 | 41 | ### Other Models 42 | 43 | - Semi-supervised learning on graph-structured data with [GCN](gcn). 44 | - Real NVP [normalizing flow](normalizing_flow) for density estimation and 45 | sampling. 46 | 47 | ### Hugging Face 48 | 49 | You can directly use or download converted checkpoints from the [MLX 50 | Community](https://huggingface.co/mlx-community) organization on Hugging Face. 51 | We encourage you to join the community and [contribute new 52 | models](https://github.com/ml-explore/mlx-examples/issues/155). 53 | 54 | ## Contributing 55 | 56 | We are grateful for all of [our 57 | contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute 58 | to MLX Examples and wish to be acknowledged, please add your name to the list in your 59 | pull request. 60 | 61 | ## Citing MLX Examples 62 | 63 | The MLX software suite was initially developed with equal contribution by Awni 64 | Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find 65 | MLX Examples useful in your research and wish to cite it, please use the following 66 | BibTex entry: 67 | 68 | ``` 69 | @software{mlx2023, 70 | author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert}, 71 | title = {{MLX}: Efficient and flexible machine learning on Apple silicon}, 72 | url = {https://github.com/ml-explore}, 73 | version = {0.0}, 74 | year = {2023}, 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /bert/README.md: -------------------------------------------------------------------------------- 1 | # BERT 2 | 3 | An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) in MLX. 4 | 5 | ## Setup 6 | 7 | Install the requirements: 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | Then convert the weights with: 14 | 15 | ``` 16 | python convert.py \ 17 | --bert-model bert-base-uncased \ 18 | --mlx-model weights/bert-base-uncased.npz 19 | ``` 20 | 21 | ## Usage 22 | 23 | To use the `Bert` model in your own code, you can load it with: 24 | 25 | ```python 26 | import mlx.core as mx 27 | from model import Bert, load_model 28 | 29 | model, tokenizer = load_model( 30 | "bert-base-uncased", 31 | "weights/bert-base-uncased.npz") 32 | 33 | batch = ["This is an example of BERT working on MLX."] 34 | tokens = tokenizer(batch, return_tensors="np", padding=True) 35 | tokens = {key: mx.array(v) for key, v in tokens.items()} 36 | 37 | output, pooled = model(**tokens) 38 | ``` 39 | 40 | The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector 41 | for every input token. If you want to train anything at the **token-level**, 42 | use this. 43 | 44 | The `pooled` contains a `Batch x Dims` tensor, which is the pooled 45 | representation for each input. If you want to train a **classification** 46 | model, use this. 47 | 48 | 49 | ## Test 50 | 51 | You can check the output for the default model (`bert-base-uncased`) matches the 52 | Hugging Face version with: 53 | 54 | ``` 55 | python test.py 56 | ``` 57 | -------------------------------------------------------------------------------- /bert/convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy 4 | from transformers import AutoModel 5 | 6 | 7 | def replace_key(key: str) -> str: 8 | key = key.replace(".layer.", ".layers.") 9 | key = key.replace(".self.key.", ".key_proj.") 10 | key = key.replace(".self.query.", ".query_proj.") 11 | key = key.replace(".self.value.", ".value_proj.") 12 | key = key.replace(".attention.output.dense.", ".attention.out_proj.") 13 | key = key.replace(".attention.output.LayerNorm.", ".ln1.") 14 | key = key.replace(".output.LayerNorm.", ".ln2.") 15 | key = key.replace(".intermediate.dense.", ".linear1.") 16 | key = key.replace(".output.dense.", ".linear2.") 17 | key = key.replace(".LayerNorm.", ".norm.") 18 | key = key.replace("pooler.dense.", "pooler.") 19 | return key 20 | 21 | 22 | def convert(bert_model: str, mlx_model: str) -> None: 23 | model = AutoModel.from_pretrained(bert_model) 24 | # save the tensors 25 | tensors = { 26 | replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items() 27 | } 28 | numpy.savez(mlx_model, **tensors) 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") 33 | parser.add_argument( 34 | "--bert-model", 35 | type=str, 36 | default="bert-base-uncased", 37 | help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.", 38 | ) 39 | parser.add_argument( 40 | "--mlx-model", 41 | type=str, 42 | default="weights/bert-base-uncased.npz", 43 | help="The output path for the MLX BERT weights.", 44 | ) 45 | args = parser.parse_args() 46 | 47 | convert(args.bert_model, args.mlx_model) 48 | -------------------------------------------------------------------------------- /bert/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.0.5 2 | transformers 3 | numpy 4 | -------------------------------------------------------------------------------- /bert/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import List 3 | 4 | import model 5 | import numpy as np 6 | from transformers import AutoModel, AutoTokenizer 7 | 8 | 9 | def run_torch(bert_model: str, batch: List[str]): 10 | tokenizer = AutoTokenizer.from_pretrained(bert_model) 11 | torch_model = AutoModel.from_pretrained(bert_model) 12 | torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) 13 | torch_forward = torch_model(**torch_tokens) 14 | torch_output = torch_forward.last_hidden_state.detach().numpy() 15 | torch_pooled = torch_forward.pooler_output.detach().numpy() 16 | return torch_output, torch_pooled 17 | 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser( 21 | description="Run a BERT-like model for a batch of text." 22 | ) 23 | parser.add_argument( 24 | "--bert-model", 25 | type=str, 26 | default="bert-base-uncased", 27 | help="The model identifier for a BERT-like model from Hugging Face Transformers.", 28 | ) 29 | parser.add_argument( 30 | "--mlx-model", 31 | type=str, 32 | default="weights/bert-base-uncased.npz", 33 | help="The path of the stored MLX BERT weights (npz file).", 34 | ) 35 | parser.add_argument( 36 | "--text", 37 | nargs="+", 38 | default=["This is an example of BERT working in MLX."], 39 | help="A batch of texts to process. Multiple texts should be separated by spaces.", 40 | ) 41 | 42 | args = parser.parse_args() 43 | 44 | torch_output, torch_pooled = run_torch(args.bert_model, args.text) 45 | 46 | mlx_output, mlx_pooled = model.run(args.bert_model, args.mlx_model, args.text) 47 | 48 | if torch_pooled is not None and mlx_pooled is not None: 49 | assert np.allclose( 50 | torch_output, mlx_output, rtol=1e-4, atol=1e-5 51 | ), "Model output is different" 52 | assert np.allclose( 53 | torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5 54 | ), "Model pooled output is different" 55 | print("Tests pass :)") 56 | else: 57 | print("Pooled outputs were not compared due to one or both being None.") 58 | -------------------------------------------------------------------------------- /bert/weights/.gitignore: -------------------------------------------------------------------------------- 1 | *.npz -------------------------------------------------------------------------------- /cifar/README.md: -------------------------------------------------------------------------------- 1 | # CIFAR and ResNets 2 | 3 | An example of training a ResNet on CIFAR-10 with MLX. Several ResNet 4 | configurations in accordance with the original 5 | [paper](https://arxiv.org/abs/1512.03385) are available. The example also 6 | illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to 7 | load the dataset. 8 | 9 | ## Pre-requisites 10 | 11 | Install the dependencies: 12 | 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Running the example 18 | 19 | Run the example with: 20 | 21 | ``` 22 | python main.py 23 | ``` 24 | 25 | By default the example runs on the GPU. To run on the CPU, use: 26 | 27 | ``` 28 | python main.py --cpu 29 | ``` 30 | 31 | For all available options, run: 32 | 33 | ``` 34 | python main.py --help 35 | ``` 36 | 37 | ## Results 38 | 39 | After training with the default `resnet20` architecture for 30 epochs, you 40 | should see the following results: 41 | 42 | ``` 43 | Epoch: 29 | avg. Train loss 0.294 | avg. Train acc 0.897 | Throughput: 270.81 images/sec 44 | Epoch: 29 | Test acc 0.841 45 | ``` 46 | 47 | Note this was run on an M1 Macbook Pro with 16GB RAM. 48 | 49 | At the time of writing, `mlx` doesn't have built-in learning rate schedules. 50 | We intend to update this example once these features are added. 51 | 52 | ## Distributed training 53 | 54 | The example also supports distributed data parallel training. You can launch a 55 | distributed training as follows: 56 | 57 | ```shell 58 | $ cat >hostfile.json 59 | [ 60 | {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}, 61 | {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]} 62 | ] 63 | $ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20 64 | ``` 65 | -------------------------------------------------------------------------------- /cifar/dataset.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import numpy as np 3 | from mlx.data.datasets import load_cifar10 4 | 5 | 6 | def get_cifar10(batch_size, root=None): 7 | tr = load_cifar10(root=root) 8 | 9 | mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) 10 | std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) 11 | 12 | def normalize(x): 13 | x = x.astype("float32") / 255.0 14 | return (x - mean) / std 15 | 16 | group = mx.distributed.init() 17 | 18 | tr_iter = ( 19 | tr.shuffle() 20 | .partition_if(group.size() > 1, group.size(), group.rank()) 21 | .to_stream() 22 | .image_random_h_flip("image", prob=0.5) 23 | .pad("image", 0, 4, 4, 0.0) 24 | .pad("image", 1, 4, 4, 0.0) 25 | .image_random_crop("image", 32, 32) 26 | .key_transform("image", normalize) 27 | .batch(batch_size) 28 | .prefetch(4, 4) 29 | ) 30 | 31 | test = load_cifar10(root=root, train=False) 32 | test_iter = ( 33 | test.to_stream() 34 | .partition_if(group.size() > 1, group.size(), group.rank()) 35 | .key_transform("image", normalize) 36 | .batch(batch_size) 37 | ) 38 | 39 | return tr_iter, test_iter 40 | -------------------------------------------------------------------------------- /cifar/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.2 2 | mlx-data 3 | numpy 4 | -------------------------------------------------------------------------------- /cifar/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385]. 3 | Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202. 4 | """ 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | from mlx.utils import tree_flatten 9 | 10 | __all__ = [ 11 | "ResNet", 12 | "resnet20", 13 | "resnet32", 14 | "resnet44", 15 | "resnet56", 16 | "resnet110", 17 | "resnet1202", 18 | ] 19 | 20 | 21 | class ShortcutA(nn.Module): 22 | def __init__(self, dims): 23 | super().__init__() 24 | self.dims = dims 25 | 26 | def __call__(self, x): 27 | return mx.pad( 28 | x[:, ::2, ::2, :], 29 | pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)], 30 | ) 31 | 32 | 33 | class Block(nn.Module): 34 | """ 35 | Implements a ResNet block with two convolutional layers and a skip connection. 36 | As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details) 37 | """ 38 | 39 | def __init__(self, in_dims, dims, stride=1): 40 | super().__init__() 41 | 42 | self.conv1 = nn.Conv2d( 43 | in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False 44 | ) 45 | self.bn1 = nn.BatchNorm(dims) 46 | 47 | self.conv2 = nn.Conv2d( 48 | dims, dims, kernel_size=3, stride=1, padding=1, bias=False 49 | ) 50 | self.bn2 = nn.BatchNorm(dims) 51 | 52 | if stride != 1: 53 | self.shortcut = ShortcutA(dims) 54 | else: 55 | self.shortcut = None 56 | 57 | def __call__(self, x): 58 | out = nn.relu(self.bn1(self.conv1(x))) 59 | out = self.bn2(self.conv2(out)) 60 | if self.shortcut is None: 61 | out += x 62 | else: 63 | out += self.shortcut(x) 64 | out = nn.relu(out) 65 | return out 66 | 67 | 68 | class ResNet(nn.Module): 69 | """ 70 | Creates a ResNet model for CIFAR-10, as specified in the original paper. 71 | """ 72 | 73 | def __init__(self, block, num_blocks, num_classes=10): 74 | super().__init__() 75 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 76 | self.bn1 = nn.BatchNorm(16) 77 | 78 | self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1) 79 | self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2) 80 | self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2) 81 | 82 | self.linear = nn.Linear(64, num_classes) 83 | 84 | def _make_layer(self, block, in_dims, dims, num_blocks, stride): 85 | strides = [stride] + [1] * (num_blocks - 1) 86 | layers = [] 87 | for stride in strides: 88 | layers.append(block(in_dims, dims, stride)) 89 | in_dims = dims 90 | return nn.Sequential(*layers) 91 | 92 | def num_params(self): 93 | nparams = sum(x.size for k, x in tree_flatten(self.parameters())) 94 | return nparams 95 | 96 | def __call__(self, x): 97 | x = nn.relu(self.bn1(self.conv1(x))) 98 | x = self.layer1(x) 99 | x = self.layer2(x) 100 | x = self.layer3(x) 101 | x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1) 102 | x = self.linear(x) 103 | return x 104 | 105 | 106 | def resnet20(**kwargs): 107 | return ResNet(Block, [3, 3, 3], **kwargs) 108 | 109 | 110 | def resnet32(**kwargs): 111 | return ResNet(Block, [5, 5, 5], **kwargs) 112 | 113 | 114 | def resnet44(**kwargs): 115 | return ResNet(Block, [7, 7, 7], **kwargs) 116 | 117 | 118 | def resnet56(**kwargs): 119 | return ResNet(Block, [9, 9, 9], **kwargs) 120 | 121 | 122 | def resnet110(**kwargs): 123 | return ResNet(Block, [18, 18, 18], **kwargs) 124 | 125 | 126 | def resnet1202(**kwargs): 127 | return ResNet(Block, [200, 200, 200], **kwargs) 128 | -------------------------------------------------------------------------------- /clip/.gitignore: -------------------------------------------------------------------------------- 1 | mlx_model/ 2 | -------------------------------------------------------------------------------- /clip/README.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | 3 | An example of OpenAI's CLIP in MLX. The CLIP (contrastive language-image 4 | pre-training) model embeds images and text in the same space.[^1] 5 | 6 | ### Setup 7 | 8 | Install the dependencies: 9 | 10 | ```shell 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | Next, download a CLIP model from Hugging Face and convert it to MLX. The 15 | default model is 16 | [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32). 17 | 18 | ``` 19 | python convert.py 20 | ``` 21 | 22 | The script will by default download the model and configuration files to the 23 | directory ``mlx_model/``. 24 | 25 | ### Run 26 | 27 | You can use the CLIP model to embed images and text. 28 | 29 | ```python 30 | from PIL import Image 31 | import clip 32 | 33 | model, tokenizer, img_processor = clip.load("mlx_model") 34 | inputs = { 35 | "input_ids": tokenizer(["a photo of a cat", "a photo of a dog"]), 36 | "pixel_values": img_processor( 37 | [Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")] 38 | ), 39 | } 40 | output = model(**inputs) 41 | 42 | # Get text and image embeddings: 43 | text_embeds = output.text_embeds 44 | image_embeds = output.image_embeds 45 | ``` 46 | 47 | Run the above example with `python clip.py`. 48 | 49 | To embed only images or only the text, pass only the ``input_ids`` or 50 | ``pixel_values``, respectively. 51 | 52 | This example re-implements minimal image preprocessing and tokenization to reduce 53 | dependencies. For additional preprocessing functionality, you can use 54 | ``transformers``. The file `hf_preproc.py` has an example. 55 | 56 | MLX CLIP has been tested and works with the following Hugging Face repos: 57 | 58 | - [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) 59 | - [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) 60 | 61 | You can run the tests with: 62 | 63 | ```shell 64 | python test.py 65 | ``` 66 | 67 | To test new models, update the `MLX_PATH` and `HF_PATH` in `test.py`. 68 | 69 | ### Attribution 70 | 71 | - `assets/cat.jpeg` is a "Cat" by London's, licensed under CC BY-SA 2.0. 72 | - `assets/dog.jpeg` is a "Happy Dog" by tedmurphy, licensed under CC BY 2.0. 73 | 74 | [^1]: Refer to the original paper [Learning Transferable Visual Models From 75 | Natural Language Supervision ](https://arxiv.org/abs/2103.00020) or [blog 76 | post](https://openai.com/research/clip) 77 | -------------------------------------------------------------------------------- /clip/assets/cat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/clip/assets/cat.jpeg -------------------------------------------------------------------------------- /clip/assets/dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/clip/assets/dog.jpeg -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from image_processor import CLIPImageProcessor 4 | from model import CLIPModel 5 | from tokenizer import CLIPTokenizer 6 | 7 | 8 | def load(model_dir: str) -> Tuple[CLIPModel, CLIPTokenizer, CLIPImageProcessor]: 9 | model = CLIPModel.from_pretrained(model_dir) 10 | tokenizer = CLIPTokenizer.from_pretrained(model_dir) 11 | img_processor = CLIPImageProcessor.from_pretrained(model_dir) 12 | return model, tokenizer, img_processor 13 | 14 | 15 | if __name__ == "__main__": 16 | from PIL import Image 17 | 18 | model, tokenizer, img_processor = load("mlx_model") 19 | inputs = { 20 | "input_ids": tokenizer(["a photo of a cat", "a photo of a dog"]), 21 | "pixel_values": img_processor( 22 | [Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")] 23 | ), 24 | } 25 | output = model(**inputs) 26 | 27 | # Get text and image embeddings: 28 | text_embeds = output.text_embeds 29 | image_embeds = output.image_embeds 30 | print("Text embeddings shape:", text_embeds.shape) 31 | print("Image embeddings shape:", image_embeds.shape) 32 | -------------------------------------------------------------------------------- /clip/convert.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import argparse 4 | import json 5 | import shutil 6 | from pathlib import Path 7 | from typing import Any, Dict, Union 8 | 9 | import mlx.core as mx 10 | import torch 11 | from huggingface_hub import snapshot_download 12 | 13 | 14 | def make_shards(weights: dict, max_file_size_gb: int = 5) -> list: 15 | max_file_size_bytes = max_file_size_gb << 30 16 | shards = [] 17 | shard, shard_size = {}, 0 18 | for k, v in weights.items(): 19 | if shard_size + v.nbytes > max_file_size_bytes: 20 | shards.append(shard) 21 | shard, shard_size = {}, 0 22 | shard[k] = v 23 | shard_size += v.nbytes 24 | shards.append(shard) 25 | return shards 26 | 27 | 28 | def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: 29 | """Save model weights into specified directory.""" 30 | if isinstance(save_path, str): 31 | save_path = Path(save_path) 32 | save_path.mkdir(parents=True, exist_ok=True) 33 | 34 | shards = make_shards(weights) 35 | shards_count = len(shards) 36 | shard_file_format = ( 37 | "model-{:05d}-of-{:05d}.safetensors" 38 | if shards_count > 1 39 | else "model.safetensors" 40 | ) 41 | 42 | total_size = sum(v.nbytes for v in weights.values()) 43 | index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} 44 | 45 | for i, shard in enumerate(shards): 46 | shard_name = shard_file_format.format(i + 1, shards_count) 47 | shard_path = save_path / shard_name 48 | 49 | mx.save_safetensors(str(shard_path), shard) 50 | 51 | for weight_name in shard.keys(): 52 | index_data["weight_map"][weight_name] = shard_name 53 | 54 | index_data["weight_map"] = { 55 | k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) 56 | } 57 | 58 | with open(save_path / "model.safetensors.index.json", "w") as f: 59 | json.dump( 60 | index_data, 61 | f, 62 | indent=4, 63 | ) 64 | 65 | 66 | def get_model_path(path_or_hf_repo: str, force_download: bool = False) -> Path: 67 | model_path = Path(path_or_hf_repo) 68 | if not model_path.exists(): 69 | model_path = Path( 70 | snapshot_download( 71 | repo_id=path_or_hf_repo, 72 | allow_patterns=[ 73 | "*.bin", 74 | "*.json", 75 | "*.txt", 76 | ], 77 | force_download=force_download, 78 | ) 79 | ) 80 | return model_path 81 | 82 | 83 | def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array: 84 | # bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss 85 | a = a.to(torch.float32) if dtype == "bfloat16" else a.to(getattr(torch, dtype)) 86 | return mx.array(a.numpy(), getattr(mx, dtype)) 87 | 88 | 89 | if __name__ == "__main__": 90 | parser = argparse.ArgumentParser( 91 | description="Download and Convert (OpenAI) CLIP weights to MLX" 92 | ) 93 | parser.add_argument( 94 | "--hf-repo", 95 | type=str, 96 | default="openai/clip-vit-base-patch32", 97 | help="Hugging Face repository name.", 98 | ) 99 | parser.add_argument( 100 | "--mlx-path", 101 | type=str, 102 | default="mlx_model", 103 | help="Path to save the MLX model.", 104 | ) 105 | parser.add_argument( 106 | "--dtype", 107 | help="The data type to save the converted model.", 108 | type=str, 109 | default="float32", 110 | ) 111 | parser.add_argument( 112 | "-f", 113 | "--force-download", 114 | help="Force download the model from Hugging Face.", 115 | action="store_true", 116 | ) 117 | args = parser.parse_args() 118 | 119 | torch_path = get_model_path(args.hf_repo, args.force_download) 120 | mlx_path = Path(args.mlx_path) 121 | mlx_path.mkdir(parents=True, exist_ok=True) 122 | 123 | print("[INFO] Loading") 124 | torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True) 125 | print("[INFO] Converting") 126 | mlx_weights = { 127 | k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items() 128 | } 129 | print("[INFO] Saving") 130 | save_weights(mlx_path, mlx_weights) 131 | for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]: 132 | shutil.copyfile( 133 | str(torch_path / f"{fn}"), 134 | str(mlx_path / f"{fn}"), 135 | ) 136 | -------------------------------------------------------------------------------- /clip/hf_preproc.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import transformers 3 | from PIL import Image 4 | 5 | import clip 6 | 7 | hf_model = "openai/clip-vit-base-patch32" 8 | mlx_model = "mlx_model" 9 | 10 | model, *_ = clip.load(mlx_model) 11 | processor = transformers.CLIPProcessor.from_pretrained(hf_model) 12 | 13 | inputs = processor( 14 | text=["a photo of a cat", "a photo of a dog"], 15 | images=[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")], 16 | return_tensors="np", 17 | ) 18 | 19 | out = model( 20 | input_ids=mx.array(inputs.input_ids), 21 | pixel_values=mx.array(inputs.pixel_values).transpose((0, 2, 3, 1)), 22 | return_loss=True, 23 | ) 24 | 25 | print("text embeddings:") 26 | print(out.text_embeds) 27 | print("image embeddings:") 28 | print(out.image_embeds) 29 | print(f"CLIP loss: {out.loss.item():.3f}") 30 | -------------------------------------------------------------------------------- /clip/image_processor.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import json 4 | from pathlib import Path 5 | from typing import List, Tuple 6 | 7 | import mlx.core as mx 8 | import numpy as np 9 | from PIL.Image import Image 10 | 11 | 12 | class CLIPImageProcessor: 13 | """ 14 | A simple port of 15 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | crop_size: int = 224, 21 | do_center_crop: bool = True, 22 | do_normalize: bool = True, 23 | do_resize: bool = True, 24 | image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073], 25 | image_std: List[float] = [0.26862954, 0.26130258, 0.27577711], 26 | size: int = 224, 27 | **kwargs 28 | ) -> None: 29 | self.crop_size = crop_size 30 | self.do_center_crop = do_center_crop 31 | self.do_normalize = do_normalize 32 | self.do_resize = do_resize 33 | self.image_mean = mx.array(image_mean) 34 | self.image_std = mx.array(image_std) 35 | self.size = size 36 | 37 | def __call__(self, images: List[Image]) -> mx.array: 38 | return mx.concatenate( 39 | [self._preprocess(image)[None] for image in images], axis=0 40 | ) 41 | 42 | def _preprocess(self, image: Image) -> mx.array: 43 | if self.do_resize: 44 | image = resize(image, self.size) 45 | if self.do_center_crop: 46 | image = center_crop(image, (self.crop_size, self.crop_size)) 47 | image = mx.array(np.array(image)) 48 | image = rescale(image) 49 | if self.do_normalize: 50 | image = normalize(image, self.image_mean, self.image_std) 51 | return image 52 | 53 | @staticmethod 54 | def from_pretrained(path: str): 55 | path = Path(path) 56 | with open(path / "preprocessor_config.json", encoding="utf-8") as f: 57 | config = json.load(f) 58 | return CLIPImageProcessor(**config) 59 | 60 | 61 | def resize(image: Image, short_size: int) -> Image: 62 | """ 63 | Resize so small size to short_size 64 | """ 65 | width, height = image.size 66 | short = min(width, height) 67 | long = max(width, height) 68 | if short == short_size: 69 | return image 70 | new_short = short_size 71 | new_long = int(short_size * long / short) 72 | new_size = (new_short, new_long) if width <= height else (new_long, new_short) 73 | return image.resize(new_size) 74 | 75 | 76 | def center_crop(image: Image, size: Tuple[int, int]) -> Image: 77 | if size[0] % 2 != 0 or size[1] % 2 != 0: 78 | raise ValueError("Only even crop sizes supported.") 79 | original_width, original_height = image.size 80 | crop_height, crop_width = size 81 | top = (original_height - crop_height) // 2 82 | bottom = top + crop_height 83 | left = (original_width - crop_width) // 2 84 | right = left + crop_width 85 | return image.crop((left, top, right, bottom)) 86 | 87 | 88 | def rescale(image: mx.array) -> mx.array: 89 | return image.astype(mx.float32) * (1 / 255.0) 90 | 91 | 92 | def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array: 93 | return (image - mean) / std 94 | -------------------------------------------------------------------------------- /clip/linear_probe.py: -------------------------------------------------------------------------------- 1 | # Mirror of the Linear Probe Evaluation Script 2 | # from the official CLIP Repository. 3 | 4 | import mlx.core as mx 5 | import numpy as np 6 | from image_processor import CLIPImageProcessor 7 | from mlx.data.datasets import load_cifar10 8 | from model import CLIPModel 9 | from PIL import Image 10 | from sklearn.linear_model import LogisticRegression 11 | from tqdm import tqdm 12 | 13 | 14 | def get_cifar10(batch_size, root=None): 15 | tr = load_cifar10(root=root).batch(batch_size) 16 | test = load_cifar10(root=root, train=False).batch(batch_size) 17 | 18 | return tr, test 19 | 20 | 21 | def get_features(model, image_proc, iter): 22 | all_features = [] 23 | all_labels = [] 24 | 25 | for batch in tqdm(iter): 26 | image, label = batch["image"], batch["label"] 27 | x = image_proc([Image.fromarray(im) for im in image]) 28 | y = mx.array(label) 29 | 30 | image_embeds = model.get_image_features(x) 31 | mx.eval(image_embeds) 32 | 33 | all_features.append(image_embeds) 34 | all_labels.append(y) 35 | 36 | return mx.concatenate(all_features), mx.concatenate(all_labels) 37 | 38 | 39 | if __name__ == "__main__": 40 | model = CLIPModel.from_pretrained("mlx_model") 41 | image_proc = CLIPImageProcessor.from_pretrained("mlx_model") 42 | 43 | train_iter, test_iter = get_cifar10(batch_size=256) 44 | train_features, train_labels = get_features(model, image_proc, train_iter) 45 | test_features, test_labels = get_features(model, image_proc, test_iter) 46 | 47 | # Perform logistic regression 48 | # NOTE: The value of C should be determined via a hyperparameter sweep 49 | # using a validation split 50 | classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) 51 | classifier.fit(train_features, train_labels) 52 | 53 | # Evaluate using the logistic regression classifier 54 | predictions = classifier.predict(test_features) 55 | accuracy = (test_labels.squeeze() == predictions).mean().item() * 100 56 | print(f"Accuracy = {accuracy:.3f}") 57 | -------------------------------------------------------------------------------- /clip/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx 2 | mlx-data 3 | numpy 4 | transformers 5 | torch 6 | huggingface_hub 7 | Pillow 8 | -------------------------------------------------------------------------------- /clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import json 4 | from pathlib import Path 5 | from typing import Any 6 | 7 | import mlx.core as mx 8 | import regex 9 | 10 | 11 | class CLIPTokenizer: 12 | """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" 13 | 14 | def __init__(self, bpe_ranks, vocab): 15 | self.bpe_ranks = bpe_ranks 16 | self.vocab = vocab 17 | self.pat = regex.compile( 18 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 19 | regex.IGNORECASE, 20 | ) 21 | self._cache = {self.bos: self.bos, self.eos: self.eos} 22 | 23 | @property 24 | def bos(self): 25 | return "<|startoftext|>" 26 | 27 | @property 28 | def bos_token(self): 29 | return self.vocab[self.bos] 30 | 31 | @property 32 | def eos(self): 33 | return "<|endoftext|>" 34 | 35 | @property 36 | def eos_token(self): 37 | return self.vocab[self.eos] 38 | 39 | def bpe(self, text): 40 | if text in self._cache: 41 | return self._cache[text] 42 | 43 | unigrams = list(text[:-1]) + [text[-1] + ""] 44 | unique_bigrams = set(zip(unigrams, unigrams[1:])) 45 | 46 | if not unique_bigrams: 47 | return unigrams 48 | 49 | # In every iteration try to merge the two most likely bigrams. If none 50 | # was merged we are done. 51 | # 52 | # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_py 53 | while unique_bigrams: 54 | bigram = min( 55 | unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")) 56 | ) 57 | if bigram not in self.bpe_ranks: 58 | break 59 | 60 | new_unigrams = [] 61 | skip = False 62 | for a, b in zip(unigrams, unigrams[1:]): 63 | if skip: 64 | skip = False 65 | continue 66 | 67 | if (a, b) == bigram: 68 | new_unigrams.append(a + b) 69 | skip = True 70 | 71 | else: 72 | new_unigrams.append(a) 73 | 74 | if not skip: 75 | new_unigrams.append(b) 76 | 77 | unigrams = new_unigrams 78 | unique_bigrams = set(zip(unigrams, unigrams[1:])) 79 | 80 | self._cache[text] = unigrams 81 | 82 | return unigrams 83 | 84 | def __call__(self, *args: Any, **kwargs: Any) -> Any: 85 | return self.tokenize(*args, **kwargs) 86 | 87 | def tokenize(self, text, prepend_bos=True, append_eos=True) -> mx.array: 88 | if isinstance(text, list): 89 | return mx.array([self.tokenize(t, prepend_bos, append_eos) for t in text]) 90 | 91 | # Lower case, cleanup, and split. Hugging Face does a much, 92 | # more thorough job here but this should suffice for 95% of 93 | # cases. 94 | clean_text = regex.sub(r"\s+", " ", text.lower()) 95 | tokens = regex.findall(self.pat, clean_text) 96 | 97 | # Split the tokens according to the byte-pair merge file 98 | bpe_tokens = [ti for t in tokens for ti in self.bpe(t)] 99 | 100 | # Map to token ids and return 101 | tokens = [] 102 | if prepend_bos: 103 | tokens.append(self.bos_token) 104 | tokens.extend(self.vocab[t] for t in bpe_tokens) 105 | if append_eos: 106 | tokens.append(self.eos_token) 107 | return mx.array(tokens) 108 | 109 | @staticmethod 110 | def from_pretrained(path: str): 111 | path = Path(path) 112 | 113 | with open(path / "vocab.json", encoding="utf-8") as f: 114 | vocab = json.load(f) 115 | with open(path / "merges.txt", encoding="utf-8") as f: 116 | bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] 117 | 118 | bpe_merges = [tuple(m.split()) for m in bpe_merges] 119 | bpe_ranks = dict(map(reversed, enumerate(bpe_merges))) 120 | 121 | return CLIPTokenizer(bpe_ranks, vocab) 122 | -------------------------------------------------------------------------------- /cvae/.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | -------------------------------------------------------------------------------- /cvae/README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Variational Autoencoder (CVAE) on MNIST 2 | 3 | Convolutional variational autoencoder (CVAE) implementation in MLX using 4 | MNIST.[^1] 5 | 6 | ## Setup 7 | 8 | Install the requirements: 9 | 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Run 15 | 16 | 17 | To train a VAE run: 18 | 19 | ```shell 20 | python main.py 21 | ``` 22 | 23 | To see the supported options, do `python main.py -h`. 24 | 25 | Training with the default options should give: 26 | 27 | ```shell 28 | $ python train.py 29 | Options: 30 | Device: GPU 31 | Seed: 0 32 | Batch size: 128 33 | Max number of filters: 64 34 | Number of epochs: 50 35 | Learning rate: 0.001 36 | Number of latent dimensions: 8 37 | Number of trainable params: 0.1493 M 38 | Epoch 1 | Loss 14626.96 | Throughput 1803.44 im/s | Time 34.3 (s) 39 | Epoch 2 | Loss 10462.21 | Throughput 1802.20 im/s | Time 34.3 (s) 40 | ... 41 | Epoch 50 | Loss 8293.13 | Throughput 1804.91 im/s | Time 34.2 (s) 42 | ``` 43 | 44 | The throughput was measured on a 32GB M1 Max. 45 | 46 | Reconstructed and generated images will be saved after each epoch in the 47 | `models/` path. Below are examples of reconstructed training set images and 48 | generated images. 49 | 50 | #### Reconstruction 51 | 52 | ![MNIST Reconstructions](assets/rec_mnist.png) 53 | 54 | #### Generation 55 | 56 | ![MNIST Samples](assets/samples_mnist.png) 57 | 58 | 59 | ## Limitations 60 | 61 | At the time of writing, MLX does not have transposed 2D convolutions. The 62 | example approximates them with a combination of nearest neighbor upsampling and 63 | regular convolutions, similar to the original U-Net. We intend to update this 64 | example once transposed 2D convolutions are available. 65 | 66 | [^1]: For a good overview of VAEs see the original paper [Auto-Encoding 67 | Variational Bayes](https://arxiv.org/abs/1312.6114) or [An Introduction to 68 | Variational Autoencoders](https://arxiv.org/abs/1906.02691). 69 | -------------------------------------------------------------------------------- /cvae/assets/rec_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/cvae/assets/rec_mnist.png -------------------------------------------------------------------------------- /cvae/assets/samples_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/cvae/assets/samples_mnist.png -------------------------------------------------------------------------------- /cvae/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from mlx.data.datasets import load_mnist 4 | 5 | 6 | def mnist(batch_size, img_size, root=None): 7 | # load train and test sets using mlx-data 8 | load_fn = load_mnist 9 | tr = load_fn(root=root, train=True) 10 | test = load_fn(root=root, train=False) 11 | 12 | # number of image channels is 1 for MNIST 13 | num_img_channels = 1 14 | 15 | # normalize to [0,1] 16 | def normalize(x): 17 | return x.astype("float32") / 255.0 18 | 19 | # iterator over training set 20 | tr_iter = ( 21 | tr.shuffle() 22 | .to_stream() 23 | .image_resize("image", h=img_size[0], w=img_size[1]) 24 | .key_transform("image", normalize) 25 | .batch(batch_size) 26 | .prefetch(4, 4) 27 | ) 28 | 29 | # iterator over test set 30 | test_iter = ( 31 | test.to_stream() 32 | .image_resize("image", h=img_size[0], w=img_size[1]) 33 | .key_transform("image", normalize) 34 | .batch(batch_size) 35 | ) 36 | return tr_iter, test_iter 37 | 38 | 39 | if __name__ == "__main__": 40 | batch_size = 32 41 | img_size = (64, 64) # (H, W) 42 | 43 | tr_iter, test_iter = mnist(batch_size=batch_size, img_size=img_size) 44 | 45 | B, H, W, C = batch_size, img_size[0], img_size[1], 1 46 | print(f"Batch size: {B}, Channels: {C}, Height: {H}, Width: {W}") 47 | 48 | batch_tr_iter = next(tr_iter) 49 | assert batch_tr_iter["image"].shape == (B, H, W, C), "Wrong training set size" 50 | assert batch_tr_iter["label"].shape == (batch_size,), "Wrong training set size" 51 | 52 | batch_test_iter = next(test_iter) 53 | assert batch_test_iter["image"].shape == (B, H, W, C), "Wrong training set size" 54 | assert batch_test_iter["label"].shape == (batch_size,), "Wrong training set size" 55 | -------------------------------------------------------------------------------- /cvae/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.2 2 | mlx-data 3 | numpy 4 | Pillow 5 | -------------------------------------------------------------------------------- /encodec/README.md: -------------------------------------------------------------------------------- 1 | # EnCodec 2 | 3 | An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and 4 | generate audio. 5 | 6 | ### Setup 7 | 8 | Install the requirements: 9 | 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | Optionally install FFmpeg and SciPy for loading and saving audio files, 15 | respectively. 16 | 17 | Install [FFmpeg](https://ffmpeg.org/): 18 | 19 | ``` 20 | # on macOS using Homebrew (https://brew.sh/) 21 | brew install ffmpeg 22 | ``` 23 | 24 | Install SciPy: 25 | 26 | ``` 27 | pip install scipy 28 | ``` 29 | 30 | ### Example 31 | 32 | An example using the model: 33 | 34 | ```python 35 | import mlx.core as mx 36 | from encodec import EncodecModel 37 | from utils import load_audio, save_audio 38 | 39 | # Load the 48 KHz model and preprocessor. 40 | model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") 41 | 42 | # Load an audio file 43 | audio = load_audio("path/to/audio", model.sampling_rate, model.channels) 44 | 45 | # Preprocess the audio (this can also be a list of arrays for batched 46 | # processing). 47 | feats, mask = processor(audio) 48 | 49 | # Encode at the given bandwidth. A lower bandwidth results in more 50 | # compression but lower reconstruction quality. 51 | @mx.compile 52 | def encode(feats, mask): 53 | return model.encode(feats, mask, bandwidth=3) 54 | 55 | # Decode to reconstruct the audio 56 | @mx.compile 57 | def decode(codes, scales, mask): 58 | return model.decode(codes, scales, mask) 59 | 60 | 61 | codes, scales = encode(feats, mask) 62 | reconstructed = decode(codes, scales, mask) 63 | 64 | # Trim any padding: 65 | reconstructed = reconstructed[0, : len(audio)] 66 | 67 | # Save the audio as a wave file 68 | save_audio("reconstructed.wav", reconstructed, model.sampling_rate) 69 | ``` 70 | 71 | The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the 72 | [Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164) 73 | in several data types. 74 | 75 | ### Optional 76 | 77 | To convert models, use the `convert.py` script. To see the options, run: 78 | 79 | ```bash 80 | python convert.py -h 81 | ``` 82 | 83 | [^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and 84 | [code](https://github.com/facebookresearch/encodec) for more details. 85 | -------------------------------------------------------------------------------- /encodec/benchmarks/bench_mx.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import time 4 | 5 | import mlx.core as mx 6 | 7 | from encodec import EncodecModel 8 | 9 | model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") 10 | 11 | audio = mx.random.uniform(shape=(288000, 2)) 12 | feats, mask = processor(audio) 13 | mx.eval(model, feats, mask) 14 | 15 | 16 | @mx.compile 17 | def fun(): 18 | codes, scales = model.encode(feats, mask, bandwidth=3) 19 | reconstructed = model.decode(codes, scales, mask) 20 | return reconstructed 21 | 22 | 23 | for _ in range(5): 24 | mx.eval(fun()) 25 | 26 | tic = time.time() 27 | for _ in range(10): 28 | mx.eval(fun()) 29 | toc = time.time() 30 | ms = 1000 * (toc - tic) / 10 31 | print(f"Time per it: {ms:.3f}") 32 | -------------------------------------------------------------------------------- /encodec/benchmarks/bench_pt.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from transformers import AutoProcessor, EncodecModel 8 | 9 | processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") 10 | audio = np.random.uniform(size=(2, 288000)).astype(np.float32) 11 | 12 | pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps") 13 | pt_inputs = processor( 14 | raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt" 15 | ).to("mps") 16 | 17 | 18 | def fun(): 19 | pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"]) 20 | pt_audio = pt_model.decode( 21 | pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"] 22 | ) 23 | torch.mps.synchronize() 24 | 25 | 26 | for _ in range(5): 27 | fun() 28 | 29 | tic = time.time() 30 | for _ in range(10): 31 | fun() 32 | toc = time.time() 33 | ms = 1000 * (toc - tic) / 10 34 | print(f"Time per it: {ms:.3f}") 35 | -------------------------------------------------------------------------------- /encodec/example.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import mlx.core as mx 4 | from utils import load_audio, save_audio 5 | 6 | from encodec import EncodecModel 7 | 8 | # Load the 48 KHz model and preprocessor. 9 | model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") 10 | 11 | # Load an audio file 12 | audio = load_audio("/path/to/audio", model.sampling_rate, model.channels) 13 | 14 | # Preprocess the audio (this can also be a list of arrays for batched 15 | # processing). 16 | feats, mask = processor(audio) 17 | 18 | 19 | # Encode at the given bandwidth. A lower bandwidth results in more 20 | # compression but lower reconstruction quality. 21 | @mx.compile 22 | def encode(feats, mask): 23 | return model.encode(feats, mask, bandwidth=3) 24 | 25 | 26 | # Decode to reconstruct the audio 27 | @mx.compile 28 | def decode(codes, scales, mask): 29 | return model.decode(codes, scales, mask) 30 | 31 | 32 | codes, scales = encode(feats, mask) 33 | reconstructed = decode(codes, scales, mask) 34 | 35 | # Trim any padding: 36 | reconstructed = reconstructed[0, : len(audio)] 37 | 38 | # Save the audio as a wave file 39 | save_audio("reconstructed.wav", reconstructed, model.sampling_rate) 40 | -------------------------------------------------------------------------------- /encodec/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.18 2 | numpy 3 | huggingface_hub 4 | -------------------------------------------------------------------------------- /encodec/test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import mlx.core as mx 4 | import numpy as np 5 | import torch 6 | from transformers import AutoProcessor 7 | from transformers import EncodecModel as PTEncodecModel 8 | 9 | from encodec import EncodecModel, preprocess_audio 10 | 11 | 12 | def compare_processors(): 13 | np.random.seed(0) 14 | audio_length = 95500 15 | audio = np.random.uniform(size=(2, audio_length)).astype(np.float32) 16 | 17 | processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") 18 | 19 | pt_inputs = processor( 20 | raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt" 21 | ) 22 | mx_inputs = preprocess_audio( 23 | mx.array(audio).T, 24 | processor.sampling_rate, 25 | processor.chunk_length, 26 | processor.chunk_stride, 27 | ) 28 | 29 | assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1)) 30 | assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1]) 31 | 32 | 33 | def compare_models(): 34 | pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz") 35 | mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") 36 | 37 | np.random.seed(0) 38 | audio_length = 190560 39 | audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32) 40 | mask = np.ones((1, audio_length), dtype=np.int32) 41 | pt_encoded = pt_model.encode( 42 | torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None] 43 | ) 44 | mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask)) 45 | pt_codes = pt_encoded.audio_codes.numpy() 46 | mx_codes = mx_encoded[0] 47 | assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch" 48 | 49 | for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales): 50 | if mx_scale is not None: 51 | pt_scale = pt_scale.numpy() 52 | assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4) 53 | 54 | pt_audio = pt_model.decode( 55 | pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None] 56 | ) 57 | pt_audio = pt_audio[0].squeeze().T.detach().numpy() 58 | mx_audio = mx_model.decode(*mx_encoded, mx.array(mask)) 59 | mx_audio = mx_audio.squeeze() 60 | assert np.allclose( 61 | pt_audio, mx_audio, atol=1e-4, rtol=1e-4 62 | ), "Decoding audio mismatch" 63 | 64 | 65 | if __name__ == "__main__": 66 | compare_processors() 67 | compare_models() 68 | -------------------------------------------------------------------------------- /encodec/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import mlx.core as mx 4 | import numpy as np 5 | 6 | 7 | def save_audio(file: str, audio: mx.array, sampling_rate: int): 8 | """ 9 | Save audio to a wave (.wav) file. 10 | """ 11 | from scipy.io.wavfile import write 12 | 13 | audio = (audio * 32767).astype(mx.int16) 14 | write(file, sampling_rate, np.array(audio)) 15 | 16 | 17 | def load_audio(file: str, sampling_rate: int, channels: int): 18 | """ 19 | Read audio into an mx.array, resampling if necessary. 20 | 21 | Args: 22 | file (str): The audio file to open. 23 | sampling_rate (int): The sample rate to resample the audio at if needed. 24 | channels (int): The number of audio channels. 25 | 26 | Returns: 27 | An mx.array containing the audio waveform in float32. 28 | """ 29 | from subprocess import CalledProcessError, run 30 | 31 | # This launches a subprocess to decode audio while down-mixing 32 | # and resampling as necessary. Requires the ffmpeg CLI in PATH. 33 | # fmt: off 34 | cmd = [ 35 | "ffmpeg", 36 | "-nostdin", 37 | "-threads", "0", 38 | "-i", file, 39 | "-f", "s16le", 40 | "-ac", str(channels), 41 | "-acodec", "pcm_s16le", 42 | "-ar", str(sampling_rate), 43 | "-" 44 | ] 45 | # fmt: on 46 | try: 47 | out = run(cmd, capture_output=True, check=True).stdout 48 | except CalledProcessError as e: 49 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 50 | 51 | out = mx.array(np.frombuffer(out, np.int16)) 52 | return out.reshape(-1, channels).astype(mx.float32) / 32767.0 53 | -------------------------------------------------------------------------------- /flux/flux/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | from .datasets import Dataset, load_dataset 4 | from .flux import FluxPipeline 5 | from .lora import LoRALinear 6 | from .sampler import FluxSampler 7 | from .trainer import Trainer 8 | from .utils import ( 9 | load_ae, 10 | load_clip, 11 | load_clip_tokenizer, 12 | load_flow_model, 13 | load_t5, 14 | load_t5_tokenizer, 15 | save_config, 16 | ) 17 | -------------------------------------------------------------------------------- /flux/flux/datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | from PIL import Image 5 | 6 | 7 | class Dataset: 8 | def __getitem__(self, index: int): 9 | raise NotImplementedError() 10 | 11 | def __len__(self): 12 | raise NotImplementedError() 13 | 14 | 15 | class LocalDataset(Dataset): 16 | prompt_key = "prompt" 17 | 18 | def __init__(self, dataset: str, data_file): 19 | self.dataset_base = Path(dataset) 20 | with open(data_file, "r") as fid: 21 | self._data = [json.loads(l) for l in fid] 22 | 23 | def __len__(self): 24 | return len(self._data) 25 | 26 | def __getitem__(self, index: int): 27 | item = self._data[index] 28 | image = Image.open(self.dataset_base / item["image"]) 29 | return image, item[self.prompt_key] 30 | 31 | 32 | class LegacyDataset(LocalDataset): 33 | prompt_key = "text" 34 | 35 | def __init__(self, dataset: str): 36 | self.dataset_base = Path(dataset) 37 | with open(self.dataset_base / "index.json") as f: 38 | self._data = json.load(f)["data"] 39 | 40 | 41 | class HuggingFaceDataset(Dataset): 42 | 43 | def __init__(self, dataset: str): 44 | from datasets import load_dataset as hf_load_dataset 45 | 46 | self._df = hf_load_dataset(dataset)["train"] 47 | 48 | def __len__(self): 49 | return len(self._df) 50 | 51 | def __getitem__(self, index: int): 52 | item = self._df[index] 53 | return item["image"], item["prompt"] 54 | 55 | 56 | def load_dataset(dataset: str): 57 | dataset_base = Path(dataset) 58 | data_file = dataset_base / "train.jsonl" 59 | legacy_file = dataset_base / "index.json" 60 | 61 | if data_file.exists(): 62 | print(f"Load the local dataset {data_file} .", flush=True) 63 | dataset = LocalDataset(dataset, data_file) 64 | elif legacy_file.exists(): 65 | print(f"Load the local dataset {legacy_file} .") 66 | print() 67 | print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.") 68 | print(" See the README for details.") 69 | print(flush=True) 70 | dataset = LegacyDataset(dataset) 71 | else: 72 | print(f"Load the Hugging Face dataset {dataset} .", flush=True) 73 | dataset = HuggingFaceDataset(dataset) 74 | 75 | return dataset 76 | -------------------------------------------------------------------------------- /flux/flux/lora.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import math 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | 9 | class LoRALinear(nn.Module): 10 | @staticmethod 11 | def from_base( 12 | linear: nn.Linear, 13 | r: int = 8, 14 | dropout: float = 0.0, 15 | scale: float = 1.0, 16 | ): 17 | output_dims, input_dims = linear.weight.shape 18 | lora_lin = LoRALinear( 19 | input_dims=input_dims, 20 | output_dims=output_dims, 21 | r=r, 22 | dropout=dropout, 23 | scale=scale, 24 | ) 25 | lora_lin.linear = linear 26 | return lora_lin 27 | 28 | def fuse(self): 29 | linear = self.linear 30 | bias = "bias" in linear 31 | weight = linear.weight 32 | dtype = weight.dtype 33 | 34 | output_dims, input_dims = weight.shape 35 | fused_linear = nn.Linear(input_dims, output_dims, bias=bias) 36 | 37 | lora_b = self.scale * self.lora_b.T 38 | lora_a = self.lora_a.T 39 | fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype) 40 | if bias: 41 | fused_linear.bias = linear.bias 42 | 43 | return fused_linear 44 | 45 | def __init__( 46 | self, 47 | input_dims: int, 48 | output_dims: int, 49 | r: int = 8, 50 | dropout: float = 0.0, 51 | scale: float = 1.0, 52 | bias: bool = False, 53 | ): 54 | super().__init__() 55 | 56 | # Regular linear layer weights 57 | self.linear = nn.Linear(input_dims, output_dims, bias=bias) 58 | 59 | self.dropout = nn.Dropout(p=dropout) 60 | 61 | # Scale for low-rank update 62 | self.scale = scale 63 | 64 | # Low rank lora weights 65 | scale = 1 / math.sqrt(input_dims) 66 | self.lora_a = mx.random.uniform( 67 | low=-scale, 68 | high=scale, 69 | shape=(input_dims, r), 70 | ) 71 | self.lora_b = mx.zeros(shape=(r, output_dims)) 72 | 73 | def __call__(self, x): 74 | y = self.linear(x) 75 | z = (self.dropout(x) @ self.lora_a) @ self.lora_b 76 | return y + (self.scale * z).astype(x.dtype) 77 | -------------------------------------------------------------------------------- /flux/flux/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import math 4 | from functools import lru_cache 5 | 6 | import mlx.core as mx 7 | 8 | 9 | class FluxSampler: 10 | def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.15): 11 | self._base_shift = base_shift 12 | self._max_shift = max_shift 13 | self._schnell = "schnell" in name 14 | 15 | def _time_shift(self, x, t): 16 | x1, x2 = 256, 4096 17 | t1, t2 = self._base_shift, self._max_shift 18 | exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1) 19 | t = exp_mu / (exp_mu + (1 / t - 1)) 20 | return t 21 | 22 | @lru_cache 23 | def timesteps( 24 | self, num_steps, image_sequence_length, start: float = 1, stop: float = 0 25 | ): 26 | t = mx.linspace(start, stop, num_steps + 1) 27 | 28 | if not self._schnell: 29 | t = self._time_shift(image_sequence_length, t) 30 | 31 | return t.tolist() 32 | 33 | def random_timesteps(self, B, L, dtype=mx.float32, key=None): 34 | if self._schnell: 35 | # TODO: Should we upweigh 1 and 0.75? 36 | t = mx.random.randint(1, 5, shape=(B,), key=key) 37 | t = t.astype(dtype) / 4 38 | else: 39 | t = mx.random.uniform(shape=(B,), dtype=dtype, key=key) 40 | t = self._time_shift(L, t) 41 | 42 | return t 43 | 44 | def sample_prior(self, shape, dtype=mx.float32, key=None): 45 | return mx.random.normal(shape, dtype=dtype, key=key) 46 | 47 | def add_noise(self, x, t, noise=None, key=None): 48 | noise = ( 49 | noise 50 | if noise is not None 51 | else mx.random.normal(x.shape, dtype=x.dtype, key=key) 52 | ) 53 | t = t.reshape([-1] + [1] * (x.ndim - 1)) 54 | return x * (1 - t) + t * noise 55 | 56 | def step(self, pred, x_t, t, t_prev): 57 | return x_t + (t_prev - t) * pred 58 | -------------------------------------------------------------------------------- /flux/flux/trainer.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import numpy as np 3 | from PIL import Image, ImageFile 4 | from tqdm import tqdm 5 | 6 | from .datasets import Dataset 7 | from .flux import FluxPipeline 8 | 9 | 10 | class Trainer: 11 | 12 | def __init__(self, flux: FluxPipeline, dataset: Dataset, args): 13 | self.flux = flux 14 | self.dataset = dataset 15 | self.args = args 16 | self.latents = [] 17 | self.t5_features = [] 18 | self.clip_features = [] 19 | 20 | def _random_crop_resize(self, img): 21 | resolution = self.args.resolution 22 | width, height = img.size 23 | 24 | a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() 25 | 26 | # Random crop the input image between 0.8 to 1.0 of its original dimensions 27 | crop_size = ( 28 | max((0.8 + 0.2 * a) * width, resolution[0]), 29 | max((0.8 + 0.2 * b) * height, resolution[1]), 30 | ) 31 | pan = (width - crop_size[0], height - crop_size[1]) 32 | img = img.crop( 33 | ( 34 | pan[0] * c, 35 | pan[1] * d, 36 | crop_size[0] + pan[0] * c, 37 | crop_size[1] + pan[1] * d, 38 | ) 39 | ) 40 | 41 | # Fit the largest rectangle with the ratio of resolution in the image 42 | # rectangle. 43 | width, height = crop_size 44 | ratio = resolution[0] / resolution[1] 45 | r1 = (height * ratio, height) 46 | r2 = (width, width / ratio) 47 | r = r1 if r1[0] <= width else r2 48 | img = img.crop( 49 | ( 50 | (width - r[0]) / 2, 51 | (height - r[1]) / 2, 52 | (width + r[0]) / 2, 53 | (height + r[1]) / 2, 54 | ) 55 | ) 56 | 57 | # Finally resize the image to resolution 58 | img = img.resize(resolution, Image.LANCZOS) 59 | 60 | return mx.array(np.array(img)) 61 | 62 | def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int): 63 | for i in range(num_augmentations): 64 | img = self._random_crop_resize(input_img) 65 | img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 66 | x_0 = self.flux.ae.encode(img[None]) 67 | x_0 = x_0.astype(self.flux.dtype) 68 | mx.eval(x_0) 69 | self.latents.append(x_0) 70 | 71 | def _encode_prompt(self, prompt): 72 | t5_tok, clip_tok = self.flux.tokenize([prompt]) 73 | t5_feat = self.flux.t5(t5_tok) 74 | clip_feat = self.flux.clip(clip_tok).pooled_output 75 | mx.eval(t5_feat, clip_feat) 76 | self.t5_features.append(t5_feat) 77 | self.clip_features.append(clip_feat) 78 | 79 | def encode_dataset(self): 80 | """Encode the images & prompt in the latent space to prepare for training.""" 81 | self.flux.ae.eval() 82 | for image, prompt in tqdm(self.dataset, desc="encode dataset"): 83 | self._encode_image(image, self.args.num_augmentations) 84 | self._encode_prompt(prompt) 85 | 86 | def iterate(self, batch_size): 87 | xs = mx.concatenate(self.latents) 88 | t5 = mx.concatenate(self.t5_features) 89 | clip = mx.concatenate(self.clip_features) 90 | mx.eval(xs, t5, clip) 91 | n_aug = self.args.num_augmentations 92 | while True: 93 | x_indices = mx.random.permutation(len(self.latents)) 94 | c_indices = x_indices // n_aug 95 | for i in range(0, len(self.latents), batch_size): 96 | x_i = x_indices[i : i + batch_size] 97 | c_i = c_indices[i : i + batch_size] 98 | yield xs[x_i], t5[c_i], clip[c_i] 99 | -------------------------------------------------------------------------------- /flux/generate_interactive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | import numpy as np 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | from flux import FluxPipeline 10 | 11 | 12 | def print_zero(group, *args, **kwargs): 13 | if group.rank() == 0: 14 | flush = kwargs.pop("flush", True) 15 | print(*args, **kwargs, flush=flush) 16 | 17 | 18 | def quantization_predicate(name, m): 19 | return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0 20 | 21 | 22 | def to_latent_size(image_size): 23 | h, w = image_size 24 | h = ((h + 15) // 16) * 16 25 | w = ((w + 15) // 16) * 16 26 | 27 | if (h, w) != image_size: 28 | print( 29 | "Warning: The image dimensions need to be divisible by 16px. " 30 | f"Changing size to {h}x{w}." 31 | ) 32 | 33 | return (h // 8, w // 8) 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser( 38 | description="Generate images from a textual prompt using FLUX" 39 | ) 40 | parser.add_argument("--quantize", "-q", action="store_true") 41 | parser.add_argument("--model", choices=["schnell", "dev"], default="schnell") 42 | parser.add_argument("--output", default="out.png") 43 | args = parser.parse_args() 44 | 45 | flux = FluxPipeline("flux-" + args.model, t5_padding=True) 46 | 47 | if args.quantize: 48 | nn.quantize(flux.flow, class_predicate=quantization_predicate) 49 | nn.quantize(flux.t5, class_predicate=quantization_predicate) 50 | nn.quantize(flux.clip, class_predicate=quantization_predicate) 51 | 52 | group = mx.distributed.init() 53 | if group.size() > 1: 54 | flux.flow.shard(group) 55 | 56 | print_zero(group, "Loading models") 57 | flux.ensure_models_are_loaded() 58 | 59 | def print_help(): 60 | print_zero(group, "The command list:") 61 | print_zero(group, "- 'q' to exit") 62 | print_zero(group, "- 's HxW' to change the size of the image") 63 | print_zero(group, "- 'n S' to change the number of steps") 64 | print_zero(group, "- 'h' to print this help") 65 | 66 | print_zero(group, "FLUX interactive session") 67 | print_help() 68 | seed = 0 69 | size = (512, 512) 70 | latent_size = to_latent_size(size) 71 | steps = 50 if args.model == "dev" else 4 72 | while True: 73 | prompt = input(">> " if group.rank() == 0 else "") 74 | if prompt == "q": 75 | break 76 | if prompt == "h": 77 | print_help() 78 | continue 79 | if prompt.startswith("s "): 80 | size = tuple([int(xi) for xi in prompt[2:].split("x")]) 81 | print_zero(group, "Setting the size to", size) 82 | latent_size = to_latent_size(size) 83 | continue 84 | if prompt.startswith("n "): 85 | steps = int(prompt[2:]) 86 | print_zero(group, "Setting the steps to", steps) 87 | continue 88 | 89 | seed += 1 90 | latents = flux.generate_latents( 91 | prompt, 92 | n_images=1, 93 | num_steps=steps, 94 | latent_size=latent_size, 95 | guidance=4.0, 96 | seed=seed, 97 | ) 98 | print_zero(group, "Processing prompt") 99 | mx.eval(next(latents)) 100 | print_zero(group, "Generating latents") 101 | for xt in tqdm(latents, total=steps, disable=group.rank() > 0): 102 | mx.eval(xt) 103 | print_zero(group, "Generating image") 104 | xt = flux.decode(xt, latent_size) 105 | xt = (xt * 255).astype(mx.uint8) 106 | mx.eval(xt) 107 | im = Image.fromarray(np.array(xt[0])) 108 | im.save(args.output) 109 | print_zero(group, "Saved at", args.output, end="\n\n") 110 | -------------------------------------------------------------------------------- /flux/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.18.1 2 | huggingface-hub 3 | regex 4 | numpy 5 | tqdm 6 | Pillow 7 | sentencepiece 8 | -------------------------------------------------------------------------------- /flux/static/dog-r4-g8-1200-512x1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/flux/static/dog-r4-g8-1200-512x1024.png -------------------------------------------------------------------------------- /flux/static/dog-r4-g8-1200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/flux/static/dog-r4-g8-1200.png -------------------------------------------------------------------------------- /flux/static/dog6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/flux/static/dog6.png -------------------------------------------------------------------------------- /flux/static/generated-mlx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/flux/static/generated-mlx.png -------------------------------------------------------------------------------- /gcn/.gitignore: -------------------------------------------------------------------------------- 1 | cora/ 2 | -------------------------------------------------------------------------------- /gcn/README.md: -------------------------------------------------------------------------------- 1 | # Graph Convolutional Network 2 | 3 | An example of [GCN](https://arxiv.org/abs/1609.02907) implementation with MLX. 4 | 5 | ### Install requirements 6 | First, install the few dependencies with `pip`. 7 | 8 | ``` 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ### Run 13 | To try the model, just run the `main.py` file. This will download the Cora dataset, run the training and testing. 14 | 15 | ``` 16 | python main.py 17 | ``` 18 | -------------------------------------------------------------------------------- /gcn/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | 4 | import mlx.core as mx 5 | import numpy as np 6 | import requests 7 | import scipy.sparse as sparse 8 | 9 | """ 10 | Preprocessing follows the same implementation as in: 11 | https://github.com/tkipf/gcn 12 | https://github.com/senadkurtisi/pytorch-GCN/tree/main 13 | """ 14 | 15 | 16 | def download_cora(): 17 | """Downloads the cora dataset into a local cora folder.""" 18 | 19 | url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz" 20 | extract_to = "." 21 | 22 | if os.path.exists(os.path.join(extract_to, "cora")): 23 | return 24 | 25 | response = requests.get(url, stream=True) 26 | if response.status_code == 200: 27 | file_path = os.path.join(extract_to, url.split("/")[-1]) 28 | 29 | # Write the file to local disk 30 | with open(file_path, "wb") as file: 31 | file.write(response.raw.read()) 32 | 33 | # Extract the .tgz file 34 | with tarfile.open(file_path, "r:gz") as tar: 35 | tar.extractall(path=extract_to) 36 | print(f"Cora dataset extracted to {extract_to}") 37 | 38 | os.remove(file_path) 39 | 40 | 41 | def train_val_test_mask(): 42 | """Splits the loaded dataset into train/validation/test sets.""" 43 | 44 | train_set = mx.arange(140) 45 | validation_set = mx.arange(200, 500) 46 | test_set = mx.arange(500, 1500) 47 | 48 | return train_set, validation_set, test_set 49 | 50 | 51 | def enumerate_labels(labels): 52 | """Converts the labels from the original 53 | string form to the integer [0:MaxLabels-1] 54 | """ 55 | label_map = {v: e for e, v in enumerate(set(labels))} 56 | labels = np.array([label_map[label] for label in labels]) 57 | return labels 58 | 59 | 60 | def normalize_adjacency(adj): 61 | """Normalizes the adjacency matrix according to the 62 | paper by Kipf et al. 63 | https://arxiv.org/abs/1609.02907 64 | """ 65 | adj = adj + sparse.eye(adj.shape[0]) 66 | 67 | node_degrees = np.array(adj.sum(1)) 68 | node_degrees = np.power(node_degrees, -0.5).flatten() 69 | node_degrees[np.isinf(node_degrees)] = 0.0 70 | node_degrees[np.isnan(node_degrees)] = 0.0 71 | degree_matrix = sparse.diags(node_degrees, dtype=np.float32) 72 | 73 | adj = degree_matrix @ adj @ degree_matrix 74 | return adj 75 | 76 | 77 | def load_data(config): 78 | """Loads the Cora graph data into MLX array format.""" 79 | print("Loading Cora dataset...") 80 | 81 | # Download dataset files 82 | download_cora() 83 | 84 | # Graph nodes 85 | raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str") 86 | raw_node_ids = raw_nodes_data[:, 0].astype( 87 | "int32" 88 | ) # unique identifier of each node 89 | raw_node_labels = raw_nodes_data[:, -1] 90 | labels_enumerated = enumerate_labels(raw_node_labels) # target labels as integers 91 | node_features = sparse.csr_matrix(raw_nodes_data[:, 1:-1], dtype="float32") 92 | 93 | # Edges 94 | ids_ordered = {raw_id: order for order, raw_id in enumerate(raw_node_ids)} 95 | raw_edges_data = np.genfromtxt(config.edges_path, dtype="int32") 96 | edges_ordered = np.array( 97 | list(map(ids_ordered.get, raw_edges_data.flatten())), dtype="int32" 98 | ).reshape(raw_edges_data.shape) 99 | 100 | # Adjacency matrix 101 | adj = sparse.coo_matrix( 102 | (np.ones(edges_ordered.shape[0]), (edges_ordered[:, 0], edges_ordered[:, 1])), 103 | shape=(labels_enumerated.shape[0], labels_enumerated.shape[0]), 104 | dtype=np.float32, 105 | ) 106 | 107 | # Make the adjacency matrix symmetric 108 | adj = adj + adj.T.multiply(adj.T > adj) 109 | adj = normalize_adjacency(adj) 110 | 111 | # Convert to mlx array 112 | features = mx.array(node_features.toarray(), mx.float32) 113 | labels = mx.array(labels_enumerated, mx.int32) 114 | adj = mx.array(adj.toarray()) 115 | 116 | print("Dataset loaded.") 117 | return features, labels, adj 118 | -------------------------------------------------------------------------------- /gcn/gcn.py: -------------------------------------------------------------------------------- 1 | import mlx.nn as nn 2 | 3 | 4 | class GCNLayer(nn.Module): 5 | def __init__(self, in_features, out_features, bias=True): 6 | super(GCNLayer, self).__init__() 7 | self.linear = nn.Linear(in_features, out_features, bias) 8 | 9 | def __call__(self, x, adj): 10 | x = self.linear(x) 11 | return adj @ x 12 | 13 | 14 | class GCN(nn.Module): 15 | def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True): 16 | super(GCN, self).__init__() 17 | 18 | layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim] 19 | self.gcn_layers = [ 20 | GCNLayer(in_dim, out_dim, bias) 21 | for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:]) 22 | ] 23 | self.dropout = nn.Dropout(p=dropout) 24 | 25 | def __call__(self, x, adj): 26 | for layer in self.gcn_layers[:-1]: 27 | x = nn.relu(layer(x, adj)) 28 | x = self.dropout(x) 29 | 30 | x = self.gcn_layers[-1](x, adj) 31 | return x 32 | -------------------------------------------------------------------------------- /gcn/main.py: -------------------------------------------------------------------------------- 1 | import time 2 | from argparse import ArgumentParser 3 | from functools import partial 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | import mlx.optimizers as optim 8 | from datasets import load_data, train_val_test_mask 9 | from mlx.utils import tree_flatten 10 | 11 | from gcn import GCN 12 | 13 | 14 | def loss_fn(y_hat, y, weight_decay=0.0, parameters=None): 15 | l = mx.mean(nn.losses.cross_entropy(y_hat, y)) 16 | 17 | if weight_decay != 0.0: 18 | assert parameters != None, "Model parameters missing for L2 reg." 19 | 20 | l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt() 21 | return l + weight_decay * l2_reg 22 | return l 23 | 24 | 25 | def eval_fn(x, y): 26 | return mx.mean(mx.argmax(x, axis=1) == y) 27 | 28 | 29 | def forward_fn(gcn, x, adj, y, train_mask, weight_decay): 30 | y_hat = gcn(x, adj) 31 | loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters()) 32 | return loss, y_hat 33 | 34 | 35 | def main(args): 36 | # Data loading 37 | x, y, adj = load_data(args) 38 | train_mask, val_mask, test_mask = train_val_test_mask() 39 | 40 | gcn = GCN( 41 | x_dim=x.shape[-1], 42 | h_dim=args.hidden_dim, 43 | out_dim=args.nb_classes, 44 | nb_layers=args.nb_layers, 45 | dropout=args.dropout, 46 | bias=args.bias, 47 | ) 48 | mx.eval(gcn.parameters()) 49 | 50 | optimizer = optim.Adam(learning_rate=args.lr) 51 | 52 | state = [gcn.state, optimizer.state, mx.random.state] 53 | 54 | @partial(mx.compile, inputs=state, outputs=state) 55 | def step(): 56 | loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn) 57 | (loss, y_hat), grads = loss_and_grad_fn( 58 | gcn, x, adj, y, train_mask, args.weight_decay 59 | ) 60 | optimizer.update(gcn, grads) 61 | return loss, y_hat 62 | 63 | best_val_loss = float("inf") 64 | cnt = 0 65 | 66 | # Training loop 67 | for epoch in range(args.epochs): 68 | tic = time.time() 69 | loss, y_hat = step() 70 | mx.eval(state) 71 | 72 | # Validation 73 | val_loss = loss_fn(y_hat[val_mask], y[val_mask]) 74 | val_acc = eval_fn(y_hat[val_mask], y[val_mask]) 75 | toc = time.time() 76 | 77 | # Early stopping 78 | if val_loss < best_val_loss: 79 | best_val_loss = val_loss 80 | cnt = 0 81 | else: 82 | cnt += 1 83 | if cnt == args.patience: 84 | break 85 | 86 | print( 87 | " | ".join( 88 | [ 89 | f"Epoch: {epoch:3d}", 90 | f"Train loss: {loss.item():.3f}", 91 | f"Val loss: {val_loss.item():.3f}", 92 | f"Val acc: {val_acc.item():.2f}", 93 | f"Time: {1e3*(toc - tic):.3f} (ms)", 94 | ] 95 | ) 96 | ) 97 | 98 | # Test 99 | test_y_hat = gcn(x, adj) 100 | test_loss = loss_fn(y_hat[test_mask], y[test_mask]) 101 | test_acc = eval_fn(y_hat[test_mask], y[test_mask]) 102 | 103 | print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}") 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = ArgumentParser() 108 | parser.add_argument("--nodes_path", type=str, default="cora/cora.content") 109 | parser.add_argument("--edges_path", type=str, default="cora/cora.cites") 110 | parser.add_argument("--hidden_dim", type=int, default=20) 111 | parser.add_argument("--dropout", type=float, default=0.5) 112 | parser.add_argument("--nb_layers", type=int, default=2) 113 | parser.add_argument("--nb_classes", type=int, default=7) 114 | parser.add_argument("--bias", type=bool, default=True) 115 | parser.add_argument("--lr", type=float, default=0.001) 116 | parser.add_argument("--weight_decay", type=float, default=0.0) 117 | parser.add_argument("--patience", type=int, default=20) 118 | parser.add_argument("--epochs", type=int, default=100) 119 | args = parser.parse_args() 120 | 121 | main(args) 122 | -------------------------------------------------------------------------------- /gcn/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.0.4 2 | numpy>=1.26.2 3 | scipy>=1.11.4 4 | requests>=2.31.0 5 | -------------------------------------------------------------------------------- /llava/.gitignore: -------------------------------------------------------------------------------- 1 | **.ipynb -------------------------------------------------------------------------------- /llava/README.md: -------------------------------------------------------------------------------- 1 | # LLaVA 2 | 3 | An example of LLaVA: Large Language and Vision Assistant in MLX.[^1] LLlava is 4 | a multimodal model that can generate text given combined image and text inputs. 5 | 6 | ## Setup 7 | 8 | Install the dependencies: 9 | 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Run 15 | 16 | You can use LLaVA to ask questions about images. 17 | 18 | For example, using the command line: 19 | 20 | ```bash 21 | python generate.py \ 22 | --model llava-hf/llava-1.5-7b-hf \ 23 | --image "http://images.cocodataset.org/val2017/000000039769.jpg" \ 24 | --prompt "USER: \nWhat are these?\nASSISTANT:" \ 25 | --max-tokens 128 \ 26 | --temp 0 27 | ``` 28 | 29 | This uses the following image: 30 | 31 | ![alt text](http://images.cocodataset.org/val2017/000000039769.jpg) 32 | 33 | And generates the output: 34 | 35 | ``` 36 | These are two cats lying on a pink couch. 37 | ``` 38 | 39 | You can also use LLaVA in Python: 40 | 41 | ```python 42 | from generate import load_model, prepare_inputs, generate_text 43 | 44 | processor, model = load_model("llava-hf/llava-1.5-7b-hf") 45 | 46 | max_tokens, temperature = 128, 0.0 47 | 48 | prompt = "USER: \nWhat are these?\nASSISTANT:" 49 | image = "http://images.cocodataset.org/val2017/000000039769.jpg" 50 | input_ids, pixel_values = prepare_inputs(processor, image, prompt) 51 | 52 | reply = generate_text( 53 | input_ids, pixel_values, model, processor, max_tokens, temperature 54 | ) 55 | 56 | print(reply) 57 | ``` 58 | 59 | [^1]: 60 | Refer to [LLaVA project webpage](https://llava-vl.github.io/) for more 61 | information. 62 | -------------------------------------------------------------------------------- /llava/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import argparse 4 | import codecs 5 | from pathlib import Path 6 | 7 | import mlx.core as mx 8 | import requests 9 | from PIL import Image 10 | from transformers import AutoProcessor 11 | 12 | from llava import LlavaModel 13 | 14 | 15 | def parse_arguments(): 16 | parser = argparse.ArgumentParser( 17 | description="Generate text from an image using a model." 18 | ) 19 | parser.add_argument( 20 | "--model", 21 | type=str, 22 | default="llava-hf/llava-1.5-7b-hf", 23 | help="The path to the local model directory or Hugging Face repo.", 24 | ) 25 | parser.add_argument( 26 | "--image", 27 | type=str, 28 | default="http://images.cocodataset.org/val2017/000000039769.jpg", 29 | help="URL or path of the image to process.", 30 | ) 31 | parser.add_argument( 32 | "--prompt", 33 | type=str, 34 | default="USER: \nWhat are these?\nASSISTANT:", 35 | help="Message to be processed by the model.", 36 | ) 37 | parser.add_argument( 38 | "--max-tokens", 39 | type=int, 40 | default=100, 41 | help="Maximum number of tokens to generate.", 42 | ) 43 | parser.add_argument( 44 | "--temp", type=float, default=0.3, help="Temperature for sampling." 45 | ) 46 | parser.add_argument( 47 | "--eos-token", 48 | type=str, 49 | default=None, 50 | help="End of sequence token for tokenizer", 51 | ) 52 | return parser.parse_args() 53 | 54 | 55 | def load_image(image_source): 56 | """ 57 | Helper function to load an image from either a URL or file. 58 | """ 59 | if image_source.startswith(("http://", "https://")): 60 | try: 61 | response = requests.get(image_source, stream=True) 62 | response.raise_for_status() 63 | return Image.open(response.raw) 64 | except Exception as e: 65 | raise ValueError( 66 | f"Failed to load image from URL: {image_source} with error {e}" 67 | ) 68 | elif Path(image_source).is_file(): 69 | try: 70 | return Image.open(image_source) 71 | except IOError as e: 72 | raise ValueError(f"Failed to load image {image_source} with error: {e}") 73 | else: 74 | raise ValueError( 75 | f"The image {image_source} must be a valid URL or existing file." 76 | ) 77 | 78 | 79 | def prepare_inputs(processor, image, prompt): 80 | if isinstance(image, str): 81 | image = load_image(image) 82 | inputs = processor(image, prompt, return_tensors="np") 83 | pixel_values = mx.array(inputs["pixel_values"]) 84 | input_ids = mx.array(inputs["input_ids"]) 85 | return pixel_values, input_ids 86 | 87 | 88 | def load_model(model_path, tokenizer_config={}): 89 | processor = AutoProcessor.from_pretrained(model_path, **tokenizer_config) 90 | model = LlavaModel.from_pretrained(model_path) 91 | return processor, model 92 | 93 | 94 | def sample(logits, temperature=0.0): 95 | if temperature == 0: 96 | return mx.argmax(logits, axis=-1) 97 | else: 98 | return mx.random.categorical(logits * (1 / temperature)) 99 | 100 | 101 | def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature): 102 | logits, cache = model(input_ids, pixel_values) 103 | logits = logits[:, -1, :] 104 | y = sample(logits, temperature=temperature) 105 | tokens = [y.item()] 106 | 107 | for n in range(max_tokens - 1): 108 | logits, cache = model.language_model(y[None], cache=cache) 109 | logits = logits[:, -1, :] 110 | y = sample(logits, temperature) 111 | token = y.item() 112 | if token == processor.tokenizer.eos_token_id: 113 | break 114 | tokens.append(token) 115 | 116 | return processor.tokenizer.decode(tokens) 117 | 118 | 119 | def main(): 120 | args = parse_arguments() 121 | 122 | tokenizer_config = {} 123 | if args.eos_token is not None: 124 | tokenizer_config["eos_token"] = args.eos_token 125 | 126 | processor, model = load_model(args.model, tokenizer_config) 127 | 128 | prompt = codecs.decode(args.prompt, "unicode_escape") 129 | pixel_values, input_ids = prepare_inputs(processor, args.image, prompt) 130 | 131 | print(prompt) 132 | generated_text = generate_text( 133 | input_ids, pixel_values, model, processor, args.max_tokens, args.temp 134 | ) 135 | print(generated_text) 136 | 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /llava/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.8.0 2 | numpy 3 | transformers 4 | torch 5 | huggingface_hub 6 | Pillow 7 | -------------------------------------------------------------------------------- /llms/README.md: -------------------------------------------------------------------------------- 1 | # MOVE NOTICE 2 | 3 | The mlx-lm package has moved to a [new repo](https://github.com/ml-explore/mlx-lm). 4 | 5 | The package has been removed from the MLX Examples repo. Send new contributions 6 | and issues to the MLX LM repo. 7 | -------------------------------------------------------------------------------- /llms/gguf_llm/README.md: -------------------------------------------------------------------------------- 1 | # LLMs in MLX with GGUF 2 | 3 | An example generating text using GGUF format models in MLX.[^1] 4 | 5 | > [!NOTE] 6 | > MLX is able to read most quantization formats from GGUF directly. However, 7 | > only a few quantizations are supported directly: `Q4_0`, `Q4_1`, and `Q8_0`. 8 | > Unsupported quantizations will be cast to `float16`. 9 | 10 | ## Setup 11 | 12 | Install the dependencies: 13 | 14 | ```bash 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ### Run 19 | 20 | Run with: 21 | 22 | ```bash 23 | python generate.py \ 24 | --repo \ 25 | --gguf \ 26 | --prompt "Write a quicksort in Python" 27 | ``` 28 | 29 | For example, to generate text with Mistral 7B use: 30 | 31 | ```bash 32 | python generate.py \ 33 | --repo TheBloke/Mistral-7B-v0.1-GGUF \ 34 | --gguf mistral-7b-v0.1.Q8_0.gguf \ 35 | --prompt "Write a quicksort in Python" 36 | ``` 37 | 38 | Run `python generate.py --help` for more options. 39 | 40 | Models that have been tested and work include: 41 | 42 | - [TheBloke/Mistral-7B-v0.1-GGUF](https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF), 43 | for quantized models use: 44 | - `mistral-7b-v0.1.Q8_0.gguf` 45 | - `mistral-7b-v0.1.Q4_0.gguf` 46 | 47 | - [TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF](https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF), 48 | for quantized models use: 49 | - `tinyllama-1.1b-chat-v1.0.Q8_0.gguf` 50 | - `tinyllama-1.1b-chat-v1.0.Q4_0.gguf` 51 | 52 | - [Jaward/phi-3-mini-4k-instruct.Q4_0.gguf](https://huggingface.co/Jaward/phi-3-mini-4k-instruct.Q4_0.gguf), 53 | for 4 bits quantized phi-3-mini-4k-instruct use: 54 | - `phi-3-mini-4k-instruct.Q4_0.gguf` 55 | 56 | [^1]: For more information on GGUF see [the documentation](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md). 57 | -------------------------------------------------------------------------------- /llms/gguf_llm/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | import time 5 | 6 | import mlx.core as mx 7 | import models 8 | 9 | 10 | def generate( 11 | model: models.Model, 12 | tokenizer: models.GGUFTokenizer, 13 | prompt: str, 14 | max_tokens: int, 15 | temp: float = 0.0, 16 | ): 17 | prompt = tokenizer.encode(prompt) 18 | 19 | tic = time.time() 20 | tokens = [] 21 | skip = 0 22 | for token, n in zip( 23 | models.generate(prompt, model, args.temp), 24 | range(args.max_tokens), 25 | ): 26 | if token == tokenizer.eos_token_id: 27 | break 28 | 29 | if n == 0: 30 | prompt_time = time.time() - tic 31 | tic = time.time() 32 | 33 | tokens.append(token.item()) 34 | s = tokenizer.decode(tokens) 35 | print(s[skip:], end="", flush=True) 36 | skip = len(s) 37 | print(tokenizer.decode(tokens)[skip:], flush=True) 38 | gen_time = time.time() - tic 39 | print("=" * 10) 40 | if len(tokens) == 0: 41 | print("No tokens generated for this prompt") 42 | return 43 | prompt_tps = len(prompt) / prompt_time 44 | gen_tps = (len(tokens) - 1) / gen_time 45 | print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") 46 | print(f"Generation: {gen_tps:.3f} tokens-per-sec") 47 | 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser(description="Inference script") 51 | parser.add_argument( 52 | "--gguf", 53 | type=str, 54 | help="The GGUF file to load (and optionally download).", 55 | ) 56 | parser.add_argument( 57 | "--repo", 58 | type=str, 59 | default=None, 60 | help="The Hugging Face repo if downloading from the Hub.", 61 | ) 62 | 63 | parser.add_argument( 64 | "--prompt", 65 | help="The message to be processed by the model", 66 | default="In the beginning the Universe was created.", 67 | ) 68 | parser.add_argument( 69 | "--max-tokens", 70 | "-m", 71 | type=int, 72 | default=100, 73 | help="Maximum number of tokens to generate", 74 | ) 75 | parser.add_argument( 76 | "--temp", 77 | help="The sampling temperature.", 78 | type=float, 79 | default=0.0, 80 | ) 81 | parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") 82 | 83 | args = parser.parse_args() 84 | mx.random.seed(args.seed) 85 | model, tokenizer = models.load(args.gguf, args.repo) 86 | generate(model, tokenizer, args.prompt, args.max_tokens, args.temp) 87 | -------------------------------------------------------------------------------- /llms/gguf_llm/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.8 2 | numpy 3 | protobuf==3.20.2 4 | sentencepiece 5 | huggingface_hub 6 | -------------------------------------------------------------------------------- /llms/gguf_llm/utils.py: -------------------------------------------------------------------------------- 1 | import sentencepiece as spm 2 | import sentencepiece.sentencepiece_model_pb2 as model 3 | 4 | 5 | def spm_tokenizer(metadata): 6 | tokens = metadata["tokenizer.ggml.tokens"] 7 | bos = metadata["tokenizer.ggml.bos_token_id"].item() 8 | eos = metadata["tokenizer.ggml.eos_token_id"].item() 9 | unk = metadata["tokenizer.ggml.unknown_token_id"].item() 10 | 11 | normalizer_spec = model.NormalizerSpec( 12 | name="identity", 13 | precompiled_charsmap=b"", 14 | add_dummy_prefix=True, 15 | remove_extra_whitespaces=False, 16 | normalization_rule_tsv=b"", 17 | ) 18 | trainer_spec = model.TrainerSpec( 19 | model_type="BPE", 20 | vocab_size=len(tokens), 21 | input_format="text", 22 | split_by_unicode_script=True, 23 | split_by_whitespace=True, 24 | split_by_number=True, 25 | treat_whitespace_as_suffix=False, 26 | split_digits=True, 27 | allow_whitespace_only_pieces=True, 28 | vocabulary_output_piece_score=True, 29 | byte_fallback=True, 30 | unk_id=unk, 31 | bos_id=bos, 32 | eos_id=eos, 33 | pad_id=-1, 34 | unk_piece="", 35 | bos_piece="", 36 | eos_piece="", 37 | pad_piece="", 38 | pretokenization_delimiter="", 39 | ) 40 | m = model.ModelProto(trainer_spec=trainer_spec, normalizer_spec=normalizer_spec) 41 | scores = metadata.get("tokenizer.ggml.scores", None) 42 | scores = scores.tolist() if scores is not None else None 43 | token_types = metadata.get("tokenizer.ggml.token_type", None) 44 | token_types = token_types.tolist() if token_types is not None else None 45 | 46 | for i, token in enumerate(tokens): 47 | score = scores[i] if scores else 0 48 | token_type = token_types[i] if token_types else 0 49 | m.pieces.append( 50 | model.ModelProto.SentencePiece(piece=token, score=score, type=token_type) 51 | ) 52 | tokenizer = spm.SentencePieceProcessor(model_proto=m.SerializeToString()) 53 | return tokenizer 54 | -------------------------------------------------------------------------------- /llms/llama/README.md: -------------------------------------------------------------------------------- 1 | # Llama 2 | 3 | An example of generating text with Llama (1 or 2) using MLX. 4 | 5 | Llama is a set of open source language models from Meta AI Research[^1][^2] 6 | ranging from 7B to 70B parameters. This example also supports Meta's Llama Chat 7 | and Code Llama models, as well as the 1.1B TinyLlama models from SUTD.[^3] 8 | 9 | ### Setup 10 | 11 | Install the dependencies: 12 | 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | Next, download and convert the model. If you do not have access to the model 18 | weights you will need to request access from Meta: 19 | 20 | - [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) 21 | - [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) 22 | 23 | > [!TIP] Alternatively, you can also download a few converted checkpoints from 24 | > the [MLX Community](https://huggingface.co/mlx-community) organization on 25 | > Hugging Face and skip the conversion step. 26 | 27 | You can download the TinyLlama models directly from [Hugging 28 | Face](https://huggingface.co/TinyLlama). 29 | 30 | Convert the weights with: 31 | 32 | ``` 33 | python convert.py --torch-path 34 | ``` 35 | 36 | To generate a 4-bit quantized model use the `-q` flag: 37 | 38 | ``` 39 | python convert.py --torch-path -q 40 | ``` 41 | 42 | For TinyLlama use 43 | 44 | ``` 45 | python convert.py --torch-path --model-name tiny_llama 46 | ``` 47 | 48 | By default, the conversion script will make the directory `mlx_model` and save 49 | the converted `weights.npz`, `tokenizer.model`, and `config.json` there. 50 | 51 | 52 | ### Run 53 | 54 | Once you've converted the weights to MLX format, you can interact with the 55 | LlamA model: 56 | 57 | ``` 58 | python llama.py --prompt "hello" 59 | ``` 60 | 61 | Run `python llama.py --help` for more details. 62 | 63 | [^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details. 64 | [^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/) 65 | [^3]: For TinyLlama refer to the [gihub repository](https://github.com/jzhang38/TinyLlama?tab=readme-ov-file) 66 | -------------------------------------------------------------------------------- /llms/llama/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.11.0 2 | sentencepiece 3 | torch 4 | numpy 5 | -------------------------------------------------------------------------------- /llms/llama/sample_prompt.txt: -------------------------------------------------------------------------------- 1 | [Instruction] Give the list of U.S. states bordering Canada 2 | [Answer] OK, here is the list of U.S. states located on the border with Canada: 3 | - Alaska 4 | - Michigan 5 | - Maine 6 | - Minnesota 7 | - Montana 8 | - New York 9 | - Washington 10 | - North Dakota 11 | - Ohio 12 | - Vermont 13 | - New Hampshire 14 | - Idaho 15 | - Pennsylvania 16 | [Instruction] Write a paragraph about "functional analysis" 17 | [Answer] OK, here is a paragraph on the topic of functional analysis: 18 | Functional analysis is a branch of mathematical analysis, the core of which is formed by the study of vector spaces endowed with some kind of limit-related structure (for example, inner product, norm, or topology) and the linear functions defined on these spaces and suitably respecting these structures. The historical roots of functional analysis lie in the study of spaces of functions and the formulation of properties of transformations of functions such as the Fourier transform as transformations defining, for example, continuous or unitary operators between function spaces. This point of view turned out to be particularly useful for the study of differential and integral equations. 19 | [Instruction] I am starting a new dog walking business. Can you help me find 2 possible names for the business? 20 | [Answer] OK, here are two possible names for a new dog walking business: 21 | The first option is "Paws on Patrol", and the second option is "The Dog Whisperer". 22 | [Instruction] {} 23 | [Answer] 24 | -------------------------------------------------------------------------------- /llms/mistral/.gitignore: -------------------------------------------------------------------------------- 1 | mistral-7B-v0.1/ 2 | -------------------------------------------------------------------------------- /llms/mistral/README.md: -------------------------------------------------------------------------------- 1 | # Mistral 2 | 3 | An example of generating text with Mistral using MLX. 4 | 5 | Mistral 7B is one of the top large language models in its size class. It is 6 | also fully open source with a permissive license[^1]. 7 | 8 | ### Setup 9 | 10 | Install the dependencies: 11 | 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | Next, download the model and tokenizer: 17 | 18 | ``` 19 | curl -O https://models.mistralcdn.com/mistral-7b-v0-1/mistral-7B-v0.1.tar 20 | tar -xf mistral-7B-v0.1.tar 21 | ``` 22 | 23 | Then, convert the weights with: 24 | 25 | ``` 26 | python convert.py --torch-path 27 | ``` 28 | 29 | To generate a 4-bit quantized model, use ``-q``. For a full list of options: 30 | 31 | ``` 32 | python convert.py --help 33 | ``` 34 | 35 | By default, the conversion script will make the directory `mlx_model` and save 36 | the converted `weights.npz`, `tokenizer.model`, and `config.json` there. 37 | 38 | > [!TIP] 39 | > Alternatively, you can also download a few converted checkpoints from the 40 | > [MLX Community](https://huggingface.co/mlx-community) organization on Hugging 41 | > Face and skip the conversion step. 42 | 43 | 44 | ### Run 45 | 46 | Once you've converted the weights to MLX format, you can generate text with 47 | the Mistral model: 48 | 49 | ``` 50 | python mistral.py --prompt "It is a truth universally acknowledged," 51 | ``` 52 | 53 | Run `python mistral.py --help` for more details. 54 | 55 | [^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) 56 | and [github repository](https://github.com/mistralai/mistral-src) for more 57 | details. 58 | -------------------------------------------------------------------------------- /llms/mistral/convert.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | import copy 5 | import json 6 | import shutil 7 | from pathlib import Path 8 | 9 | import mlx.core as mx 10 | import mlx.nn as nn 11 | import numpy as np 12 | import torch 13 | from mistral import Mistral, ModelArgs 14 | from mlx.utils import tree_flatten, tree_map, tree_unflatten 15 | 16 | 17 | def quantize(weights, config, args): 18 | quantized_config = copy.deepcopy(config) 19 | 20 | # Load the model: 21 | config.pop("sliding_window", None) 22 | model = Mistral(ModelArgs(**config)) 23 | weights = tree_map(mx.array, weights) 24 | model.update(tree_unflatten(list(weights.items()))) 25 | 26 | # Quantize the model: 27 | nn.quantize(model, args.q_group_size, args.q_bits) 28 | 29 | # Update the config: 30 | quantized_config["quantization"] = { 31 | "group_size": args.q_group_size, 32 | "bits": args.q_bits, 33 | } 34 | quantized_weights = dict(tree_flatten(model.parameters())) 35 | 36 | return quantized_weights, quantized_config 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") 41 | parser.add_argument( 42 | "--torch-path", 43 | type=str, 44 | default="mistral-7B-v0.1", 45 | help="The path to the PyTorch model.", 46 | ) 47 | parser.add_argument( 48 | "--mlx-path", 49 | type=str, 50 | default="mlx_model", 51 | help="The path to save the MLX model.", 52 | ) 53 | parser.add_argument( 54 | "-q", 55 | "--quantize", 56 | help="Generate a quantized model.", 57 | action="store_true", 58 | ) 59 | parser.add_argument( 60 | "--q-group-size", 61 | help="Group size for quantization.", 62 | type=int, 63 | default=64, 64 | ) 65 | parser.add_argument( 66 | "--q-bits", 67 | help="Bits per weight for quantization.", 68 | type=int, 69 | default=4, 70 | ) 71 | args = parser.parse_args() 72 | 73 | torch_path = Path(args.torch_path) 74 | state = torch.load(str(torch_path / "consolidated.00.pth")) 75 | mlx_path = Path(args.mlx_path) 76 | mlx_path.mkdir(parents=True, exist_ok=True) 77 | 78 | weights = {k: v.to(torch.float16).numpy() for k, v in state.items()} 79 | with open(torch_path / "params.json", "r") as f: 80 | config = json.loads(f.read()) 81 | 82 | if args.quantize: 83 | print("[INFO] Quantizing") 84 | weights, config = quantize(weights, config, args) 85 | 86 | # Save weights 87 | np.savez(str(mlx_path / "weights.npz"), **weights) 88 | 89 | # Copy tokenizer 90 | shutil.copyfile( 91 | str(torch_path / "tokenizer.model"), 92 | str(mlx_path / "tokenizer.model"), 93 | ) 94 | 95 | # Save config.json with model_type 96 | with open(mlx_path / "config.json", "w") as f: 97 | config["model_type"] = "mistral" 98 | json.dump(config, f, indent=4) 99 | -------------------------------------------------------------------------------- /llms/mistral/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.11.0 2 | sentencepiece 3 | torch 4 | numpy 5 | -------------------------------------------------------------------------------- /llms/mistral/test.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import unittest 4 | 5 | import mistral 6 | import mlx.core as mx 7 | from mlx.utils import tree_map 8 | 9 | 10 | class TestMistral(unittest.TestCase): 11 | def test_model(self): 12 | vocab_size = 100 13 | L = 32 14 | args = mistral.ModelArgs( 15 | dim=128, 16 | n_layers=2, 17 | head_dim=32, 18 | hidden_dim=256, 19 | n_heads=4, 20 | n_kv_heads=4, 21 | norm_eps=1e-3, 22 | vocab_size=vocab_size, 23 | ) 24 | 25 | model = mistral.Mistral(args) 26 | inputs = mx.random.randint(0, vocab_size, (L,)) 27 | logits, cache = model(inputs[None]) 28 | self.assertEqual(logits.shape, [1, L, vocab_size]) 29 | self.assertEqual(logits.dtype, mx.float32) 30 | self.assertEqual(len(cache), args.n_layers) 31 | 32 | params = tree_map(lambda p: p.astype(mx.float16), model.parameters()) 33 | model.update(params) 34 | logits, _ = model(inputs[None]) 35 | self.assertEqual(logits.dtype, mx.float16) 36 | 37 | def test_generate(self): 38 | model, tokenizer = mistral.load_model("mistral-7B-v0.1") 39 | prompt = mx.array(tokenizer.encode("This is a test")) 40 | tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(30))] 41 | mx.eval(tokens) 42 | tokens = [t.item() for t in tokens] 43 | expected = [ 44 | 302, 45 | 272, 46 | 11843, 47 | 11837, 48 | 1587, 49 | 28723, 50 | 851, 51 | 349, 52 | 865, 53 | 264, 54 | 1369, 55 | 28723, 56 | 13, 57 | 13, 58 | 3381, 59 | 456, 60 | 654, 61 | 264, 62 | 1353, 63 | 11843, 64 | 28725, 65 | 368, 66 | 682, 67 | 347, 68 | 2240, 69 | 767, 70 | 298, 71 | 511, 72 | 28723, 73 | 13, 74 | ] 75 | self.assertEqual(tokens, expected) 76 | 77 | def benchmark(self): 78 | import time 79 | 80 | model, tokenizer = mistral.load_model("mistral-7B-v0.1") 81 | prompt = mx.random.randint(0, model.vocab_size, (128,)) 82 | 83 | # warmup 84 | for _ in range(2): 85 | generator = mistral.generate(prompt, model) 86 | mx.eval(next(generator)) 87 | 88 | tic = time.time() 89 | its = 5 90 | for _ in range(its): 91 | generator = mistral.generate(prompt, model) 92 | mx.eval(next(generator)) 93 | toc = time.time() 94 | tps = its * prompt.size / (toc - tic) 95 | print(f"Prompt processing: {tps:.2f} tokens per second") 96 | 97 | # warmup 98 | for _ in range(2): 99 | tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(101))] 100 | mx.eval(tokens) 101 | 102 | time_total = 0.0 103 | its = 2 104 | for _ in range(its): 105 | generator = mistral.generate(prompt, model) 106 | mx.eval(next(generator)) 107 | tic = time.time() 108 | tokens = [t for t, _ in zip(generator, range(100))] 109 | mx.eval(tokens) 110 | time_total += time.time() - tic 111 | 112 | tps = len(tokens) * its / time_total 113 | print(f"Token generation: {tps:.3f} tokens per second") 114 | 115 | 116 | if __name__ == "__main__": 117 | unittest.main() 118 | -------------------------------------------------------------------------------- /llms/mixtral/README.md: -------------------------------------------------------------------------------- 1 | ## Mixtral 8x7B 2 | 3 | Run the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon. 4 | 5 | This example also supports the instruction fine-tuned Mixtral model.[^instruct] 6 | 7 | Note, for 16-bit precision this model needs a machine with substantial RAM (~100GB) to run. 8 | 9 | ### Setup 10 | 11 | Install [Git Large File 12 | Storage](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage). 13 | For example with Homebrew: 14 | 15 | ``` 16 | brew install git-lfs 17 | ``` 18 | 19 | Download the models from Hugging Face: 20 | 21 | For the base model use: 22 | 23 | ``` 24 | export MIXTRAL_MODEL=Mixtral-8x7B-v0.1 25 | ``` 26 | 27 | For the instruction fine-tuned model use: 28 | 29 | ``` 30 | export MIXTRAL_MODEL=Mixtral-8x7B-Instruct-v0.1 31 | ``` 32 | 33 | Then run: 34 | 35 | ``` 36 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mistralai/${MIXTRAL_MODEL}/ 37 | cd $MIXTRAL_MODEL/ && \ 38 | git lfs pull --include "consolidated.*.pt" && \ 39 | git lfs pull --include "tokenizer.model" 40 | ``` 41 | 42 | Now from `mlx-exmaples/mixtral` convert and save the weights as NumPy arrays so 43 | MLX can read them: 44 | 45 | ``` 46 | python convert.py --torch-path $MIXTRAL_MODEL/ 47 | ``` 48 | 49 | To generate a 4-bit quantized model, use ``-q``. For a full list of options: 50 | 51 | ``` 52 | python convert.py --help 53 | ``` 54 | 55 | By default, the conversion script will make the directory `mlx_model` and save 56 | the converted `weights.npz`, `tokenizer.model`, and `config.json` there. 57 | 58 | 59 | ### Generate 60 | 61 | As easy as: 62 | 63 | ``` 64 | python mixtral.py --model-path mlx_model 65 | ``` 66 | 67 | For more options including how to prompt the model, run: 68 | 69 | ``` 70 | python mixtral.py --help 71 | ``` 72 | 73 | For the Instruction model, make sure to follow the prompt format: 74 | 75 | ``` 76 | [INST] Instruction prompt [/INST] 77 | ``` 78 | 79 | [^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) and the [Hugging Face blog post](https://huggingface.co/blog/mixtral) for more details. 80 | [^instruc]: Refer to the [Hugging Face repo](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for more 81 | details 82 | -------------------------------------------------------------------------------- /llms/mixtral/convert.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | import copy 5 | import glob 6 | import json 7 | import shutil 8 | from pathlib import Path 9 | 10 | import mlx.core as mx 11 | import mlx.nn as nn 12 | import numpy as np 13 | import torch 14 | from mixtral import Mixtral, ModelArgs 15 | from mlx.utils import tree_flatten, tree_map, tree_unflatten 16 | 17 | 18 | def convert(tf, config): 19 | def convert_single(k, v): 20 | v = v.to(torch.float16).numpy() 21 | if "block_sparse_moe" not in k: 22 | return [(k, v)] 23 | if "gate" in k: 24 | return [(k.replace("block_sparse_moe", "feed_forward"), v)] 25 | 26 | # From: layers.N.block_sparse_moe.w 27 | # To: layers.N.experts.M.w 28 | num_experts = config["moe"]["num_experts"] 29 | key_path = k.split(".") 30 | v = np.split(v, num_experts, axis=0) 31 | if key_path[-1] == "w2": 32 | v = [u.T for u in v] 33 | 34 | w_name = key_path.pop() 35 | key_path[-1] = "feed_forward.experts" 36 | return [ 37 | (".".join(key_path + [str(e), w_name, "weight"]), u) 38 | for e, u in enumerate(v) 39 | ] 40 | 41 | state = torch.load(tf) 42 | weights = {} 43 | for k, v in state.items(): 44 | weights.update(convert_single(k, v)) 45 | return weights 46 | 47 | 48 | def quantize(weights, config, args): 49 | quantized_config = copy.deepcopy(config) 50 | 51 | # Load the model and update with the subset of weights: 52 | config.pop("quantization", None) 53 | model = Mixtral(ModelArgs(**config)) 54 | all_weights = dict(tree_flatten(model.parameters())) 55 | 56 | weights = tree_map(mx.array, weights) 57 | 58 | all_weights.update(weights) 59 | all_weights = tree_unflatten(list(all_weights.items())) 60 | model.update(all_weights) 61 | 62 | # Quantize the model: 63 | nn.quantize( 64 | model, 65 | args.q_group_size, 66 | args.q_bits, 67 | ) 68 | 69 | # Extract the subset of quantized weights: 70 | all_weights = dict(tree_flatten(model.parameters())) 71 | quantized_weights = {} 72 | for k, v in all_weights.items(): 73 | if k not in weights: 74 | continue 75 | quantized_weights[k] = v 76 | prefix = k.split(".")[:-1] 77 | for qw in ["scales", "biases"]: 78 | if (k := ".".join(prefix + [qw])) in all_weights: 79 | quantized_weights[k] = all_weights[k] 80 | 81 | # Update the config: 82 | quantized_config["quantization"] = { 83 | "group_size": args.q_group_size, 84 | "bits": args.q_bits, 85 | } 86 | return quantized_weights, quantized_config 87 | 88 | 89 | if __name__ == "__main__": 90 | parser = argparse.ArgumentParser(description="Convert Mixtral weights to MLX.") 91 | parser.add_argument( 92 | "--torch-path", 93 | type=str, 94 | default="Mixtral-8x7B-v0.1", 95 | help="The path to the PyTorch model.", 96 | ) 97 | parser.add_argument( 98 | "--mlx-path", 99 | type=str, 100 | default="mlx_model", 101 | help="The path to save the MLX model.", 102 | ) 103 | parser.add_argument( 104 | "-q", 105 | "--quantize", 106 | help="Generate a quantized model.", 107 | action="store_true", 108 | ) 109 | parser.add_argument( 110 | "--q-group-size", 111 | help="Group size for quantization.", 112 | type=int, 113 | default=64, 114 | ) 115 | parser.add_argument( 116 | "--q-bits", 117 | help="Bits per weight for quantization.", 118 | type=int, 119 | default=4, 120 | ) 121 | args = parser.parse_args() 122 | torch_path = Path(args.torch_path) 123 | mlx_path = Path(args.mlx_path) 124 | mlx_path.mkdir(parents=True, exist_ok=True) 125 | 126 | with open("params.json") as fid: 127 | config = json.load(fid) 128 | 129 | # Copy tokenizer 130 | shutil.copyfile( 131 | str(torch_path / "tokenizer.model"), 132 | str(mlx_path / "tokenizer.model"), 133 | ) 134 | 135 | # Convert and save model in shards 136 | torch_files = glob.glob(str(torch_path / "consolidated.*.pt")) 137 | torch_files = sorted(torch_files, key=lambda tf: int(tf.split(".")[-2])) 138 | for e, tf in enumerate(torch_files): 139 | print(f"[INFO] Converting file {e + 1}/{len(torch_files)}") 140 | weights = convert(tf, config) 141 | if args.quantize: 142 | print("[INFO] Quantizing") 143 | weights, config = quantize(weights, config, args) 144 | np.savez(str(mlx_path / f"weights.{e}.npz"), **weights) 145 | 146 | # Save updated config 147 | with open(mlx_path / "config.json", "w") as f: 148 | config["model_type"] = "mixtral" 149 | json.dump(config, f, indent=4) 150 | -------------------------------------------------------------------------------- /llms/mixtral/params.json: -------------------------------------------------------------------------------- 1 | {"dim": 4096, "n_layers": 32, "head_dim": 128, "hidden_dim": 14336, "n_heads": 32, "n_kv_heads": 8, "norm_eps": 1e-05, "vocab_size": 32000, "moe": {"num_experts_per_tok": 2, "num_experts": 8}} 2 | -------------------------------------------------------------------------------- /llms/mixtral/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.11.0 2 | sentencepiece 3 | torch 4 | numpy 5 | -------------------------------------------------------------------------------- /llms/speculative_decoding/README.md: -------------------------------------------------------------------------------- 1 | # Speculative Decoding 2 | 3 | This example implements speculative decoding with the T5 model for text 4 | generation.[^1][^2] Speculative decoding uses a smaller draft model to propose 5 | several tokens, and a larger model to decide which tokens to accept. The 6 | distribution of the generated text is identical to what the larger model would 7 | produce on its own, but with far fewer forward passes of the large model since 8 | it can evaluate the draft tokens in parallel. 9 | 10 | ### Setup 11 | 12 | First, install the requirements: 13 | 14 | ``` 15 | cd speculative_decoding 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | Then convert the model and the draft model. We'll use T5-XXL (11B parameters) 20 | for the main model. Convert it with: 21 | 22 | ``` 23 | python convert.py --model t5-11b 24 | ``` 25 | 26 | We'll use T5-small for the draft model. Convert it with: 27 | 28 | ``` 29 | python convert.py --model t5-small 30 | ``` 31 | 32 | ### Run 33 | 34 | You can run with the default arguments: 35 | 36 | ``` 37 | python main.py 38 | ``` 39 | 40 | To see a full list of options use: 41 | ``` 42 | python main.py --help 43 | ``` 44 | 45 | ### Notes 46 | 47 | Speculative decoding works well when most of the tokens from the draft model 48 | are accepted by the larger model. That's more likely to happen if the models 49 | are trained on similar data. 50 | 51 | One way to increase the chance of accepting a draft token is with the parameter 52 | `--delta`. This parameter can be in the range $[0, 1]$. If it is $1$ then all 53 | the draft tokens will be accepted by the model. If it is $0$, then only draft 54 | tokens that match the original acceptance criterion are kept.[^1] Values 55 | closer to $1$ increase the chance that a draft token is accepted. 56 | 57 | Conversely, the fewer draft tokens accepted by the main model, the more 58 | expensive speculative decoding is. You can use `--num-draft` to tune the number 59 | of draft tokens per model evaluation to reduce the number of discarded 60 | draft tokens. Decreasing `--num-draft` will decrease the number of discarded 61 | draft tokens at the expense of more large model evaluations. 62 | 63 | [^1]: See the paper [Fast Inference from Transformers via Speculative 64 | Decoding](https://arxiv.org/abs/2211.17192) 65 | [^2]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683) 66 | or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5). 67 | -------------------------------------------------------------------------------- /llms/speculative_decoding/convert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from transformers import T5ForConditionalGeneration 3 | 4 | SHARED_REPLACEMENT_PATTERNS = [ 5 | (".block.", ".layers."), 6 | (".k.", ".key_proj."), 7 | (".o.", ".out_proj."), 8 | (".q.", ".query_proj."), 9 | (".v.", ".value_proj."), 10 | ("shared.", "wte."), 11 | ("lm_head.", "lm_head.linear."), 12 | (".layer.0.layer_norm.", ".ln1."), 13 | (".layer.1.layer_norm.", ".ln2."), 14 | (".layer.2.layer_norm.", ".ln3."), 15 | (".final_layer_norm.", ".ln."), 16 | ( 17 | "layers.0.layer.0.SelfAttention.relative_attention_bias.", 18 | "relative_attention_bias.embeddings.", 19 | ), 20 | ] 21 | 22 | ENCODER_REPLACEMENT_PATTERNS = [ 23 | (".layer.0.SelfAttention.", ".attention."), 24 | (".layer.1.DenseReluDense.", ".dense."), 25 | ] 26 | 27 | DECODER_REPLACEMENT_PATTERNS = [ 28 | (".layer.0.SelfAttention.", ".self_attention."), 29 | (".layer.1.EncDecAttention.", ".cross_attention."), 30 | (".layer.2.DenseReluDense.", ".dense."), 31 | ] 32 | 33 | 34 | def replace_key(key: str) -> str: 35 | for old, new in SHARED_REPLACEMENT_PATTERNS: 36 | key = key.replace(old, new) 37 | if key.startswith("encoder."): 38 | for old, new in ENCODER_REPLACEMENT_PATTERNS: 39 | key = key.replace(old, new) 40 | elif key.startswith("decoder."): 41 | for old, new in DECODER_REPLACEMENT_PATTERNS: 42 | key = key.replace(old, new) 43 | return key 44 | 45 | 46 | def convert(model_name, dtype): 47 | dtype = getattr(np, dtype) 48 | model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") 49 | weights = { 50 | replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items() 51 | } 52 | file_name = model_name.replace("/", "-") 53 | print(f"Saving weights to {file_name}.npz") 54 | np.savez(f"{file_name}.npz", **weights) 55 | 56 | 57 | if __name__ == "__main__": 58 | import argparse 59 | 60 | parser = argparse.ArgumentParser(description="Convert T5 weights to MLX") 61 | parser.add_argument( 62 | "--model", 63 | type=str, 64 | help="Name of the T5 model.", 65 | default="t5-small", 66 | ) 67 | parser.add_argument( 68 | "--dtype", 69 | help="The model data type.", 70 | type=str, 71 | choices=["float16", "float32"], 72 | default="float32", 73 | ) 74 | args = parser.parse_args() 75 | convert(args.model, args.dtype) 76 | -------------------------------------------------------------------------------- /llms/speculative_decoding/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import mlx.core as mx 5 | from decoder import SpeculativeDecoder 6 | from mlx.utils import tree_unflatten 7 | from model import Model 8 | from transformers import T5Config 9 | 10 | 11 | def load_model(model_name: str): 12 | config = T5Config.from_pretrained(model_name) 13 | model = Model(config) 14 | weights = mx.load(f"{model_name}.npz") 15 | weights = tree_unflatten(list(weights.items())) 16 | model.update(weights) 17 | mx.eval(model.parameters()) 18 | return model 19 | 20 | 21 | def main(args): 22 | mx.random.seed(args.seed) 23 | 24 | spec_decoder = SpeculativeDecoder( 25 | model=load_model(args.model_name), 26 | draft_model=load_model(args.draft_model_name), 27 | tokenizer=args.model_name, 28 | delta=args.delta, 29 | num_draft=args.num_draft, 30 | ) 31 | 32 | tic = time.time() 33 | print(args.prompt) 34 | if args.regular_decode: 35 | spec_decoder.generate(args.prompt, max_tokens=args.max_tokens) 36 | else: 37 | stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens) 38 | print("=" * 10) 39 | print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.") 40 | print(f"Decoding steps {stats['n_steps']}.") 41 | 42 | toc = time.time() 43 | print("=" * 10) 44 | print(f"Full generation time {toc - tic:.3f}") 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") 49 | parser.add_argument( 50 | "--num-draft", 51 | type=int, 52 | default=5, 53 | help="Number of draft tokens to use per decoding step.", 54 | ) 55 | parser.add_argument( 56 | "--model-name", 57 | help="Name of the model.", 58 | default="t5-small", 59 | ) 60 | parser.add_argument( 61 | "--draft-model-name", 62 | help="Name of the draft model.", 63 | default="t5-small", 64 | ) 65 | parser.add_argument( 66 | "--seed", 67 | type=int, 68 | default=0, 69 | help="PRNG seed.", 70 | ) 71 | parser.add_argument( 72 | "--max-tokens", 73 | "-m", 74 | type=int, 75 | default=100, 76 | help="Maximum number of tokens to generate.", 77 | ) 78 | parser.add_argument( 79 | "--prompt", 80 | default="translate English to French: Let's go to the store and buy some groceries including eggs, avocadoes, and bread.", 81 | help="The prompt processed by the model.", 82 | ) 83 | parser.add_argument( 84 | "--delta", 85 | type=float, 86 | default=0.1, 87 | help="Lenience for accepting the proposal tokens.", 88 | ) 89 | parser.add_argument( 90 | "--regular-decode", 91 | action="store_true", 92 | help="Use regular decoding instead of speculative decoding.", 93 | ) 94 | args = parser.parse_args() 95 | main(args) 96 | -------------------------------------------------------------------------------- /llms/speculative_decoding/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.8.0 2 | transformers 3 | numpy 4 | -------------------------------------------------------------------------------- /lora/.gitignore: -------------------------------------------------------------------------------- 1 | adapters.npz 2 | -------------------------------------------------------------------------------- /lora/convert.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import argparse 4 | import copy 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | import models 9 | import utils 10 | from mlx.utils import tree_flatten 11 | 12 | 13 | def quantize(weights, config, args): 14 | quantized_config = copy.deepcopy(config) 15 | 16 | # Load the model: 17 | model = models.Model(models.ModelArgs.from_dict(config)) 18 | model.load_weights(list(weights.items())) 19 | 20 | # Quantize the model: 21 | nn.quantize( 22 | model, 23 | args.q_group_size, 24 | args.q_bits, 25 | ) 26 | 27 | # Update the config: 28 | quantized_config["quantization"] = { 29 | "group_size": args.q_group_size, 30 | "bits": args.q_bits, 31 | } 32 | quantized_weights = dict(tree_flatten(model.parameters())) 33 | 34 | return quantized_weights, quantized_config 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser( 39 | description="Convert Hugging Face model to MLX format" 40 | ) 41 | parser.add_argument( 42 | "--hf-path", 43 | type=str, 44 | help="Path to the Hugging Face model.", 45 | ) 46 | parser.add_argument( 47 | "--mlx-path", 48 | type=str, 49 | default="mlx_model", 50 | help="Path to save the MLX model.", 51 | ) 52 | parser.add_argument( 53 | "-q", 54 | "--quantize", 55 | help="Generate a quantized model.", 56 | action="store_true", 57 | ) 58 | parser.add_argument( 59 | "--q-group-size", 60 | help="Group size for quantization.", 61 | type=int, 62 | default=64, 63 | ) 64 | parser.add_argument( 65 | "--q-bits", 66 | help="Bits per weight for quantization.", 67 | type=int, 68 | default=4, 69 | ) 70 | parser.add_argument( 71 | "--dtype", 72 | help="Type to save the parameters, ignored if -q is given.", 73 | type=str, 74 | choices=["float16", "bfloat16", "float32"], 75 | default="float16", 76 | ) 77 | parser.add_argument( 78 | "--upload-name", 79 | help="The name of model to upload to Hugging Face MLX Community", 80 | type=str, 81 | default=None, 82 | ) 83 | 84 | args = parser.parse_args() 85 | 86 | print("[INFO] Loading") 87 | weights, config, tokenizer = utils.fetch_from_hub(args.hf_path) 88 | 89 | dtype = mx.float16 if args.quantize else getattr(mx, args.dtype) 90 | weights = {k: v.astype(dtype) for k, v in weights.items()} 91 | if args.quantize: 92 | print("[INFO] Quantizing") 93 | weights, config = quantize(weights, config, args) 94 | 95 | utils.save_model(args.mlx_path, weights, tokenizer, config) 96 | if args.upload_name is not None: 97 | utils.upload_to_hub(args.mlx_path, args.upload_name, args.hf_path) 98 | -------------------------------------------------------------------------------- /lora/data/wikisql.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """ 4 | Code to preprocess the WikiSQL dataset adapted from 5 | https://github.com/salesforce/WikiSQL and 6 | https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb . 7 | """ 8 | 9 | 10 | import json 11 | import os 12 | 13 | 14 | def load(): 15 | """ 16 | Load all three splits of the WikiSQL dataset. 17 | """ 18 | return (WikiSQL(dn) for dn in ["train", "dev", "test"]) 19 | 20 | 21 | class WikiSQL: 22 | def __init__(self, dataset, save_dir="/tmp"): 23 | valid_sets = ("train", "dev", "test") 24 | if dataset not in valid_sets: 25 | raise ValueError(f"Dataset must be in {valid_sets}, got {dataset}") 26 | data_dir = os.path.join(save_dir, "wikisql") 27 | self._maybe_download(data_dir) 28 | 29 | self._parse_tables(os.path.join(data_dir, f"data/{dataset}.tables.jsonl")) 30 | self._parse_queries(os.path.join(data_dir, f"data/{dataset}.jsonl")) 31 | 32 | def _maybe_download(self, data_dir): 33 | if not os.path.exists(data_dir): 34 | import io 35 | import tarfile 36 | from urllib import request 37 | 38 | url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2" 39 | r = request.urlopen(url) 40 | with tarfile.open(fileobj=io.BytesIO(r.read())) as tf: 41 | tf.extractall(data_dir) 42 | 43 | def _parse_tables(self, tables): 44 | self._tables = {} 45 | with open(tables) as f: 46 | for line in f: 47 | table = json.loads(line) 48 | self._tables[table["id"]] = { 49 | "columns": table["header"], 50 | "types": table["types"], 51 | "desc": f"table: {table['id']}\ncolumns: {', '.join(table['header'])}", 52 | } 53 | 54 | def _parse_queries(self, queries): 55 | self._queries = [] 56 | with open(queries) as f: 57 | for line in f: 58 | query = json.loads(line) 59 | table = self._tables[query["table_id"]] 60 | question = query["question"] 61 | answer = self.query_to_text( 62 | query["sql"], query["table_id"], table["columns"], table["types"] 63 | ) 64 | self._queries.append( 65 | f"{table['desc']}\nQ: {question}\nA: {answer}" 66 | ) 67 | 68 | def query_to_text(self, query, table, columns, types): 69 | aggregation_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"] 70 | condition_ops = ["=", ">", "<", "OP"] 71 | column = columns[query["sel"]] 72 | aggregation = (aggregation_ops[query["agg"]] + " ") if query["agg"] > 0 else "" 73 | sql = f"SELECT {aggregation}{column} FROM {table}" 74 | 75 | conditions = query["conds"] 76 | if conditions: 77 | cs = [] 78 | for i, o, v in conditions: 79 | column = columns[i] 80 | op = condition_ops[o] 81 | 82 | if types[i] == "text": 83 | value = f"'{v}'" 84 | else: 85 | value = v 86 | cs.append(f"{column} {op} {value}") 87 | 88 | sql += " WHERE " + " AND ".join(cs) 89 | 90 | return sql 91 | 92 | def __getitem__(self, idx): 93 | return self._queries[idx] 94 | 95 | def __len__(self): 96 | return len(self._queries) 97 | 98 | 99 | if __name__ == "__main__": 100 | datanames = ["train", "dev", "test"] 101 | sizes = [56355, 8421, 15878] 102 | for dataname, size in zip(datanames, sizes): 103 | len(WikiSQL(dataname)) == size, f"Wrong {dataname} set size." 104 | 105 | # Write the sets to jsonl 106 | import json 107 | 108 | train, dev, test = load() 109 | datasets = [ 110 | (train, "train", 1000), 111 | (dev, "valid", 100), 112 | (test, "test", 100), 113 | ] 114 | for dataset, name, size in datasets: 115 | with open(f"data/{name}.jsonl", "w") as fid: 116 | for e, t in zip(range(size), dataset): 117 | # Strip the , since the tokenizer adds them 118 | json.dump({"text": t[3:-4]}, fid) 119 | fid.write("\n") 120 | -------------------------------------------------------------------------------- /lora/fuse.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import argparse 4 | from pathlib import Path 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | import utils 9 | from mlx.utils import tree_flatten, tree_unflatten 10 | from models import LoRALinear 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") 14 | parser.add_argument( 15 | "--model", 16 | default="mlx_model", 17 | help="The path to the local model directory or Hugging Face repo.", 18 | ) 19 | parser.add_argument( 20 | "--save-path", 21 | default="lora_fused_model", 22 | help="The path to save the fused model.", 23 | ) 24 | parser.add_argument( 25 | "--adapter-file", 26 | type=str, 27 | default="adapters.npz", 28 | help="Path to the trained adapter weights (npz or safetensors).", 29 | ) 30 | parser.add_argument( 31 | "--hf-path", 32 | help=( 33 | "Path to the original Hugging Face model. This is " 34 | "required for upload if --model is a local directory." 35 | ), 36 | type=str, 37 | default=None, 38 | ) 39 | parser.add_argument( 40 | "--upload-name", 41 | help="The name of model to upload to Hugging Face MLX Community.", 42 | type=str, 43 | default=None, 44 | ) 45 | parser.add_argument( 46 | "-d", 47 | "--de-quantize", 48 | help="Generate a de-quantized model.", 49 | action="store_true", 50 | ) 51 | 52 | print("Loading pretrained model") 53 | args = parser.parse_args() 54 | 55 | model, tokenizer, config = utils.load(args.model) 56 | 57 | # Load adapters and get number of LoRA layers 58 | adapters = list(mx.load(args.adapter_file).items()) 59 | lora_layers = len([m for m in adapters if "q_proj.lora_a" in m[0]]) 60 | 61 | # Freeze all layers other than LORA linears 62 | model.freeze() 63 | for l in model.model.layers[len(model.model.layers) - lora_layers :]: 64 | l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) 65 | l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) 66 | if hasattr(l, "block_sparse_moe"): 67 | l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate) 68 | 69 | model.update(tree_unflatten(adapters)) 70 | fused_linears = [ 71 | (n, m.to_linear()) 72 | for n, m in model.named_modules() 73 | if isinstance(m, LoRALinear) 74 | ] 75 | 76 | model.update_modules(tree_unflatten(fused_linears)) 77 | 78 | if args.de_quantize: 79 | de_quantize_layers = [] 80 | for n, m in model.named_modules(): 81 | if isinstance(m, nn.QuantizedLinear): 82 | bias = "bias" in m 83 | weight = m.weight 84 | weight = mx.dequantize( 85 | weight, 86 | m.scales, 87 | m.biases, 88 | m.group_size, 89 | m.bits, 90 | ).astype(mx.float16) 91 | output_dims, input_dims = weight.shape 92 | linear = nn.Linear(input_dims, output_dims, bias=bias) 93 | linear.weight = weight 94 | if bias: 95 | linear.bias = m.bias 96 | de_quantize_layers.append((n, linear)) 97 | 98 | model.update_modules(tree_unflatten(de_quantize_layers)) 99 | 100 | weights = dict(tree_flatten(model.parameters())) 101 | if args.de_quantize: 102 | config.pop("quantization", None) 103 | utils.save_model(args.save_path, weights, tokenizer, config) 104 | 105 | if args.upload_name is not None: 106 | hf_path = args.hf_path 107 | if not Path(args.model).exists(): 108 | # If the model path doesn't exist, assume it's an HF repo 109 | hf_path = args.model 110 | elif hf_path is None: 111 | raise ValueError( 112 | "Must provide original Hugging Face repo to upload local model." 113 | ) 114 | utils.upload_to_hub(args.save_path, args.upload_name, hf_path) 115 | -------------------------------------------------------------------------------- /lora/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.8.0 2 | transformers 3 | numpy 4 | -------------------------------------------------------------------------------- /mnist/README.md: -------------------------------------------------------------------------------- 1 | # MNIST 2 | 3 | This example shows how to run some simple models on MNIST. 4 | 5 | Install the dependencies: 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | Run the example with: 12 | 13 | ``` 14 | python main.py 15 | ``` 16 | 17 | By default, the example runs on the CPU. To run on the GPU, use: 18 | 19 | ``` 20 | python main.py --gpu 21 | ``` 22 | 23 | For a full list of options run: 24 | 25 | ``` 26 | python main.py --help 27 | ``` 28 | 29 | To run the PyTorch or JAX examples install the respective framework. 30 | -------------------------------------------------------------------------------- /mnist/main.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | import time 5 | from functools import partial 6 | 7 | import mlx.core as mx 8 | import mlx.nn as nn 9 | import mlx.optimizers as optim 10 | import numpy as np 11 | 12 | import mnist 13 | 14 | 15 | class MLP(nn.Module): 16 | """A simple MLP.""" 17 | 18 | def __init__( 19 | self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int 20 | ): 21 | super().__init__() 22 | layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] 23 | self.layers = [ 24 | nn.Linear(idim, odim) 25 | for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) 26 | ] 27 | 28 | def __call__(self, x): 29 | for l in self.layers[:-1]: 30 | x = nn.relu(l(x)) 31 | return self.layers[-1](x) 32 | 33 | 34 | def loss_fn(model, X, y): 35 | return nn.losses.cross_entropy(model(X), y, reduction="mean") 36 | 37 | 38 | def batch_iterate(batch_size, X, y): 39 | perm = mx.array(np.random.permutation(y.size)) 40 | for s in range(0, y.size, batch_size): 41 | ids = perm[s : s + batch_size] 42 | yield X[ids], y[ids] 43 | 44 | 45 | def main(args): 46 | seed = 0 47 | num_layers = 2 48 | hidden_dim = 32 49 | num_classes = 10 50 | batch_size = 256 51 | num_epochs = 10 52 | learning_rate = 1e-1 53 | 54 | np.random.seed(seed) 55 | 56 | # Load the data 57 | train_images, train_labels, test_images, test_labels = map( 58 | mx.array, getattr(mnist, args.dataset)() 59 | ) 60 | 61 | # Load the model 62 | model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) 63 | mx.eval(model.parameters()) 64 | 65 | optimizer = optim.SGD(learning_rate=learning_rate) 66 | loss_and_grad_fn = nn.value_and_grad(model, loss_fn) 67 | 68 | @partial(mx.compile, inputs=model.state, outputs=model.state) 69 | def step(X, y): 70 | loss, grads = loss_and_grad_fn(model, X, y) 71 | optimizer.update(model, grads) 72 | return loss 73 | 74 | @partial(mx.compile, inputs=model.state) 75 | def eval_fn(X, y): 76 | return mx.mean(mx.argmax(model(X), axis=1) == y) 77 | 78 | for e in range(num_epochs): 79 | tic = time.perf_counter() 80 | for X, y in batch_iterate(batch_size, train_images, train_labels): 81 | step(X, y) 82 | mx.eval(model.state) 83 | accuracy = eval_fn(test_images, test_labels) 84 | toc = time.perf_counter() 85 | print( 86 | f"Epoch {e}: Test accuracy {accuracy.item():.3f}," 87 | f" Time {toc - tic:.3f} (s)" 88 | ) 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.") 93 | parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") 94 | parser.add_argument( 95 | "--dataset", 96 | type=str, 97 | default="mnist", 98 | choices=["mnist", "fashion_mnist"], 99 | help="The dataset to use.", 100 | ) 101 | args = parser.parse_args() 102 | if not args.gpu: 103 | mx.set_default_device(mx.cpu) 104 | main(args) 105 | -------------------------------------------------------------------------------- /mnist/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import gzip 4 | import os 5 | import pickle 6 | from urllib import request 7 | 8 | import numpy as np 9 | 10 | 11 | def mnist( 12 | save_dir="/tmp", 13 | base_url="https://raw.githubusercontent.com/fgnt/mnist/master/", 14 | filename="mnist.pkl", 15 | ): 16 | """ 17 | Load the MNIST dataset in 4 tensors: train images, train labels, 18 | test images, and test labels. 19 | 20 | Checks `save_dir` for already downloaded data otherwise downloads. 21 | 22 | Download code modified from: 23 | https://github.com/hsjeong5/MNIST-for-Numpy 24 | """ 25 | 26 | def download_and_save(save_file): 27 | filename = [ 28 | ["training_images", "train-images-idx3-ubyte.gz"], 29 | ["test_images", "t10k-images-idx3-ubyte.gz"], 30 | ["training_labels", "train-labels-idx1-ubyte.gz"], 31 | ["test_labels", "t10k-labels-idx1-ubyte.gz"], 32 | ] 33 | 34 | mnist = {} 35 | for name in filename: 36 | out_file = os.path.join("/tmp", name[1]) 37 | request.urlretrieve(base_url + name[1], out_file) 38 | for name in filename[:2]: 39 | out_file = os.path.join("/tmp", name[1]) 40 | with gzip.open(out_file, "rb") as f: 41 | mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape( 42 | -1, 28 * 28 43 | ) 44 | for name in filename[-2:]: 45 | out_file = os.path.join("/tmp", name[1]) 46 | with gzip.open(out_file, "rb") as f: 47 | mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8) 48 | with open(save_file, "wb") as f: 49 | pickle.dump(mnist, f) 50 | 51 | save_file = os.path.join(save_dir, filename) 52 | if not os.path.exists(save_file): 53 | download_and_save(save_file) 54 | with open(save_file, "rb") as f: 55 | mnist = pickle.load(f) 56 | 57 | def preproc(x): 58 | return x.astype(np.float32) / 255.0 59 | 60 | mnist["training_images"] = preproc(mnist["training_images"]) 61 | mnist["test_images"] = preproc(mnist["test_images"]) 62 | return ( 63 | mnist["training_images"], 64 | mnist["training_labels"].astype(np.uint32), 65 | mnist["test_images"], 66 | mnist["test_labels"].astype(np.uint32), 67 | ) 68 | 69 | 70 | def fashion_mnist(save_dir="/tmp"): 71 | return mnist( 72 | save_dir, 73 | base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/", 74 | filename="fashion_mnist.pkl", 75 | ) 76 | 77 | 78 | if __name__ == "__main__": 79 | train_x, train_y, test_x, test_y = mnist() 80 | assert train_x.shape == (60000, 28 * 28), "Wrong training set size" 81 | assert train_y.shape == (60000,), "Wrong training set size" 82 | assert test_x.shape == (10000, 28 * 28), "Wrong test set size" 83 | assert test_y.shape == (10000,), "Wrong test set size" 84 | -------------------------------------------------------------------------------- /mnist/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.2 2 | numpy 3 | -------------------------------------------------------------------------------- /musicgen/README.md: -------------------------------------------------------------------------------- 1 | # MusicGen 2 | 3 | An example of Meta's MusicGen model in MLX.[^1] MusicGen is used to generate 4 | music from text descriptions. 5 | 6 | ### Setup 7 | 8 | Install the requirements: 9 | 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ### Example 15 | 16 | An example using the model: 17 | 18 | ```python 19 | from musicgen import MusicGen 20 | from utils import save_audio 21 | 22 | model = MusicGen.from_pretrained("facebook/musicgen-medium") 23 | 24 | audio = model.generate("happy rock") 25 | 26 | save_audio("out.wav", audio, model.sampling_rate) 27 | ``` 28 | 29 | [^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2306.05284) and 30 | [code](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) for more details. 31 | -------------------------------------------------------------------------------- /musicgen/benchmarks/bench_mx.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import sys 4 | import time 5 | from pathlib import Path 6 | 7 | import mlx.core as mx 8 | 9 | cur_path = Path(__file__).parents[1].resolve() 10 | sys.path.append(str(cur_path)) 11 | 12 | from musicgen import MusicGen 13 | 14 | text = "folk ballad" 15 | model = MusicGen.from_pretrained("facebook/musicgen-medium") 16 | 17 | max_steps = 100 18 | 19 | audio = model.generate(text, max_steps=10) 20 | mx.eval(audio) 21 | 22 | tic = time.time() 23 | audio = model.generate(text, max_steps=max_steps) 24 | mx.eval(audio) 25 | toc = time.time() 26 | 27 | ms = 1000 * (toc - tic) / max_steps 28 | print(f"Time (ms) per step: {ms:.3f}") 29 | -------------------------------------------------------------------------------- /musicgen/benchmarks/bench_pt.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import time 4 | 5 | import torch 6 | from transformers import AutoProcessor, MusicgenForConditionalGeneration 7 | 8 | model_name = "facebook/musicgen-medium" 9 | processor = AutoProcessor.from_pretrained(model_name) 10 | model = MusicgenForConditionalGeneration.from_pretrained(model_name).to("mps") 11 | 12 | inputs = processor( 13 | text=["folk ballad"], 14 | padding=True, 15 | return_tensors="pt", 16 | ) 17 | inputs["input_ids"] = inputs["input_ids"].to("mps") 18 | inputs["attention_mask"] = inputs["attention_mask"].to("mps") 19 | 20 | # warmup 21 | audio_values = model.generate(**inputs, max_new_tokens=10) 22 | torch.mps.synchronize() 23 | 24 | max_steps = 100 25 | tic = time.time() 26 | audio_values = model.generate(**inputs, max_new_tokens=max_steps) 27 | torch.mps.synchronize() 28 | toc = time.time() 29 | 30 | ms = 1000 * (toc - tic) / max_steps 31 | print(f"Time (ms) per step: {ms:.3f}") 32 | -------------------------------------------------------------------------------- /musicgen/encodec.py: -------------------------------------------------------------------------------- 1 | ../encodec/encodec.py -------------------------------------------------------------------------------- /musicgen/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import argparse 4 | 5 | from utils import save_audio 6 | 7 | from musicgen import MusicGen 8 | 9 | 10 | def main(text: str, output_path: str, model_name: str, max_steps: int): 11 | model = MusicGen.from_pretrained(model_name) 12 | audio = model.generate(text, max_steps=max_steps) 13 | save_audio(output_path, audio, model.sampling_rate) 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model", required=False, default="facebook/musicgen-medium") 19 | parser.add_argument("--text", required=False, default="happy rock") 20 | parser.add_argument("--output-path", required=False, default="0.wav") 21 | parser.add_argument("--max-steps", required=False, default=500, type=int) 22 | args = parser.parse_args() 23 | main(args.text, args.output_path, args.model, args.max_steps) 24 | -------------------------------------------------------------------------------- /musicgen/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.18 2 | numpy 3 | huggingface_hub 4 | torch 5 | transformers 6 | scipy 7 | -------------------------------------------------------------------------------- /musicgen/t5.py: -------------------------------------------------------------------------------- 1 | ../t5/t5.py -------------------------------------------------------------------------------- /musicgen/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import mlx.core as mx 4 | import numpy as np 5 | 6 | 7 | def save_audio(file: str, audio: mx.array, sampling_rate: int): 8 | """ 9 | Save audio to a wave (.wav) file. 10 | """ 11 | from scipy.io.wavfile import write 12 | 13 | audio = mx.clip(audio, -1, 1) 14 | audio = (audio * 32767).astype(mx.int16) 15 | write(file, sampling_rate, np.array(audio)) 16 | -------------------------------------------------------------------------------- /normalizing_flow/README.md: -------------------------------------------------------------------------------- 1 | # Normalizing Flow 2 | 3 | An example of a normalizing flow for density estimation and sampling 4 | implemented in MLX. This example implements the real NVP (non-volume 5 | preserving) model.[^1] 6 | 7 | ## Basic usage 8 | 9 | ```python 10 | import mlx.core as mx 11 | from flows import RealNVP 12 | 13 | model = RealNVP(n_transforms=8, d_params=4, d_hidden=256, n_layers=4) 14 | 15 | x = mx.random.normal(shape=(32, 4)) 16 | 17 | # Evaluate log-density 18 | log_prob = model.log_prob(x=x) 19 | 20 | # Draw samples 21 | x_samples = model.sample(sample_shape=(32, 4)) 22 | ``` 23 | 24 | ## Running the example 25 | 26 | Install the dependencies: 27 | 28 | ``` 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | The example can be run with: 33 | ``` 34 | python main.py [--cpu] 35 | ``` 36 | 37 | This trains the normalizing flow on the two moons dataset and plots the result 38 | in `samples.png`. The optional `--cpu` flag can be used to run the example on 39 | the CPU, otherwise it will use the GPU by default. 40 | 41 | For all available options, run: 42 | 43 | ``` 44 | python main.py --help 45 | ``` 46 | 47 | ## Results 48 | 49 | ![Samples](./samples.png) 50 | 51 | [^1]: This example is from [Density estimation using Real NVP]( 52 | https://arxiv.org/abs/1605.08803), Dinh et al. (2016) 53 | -------------------------------------------------------------------------------- /normalizing_flow/bijectors.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from typing import Tuple 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | 8 | 9 | class Bijector: 10 | def forward_and_log_det(self, x: mx.array) -> Tuple[mx.array, mx.array]: 11 | raise NotImplementedError 12 | 13 | def inverse_and_log_det(self, y: mx.array) -> Tuple[mx.array, mx.array]: 14 | raise NotImplementedError 15 | 16 | 17 | class AffineBijector(Bijector): 18 | def __init__(self, shift_and_log_scale: mx.array): 19 | self.shift_and_log_scale = shift_and_log_scale 20 | 21 | def forward_and_log_det(self, x: mx.array): 22 | shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1) 23 | y = x * mx.exp(log_scale) + shift 24 | log_det = log_scale 25 | return y, log_det 26 | 27 | def inverse_and_log_det(self, y: mx.array): 28 | shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1) 29 | x = (y - shift) * mx.exp(-log_scale) 30 | log_det = -log_scale 31 | return x, log_det 32 | 33 | 34 | class MaskedCoupling(Bijector): 35 | def __init__(self, mask: mx.array, conditioner: nn.Module, bijector: Bijector): 36 | """Coupling layer with masking and conditioner.""" 37 | self.mask = mask 38 | self.conditioner = conditioner 39 | self.bijector = bijector 40 | 41 | def apply_mask(self, x: mx.array, func: callable): 42 | """Transforms masked indices of `x` conditioned on unmasked indices using `func`.""" 43 | x_masked = mx.where(self.mask, 0.0, x) 44 | bijector_params = self.conditioner(x_masked) 45 | y, log_det = func(bijector_params) 46 | log_det = mx.where(self.mask, log_det, 0.0) 47 | y = mx.where(self.mask, y, x) 48 | return y, mx.sum(log_det, axis=-1) 49 | 50 | def forward_and_log_det(self, x: mx.array): 51 | """Transforms masked indices of `x` conditioned on unmasked indices using bijector.""" 52 | return self.apply_mask( 53 | x, lambda params: self.bijector(params).forward_and_log_det(x) 54 | ) 55 | 56 | def inverse_and_log_det(self, y: mx.array): 57 | """Transforms masked indices of `y` conditioned on unmasked indices using bijector.""" 58 | return self.apply_mask( 59 | y, lambda params: self.bijector(params).inverse_and_log_det(y) 60 | ) 61 | -------------------------------------------------------------------------------- /normalizing_flow/distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import math 4 | from typing import Optional, Tuple, Union 5 | 6 | import mlx.core as mx 7 | 8 | 9 | class Normal: 10 | def __init__(self, mu: mx.array, sigma: mx.array): 11 | super().__init__() 12 | self.mu = mu 13 | self.sigma = sigma 14 | 15 | def sample( 16 | self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None 17 | ): 18 | return mx.random.normal(sample_shape, key=key) * self.sigma + self.mu 19 | 20 | def log_prob(self, x: mx.array): 21 | return ( 22 | -0.5 * math.log(2 * math.pi) 23 | - mx.log(self.sigma) 24 | - 0.5 * ((x - self.mu) / self.sigma) ** 2 25 | ) 26 | 27 | def sample_and_log_prob( 28 | self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None 29 | ): 30 | x = self.sample(sample_shape, key=key) 31 | return x, self.log_prob(x) 32 | -------------------------------------------------------------------------------- /normalizing_flow/flows.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from typing import Optional, Tuple, Union 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | from bijectors import AffineBijector, MaskedCoupling 8 | from distributions import Normal 9 | 10 | 11 | class MLP(nn.Module): 12 | def __init__(self, n_layers: int, d_in: int, d_hidden: int, d_out: int): 13 | super().__init__() 14 | layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out] 15 | self.layers = [ 16 | nn.Linear(idim, odim) 17 | for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) 18 | ] 19 | 20 | def __call__(self, x): 21 | for l in self.layers[:-1]: 22 | x = nn.gelu(l(x)) 23 | return self.layers[-1](x) 24 | 25 | 26 | class RealNVP(nn.Module): 27 | def __init__(self, n_transforms: int, d_params: int, d_hidden: int, n_layers: int): 28 | super().__init__() 29 | 30 | # Alternating masks 31 | self.mask_list = [mx.arange(d_params) % 2 == i % 2 for i in range(n_transforms)] 32 | self.mask_list = [mask.astype(mx.bool_) for mask in self.mask_list] 33 | 34 | self.freeze(keys=["mask_list"]) 35 | 36 | # Conditioning MLP 37 | self.conditioner_list = [ 38 | MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms) 39 | ] 40 | 41 | self.base_dist = Normal(mx.zeros(d_params), mx.ones(d_params)) 42 | 43 | def log_prob(self, x: mx.array): 44 | """ 45 | Flow back to the primal Gaussian and compute log-density, 46 | adding the transformation log-determinant along the way. 47 | """ 48 | log_prob = mx.zeros(x.shape[0]) 49 | for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]): 50 | x, ldj = MaskedCoupling( 51 | mask, conditioner, AffineBijector 52 | ).inverse_and_log_det(x) 53 | log_prob += ldj 54 | return log_prob + self.base_dist.log_prob(x).sum(-1) 55 | 56 | def sample( 57 | self, 58 | sample_shape: Union[int, Tuple[int, ...]], 59 | key: Optional[mx.array] = None, 60 | n_transforms: Optional[int] = None, 61 | ): 62 | """ 63 | Sample from the primal Gaussian and flow towards the target distribution. 64 | """ 65 | x = self.base_dist.sample(sample_shape, key=key) 66 | for mask, conditioner in zip( 67 | self.mask_list[:n_transforms], self.conditioner_list[:n_transforms] 68 | ): 69 | x, _ = MaskedCoupling( 70 | mask, conditioner, AffineBijector 71 | ).forward_and_log_det(x) 72 | return x 73 | 74 | def __call__(self, x: mx.array): 75 | return self.log_prob(x) 76 | -------------------------------------------------------------------------------- /normalizing_flow/main.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from functools import partial 4 | 5 | import matplotlib.pyplot as plt 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | import mlx.optimizers as optim 9 | import numpy as np 10 | from flows import RealNVP 11 | from sklearn import datasets, preprocessing 12 | from tqdm import trange 13 | 14 | 15 | def get_moons_dataset(n_samples=100_000, noise=0.06): 16 | """Get two moons dataset with given noise level.""" 17 | x, _ = datasets.make_moons(n_samples=n_samples, noise=noise) 18 | scaler = preprocessing.StandardScaler() 19 | x = scaler.fit_transform(x) 20 | return x 21 | 22 | 23 | def main(args): 24 | x = get_moons_dataset(n_samples=100_000, noise=args.noise) 25 | 26 | model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers) 27 | mx.eval(model.parameters()) 28 | 29 | def loss_fn(model, x): 30 | return -mx.mean(model(x)) 31 | 32 | optimizer = optim.Adam(learning_rate=args.learning_rate) 33 | 34 | state = [model.state, optimizer.state] 35 | 36 | @partial(mx.compile, inputs=state, outputs=state) 37 | def step(x): 38 | loss_and_grad_fn = nn.value_and_grad(model, loss_fn) 39 | loss, grads = loss_and_grad_fn(model, x) 40 | optimizer.update(model, grads) 41 | return loss 42 | 43 | with trange(args.n_steps) as steps: 44 | for it in steps: 45 | idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch) 46 | loss = step(mx.array(x[idx])) 47 | mx.eval(state) 48 | steps.set_postfix(val=loss.item()) 49 | 50 | # Plot samples from trained flow 51 | 52 | fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4)) 53 | cmap = plt.get_cmap("Blues") 54 | bins = 100 55 | 56 | # Sample from intermediate flow-transformed distributions 57 | for n_transforms in range(args.n_transforms + 1): 58 | x_samples = model.sample((100_000, 2), n_transforms=n_transforms) 59 | 60 | axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap) 61 | axs[n_transforms].set_xlim(-2, 2) 62 | axs[n_transforms].set_ylim(-2, 2) 63 | axs[n_transforms].set_title( 64 | f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution" 65 | ) 66 | axs[n_transforms].set_xticklabels([]) 67 | axs[n_transforms].set_yticklabels([]) 68 | 69 | # Plot original data 70 | axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap) 71 | axs[-1].set_xlim(-2, 2) 72 | axs[-1].set_ylim(-2, 2) 73 | axs[-1].set_title("Original data") 74 | axs[-1].set_xticklabels([]) 75 | axs[-1].set_yticklabels([]) 76 | 77 | plt.tight_layout() 78 | plt.savefig("samples.png") 79 | 80 | 81 | if __name__ == "__main__": 82 | import argparse 83 | 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument( 86 | "--n_steps", type=int, default=5_000, help="Number of steps to train" 87 | ) 88 | parser.add_argument("--n_batch", type=int, default=64, help="Batch size") 89 | parser.add_argument( 90 | "--n_transforms", type=int, default=6, help="Number of flow transforms" 91 | ) 92 | parser.add_argument( 93 | "--d_params", type=int, default=2, help="Dimensionality of modeled distribution" 94 | ) 95 | parser.add_argument( 96 | "--d_hidden", 97 | type=int, 98 | default=128, 99 | help="Hidden dimensionality of coupling conditioner", 100 | ) 101 | parser.add_argument( 102 | "--n_layers", 103 | type=int, 104 | default=4, 105 | help="Number of layers in coupling conditioner", 106 | ) 107 | parser.add_argument( 108 | "--learning_rate", type=float, default=3e-4, help="Learning rate" 109 | ) 110 | parser.add_argument( 111 | "--noise", type=float, default=0.06, help="Noise level in two moons dataset" 112 | ) 113 | parser.add_argument("--cpu", action="store_true") 114 | 115 | args = parser.parse_args() 116 | 117 | if args.cpu: 118 | mx.set_default_device(mx.cpu) 119 | 120 | main(args) 121 | -------------------------------------------------------------------------------- /normalizing_flow/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.2 2 | numpy 3 | tqdm 4 | scikit-learn 5 | matplotlib 6 | -------------------------------------------------------------------------------- /normalizing_flow/samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/normalizing_flow/samples.png -------------------------------------------------------------------------------- /segment_anything/README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything 2 | 3 | An implementation of the Segment Anything Model (SAM) in MLX. See the original 4 | repo by Meta AI for more details.[^1] 5 | 6 | ## Installation 7 | 8 | ```bash 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ## Convert 13 | 14 | ```bash 15 | python convert.py --hf-path facebook/sam-vit-base --mlx-path sam-vit-base 16 | ``` 17 | 18 | The `safetensors` weight file and configs are downloaded from Hugging Face, 19 | converted, and saved in the directory specified by `--mlx-path`. 20 | 21 | The model sizes are: 22 | 23 | - `facebook/sam-vit-base` 24 | - `facebook/sam-vit-large` 25 | - `facebook/sam-vit-huge` 26 | 27 | ## Run 28 | 29 | See examples `notebooks/predictor_example.ipynb` and 30 | `notebooks/automatic_mask_generator_example.ipynb` to try the Segment Anything 31 | Model with MLX. 32 | 33 | You can also generate masks from the command line: 34 | 35 | ```bash 36 | python main.py --model --input --output 37 | ``` 38 | 39 | [^1]: The original Segment Anything [GitHub repo](https://github.com/facebookresearch/segment-anything/tree/main). 40 | -------------------------------------------------------------------------------- /segment_anything/convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import shutil 4 | from pathlib import Path 5 | from typing import Dict, Union 6 | 7 | import mlx.core as mx 8 | from huggingface_hub import snapshot_download 9 | 10 | 11 | def save_weights(save_path: Union[str, Path], weights: Dict[str, mx.array]) -> None: 12 | """Save model weights into specified directory.""" 13 | if isinstance(save_path, str): 14 | save_path = Path(save_path) 15 | save_path.mkdir(parents=True, exist_ok=True) 16 | 17 | total_size = sum(v.nbytes for v in weights.values()) 18 | index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} 19 | 20 | model_path = save_path / "model.safetensors" 21 | mx.save_safetensors(str(model_path), weights) 22 | 23 | for weight_name in weights.keys(): 24 | index_data["weight_map"][weight_name] = "model.safetensors" 25 | 26 | index_data["weight_map"] = { 27 | k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) 28 | } 29 | 30 | with open(save_path / "model.safetensors.index.json", "w") as f: 31 | json.dump(index_data, f, indent=4) 32 | 33 | 34 | def download(hf_repo): 35 | return Path( 36 | snapshot_download( 37 | repo_id=hf_repo, 38 | allow_patterns=["*.safetensors", "*.json"], 39 | resume_download=True, 40 | ) 41 | ) 42 | 43 | 44 | def convert(model_path): 45 | weight_file = str(model_path / "model.safetensors") 46 | weights = mx.load(weight_file) 47 | 48 | mlx_weights = dict() 49 | for k, v in weights.items(): 50 | if k in { 51 | "vision_encoder.patch_embed.projection.weight", 52 | "vision_encoder.neck.conv1.weight", 53 | "vision_encoder.neck.conv2.weight", 54 | "prompt_encoder.mask_embed.conv1.weight", 55 | "prompt_encoder.mask_embed.conv2.weight", 56 | "prompt_encoder.mask_embed.conv3.weight", 57 | }: 58 | v = v.transpose(0, 2, 3, 1) 59 | if k in { 60 | "mask_decoder.upscale_conv1.weight", 61 | "mask_decoder.upscale_conv2.weight", 62 | }: 63 | v = v.transpose(1, 2, 3, 0) 64 | mlx_weights[k] = v 65 | return mlx_weights 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = argparse.ArgumentParser(description="Convert Meta SAM weights to MLX") 70 | parser.add_argument( 71 | "--hf-path", 72 | default="facebook/sam-vit-base", 73 | type=str, 74 | help="Path to the Hugging Face model repo.", 75 | ) 76 | parser.add_argument( 77 | "--mlx-path", 78 | type=str, 79 | default="sam-vit-base", 80 | help="Path to save the MLX model.", 81 | ) 82 | args = parser.parse_args() 83 | 84 | model_path = download(args.hf_path) 85 | 86 | mlx_path = Path(args.mlx_path) 87 | mlx_path.mkdir(parents=True, exist_ok=True) 88 | 89 | mlx_weights = convert(model_path) 90 | save_weights(mlx_path, mlx_weights) 91 | shutil.copy(model_path / "config.json", mlx_path / "config.json") 92 | -------------------------------------------------------------------------------- /segment_anything/notebooks/images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/segment_anything/notebooks/images/dog.jpg -------------------------------------------------------------------------------- /segment_anything/notebooks/images/groceries.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/segment_anything/notebooks/images/groceries.jpg -------------------------------------------------------------------------------- /segment_anything/notebooks/images/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/segment_anything/notebooks/images/truck.jpg -------------------------------------------------------------------------------- /segment_anything/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | opencv-python 3 | huggingface_hub 4 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | from .automatic_mask_generator import SamAutomaticMaskGenerator 2 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/common.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | 6 | 7 | class MLPBlock(nn.Module): 8 | def __init__( 9 | self, 10 | embedding_dim: int, 11 | mlp_dim: int, 12 | act: Type[nn.Module] = nn.GELU, 13 | ) -> None: 14 | super().__init__() 15 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 16 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 17 | self.act = act() 18 | 19 | def __call__(self, x: mx.array) -> mx.array: 20 | return self.lin2(self.act(self.lin1(x))) 21 | 22 | 23 | class LayerNorm2d(nn.Module): 24 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 25 | super().__init__() 26 | self.weight = mx.ones(num_channels) 27 | self.bias = mx.zeros(num_channels) 28 | self.eps = eps 29 | 30 | def __call__(self, x: mx.array) -> mx.array: 31 | u = x.mean(3, keepdims=True) 32 | s = ((x - u) ** 2).mean(3, keepdims=True) 33 | x = (x - u) / mx.sqrt(s + self.eps) 34 | x = self.weight * x + self.bias 35 | return x 36 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/segment_anything/segment_anything/utils/__init__.py -------------------------------------------------------------------------------- /segment_anything/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Tuple 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | import numpy as np 7 | from PIL import Image 8 | 9 | 10 | class ResizeLongestSide: 11 | """ 12 | Resizes images to the longest side 'target_length', as well as provides 13 | methods for resizing coordinates and boxes. Provides methods for 14 | transforming both numpy array and batched mlx tensors. 15 | """ 16 | 17 | def __init__(self, target_length: int) -> None: 18 | self.target_length = target_length 19 | 20 | def apply_image(self, image: np.ndarray) -> np.ndarray: 21 | """ 22 | Expects a numpy array with shape HxWxC in uint8 format. 23 | """ 24 | target_size = self.get_preprocess_shape( 25 | image.shape[0], image.shape[1], self.target_length 26 | ) 27 | return np.array( 28 | Image.fromarray(image).resize( 29 | target_size[::-1], resample=Image.Resampling.BILINEAR 30 | ) 31 | ) 32 | 33 | def apply_coords( 34 | self, coords: mx.array, original_size: Tuple[int, ...] 35 | ) -> mx.array: 36 | """ 37 | Expects a mlx tensor with length 2 in the last dimension. Requires the 38 | original image size in (H, W) format. 39 | """ 40 | old_h, old_w = original_size 41 | new_h, new_w = self.get_preprocess_shape( 42 | original_size[0], original_size[1], self.target_length 43 | ) 44 | return coords * mx.array([new_w / old_w, new_h / old_h]) 45 | 46 | def apply_boxes(self, boxes: mx.array, original_size: Tuple[int, ...]) -> mx.array: 47 | """ 48 | Expects a mlx tensor with shape ...x4. Requires the original image 49 | size in (H, W) format. 50 | """ 51 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 52 | return boxes.reshape(-1, 4) 53 | 54 | @staticmethod 55 | def get_preprocess_shape( 56 | oldh: int, oldw: int, long_side_length: int 57 | ) -> Tuple[int, int]: 58 | """ 59 | Compute the output size given input size and target long side length. 60 | """ 61 | scale = long_side_length * 1.0 / max(oldh, oldw) 62 | newh, neww = oldh * scale, oldw * scale 63 | neww = int(neww + 0.5) 64 | newh = int(newh + 0.5) 65 | return (newh, neww) 66 | -------------------------------------------------------------------------------- /speechcommands/README.md: -------------------------------------------------------------------------------- 1 | # Train a Keyword Spotting Transformer on Speech Commands 2 | 3 | An example of training a Keyword Spotting Transformer[^1] on the Speech 4 | Commands dataset[^2] with MLX. All supervised only configurations from the 5 | paper are available. The example also illustrates how to use [MLX 6 | Data](https://github.com/ml-explore/mlx-data) to load and process an audio 7 | dataset. 8 | 9 | ## Pre-requisites 10 | 11 | Install the remaining python requirements: 12 | 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Running the example 18 | 19 | Run the example with: 20 | 21 | ``` 22 | python main.py 23 | ``` 24 | 25 | By default the example runs on the GPU. To run it on the CPU, use: 26 | 27 | ``` 28 | python main.py --cpu 29 | ``` 30 | 31 | For all available options, run: 32 | 33 | ``` 34 | python main.py --help 35 | ``` 36 | 37 | ## Results 38 | 39 | After training with the `kwt1` architecture for 100 epochs, you 40 | should see the following results: 41 | 42 | ``` 43 | Epoch: 99 | avg. Train loss 0.018 | avg. Train acc 0.996 | Throughput: 662.51 samples/sec 44 | Epoch: 99 | Val acc 0.893 | Throughput: 3091.26 samples/sec 45 | Testing best model from epoch 97 46 | Test acc -> 0.882 47 | ``` 48 | 49 | For the `kwt2` model, you should see: 50 | ``` 51 | Epoch: 99 | avg. Train loss 0.003 | avg. Train acc 1.000 | Throughput: 396.53 samples/sec 52 | Epoch: 99 | Val acc 0.901 | Throughput: 1543.48 samples/sec 53 | Testing best model from epoch 94 54 | Test acc -> 0.893 55 | ``` 56 | 57 | Note that this was run on an M1 Macbook Pro with 16GB RAM. 58 | 59 | At the time of writing, `mlx` doesn't have built-in `cosine` learning rate 60 | schedules, which is used along with the AdamW optimizer in the official 61 | implementation. We intend to update this example once these features are added, 62 | as well as with appropriate data augmentations. 63 | 64 | [^1]: Based on the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html) 65 | [^2]: We use version 0.02. See the [paper](https://arxiv.org/abs/1804.03209) for more details. 66 | -------------------------------------------------------------------------------- /speechcommands/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.2 2 | mlx-data 3 | -------------------------------------------------------------------------------- /stable_diffusion/README.md: -------------------------------------------------------------------------------- 1 | Stable Diffusion 2 | ================ 3 | 4 | Stable Diffusion in MLX. The implementation was ported from Hugging Face's 5 | [diffusers](https://huggingface.co/docs/diffusers/index) and model weights are 6 | downloaded directly from the Hugging Face hub. The implementation currently 7 | supports the following models: 8 | 9 | - [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo) 10 | - [stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) 11 | 12 | ![out](generated-mlx.png) 13 | *Image generated using Stable Diffusion in MLX and the prompt 'A big red sign 14 | saying MLX in capital letters.'* 15 | 16 | Installation 17 | ------------ 18 | 19 | The dependencies are minimal, namely: 20 | 21 | - `huggingface-hub` to download the checkpoints. 22 | - `regex` for the tokenization 23 | - `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script 24 | 25 | You can install all of the above with the `requirements.txt` as follows: 26 | 27 | pip install -r requirements.txt 28 | 29 | Usage 30 | ------ 31 | 32 | Although each component in this repository can be used by itself, the fastest 33 | way to get started is by using the `StableDiffusion` class from the `stable_diffusion` 34 | module. 35 | 36 | ```python 37 | import mlx.core as mx 38 | from stable_diffusion import StableDiffusion 39 | 40 | # This will download all the weights from HF hub and load the models in 41 | # memory 42 | sd = StableDiffusion() 43 | 44 | # This creates a python generator that returns the latent produced by the 45 | # reverse diffusion process. 46 | # 47 | # Because MLX is lazily evaluated iterating over this generator doesn't 48 | # actually perform the computation until mx.eval() is called. 49 | latent_generator = sd.generate_latents( 50 | "A photo of an astronaut riding a horse on Mars." 51 | ) 52 | 53 | # Here we are evaluating each diffusion step but we could also evaluate 54 | # once at the end. 55 | for x_t in latent_generator: 56 | mx.eval(x_t) 57 | 58 | # Now x_t is the last latent from the reverse process aka x_0. We can 59 | # decode it into an image using the stable diffusion VAE. 60 | im = sd.decode(x_t) 61 | ``` 62 | 63 | The above is essentially the implementation of the `txt2image.py` script in the 64 | root of the repository. You can use the script as follows: 65 | 66 | 67 | ```shell 68 | python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2 69 | ``` 70 | 71 | You can select the model using `--model` argument. Currently supported models 72 | are `sdxl` (default) and `sd`. 73 | 74 | Image 2 Image 75 | ------------- 76 | 77 | There is also the option of generating images based on another image using the 78 | example script `image2image.py`. To do that an image is first encoded using the 79 | autoencoder to get its latent representation and then noise is added according 80 | to the forward diffusion process and the `strength` parameter. A `strength` of 81 | 0.0 means no noise and a `strength` of 1.0 means starting from completely 82 | random noise. 83 | 84 | ![image2image](im2im.png) 85 | 86 | *Generations with varying strength using the original image and the prompt 'A lit fireplace'.* 87 | 88 | The command to generate the above images is: 89 | 90 | ```shell 91 | python image2image.py --strength 0.5 original.png 'A lit fireplace' 92 | ``` 93 | 94 | > [!Note] 95 | > `image2image.py` will automatically downsample your input image to guarantee 96 | > that its dimensions are divisible by 64. If you want full control of this 97 | > process, resize your image prior to using the script. 98 | 99 | Memory constrained devices 100 | -------------------------- 101 | 102 | The `txt2image.py` script by default loads the model in float16 which reduces 103 | significantly the required memory for image generation. However, since the 104 | Stable Diffusion XL UNet alone has 2.6B parameters in order to use it in 105 | devices with 8GB of RAM, quantization is practically necessary. 106 | 107 | The `txt2image.py` script supports quantization using the `-q` or `--quantize` 108 | command line arguments. When quantization is used, the script quantizes the 109 | text encoder models to 4 bits and the unet to 8 bits. This allows generating 110 | images on an 8GB Mac Mini with no-swapping. 111 | 112 | ``` 113 | python txt2image.py --n_images 4 -q -v --output still-life.png "A painting of a vase on a wooden table, dark background, still life." 114 | ``` 115 | 116 | ![painting](still-life.png) 117 | *Image generated using Stable Diffusion XL turbo in MLX with the above command on an 8GB M1 Mac mini* 118 | -------------------------------------------------------------------------------- /stable_diffusion/generated-mlx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/stable_diffusion/generated-mlx.png -------------------------------------------------------------------------------- /stable_diffusion/im2im.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/stable_diffusion/im2im.png -------------------------------------------------------------------------------- /stable_diffusion/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.11 2 | huggingface-hub 3 | regex 4 | numpy 5 | tqdm 6 | Pillow 7 | -------------------------------------------------------------------------------- /stable_diffusion/stable_diffusion/clip.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | 9 | from .config import CLIPTextModelConfig 10 | 11 | _ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu} 12 | 13 | 14 | @dataclass 15 | class CLIPOutput: 16 | # The last_hidden_state indexed at the EOS token and possibly projected if 17 | # the model has a projection layer 18 | pooled_output: Optional[mx.array] = None 19 | 20 | # The full sequence output of the transformer after the final layernorm 21 | last_hidden_state: Optional[mx.array] = None 22 | 23 | # A list of hidden states corresponding to the outputs of the transformer layers 24 | hidden_states: Optional[List[mx.array]] = None 25 | 26 | 27 | class CLIPEncoderLayer(nn.Module): 28 | """The transformer encoder layer from CLIP.""" 29 | 30 | def __init__(self, model_dims: int, num_heads: int, activation: str): 31 | super().__init__() 32 | 33 | self.layer_norm1 = nn.LayerNorm(model_dims) 34 | self.layer_norm2 = nn.LayerNorm(model_dims) 35 | 36 | self.attention = nn.MultiHeadAttention(model_dims, num_heads) 37 | # Add biases to the attention projections to match CLIP 38 | self.attention.query_proj.bias = mx.zeros(model_dims) 39 | self.attention.key_proj.bias = mx.zeros(model_dims) 40 | self.attention.value_proj.bias = mx.zeros(model_dims) 41 | self.attention.out_proj.bias = mx.zeros(model_dims) 42 | 43 | self.linear1 = nn.Linear(model_dims, 4 * model_dims) 44 | self.linear2 = nn.Linear(4 * model_dims, model_dims) 45 | 46 | self.act = _ACTIVATIONS[activation] 47 | 48 | def __call__(self, x, attn_mask=None): 49 | y = self.layer_norm1(x) 50 | y = self.attention(y, y, y, attn_mask) 51 | x = y + x 52 | 53 | y = self.layer_norm2(x) 54 | y = self.linear1(y) 55 | y = self.act(y) 56 | y = self.linear2(y) 57 | x = y + x 58 | 59 | return x 60 | 61 | 62 | class CLIPTextModel(nn.Module): 63 | """Implements the text encoder transformer from CLIP.""" 64 | 65 | def __init__(self, config: CLIPTextModelConfig): 66 | super().__init__() 67 | 68 | self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims) 69 | self.position_embedding = nn.Embedding(config.max_length, config.model_dims) 70 | self.layers = [ 71 | CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act) 72 | for i in range(config.num_layers) 73 | ] 74 | self.final_layer_norm = nn.LayerNorm(config.model_dims) 75 | 76 | if config.projection_dim is not None: 77 | self.text_projection = nn.Linear( 78 | config.model_dims, config.projection_dim, bias=False 79 | ) 80 | 81 | def _get_mask(self, N, dtype): 82 | indices = mx.arange(N) 83 | mask = indices[:, None] < indices[None] 84 | mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9) 85 | return mask 86 | 87 | def __call__(self, x): 88 | # Extract some shapes 89 | B, N = x.shape 90 | eos_tokens = x.argmax(-1) 91 | 92 | # Compute the embeddings 93 | x = self.token_embedding(x) 94 | x = x + self.position_embedding.weight[:N] 95 | 96 | # Compute the features from the transformer 97 | mask = self._get_mask(N, x.dtype) 98 | hidden_states = [] 99 | for l in self.layers: 100 | x = l(x, mask) 101 | hidden_states.append(x) 102 | 103 | # Apply the final layernorm and return 104 | x = self.final_layer_norm(x) 105 | last_hidden_state = x 106 | 107 | # Select the EOS token 108 | pooled_output = x[mx.arange(len(x)), eos_tokens] 109 | if "text_projection" in self: 110 | pooled_output = self.text_projection(pooled_output) 111 | 112 | return CLIPOutput( 113 | pooled_output=pooled_output, 114 | last_hidden_state=last_hidden_state, 115 | hidden_states=hidden_states, 116 | ) 117 | -------------------------------------------------------------------------------- /stable_diffusion/stable_diffusion/config.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional, Tuple 5 | 6 | 7 | @dataclass 8 | class AutoencoderConfig: 9 | in_channels: int = 3 10 | out_channels: int = 3 11 | latent_channels_out: int = 8 12 | latent_channels_in: int = 4 13 | block_out_channels: Tuple[int] = (128, 256, 512, 512) 14 | layers_per_block: int = 2 15 | norm_num_groups: int = 32 16 | scaling_factor: float = 0.18215 17 | 18 | 19 | @dataclass 20 | class CLIPTextModelConfig: 21 | num_layers: int = 23 22 | model_dims: int = 1024 23 | num_heads: int = 16 24 | max_length: int = 77 25 | vocab_size: int = 49408 26 | projection_dim: Optional[int] = None 27 | hidden_act: str = "quick_gelu" 28 | 29 | 30 | @dataclass 31 | class UNetConfig: 32 | in_channels: int = 4 33 | out_channels: int = 4 34 | conv_in_kernel: int = 3 35 | conv_out_kernel: int = 3 36 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280) 37 | layers_per_block: Tuple[int] = (2, 2, 2, 2) 38 | mid_block_layers: int = 2 39 | transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1) 40 | num_attention_heads: Tuple[int] = (5, 10, 20, 20) 41 | cross_attention_dim: Tuple[int] = (1024,) * 4 42 | norm_num_groups: int = 32 43 | down_block_types: Tuple[str] = ( 44 | "CrossAttnDownBlock2D", 45 | "CrossAttnDownBlock2D", 46 | "CrossAttnDownBlock2D", 47 | "DownBlock2D", 48 | ) 49 | up_block_types: Tuple[str] = ( 50 | "UpBlock2D", 51 | "CrossAttnUpBlock2D", 52 | "CrossAttnUpBlock2D", 53 | "CrossAttnUpBlock2D", 54 | ) 55 | addition_embed_type: Optional[str] = None 56 | addition_time_embed_dim: Optional[int] = None 57 | projection_class_embeddings_input_dim: Optional[int] = None 58 | 59 | 60 | @dataclass 61 | class DiffusionConfig: 62 | beta_schedule: str = "scaled_linear" 63 | beta_start: float = 0.00085 64 | beta_end: float = 0.012 65 | num_train_steps: int = 1000 66 | -------------------------------------------------------------------------------- /stable_diffusion/stable_diffusion/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import mlx.core as mx 4 | 5 | from .config import DiffusionConfig 6 | 7 | 8 | def _linspace(a, b, num): 9 | x = mx.arange(0, num) / (num - 1) 10 | return (b - a) * x + a 11 | 12 | 13 | def _interp(y, x_new): 14 | """Interpolate the function defined by (arange(0, len(y)), y) at positions x_new.""" 15 | x_low = x_new.astype(mx.int32) 16 | x_high = mx.minimum(x_low + 1, len(y) - 1) 17 | 18 | y_low = y[x_low] 19 | y_high = y[x_high] 20 | delta_x = x_new - x_low 21 | y_new = y_low * (1 - delta_x) + delta_x * y_high 22 | 23 | return y_new 24 | 25 | 26 | class SimpleEulerSampler: 27 | """A simple Euler integrator that can be used to sample from our diffusion models. 28 | 29 | The method ``step()`` performs one Euler step from x_t to x_t_prev. 30 | """ 31 | 32 | def __init__(self, config: DiffusionConfig): 33 | # Compute the noise schedule 34 | if config.beta_schedule == "linear": 35 | betas = _linspace( 36 | config.beta_start, config.beta_end, config.num_train_steps 37 | ) 38 | elif config.beta_schedule == "scaled_linear": 39 | betas = _linspace( 40 | config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps 41 | ).square() 42 | else: 43 | raise NotImplementedError(f"{config.beta_schedule} is not implemented.") 44 | 45 | alphas = 1 - betas 46 | alphas_cumprod = mx.cumprod(alphas) 47 | 48 | self._sigmas = mx.concatenate( 49 | [mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()] 50 | ) 51 | 52 | @property 53 | def max_time(self): 54 | return len(self._sigmas) - 1 55 | 56 | def sample_prior(self, shape, dtype=mx.float32, key=None): 57 | noise = mx.random.normal(shape, key=key) 58 | return ( 59 | noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt() 60 | ).astype(dtype) 61 | 62 | def add_noise(self, x, t, key=None): 63 | noise = mx.random.normal(x.shape, key=key) 64 | s = self.sigmas(t) 65 | return (x + noise * s) * (s.square() + 1).rsqrt() 66 | 67 | def sigmas(self, t): 68 | return _interp(self._sigmas, t) 69 | 70 | def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32): 71 | start_time = start_time or (len(self._sigmas) - 1) 72 | assert 0 < start_time <= (len(self._sigmas) - 1) 73 | steps = _linspace(start_time, 0, num_steps + 1).astype(dtype) 74 | return list(zip(steps, steps[1:])) 75 | 76 | def step(self, eps_pred, x_t, t, t_prev): 77 | sigma = self.sigmas(t).astype(eps_pred.dtype) 78 | sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype) 79 | 80 | dt = sigma_prev - sigma 81 | x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt 82 | 83 | x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt() 84 | 85 | return x_t_prev 86 | 87 | 88 | class SimpleEulerAncestralSampler(SimpleEulerSampler): 89 | def step(self, eps_pred, x_t, t, t_prev): 90 | sigma = self.sigmas(t).astype(eps_pred.dtype) 91 | sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype) 92 | 93 | sigma2 = sigma.square() 94 | sigma_prev2 = sigma_prev.square() 95 | sigma_up = (sigma_prev2 * (sigma2 - sigma_prev2) / sigma2).sqrt() 96 | sigma_down = (sigma_prev2 - sigma_up**2).sqrt() 97 | 98 | dt = sigma_down - sigma 99 | x_t_prev = (sigma2 + 1).sqrt() * x_t + eps_pred * dt 100 | noise = mx.random.normal(x_t_prev.shape).astype(x_t_prev.dtype) 101 | x_t_prev = x_t_prev + noise * sigma_up 102 | 103 | x_t_prev = x_t_prev * (sigma_prev2 + 1).rsqrt() 104 | 105 | return x_t_prev 106 | -------------------------------------------------------------------------------- /stable_diffusion/stable_diffusion/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import regex 4 | 5 | 6 | class Tokenizer: 7 | """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" 8 | 9 | def __init__(self, bpe_ranks, vocab): 10 | self.bpe_ranks = bpe_ranks 11 | self.vocab = vocab 12 | self.pat = regex.compile( 13 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 14 | regex.IGNORECASE, 15 | ) 16 | 17 | self._cache = {self.bos: self.bos, self.eos: self.eos} 18 | 19 | @property 20 | def bos(self): 21 | return "<|startoftext|>" 22 | 23 | @property 24 | def bos_token(self): 25 | return self.vocab[self.bos] 26 | 27 | @property 28 | def eos(self): 29 | return "<|endoftext|>" 30 | 31 | @property 32 | def eos_token(self): 33 | return self.vocab[self.eos] 34 | 35 | def bpe(self, text): 36 | if text in self._cache: 37 | return self._cache[text] 38 | 39 | unigrams = list(text[:-1]) + [text[-1] + ""] 40 | unique_bigrams = set(zip(unigrams, unigrams[1:])) 41 | 42 | if not unique_bigrams: 43 | return unigrams 44 | 45 | # In every iteration try to merge the two most likely bigrams. If none 46 | # was merged we are done. 47 | # 48 | # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py 49 | while unique_bigrams: 50 | bigram = min( 51 | unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")) 52 | ) 53 | if bigram not in self.bpe_ranks: 54 | break 55 | 56 | new_unigrams = [] 57 | skip = False 58 | for a, b in zip(unigrams, unigrams[1:]): 59 | if skip: 60 | skip = False 61 | continue 62 | 63 | if (a, b) == bigram: 64 | new_unigrams.append(a + b) 65 | skip = True 66 | 67 | else: 68 | new_unigrams.append(a) 69 | 70 | if not skip: 71 | new_unigrams.append(b) 72 | 73 | unigrams = new_unigrams 74 | unique_bigrams = set(zip(unigrams, unigrams[1:])) 75 | 76 | self._cache[text] = unigrams 77 | 78 | return unigrams 79 | 80 | def tokenize(self, text, prepend_bos=True, append_eos=True): 81 | if isinstance(text, list): 82 | return [self.tokenize(t, prepend_bos, append_eos) for t in text] 83 | 84 | # Lower case cleanup and split according to self.pat. Hugging Face does 85 | # a much more thorough job here but this should suffice for 95% of 86 | # cases. 87 | clean_text = regex.sub(r"\s+", " ", text.lower()) 88 | tokens = regex.findall(self.pat, clean_text) 89 | 90 | # Split the tokens according to the byte-pair merge file 91 | bpe_tokens = [ti for t in tokens for ti in self.bpe(t)] 92 | 93 | # Map to token ids and return 94 | tokens = [self.vocab[t] for t in bpe_tokens] 95 | if prepend_bos: 96 | tokens = [self.bos_token] + tokens 97 | if append_eos: 98 | tokens.append(self.eos_token) 99 | 100 | return tokens 101 | -------------------------------------------------------------------------------- /stable_diffusion/still-life.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/stable_diffusion/still-life.png -------------------------------------------------------------------------------- /stable_diffusion/txt2image.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | import numpy as np 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | from stable_diffusion import StableDiffusion, StableDiffusionXL 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser( 15 | description="Generate images from a textual prompt using stable diffusion" 16 | ) 17 | parser.add_argument("prompt") 18 | parser.add_argument("--model", choices=["sd", "sdxl"], default="sdxl") 19 | parser.add_argument("--n_images", type=int, default=4) 20 | parser.add_argument("--steps", type=int) 21 | parser.add_argument("--cfg", type=float) 22 | parser.add_argument("--negative_prompt", default="") 23 | parser.add_argument("--n_rows", type=int, default=1) 24 | parser.add_argument("--decoding_batch_size", type=int, default=1) 25 | parser.add_argument("--no-float16", dest="float16", action="store_false") 26 | parser.add_argument("--quantize", "-q", action="store_true") 27 | parser.add_argument("--preload-models", action="store_true") 28 | parser.add_argument("--output", default="out.png") 29 | parser.add_argument("--seed", type=int) 30 | parser.add_argument("--verbose", "-v", action="store_true") 31 | args = parser.parse_args() 32 | 33 | # Load the models 34 | if args.model == "sdxl": 35 | sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16) 36 | if args.quantize: 37 | nn.quantize( 38 | sd.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear) 39 | ) 40 | nn.quantize( 41 | sd.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear) 42 | ) 43 | nn.quantize(sd.unet, group_size=32, bits=8) 44 | args.cfg = args.cfg or 0.0 45 | args.steps = args.steps or 2 46 | else: 47 | sd = StableDiffusion( 48 | "stabilityai/stable-diffusion-2-1-base", float16=args.float16 49 | ) 50 | if args.quantize: 51 | nn.quantize( 52 | sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear) 53 | ) 54 | nn.quantize(sd.unet, group_size=32, bits=8) 55 | args.cfg = args.cfg or 7.5 56 | args.steps = args.steps or 50 57 | 58 | # Ensure that models are read in memory if needed 59 | if args.preload_models: 60 | sd.ensure_models_are_loaded() 61 | 62 | # Generate the latent vectors using diffusion 63 | latents = sd.generate_latents( 64 | args.prompt, 65 | n_images=args.n_images, 66 | cfg_weight=args.cfg, 67 | num_steps=args.steps, 68 | seed=args.seed, 69 | negative_text=args.negative_prompt, 70 | ) 71 | for x_t in tqdm(latents, total=args.steps): 72 | mx.eval(x_t) 73 | 74 | # The following is not necessary but it may help in memory 75 | # constrained systems by reusing the memory kept by the unet and the text 76 | # encoders. 77 | if args.model == "sdxl": 78 | del sd.text_encoder_1 79 | del sd.text_encoder_2 80 | else: 81 | del sd.text_encoder 82 | del sd.unet 83 | del sd.sampler 84 | peak_mem_unet = mx.metal.get_peak_memory() / 1024**3 85 | 86 | # Decode them into images 87 | decoded = [] 88 | for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): 89 | decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size])) 90 | mx.eval(decoded[-1]) 91 | peak_mem_overall = mx.metal.get_peak_memory() / 1024**3 92 | 93 | # Arrange them on a grid 94 | x = mx.concatenate(decoded, axis=0) 95 | x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)]) 96 | B, H, W, C = x.shape 97 | x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4) 98 | x = x.reshape(args.n_rows * H, B // args.n_rows * W, C) 99 | x = (x * 255).astype(mx.uint8) 100 | 101 | # Save them to disc 102 | im = Image.fromarray(np.array(x)) 103 | im.save(args.output) 104 | 105 | # Report the peak memory used during generation 106 | if args.verbose: 107 | print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB") 108 | print(f"Peak memory used overall: {peak_mem_overall:.3f}GB") 109 | -------------------------------------------------------------------------------- /t5/.gitignore: -------------------------------------------------------------------------------- 1 | *.npz 2 | -------------------------------------------------------------------------------- /t5/README.md: -------------------------------------------------------------------------------- 1 | # T5 2 | 3 | The T5 models are encoder-decoder models pre-trained on a mixture of 4 | unsupervised and supervised tasks.[^1] These models work well on a variety of 5 | tasks by prepending task-specific prefixes to the input, e.g.: 6 | `translate English to German: …`, `summarize: ….`, etc. 7 | 8 | This example also supports the FLAN-T5 models variants.[^2] 9 | 10 | ## Generate 11 | 12 | Generate text with: 13 | 14 | ```sh 15 | python t5.py --model t5-small --prompt "translate English to German: A tasty apple" 16 | ``` 17 | 18 | This should give the output: `Ein leckerer Apfel` 19 | 20 | To see a list of options run: 21 | 22 | ```sh 23 | python t5.py --help 24 | ``` 25 | 26 | The `` can be any of the following: 27 | 28 | | Model Name | Model Size | 29 | | ---------- | ---------- 30 | | t5-small | 60 million | 31 | | t5-base | 220 million | 32 | | t5-large | 770 million | 33 | | t5-3b | 3 billion | 34 | | t5-11b | 11 billion | 35 | 36 | The FLAN variants can be specified with `google/flan-t5-small`, 37 | `google/flan-t5-base`, etc. See the [Hugging Face 38 | page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a 39 | complete list of models. 40 | 41 | [^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683) 42 | or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5). 43 | [^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416). 44 | -------------------------------------------------------------------------------- /t5/hf_t5.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5EncoderModel 4 | 5 | 6 | def embed(t5_model: str): 7 | batch = [ 8 | "translate English to German: That is good.", 9 | "This is an example of T5 working on MLX.", 10 | ] 11 | 12 | tokenizer = AutoTokenizer.from_pretrained(t5_model) 13 | torch_model = T5EncoderModel.from_pretrained(t5_model) 14 | torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) 15 | torch_forward = torch_model(**torch_tokens, output_hidden_states=True) 16 | torch_output = torch_forward.last_hidden_state.detach().numpy() 17 | 18 | print("\n TF BERT:") 19 | for input_str, embedding in list(zip(batch, torch_output)): 20 | print("Input:", input_str) 21 | print(embedding) 22 | print() 23 | 24 | 25 | def generate(t5_model: str): 26 | prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast." 27 | tokenizer = AutoTokenizer.from_pretrained(t5_model) 28 | torch_model = AutoModelForSeq2SeqLM.from_pretrained(t5_model) 29 | torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids 30 | outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512) 31 | print(tokenizer.decode(outputs[0], skip_special_tokens=True)) 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser( 36 | description="Run the T5 model using Hugging Face Transformers." 37 | ) 38 | parser.add_argument( 39 | "--encode-only", 40 | action="store_true", 41 | help="Only run the encoder and print the embeddings.", 42 | default=False, 43 | ) 44 | parser.add_argument( 45 | "--model", 46 | default="t5-small", 47 | help="The huggingface name of the T5 model to save.", 48 | ) 49 | args = parser.parse_args() 50 | if args.encode_only: 51 | embed(args.model) 52 | else: 53 | generate(args.model) 54 | -------------------------------------------------------------------------------- /t5/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.8.0 2 | numpy 3 | transformers 4 | -------------------------------------------------------------------------------- /transformer_lm/README.md: -------------------------------------------------------------------------------- 1 | # Transformer LM 2 | 3 | This is an example of a decoder-only Transformer LM. The only dependency is 4 | MLX. 5 | 6 | Run the example on the GPU with: 7 | 8 | ``` 9 | python main.py --gpu 10 | ``` 11 | 12 | By default the dataset is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). Choose a different dataset with the `--dataset` option. 13 | -------------------------------------------------------------------------------- /transformer_lm/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import io 4 | import itertools 5 | import os 6 | import zipfile 7 | from urllib import request 8 | 9 | import numpy as np 10 | 11 | 12 | def load_dataset(dataname): 13 | if dataname == "enwik8": 14 | return enwik8() 15 | elif dataname == "ptb": 16 | return ptb() 17 | elif dataname == "wikitext2": 18 | return wikitext(dataset="2") 19 | else: 20 | return wikitext(dataset="103") 21 | 22 | 23 | def _load(save_dir, filenames): 24 | # *NB* First file is expected to be the training set 25 | with open(os.path.join(save_dir, filenames[0]), "r") as fid: 26 | vocab = set(t for l in fid.readlines() for t in l.strip().split(" ")) 27 | eos = "" 28 | vocab.add(eos) 29 | vocab = {v: i for i, v in enumerate(vocab)} 30 | 31 | def to_array(dataset): 32 | with open(os.path.join(save_dir, dataset), "r") as fid: 33 | lines = (l.strip().split(" ") for l in fid.readlines()) 34 | return np.array( 35 | [vocab[w] for line in lines for w in itertools.chain(line, [eos])], 36 | dtype=np.uint32, 37 | ) 38 | 39 | datasets = [to_array(fn) for fn in filenames] 40 | return vocab, *datasets 41 | 42 | 43 | def wikitext(dataset="2", save_dir="/tmp"): 44 | """ 45 | Load the WikiText-* language modeling dataset: 46 | https://paperswithcode.com/dataset/wikitext-2 47 | https://paperswithcode.com/dataset/wikitext-103 48 | 49 | """ 50 | if dataset not in ("2", "103"): 51 | raise ValueError(f'Dataset must be either "2" or "103", got {dataset}') 52 | 53 | filenames = ["wiki.train.tokens", "wiki.valid.tokens", "wiki.test.tokens"] 54 | dataname = f"wikitext-{dataset}" 55 | data_dir = os.path.join(save_dir, dataname) 56 | if not os.path.exists(data_dir): 57 | base_url = "https://s3.amazonaws.com/research.metamind.io/wikitext/" 58 | zip_file_url = base_url + dataname + "-v1.zip" 59 | r = request.urlopen(zip_file_url) 60 | with zipfile.ZipFile(io.BytesIO(r.read())) as zf: 61 | zf.extractall(save_dir) 62 | 63 | return _load(data_dir, filenames) 64 | 65 | 66 | def ptb(save_dir="/tmp"): 67 | """ 68 | Load the PTB language modeling dataset: 69 | https://paperswithcode.com/dataset/penn-treebank 70 | """ 71 | filenames = [ 72 | "ptb.train.txt", 73 | "ptb.valid.txt", 74 | "ptb.test.txt", 75 | ] 76 | 77 | def download_and_save(save_dir): 78 | base_url = "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/" 79 | for name in filenames: 80 | out_file = os.path.join(save_dir, name) 81 | if not os.path.exists(out_file): 82 | request.urlretrieve(base_url + name, out_file) 83 | 84 | save_dir = os.path.join(save_dir, "ptb") 85 | if not os.path.exists(save_dir): 86 | os.mkdir(save_dir) 87 | download_and_save(save_dir) 88 | 89 | return _load(save_dir, filenames) 90 | 91 | 92 | def enwik8(save_dir="/tmp"): 93 | """ 94 | Load the enwik8 language modeling dataset: 95 | https://mattmahoney.net/dc/textdata.html 96 | """ 97 | out_file = os.path.join(save_dir, "enwik8.zip") 98 | if not os.path.exists(out_file): 99 | request.urlretrieve("http://mattmahoney.net/dc/enwik8.zip", out_file) 100 | 101 | with zipfile.ZipFile(out_file) as zf: 102 | data = zf.read("enwik8") 103 | 104 | num_test_bytes = 5000000 # 90 + 5 + 5 split 105 | 106 | train_data = data[: -2 * num_test_bytes] 107 | valid_data = data[-2 * num_test_bytes : -num_test_bytes] 108 | test_data = data[-num_test_bytes:] 109 | 110 | vocab = set(c for c in train_data) 111 | vocab = {c: i for i, c in enumerate(vocab)} 112 | 113 | def to_array(dataset): 114 | return np.array([vocab[c] for c in dataset], dtype=np.uint32) 115 | 116 | return vocab, to_array(train_data), to_array(valid_data), to_array(test_data) 117 | 118 | 119 | if __name__ == "__main__": 120 | vocab, train, val, test = enwik8() 121 | assert len(vocab) == 205, "enwik8: Wrong vocab size" 122 | 123 | vocab, train, val, test = ptb() 124 | assert len(vocab) == 10000, "PTB: Wrong vocab size" 125 | 126 | vocab, train, val, test = wikitext() 127 | assert len(vocab) == 33279, "WikiText: Wrong vocab size" 128 | -------------------------------------------------------------------------------- /transformer_lm/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.2 2 | -------------------------------------------------------------------------------- /whisper/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include mlx_whisper/requirements.txt 2 | include mlx_whisper/assets/mel_filters.npz 3 | include mlx_whisper/assets/multilingual.tiktoken 4 | include mlx_whisper/assets/gpt2.tiktoken 5 | -------------------------------------------------------------------------------- /whisper/README.md: -------------------------------------------------------------------------------- 1 | # Whisper 2 | 3 | Speech recognition with Whisper in MLX. Whisper is a set of open source speech 4 | recognition models from OpenAI, ranging from 39 million to 1.5 billion 5 | parameters.[^1] 6 | 7 | ### Setup 8 | 9 | Install [`ffmpeg`](https://ffmpeg.org/): 10 | 11 | ``` 12 | # on macOS using Homebrew (https://brew.sh/) 13 | brew install ffmpeg 14 | ``` 15 | 16 | Install the `mlx-whisper` package with: 17 | 18 | ``` 19 | pip install mlx-whisper 20 | ``` 21 | 22 | ### Run 23 | 24 | #### CLI 25 | 26 | At its simplest: 27 | 28 | ```sh 29 | mlx_whisper audio_file.mp3 30 | ``` 31 | 32 | This will make a text file `audio_file.txt` with the results. 33 | 34 | Use `-f` to specify the output format and `--model` to specify the model. There 35 | are many other supported command line options. To see them all, run 36 | `mlx_whisper -h`. 37 | 38 | You can also pipe the audio content of other programs via stdin: 39 | 40 | ```sh 41 | some-process | mlx_whisper - 42 | ``` 43 | 44 | The default output file name will be `content.*`. You can specify the name with 45 | the `--output-name` flag. 46 | 47 | #### API 48 | 49 | Transcribe audio with: 50 | 51 | ```python 52 | import mlx_whisper 53 | 54 | text = mlx_whisper.transcribe(speech_file)["text"] 55 | ``` 56 | 57 | The default model is "mlx-community/whisper-tiny". Choose the model by 58 | setting `path_or_hf_repo`. For example: 59 | 60 | ```python 61 | result = mlx_whisper.transcribe(speech_file, path_or_hf_repo="models/large") 62 | ``` 63 | 64 | This will load the model contained in `models/large`. The `path_or_hf_repo` can 65 | also point to an MLX-style Whisper model on the Hugging Face Hub. In this case, 66 | the model will be automatically downloaded. A [collection of pre-converted 67 | Whisper 68 | models](https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc) 69 | are in the Hugging Face MLX Community. 70 | 71 | The `transcribe` function also supports word-level timestamps. You can generate 72 | these with: 73 | 74 | ```python 75 | output = mlx_whisper.transcribe(speech_file, word_timestamps=True) 76 | print(output["segments"][0]["words"]) 77 | ``` 78 | 79 | To see more transcription options use: 80 | 81 | ``` 82 | >>> help(mlx_whisper.transcribe) 83 | ``` 84 | 85 | ### Converting models 86 | 87 | > [!TIP] 88 | > Skip the conversion step by using pre-converted checkpoints from the Hugging 89 | > Face Hub. There are a few available in the [MLX 90 | > Community](https://huggingface.co/mlx-community) organization. 91 | 92 | To convert a model, first clone the MLX Examples repo: 93 | 94 | ``` 95 | git clone https://github.com/ml-explore/mlx-examples.git 96 | ``` 97 | 98 | Then run `convert.py` from `mlx-examples/whisper`. For example, to convert the 99 | `tiny` model use: 100 | 101 | ``` 102 | python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny 103 | ``` 104 | 105 | Note you can also convert a local PyTorch checkpoint which is in the original 106 | OpenAI format. 107 | 108 | To generate a 4-bit quantized model, use `-q`. For a full list of options: 109 | 110 | ``` 111 | python convert.py --help 112 | ``` 113 | 114 | By default, the conversion script will make the directory `mlx_models` 115 | and save the converted `weights.npz` and `config.json` there. 116 | 117 | Each time it is run, `convert.py` will overwrite any model in the provided 118 | path. To save different models, make sure to set `--mlx-path` to a unique 119 | directory for each converted model. For example: 120 | 121 | ```bash 122 | model="tiny" 123 | python convert.py --torch-name-or-path ${model} --mlx-path mlx_models/${model}_fp16 124 | python convert.py --torch-name-or-path ${model} --dtype float32 --mlx-path mlx_models/${model}_fp32 125 | python convert.py --torch-name-or-path ${model} -q --q_bits 4 --mlx-path mlx_models/${model}_quantized_4bits 126 | ``` 127 | 128 | [^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details. 129 | -------------------------------------------------------------------------------- /whisper/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | import argparse 3 | import os 4 | import time 5 | 6 | import mlx.core as mx 7 | from mlx_whisper import audio, decoding, load_models, transcribe 8 | 9 | audio_file = "mlx_whisper/assets/ls_test.flac" 10 | 11 | 12 | def parse_arguments(): 13 | parser = argparse.ArgumentParser(description="Benchmark script.") 14 | parser.add_argument( 15 | "--mlx-dir", 16 | type=str, 17 | default="mlx_models", 18 | help="The folder of MLX models", 19 | ) 20 | parser.add_argument( 21 | "--all", 22 | action="store_true", 23 | help="Use all available models, i.e. tiny,small,medium,large-v3", 24 | ) 25 | parser.add_argument( 26 | "-m", 27 | "--models", 28 | type=str, 29 | help="Specify models as a comma-separated list (e.g., tiny,small,medium)", 30 | ) 31 | return parser.parse_args() 32 | 33 | 34 | def timer(fn, *args): 35 | for _ in range(5): 36 | fn(*args) 37 | 38 | num_its = 10 39 | 40 | tic = time.perf_counter() 41 | for _ in range(num_its): 42 | fn(*args) 43 | toc = time.perf_counter() 44 | return (toc - tic) / num_its 45 | 46 | 47 | def feats(n_mels: int = 80): 48 | data = audio.load_audio(audio_file) 49 | data = audio.pad_or_trim(data) 50 | mels = audio.log_mel_spectrogram(data, n_mels) 51 | mx.eval(mels) 52 | return mels 53 | 54 | 55 | def model_forward(model, mels, tokens): 56 | logits = model(mels, tokens) 57 | mx.eval(logits) 58 | return logits 59 | 60 | 61 | def decode(model, mels): 62 | return decoding.decode(model, mels) 63 | 64 | 65 | def everything(model_path): 66 | return transcribe(audio_file, path_or_hf_repo=model_path) 67 | 68 | 69 | if __name__ == "__main__": 70 | args = parse_arguments() 71 | if args.all: 72 | models = ["tiny", "small", "medium", "large-v3"] 73 | elif args.models: 74 | models = args.models.split(",") 75 | else: 76 | models = ["tiny"] 77 | 78 | print("Selected models:", models) 79 | 80 | feat_time = timer(feats) 81 | print(f"\nFeature time {feat_time:.3f}") 82 | 83 | for model_name in models: 84 | model_path = f"mlx-community/whisper-{model_name}-mlx" 85 | print(f"\nModel: {model_name.upper()}") 86 | tokens = mx.array( 87 | [ 88 | 50364, 89 | 1396, 90 | 264, 91 | 665, 92 | 5133, 93 | 23109, 94 | 25462, 95 | 264, 96 | 6582, 97 | 293, 98 | 750, 99 | 632, 100 | 42841, 101 | 292, 102 | 370, 103 | 938, 104 | 294, 105 | 4054, 106 | 293, 107 | 12653, 108 | 356, 109 | 50620, 110 | 50620, 111 | 23563, 112 | 322, 113 | 3312, 114 | 13, 115 | 50680, 116 | ], 117 | mx.int32, 118 | )[None] 119 | model = load_models.load_model(path_or_hf_repo=model_path, dtype=mx.float16) 120 | mels = feats(model.dims.n_mels)[None].astype(mx.float16) 121 | model_forward_time = timer(model_forward, model, mels, tokens) 122 | print(f"Model forward time {model_forward_time:.3f}") 123 | decode_time = timer(decode, model, mels) 124 | print(f"Decode time {decode_time:.3f}") 125 | everything_time = timer(everything, model_path) 126 | print(f"Everything time {everything_time:.3f}") 127 | print(f"\n{'-----' * 10}\n") 128 | -------------------------------------------------------------------------------- /whisper/mlx_whisper/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from . import audio, decoding, load_models 4 | from ._version import __version__ 5 | from .transcribe import transcribe 6 | -------------------------------------------------------------------------------- /whisper/mlx_whisper/_version.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | __version__ = "0.4.1" 4 | -------------------------------------------------------------------------------- /whisper/mlx_whisper/assets/download_alice.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | audio_file=$HOME/.cache/whisper/alice.mp3 4 | echo $audio_file 5 | zipf=alice_in_wonderland_librivox_64kb_mp3.zip 6 | url=https://www.archive.org/download/alice_in_wonderland_librivox/ 7 | curl -LO $url/$zipf 8 | unzip $zipf 9 | mv wonderland_ch_02_64kb.mp3 $audio_file 10 | rm wonderland_* $zipf 11 | -------------------------------------------------------------------------------- /whisper/mlx_whisper/assets/ls_test.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/whisper/mlx_whisper/assets/ls_test.flac -------------------------------------------------------------------------------- /whisper/mlx_whisper/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-examples/4c9f9f9be798e6cf04fd0f74395a3b4420077aad/whisper/mlx_whisper/assets/mel_filters.npz -------------------------------------------------------------------------------- /whisper/mlx_whisper/load_models.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import json 4 | from pathlib import Path 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | from huggingface_hub import snapshot_download 9 | from mlx.utils import tree_unflatten 10 | 11 | from . import whisper 12 | 13 | 14 | def load_model( 15 | path_or_hf_repo: str, 16 | dtype: mx.Dtype = mx.float32, 17 | ) -> whisper.Whisper: 18 | model_path = Path(path_or_hf_repo) 19 | if not model_path.exists(): 20 | model_path = Path(snapshot_download(repo_id=path_or_hf_repo)) 21 | 22 | with open(str(model_path / "config.json"), "r") as f: 23 | config = json.loads(f.read()) 24 | config.pop("model_type", None) 25 | quantization = config.pop("quantization", None) 26 | 27 | model_args = whisper.ModelDimensions(**config) 28 | 29 | wf = model_path / "weights.safetensors" 30 | if not wf.exists(): 31 | wf = model_path / "weights.npz" 32 | weights = mx.load(str(wf)) 33 | 34 | model = whisper.Whisper(model_args, dtype) 35 | 36 | if quantization is not None: 37 | class_predicate = ( 38 | lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) 39 | and f"{p}.scales" in weights 40 | ) 41 | nn.quantize(model, **quantization, class_predicate=class_predicate) 42 | 43 | weights = tree_unflatten(list(weights.items())) 44 | model.update(weights) 45 | mx.eval(model.parameters()) 46 | return model 47 | -------------------------------------------------------------------------------- /whisper/mlx_whisper/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.11 2 | numba 3 | numpy 4 | torch 5 | tqdm 6 | more-itertools 7 | tiktoken 8 | huggingface_hub 9 | scipy 10 | -------------------------------------------------------------------------------- /whisper/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | from setuptools import find_namespace_packages, setup 7 | 8 | package_dir = Path(__file__).parent / "mlx_whisper" 9 | 10 | with open(package_dir / "requirements.txt") as fid: 11 | requirements = [l.strip() for l in fid.readlines()] 12 | 13 | sys.path.append(str(package_dir)) 14 | 15 | from _version import __version__ 16 | 17 | setup( 18 | name="mlx-whisper", 19 | version=__version__, 20 | description="OpenAI Whisper on Apple silicon with MLX and the Hugging Face Hub", 21 | long_description=open("README.md", encoding="utf-8").read(), 22 | long_description_content_type="text/markdown", 23 | readme="README.md", 24 | author_email="mlx@group.apple.com", 25 | author="MLX Contributors", 26 | url="https://github.com/ml-explore/mlx-examples", 27 | license="MIT", 28 | install_requires=requirements, 29 | packages=find_namespace_packages(), 30 | include_package_data=True, 31 | python_requires=">=3.8", 32 | entry_points={ 33 | "console_scripts": [ 34 | "mlx_whisper = mlx_whisper.cli:main", 35 | ] 36 | }, 37 | ) 38 | --------------------------------------------------------------------------------