├── .github └── pull_request_template.md ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── average_benchmark.md └── detailed_benchmark.md ├── images ├── plot_BCE.png ├── plot_Concat.png ├── plot_Conv2d.png ├── plot_Linear.png ├── plot_MatMul.png ├── plot_Sigmoid.png ├── plot_Softmax.png └── plot_Sort.png ├── mlx_benchmark ├── __init__.py ├── base_benchmark.py ├── config.py ├── get_cpu_gpu_config.py ├── operations │ ├── __init__.py │ ├── binary_cross_entropy.py │ ├── concat.py │ ├── conv.py │ ├── gather_scatter.py │ ├── layernorm.py │ ├── linear.py │ ├── matmul.py │ ├── scaled_dot_product_attention.py │ └── simple_operations.py ├── run_benchmark.py ├── run_viz.py └── utils.py └── requirements.txt /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | Thanks for opening this PR ⚡️ ! A few checks before submission: 2 | 3 | - [ ] I provide a description of the benchmark by running `python mlx_benchmark/get_cpu_gpu_config.py` 4 | - [ ] I verify that my benchmark ran on a newer version of MLX, or that my device is not yet integrated within the benchmark 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.10 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to mlx-benchmark 2 | 3 | This repo could not exist without the contributions from everybody. There are two main ways to contribute to this project: 4 | 5 | ## Add a new device 6 | If you have an Apple Silicon chip that is missing in the benchmark, or a popular CUDA GPU not yet benchmarked on torch benchmarks, you can add your own experimental results by running the bench locally on your device. 7 | 8 | Follow the [installation instructions](README.md#installation) and run the benchmark on `mps`, `mlx` and `cpu` devices if you propose a Mac-based benchmark. This is the default behavior when running the benchmark: 9 | 10 | ```shell 11 | python run_benchmark.py 12 | ``` 13 | 14 | If you propose a CUDA GPU-based benchmark, running the benchmark on `cpu` and `cuda` devices is enough: 15 | 16 | ```shell 17 | python run_benchmark.py --include_mps=False --include_mlx_cpu=False --include_mlx_gpu=False --include_mlx_gpu_compile=False --include_cuda=True 18 | ``` 19 | 20 | Once run, 2 tables will be printed. Copy-paste the detailed benchmark into [detailed_benchmark.md](benchmarks/detailed_benchmark.md) and do the same for the average benchmark into [average_benchmark.md](benchmarks/average_benchmark.md). You can then submit a pull request. To ensure consistency in the results, ensure that enough memory is available before running the benchmarks. 21 | 22 | Before submitting your PR, ensure to add the config of your M ship along with the version of MLX you're using. This can be easily done by running: 23 | 24 | ```python 25 | python mlx_benchmark/get_cpu_gpu_config.py 26 | 27 | >>> (Apple M1 Pro: 2E+8P+16GPU+16GB) - mlx: 0.2.0 28 | ``` 29 | 30 | ## Add a new operation 31 | 32 | Many layers and basic operations are still missing in the benchmark. New examples can be easily added to the benchmark, we take here the example of the `concat` operation. 33 | 34 | 1. Append your benchmarks to the existing ones within `run_benchmark.py`. 35 | 36 | ```python 37 | operations = [ 38 | ... 39 | Concat(dim1="1000000x64", dim2="1000000x32", axis=1), 40 | Concat(dim1="1000000x64", dim2="1000000x128", axis=1), 41 | Concat(dim1="1000000x64", dim2="1000000x64", axis=0), 42 | Concat(dim1="64x1000000", dim2="64x1000000", axis=0), 43 | ] 44 | ``` 45 | The arguments starting with `dim*` will create an input tensor of the given shape. If multiple dims are given, like in the previous example, the input tensors for the benchmark can be accessed using `self.inputs[0]` and `self.inputs[1]`. All other arguments such as `axis` can be accessed using `self.kwargs["axis"]`. 46 | 47 | 2. Create a new file and write the actual implementation of the operation, here in `operations/concat.py`. 48 | 49 | ```python 50 | from config import USE_MLX 51 | 52 | if USE_MLX: 53 | import mlx.core as mx 54 | 55 | import torch 56 | 57 | from base_benchmark import BaseBenchmark 58 | 59 | 60 | class Concat(BaseBenchmark): 61 | def __init__(self, **kwargs): 62 | super().__init__(**kwargs) 63 | 64 | def forward_mlx(self, **kwargs): 65 | a, b = self.inputs 66 | 67 | y = mx.concatenate([a, b], axis=self.kwargs["axis"]) 68 | mx.eval(y) 69 | 70 | @torch.no_grad() 71 | def forward_torch(self, **kwargs): 72 | a, b = self.inputs 73 | 74 | y = torch.cat([a, b], dim=self.kwargs["axis"]) 75 | self.sync_mps_if_needed() 76 | ``` 77 | 78 | The structure is almost always the same for all operations. The method `forward_mlx` is the actual implementation of the mlx operation, and the same applies for `forward_torch`. For the mlx implementation, `mx.eval(.)` should be used to compute the operation, whereas `self.sync_mps_if_needed()` should be used after the torch operation. 79 | 80 | If the default inputs provided within the args are not enough to implement the new benchmark, you can add your own attributes by overriding this method: 81 | 82 | ```python 83 | def additional_preprocessing(self, framework): 84 | if framework == "mlx": 85 | self.specific_input_for_mlx = ... 86 | 87 | def forward_mlx(self, **kwargs): 88 | a = self.specific_input_for_mlx 89 | ... 90 | ``` 91 | 92 | 3. Lastly, append the new operation in `operations/__init__.py`: 93 | 94 | ```python 95 | ... 96 | from .concat import Concat 97 | ``` 98 | 99 | For simpler operations that only require the default input tensors given by the `dim` provided in args, you can implement the `SimpleOperationBenchmark` class in the same way as done in `simple_operations.py`. 100 | 101 | ## New features 102 | 103 | Enhancements and new features are always welcome! Feel free to submit issues or pull requests. 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tristan Bilot 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ⚡️ mlx-benchmark ⚡️ 2 | ### A comprehensive benchmark of MLX ops. 3 | 4 | This repo aims to benchmark Apple's MLX operations and layers, on all Apple Silicon chips, along with some GPUs. 5 | 6 | **Contributions:** Everyone can contribute to the benchmark! If you have a missing device or if you want to add a missing layer/operation, please read the [contribution guidelines](CONTRIBUTING.md). 7 | 8 | Current M chips: `M1`, `M1 Pro`, `M1 Max`, `M2`, `M2 Pro`, `M2 Max`, `M2 Ultra`, `M3`, `M3 Pro`, `M3 Max`, `M3 Ultra`. 9 | 10 | Current CUDA GPUs: `RTX4090`, `Tesla V100`, `A100`. 11 | 12 | Missing devices: `M1 Ultra`, and `other CUDA GPUs`. 13 | 14 | > [!NOTE] 15 | > You can submit your benchmark even for a device that is already listed, provided you use a newer version of MLX. Simply submit a PR by overriding the old benchmark table. Also, most of the existing benchmarks do not include the `mx.compile` feature, which has been recently added to mlx-benchmark. 16 | 17 | ## Benchmarks 🧪 18 | 19 | Benchmarks are generated by measuring the runtime of every `mlx` operations on GPU and CPU, along with their equivalent in pytorch with `mps`, `cpu` and `cuda` backends. On MLX with GPU, the operations compiled with `mx.compile` are included in the benchmark by default. To not benchmark the compiled functions, set `--compile=False`. 20 | 21 | For each operation, we measure the runtime of multiple experiments. We propose 2 benchmarks based on these experiments: 22 | 23 | * [Detailed benchmark](benchmarks/detailed_benchmark.md): provides the runtime of each experiment. 24 | * [Average runtime benchmark](benchmarks/average_benchmark.md): computes the mean of experiments. Easier to navigate, with fewer details. 25 | 26 | 27 | ## Installation 💻 28 | 29 | 30 | ### Installation on Mac devices 31 | 32 | Running the benchmark locally is straightforward. Create a new env with `osx-arm64` architecture and install the dependencies. 33 | 34 | ```shell 35 | CONDA_SUBDIR=osx-arm64 conda create -n mlx_benchmark python=3.10 numpy pytorch torchvision scipy requests -c conda-forge 36 | 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | 41 | ### Installation on other devices 42 | Other operating systems than macOS can only run the torch experiments, on CPU or with a CUDA device. Install a new env without the `CONDA_SUBDIR=osx-arm64` prefix and install the torch package that matches your CUDA version. Then install all the requirements within `requirements.txt`, except `mlx`. 43 | 44 | Finally, open the `config.py` file and set: 45 | ``` 46 | USE_MLX = False 47 | ``` 48 | to avoid importing the mlx package, which cannot be installed on non-Mac devices. 49 | 50 | ## Run the benchmark 🧑‍💻 51 | 52 | ### Run on Mac 53 | 54 | To run the benchmark on mps, mlx and CPU: 55 | 56 | ```shell 57 | python run_benchmark.py --include_mps=True --include_mlx_gpu=True --include_mlx_cpu=True --include_cpu=True 58 | ``` 59 | 60 | ### Run on other devices 61 | 62 | To run the torch benchmark on CUDA and CPU: 63 | 64 | ```shell 65 | python run_benchmark.py --include_mps=False --include_mlx_gpu=False --include_mlx_cpu=False --include_cuda=True --include_cpu=True 66 | ``` 67 | 68 | ### Run only compiled functions 69 | 70 | If you're interested in benchmarking only operations against operations compiled with `mx.compile`, you can run: 71 | 72 | ```shell 73 | python run_benchmark.py --include_mps=False --include_cpu=False --include_mlx_cpu=False 74 | ``` 75 | 76 | ## Contributing 🚀 77 | 78 | If you have a device not yet featured in the benchmark, especially the ones listed below, your PR is welcome to broaden the scope and accuracy of this project. 79 | -------------------------------------------------------------------------------- /benchmarks/average_benchmark.md: -------------------------------------------------------------------------------- 1 | # Average MLX benchmark 2 | 3 | Averaged runtime benchmark of mlx operations, measured in `milliseconds`. 4 | 5 | * `mlx_gpu`: mlx framework with gpu backend 6 | * `mlx_cpu`: mlx framework with cpu backend 7 | * `cpu`: torch framework with cpu backend 8 | * `mps`: torch framework with mps (gpu) backend 9 | * `mlx_gpu/mps speedup`: runtime speedup of mlx_gpu compared to mps 10 | * `mlx_gpu/mlx_cpu speedup`: runtime speedup of mlx_gpu compared to mlx_cpu 11 | * `cuda/cpu speedup`: runtime speedup of cuda compared to cpu 12 | 13 | ## Apple Silicon 14 | 15 | **M1 (cores: 4E+4P+8GPU)** 16 | 17 | | Operation | mlx_gpu | mlx_cpu | mps | cpu | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 18 | |----------------|-------|-------|------|------|-------------------|-----------------------| 19 | | Argmax | 1.81 | 10.63 | 2.87 | 8.05 | +58% | +486% | 20 | | BCE | 5.51 | 51.81 | 12.19 | 10.87 | +121% | +840% | 21 | | Concat | 19.13 | 100.88 | 19.28 | 49.63 | +0% | +427% | 22 | | Conv1d | 3.83 | 4.53 | 3.73 | 116.13 | -2% | +18% | 23 | | Conv2d | 30.12 | 436.68 | 7.06 | 45.54 | -76% | +1349% | 24 | | LeakyReLU | 2.06 | 2.90 | 1.16 | 1.37 | -43% | +41% | 25 | | Linear | 30.41 | 73.32 | 53.70 | 117.68 | +76% | +141% | 26 | | MatMul | 26.38 | 93.82 | 47.87 | 504.47 | +81% | +255% | 27 | | PReLU | 3.50 | 4.54 | 1.15 | 1.32 | -67% | +29% | 28 | | ReLU | 0.98 | 0.90 | 1.13 | 1.35 | +14% | -8% | 29 | | SeLU | 7.81 | 14.73 | 1.14 | 7.72 | -85% | +88% | 30 | | Sigmoid | 0.96 | 32.66 | 1.16 | 7.23 | +19% | +3287% | 31 | | Softmax | 10.15 | 40.98 | 19.27 | 46.69 | +89% | +303% | 32 | | Softplus | 1.07 | 33.08 | 1.73 | 10.99 | +60% | +2977% | 33 | | Sort | 18.49 | 713.23 | 73.24 | 70.11 | +296% | +3756% | 34 | | Sum | 11.33 | 12.70 | 16.35 | 13.43 | +44% | +12% | 35 | | SumAll | 6.91 | 6.85 | 7.40 | 7.00 | +7% | 0% | 36 | 37 | **M1 Pro (2E+8P+16GPU+16GB) - mlx: 0.5.0** 38 | 39 | | Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 40 | |-----------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------| 41 | | Argmax | 1.75 | 1.74 | 10.55 | 1.02 | 8.19 | +0% | -41% | +503% | 42 | | BCE | 2.18 | 0.97 | 59.50 | 0.84 | 8.48 | +125% | -61% | +2629% | 43 | | Concat | 6.14 | 6.13 | 87.88 | 6.21 | 36.74 | +0% | +1% | +1332% | 44 | | Conv1d | 1.76 | 1.64 | 3.42 | 1.01 | 154.38 | +7% | -42% | +94% | 45 | | Conv2d | 5.71 | 5.67 | 443.83 | 2.52 | 42.12 | +0% | -55% | +7669% | 46 | | Gather | 3.15 | 3.17 | 4.95 | 18.87 | 9.03 | 0% | +498% | +57% | 47 | | LeakyReLU | 0.46 | 0.44 | 0.80 | 0.47 | 1.21 | +4% | +2% | +74% | 48 | | Linear | 9.57 | 9.76 | 34.65 | 33.21 | 127.82 | -1% | +246% | +261% | 49 | | MatMul | 10.52 | 10.65 | 38.29 | 22.76 | 498.70 | -1% | +116% | +263% | 50 | | PReLU | 0.48 | 0.46 | 3.37 | 0.55 | 1.07 | +3% | +15% | +607% | 51 | | ReLU | 0.47 | 0.43 | 0.63 | 0.55 | 1.08 | +9% | +18% | +34% | 52 | | Scatter | 0.59 | 0.57 | 30.02 | 3.38 | 1.94 | +2% | +473% | +5002% | 53 | | ScatterSum | 0.03 | 0.04 | 0.01 | nan | 1.47 | -14% | nan% | -71% | 54 | | ScatterMax | 0.03 | 0.04 | 0.01 | nan | 1.44 | -10% | nan% | -69% | 55 | | SeLU | 0.51 | 0.46 | 4.86 | 0.47 | 6.72 | +12% | -8% | +849% | 56 | | Sigmoid | 0.44 | 0.44 | 4.58 | 0.55 | 6.39 | +0% | +23% | +931% | 57 | | Softmax | 9.44 | 7.32 | 41.66 | 5.96 | 30.23 | +28% | -36% | +341% | 58 | | Softplus | 0.46 | 0.49 | 35.26 | 0.49 | 8.97 | -7% | +6% | +7646% | 59 | | Sort | 1.69 | 1.72 | 258.35 | 37.76 | 58.56 | -1% | +2129% | +15156% | 60 | | Sum | 3.38 | 3.46 | 9.25 | 6.06 | 10.02 | -2% | +79% | +173% | 61 | | SumAll | 2.52 | 2.63 | 6.83 | 2.48 | 3.46 | -4% | -1% | +171% | 62 | 63 | **M1 Max (64GB)** mlx 0.2.0 64 | 65 | | Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 66 | |-----------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------| 67 | | Argmax | 2.14 | 1.69 | 10.80 | 1.93 | 9.17 | +27% | -10% | +403% | 68 | | BCE | 1.30 | 0.65 | 50.27 | 1.01 | 8.09 | +98% | -22% | +3777% | 69 | | Concat | 3.20 | 3.20 | 92.35 | 3.27 | 24.79 | +0% | +2% | +2782% | 70 | | Conv1d | 2.20 | 0.98 | 3.34 | 1.18 | 157.26 | +124% | -46% | +51% | 71 | | Conv2d | 8.18 | 7.24 | 455.47 | 1.98 | 35.56 | +13% | -75% | +5468% | 72 | | Gather | 2.51 | 2.37 | 5.94 | 9.78 | 8.92 | +5% | +289% | +136% | 73 | | LeakyReLU | 0.54 | 0.34 | 4.40 | 0.45 | 0.63 | +59% | -15% | +719% | 74 | | Linear | 6.73 | 6.49 | 32.46 | 16.44 | 39.44 | +3% | +144% | +382% | 75 | | MatMul | 4.66 | 4.64 | 47.17 | 11.16 | 88.32 | +0% | +139% | +913% | 76 | | PReLU | 0.82 | 0.36 | 2.64 | 0.44 | 0.57 | +127% | -46% | +222% | 77 | | ReLU | 0.36 | 0.33 | 0.82 | 0.44 | 0.60 | +9% | +21% | +125% | 78 | | Scatter | 4.11 | 4.09 | 30.31 | 1.85 | 1.78 | +0% | -55% | +637% | 79 | | ScatterSum | 0.05 | 0.03 | 0.01 | nan | 1.35 | +42% | nan% | -81% | 80 | | ScatterMax | 0.05 | 0.03 | 0.01 | nan | 1.35 | +34% | nan% | -81% | 81 | | SeLU | 1.53 | 0.36 | 7.05 | 0.46 | 5.97 | +323% | -69% | +362% | 82 | | Sigmoid | 0.38 | 0.36 | 32.57 | 0.50 | 5.43 | +7% | +30% | +8409% | 83 | | Softmax | 4.84 | 3.71 | 43.48 | 3.88 | 28.93 | +30% | -19% | +798% | 84 | | Softplus | 0.57 | 0.34 | 32.79 | 0.67 | 8.25 | +65% | +17% | +5642% | 85 | | Sort | 1.08 | 0.97 | 257.18 | 20.18 | 49.30 | +10% | +1773% | +23780% | 86 | | Sum | 1.75 | 1.74 | 8.87 | 3.10 | 10.35 | +0% | +77% | +406% | 87 | | SumAll | 1.36 | 1.34 | 6.63 | 1.50 | 3.36 | +1% | +10% | +389% | 88 | 89 | **M2 () - mlx 0.2.0** 90 | 91 | | Operation | mlx_gpu | mlx_cpu | mps | cpu | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 92 | |----------------|-------|-------|------|------|-------------------|-----------------------| 93 | | Argmax | 1.71 | 16.21 | 2.71 | 8.51 | +58% | +849% | 94 | | BCE | 3.71 | 82.34 | 13.06 | 13.34 | +251% | +2118% | 95 | | Concat | 12.14 | 161.07 | 12.51 | 46.33 | +3% | +1226% | 96 | | Conv1d | 3.66 | 6.01 | 3.29 | 132.69 | -10% | +64% | 97 | | Conv2d | 27.74 | 705.78 | 5.94 | 56.88 | -78% | +2444% | 98 | | LeakyReLU | 1.50 | 2.44 | 1.10 | 1.33 | -26% | +62% | 99 | | Linear | 25.01 | 99.71 | 57.29 | 183.02 | +129% | +298% | 100 | | MatMul | 22.04 | 120.61 | 78.10 | 629.63 | +254% | +447% | 101 | | PReLU | 2.43 | 4.58 | 1.04 | 1.35 | -57% | +88% | 102 | | ReLU | 0.77 | 1.00 | 1.00 | 1.34 | +30% | +29% | 103 | | SeLU | 5.31 | 17.25 | 1.11 | 8.24 | -79% | +224% | 104 | | Sigmoid | 0.77 | 52.85 | 1.13 | 7.47 | +47% | +6797% | 105 | | Softmax | 7.07 | 65.62 | 14.54 | 60.92 | +105% | +828% | 106 | | Softplus | 0.91 | 53.94 | 1.73 | 12.12 | +90% | +5846% | 107 | | Sort | 16.87 | 1243.25 | 46.69 | 79.31 | +176% | +7269% | 108 | | Sum | 9.15 | 18.38 | 10.47 | 14.19 | +14% | +100% | 109 | | SumAll | 4.31 | 7.79 | 4.96 | 6.11 | +14% | +80% | 110 | 111 | **M2 Pro (cores: 4E+6P+16GPU)** mlx 0.12.2 torch 2.1.2 112 | 113 | | Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 114 | |-----------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------| 115 | | Argmax | 1.55 | 1.52 | 9.98 | 1.28 | 7.88 | +2% | -17% | +542% | 116 | | BCE | 2.03 | 0.82 | 59.53 | 0.74 | 8.01 | +146% | -63% | +2835% | 117 | | Concat | 6.17 | 6.42 | 86.32 | 6.26 | 36.48 | -3% | +1% | +1299% | 118 | | Conv1d | 1.62 | 1.49 | 3.05 | 0.90 | 147.44 | +8% | -44% | +88% | 119 | | Conv2d | 5.20 | 5.19 | 410.99 | 2.11 | 43.67 | +0% | -59% | +7797% | 120 | | Gather | 3.03 | 3.01 | 4.13 | 15.83 | 9.79 | +0% | +423% | +36% | 121 | | LeakyReLU | 0.36 | 0.36 | 0.90 | 0.44 | 0.93 | 0% | +19% | +146% | 122 | | Linear | 9.36 | 9.29 | 27.06 | 31.34 | 115.10 | +0% | +234% | +189% | 123 | | MatMul | 10.93 | 9.89 | 35.71 | 21.59 | 754.10 | +10% | +97% | +226% | 124 | | PReLU | 0.53 | 0.39 | 3.46 | 0.44 | 0.91 | +36% | -17% | +552% | 125 | | ReLU | 0.41 | 0.37 | 0.73 | 0.43 | 0.92 | +11% | +4% | +79% | 126 | | Scatter | 0.31 | 0.31 | 28.25 | 2.77 | 2.31 | 0% | +788% | +8959% | 127 | | ScatterSum | 0.04 | 0.03 | 0.02 | nan | 1.38 | +3% | nan% | -50% | 128 | | ScatterMax | 0.04 | 0.03 | 0.02 | nan | 1.38 | +7% | nan% | -49% | 129 | | SeLU | 0.49 | 0.43 | 4.85 | 0.51 | 2.66 | +13% | +4% | +899% | 130 | | Sigmoid | 0.37 | 0.37 | 4.33 | 0.48 | 2.23 | +1% | +26% | +1055% | 131 | | Softmax | 9.25 | 6.99 | 39.72 | 4.88 | 25.00 | +32% | -47% | +329% | 132 | | Softplus | 0.41 | 0.37 | 33.75 | 0.47 | 4.73 | +9% | +16% | +8220% | 133 | | Sort | 1.48 | 1.49 | 242.55 | 22.40 | 51.73 | 0% | +1414% | +16295% | 134 | | Sum | 3.24 | 3.22 | 9.11 | 3.09 | 10.11 | +0% | -4% | +180% | 135 | | SumAll | 2.37 | 2.37 | 6.58 | 2.36 | 3.31 | +0% | 0% | +176% | 136 | 137 | **M2 Max (cores: 4E+8P+38GPU)** mlx 0.5.0 torch 2.2.1 138 | 139 | | Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 140 | |-----------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------| 141 | | Argmax | 1.50 | 1.51 | 10.10 | 0.68 | 8.63 | 0% | -54% | +571% | 142 | | BCE | 1.00 | 0.44 | 59.91 | 0.60 | 9.01 | +126% | -40% | +5880% | 143 | | Concat | 3.18 | 3.16 | 83.04 | 3.34 | 27.24 | +0% | +4% | +2507% | 144 | | Conv1d | 0.86 | 0.76 | 3.03 | 0.53 | 160.40 | +13% | -38% | +250% | 145 | | Conv2d | 2.45 | 2.44 | 424.52 | 1.15 | 34.30 | +0% | -53% | +17249% | 146 | | Gather | 1.34 | 1.57 | 3.92 | 8.12 | 8.98 | -14% | +504% | +191% | 147 | | LeakyReLU | 0.22 | 0.30 | 0.72 | 0.30 | 1.21 | -24% | +35% | +219% | 148 | | Linear | 5.51 | 5.63 | 23.52 | 12.97 | 37.92 | -2% | +135% | +327% | 149 | | MatMul | 3.77 | 3.83 | 27.42 | 9.78 | 83.55 | -1% | +159% | +627% | 150 | | PReLU | 0.28 | 0.48 | 3.27 | 0.42 | 1.03 | -41% | +50% | +1062% | 151 | | ReLU | 0.37 | 0.24 | 0.62 | 0.35 | 0.94 | +51% | -5% | +67% | 152 | | Scatter | 0.22 | 0.24 | 28.88 | 1.47 | 1.82 | -9% | +567% | +12984% | 153 | | ScatterSum | 0.03 | 0.03 | 0.01 | nan | 1.37 | +10% | nan% | -69% | 154 | | ScatterMax | 0.03 | 0.03 | 0.01 | nan | 1.39 | +10% | nan% | -68% | 155 | | SeLU | 0.29 | 0.36 | 4.62 | 0.49 | 7.07 | -20% | +69% | +1511% | 156 | | Sigmoid | 0.24 | 0.27 | 4.34 | 0.36 | 6.35 | -10% | +52% | +1714% | 157 | | Softmax | 4.62 | 3.60 | 40.03 | 3.07 | 33.25 | +28% | -33% | +766% | 158 | | Softplus | 0.25 | 0.24 | 34.73 | 0.33 | 9.39 | +3% | +31% | +13696% | 159 | | Sort | 0.73 | 0.75 | 248.89 | 10.65 | 58.88 | -2% | +1360% | +34026% | 160 | | Sum | 1.61 | 1.64 | 9.22 | 1.96 | 12.05 | -1% | +21% | +472% | 161 | | SumAll | 1.20 | 1.23 | 6.86 | 1.32 | 3.84 | -2% | +9% | +471% | 162 | 163 | **M2 Ultra (cores: 8E+16P+76GPU)** mlx 0.7.0 164 | 165 | | Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 166 | |-----------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------| 167 | | Argmax | 1.60 | 1.63 | 9.46 | 0.65 | 9.49 | -1% | -59% | +492% | 168 | | BCE | 0.64 | 0.45 | 56.57 | 0.47 | 4.23 | +42% | -27% | +8702% | 169 | | Concat | 1.69 | 1.69 | 81.95 | 1.66 | 38.93 | +0% | -1% | +4743% | 170 | | Conv1d | 0.55 | 0.51 | 2.64 | 0.45 | 187.91 | +7% | -17% | +382% | 171 | | Conv2d | 1.35 | 1.38 | 409.78 | 0.67 | 46.05 | -1% | -50% | +30276% | 172 | | Gather | 0.77 | 0.79 | 3.83 | 3.92 | 11.82 | -2% | +407% | +395% | 173 | | LeakyReLU | 0.32 | 0.25 | 0.85 | 0.21 | 1.99 | +28% | -34% | +162% | 174 | | Linear | 2.26 | 2.23 | 16.83 | 6.67 | 39.12 | +1% | +195% | +645% | 175 | | MatMul | 2.53 | 2.53 | 19.21 | 5.59 | 66.55 | 0% | +121% | +660% | 176 | | PReLU | 0.37 | 0.45 | 3.15 | 0.32 | 1.61 | -18% | -13% | +759% | 177 | | ReLU | 0.29 | 0.24 | 0.67 | 0.33 | 1.61 | +20% | +13% | +132% | 178 | | Scatter | 0.25 | 0.25 | 27.04 | 0.73 | 1.49 | +0% | +193% | +10802% | 179 | | ScatterSum | 0.03 | 0.03 | 0.01 | nan | 1.36 | -1% | nan% | -76% | 180 | | ScatterMax | 0.03 | 0.03 | 0.01 | nan | 1.37 | +10% | nan% | -76% | 181 | | SeLU | 0.46 | 0.28 | 4.50 | 0.29 | 1.86 | +65% | -36% | +877% | 182 | | Sigmoid | 0.24 | 0.25 | 4.11 | 0.26 | 1.71 | -2% | +6% | +1606% | 183 | | Softmax | 2.47 | 1.88 | 39.27 | 1.35 | 17.90 | +31% | -45% | +1488% | 184 | | Softplus | 0.27 | 0.26 | 32.13 | 0.26 | 3.53 | +7% | -6% | +11598% | 185 | | Sort | 0.48 | 0.49 | 229.84 | 6.41 | 33.91 | -1% | +1231% | +47639% | 186 | | Sum | 0.90 | 0.91 | 9.22 | 0.95 | 6.80 | -1% | +6% | +925% | 187 | | SumAll | 0.70 | 0.71 | 6.70 | 0.83 | 1.97 | -1% | +19% | +859% | 188 | 189 | **M3 (RAM: 16GB)** - mlx 0.2.0 190 | 191 | Average benchmark: 192 | | Operation | mlx_gpu | mlx_cpu | mps | cpu | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 193 | |----------------|-------|-------|------|------|-------------------|-----------------------| 194 | | Argmax | 1.20 | 11.63 | 1.71 | 7.10 | +43% | +870% | 195 | | BCE | 4.05 | 40.80 | 8.59 | 8.14 | +111% | +906% | 196 | | Concat | 12.52 | 83.29 | 12.60 | 35.29 | +0% | +565% | 197 | | Conv1d | 2.34 | 3.66 | 1.98 | 71.23 | -15% | +56% | 198 | | Conv2d | 16.47 | 340.03 | 4.43 | 36.36 | -73% | +1965% | 199 | | LeakyReLU | 1.43 | 3.05 | 1.01 | 1.07 | -29% | +113% | 200 | | Linear | 21.55 | 71.89 | 15.84 | 122.32 | -26% | +233% | 201 | | MatMul | 15.49 | 76.57 | 33.24 | 490.48 | +114% | +394% | 202 | | PReLU | 2.36 | 2.76 | 0.99 | 1.11 | -58% | +16% | 203 | | ReLU | 0.76 | 1.39 | 0.96 | 1.01 | +26% | +81% | 204 | | SeLU | 5.23 | 7.72 | 1.02 | 6.88 | -80% | +47% | 205 | | Sigmoid | 0.79 | 26.97 | 1.07 | 5.69 | +35% | +3309% | 206 | | Softmax | 6.31 | 41.35 | 12.08 | 32.54 | +91% | +555% | 207 | | Softplus | 0.73 | 26.82 | 1.08 | 9.09 | +47% | +3569% | 208 | | Sort | 12.67 | 724.26 | 30.73 | 60.29 | +142% | +5616% | 209 | | Sum | 6.96 | 11.24 | 6.61 | 12.27 | -5% | +61% | 210 | | SumAll | 4.26 | 7.79 | 4.78 | 4.38 | +12% | +82% | 211 | 212 | **M3 Pro (cores: 6E+5P+14GPU)** 213 | 214 | | Operation | mlx_gpu | mlx_cpu | mps | cpu | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 215 | |----------------|-------|-------|------|------|-------------------|-----------------------| 216 | | Argmax | 0.98 | 11.21 | 1.24 | 6.14 | +25% | +1041% | 217 | | BCE | 2.70 | 39.88 | 6.87 | 6.78 | +154% | +1374% | 218 | | Concat | 8.25 | 78.33 | 8.87 | 38.10 | +7% | +849% | 219 | | Conv1d | 2.15 | 3.36 | 2.07 | 83.18 | -3% | +56% | 220 | | Conv2d | 12.06 | 333.03 | 3.09 | 33.87 | -74% | +2660% | 221 | | LeakyReLU | 1.54 | 1.53 | 1.26 | 0.96 | -18% | 0% | 222 | | Linear | 15.30 | 52.78 | 11.44 | 91.49 | -25% | +244% | 223 | | MatMul | 16.04 | 69.27 | 22.53 | 390.04 | +40% | +331% | 224 | | PReLU | 2.04 | 2.80 | 1.35 | 0.91 | -34% | +37% | 225 | | ReLU | 0.94 | 0.61 | 1.37 | 0.92 | +45% | -34% | 226 | | SeLU | 3.98 | 10.10 | 1.27 | 4.69 | -68% | +153% | 227 | | Sigmoid | 1.03 | 26.28 | 1.30 | 4.28 | +25% | +2446% | 228 | | Softmax | 4.62 | 32.54 | 9.32 | 29.78 | +101% | +604% | 229 | | Softplus | 1.02 | 25.95 | 1.26 | 6.52 | +23% | +2444% | 230 | | Sort | 8.67 | 711.98 | 21.37 | 46.71 | +146% | +8114% | 231 | | Sum | 4.73 | 9.81 | 5.12 | 8.83 | +8% | +107% | 232 | | SumAll | 3.17 | 4.71 | 3.69 | 3.44 | +16% | +48% | 233 | 234 | **M3 Max (cores: 4E+12P+40GPU)** mlx 0.2.0 235 | 236 | | Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 237 | |-----------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------| 238 | | Argmax | 1.57 | 1.56 | 8.34 | 1.02 | 6.14 | +0% | -35% | +430% | 239 | | BCE | 1.12 | 0.52 | 38.72 | 0.59 | 3.73 | +114% | -47% | +3362% | 240 | | Concat | 3.32 | 3.30 | 82.26 | 3.40 | 22.89 | +0% | +2% | +2380% | 241 | | Conv1d | 0.85 | 0.75 | 2.40 | 0.92 | 156.00 | +13% | +8% | +182% | 242 | | Conv2d | 4.21 | 4.14 | 329.47 | 1.42 | 31.25 | +1% | -66% | +7723% | 243 | | Gather | 1.56 | 1.47 | 4.37 | 8.23 | 6.68 | +5% | +428% | +180% | 244 | | LeakyReLU | 0.43 | 0.29 | 2.57 | 0.54 | 0.66 | +48% | +24% | +491% | 245 | | Linear | 5.66 | 5.66 | 24.67 | 4.24 | 59.04 | +0% | -25% | +336% | 246 | | MatMul | 4.20 | 4.19 | 25.57 | 7.62 | 585.74 | +0% | +81% | +508% | 247 | | PReLU | 0.70 | 0.29 | 2.06 | 0.49 | 0.61 | +144% | -29% | +193% | 248 | | ReLU | 0.51 | 0.35 | 0.65 | 0.78 | 0.62 | +45% | +53% | +28% | 249 | | Scatter | 2.29 | 2.22 | 25.40 | 1.66 | 0.93 | +3% | -27% | +1009% | 250 | | ScatterSum | 0.04 | 0.03 | 0.01 | nan | 1.22 | +52% | nan% | -81% | 251 | | ScatterMax | 0.04 | 0.03 | 0.01 | nan | 1.23 | +52% | nan% | -81% | 252 | | SeLU | 1.35 | 0.29 | 5.14 | 0.48 | 2.93 | +361% | -64% | +281% | 253 | | Sigmoid | 0.30 | 0.29 | 26.28 | 0.49 | 2.85 | +4% | +62% | +8629% | 254 | | Softmax | 4.75 | 3.59 | 35.79 | 3.40 | 16.50 | +32% | -28% | +653% | 255 | | Softplus | 0.35 | 0.29 | 26.02 | 0.51 | 4.00 | +21% | +43% | +7257% | 256 | | Sort | 0.77 | 0.76 | 229.39 | 8.04 | 32.43 | +1% | +942% | +29646% | 257 | | Sum | 1.55 | 1.54 | 6.53 | 1.90 | 6.99 | +0% | +22% | +322% | 258 | | SumAll | 1.19 | 1.19 | 4.78 | 1.32 | 3.22 | +0% | +10% | +300% | 259 | 260 | **M4 (6E+4P+10GPU+16GB)** mlx: 0.20.0 261 | 262 | | Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 263 | |-----------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------| 264 | | Argmax | 1.56 | 1.49 | 8.33 | 1.43 | 5.39 | +4% | -7% | +434% | 265 | | BCE | 3.73 | 1.61 | 35.59 | 1.19 | 8.15 | +131% | -68% | +853% | 266 | | Concat | 12.61 | 12.42 | 50.48 | 12.59 | 29.69 | +1% | 0% | +300% | 267 | | Conv1d | 1.77 | 1.73 | 4.55 | 1.16 | 58.55 | +2% | -34% | +156% | 268 | | Conv2d | 4.94 | 4.99 | 42.63 | 1.48 | 25.15 | -1% | -70% | +763% | 269 | | Gather | 3.57 | 3.53 | 3.24 | 34.09 | 9.04 | +1% | +854% | -9% | 270 | | LeakyReLU | 0.76 | 0.76 | 0.69 | 0.82 | 0.83 | +0% | +8% | -9% | 271 | | Linear | 12.62 | 12.67 | 60.38 | 13.17 | 116.89 | 0% | +4% | +378% | 272 | | MatMul | 18.27 | 17.17 | 42.77 | 32.16 | 133.45 | +6% | +75% | +134% | 273 | | PReLU | 0.91 | 0.90 | 2.15 | 0.82 | 0.79 | +1% | -9% | +136% | 274 | | ReLU | 0.78 | 0.74 | 0.54 | 0.75 | 1.33 | +5% | -3% | -29% | 275 | | Scatter | 0.82 | 0.79 | 9.34 | 5.89 | 0.98 | +3% | +621% | +1043% | 276 | | ScatterSum | 0.00 | 0.00 | 0.00 | nan | 1.08 | +27% | nan% | -7% | 277 | | ScatterMax | 0.00 | 0.00 | 0.00 | nan | 1.14 | +36% | nan% | -5% | 278 | | SeLU | 0.89 | 0.88 | 3.65 | 0.81 | 1.65 | +1% | -8% | +308% | 279 | | Sigmoid | 0.75 | 0.75 | 3.48 | 0.81 | 1.42 | +0% | +7% | +364% | 280 | | Softmax | 18.11 | 13.82 | 38.51 | 6.02 | 28.30 | +31% | -66% | +112% | 281 | | Softplus | 0.83 | 0.76 | 21.28 | 0.78 | 3.51 | +9% | -6% | +2464% | 282 | | Sort | 1.99 | 1.99 | 218.30 | 32.71 | 98.28 | 0% | +1545% | +10884% | 283 | | Sum | 5.90 | 6.18 | 9.00 | 6.70 | 12.98 | -4% | +13% | +52% | 284 | | SumAll | 4.32 | 4.56 | 6.58 | 4.84 | 5.41 | -5% | +12% | +52% | 285 | 286 | **M4 Pro (4E+8P+16GPU+24GB)** mlx: 0.20.0 287 | 288 | | Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 289 | |-----------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------| 290 | | Argmax | 1.48 | 1.43 | 7.96 | 1.02 | 5.15 | +3% | -30% | +437% | 291 | | BCE | 1.47 | 0.70 | 34.00 | 0.70 | 4.26 | +110% | -52% | +2208% | 292 | | Concat | 5.59 | 5.33 | 48.85 | 5.03 | 27.93 | +5% | -10% | +773% | 293 | | Conv1d | 1.04 | 1.00 | 4.24 | 0.66 | 85.48 | +3% | -36% | +307% | 294 | | Conv2d | 3.05 | 3.08 | 32.51 | 0.80 | 29.57 | 0% | -73% | +967% | 295 | | Gather | 2.28 | 2.23 | 3.18 | 13.48 | 7.39 | +2% | +491% | +39% | 296 | | LeakyReLU | 0.30 | 0.30 | 0.64 | 0.35 | 0.79 | +0% | +17% | +112% | 297 | | Linear | 7.61 | 7.56 | 40.24 | 7.45 | 63.95 | +0% | -2% | +428% | 298 | | MatMul | 8.24 | 7.54 | 21.94 | 13.81 | 137.11 | +9% | +67% | +166% | 299 | | PReLU | 0.43 | 0.46 | 2.13 | 0.37 | 0.84 | -6% | -14% | +394% | 300 | | ReLU | 0.29 | 0.33 | 0.42 | 0.36 | 1.17 | -12% | +25% | +45% | 301 | | Scatter | 0.52 | 0.51 | 9.16 | 2.29 | 0.83 | +2% | +339% | +1658% | 302 | | ScatterSum | 0.00 | 0.00 | 0.00 | nan | 1.06 | +34% | nan% | -3% | 303 | | ScatterMax | 0.00 | 0.00 | 0.00 | nan | 1.03 | +19% | nan% | -5% | 304 | | SeLU | 0.46 | 0.43 | 3.62 | 0.40 | 1.13 | +7% | -11% | +693% | 305 | | Sigmoid | 0.28 | 0.29 | 3.46 | 0.34 | 1.01 | -5% | +23% | +1150% | 306 | | Softmax | 7.23 | 5.56 | 30.51 | 3.08 | 18.53 | +30% | -57% | +321% | 307 | | Softplus | 0.35 | 0.32 | 21.29 | 0.34 | 2.28 | +11% | -4% | +5944% | 308 | | Sort | 1.26 | 1.23 | 214.41 | 15.29 | 56.47 | +2% | +1112% | +16912% | 309 | | Sum | 2.25 | 2.28 | 5.91 | 2.61 | 6.09 | -1% | +16% | +163% | 310 | | SumAll | 1.69 | 1.70 | 4.28 | 1.77 | 1.82 | 0% | +4% | +153% | 311 | 312 | **M4 Pro (4E+10P+20GPU+24GB)** mlx: 0.24.1 313 | 314 | | Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 315 | |--------------------------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------| 316 | | Argmax | 1.42 | 1.45 | 7.67 | 0.78 | 5.01 | -1% | -45% | +438% | 317 | | BCE | 1.46 | 0.65 | 14.54 | 0.44 | 3.65 | +124% | -69% | +896% | 318 | | Concat | 5.54 | 5.22 | 48.61 | 4.96 | 29.00 | +6% | -10% | +777% | 319 | | Conv1d | 0.82 | 0.82 | 3.82 | 0.47 | 113.39 | 0% | -42% | +364% | 320 | | Conv2d | 2.52 | 2.54 | 31.65 | 0.67 | 30.70 | 0% | -73% | +1154% | 321 | | Gather | 1.30 | 1.32 | 3.13 | 13.28 | 6.85 | -1% | +919% | +140% | 322 | | LayerNorm | 0.42 | 0.43 | 2.62 | 0.84 | 1.16 | -1% | +96% | +517% | 323 | | LeakyReLU | 0.43 | 0.33 | 0.60 | 0.31 | 0.63 | +30% | -27% | +39% | 324 | | Linear | 6.37 | 6.22 | 39.72 | 6.20 | 36.36 | +2% | -2% | +523% | 325 | | MatMul | 9.81 | 6.86 | 22.04 | 13.28 | 67.90 | +43% | +35% | +124% | 326 | | PReLU | 0.57 | 0.31 | 3.02 | 0.34 | 0.62 | +83% | -39% | +431% | 327 | | ReLU | 0.31 | 0.30 | 0.37 | 0.41 | 0.78 | +1% | +33% | +21% | 328 | | ScaledDotProductAttention | 2.62 | 2.59 | 10.11 | 1.81 | 5.38 | +1% | -30% | +285% | 329 | | Scatter | 0.37 | 0.32 | 9.06 | 2.29 | 0.75 | +17% | +512% | +2320% | 330 | | ScatterSum | 0.00 | 0.00 | 0.00 | 0.28 | 0.97 | +18% | +22629% | +2% | 331 | | ScatterMax | 0.00 | 0.00 | 0.00 | 0.29 | 0.98 | +33% | +24052% | +0% | 332 | | SeLU | 0.95 | 0.34 | 6.10 | 0.32 | 0.93 | +176% | -66% | +541% | 333 | | Sigmoid | 0.36 | 0.36 | 2.19 | 0.28 | 0.84 | +1% | -22% | +504% | 334 | | Softmax | 7.17 | 5.37 | 32.41 | 2.77 | 13.97 | +33% | -61% | +352% | 335 | | Softplus | 0.32 | 0.32 | 20.25 | 0.30 | 1.86 | -1% | -6% | +6274% | 336 | | Sort | 1.29 | 1.24 | 209.96 | 14.67 | 45.28 | +4% | +1039% | +16202% | 337 | | Sum | 2.25 | 2.21 | 9.85 | 2.49 | 5.12 | +1% | +10% | +338% | 338 | | SumAll | 1.67 | 1.68 | 7.18 | 1.71 | 1.62 | 0% | +2% | +329% | 339 | 340 | **M4 Max (4E+12P+40GPU+128GB)** mlx: 0.20.0 341 | 342 | | Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 343 | |-----------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------| 344 | | Argmax | 1.41 | 1.43 | 8.01 | 0.70 | 4.69 | -1% | -50% | +468% | 345 | | BCE | 0.88 | 0.89 | 33.88 | 0.52 | 2.93 | -1% | -40% | +3770% | 346 | | Concat | 2.86 | 2.87 | 47.64 | 2.67 | 19.94 | 0% | -6% | +1563% | 347 | | Conv1d | 0.59 | 0.51 | 3.81 | 0.40 | 110.98 | +15% | -31% | +542% | 348 | | Conv2d | 1.43 | 1.43 | 32.19 | 0.60 | 26.71 | +0% | -58% | +2152% | 349 | | Gather | 1.08 | 1.04 | 3.07 | 7.37 | 6.20 | +3% | +584% | +185% | 350 | | LeakyReLU | 0.24 | 0.23 | 0.67 | 0.19 | 0.75 | +2% | -17% | +181% | 351 | | Linear | 3.35 | 3.83 | 39.12 | 3.21 | 45.63 | -12% | -4% | +1069% | 352 | | MatMul | 4.12 | 4.21 | 21.00 | 6.36 | 139.36 | -2% | +54% | +409% | 353 | | PReLU | 0.38 | 0.29 | 2.10 | 0.28 | 0.87 | +30% | -26% | +445% | 354 | | ReLU | 0.23 | 0.22 | 0.42 | 0.27 | 0.87 | +3% | +21% | +86% | 355 | | Scatter | 0.28 | 0.27 | 9.08 | 1.25 | 0.65 | +2% | +343% | +3122% | 356 | | ScatterSum | 0.00 | 0.00 | 0.00 | nan | 1.06 | +28% | nan% | -7% | 357 | | ScatterMax | 0.00 | 0.00 | 0.00 | nan | 1.06 | +10% | nan% | -3% | 358 | | SeLU | 0.34 | 0.29 | 3.63 | 0.35 | 1.12 | +18% | +2% | +967% | 359 | | Sigmoid | 0.21 | 0.22 | 3.45 | 0.27 | 0.93 | -6% | +27% | +1551% | 360 | | Softmax | 3.97 | 3.40 | 29.65 | 1.52 | 13.74 | +16% | -61% | +647% | 361 | | Softplus | 0.29 | 0.26 | 21.32 | 0.25 | 1.88 | +11% | -13% | +7339% | 362 | | Sort | 0.58 | 0.59 | 207.83 | 7.18 | 41.13 | 0% | +1128% | +35475% | 363 | | Sum | 1.25 | 1.23 | 6.28 | 1.44 | 4.54 | +1% | +15% | +403% | 364 | | SumAll | 0.95 | 0.93 | 4.61 | 1.07 | 1.48 | +1% | +13% | +387% | 365 | 366 | **M3 Ultra (8E+20P+60GPU+96GB)** mlx: 0.24.1 367 | | Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup | 368 | |-----------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------| 369 | | Argmax | 1.77 | 1.70 | 8.40 | 0.59 | 7.99 | +4% | -66% | +373% | 370 | | BCE | 0.68 | 0.37 | 15.70 | 0.54 | 2.77 | +82% | -21% | +2201% | 371 | | Concat | 1.88 | 1.85 | 84.41 | 1.80 | 40.87 | +1% | -4% | +4379% | 372 | | Conv1d | 0.52 | 0.52 | 4.16 | 0.44 | 148.64 | +0% | -14% | +700% | 373 | | Conv2d | 1.32 | 1.30 | 31.36 | 0.55 | 33.67 | +1% | -58% | +2272% | 374 | | Gather | 0.60 | 0.72 | 3.38 | 4.11 | 14.71 | -15% | +579% | +458% | 375 | | LeakyReLU | 0.26 | 0.26 | 0.67 | 0.22 | 1.47 | 0% | -13% | +158% | 376 | | Linear | 2.74 | 2.77 | 52.00 | 2.53 | 45.65 | -1% | -7% | +1800% | 377 | | MatMul | 3.38 | 3.54 | 16.76 | 4.96 | 231.43 | -4% | +46% | +395% | 378 | | PReLU | 0.62 | 0.48 | 3.15 | 0.26 | 1.67 | +29% | -58% | +412% | 379 | | ReLU | 0.26 | 0.35 | 0.42 | 0.22 | 1.47 | -25% | -12% | +63% | 380 | | Scatter | 0.26 | 0.26 | 10.27 | 0.82 | 0.93 | +2% | +211% | +3792% | 381 | | ScatterSum | 0.00 | 0.00 | 0.00 | 0.41 | 1.25 | +43% | +18803% | -33% | 382 | | ScatterMax | 0.00 | 0.00 | 0.00 | 1.22 | 1.23 | +53% | +60915% | -27% | 383 | | SeLU | 0.60 | 0.43 | 6.39 | 0.25 | 1.29 | +40% | -58% | +963% | 384 | | Sigmoid | 0.26 | 0.26 | 2.21 | 0.27 | 1.25 | +2% | +2% | +740% | 385 | | Softmax | 2.67 | 2.08 | 33.54 | 1.22 | 11.51 | +28% | -54% | +1155% | 386 | | Softplus | 0.29 | 0.25 | 24.63 | 0.23 | 1.76 | +14% | -19% | +8357% | 387 | | Sort | 0.73 | 0.60 | 213.11 | 6.18 | 33.69 | +22% | +746% | +29113% | 388 | | Sum | 0.96 | 0.96 | 10.85 | 1.09 | 4.35 | 0% | +13% | +1029% | 389 | | SumAll | 0.73 | 0.76 | 7.93 | 0.84 | 1.67 | -3% | +15% | +982% | 390 | 391 | ## CUDA GPUs 392 | 393 | **Tesla V100 PCIe (32Go / Intel Xeon Gold 5120 14 cores / 28 threads @ 2.2GHz (Skylake), 60Go)** 394 | 395 | | Operation | cpu | cuda | cuda/cpu speedup | 396 | |-----------------|------|------|----------------| 397 | | Argmax | 34.34 | 0.10 | +33411% | 398 | | BCE | 198.19 | 0.19 | +102820% | 399 | | Concat | 380.98 | 1.67 | +22679% | 400 | | Conv1d | 30.21 | 0.33 | +9027% | 401 | | Conv2d | 52.73 | 0.87 | +5938% | 402 | | Gather | 96.61 | 0.42 | +22636% | 403 | | LeakyReLU | 5.51 | 0.08 | +7010% | 404 | | Linear | 901.98 | 3.79 | +23722% | 405 | | MatMul | 1241.12 | 2.80 | +44293% | 406 | | PReLU | 5.55 | 0.08 | +7159% | 407 | | ReLU | 5.50 | 0.08 | +7032% | 408 | | Scatter | 6.92 | 0.12 | +5875% | 409 | | ScatterSum | 4.25 | 0.08 | +5058% | 410 | | ScatterMax | nan | nan | nan% | 411 | | SeLU | 11.56 | 0.08 | +14709% | 412 | | Sigmoid | 9.46 | 0.08 | +12023% | 413 | | Softmax | 221.43 | 0.71 | +31300% | 414 | | Softplus | 22.13 | 0.08 | +27658% | 415 | | Sort | 526.33 | 2.59 | +20202% | 416 | | Sum | 67.43 | 0.70 | +9472% | 417 | | SumAll | 29.82 | 0.50 | +5822% | 418 | 419 | **Tesla V100 NVLink (32Go / Intel Xeon Gold 6148 20 cores, 40 threads @ 2.4 GHz (Skylake), 60Go)** 420 | 421 | | Operation | cpu | cuda | cuda/cpu speedup | 422 | |-----------------|------|------|----------------| 423 | | Argmax | 28.23 | 0.10 | +28460% | 424 | | BCE | 186.05 | 0.19 | +97956% | 425 | | Concat | 531.34 | 1.67 | +31744% | 426 | | Conv1d | 22.37 | 0.31 | +7033% | 427 | | Conv2d | 52.89 | 0.83 | +6257% | 428 | | Gather | 161.56 | 0.41 | +39152% | 429 | | LeakyReLU | 16.95 | 0.08 | +21591% | 430 | | Linear | 666.79 | 3.58 | +18532% | 431 | | MatMul | 998.29 | 2.68 | +37198% | 432 | | PReLU | 15.55 | 0.08 | +20584% | 433 | | ReLU | 14.07 | 0.08 | +18496% | 434 | | Scatter | 6.19 | 0.11 | +5548% | 435 | | ScatterSum | 6.83 | 0.08 | +8757% | 436 | | ScatterMax | nan | nan | nan% | 437 | | SeLU | 20.94 | 0.08 | +27171% | 438 | | Sigmoid | 19.82 | 0.08 | +25331% | 439 | | Softmax | 253.76 | 0.70 | +36156% | 440 | | Softplus | 29.21 | 0.08 | +37131% | 441 | | Sort | 422.98 | 2.48 | +16933% | 442 | | Sum | 69.38 | 0.70 | +9861% | 443 | | SumAll | 31.13 | 0.50 | +6152% | 444 | 445 | **RTX4090 ((Desktop) / 10th Gen Intel Core i9-10940X @ 3.30GHz 128GB)** 446 | 447 | | Operation | cpu | cuda | cuda/cpu speedup | 448 | |-----------------|------|------|----------------| 449 | | Argmax | 6.67 | 0.04 | +14782% | 450 | | BCE | 23.74 | 0.14 | +16992% | 451 | | Concat | 52.08 | 1.29 | +3922% | 452 | | Conv1d | 2.84 | 0.15 | +1753% | 453 | | Conv2d | 6.60 | 0.25 | +2559% | 454 | | Gather | 19.75 | 0.27 | +7340% | 455 | | LeakyReLU | 2.44 | 0.03 | +7439% | 456 | | Linear | 62.27 | 1.01 | +6057% | 457 | | MatMul | 87.47 | 1.36 | +6322% | 458 | | PReLU | 2.28 | 0.04 | +5297% | 459 | | ReLU | 2.47 | 0.03 | +7216% | 460 | | Scatter | 1.84 | 0.07 | +2652% | 461 | | ScatterSum | 3.86 | 0.06 | +5919% | 462 | | ScatterMax | 3.86 | 0.08 | +4790% | 463 | | SeLU | 2.71 | 0.04 | +6952% | 464 | | Sigmoid | 2.63 | 0.05 | +5626% | 465 | | Softmax | 27.75 | 0.59 | +4634% | 466 | | Softplus | 3.50 | 0.04 | +8149% | 467 | | Sort | 46.67 | 0.90 | +5077% | 468 | | Sum | 12.19 | 0.62 | +1866% | 469 | | SumAll | 6.95 | 0.45 | +1428% | 470 | 471 | A100 80GB 80GB PCIe ((Server) / Intel(R) Xeon(R) Gold 6254 CPU @ 3.10GHz, 754GB) 472 | 473 | | Operation | cpu | cuda | cuda/cpu speedup | 474 | |-----------------|------|------|----------------| 475 | | Argmax | 5.04 | 0.06 | +7856% | 476 | | BCE | 18.22 | 0.11 | +16097% | 477 | | Concat | 30.47 | 0.74 | +4036% | 478 | | Conv1d | 1029.44 | 0.13 | +811270% | 479 | | Conv2d | 531.83 | 0.26 | +205989% | 480 | | Gather | 9.59 | 0.30 | +3045% | 481 | | LeakyReLU | 1.68 | 0.06 | +2579% | 482 | | Linear | 47.44 | 2.17 | +2090% | 483 | | MatMul | 50.91 | 2.07 | +2355% | 484 | | PReLU | 1.60 | 0.05 | +3332% | 485 | | ReLU | 1.43 | 0.04 | +3380% | 486 | | Scatter | 1.61 | 0.11 | +1358% | 487 | | ScatterSum | 4.95 | 0.06 | +7547% | 488 | | ScatterMax | 5.39 | 0.33 | +1511% | 489 | | SeLU | 1.82 | 0.04 | +4259% | 490 | | Sigmoid | 3.03 | 0.04 | +7553% | 491 | | Softmax | 18.18 | 0.36 | +5003% | 492 | | Softplus | 2.87 | 0.04 | +6412% | 493 | | Sort | 52.86 | 1.16 | +4449% | 494 | | Sum | 11.38 | 0.37 | +2947% | 495 | | SumAll | 6.85 | 0.29 | +2226% | 496 | 497 | -------------------------------------------------------------------------------- /images/plot_BCE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TristanBilot/mlx-benchmark/b5bd4888ec4deb374664b2e49cac09b41bbd9c34/images/plot_BCE.png -------------------------------------------------------------------------------- /images/plot_Concat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TristanBilot/mlx-benchmark/b5bd4888ec4deb374664b2e49cac09b41bbd9c34/images/plot_Concat.png -------------------------------------------------------------------------------- /images/plot_Conv2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TristanBilot/mlx-benchmark/b5bd4888ec4deb374664b2e49cac09b41bbd9c34/images/plot_Conv2d.png -------------------------------------------------------------------------------- /images/plot_Linear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TristanBilot/mlx-benchmark/b5bd4888ec4deb374664b2e49cac09b41bbd9c34/images/plot_Linear.png -------------------------------------------------------------------------------- /images/plot_MatMul.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TristanBilot/mlx-benchmark/b5bd4888ec4deb374664b2e49cac09b41bbd9c34/images/plot_MatMul.png -------------------------------------------------------------------------------- /images/plot_Sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TristanBilot/mlx-benchmark/b5bd4888ec4deb374664b2e49cac09b41bbd9c34/images/plot_Sigmoid.png -------------------------------------------------------------------------------- /images/plot_Softmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TristanBilot/mlx-benchmark/b5bd4888ec4deb374664b2e49cac09b41bbd9c34/images/plot_Softmax.png -------------------------------------------------------------------------------- /images/plot_Sort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TristanBilot/mlx-benchmark/b5bd4888ec4deb374664b2e49cac09b41bbd9c34/images/plot_Sort.png -------------------------------------------------------------------------------- /mlx_benchmark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TristanBilot/mlx-benchmark/b5bd4888ec4deb374664b2e49cac09b41bbd9c34/mlx_benchmark/__init__.py -------------------------------------------------------------------------------- /mlx_benchmark/base_benchmark.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Callable 3 | 4 | from config import USE_MLX 5 | 6 | if USE_MLX: 7 | import mlx.core as mx 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | class BaseBenchmark: 14 | """ 15 | Base class for benchmarking different operations or layers. 16 | """ 17 | 18 | def __init__(self, **kwargs): 19 | self.device = None 20 | self.y = [] 21 | self.args_str = " ".join([f"{k[:3]}={v}" for k, v in kwargs.items()]) 22 | 23 | self.kwargs = kwargs 24 | self.compiled_fn: Callable = None 25 | 26 | self.inputs = None 27 | 28 | def compute_inputs(self, framework, device=None): 29 | """ 30 | Generates the default inputs for all benchmarks. 31 | """ 32 | dims = [] 33 | for d in ["dim1", "dim2", "dim3"]: 34 | if d in self.kwargs: 35 | dim = self.kwargs[d].split("x") 36 | dims.append([int(i) for i in dim]) 37 | 38 | if framework == "mlx": 39 | self.inputs = [mx.random.normal(dim).astype(mx.float32) for dim in dims] 40 | else: 41 | self.inputs = [ 42 | torch.randn(dim, device=device, dtype=torch.float32) for dim in dims 43 | ] 44 | 45 | if "axis" in self.kwargs: 46 | self.axis = self.kwargs["axis"] 47 | 48 | def additional_preprocessing(self, framework=None, device=None): 49 | """ 50 | Can be overridden if custom preprocessing has to be performed on the default 51 | `self.inputs` given to operations. 52 | """ 53 | pass 54 | 55 | def forward_mlx(self, **kwargs): 56 | """ 57 | Abstract method for the forward pass of the benchmark on MLX. 58 | Should be implemented by subclasses to define the actual computation 59 | or model inference to be benchmarked. 60 | """ 61 | raise NotImplementedError 62 | 63 | def forward_torch(self, **kwargs): 64 | """ 65 | Abstract method for the forward pass of the benchmark on torch. 66 | Same implementation in torch. 67 | """ 68 | raise NotImplementedError 69 | 70 | def run(self, framework, compile=False, device=None, **kwargs) -> float: 71 | """ 72 | Runs the benchmark for a specified number of iterations. 73 | Measures and records the duration of each forward pass. 74 | """ 75 | if self.inputs is None: 76 | self.compute_inputs(framework, device) 77 | self.additional_preprocessing(framework, device) 78 | 79 | if framework == "mlx": 80 | forward_fn = self.forward_mlx 81 | kwargs = { 82 | **kwargs, 83 | "compile": compile, 84 | } 85 | 86 | elif framework == "torch": 87 | self.device = device 88 | forward_fn = self.forward_torch 89 | 90 | else: 91 | raise ValueError("Invalid framework.") 92 | 93 | # Measures runtime for n iterations. 94 | duration = np.mean( 95 | [self._measure_runtime(forward_fn, **kwargs) for _ in range(10)][1:] 96 | ) 97 | # [1:] is used to remove the first measure, usually slower 98 | # due to cold start. 99 | 100 | duration_ms = duration * 1000 101 | return duration_ms 102 | 103 | def _measure_runtime(self, fn, **kwargs) -> float: 104 | """ 105 | Simply runs the forward method and measures metrics. 106 | """ 107 | tic = time.perf_counter() 108 | try: 109 | fn(**kwargs) 110 | except (NotImplementedError, RuntimeError): 111 | return float("nan") 112 | 113 | duration = time.perf_counter() - tic 114 | 115 | return duration 116 | 117 | def sync_torch_gpu_if_needed(self): 118 | """ 119 | Call this function after every torch implementation to ensure 120 | the mps or cuda execution has finished. 121 | """ 122 | if self.device == "cuda": 123 | torch.cuda.synchronize() 124 | elif self.device == "mps": 125 | torch.mps.synchronize() 126 | 127 | def compile_if_needed(self, fn, **kwargs): 128 | """ 129 | Caches the compiled function `fun` when called for the first time. 130 | """ 131 | if "compile" in kwargs and kwargs["compile"]: 132 | if self.compiled_fn is not None: 133 | return self.compiled_fn 134 | self.compiled_fn = mx.compile(fn) 135 | return self.compiled_fn 136 | return fn 137 | 138 | def clear(self): 139 | self.inputs = None 140 | self.y = [] 141 | for attribute in [ 142 | "a_torch", 143 | "b_torch", 144 | "b_mlx", 145 | "src", 146 | "index", 147 | "node_features", 148 | ]: 149 | if hasattr(self, attribute): 150 | delattr(self, attribute) 151 | -------------------------------------------------------------------------------- /mlx_benchmark/config.py: -------------------------------------------------------------------------------- 1 | # Set to True if the benchmark is run on a Mac device. 2 | # Set to False if the benchmark is run on another device, 3 | # e.g. Linux or Windows with or without CUDA. 4 | USE_MLX = True 5 | -------------------------------------------------------------------------------- /mlx_benchmark/get_cpu_gpu_config.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import re 3 | import mlx.core as mx 4 | 5 | 6 | def get_system_info(): 7 | hardware_result = subprocess.run( 8 | ["system_profiler", "SPHardwareDataType"], capture_output=True, text=True 9 | ) 10 | hardware_data = hardware_result.stdout 11 | 12 | display_result = subprocess.run( 13 | ["system_profiler", "SPDisplaysDataType"], capture_output=True, text=True 14 | ) 15 | display_data = display_result.stdout 16 | 17 | chipset_pattern = r"Chip: (.+)" 18 | chipset_match = re.search(chipset_pattern, hardware_data) 19 | chipset_model = ( 20 | chipset_match.group(1).strip() if chipset_match else "Unknown Chipset" 21 | ) 22 | 23 | cores_pattern = ( 24 | r"Total Number of Cores: (\d+) \((\d+) performance and (\d+) efficiency\)" 25 | ) 26 | match = re.search(cores_pattern, hardware_data) 27 | if match: 28 | total_cores, performance_cores, efficiency_cores = match.groups() 29 | 30 | gpu_cores_pattern = r"Total Number of Cores: (\d+)" 31 | gpu_match = re.search(gpu_cores_pattern, display_data) 32 | if gpu_match: 33 | gpu_cores = gpu_match.group(1) 34 | 35 | ram_pattern = r"Memory: (.+)" 36 | ram_match = re.search(ram_pattern, hardware_data) 37 | ram = ram_match.group(1).strip() if ram_match else "Unknown RAM" 38 | ram = ram.split(" ")[0] 39 | 40 | chipset_model = chipset_model.split("Apple")[-1] 41 | formatted_output = f"{chipset_model} ({efficiency_cores}E+{performance_cores}P+{gpu_cores}GPU+{ram}GB)" 42 | return formatted_output 43 | 44 | 45 | description = f"{get_system_info()} - mlx: {mx.__version__}" 46 | print(description) 47 | -------------------------------------------------------------------------------- /mlx_benchmark/operations/__init__.py: -------------------------------------------------------------------------------- 1 | from .binary_cross_entropy import BCE 2 | from .concat import Concat 3 | from .conv import Conv1d, Conv2d 4 | from .gather_scatter import Gather, Scatter, ScatterSum, ScatterMax 5 | from .layernorm import LayerNorm 6 | from .linear import Linear 7 | from .matmul import MatMul 8 | from .simple_operations import ( 9 | Sort, 10 | Argmax, 11 | Softmax, 12 | ReLU, 13 | PReLU, 14 | LeakyReLU, 15 | Softplus, 16 | SeLU, 17 | Sigmoid, 18 | Sum, 19 | SumAll, 20 | ) 21 | from .scaled_dot_product_attention import ScaledDotProductAttention 22 | 23 | -------------------------------------------------------------------------------- /mlx_benchmark/operations/binary_cross_entropy.py: -------------------------------------------------------------------------------- 1 | from config import USE_MLX 2 | 3 | if USE_MLX: 4 | import mlx.core as mx 5 | import mlx.nn as mx_nn 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from base_benchmark import BaseBenchmark 11 | 12 | 13 | class BCE(BaseBenchmark): 14 | def __init__(self, **kwargs): 15 | super().__init__(**kwargs) 16 | 17 | def additional_preprocessing(self, framework, device): 18 | if framework == "torch": 19 | self.a_torch = torch.rand(*self.inputs[0].shape).to(device) 20 | self.b_torch = torch.randint( 21 | size=self.inputs[1].shape, 22 | dtype=torch.float32, 23 | low=0, 24 | high=2, 25 | device=device, 26 | ) 27 | else: 28 | self.b_mlx = mx.random.randint(shape=self.inputs[1].shape, low=0, high=2) 29 | 30 | def forward_mlx(self, compile=False, **kwargs): 31 | a, _ = self.inputs 32 | b = self.b_mlx 33 | 34 | fn = self.compile_if_needed(mx_nn.losses.binary_cross_entropy, compile=compile) 35 | y = fn(a, b) 36 | mx.eval(y) 37 | 38 | @torch.no_grad() 39 | def forward_torch(self, **kwargs): 40 | a, b = self.a_torch, self.b_torch 41 | 42 | y = F.binary_cross_entropy(a, b) 43 | self.sync_torch_gpu_if_needed() 44 | -------------------------------------------------------------------------------- /mlx_benchmark/operations/concat.py: -------------------------------------------------------------------------------- 1 | from config import USE_MLX 2 | 3 | if USE_MLX: 4 | import mlx.core as mx 5 | 6 | import torch 7 | 8 | from base_benchmark import BaseBenchmark 9 | 10 | 11 | class Concat(BaseBenchmark): 12 | def __init__(self, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | def forward_mlx(self, compile=False, **kwargs): 16 | a, b = self.inputs 17 | 18 | fn = self.compile_if_needed(mx.concatenate, compile=compile) 19 | y = fn([a, b], self.kwargs["axis"]) 20 | mx.eval(y) 21 | 22 | @torch.no_grad() 23 | def forward_torch(self, **kwargs): 24 | a, b = self.inputs 25 | 26 | y = torch.cat([a, b], dim=self.kwargs["axis"]) 27 | self.sync_torch_gpu_if_needed() 28 | -------------------------------------------------------------------------------- /mlx_benchmark/operations/conv.py: -------------------------------------------------------------------------------- 1 | from config import USE_MLX 2 | 3 | if USE_MLX: 4 | import mlx.core as mx 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from base_benchmark import BaseBenchmark 10 | 11 | 12 | class Conv1d(BaseBenchmark): 13 | def __init__(self, **kwargs): 14 | super().__init__(**kwargs) 15 | 16 | def additional_preprocessing(self, framework, device): 17 | a, b = self.inputs 18 | 19 | # In torch, the channels are located at axis 1 20 | if framework == "torch": 21 | self.a_torch = torch.transpose(a, -1, -2) 22 | self.b_torch = torch.transpose(b, -1, -2) 23 | 24 | def forward_mlx(self, compile=False, **kwargs): 25 | a, b = self.inputs 26 | 27 | fn = self.compile_if_needed(mx.conv1d, compile=compile) 28 | y = fn(a, b) 29 | mx.eval(y) 30 | 31 | @torch.no_grad() 32 | def forward_torch(self, **kwargs): 33 | a, b = self.a_torch, self.b_torch 34 | 35 | y = F.conv1d(a, b) 36 | self.sync_torch_gpu_if_needed() 37 | 38 | 39 | class Conv2d(BaseBenchmark): 40 | def __init__(self, **kwargs): 41 | super().__init__(**kwargs) 42 | 43 | def additional_preprocessing(self, framework, device): 44 | a, b = self.inputs 45 | 46 | if framework == "torch": 47 | self.a_torch = torch.permute(a, (0, 3, 1, 2)) 48 | self.b_torch = torch.permute(b, (0, 3, 1, 2)) 49 | 50 | def forward_mlx(self, compile=False, **kwargs): 51 | a, b = self.inputs 52 | 53 | fn = self.compile_if_needed(mx.conv2d, compile=compile) 54 | y = fn(a, b) 55 | mx.eval(y) 56 | 57 | @torch.no_grad() 58 | def forward_torch(self, **kwargs): 59 | a, b = self.a_torch, self.b_torch 60 | 61 | y = F.conv2d(a, b) 62 | self.sync_torch_gpu_if_needed() 63 | -------------------------------------------------------------------------------- /mlx_benchmark/operations/gather_scatter.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from config import USE_MLX 3 | 4 | if USE_MLX: 5 | import mlx.core as mx 6 | 7 | import torch 8 | 9 | from base_benchmark import BaseBenchmark 10 | from utils import get_dummy_edge_index 11 | 12 | 13 | class Gather(BaseBenchmark): 14 | def __init__(self, **kwargs): 15 | super().__init__(**kwargs) 16 | 17 | def additional_preprocessing(self, framework, device): 18 | node_features, edge_index = self.inputs 19 | 20 | num_nodes = node_features.shape[0] 21 | edge_index = get_dummy_edge_index( 22 | (2, edge_index.shape[0]), num_nodes, device, framework 23 | ) 24 | 25 | self.node_features = node_features 26 | self.index = edge_index[0] 27 | 28 | def forward_mlx(self, compile=False, **kwargs): 29 | a, b = self.node_features, self.index 30 | 31 | fn = lambda x, y: x[y] 32 | fn = self.compile_if_needed(fn, compile=compile) 33 | 34 | y = fn(a, b) 35 | mx.eval(y) 36 | 37 | @torch.no_grad() 38 | def forward_torch(self, **kwargs): 39 | a, b = self.node_features, self.index 40 | 41 | y = a[b] 42 | self.sync_torch_gpu_if_needed() 43 | 44 | 45 | class _Scatter(BaseBenchmark): 46 | def __init__(self, scatter_op: Literal["indexing", "add", "max"], **kwargs): 47 | super().__init__(**kwargs) 48 | 49 | self.scatter_op = scatter_op 50 | 51 | def additional_preprocessing(self, framework, device): 52 | node_features, edge_index = self.inputs 53 | 54 | num_nodes = node_features.shape[0] 55 | edge_index = get_dummy_edge_index( 56 | (2, edge_index.shape[0]), num_nodes, device, framework 57 | ) 58 | 59 | self.node_features = node_features[edge_index[0]] 60 | self.index = edge_index[1] 61 | self.src = ( 62 | mx.zeros_like(node_features) 63 | if framework == "mlx" 64 | else torch.zeros_like(node_features) 65 | ) 66 | 67 | if self.scatter_op != "indexing" and framework == "torch": 68 | self.index = self.index.unsqueeze(1) 69 | 70 | def fn_indexing(x, y, z): 71 | x[y] = z 72 | return x 73 | 74 | def fn_add(x, y, z): 75 | x.at[y].add(z) 76 | return x 77 | 78 | def fn_max(x, y, z): 79 | x.at[y].maximum(z) 80 | return x 81 | 82 | fns = { 83 | "indexing": fn_indexing, 84 | "add": fn_add, 85 | "max": fn_max, 86 | } 87 | self._scatter_fn_mlx = fns[self.scatter_op] 88 | 89 | def fn_indexing_torch(x, y, z): 90 | x[y] = z 91 | return x 92 | 93 | def fn_add_torch(x, y, z): 94 | x = torch.scatter_reduce(x, 0, y, z, reduce="sum") 95 | return x 96 | 97 | def fn_max_torch(x, y, z): 98 | x = torch.scatter_reduce(x, 0, y, z, reduce="max") 99 | return x 100 | 101 | fns = { 102 | "indexing": fn_indexing_torch, 103 | "add": fn_add_torch, 104 | "max": fn_max_torch, 105 | } 106 | self._scatter_fn_torch = fns[self.scatter_op] 107 | 108 | def forward_mlx(self, compile=False, **kwargs): 109 | a, b, c = self.src, self.index, self.node_features 110 | 111 | fn = self.compile_if_needed(self._scatter_fn_mlx, compile=compile) 112 | 113 | y = fn(a, b, c) 114 | mx.eval(y) 115 | 116 | @torch.no_grad() 117 | def forward_torch(self, **kwargs): 118 | a, b, c = self.src, self.index, self.node_features 119 | 120 | y = self._scatter_fn_torch(a, b, c) 121 | self.sync_torch_gpu_if_needed() 122 | 123 | 124 | class Scatter(_Scatter): 125 | def __init__(self, **kwargs): 126 | super().__init__(scatter_op="indexing", **kwargs) 127 | 128 | 129 | class ScatterSum(_Scatter): 130 | def __init__(self, **kwargs): 131 | super().__init__(scatter_op="add", **kwargs) 132 | 133 | 134 | class ScatterMax(_Scatter): 135 | def __init__(self, **kwargs): 136 | super().__init__(scatter_op="max", **kwargs) 137 | -------------------------------------------------------------------------------- /mlx_benchmark/operations/layernorm.py: -------------------------------------------------------------------------------- 1 | from config import USE_MLX 2 | 3 | if USE_MLX: 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from base_benchmark import BaseBenchmark 10 | 11 | 12 | class LayerNorm(BaseBenchmark): 13 | def additional_preprocessing(self, framework=None, device=None): 14 | if framework == "mlx": 15 | (x,) = self.inputs 16 | 17 | self.mlx_layernorm = nn.LayerNorm( 18 | dims=x.shape[2], 19 | affine=False, 20 | bias=False, 21 | ) 22 | 23 | return 24 | 25 | def forward_mlx(self, compile=False, **kwargs): 26 | (x,) = self.inputs 27 | 28 | fn = self.mlx_layernorm 29 | fn = self.compile_if_needed(fn, compile=compile) 30 | 31 | y = fn(x) 32 | mx.eval(y) 33 | 34 | @torch.inference_mode() 35 | def forward_torch(self, **kwargs): 36 | (x,) = self.inputs 37 | 38 | fn = lambda x: F.layer_norm( 39 | x, normalized_shape=x.shape[2:], bias=None, weight=None 40 | ) 41 | y = fn(x) 42 | 43 | self.sync_torch_gpu_if_needed() 44 | -------------------------------------------------------------------------------- /mlx_benchmark/operations/linear.py: -------------------------------------------------------------------------------- 1 | from config import USE_MLX 2 | 3 | if USE_MLX: 4 | import mlx.core as mx 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from base_benchmark import BaseBenchmark 10 | 11 | 12 | class Linear(BaseBenchmark): 13 | def __init__(self, **kwargs): 14 | super().__init__(**kwargs) 15 | 16 | def additional_preprocessing(self, framework, device): 17 | _, b, _ = self.inputs 18 | 19 | if framework == "torch": 20 | self.b_torch = b.T 21 | 22 | def forward_mlx(self, compile=False, **kwargs): 23 | a, b, c = self.inputs 24 | 25 | fn = lambda x, y, z: mx.addmm(z, x, y) 26 | fn = self.compile_if_needed(fn, compile=compile) 27 | 28 | y = fn(a, b, c) 29 | mx.eval(y) 30 | 31 | @torch.no_grad() 32 | def forward_torch(self, **kwargs): 33 | a, _, c = self.inputs 34 | b = self.b_torch 35 | 36 | # NOTE: torch.addmm only supports 2D matrix so we use linear 37 | y = F.linear(a, b, c) 38 | self.sync_torch_gpu_if_needed() 39 | -------------------------------------------------------------------------------- /mlx_benchmark/operations/matmul.py: -------------------------------------------------------------------------------- 1 | from config import USE_MLX 2 | 3 | if USE_MLX: 4 | import mlx.core as mx 5 | 6 | import torch 7 | 8 | from base_benchmark import BaseBenchmark 9 | 10 | 11 | class MatMul(BaseBenchmark): 12 | def __init__(self, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | def forward_mlx(self, compile=False, **kwargs): 16 | a, b = self.inputs 17 | 18 | fn = lambda x, y: x @ y 19 | fn = self.compile_if_needed(fn, compile=compile) 20 | 21 | y = fn(a, b) 22 | mx.eval(y) 23 | 24 | @torch.no_grad() 25 | def forward_torch(self, **kwargs): 26 | a, b = self.inputs 27 | 28 | y = a @ b 29 | self.sync_torch_gpu_if_needed() 30 | -------------------------------------------------------------------------------- /mlx_benchmark/operations/scaled_dot_product_attention.py: -------------------------------------------------------------------------------- 1 | from config import USE_MLX 2 | import math 3 | 4 | if USE_MLX: 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | import mlx.core.fast 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from base_benchmark import BaseBenchmark 13 | 14 | 15 | class ScaledDotProductAttention(BaseBenchmark): 16 | def additional_preprocessing(self, framework=None, device=None): 17 | q, k, v = self.inputs 18 | 19 | # verify input shapes 20 | batch_size, _, num_tokens_q, dim_q = q.shape 21 | batch_size_k, num_heads_k, num_tokens_k, dim_k = k.shape 22 | batch_size_v, num_heads_v, num_tokens_v, _ = v.shape 23 | 24 | if (batch_size_k != batch_size) or (batch_size_v != batch_size) or \ 25 | (num_tokens_k != num_tokens_v) or (num_heads_k != num_heads_v) or (dim_q != dim_k): 26 | raise ValueError( 27 | f"incompatible shapes: q.shape = {q.shape}, k.shape = {k.shape}, v.shape = {v.shape}" 28 | ) 29 | 30 | # create additive causal mask, scale 31 | attention_mask = torch.ones(num_tokens_q, num_tokens_k, dtype=torch.bool).tril(diagonal=0) 32 | attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf")) 33 | 34 | if framework == "mlx": 35 | self.attention_mask_mlx = mx.array(attention_mask, dtype=q.dtype) 36 | 37 | if framework == "torch" and device is not None: 38 | self.attention_mask_torch = attention_mask.to(dtype=q.dtype, device=device) 39 | 40 | self.scale_factor = 1 / math.sqrt(dim_q) 41 | 42 | return 43 | 44 | def forward_mlx(self, compile=False, **kwargs): 45 | q, k, v = self.inputs 46 | 47 | fn = mlx.core.fast.scaled_dot_product_attention 48 | fn = self.compile_if_needed(fn, compile=compile) 49 | 50 | y = fn(q, k, v, scale=self.scale_factor, mask=self.attention_mask_mlx) 51 | mx.eval(y) 52 | 53 | @torch.inference_mode() 54 | def forward_torch(self, **kwargs): 55 | q, k, v = self.inputs 56 | 57 | fn = F.scaled_dot_product_attention 58 | 59 | y = fn(q, k, v, scale=self.scale_factor, attn_mask=self.attention_mask_torch) 60 | self.sync_torch_gpu_if_needed() 61 | -------------------------------------------------------------------------------- /mlx_benchmark/operations/simple_operations.py: -------------------------------------------------------------------------------- 1 | from config import USE_MLX 2 | 3 | if USE_MLX: 4 | import mlx.core as mx 5 | import mlx.nn as mx_nn 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from base_benchmark import BaseBenchmark 11 | 12 | if USE_MLX: 13 | mlx_simple_operations = { 14 | "sort": mx.sort, 15 | "argmax": mx.argmax, 16 | "softmax": mx.softmax, 17 | "relu": mx_nn.relu, 18 | "leaky_relu": mx_nn.leaky_relu, 19 | "prelu": mx_nn.prelu, 20 | "softplus": mx_nn.softplus, 21 | "selu": mx_nn.selu, 22 | "sigmoid": mx.sigmoid, 23 | "sum": mx.sum, 24 | } 25 | 26 | torch_simple_operations = { 27 | "sort": torch.sort, 28 | "argmax": torch.argmax, 29 | "softmax": F.softmax, 30 | "relu": F.relu, 31 | "leaky_relu": F.leaky_relu, 32 | "prelu": F.prelu, 33 | "softplus": F.softplus, 34 | "selu": F.selu, 35 | "sigmoid": F.sigmoid, 36 | "sum": torch.sum, 37 | } 38 | 39 | if USE_MLX: 40 | assert ( 41 | torch_simple_operations.keys() == mlx_simple_operations.keys() 42 | ), "torch and mlx operations are not the same." 43 | simple_operations = { 44 | op: [mlx_simple_operations[op], torch_simple_operations[op]] 45 | for op in torch_simple_operations.keys() 46 | } 47 | else: 48 | simple_operations = { 49 | op: [None, torch_simple_operations[op]] for op in torch_simple_operations.keys() 50 | } 51 | 52 | 53 | class SimpleOperationBenchmark(BaseBenchmark): 54 | """ 55 | This class can be overridden for simple operations that only necessitate the 56 | default inputs and an optional `axis` arg. 57 | """ 58 | 59 | def __init__(self, basic_operation, **kwargs): 60 | super().__init__(**kwargs) 61 | 62 | self.basic_operation = basic_operation 63 | 64 | def forward_mlx(self, compile=False, **kwargs): 65 | fn = simple_operations[self.basic_operation][0] 66 | 67 | if "axis" in self.kwargs: 68 | kwargs["axis"] = self.kwargs["axis"] 69 | 70 | fn = self.compile_if_needed(fn, compile=compile) 71 | y = fn(*self.inputs) 72 | mx.eval(y) 73 | 74 | @torch.no_grad() 75 | def forward_torch(self, **kwargs): 76 | fn = simple_operations[self.basic_operation][1] 77 | 78 | kwargs = {} 79 | if "axis" in self.kwargs: 80 | kwargs["dim"] = self.kwargs["axis"] 81 | 82 | y = fn(*self.inputs, **kwargs) 83 | self.sync_torch_gpu_if_needed() 84 | 85 | 86 | class Sort(SimpleOperationBenchmark): 87 | def __init__(self, **kwargs): 88 | super().__init__("sort", **kwargs) 89 | 90 | 91 | class Argmax(SimpleOperationBenchmark): 92 | def __init__(self, **kwargs): 93 | super().__init__("argmax", **kwargs) 94 | 95 | 96 | class Softmax(SimpleOperationBenchmark): 97 | def __init__(self, **kwargs): 98 | super().__init__("softmax", **kwargs) 99 | 100 | 101 | class ReLU(SimpleOperationBenchmark): 102 | def __init__(self, **kwargs): 103 | super().__init__("relu", **kwargs) 104 | 105 | 106 | class LeakyReLU(SimpleOperationBenchmark): 107 | def __init__(self, **kwargs): 108 | super().__init__("leaky_relu", **kwargs) 109 | 110 | 111 | class PReLU(SimpleOperationBenchmark): 112 | def __init__(self, **kwargs): 113 | super().__init__("prelu", **kwargs) 114 | 115 | 116 | class Softplus(SimpleOperationBenchmark): 117 | def __init__(self, **kwargs): 118 | super().__init__("softplus", **kwargs) 119 | 120 | 121 | class SeLU(SimpleOperationBenchmark): 122 | def __init__(self, **kwargs): 123 | super().__init__("selu", **kwargs) 124 | 125 | 126 | class Sigmoid(SimpleOperationBenchmark): 127 | def __init__(self, **kwargs): 128 | super().__init__("sigmoid", **kwargs) 129 | 130 | 131 | class Sum(SimpleOperationBenchmark): 132 | def __init__(self, **kwargs): 133 | super().__init__("sum", **kwargs) 134 | 135 | 136 | class SumAll(SimpleOperationBenchmark): 137 | def __init__(self, **kwargs): 138 | super().__init__("sum", **kwargs) 139 | -------------------------------------------------------------------------------- /mlx_benchmark/run_benchmark.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | 3 | try: 4 | mp.set_start_method("spawn", force=True) 5 | except RuntimeError: 6 | pass 7 | 8 | from argparse import ArgumentParser 9 | from collections import defaultdict 10 | from distutils.util import strtobool 11 | 12 | import numpy as np 13 | import torch 14 | from tqdm import tqdm 15 | 16 | from config import USE_MLX 17 | 18 | if USE_MLX: 19 | import mlx.core as mx 20 | 21 | from utils import print_benchmark 22 | from operations import * 23 | 24 | 25 | def run_processes(operations, backends, iterations=5): 26 | """ 27 | Runs all operations, on all backends, for a specified number of iterations. 28 | """ 29 | all_times = defaultdict(dict) 30 | 31 | for i, backend in enumerate(backends): 32 | print(f"\nRunning benchmarks on {backend} ({i + 1}/{len(backends)})") 33 | if "mlx" in backend: 34 | times = run_mlx_backend(operations, backend, iterations) 35 | else: 36 | times = run_backend(operations, backend, iterations) 37 | for op_name, duration in times.items(): 38 | all_times[op_name][backend] = duration 39 | 40 | print("\nDetailed benchmark:") 41 | print_benchmark(all_times, backends) 42 | print("\n Average benchmark:") 43 | print_benchmark(all_times, backends, reduce_mean=True) 44 | 45 | 46 | def run_mlx_backend(operations, backend, iterations, ops_per_process=10): 47 | """ 48 | Runs all operations on the given backend in a separate process, to reduce memory requirements. 49 | 50 | ``ops_per_process`` determines the number of ops to run in each process. 51 | Decreasing this number decreases the max memory usage. 52 | """ 53 | times = {} 54 | 55 | with tqdm(total=len(operations)) as pbar: 56 | for i in range(0, len(operations), ops_per_process): 57 | queue = mp.Queue() 58 | p = mp.Process( 59 | target=run_backend, 60 | args=( 61 | operations[i : min(i + 10, len(operations))], 62 | backend, 63 | iterations, 64 | queue, 65 | ), 66 | ) 67 | p.start() 68 | p.join() 69 | times.update(queue.get()) 70 | queue.close() 71 | pbar.update(min(10, len(operations) - i)) 72 | 73 | return times 74 | 75 | 76 | def run_backend(operations, backend, iterations, queue=None): 77 | """ 78 | Runs all operations on the given backend for a specified number of iterations. 79 | """ 80 | times = {} 81 | 82 | op_iterable = tqdm(operations) if queue is None else operations 83 | 84 | for op in op_iterable: 85 | op_name = type(op).__name__ + " / " + op.args_str 86 | duration = run(op, backend, iterations) 87 | times[op_name] = duration 88 | 89 | if queue is None: 90 | return times 91 | queue.put(times) 92 | 93 | 94 | def run(op, backend, iterations): 95 | """ 96 | Measures runtime of a single op on the given framework and device specified by backend. 97 | """ 98 | 99 | if backend == "mlx_gpu": 100 | mx.set_default_device(mx.gpu) 101 | duration = np.mean([op.run(framework="mlx") for _ in range(iterations)]) 102 | elif backend == "mlx_gpu_compile": 103 | mx.set_default_device(mx.gpu) 104 | duration = np.mean( 105 | [op.run(framework="mlx", compile=True) for _ in range(iterations)] 106 | ) 107 | elif backend == "mlx_cpu": 108 | mx.set_default_device(mx.cpu) 109 | duration = np.mean([op.run(framework="mlx") for _ in range(iterations)]) 110 | elif backend == "cpu": 111 | duration = np.mean( 112 | [op.run(framework="torch", device="cpu") for _ in range(iterations)] 113 | ) 114 | elif backend == "mps": 115 | duration = np.mean( 116 | [op.run(framework="torch", device="mps") for _ in range(iterations)] 117 | ) 118 | elif backend == "cuda": 119 | duration = np.mean( 120 | [op.run(framework="torch", device="cuda") for _ in range(iterations)] 121 | ) 122 | 123 | op.clear() 124 | 125 | if backend == "mps": 126 | torch.mps.empty_cache() 127 | 128 | if backend == "cuda": 129 | torch.cuda.empty_cache() 130 | 131 | return duration 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = ArgumentParser() 136 | parser.add_argument("--include_mlx_gpu", type=strtobool, default="True") 137 | parser.add_argument("--include_mlx_gpu_compile", type=strtobool, default="True") 138 | parser.add_argument("--include_mlx_cpu", type=strtobool, default="True") 139 | parser.add_argument("--include_mps", type=strtobool, default="True") 140 | parser.add_argument("--include_cpu", type=strtobool, default="True") 141 | parser.add_argument("--include_cuda", type=strtobool, default="False") 142 | args = parser.parse_args() 143 | print(args) 144 | print(f"Use MLX: {USE_MLX}") 145 | 146 | if args.include_mps: 147 | assert torch.backends.mps.is_available(), "MPS backend not available." 148 | if args.include_cuda: 149 | assert torch.cuda.is_available(), "CUDA device not found." 150 | 151 | backends = [ 152 | arg.replace("include_", "") for arg, value in vars(args).items() if value 153 | ] 154 | 155 | operations = [ 156 | Argmax(dim1="64x1024x128", axis=0), 157 | Argmax(dim1="64x1024x128", axis=1), 158 | Argmax(dim1="64x1024x128", axis=2), 159 | Argmax(dim1="64x128x1024", axis=2), 160 | BCE(dim1="1000000", dim2="1000000"), 161 | BCE(dim1="100000x32", dim2="100000x32"), 162 | BCE(dim1="100000x64x2", dim2="100000x64x2"), 163 | BCE(dim1="128x100000", dim2="128x100000"), 164 | Concat(dim1="1000000x64", dim2="1000000x32", axis=1), 165 | Concat(dim1="1000000x64", dim2="1000000x128", axis=1), 166 | Concat(dim1="1000000x64", dim2="1000000x64", axis=0), 167 | Concat(dim1="64x1000000", dim2="64x1000000", axis=0), 168 | Conv1d(dim1="100x256x3", dim2="8x3x3"), 169 | Conv1d(dim1="100x256x256", dim2="8x3x256"), 170 | Conv1d(dim1="16x1000x80", dim2="128x11x80"), 171 | Conv1d(dim1="16x1000x3", dim2="128x11x3"), 172 | Conv2d(dim1="100x256x256x3", dim2="8x3x3x3"), 173 | Conv2d(dim1="10x256x256x12", dim2="8x3x3x12"), 174 | Conv2d(dim1="1x256x256x128", dim2="8x3x3x128"), 175 | Conv2d(dim1="100x28x28x3", dim2="8x3x3x3"), 176 | Conv2d(dim1="1000x28x28x3", dim2="8x3x3x3"), 177 | Gather(dim1="64x256", dim2="10"), 178 | Gather(dim1="64x256", dim2="1000"), 179 | Gather(dim1="64x256", dim2="1000000"), 180 | Gather(dim1="1024x32", dim2="10"), 181 | Gather(dim1="1024x32", dim2="1000"), 182 | Gather(dim1="1024x32", dim2="1000000"), 183 | LayerNorm(dim1="64x128x1024"), 184 | LayerNorm(dim1="128x64x1024"), 185 | LeakyReLU(dim1="128x16x1024"), 186 | LeakyReLU(dim1="64x128x1024"), 187 | Linear(dim1="100x1024x32", dim2="32x1024", dim3="1024"), 188 | Linear(dim1="100x1024x64", dim2="64x1024", dim3="1024"), 189 | Linear(dim1="100x1024x256", dim2="256x1024", dim3="1024"), 190 | Linear(dim1="100x1024x512", dim2="512x1024", dim3="1024"), 191 | Linear(dim1="100x1x51200", dim2="51200x1", dim3="1"), 192 | MatMul(dim1="32x1x1000", dim2="32x1000x128"), 193 | MatMul(dim1="1000x64x256", dim2="256x32"), 194 | MatMul(dim1="1000x64x1024", dim2="1000x1024x32"), 195 | MatMul(dim1="1000x1024x64", dim2="1000x64x256"), 196 | MatMul(dim1="64x1000000", dim2="1000000x32"), 197 | MatMul(dim1="1000000x64", dim2="64x1024"), 198 | PReLU(dim1="128x16x1024", dim2="1"), 199 | PReLU(dim1="64x128x1024", dim2="1"), 200 | ReLU(dim1="128x16x1024"), 201 | ReLU(dim1="64x128x1024"), 202 | ScaledDotProductAttention(dim1="64x8x16x1024", dim2="64x8x16x1024", dim3="64x8x16x1024"), 203 | ScaledDotProductAttention(dim1="64x8x16x1024", dim2="64x8x64x1024", dim3="64x8x64x1024"), 204 | ScaledDotProductAttention(dim1="64x16x16x1024", dim2="64x16x16x1024", dim3="64x16x16x1024"), 205 | ScaledDotProductAttention(dim1="64x16x16x1024", dim2="64x16x64x1024", dim3="64x16x64x1024"), 206 | Scatter(dim1="64x16", dim2="10"), 207 | Scatter(dim1="64x16", dim2="1000"), 208 | Scatter(dim1="64x16", dim2="1000000"), 209 | Scatter(dim1="1024x32", dim2="10"), 210 | Scatter(dim1="1024x32", dim2="1000"), 211 | Scatter(dim1="1024x32", dim2="1000000"), 212 | ScatterSum(dim1="64x16", dim2="10"), 213 | ScatterSum(dim1="64x16", dim2="1000"), 214 | ScatterSum(dim1="64x16", dim2="1000000"), 215 | ScatterSum(dim1="1024x32", dim2="10"), 216 | ScatterSum(dim1="1024x32", dim2="1000"), 217 | ScatterSum(dim1="1024x32", dim2="1000000"), 218 | ScatterMax(dim1="64x16", dim2="10"), 219 | ScatterMax(dim1="64x16", dim2="1000"), 220 | ScatterMax(dim1="64x16", dim2="1000000"), 221 | ScatterMax(dim1="1024x32", dim2="10"), 222 | ScatterMax(dim1="1024x32", dim2="1000"), 223 | ScatterMax(dim1="1024x32", dim2="1000000"), 224 | SeLU(dim1="128x16x1024"), 225 | SeLU(dim1="64x128x1024"), 226 | Sigmoid(dim1="128x16x1024"), 227 | Sigmoid(dim1="64x128x1024"), 228 | Softmax(dim1="64x1000000", axis=-1), 229 | Softmax(dim1="1000000x64", axis=-1), 230 | Softmax(dim1="64x16x32x1024", axis=-1), 231 | Softmax(dim1="128x16x32x1024", axis=-1), 232 | Softmax(dim1="1024x16x32x128", axis=-1), 233 | Softmax(dim1="1024x64x32x8", axis=-1), 234 | Softplus(dim1="128x16x1024"), 235 | Softplus(dim1="64x128x1024"), 236 | Sort(dim1="64x128x1024", axis=0), 237 | Sort(dim1="64x128x1024", axis=1), 238 | Sort(dim1="64x128x1024", axis=2), 239 | Sum(dim1="64x128x128x128", axis=0), 240 | Sum(dim1="64x128x128x128", axis=1), 241 | Sum(dim1="64x128x128x128", axis=2), 242 | Sum(dim1="64x128x128x128", axis=3), 243 | SumAll(dim1="64x128x128x128"), 244 | SumAll(dim1="1000000"), 245 | SumAll(dim1="1000000x128"), 246 | SumAll(dim1="128x1000000"), 247 | ] 248 | run_processes(operations, backends) 249 | -------------------------------------------------------------------------------- /mlx_benchmark/run_viz.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | import matplotlib.cm as cm 7 | import re 8 | from io import StringIO 9 | 10 | 11 | # Function to extract tables from markdown content 12 | def extract_markdown_tables(md_content): 13 | tables = [] 14 | current_table = [] 15 | in_table = False 16 | lines = md_content.split("\n") 17 | 18 | for line in lines: 19 | # Check for table start/end 20 | if re.match(r"\|.*\|", line): 21 | in_table = True 22 | current_table.append(line) 23 | elif in_table: 24 | # Table has ended 25 | tables.append("\n".join(current_table)) 26 | current_table = [] 27 | in_table = False 28 | 29 | # Add the last table if it wasn't added 30 | if in_table: 31 | tables.append("\n".join(current_table)) 32 | 33 | return tables 34 | 35 | 36 | # Function to extract and shorten the operation name if it's too long 37 | def shorten_name(x): 38 | return x.split("/")[0].strip() if "/" in x else re.sub(r"\s*\([^)]*\)", "", x) 39 | 40 | 41 | def plot(file_content, operation): 42 | # Extracting tables from the file content 43 | tables_md = extract_markdown_tables(file_content) 44 | 45 | # Convert markdown tables to Pandas DataFrames 46 | dataframes = [] 47 | for table_md in tables_md: 48 | df = pd.read_table(StringIO(table_md), sep="|") 49 | df.columns = df.columns.str.strip() # Clean column names 50 | df = df.apply( 51 | lambda x: x.str.strip() if x.dtype == "object" else x 52 | ) # Clean cell values 53 | df.dropna(axis=1, how="all", inplace=True) # Drop columns with all NaN values 54 | dataframes.append(df) 55 | 56 | # Extracting table titles and operation values for each table 57 | title_pattern = r"\*\*(.*?)\s*\(" 58 | titles = re.findall(title_pattern, file_content) 59 | 60 | # Preparing data for the visualization 61 | mlx_gpu_values = [] 62 | mps_values = [] 63 | cuda_values = [] 64 | table_titles = [] 65 | 66 | for i, df in enumerate(dataframes): 67 | if "Operation" in df.columns and operation in df["Operation"].values: 68 | op_row = df[df["Operation"] == operation] 69 | 70 | title = titles[i] if i < len(titles) else f"Table {i+1}" 71 | table_titles.append(shorten_name(title)) 72 | 73 | # Extract mlx_gpu, mps, and cuda values 74 | mlx_gpu_val = ( 75 | op_row["mlx_gpu"].values[0] if "mlx_gpu" in df.columns else np.nan 76 | ) 77 | mps_val = op_row["mps"].values[0] if "mps" in df.columns else np.nan 78 | cuda_val = op_row["cuda"].values[0] if "cuda" in df.columns else np.nan 79 | 80 | mlx_gpu_values.append(pd.to_numeric(mlx_gpu_val, errors="coerce")) 81 | mps_values.append(pd.to_numeric(mps_val, errors="coerce")) 82 | cuda_values.append(pd.to_numeric(cuda_val, errors="coerce")) 83 | 84 | plt.figure(figsize=(12, 6)) 85 | 86 | # Colors for mlx_gpu, mps, and cuda 87 | mlx_gpu_color = "skyblue" 88 | mps_color = "salmon" 89 | cuda_color = "lightgreen" 90 | 91 | bar_width = 0.25 92 | indices = np.arange(len(table_titles)) 93 | 94 | # Plot each set of bars 95 | for i in indices: 96 | # Plot mlx_gpu and annotate 97 | bar_mlx_gpu = plt.bar( 98 | i - bar_width, 99 | mlx_gpu_values[i], 100 | bar_width, 101 | color=mlx_gpu_color, 102 | label="mlx_gpu" if i == 0 else "", 103 | ) 104 | if not np.isnan(mlx_gpu_values[i]): 105 | plt.annotate( 106 | f"{mlx_gpu_values[i]:.2f}", 107 | ( 108 | bar_mlx_gpu[0].get_x() + bar_mlx_gpu[0].get_width() / 2, 109 | bar_mlx_gpu[0].get_height(), 110 | ), 111 | ha="center", 112 | va="bottom", 113 | ) 114 | 115 | # Plot mps and annotate 116 | bar_mps = plt.bar( 117 | i, mps_values[i], bar_width, color=mps_color, label="mps" if i == 0 else "" 118 | ) 119 | if not np.isnan(mps_values[i]): 120 | plt.annotate( 121 | f"{mps_values[i]:.2f}", 122 | ( 123 | bar_mps[0].get_x() + bar_mps[0].get_width() / 2, 124 | bar_mps[0].get_height(), 125 | ), 126 | ha="center", 127 | va="bottom", 128 | ) 129 | 130 | # Plot cuda and annotate 131 | bar_cuda = plt.bar( 132 | i + bar_width, 133 | cuda_values[i], 134 | bar_width, 135 | color=cuda_color, 136 | label="cuda" if i == 0 else "", 137 | ) 138 | if not np.isnan(cuda_values[i]): 139 | plt.annotate( 140 | f"{cuda_values[i]:.2f}", 141 | ( 142 | bar_cuda[0].get_x() + bar_cuda[0].get_width() / 2, 143 | bar_cuda[0].get_height(), 144 | ), 145 | ha="center", 146 | va="bottom", 147 | ) 148 | 149 | plt.xlabel("Chips/GPUs") 150 | plt.ylabel("Runtime (ms)") 151 | plt.title(f'"{operation}" benchmark') 152 | plt.xticks(indices, table_titles, rotation=45) 153 | plt.legend() 154 | plt.tight_layout() 155 | 156 | # Save the plot 157 | op_chart_table_titles_path = f"plot_{operation}.png" 158 | plt.savefig(op_chart_table_titles_path) 159 | plt.close() 160 | 161 | 162 | if __name__ == "__main__": 163 | current_directory = Path(__file__).parent 164 | parent_directory = current_directory.parent 165 | file_path = f"{parent_directory}/benchmarks/average_benchmark.md" 166 | 167 | with open(file_path, "r") as file: 168 | file_content = file.read() 169 | 170 | for op in [ 171 | "Linear", 172 | "Concat", 173 | "MatMul", 174 | "Softmax", 175 | "Conv2d", 176 | "BCE", 177 | "Sort", 178 | "Sigmoid", 179 | ]: 180 | plot(file_content, operation=op) 181 | -------------------------------------------------------------------------------- /mlx_benchmark/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import defaultdict 3 | 4 | import torch 5 | import numpy as np 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | 9 | 10 | def load_mnist(path="data/mnist/"): 11 | transform = transforms.Compose([transforms.ToTensor()]) 12 | 13 | # Only load the train set lcoally. 14 | data = torchvision.datasets.MNIST( 15 | root=path, train=True, download=True, transform=transform 16 | ) 17 | return data 18 | 19 | 20 | def calculate_speedup(a, compared_to): 21 | percentage_difference = -((a - compared_to) / a) 22 | return percentage_difference * 100 23 | 24 | 25 | def print_benchmark(times, backends, reduce_mean=False): 26 | times = dict(times) 27 | 28 | if reduce_mean: 29 | new_times = defaultdict(lambda: defaultdict(list)) 30 | for k, v in times.items(): 31 | op = k.split("/")[0] 32 | for backend, runtime in v.items(): 33 | new_times[op][backend].append(runtime) 34 | 35 | for k, v in new_times.items(): 36 | for backend, runtimes in v.items(): 37 | new_times[k][backend] = np.mean(new_times[k][backend]) 38 | times = new_times 39 | 40 | # Column headers 41 | header_order = ["mlx_gpu", "mlx_gpu_compile", "mlx_cpu", "mps", "cpu", "cuda"] 42 | headers = sorted(backends, key=lambda x: header_order.index(x)) 43 | 44 | if "mlx_gpu_compile" in backends and "mlx_gpu" in backends: 45 | h = "mlx_gpu_compile/mlx_gpu speedup" 46 | headers.append(h) 47 | for k, v in times.items(): 48 | v[h] = calculate_speedup(v["mlx_gpu_compile"], compared_to=v["mlx_gpu"]) 49 | 50 | if "mps" in backends and "mlx_gpu" in backends: 51 | h = "mlx_gpu/mps speedup" 52 | headers.append(h) 53 | for k, v in times.items(): 54 | v[h] = calculate_speedup(v["mlx_gpu"], compared_to=v["mps"]) 55 | 56 | if "mlx_cpu" in backends and "mlx_gpu" in backends: 57 | h = "mlx_gpu/mlx_cpu speedup" 58 | headers.append(h) 59 | for k, v in times.items(): 60 | v[h] = calculate_speedup(v["mlx_gpu"], compared_to=v["mlx_cpu"]) 61 | 62 | if "cpu" in backends and "cuda" in backends: 63 | h = "cuda/cpu speedup" 64 | headers.append(h) 65 | for k, v in times.items(): 66 | v[h] = calculate_speedup(v["cuda"], compared_to=v["cpu"]) 67 | 68 | max_name_length = max(len(name) for name in times.keys()) 69 | 70 | # Formatting the header row 71 | header_row = ( 72 | "| Operation" + " " * (max_name_length - 5) + " | " + " | ".join(headers) + " |" 73 | ) 74 | header_line_parts = ["-" * (max_name_length + 6)] + [ 75 | "-" * max(6, len(header)) for header in headers 76 | ] 77 | header_line = "|" + "|".join(header_line_parts) + "|" 78 | 79 | print(header_row) 80 | print(header_line) 81 | 82 | add_plus_symbol = ( 83 | lambda x, rounding: f"{'+' if x > 0 else ''}{(int(x) if not math.isnan(x) else x) if rounding == 0 else round(x, rounding)}" 84 | ) 85 | format_value = ( 86 | lambda header: f"{add_plus_symbol(times[header], 0):>6}%" 87 | if "speedup" in header 88 | else f"{times[header]:>6.2f}" 89 | ) 90 | 91 | for op, times in times.items(): 92 | times_str = " | ".join(format_value(header) for header in headers) 93 | 94 | # Formatting each row 95 | print(f"| {op.ljust(max_name_length)} | {times_str} |") 96 | 97 | 98 | def get_dummy_edge_index(shape, num_nodes, device, framework): 99 | if framework == "mlx": 100 | import mlx.core as mx 101 | 102 | return mx.random.randint(0, num_nodes - 1, shape) 103 | elif framework == "torch": 104 | return torch.randint(0, num_nodes - 1, shape).to(device) 105 | raise ValueError("Framework should be either mlx or torch.") 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlx>=0.0.5 2 | black>=23.12.0 3 | pre-commit>=3.6.0 4 | tqdm>=4.66.1 5 | matplotlib>=3.8.2 6 | pandas>=2.1.4 7 | --------------------------------------------------------------------------------