├── .circleci └── config.yml ├── .clang-format ├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md ├── pull_request_template.md └── workflows │ └── pull_request.yml ├── .gitignore ├── .pre-commit-config.yaml ├── ACKNOWLEDGMENTS.md ├── CITATION.cff ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── benchmarks ├── cpp │ ├── CMakeLists.txt │ ├── autograd.cpp │ ├── compare_devices.cpp │ ├── irregular_strides.cpp │ ├── single_ops.cpp │ └── time_utils.h ├── numpy │ ├── single_ops.py │ └── time_utils.py └── python │ ├── batch_matmul_bench.py │ ├── blas │ ├── bench_gemm.py │ └── bench_gemv.py │ ├── comparative │ ├── README.md │ ├── bench_mlx.py │ ├── bench_torch.py │ └── compare.py │ ├── compile_bench.py │ ├── conv1d_bench.py │ ├── conv2d_bench_cpu.py │ ├── conv2d_train_bench_cpu.py │ ├── conv2d_transpose_bench_cpu.py │ ├── conv3d_bench_cpu.py │ ├── conv3d_train_bench_cpu.py │ ├── conv3d_transpose_bench_cpu.py │ ├── conv_bench.py │ ├── conv_transpose_bench.py │ ├── distributed_bench.py │ ├── einsum_bench.py │ ├── fft_bench.py │ ├── gather_bench.py │ ├── gather_mm_bench.py │ ├── gather_qmm_bench.py │ ├── hadamard_bench.py │ ├── layer_norm_bench.py │ ├── rms_norm_bench.py │ ├── rope_bench.py │ ├── scatter_bench.py │ ├── sdpa_bench.py │ ├── sdpa_vector_bench.py │ ├── single_ops.py │ ├── synchronize_bench.py │ └── time_utils.py ├── cmake └── extension.cmake ├── docs ├── .clang-format ├── .gitignore ├── .nojekyll ├── Doxyfile ├── Makefile ├── README.md ├── index.html ├── requirements.txt └── src │ ├── _static │ ├── metal_debugger │ │ ├── capture.png │ │ └── schema.png │ ├── mlx_logo.png │ └── mlx_logo_dark.png │ ├── _templates │ ├── module-base-class.rst │ ├── nn-module-template.rst │ └── optimizers-template.rst │ ├── conf.py │ ├── cpp │ └── ops.rst │ ├── dev │ ├── custom_metal_kernels.rst │ ├── extensions.rst │ ├── metal_debugger.rst │ └── mlx_in_cpp.rst │ ├── examples │ ├── linear_regression.rst │ ├── llama-inference.rst │ └── mlp.rst │ ├── index.rst │ ├── install.rst │ ├── python │ ├── array.rst │ ├── data_types.rst │ ├── devices_and_streams.rst │ ├── distributed.rst │ ├── export.rst │ ├── fast.rst │ ├── fft.rst │ ├── linalg.rst │ ├── memory_management.rst │ ├── metal.rst │ ├── nn.rst │ ├── nn │ │ ├── functions.rst │ │ ├── init.rst │ │ ├── layers.rst │ │ ├── losses.rst │ │ └── module.rst │ ├── ops.rst │ ├── optimizers.rst │ ├── optimizers │ │ ├── common_optimizers.rst │ │ ├── optimizer.rst │ │ └── schedulers.rst │ ├── random.rst │ ├── transforms.rst │ └── tree_utils.rst │ └── usage │ ├── compile.rst │ ├── distributed.rst │ ├── export.rst │ ├── function_transforms.rst │ ├── indexing.rst │ ├── launching_distributed.rst │ ├── lazy_evaluation.rst │ ├── numpy.rst │ ├── quick_start.rst │ ├── saving_and_loading.rst │ ├── unified_memory.rst │ └── using_streams.rst ├── examples ├── cmake_project │ ├── CMakeLists.txt │ ├── README.md │ └── example.cpp ├── cpp │ ├── CMakeLists.txt │ ├── distributed.cpp │ ├── linear_regression.cpp │ ├── logistic_regression.cpp │ ├── metal_capture.cpp │ ├── timer.h │ └── tutorial.cpp ├── export │ ├── CMakeLists.txt │ ├── README.md │ ├── eval_mlp.cpp │ ├── eval_mlp.py │ ├── train_mlp.cpp │ └── train_mlp.py ├── extensions │ ├── CMakeLists.txt │ ├── README.md │ ├── axpby │ │ ├── axpby.cpp │ │ ├── axpby.h │ │ └── axpby.metal │ ├── bindings.cpp │ ├── mlx_sample_extensions │ │ └── __init__.py │ ├── pyproject.toml │ ├── requirements.txt │ ├── setup.py │ └── test.py └── python │ ├── linear_regression.py │ └── logistic_regression.py ├── mlx.pc.in ├── mlx ├── 3rdparty │ ├── .clang-format │ └── pocketfft.h ├── CMakeLists.txt ├── allocator.cpp ├── allocator.h ├── array.cpp ├── array.h ├── backend │ ├── common │ │ ├── CMakeLists.txt │ │ ├── binary.h │ │ ├── broadcasting.cpp │ │ ├── broadcasting.h │ │ ├── buffer_cache.h │ │ ├── common.cpp │ │ ├── compiled.cpp │ │ ├── compiled.h │ │ ├── copy.h │ │ ├── hadamard.h │ │ ├── load.cpp │ │ ├── reduce.cpp │ │ ├── reduce.h │ │ ├── slicing.cpp │ │ ├── slicing.h │ │ ├── ternary.h │ │ ├── utils.cpp │ │ └── utils.h │ ├── cpu │ │ ├── CMakeLists.txt │ │ ├── arange.h │ │ ├── arg_reduce.cpp │ │ ├── available.cpp │ │ ├── available.h │ │ ├── binary.cpp │ │ ├── binary.h │ │ ├── binary_ops.h │ │ ├── binary_two.h │ │ ├── cholesky.cpp │ │ ├── compiled.cpp │ │ ├── compiled_preamble.h │ │ ├── conv.cpp │ │ ├── copy.cpp │ │ ├── copy.h │ │ ├── distributed.cpp │ │ ├── eig.cpp │ │ ├── eigh.cpp │ │ ├── encoder.cpp │ │ ├── encoder.h │ │ ├── eval.cpp │ │ ├── eval.h │ │ ├── fft.cpp │ │ ├── gemm.h │ │ ├── gemms │ │ │ ├── bnns.cpp │ │ │ ├── cblas.cpp │ │ │ ├── simd_bf16.cpp │ │ │ ├── simd_fp16.cpp │ │ │ └── simd_gemm.h │ │ ├── hadamard.cpp │ │ ├── indexing.cpp │ │ ├── inverse.cpp │ │ ├── jit_compiler.cpp │ │ ├── jit_compiler.h │ │ ├── lapack.h │ │ ├── logsumexp.cpp │ │ ├── luf.cpp │ │ ├── make_compiled_preamble.ps1 │ │ ├── make_compiled_preamble.sh │ │ ├── masked_mm.cpp │ │ ├── matmul.cpp │ │ ├── primitives.cpp │ │ ├── qrf.cpp │ │ ├── quantized.cpp │ │ ├── reduce.cpp │ │ ├── scan.cpp │ │ ├── select.cpp │ │ ├── simd │ │ │ ├── accelerate_fp16_simd.h │ │ │ ├── accelerate_simd.h │ │ │ ├── base_simd.h │ │ │ ├── math.h │ │ │ ├── neon_fp16_simd.h │ │ │ ├── simd.h │ │ │ └── type.h │ │ ├── slicing.h │ │ ├── softmax.cpp │ │ ├── sort.cpp │ │ ├── svd.cpp │ │ ├── ternary.h │ │ ├── threefry.cpp │ │ ├── threefry.h │ │ ├── unary.cpp │ │ ├── unary.h │ │ └── unary_ops.h │ ├── cuda │ │ ├── CMakeLists.txt │ │ ├── allocator.cpp │ │ ├── allocator.h │ │ ├── copy.cpp │ │ ├── device.cpp │ │ ├── device.h │ │ ├── eval.cpp │ │ ├── event.cu │ │ ├── event.h │ │ ├── fence.cu │ │ ├── kernel_utils.cu │ │ ├── kernel_utils.cuh │ │ ├── kernels │ │ │ ├── arange.cuh │ │ │ └── fp16_math.cuh │ │ ├── primitives.cu │ │ ├── slicing.cpp │ │ ├── utils.cpp │ │ ├── utils.h │ │ ├── worker.cpp │ │ └── worker.h │ ├── gpu │ │ ├── CMakeLists.txt │ │ ├── available.h │ │ ├── copy.cpp │ │ ├── copy.h │ │ ├── eval.h │ │ ├── primitives.cpp │ │ ├── slicing.cpp │ │ └── slicing.h │ ├── metal │ │ ├── CMakeLists.txt │ │ ├── allocator.cpp │ │ ├── allocator.h │ │ ├── binary.cpp │ │ ├── binary.h │ │ ├── compiled.cpp │ │ ├── conv.cpp │ │ ├── copy.cpp │ │ ├── custom_kernel.cpp │ │ ├── device.cpp │ │ ├── device.h │ │ ├── distributed.cpp │ │ ├── eval.cpp │ │ ├── event.cpp │ │ ├── fence.cpp │ │ ├── fft.cpp │ │ ├── hadamard.cpp │ │ ├── indexing.cpp │ │ ├── jit │ │ │ ├── includes.h │ │ │ └── indexing.h │ │ ├── jit_kernels.cpp │ │ ├── kernels.h │ │ ├── kernels │ │ │ ├── CMakeLists.txt │ │ │ ├── arange.h │ │ │ ├── arange.metal │ │ │ ├── arg_reduce.metal │ │ │ ├── atomic.h │ │ │ ├── bf16_math.h │ │ │ ├── binary.h │ │ │ ├── binary.metal │ │ │ ├── binary_ops.h │ │ │ ├── binary_two.h │ │ │ ├── binary_two.metal │ │ │ ├── complex.h │ │ │ ├── conv.metal │ │ │ ├── copy.h │ │ │ ├── copy.metal │ │ │ ├── defines.h │ │ │ ├── erf.h │ │ │ ├── expm1f.h │ │ │ ├── fence.metal │ │ │ ├── fft.h │ │ │ ├── fft.metal │ │ │ ├── fft │ │ │ │ ├── radix.h │ │ │ │ └── readwrite.h │ │ │ ├── gather.h │ │ │ ├── gather_axis.h │ │ │ ├── gemv.metal │ │ │ ├── gemv_masked.h │ │ │ ├── gemv_masked.metal │ │ │ ├── hadamard.h │ │ │ ├── indexing.h │ │ │ ├── jit │ │ │ │ └── bf16.h │ │ │ ├── layer_norm.metal │ │ │ ├── logsumexp.h │ │ │ ├── logsumexp.metal │ │ │ ├── metal_3_0 │ │ │ │ └── bf16.h │ │ │ ├── metal_3_1 │ │ │ │ └── bf16.h │ │ │ ├── quantized.h │ │ │ ├── quantized.metal │ │ │ ├── random.metal │ │ │ ├── reduce.h │ │ │ ├── reduce.metal │ │ │ ├── reduce_utils.h │ │ │ ├── reduction │ │ │ │ ├── ops.h │ │ │ │ ├── reduce_all.h │ │ │ │ ├── reduce_col.h │ │ │ │ ├── reduce_init.h │ │ │ │ └── reduce_row.h │ │ │ ├── rms_norm.metal │ │ │ ├── rope.metal │ │ │ ├── scaled_dot_product_attention.metal │ │ │ ├── scan.h │ │ │ ├── scan.metal │ │ │ ├── scatter.h │ │ │ ├── scatter_axis.h │ │ │ ├── sdpa_vector.h │ │ │ ├── softmax.h │ │ │ ├── softmax.metal │ │ │ ├── sort.h │ │ │ ├── sort.metal │ │ │ ├── steel │ │ │ │ ├── attn │ │ │ │ │ ├── attn.h │ │ │ │ │ ├── kernels │ │ │ │ │ │ ├── steel_attention.h │ │ │ │ │ │ └── steel_attention.metal │ │ │ │ │ ├── loader.h │ │ │ │ │ ├── mma.h │ │ │ │ │ ├── params.h │ │ │ │ │ └── transforms.h │ │ │ │ ├── conv │ │ │ │ │ ├── conv.h │ │ │ │ │ ├── kernels │ │ │ │ │ │ ├── steel_conv.h │ │ │ │ │ │ ├── steel_conv.metal │ │ │ │ │ │ ├── steel_conv_general.h │ │ │ │ │ │ └── steel_conv_general.metal │ │ │ │ │ ├── loader.h │ │ │ │ │ ├── loaders │ │ │ │ │ │ ├── loader_channel_l.h │ │ │ │ │ │ ├── loader_channel_n.h │ │ │ │ │ │ └── loader_general.h │ │ │ │ │ └── params.h │ │ │ │ ├── defines.h │ │ │ │ ├── gemm │ │ │ │ │ ├── gemm.h │ │ │ │ │ ├── kernels │ │ │ │ │ │ ├── steel_gemm_fused.h │ │ │ │ │ │ ├── steel_gemm_fused.metal │ │ │ │ │ │ ├── steel_gemm_gather.h │ │ │ │ │ │ ├── steel_gemm_gather.metal │ │ │ │ │ │ ├── steel_gemm_masked.h │ │ │ │ │ │ ├── steel_gemm_masked.metal │ │ │ │ │ │ ├── steel_gemm_splitk.h │ │ │ │ │ │ └── steel_gemm_splitk.metal │ │ │ │ │ ├── loader.h │ │ │ │ │ ├── mma.h │ │ │ │ │ ├── params.h │ │ │ │ │ └── transforms.h │ │ │ │ ├── utils.h │ │ │ │ └── utils │ │ │ │ │ ├── integral_constant.h │ │ │ │ │ └── type_traits.h │ │ │ ├── ternary.h │ │ │ ├── ternary.metal │ │ │ ├── ternary_ops.h │ │ │ ├── unary.h │ │ │ ├── unary.metal │ │ │ ├── unary_ops.h │ │ │ └── utils.h │ │ ├── logsumexp.cpp │ │ ├── make_compiled_preamble.sh │ │ ├── matmul.cpp │ │ ├── matmul.h │ │ ├── metal.cpp │ │ ├── metal.h │ │ ├── no_metal.cpp │ │ ├── nojit_kernels.cpp │ │ ├── normalization.cpp │ │ ├── primitives.cpp │ │ ├── quantized.cpp │ │ ├── reduce.cpp │ │ ├── reduce.h │ │ ├── resident.cpp │ │ ├── resident.h │ │ ├── rope.cpp │ │ ├── scaled_dot_product_attention.cpp │ │ ├── scan.cpp │ │ ├── slicing.cpp │ │ ├── softmax.cpp │ │ ├── sort.cpp │ │ ├── ternary.cpp │ │ ├── ternary.h │ │ ├── unary.cpp │ │ ├── unary.h │ │ ├── utils.cpp │ │ └── utils.h │ ├── no_cpu │ │ ├── CMakeLists.txt │ │ ├── available.cpp │ │ ├── compiled.cpp │ │ └── primitives.cpp │ └── no_gpu │ │ ├── CMakeLists.txt │ │ ├── allocator.cpp │ │ ├── apple_memory.h │ │ ├── eval.cpp │ │ ├── event.cpp │ │ ├── fence.cpp │ │ ├── linux_memory.h │ │ └── primitives.cpp ├── compile.cpp ├── compile.h ├── compile_impl.h ├── device.cpp ├── device.h ├── distributed │ ├── CMakeLists.txt │ ├── distributed.cpp │ ├── distributed.h │ ├── distributed_impl.h │ ├── mpi │ │ ├── CMakeLists.txt │ │ ├── mpi.cpp │ │ ├── mpi.h │ │ ├── mpi_declarations.h │ │ └── no_mpi.cpp │ ├── ops.cpp │ ├── ops.h │ ├── primitives.cpp │ ├── primitives.h │ └── ring │ │ ├── CMakeLists.txt │ │ ├── no_ring.cpp │ │ ├── ring.cpp │ │ └── ring.h ├── dtype.cpp ├── dtype.h ├── dtype_utils.cpp ├── dtype_utils.h ├── einsum.cpp ├── einsum.h ├── event.h ├── export.cpp ├── export.h ├── export_impl.h ├── fast.cpp ├── fast.h ├── fast_primitives.h ├── fence.h ├── fft.cpp ├── fft.h ├── graph_utils.cpp ├── graph_utils.h ├── io.h ├── io │ ├── CMakeLists.txt │ ├── gguf.cpp │ ├── gguf.h │ ├── gguf_quants.cpp │ ├── load.cpp │ ├── load.h │ ├── no_gguf.cpp │ ├── no_safetensors.cpp │ └── safetensors.cpp ├── linalg.cpp ├── linalg.h ├── memory.h ├── mlx.h ├── ops.cpp ├── ops.h ├── primitives.cpp ├── primitives.h ├── random.cpp ├── random.h ├── scheduler.cpp ├── scheduler.h ├── stream.h ├── threadpool.h ├── transforms.cpp ├── transforms.h ├── transforms_impl.h ├── types │ ├── bf16.h │ ├── complex.h │ ├── fp16.h │ ├── half_types.h │ └── limits.h ├── utils.cpp ├── utils.h ├── version.cpp └── version.h ├── pyproject.toml ├── python ├── mlx │ ├── __main__.py │ ├── _os_warning.py │ ├── _reprlib_fix.py │ ├── _stub_patterns.txt │ ├── distributed_run.py │ ├── extension.py │ ├── nn │ │ ├── __init__.py │ │ ├── init.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── base.py │ │ │ ├── containers.py │ │ │ ├── convolution.py │ │ │ ├── convolution_transpose.py │ │ │ ├── distributed.py │ │ │ ├── dropout.py │ │ │ ├── embedding.py │ │ │ ├── linear.py │ │ │ ├── normalization.py │ │ │ ├── pooling.py │ │ │ ├── positional_encoding.py │ │ │ ├── quantized.py │ │ │ ├── recurrent.py │ │ │ ├── transformer.py │ │ │ └── upsample.py │ │ ├── losses.py │ │ └── utils.py │ ├── optimizers │ │ ├── __init__.py │ │ ├── optimizers.py │ │ └── schedulers.py │ ├── py.typed │ └── utils.py ├── src │ ├── CMakeLists.txt │ ├── array.cpp │ ├── buffer.h │ ├── constants.cpp │ ├── convert.cpp │ ├── convert.h │ ├── device.cpp │ ├── distributed.cpp │ ├── export.cpp │ ├── fast.cpp │ ├── fft.cpp │ ├── indexing.cpp │ ├── indexing.h │ ├── linalg.cpp │ ├── load.cpp │ ├── load.h │ ├── memory.cpp │ ├── metal.cpp │ ├── mlx.cpp │ ├── mlx_func.cpp │ ├── mlx_func.h │ ├── ops.cpp │ ├── random.cpp │ ├── stream.cpp │ ├── transforms.cpp │ ├── trees.cpp │ ├── trees.h │ ├── utils.cpp │ └── utils.h └── tests │ ├── mlx_distributed_tests.py │ ├── mlx_tests.py │ ├── mpi_test_distributed.py │ ├── ring_test_distributed.py │ ├── test_array.py │ ├── test_autograd.py │ ├── test_bf16.py │ ├── test_blas.py │ ├── test_compile.py │ ├── test_constants.py │ ├── test_conv.py │ ├── test_conv_transpose.py │ ├── test_device.py │ ├── test_double.py │ ├── test_einsum.py │ ├── test_eval.py │ ├── test_export_import.py │ ├── test_fast.py │ ├── test_fast_sdpa.py │ ├── test_fft.py │ ├── test_graph.py │ ├── test_init.py │ ├── test_linalg.py │ ├── test_load.py │ ├── test_losses.py │ ├── test_memory.py │ ├── test_nn.py │ ├── test_ops.py │ ├── test_optimizers.py │ ├── test_quantized.py │ ├── test_random.py │ ├── test_reduce.py │ ├── test_tree.py │ ├── test_upsample.py │ └── test_vmap.py ├── setup.py └── tests ├── CMakeLists.txt ├── allocator_tests.cpp ├── arg_reduce_tests.cpp ├── array_tests.cpp ├── autograd_tests.cpp ├── blas_tests.cpp ├── compile_tests.cpp ├── creations_tests.cpp ├── custom_vjp_tests.cpp ├── device_tests.cpp ├── einsum_tests.cpp ├── eval_tests.cpp ├── export_import_tests.cpp ├── fft_tests.cpp ├── gpu_tests.cpp ├── linalg_tests.cpp ├── load_tests.cpp ├── ops_tests.cpp ├── random_tests.cpp ├── scheduler_tests.cpp ├── tests.cpp ├── utils_tests.cpp └── vmap_tests.cpp /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report about an issue you've encountered 4 | title: "[BUG] " 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | 15 | Include code snippet 16 | ```python 17 | 18 | ``` 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Desktop (please complete the following information):** 24 | - OS Version: [e.g. MacOS 14.1.2] 25 | - Version [e.g. 0.7.0] 26 | 27 | **Additional context** 28 | Add any other context about the problem here. 29 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Proposed changes 2 | 3 | Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #. 4 | 5 | ## Checklist 6 | 7 | Put an `x` in the boxes that apply. 8 | 9 | - [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document 10 | - [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes 11 | - [ ] I have added tests that prove my fix is effective or that my feature works 12 | - [ ] I have updated the necessary documentation (if needed) 13 | -------------------------------------------------------------------------------- /.github/workflows/pull_request.yml: -------------------------------------------------------------------------------- 1 | on: 2 | pull_request: 3 | branches: 4 | - main 5 | 6 | jobs: 7 | check_lint: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v4 11 | - uses: actions/setup-python@v4 12 | with: 13 | python-version: 3.8 14 | - name: Install dependencies 15 | run: | 16 | python -m pip install --upgrade pip 17 | pip install pre-commit black isort clang-format 18 | - name: Run lint 19 | run: | 20 | pre-commit run --all-files 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # tensor files 10 | *.safe 11 | *.safetensors 12 | 13 | # Metal libraries 14 | *.metallib 15 | venv/ 16 | 17 | # Distribution / packaging 18 | python/mlx/core 19 | python/mlx/share 20 | python/mlx/include 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | uv.lock 40 | 41 | # vim 42 | *.swp 43 | 44 | # Ignore build dir 45 | build/ 46 | 47 | # Prerequisites 48 | *.d 49 | 50 | # Compiled Object files 51 | *.slo 52 | *.lo 53 | *.o 54 | *.obj 55 | 56 | # Precompiled Headers 57 | *.gch 58 | *.pch 59 | 60 | # Compiled Dynamic libraries 61 | *.so 62 | *.dylib 63 | *.dll 64 | 65 | # Fortran module files 66 | *.mod 67 | *.smod 68 | 69 | # Compiled Static libraries 70 | *.lai 71 | *.la 72 | *.a 73 | *.lib 74 | 75 | # Executables 76 | *.exe 77 | *.out 78 | *.app 79 | 80 | # Debug symbols 81 | *.pdb 82 | 83 | # VSCode 84 | .vscode/ 85 | .DS_Store 86 | 87 | # Jetbrains 88 | .cache 89 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-clang-format 3 | rev: v19.1.7 4 | hooks: 5 | - id: clang-format 6 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster 7 | - repo: https://github.com/psf/black-pre-commit-mirror 8 | rev: 25.1.0 9 | hooks: 10 | - id: black 11 | 12 | - repo: https://github.com/pycqa/isort 13 | rev: 6.0.0 14 | hooks: 15 | - id: isort 16 | args: 17 | - --profile=black 18 | - repo: https://github.com/cheshirekow/cmake-format-precommit 19 | rev: v0.6.13 20 | hooks: 21 | - id: cmake-format 22 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: mlx 3 | message: >- 4 | If you use this software, please cite it using the 5 | metadata from this file. 6 | type: software 7 | authors: 8 | - given-names: Awni 9 | family-names: Hannun 10 | affiliation: Apple 11 | - given-names: Jagrit 12 | family-names: Digani 13 | affiliation: Apple 14 | - given-names: Angelos 15 | family-names: Katharopoulos 16 | affiliation: Apple 17 | - given-names: Ronan 18 | family-names: Collobert 19 | affiliation: Apple 20 | repository-code: 'https://github.com/ml-explore' 21 | abstract: >- 22 | MLX: efficient and flexible machine learning on Apple 23 | silicon 24 | license: MIT 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MLX 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | 1. Fork and submit pull requests to the repo. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If a change is likely to impact efficiency, run some of the benchmarks before 11 | and after the change. Examples of benchmarks can be found in `benchmarks/python/`. 12 | 4. If you've changed APIs, update the documentation. 13 | 5. Every PR should have passing tests and at least one review. 14 | 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 15 | This should install hooks for running `black` and `clang-format` to ensure 16 | consistent style for C++ and python code. 17 | 18 | You can also run the formatters manually as follows: 19 | 20 | ```shell 21 | clang-format -i file.cpp 22 | ``` 23 | 24 | ```shell 25 | black file.py 26 | ``` 27 | 28 | or run `pre-commit run --all-files` to check all files in the repo. 29 | 30 | ## Issues 31 | 32 | We use GitHub issues to track public bugs. Please ensure your description is 33 | clear and has sufficient instructions to be able to reproduce the issue. 34 | 35 | ## License 36 | 37 | By contributing to MLX, you agree that your contributions will be licensed 38 | under the LICENSE file in the root directory of this source tree. 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CMakeLists.txt 2 | include mlx.pc.in 3 | recursive-include mlx/ * 4 | include cmake/* 5 | include python/src/* 6 | include python/mlx/py.typed # support type hinting as in PEP-561 7 | -------------------------------------------------------------------------------- /benchmarks/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | function(build_benchmark SRCFILE) 2 | get_filename_component(src_name ${SRCFILE} NAME_WE) 3 | set(target "${src_name}") 4 | add_executable(${target} ${SRCFILE}) 5 | target_link_libraries(${target} PRIVATE mlx) 6 | endfunction(build_benchmark) 7 | 8 | build_benchmark(single_ops.cpp) 9 | build_benchmark(irregular_strides.cpp) 10 | build_benchmark(compare_devices.cpp) 11 | build_benchmark(autograd.cpp) 12 | -------------------------------------------------------------------------------- /benchmarks/cpp/autograd.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/mlx.h" 6 | #include "time_utils.h" 7 | 8 | namespace mx = mlx::core; 9 | 10 | void time_value_and_grad() { 11 | auto x = mx::ones({200, 1000}); 12 | mx::eval(x); 13 | auto fn = [](mx::array x) { 14 | for (int i = 0; i < 20; ++i) { 15 | x = mx::log(mx::exp(x)); 16 | } 17 | return mx::sum(x); 18 | }; 19 | 20 | auto grad_fn = mx::grad(fn); 21 | auto independent_value_and_grad = [&]() { 22 | auto value = fn(x); 23 | auto dfdx = grad_fn(x); 24 | return std::vector{value, dfdx}; 25 | }; 26 | TIME(independent_value_and_grad); 27 | 28 | auto value_and_grad_fn = mx::value_and_grad(fn); 29 | auto combined_value_and_grad = [&]() { 30 | auto [value, dfdx] = value_and_grad_fn(x); 31 | return std::vector{value, dfdx}; 32 | }; 33 | TIME(combined_value_and_grad); 34 | } 35 | 36 | int main() { 37 | std::cout << "Benchmarks for " << mx::default_device() << std::endl; 38 | time_value_and_grad(); 39 | } 40 | -------------------------------------------------------------------------------- /benchmarks/cpp/compare_devices.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include "mlx/mlx.h" 5 | #include "time_utils.h" 6 | 7 | namespace mx = mlx::core; 8 | 9 | void time_add_op() { 10 | std::vector sizes(1, 1); 11 | for (int i = 0; i < 9; ++i) { 12 | sizes.push_back(10 * sizes.back()); 13 | } 14 | set_default_device(mx::Device::cpu); 15 | for (auto size : sizes) { 16 | auto a = mx::random::uniform({size}); 17 | auto b = mx::random::uniform({size}); 18 | mx::eval(a, b); 19 | std::cout << "Size " << size << std::endl; 20 | TIMEM("cpu", mx::add, a, b, mx::Device::cpu); 21 | TIMEM("gpu", mx::add, a, b, mx::Device::gpu); 22 | } 23 | } 24 | 25 | int main() { 26 | time_add_op(); 27 | } 28 | -------------------------------------------------------------------------------- /benchmarks/cpp/time_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "mlx/mlx.h" 10 | 11 | #define milliseconds(x) \ 12 | (std::chrono::duration_cast(x).count() / 1e6) 13 | #define time_now() std::chrono::high_resolution_clock::now() 14 | 15 | #define TIME(FUNC, ...) \ 16 | std::cout << "Timing " << #FUNC << " ... " << std::flush \ 17 | << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \ 18 | << std::endl; 19 | 20 | #define TIMEM(MSG, FUNC, ...) \ 21 | std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \ 22 | << std::flush << std::setprecision(5) \ 23 | << time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl; 24 | 25 | template 26 | double time_fn(F fn, Args&&... args) { 27 | // warmup 28 | for (int i = 0; i < 5; ++i) { 29 | eval(fn(std::forward(args)...)); 30 | } 31 | 32 | int num_iters = 100; 33 | auto start = time_now(); 34 | for (int i = 0; i < num_iters; i++) { 35 | eval(fn(std::forward(args)...)); 36 | } 37 | auto end = time_now(); 38 | return milliseconds(end - start) / static_cast(num_iters); 39 | } 40 | -------------------------------------------------------------------------------- /benchmarks/numpy/single_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import numpy as np 4 | from time_utils import time_fn 5 | 6 | 7 | def time_add(): 8 | a = np.ones((100, 100, 10), dtype=np.float32) 9 | b = np.ones((100, 100, 10), dtype=np.float32) 10 | time_fn(np.add, a, b) 11 | 12 | 13 | def time_matmul(): 14 | a = np.random.rand(1000, 500).astype(np.float32) 15 | b = np.random.rand(500, 1000).astype(np.float32) 16 | time_fn(np.matmul, a, b) 17 | 18 | 19 | def time_exp(): 20 | a = np.random.randn(1000, 100).astype(np.float32) 21 | time_fn(np.exp, a) 22 | 23 | 24 | def time_take(): 25 | a = np.random.rand(10000, 500) 26 | ids = np.random.randint(0, 10000, (20, 10)) 27 | ids = [idx.reshape(-1) for idx in np.split(ids, 20)] 28 | 29 | def random_take(): 30 | return [np.take(a, idx, 0) for idx in ids] 31 | 32 | time_fn(random_take) 33 | 34 | 35 | if __name__ == "__main__": 36 | time_add() 37 | time_matmul() 38 | time_exp() 39 | time_take() 40 | -------------------------------------------------------------------------------- /benchmarks/numpy/time_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | 5 | 6 | def time_fn(fn, *args): 7 | print(f"Timing {fn.__name__} ...", end=" ") 8 | 9 | # warmup 10 | for _ in range(5): 11 | fn(*args) 12 | 13 | num_iters = 100 14 | tic = time.perf_counter() 15 | for _ in range(num_iters): 16 | x = fn(*args) 17 | toc = time.perf_counter() 18 | 19 | msec = 1e3 * (toc - tic) / num_iters 20 | print(f"{msec:.5f} msec") 21 | -------------------------------------------------------------------------------- /benchmarks/python/batch_matmul_bench.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | 5 | import mlx.core as mx 6 | from time_utils import time_fn 7 | 8 | B = 8 9 | T = 1024 10 | D = 512 11 | 12 | 13 | def time_batch_matmul(): 14 | mx.random.seed(3) 15 | a = mx.random.uniform(shape=(B, T, D)) 16 | b = mx.random.uniform(shape=(D, D)) 17 | c = mx.random.uniform(shape=(B, T, D)) 18 | mx.eval(a, b, c) 19 | 20 | time_fn(mx.matmul, a, b) 21 | 22 | def batch_vjp_first(): 23 | return mx.vjp(mx.matmul, [a, b], [c])[1][0] 24 | 25 | time_fn(batch_vjp_first) 26 | 27 | def batch_vjp_second(): 28 | return mx.vjp(mx.matmul, [a, b], [c])[1][1] 29 | 30 | time_fn(batch_vjp_second) 31 | 32 | 33 | def time_unbatch_matmul(): 34 | mx.random.seed(3) 35 | a = mx.random.uniform(shape=(B * T, D)) 36 | b = mx.random.uniform(shape=(D, D)) 37 | c = mx.random.uniform(shape=(B * T, D)) 38 | mx.eval(a, b, c) 39 | time_fn(mx.matmul, a, b) 40 | 41 | def unbatch_vjp_first(): 42 | return mx.matmul(c, mx.transpose(b)) 43 | 44 | time_fn(unbatch_vjp_first) 45 | 46 | def unbatch_vjp_second(): 47 | return mx.matmul(mx.transpose(a), c) 48 | 49 | time_fn(unbatch_vjp_second) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser("MLX benchmarks.") 54 | parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") 55 | args = parser.parse_args() 56 | if args.gpu: 57 | mx.set_default_device(mx.gpu) 58 | else: 59 | mx.set_default_device(mx.cpu) 60 | 61 | time_batch_matmul() 62 | time_unbatch_matmul() 63 | -------------------------------------------------------------------------------- /benchmarks/python/comparative/README.md: -------------------------------------------------------------------------------- 1 | Microbenchmarks comparing MLX to PyTorch 2 | ======================================== 3 | 4 | Implement the same microbenchmarks in MLX and PyTorch to compare and make a 5 | list of the biggest possible performance improvements and/or regressions. 6 | 7 | Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for 8 | instance to measure the times it takes to sum across the 3rd axis of the above 9 | tensor on the cpu. 10 | 11 | `compare.py` runs several benchmarks and compares the speed-up or lack thereof 12 | in comparison to PyTorch. 13 | 14 | Each bench script can be run with `--print-pid` to print the PID and wait for a 15 | key in order to ease attaching a debugger. 16 | -------------------------------------------------------------------------------- /benchmarks/python/distributed_bench.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | """ 4 | Run with: 5 | mpirun -n 2 python /path/to/distributed_bench.py 6 | """ 7 | 8 | import time 9 | 10 | import mlx.core as mx 11 | 12 | 13 | def time_fn(fn, *args, **kwargs): 14 | msg = kwargs.pop("msg", None) 15 | world = mx.distributed.init() 16 | if world.rank() == 0: 17 | if msg: 18 | print(f"Timing {msg} ...", end=" ") 19 | else: 20 | print(f"Timing {fn.__name__} ...", end=" ") 21 | 22 | # warmup 23 | for _ in range(5): 24 | mx.eval(fn(*args, **kwargs)) 25 | 26 | num_iters = 100 27 | tic = time.perf_counter() 28 | for _ in range(num_iters): 29 | x = mx.eval(fn(*args, **kwargs)) 30 | toc = time.perf_counter() 31 | 32 | msec = 1e3 * (toc - tic) / num_iters 33 | if world.rank() == 0: 34 | print(f"{msec:.5f} msec") 35 | 36 | 37 | def time_all_sum(): 38 | shape = (4096,) 39 | x = mx.random.uniform(shape=shape) 40 | mx.eval(x) 41 | 42 | def sine(x): 43 | for _ in range(20): 44 | x = mx.sin(x) 45 | return x 46 | 47 | time_fn(sine, x) 48 | 49 | def all_sum_plain(x): 50 | for _ in range(20): 51 | x = mx.distributed.all_sum(x) 52 | return x 53 | 54 | time_fn(all_sum_plain, x) 55 | 56 | def all_sum_with_sine(x): 57 | for _ in range(20): 58 | x = mx.sin(x) 59 | x = mx.distributed.all_sum(x) 60 | return x 61 | 62 | time_fn(all_sum_with_sine, x) 63 | 64 | 65 | if __name__ == "__main__": 66 | time_all_sum() 67 | -------------------------------------------------------------------------------- /benchmarks/python/gather_bench.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import argparse 4 | 5 | import mlx.core as mx 6 | import torch 7 | from time_utils import measure_runtime 8 | 9 | 10 | def benchmark_gather_mlx(x_shape, idx_shape): 11 | def gather(x, idx): 12 | mx.eval(x[idx]) 13 | 14 | idx = mx.random.randint(0, x_shape[0] - 1, idx_shape) 15 | x = mx.random.normal(x_shape).astype(mx.float32) 16 | 17 | runtime = measure_runtime(gather, x=x, idx=idx) 18 | print(f"MLX: {runtime:.3f}ms") 19 | 20 | 21 | def benchmark_gather_torch(x_shape, idx_shape, device): 22 | def gather(x, idx, device): 23 | _ = x[idx] 24 | if device == torch.device("mps"): 25 | torch.mps.synchronize() 26 | 27 | idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device) 28 | x = torch.randn(x_shape, dtype=torch.float32).to(device) 29 | 30 | runtime = measure_runtime(gather, x=x, idx=idx, device=device) 31 | print(f"PyTorch: {runtime:.3f}ms") 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser("Gather benchmarks.") 36 | parser.add_argument("--cpu", action="store_true", help="Use the CPU.") 37 | args = parser.parse_args() 38 | 39 | if args.cpu: 40 | mx.set_default_device(mx.cpu) 41 | device = torch.device("cpu") 42 | else: 43 | device = torch.device("mps") 44 | 45 | idx_shapes = [(1_000_000,), (100_000,), ()] 46 | x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)] 47 | 48 | for x_shape, idx_shape in zip(x_shapes, idx_shapes): 49 | print("=" * 20) 50 | print(f"X {x_shape}, Indices {idx_shape}") 51 | benchmark_gather_mlx(x_shape, idx_shape) 52 | benchmark_gather_torch(x_shape, idx_shape, device=device) 53 | -------------------------------------------------------------------------------- /benchmarks/python/rope_bench.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | from time_utils import time_fn 6 | 7 | 8 | def time_rope(): 9 | rope = nn.RoPE(64) 10 | 11 | # vec 12 | x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16) 13 | mx.eval(x) 14 | 15 | def rope_vec(x): 16 | for _ in range(32): 17 | x = rope(x, offset=100) 18 | return x 19 | 20 | time_fn(rope_vec, x) 21 | 22 | # matrix 23 | x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16) 24 | mx.eval(x) 25 | 26 | def rope_mat(x): 27 | for _ in range(32): 28 | x = rope(x) 29 | return x 30 | 31 | time_fn(rope_mat, x) 32 | 33 | 34 | if __name__ == "__main__": 35 | time_rope() 36 | -------------------------------------------------------------------------------- /benchmarks/python/synchronize_bench.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import mlx.core as mx 4 | 5 | rank = mx.distributed.init().rank() 6 | 7 | 8 | def timeit(fn, a): 9 | 10 | # warmup 11 | for _ in range(5): 12 | mx.eval(fn(a)) 13 | 14 | its = 10 15 | tic = time.perf_counter() 16 | for _ in range(its): 17 | mx.eval(fn(a)) 18 | toc = time.perf_counter() 19 | ms = 1000 * (toc - tic) / its 20 | return ms 21 | 22 | 23 | def all_reduce_benchmark(): 24 | a = mx.ones((5, 5), mx.int32) 25 | 26 | its_per_eval = 100 27 | 28 | def fn(x): 29 | for _ in range(its_per_eval): 30 | x = mx.distributed.all_sum(x) 31 | x = x - 1 32 | return x 33 | 34 | ms = timeit(fn, a) / its_per_eval 35 | if rank == 0: 36 | print(f"All Reduce: time per iteration {ms:.6f} (ms)") 37 | 38 | 39 | def all_gather_benchmark(): 40 | a = mx.ones((5, 5), mx.int32) 41 | its_per_eval = 100 42 | 43 | def fn(x): 44 | for _ in range(its_per_eval): 45 | x = mx.distributed.all_gather(x)[0] 46 | return x 47 | 48 | ms = timeit(fn, a) / its_per_eval 49 | if rank == 0: 50 | print(f"All gather: time per iteration {ms:.6f} (ms)") 51 | 52 | 53 | if __name__ == "__main__": 54 | all_reduce_benchmark() 55 | all_gather_benchmark() 56 | -------------------------------------------------------------------------------- /benchmarks/python/time_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import time 4 | 5 | import mlx.core as mx 6 | 7 | 8 | def time_fn(fn, *args, **kwargs): 9 | msg = kwargs.pop("msg", None) 10 | if msg: 11 | print(f"Timing {msg} ...", end=" ") 12 | else: 13 | print(f"Timing {fn.__name__} ...", end=" ") 14 | 15 | # warmup 16 | for _ in range(5): 17 | mx.eval(fn(*args, **kwargs)) 18 | 19 | num_iters = 100 20 | tic = time.perf_counter() 21 | for _ in range(num_iters): 22 | x = mx.eval(fn(*args, **kwargs)) 23 | toc = time.perf_counter() 24 | 25 | msec = 1e3 * (toc - tic) / num_iters 26 | print(f"{msec:.5f} msec") 27 | 28 | 29 | def measure_runtime(fn, **kwargs): 30 | # Warmup 31 | for _ in range(5): 32 | fn(**kwargs) 33 | 34 | tic = time.time() 35 | iters = 100 36 | for _ in range(iters): 37 | fn(**kwargs) 38 | return (time.time() - tic) * 1000 / iters 39 | -------------------------------------------------------------------------------- /docs/.clang-format: -------------------------------------------------------------------------------- 1 | DisableFormat: true 2 | SortIncludes: Never 3 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | src/python/_autosummary*/ 2 | src/python/nn/_autosummary*/ 3 | src/python/optimizers/_autosummary*/ 4 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx/db5a7c6192af90eed81ff7eac8213e7fe7b7a0c8/docs/.nojekyll -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | 3 | # You can set these variables from the command line. 4 | SPHINXOPTS = 5 | SPHINXBUILD = sphinx-build 6 | SOURCEDIR = src 7 | BUILDDIR = build 8 | 9 | # Put it first so that "make" without argument is like "make help". 10 | help: 11 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 12 | 13 | .PHONY: help Makefile 14 | 15 | # Catch-all target: route all unknown targets to Sphinx using the new 16 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 17 | %: Makefile 18 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 19 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ## Build the Docs 2 | 3 | ### Setup (do once) 4 | 5 | Install Doxygen: 6 | 7 | ``` 8 | brew install doxygen 9 | ``` 10 | 11 | Install Python packages: 12 | 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ### Build 18 | 19 | Build the docs from `mlx/docs/` 20 | 21 | ``` 22 | doxygen && make html 23 | ``` 24 | 25 | View the docs by running a server in `mlx/docs/build/html/`: 26 | 27 | ``` 28 | python -m http.server 29 | ``` 30 | 31 | and point your browser to `http://localhost:`. 32 | 33 | ### Push to GitHub Pages 34 | 35 | Check-out the `gh-pages` branch (`git switch gh-pages`) and build 36 | the docs. Then force add the `build/html` directory: 37 | 38 | `git add -f build/html` 39 | 40 | Commit and push the changes to the `gh-pages` branch. 41 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | breathe 3 | sphinx-book-theme 4 | mlx 5 | -------------------------------------------------------------------------------- /docs/src/_static/metal_debugger/capture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx/db5a7c6192af90eed81ff7eac8213e7fe7b7a0c8/docs/src/_static/metal_debugger/capture.png -------------------------------------------------------------------------------- /docs/src/_static/metal_debugger/schema.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx/db5a7c6192af90eed81ff7eac8213e7fe7b7a0c8/docs/src/_static/metal_debugger/schema.png -------------------------------------------------------------------------------- /docs/src/_static/mlx_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx/db5a7c6192af90eed81ff7eac8213e7fe7b7a0c8/docs/src/_static/mlx_logo.png -------------------------------------------------------------------------------- /docs/src/_static/mlx_logo_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx/db5a7c6192af90eed81ff7eac8213e7fe7b7a0c8/docs/src/_static/mlx_logo_dark.png -------------------------------------------------------------------------------- /docs/src/_templates/module-base-class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | .. rubric:: Attributes 12 | 13 | .. autosummary:: 14 | :toctree: . 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | .. rubric:: Methods 24 | 25 | .. autosummary:: 26 | :toctree: . 27 | {% for item in methods %} 28 | {%- if item not in inherited_members and item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | -------------------------------------------------------------------------------- /docs/src/_templates/nn-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | {% block methods %} 8 | 9 | {% if methods %} 10 | .. rubric:: {{ _('Methods') }} 11 | 12 | .. autosummary:: 13 | {% for item in methods %} 14 | {%- if item not in inherited_members and item != "__init__" %} 15 | ~{{ name }}.{{ item }} 16 | {%- endif %} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | -------------------------------------------------------------------------------- /docs/src/_templates/optimizers-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | {% block methods %} 8 | 9 | {% if methods %} 10 | .. rubric:: {{ _('Methods') }} 11 | 12 | .. autosummary:: 13 | {% for item in methods %} 14 | {%- if item not in inherited_members %} 15 | ~{{ name }}.{{ item }} 16 | {%- endif %} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | -------------------------------------------------------------------------------- /docs/src/cpp/ops.rst: -------------------------------------------------------------------------------- 1 | .. _cpp_ops: 2 | 3 | Operations 4 | ========== 5 | 6 | .. doxygengroup:: ops 7 | :content-only: 8 | -------------------------------------------------------------------------------- /docs/src/python/array.rst: -------------------------------------------------------------------------------- 1 | .. _array: 2 | 3 | Array 4 | ===== 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | array 12 | array.astype 13 | array.at 14 | array.item 15 | array.tolist 16 | array.dtype 17 | array.itemsize 18 | array.nbytes 19 | array.ndim 20 | array.shape 21 | array.size 22 | array.real 23 | array.imag 24 | array.abs 25 | array.all 26 | array.any 27 | array.argmax 28 | array.argmin 29 | array.conj 30 | array.cos 31 | array.cummax 32 | array.cummin 33 | array.cumprod 34 | array.cumsum 35 | array.diag 36 | array.diagonal 37 | array.exp 38 | array.flatten 39 | array.log 40 | array.log10 41 | array.log1p 42 | array.log2 43 | array.logcumsumexp 44 | array.logsumexp 45 | array.max 46 | array.mean 47 | array.min 48 | array.moveaxis 49 | array.prod 50 | array.reciprocal 51 | array.reshape 52 | array.round 53 | array.rsqrt 54 | array.sin 55 | array.split 56 | array.sqrt 57 | array.square 58 | array.squeeze 59 | array.std 60 | array.sum 61 | array.swapaxes 62 | array.transpose 63 | array.T 64 | array.var 65 | array.view 66 | -------------------------------------------------------------------------------- /docs/src/python/data_types.rst: -------------------------------------------------------------------------------- 1 | .. _data_types: 2 | 3 | Data Types 4 | ========== 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | The default floating point type is ``float32`` and the default integer type is 9 | ``int32``. The table below shows supported values for :obj:`Dtype`. 10 | 11 | .. list-table:: Supported Data Types 12 | :widths: 5 3 20 13 | :header-rows: 1 14 | 15 | * - Type 16 | - Bytes 17 | - Description 18 | * - ``bool_`` 19 | - 1 20 | - Boolean (``True``, ``False``) data type 21 | * - ``uint8`` 22 | - 1 23 | - 8-bit unsigned integer 24 | * - ``uint16`` 25 | - 2 26 | - 16-bit unsigned integer 27 | * - ``uint32`` 28 | - 4 29 | - 32-bit unsigned integer 30 | * - ``uint64`` 31 | - 8 32 | - 64-bit unsigned integer 33 | * - ``int8`` 34 | - 1 35 | - 8-bit signed integer 36 | * - ``int16`` 37 | - 2 38 | - 16-bit signed integer 39 | * - ``int32`` 40 | - 4 41 | - 32-bit signed integer 42 | * - ``int64`` 43 | - 8 44 | - 64-bit signed integer 45 | * - ``bfloat16`` 46 | - 2 47 | - 16-bit brain float (e8, m7) 48 | * - ``float16`` 49 | - 2 50 | - 16-bit IEEE float (e5, m10) 51 | * - ``float32`` 52 | - 4 53 | - 32-bit float 54 | * - ``float64`` 55 | - 4 56 | - 64-bit double 57 | * - ``complex64`` 58 | - 8 59 | - 64-bit complex float 60 | 61 | 62 | .. note:: 63 | 64 | Arrays with type ``float64`` only work with CPU operations. Using 65 | ``float64`` arrays on the GPU will result in an exception. 66 | 67 | 68 | Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object 69 | documentation for more information. Use :func:`issubdtype` to determine if one 70 | ``dtype`` (or category) is a subtype of another category. 71 | 72 | .. autosummary:: 73 | :toctree: _autosummary 74 | 75 | Dtype 76 | DtypeCategory 77 | issubdtype 78 | finfo 79 | -------------------------------------------------------------------------------- /docs/src/python/devices_and_streams.rst: -------------------------------------------------------------------------------- 1 | .. _devices_and_streams: 2 | 3 | Devices and Streams 4 | =================== 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | Device 12 | Stream 13 | default_device 14 | set_default_device 15 | default_stream 16 | new_stream 17 | set_default_stream 18 | stream 19 | synchronize 20 | -------------------------------------------------------------------------------- /docs/src/python/distributed.rst: -------------------------------------------------------------------------------- 1 | .. _distributed: 2 | 3 | .. currentmodule:: mlx.core.distributed 4 | 5 | Distributed Communication 6 | ========================== 7 | 8 | MLX provides a distributed communication package using MPI. The MPI library is 9 | loaded at runtime; if MPI is available then distributed communication is also 10 | made available. 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | 15 | Group 16 | is_available 17 | init 18 | all_sum 19 | all_gather 20 | send 21 | recv 22 | recv_like 23 | -------------------------------------------------------------------------------- /docs/src/python/export.rst: -------------------------------------------------------------------------------- 1 | .. _export: 2 | 3 | Export Functions 4 | ================ 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | export_function 12 | import_function 13 | exporter 14 | export_to_dot 15 | -------------------------------------------------------------------------------- /docs/src/python/fast.rst: -------------------------------------------------------------------------------- 1 | .. _fast: 2 | 3 | Fast 4 | ==== 5 | 6 | .. currentmodule:: mlx.core.fast 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | rms_norm 12 | layer_norm 13 | rope 14 | scaled_dot_product_attention 15 | metal_kernel 16 | -------------------------------------------------------------------------------- /docs/src/python/fft.rst: -------------------------------------------------------------------------------- 1 | .. _fft: 2 | 3 | FFT 4 | === 5 | 6 | .. currentmodule:: mlx.core.fft 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | fft 12 | ifft 13 | fft2 14 | ifft2 15 | fftn 16 | ifftn 17 | rfft 18 | irfft 19 | rfft2 20 | irfft2 21 | rfftn 22 | irfftn 23 | fftshift 24 | ifftshift 25 | -------------------------------------------------------------------------------- /docs/src/python/linalg.rst: -------------------------------------------------------------------------------- 1 | .. _linalg: 2 | 3 | Linear Algebra 4 | ============== 5 | 6 | .. currentmodule:: mlx.core.linalg 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | inv 12 | tri_inv 13 | norm 14 | cholesky 15 | cholesky_inv 16 | cross 17 | qr 18 | svd 19 | eigvals 20 | eig 21 | eigvalsh 22 | eigh 23 | lu 24 | lu_factor 25 | pinv 26 | solve 27 | solve_triangular 28 | -------------------------------------------------------------------------------- /docs/src/python/memory_management.rst: -------------------------------------------------------------------------------- 1 | Memory Management 2 | ================= 3 | 4 | .. currentmodule:: mlx.core 5 | 6 | .. autosummary:: 7 | :toctree: _autosummary 8 | 9 | get_active_memory 10 | get_peak_memory 11 | reset_peak_memory 12 | get_cache_memory 13 | set_memory_limit 14 | set_cache_limit 15 | set_wired_limit 16 | clear_cache 17 | -------------------------------------------------------------------------------- /docs/src/python/metal.rst: -------------------------------------------------------------------------------- 1 | Metal 2 | ===== 3 | 4 | .. currentmodule:: mlx.core.metal 5 | 6 | .. autosummary:: 7 | :toctree: _autosummary 8 | 9 | is_available 10 | device_info 11 | start_capture 12 | stop_capture 13 | -------------------------------------------------------------------------------- /docs/src/python/nn/functions.rst: -------------------------------------------------------------------------------- 1 | .. _nn_functions: 2 | 3 | .. currentmodule:: mlx.nn 4 | 5 | Functions 6 | --------- 7 | 8 | Layers without parameters (e.g. activation functions) are also provided as 9 | simple functions. 10 | 11 | .. autosummary:: 12 | :toctree: _autosummary_functions 13 | :template: nn-module-template.rst 14 | 15 | elu 16 | celu 17 | gelu 18 | gelu_approx 19 | gelu_fast_approx 20 | glu 21 | hard_shrink 22 | hard_tanh 23 | hardswish 24 | leaky_relu 25 | log_sigmoid 26 | log_softmax 27 | mish 28 | prelu 29 | relu 30 | relu6 31 | selu 32 | sigmoid 33 | silu 34 | softmax 35 | softmin 36 | softplus 37 | softshrink 38 | step 39 | tanh 40 | -------------------------------------------------------------------------------- /docs/src/python/nn/init.rst: -------------------------------------------------------------------------------- 1 | .. _init: 2 | 3 | .. currentmodule:: mlx.nn.init 4 | 5 | Initializers 6 | ------------ 7 | 8 | The ``mlx.nn.init`` package contains commonly used initializers for neural 9 | network parameters. Initializers return a function which can be applied to any 10 | input :obj:`mlx.core.array` to produce an initialized output. 11 | 12 | For example: 13 | 14 | .. code:: python 15 | 16 | import mlx.core as mx 17 | import mlx.nn as nn 18 | 19 | init_fn = nn.init.uniform() 20 | 21 | # Produces a [2, 2] uniform matrix 22 | param = init_fn(mx.zeros((2, 2))) 23 | 24 | To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform 25 | distribution, you can do: 26 | 27 | .. code:: python 28 | 29 | import mlx.nn as nn 30 | model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5)) 31 | init_fn = nn.init.uniform(low=-0.1, high=0.1) 32 | model.apply(init_fn) 33 | 34 | 35 | .. autosummary:: 36 | :toctree: _autosummary 37 | 38 | constant 39 | normal 40 | uniform 41 | identity 42 | glorot_normal 43 | glorot_uniform 44 | he_normal 45 | he_uniform 46 | -------------------------------------------------------------------------------- /docs/src/python/nn/layers.rst: -------------------------------------------------------------------------------- 1 | .. _layers: 2 | 3 | .. currentmodule:: mlx.nn 4 | 5 | Layers 6 | ------ 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | :template: nn-module-template.rst 11 | 12 | ALiBi 13 | AvgPool1d 14 | AvgPool2d 15 | AvgPool3d 16 | BatchNorm 17 | CELU 18 | Conv1d 19 | Conv2d 20 | Conv3d 21 | ConvTranspose1d 22 | ConvTranspose2d 23 | ConvTranspose3d 24 | Dropout 25 | Dropout2d 26 | Dropout3d 27 | Embedding 28 | ELU 29 | GELU 30 | GLU 31 | GroupNorm 32 | GRU 33 | HardShrink 34 | HardTanh 35 | Hardswish 36 | InstanceNorm 37 | LayerNorm 38 | LeakyReLU 39 | Linear 40 | LogSigmoid 41 | LogSoftmax 42 | LSTM 43 | MaxPool1d 44 | MaxPool2d 45 | MaxPool3d 46 | Mish 47 | MultiHeadAttention 48 | PReLU 49 | QuantizedEmbedding 50 | QuantizedLinear 51 | RMSNorm 52 | ReLU 53 | ReLU6 54 | RNN 55 | RoPE 56 | SELU 57 | Sequential 58 | Sigmoid 59 | SiLU 60 | SinusoidalPositionalEncoding 61 | Softmin 62 | Softshrink 63 | Softsign 64 | Softmax 65 | Softplus 66 | Step 67 | Tanh 68 | Transformer 69 | Upsample 70 | -------------------------------------------------------------------------------- /docs/src/python/nn/losses.rst: -------------------------------------------------------------------------------- 1 | .. _losses: 2 | 3 | .. currentmodule:: mlx.nn.losses 4 | 5 | Loss Functions 6 | -------------- 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary_functions 10 | :template: nn-module-template.rst 11 | 12 | binary_cross_entropy 13 | cosine_similarity_loss 14 | cross_entropy 15 | gaussian_nll_loss 16 | hinge_loss 17 | huber_loss 18 | kl_div_loss 19 | l1_loss 20 | log_cosh_loss 21 | margin_ranking_loss 22 | mse_loss 23 | nll_loss 24 | smooth_l1_loss 25 | triplet_loss -------------------------------------------------------------------------------- /docs/src/python/nn/module.rst: -------------------------------------------------------------------------------- 1 | Module 2 | ====== 3 | 4 | .. currentmodule:: mlx.nn 5 | 6 | .. autoclass:: Module 7 | 8 | .. rubric:: Attributes 9 | 10 | .. autosummary:: 11 | :toctree: _autosummary 12 | 13 | Module.training 14 | Module.state 15 | 16 | .. rubric:: Methods 17 | 18 | .. autosummary:: 19 | :toctree: _autosummary 20 | 21 | Module.apply 22 | Module.apply_to_modules 23 | Module.children 24 | Module.eval 25 | Module.filter_and_map 26 | Module.freeze 27 | Module.leaf_modules 28 | Module.load_weights 29 | Module.modules 30 | Module.named_modules 31 | Module.parameters 32 | Module.save_weights 33 | Module.set_dtype 34 | Module.train 35 | Module.trainable_parameters 36 | Module.unfreeze 37 | Module.update 38 | Module.update_modules 39 | -------------------------------------------------------------------------------- /docs/src/python/optimizers/common_optimizers.rst: -------------------------------------------------------------------------------- 1 | .. _common_optimizers: 2 | 3 | Common Optimizers 4 | ================= 5 | 6 | .. currentmodule:: mlx.optimizers 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | :template: optimizers-template.rst 11 | 12 | SGD 13 | RMSprop 14 | Adagrad 15 | Adafactor 16 | AdaDelta 17 | Adam 18 | AdamW 19 | Adamax 20 | Lion 21 | MultiOptimizer 22 | -------------------------------------------------------------------------------- /docs/src/python/optimizers/optimizer.rst: -------------------------------------------------------------------------------- 1 | Optimizer 2 | ========= 3 | 4 | .. currentmodule:: mlx.optimizers 5 | 6 | .. autoclass:: Optimizer 7 | 8 | 9 | .. rubric:: Attributes 10 | 11 | .. autosummary:: 12 | :toctree: _autosummary 13 | 14 | Optimizer.state 15 | 16 | .. rubric:: Methods 17 | 18 | .. autosummary:: 19 | :toctree: _autosummary 20 | 21 | Optimizer.apply_gradients 22 | Optimizer.init 23 | Optimizer.update 24 | -------------------------------------------------------------------------------- /docs/src/python/optimizers/schedulers.rst: -------------------------------------------------------------------------------- 1 | .. _schedulers: 2 | 3 | Schedulers 4 | ========== 5 | 6 | .. currentmodule:: mlx.optimizers 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | cosine_decay 12 | exponential_decay 13 | join_schedules 14 | linear_schedule 15 | step_decay 16 | -------------------------------------------------------------------------------- /docs/src/python/random.rst: -------------------------------------------------------------------------------- 1 | .. _random: 2 | 3 | Random 4 | ====== 5 | 6 | Random sampling functions in MLX use an implicit global PRNG state by default. 7 | However, all function take an optional ``key`` keyword argument for when more 8 | fine-grained control or explicit state management is needed. 9 | 10 | For example, you can generate random numbers with: 11 | 12 | .. code-block:: python 13 | 14 | for _ in range(3): 15 | print(mx.random.uniform()) 16 | 17 | which will print a sequence of unique pseudo random numbers. Alternatively you 18 | can explicitly set the key: 19 | 20 | .. code-block:: python 21 | 22 | key = mx.random.key(0) 23 | for _ in range(3): 24 | print(mx.random.uniform(key=key)) 25 | 26 | which will yield the same pseudo random number at each iteration. 27 | 28 | Following `JAX's PRNG design `_ 29 | we use a splittable version of Threefry, which is a counter-based PRNG. 30 | 31 | .. currentmodule:: mlx.core.random 32 | 33 | .. autosummary:: 34 | :toctree: _autosummary 35 | 36 | bernoulli 37 | categorical 38 | gumbel 39 | key 40 | normal 41 | multivariate_normal 42 | randint 43 | seed 44 | split 45 | truncated_normal 46 | uniform 47 | laplace 48 | permutation 49 | -------------------------------------------------------------------------------- /docs/src/python/transforms.rst: -------------------------------------------------------------------------------- 1 | .. _transforms: 2 | 3 | Transforms 4 | ========== 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | eval 12 | async_eval 13 | compile 14 | custom_function 15 | disable_compile 16 | enable_compile 17 | grad 18 | value_and_grad 19 | jvp 20 | vjp 21 | vmap 22 | -------------------------------------------------------------------------------- /docs/src/python/tree_utils.rst: -------------------------------------------------------------------------------- 1 | .. _utils: 2 | 3 | Tree Utils 4 | ========== 5 | 6 | In MLX we consider a python tree to be an arbitrarily nested collection of 7 | dictionaries, lists and tuples without cycles. Functions in this module that 8 | return python trees will be using the default python ``dict``, ``list`` and 9 | ``tuple`` but they can usually process objects that inherit from any of these. 10 | 11 | .. note:: 12 | Dictionaries should have keys that are valid python identifiers. 13 | 14 | .. currentmodule:: mlx.utils 15 | 16 | .. autosummary:: 17 | :toctree: _autosummary 18 | 19 | tree_flatten 20 | tree_unflatten 21 | tree_map 22 | tree_map_with_path 23 | tree_reduce 24 | -------------------------------------------------------------------------------- /docs/src/usage/using_streams.rst: -------------------------------------------------------------------------------- 1 | .. _using_streams: 2 | 3 | Using Streams 4 | ============= 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | Specifying the :obj:`Stream` 9 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 10 | 11 | All operations (including random number generation) take an optional 12 | keyword argument ``stream``. The ``stream`` kwarg specifies which 13 | :obj:`Stream` the operation should run on. If the stream is unspecified then 14 | the operation is run on the default stream of the default device: 15 | ``mx.default_stream(mx.default_device())``. The ``stream`` kwarg can also 16 | be a :obj:`Device` (e.g. ``stream=my_device``) in which case the operation is 17 | run on the default stream of the provided device 18 | ``mx.default_stream(my_device)``. 19 | -------------------------------------------------------------------------------- /examples/cmake_project/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.27) 2 | 3 | project(example LANGUAGES CXX) 4 | 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | # Comment the following two commands only the MLX C++ library is installed and 9 | # set(MLX_ROOT "/path/to/mlx") directly if needed. 10 | find_package( 11 | Python 3.9 12 | COMPONENTS Interpreter Development.Module 13 | REQUIRED) 14 | execute_process( 15 | COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir 16 | OUTPUT_STRIP_TRAILING_WHITESPACE 17 | OUTPUT_VARIABLE MLX_ROOT) 18 | 19 | find_package(MLX CONFIG REQUIRED) 20 | 21 | add_executable(example example.cpp) 22 | target_link_libraries(example PRIVATE mlx) 23 | -------------------------------------------------------------------------------- /examples/cmake_project/README.md: -------------------------------------------------------------------------------- 1 | ## Build and Run 2 | 3 | Install MLX with Python: 4 | 5 | ```bash 6 | pip install mlx>=0.22 7 | ``` 8 | 9 | Build the C++ example: 10 | 11 | ```bash 12 | cmake -B build -DCMAKE_BUILD_TYPE=Release 13 | cmake --build build 14 | ``` 15 | 16 | Run the C++ example: 17 | 18 | ``` 19 | ./build/example 20 | ``` 21 | 22 | which should output: 23 | 24 | ``` 25 | array([2, 4, 6], dtype=int32) 26 | ``` 27 | -------------------------------------------------------------------------------- /examples/cmake_project/example.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/mlx.h" 6 | 7 | namespace mx = mlx::core; 8 | 9 | int main() { 10 | auto x = mx::array({1, 2, 3}); 11 | auto y = mx::array({1, 2, 3}); 12 | std::cout << x + y << std::endl; 13 | return 0; 14 | } 15 | -------------------------------------------------------------------------------- /examples/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | function(build_example SRCFILE) 2 | get_filename_component(src_name ${SRCFILE} NAME_WE) 3 | set(target "${src_name}") 4 | add_executable(${target} ${SRCFILE}) 5 | target_link_libraries(${target} PRIVATE mlx) 6 | endfunction(build_example) 7 | 8 | build_example(tutorial.cpp) 9 | build_example(linear_regression.cpp) 10 | build_example(logistic_regression.cpp) 11 | build_example(metal_capture.cpp) 12 | build_example(distributed.cpp) 13 | -------------------------------------------------------------------------------- /examples/cpp/distributed.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/mlx.h" 6 | 7 | namespace mx = mlx::core; 8 | 9 | int main() { 10 | if (!mx::distributed::is_available()) { 11 | std::cout << "No communication backend found" << std::endl; 12 | return 1; 13 | } 14 | 15 | auto global_group = mx::distributed::init(); 16 | std::cout << global_group.rank() << " / " << global_group.size() << std::endl; 17 | 18 | mx::array x = mx::ones({10}); 19 | mx::array out = mx::distributed::all_sum(x, global_group); 20 | 21 | std::cout << out << std::endl; 22 | } 23 | -------------------------------------------------------------------------------- /examples/cpp/linear_regression.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "mlx/mlx.h" 8 | #include "timer.h" 9 | 10 | /** 11 | * An example of linear regression with MLX. 12 | */ 13 | namespace mx = mlx::core; 14 | 15 | int main() { 16 | int num_features = 100; 17 | int num_examples = 1'000; 18 | int num_iters = 10'000; 19 | float learning_rate = 0.01; 20 | 21 | // True parameters 22 | auto w_star = mx::random::normal({num_features}); 23 | 24 | // The input examples (design matrix) 25 | auto X = mx::random::normal({num_examples, num_features}); 26 | 27 | // Noisy labels 28 | auto eps = 1e-2 * mx::random::normal({num_examples}); 29 | auto y = mx::matmul(X, w_star) + eps; 30 | 31 | // Initialize random parameters 32 | mx::array w = 1e-2 * mx::random::normal({num_features}); 33 | 34 | auto loss_fn = [&](mx::array w) { 35 | auto yhat = mx::matmul(X, w); 36 | return (0.5f / num_examples) * mx::sum(mx::square(yhat - y)); 37 | }; 38 | 39 | auto grad_fn = mx::grad(loss_fn); 40 | 41 | auto tic = timer::time(); 42 | for (int it = 0; it < num_iters; ++it) { 43 | auto grads = grad_fn(w); 44 | w = w - learning_rate * grads; 45 | mx::eval(w); 46 | } 47 | auto toc = timer::time(); 48 | 49 | auto loss = loss_fn(w); 50 | auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item()); 51 | auto throughput = num_iters / timer::seconds(toc - tic); 52 | std::cout << "Loss " << loss << ", |w - w*| = " << error_norm 53 | << ", Throughput " << throughput << " (it/s)." << std::endl; 54 | } 55 | -------------------------------------------------------------------------------- /examples/cpp/logistic_regression.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "mlx/mlx.h" 8 | #include "timer.h" 9 | 10 | /** 11 | * An example of logistic regression with MLX. 12 | */ 13 | namespace mx = mlx::core; 14 | 15 | int main() { 16 | int num_features = 100; 17 | int num_examples = 1'000; 18 | int num_iters = 10'000; 19 | float learning_rate = 0.1; 20 | 21 | // True parameters 22 | auto w_star = mx::random::normal({num_features}); 23 | 24 | // The input examples 25 | auto X = mx::random::normal({num_examples, num_features}); 26 | 27 | // Labels 28 | auto y = mx::matmul(X, w_star) > 0; 29 | 30 | // Initialize random parameters 31 | mx::array w = 1e-2 * mx::random::normal({num_features}); 32 | 33 | auto loss_fn = [&](mx::array w) { 34 | auto logits = mx::matmul(X, w); 35 | auto scale = (1.0f / num_examples); 36 | return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits); 37 | }; 38 | 39 | auto grad_fn = mx::grad(loss_fn); 40 | 41 | auto tic = timer::time(); 42 | for (int it = 0; it < num_iters; ++it) { 43 | auto grads = grad_fn(w); 44 | w = w - learning_rate * grads; 45 | mx::eval(w); 46 | } 47 | auto toc = timer::time(); 48 | 49 | auto loss = loss_fn(w); 50 | auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples; 51 | auto throughput = num_iters / timer::seconds(toc - tic); 52 | std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput " 53 | << throughput << " (it/s)." << std::endl; 54 | } 55 | -------------------------------------------------------------------------------- /examples/cpp/metal_capture.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "mlx/mlx.h" 7 | 8 | namespace mx = mlx::core; 9 | 10 | int main() { 11 | // To use Metal debugging and profiling: 12 | // 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON). 13 | // 2. Run with MTL_CAPTURE_ENABLED=1. 14 | mx::metal::start_capture("mlx_trace.gputrace"); 15 | 16 | // Start at index two because the default GPU and CPU streams have indices 17 | // zero and one, respectively. This naming matches the label assigned to each 18 | // stream's command queue. 19 | auto s2 = new_stream(mx::Device::gpu); 20 | auto s3 = new_stream(mx::Device::gpu); 21 | 22 | auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2); 23 | auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3); 24 | auto x = mx::add(a, a, s2); 25 | auto y = mx::add(b, b, s3); 26 | 27 | // The multiply will happen on the default stream. 28 | std::cout << mx::multiply(x, y) << std::endl; 29 | 30 | mx::metal::stop_capture(); 31 | } 32 | -------------------------------------------------------------------------------- /examples/cpp/timer.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace timer { 8 | 9 | using namespace std::chrono; 10 | 11 | template 12 | inline double seconds(duration x) { 13 | return duration_cast(x).count() / 1e9; 14 | } 15 | 16 | inline auto time() { 17 | return high_resolution_clock::now(); 18 | } 19 | 20 | } // namespace timer 21 | -------------------------------------------------------------------------------- /examples/export/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.27) 2 | 3 | project(import_mlx LANGUAGES CXX) 4 | 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | find_package( 9 | Python 3.9 10 | COMPONENTS Interpreter Development.Module 11 | REQUIRED) 12 | execute_process( 13 | COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir 14 | OUTPUT_STRIP_TRAILING_WHITESPACE 15 | OUTPUT_VARIABLE MLX_ROOT) 16 | find_package(MLX CONFIG REQUIRED) 17 | 18 | add_executable(eval_mlp eval_mlp.cpp) 19 | target_link_libraries(eval_mlp PRIVATE mlx) 20 | 21 | add_executable(train_mlp train_mlp.cpp) 22 | target_link_libraries(train_mlp PRIVATE mlx) 23 | -------------------------------------------------------------------------------- /examples/export/README.md: -------------------------------------------------------------------------------- 1 | ## Setup 2 | 3 | Install MLX: 4 | 5 | ```bash 6 | pip install mlx>=0.22 7 | ``` 8 | 9 | Build the C++ examples: 10 | 11 | ```bash 12 | cmake -B build -DCMAKE_BUILD_TYPE=Release 13 | cmake --build build 14 | ``` 15 | 16 | ## Run 17 | 18 | ### Eval MLP 19 | 20 | Run the Python script to export the eval function: 21 | 22 | ```bash 23 | python eval_mlp.py 24 | ``` 25 | 26 | Then run the C++ program to import and run the function: 27 | 28 | ``` 29 | ./build/eval_mlp 30 | ``` 31 | 32 | The Python and C++ programs should output the same result. 33 | 34 | ### Train MLP 35 | 36 | Run the Python script to export the model initialization and training 37 | functions: 38 | 39 | ```bash 40 | python train_mlp.py 41 | ``` 42 | 43 | Then run the C++ program to import and run the functions: 44 | 45 | ``` 46 | ./build/train_mlp 47 | ``` 48 | 49 | The Python and C++ programs should output the same results. 50 | -------------------------------------------------------------------------------- /examples/export/eval_mlp.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | namespace mx = mlx::core; 7 | 8 | int main() { 9 | int batch_size = 8; 10 | int input_dim = 32; 11 | 12 | // Make the input 13 | mx::random::seed(42); 14 | auto example_x = mx::random::uniform({batch_size, input_dim}); 15 | 16 | // Import the function 17 | auto forward = mx::import_function("eval_mlp.mlxfn"); 18 | 19 | // Call the imported function 20 | auto out = forward({example_x})[0]; 21 | 22 | std::cout << out << std::endl; 23 | 24 | return 0; 25 | } 26 | -------------------------------------------------------------------------------- /examples/export/eval_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | import mlx.utils 6 | 7 | 8 | class MLP(nn.Module): 9 | """A simple MLP.""" 10 | 11 | def __init__( 12 | self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int 13 | ): 14 | super().__init__() 15 | layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] 16 | self.layers = [ 17 | nn.Linear(idim, odim) 18 | for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) 19 | ] 20 | 21 | def __call__(self, x): 22 | for l in self.layers[:-1]: 23 | x = nn.relu(l(x)) 24 | return self.layers[-1](x) 25 | 26 | 27 | if __name__ == "__main__": 28 | 29 | batch_size = 8 30 | input_dim = 32 31 | output_dim = 10 32 | 33 | # Load the model 34 | mx.random.seed(0) # Seed for params 35 | model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim) 36 | mx.eval(model) 37 | 38 | # Note, the model parameters are saved in the export function 39 | def forward(x): 40 | return model(x) 41 | 42 | mx.random.seed(42) # Seed for input 43 | example_x = mx.random.uniform(shape=(batch_size, input_dim)) 44 | 45 | mx.export_function("eval_mlp.mlxfn", forward, example_x) 46 | 47 | # Import in Python 48 | imported_forward = mx.import_function("eval_mlp.mlxfn") 49 | expected = forward(example_x) 50 | (out,) = imported_forward(example_x) 51 | assert mx.allclose(expected, out) 52 | print(out) 53 | -------------------------------------------------------------------------------- /examples/export/train_mlp.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | namespace mx = mlx::core; 7 | 8 | int main() { 9 | int batch_size = 8; 10 | int input_dim = 32; 11 | int output_dim = 10; 12 | 13 | auto state = mx::import_function("init_mlp.mlxfn")({}); 14 | 15 | // Make the input 16 | mx::random::seed(42); 17 | auto example_X = mx::random::normal({batch_size, input_dim}); 18 | auto example_y = mx::random::randint(0, output_dim, {batch_size}); 19 | 20 | // Import the function 21 | auto step = mx::import_function("train_mlp.mlxfn"); 22 | 23 | // Call the imported function 24 | for (int it = 0; it < 100; ++it) { 25 | state.insert(state.end(), {example_X, example_y}); 26 | state = step(state); 27 | eval(state); 28 | auto loss = state.back(); 29 | state.pop_back(); 30 | if (it % 10 == 0) { 31 | std::cout << "Loss " << loss.item() << std::endl; 32 | } 33 | } 34 | return 0; 35 | } 36 | -------------------------------------------------------------------------------- /examples/extensions/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Build 3 | 4 | ``` 5 | pip install -e . 6 | ``` 7 | 8 | For faster builds during development, you can also pre-install the requirements: 9 | 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | And then run: 15 | 16 | ``` 17 | python setup.py build_ext -j8 --inplace 18 | ``` 19 | 20 | ## Test 21 | 22 | ``` 23 | python test.py 24 | ``` 25 | -------------------------------------------------------------------------------- /examples/extensions/axpby/axpby.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2025 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/metal/kernels/utils.h" 6 | 7 | template 8 | [[kernel]] void axpby_general( 9 | device const T* x [[buffer(0)]], 10 | device const T* y [[buffer(1)]], 11 | device T* out [[buffer(2)]], 12 | constant const float& alpha [[buffer(3)]], 13 | constant const float& beta [[buffer(4)]], 14 | constant const int* shape [[buffer(5)]], 15 | constant const int64_t* x_strides [[buffer(6)]], 16 | constant const int64_t* y_strides [[buffer(7)]], 17 | constant const int& ndim [[buffer(8)]], 18 | uint index [[thread_position_in_grid]]) { 19 | auto x_offset = elem_to_loc(index, shape, x_strides, ndim); 20 | auto y_offset = elem_to_loc(index, shape, y_strides, ndim); 21 | out[index] = 22 | static_cast(alpha) * x[x_offset] + static_cast(beta) * y[y_offset]; 23 | } 24 | 25 | template 26 | [[kernel]] void axpby_contiguous( 27 | device const T* x [[buffer(0)]], 28 | device const T* y [[buffer(1)]], 29 | device T* out [[buffer(2)]], 30 | constant const float& alpha [[buffer(3)]], 31 | constant const float& beta [[buffer(4)]], 32 | uint index [[thread_position_in_grid]]) { 33 | out[index] = 34 | static_cast(alpha) * x[index] + static_cast(beta) * y[index]; 35 | } 36 | 37 | // clang-format off 38 | #define instantiate_axpby(type_name, type) \ 39 | instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \ 40 | instantiate_kernel( \ 41 | "axpby_contiguous_" #type_name, axpby_contiguous, type) 42 | 43 | instantiate_axpby(float32, float); 44 | instantiate_axpby(float16, half); 45 | instantiate_axpby(bfloat16, bfloat16_t); 46 | instantiate_axpby(complex64, complex64_t); 47 | // clang-format on 48 | -------------------------------------------------------------------------------- /examples/extensions/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "axpby/axpby.h" 7 | 8 | namespace nb = nanobind; 9 | using namespace nb::literals; 10 | 11 | NB_MODULE(_ext, m) { 12 | m.doc() = "Sample extension for MLX"; 13 | 14 | m.def( 15 | "axpby", 16 | &my_ext::axpby, 17 | "x"_a, 18 | "y"_a, 19 | "alpha"_a, 20 | "beta"_a, 21 | nb::kw_only(), 22 | "stream"_a = nb::none(), 23 | R"( 24 | Scale and sum two vectors element-wise 25 | ``z = alpha * x + beta * y`` 26 | 27 | Follows numpy style broadcasting between ``x`` and ``y`` 28 | Inputs are upcasted to floats if needed 29 | 30 | Args: 31 | x (array): Input array. 32 | y (array): Input array. 33 | alpha (float): Scaling factor for ``x``. 34 | beta (float): Scaling factor for ``y``. 35 | 36 | Returns: 37 | array: ``alpha * x + beta * y`` 38 | )"); 39 | } 40 | -------------------------------------------------------------------------------- /examples/extensions/mlx_sample_extensions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import mlx.core as mx 4 | 5 | from ._ext import axpby 6 | -------------------------------------------------------------------------------- /examples/extensions/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "cmake>=3.25", 5 | "mlx>=0.18.0", 6 | "nanobind==2.4.0", 7 | ] 8 | build-backend = "setuptools.build_meta" 9 | -------------------------------------------------------------------------------- /examples/extensions/requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools>=42 2 | cmake>=3.25 3 | mlx>=0.21.0 4 | nanobind==2.2.0 5 | -------------------------------------------------------------------------------- /examples/extensions/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from setuptools import setup 4 | 5 | from mlx import extension 6 | 7 | if __name__ == "__main__": 8 | setup( 9 | name="mlx_sample_extensions", 10 | version="0.0.0", 11 | description="Sample C++ and Metal extensions for MLX primitives.", 12 | ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")], 13 | cmdclass={"build_ext": extension.CMakeBuild}, 14 | packages=["mlx_sample_extensions"], 15 | package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, 16 | zip_safe=False, 17 | python_requires=">=3.8", 18 | ) 19 | -------------------------------------------------------------------------------- /examples/extensions/test.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | from mlx_sample_extensions import axpby 3 | 4 | a = mx.ones((3, 4)) 5 | b = mx.ones((3, 4)) 6 | c = axpby(a, b, 4.0, 2.0, stream=mx.cpu) 7 | 8 | print(f"c shape: {c.shape}") 9 | print(f"c dtype: {c.dtype}") 10 | print(f"c correct: {mx.all(c == 6.0).item()}") 11 | -------------------------------------------------------------------------------- /examples/python/linear_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | 5 | import mlx.core as mx 6 | 7 | num_features = 100 8 | num_examples = 1_000 9 | num_iters = 10_000 10 | lr = 0.01 11 | 12 | # True parameters 13 | w_star = mx.random.normal((num_features,)) 14 | 15 | # Input examples (design matrix) 16 | X = mx.random.normal((num_examples, num_features)) 17 | 18 | # Noisy labels 19 | eps = 1e-2 * mx.random.normal((num_examples,)) 20 | y = X @ w_star + eps 21 | 22 | # Initialize random parameters 23 | w = 1e-2 * mx.random.normal((num_features,)) 24 | 25 | 26 | def loss_fn(w): 27 | return 0.5 * mx.mean(mx.square(X @ w - y)) 28 | 29 | 30 | grad_fn = mx.grad(loss_fn) 31 | 32 | tic = time.time() 33 | for _ in range(num_iters): 34 | grad = grad_fn(w) 35 | w = w - lr * grad 36 | mx.eval(w) 37 | toc = time.time() 38 | 39 | loss = loss_fn(w) 40 | error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5 41 | throughput = num_iters / (toc - tic) 42 | 43 | print( 44 | f"Loss {loss.item():.5f}, L2 distance: |w-w*| = {error_norm:.5f}, " 45 | f"Throughput {throughput:.5f} (it/s)" 46 | ) 47 | -------------------------------------------------------------------------------- /examples/python/logistic_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | 5 | import mlx.core as mx 6 | 7 | num_features = 100 8 | num_examples = 1_000 9 | num_iters = 10_000 10 | lr = 0.1 11 | 12 | # True parameters 13 | w_star = mx.random.normal((num_features,)) 14 | 15 | # Input examples 16 | X = mx.random.normal((num_examples, num_features)) 17 | 18 | # Labels 19 | y = (X @ w_star) > 0 20 | 21 | 22 | # Initialize random parameters 23 | w = 1e-2 * mx.random.normal((num_features,)) 24 | 25 | 26 | def loss_fn(w): 27 | logits = X @ w 28 | return mx.mean(mx.logaddexp(0.0, logits) - y * logits) 29 | 30 | 31 | grad_fn = mx.grad(loss_fn) 32 | 33 | tic = time.time() 34 | for _ in range(num_iters): 35 | grad = grad_fn(w) 36 | w = w - lr * grad 37 | mx.eval(w) 38 | 39 | toc = time.time() 40 | 41 | loss = loss_fn(w) 42 | final_preds = (X @ w) > 0 43 | acc = mx.mean(final_preds == y) 44 | 45 | throughput = num_iters / (toc - tic) 46 | print( 47 | f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} " 48 | f"Throughput {throughput:.5f} (it/s)" 49 | ) 50 | -------------------------------------------------------------------------------- /mlx.pc.in: -------------------------------------------------------------------------------- 1 | # Find MLX 2 | # 3 | # Defines the following variables: 4 | # 5 | # MLX_FOUND : True if MLX is found 6 | # MLX_INCLUDE_DIRS : Include directory 7 | # MLX_LIBRARIES : Libraries to link against 8 | # MLX_CXX_FLAGS : Additional compiler flags 9 | # MLX_BUILD_ACCELERATE : True if MLX was built with accelerate 10 | # MLX_BUILD_METAL : True if MLX was built with metal 11 | 12 | @PACKAGE_INIT@ 13 | 14 | include(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/MLXTargets.cmake) 15 | include(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/extension.cmake) 16 | 17 | set_and_check(MLX_LIBRARY_DIRS @PACKAGE_CMAKE_INSTALL_LIBDIR@) 18 | set_and_check(MLX_INCLUDE_DIRS @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@) 19 | set(MLX_LIBRARIES mlx) 20 | 21 | find_library(MLX_LIBRARY mlx PATHS ${MLX_LIBRARY_DIRS}) 22 | 23 | if (@MLX_BUILD_ACCELERATE@) 24 | set(MLX_BUILD_ACCELERATE @MLX_BUILD_ACCELERATE@) 25 | set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -DACCELERATE_NEW_LAPACK) 26 | endif() 27 | 28 | if (@MLX_BUILD_METAL@) 29 | set(MLX_BUILD_METAL @MLX_BUILD_METAL@) 30 | set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_) 31 | set(MLX_INCLUDE_DIRS 32 | "${MLX_INCLUDE_DIRS};" 33 | @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp 34 | ) 35 | if(@MLX_METAL_VERSION@ GREATER_EQUAL 310) 36 | set(MLX_INCLUDE_DIRS 37 | "${MLX_INCLUDE_DIRS};" 38 | @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1) 39 | else() 40 | set(MLX_INCLUDE_DIRS 41 | "${MLX_INCLUDE_DIRS};" 42 | @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0) 43 | endif() 44 | endif() 45 | 46 | set_target_properties(mlx PROPERTIES 47 | CXX_STANDARD 17 48 | INTERFACE_COMPILE_OPTIONS "${MLX_CXX_FLAGS}" 49 | ) 50 | 51 | include(FindPackageHandleStandardArgs) 52 | find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS) 53 | -------------------------------------------------------------------------------- /mlx/3rdparty/.clang-format: -------------------------------------------------------------------------------- 1 | DisableFormat: true 2 | SortIncludes: Never 3 | -------------------------------------------------------------------------------- /mlx/allocator.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "mlx/allocator.h" 7 | 8 | namespace mlx::core::allocator { 9 | 10 | Buffer malloc(size_t size) { 11 | auto buffer = allocator().malloc(size); 12 | if (size && !buffer.ptr()) { 13 | std::ostringstream msg; 14 | msg << "[malloc] Unable to allocate " << size << " bytes."; 15 | throw std::runtime_error(msg.str()); 16 | } 17 | return buffer; 18 | } 19 | 20 | void free(Buffer buffer) { 21 | allocator().free(buffer); 22 | } 23 | 24 | } // namespace mlx::core::allocator 25 | -------------------------------------------------------------------------------- /mlx/allocator.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace mlx::core::allocator { 8 | 9 | // Simple wrapper around buffer pointers 10 | // WARNING: Only Buffer objects constructed from and those that wrap 11 | // raw pointers from mlx::allocator are supported. 12 | class Buffer { 13 | private: 14 | void* ptr_; 15 | 16 | public: 17 | Buffer(void* ptr) : ptr_(ptr) {}; 18 | 19 | // Get the raw data pointer from the buffer 20 | void* raw_ptr(); 21 | 22 | // Get the buffer pointer from the buffer 23 | const void* ptr() const { 24 | return ptr_; 25 | }; 26 | void* ptr() { 27 | return ptr_; 28 | }; 29 | }; 30 | 31 | Buffer malloc(size_t size); 32 | 33 | void free(Buffer buffer); 34 | 35 | class Allocator { 36 | /** Abstract base class for a memory allocator. */ 37 | public: 38 | virtual Buffer malloc(size_t size) = 0; 39 | virtual void free(Buffer buffer) = 0; 40 | virtual size_t size(Buffer buffer) const = 0; 41 | 42 | Allocator() = default; 43 | Allocator(const Allocator& other) = delete; 44 | Allocator(Allocator&& other) = delete; 45 | Allocator& operator=(const Allocator& other) = delete; 46 | Allocator& operator=(Allocator&& other) = delete; 47 | virtual ~Allocator() = default; 48 | }; 49 | 50 | Allocator& allocator(); 51 | 52 | } // namespace mlx::core::allocator 53 | -------------------------------------------------------------------------------- /mlx/backend/common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp 4 | ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp 7 | ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp 8 | ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp 9 | ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) 10 | -------------------------------------------------------------------------------- /mlx/backend/common/broadcasting.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/backend/common/utils.h" 4 | 5 | namespace mlx::core { 6 | 7 | void broadcast(const array& in, array& out) { 8 | if (out.size() == 0) { 9 | out.set_data(nullptr); 10 | return; 11 | } 12 | Strides strides(out.ndim(), 0); 13 | int diff = out.ndim() - in.ndim(); 14 | for (int i = in.ndim() - 1; i >= 0; --i) { 15 | strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; 16 | } 17 | auto flags = in.flags(); 18 | if (out.size() > in.size()) { 19 | flags.row_contiguous = flags.col_contiguous = false; 20 | } 21 | out.copy_shared_buffer(in, strides, flags, in.data_size()); 22 | } 23 | 24 | } // namespace mlx::core 25 | -------------------------------------------------------------------------------- /mlx/backend/common/broadcasting.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | void broadcast(const array& in, array& out); 10 | 11 | } // namespace mlx::core 12 | -------------------------------------------------------------------------------- /mlx/backend/common/copy.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | enum class CopyType { 10 | // Copy a raw scalar input into the full contiguous output 11 | Scalar, 12 | 13 | // Copy the raw input buffer contiguously into a raw output buffer of the same 14 | // size 15 | Vector, 16 | 17 | // Copy the full virtual input to the full contiguous output 18 | General, 19 | 20 | // Copy the full virtual input to the full virtual output. We assume the 21 | // input and output have the same shape. 22 | GeneralGeneral 23 | }; 24 | 25 | inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { 26 | if (ctype == CopyType::Vector) { 27 | // If the input is donateable, we are doing a vector copy and the types 28 | // have the same size, then the input buffer can hold the output. 29 | if (in.is_donatable() && in.itemsize() == out.itemsize()) { 30 | out.copy_shared_buffer(in); 31 | return true; 32 | } else { 33 | out.set_data( 34 | allocator::malloc(in.data_size() * out.itemsize()), 35 | in.data_size(), 36 | in.strides(), 37 | in.flags()); 38 | return false; 39 | } 40 | } else { 41 | out.set_data(allocator::malloc(out.nbytes())); 42 | return false; 43 | } 44 | } 45 | 46 | } // namespace mlx::core 47 | -------------------------------------------------------------------------------- /mlx/backend/common/load.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "mlx/primitives.h" 7 | #include "mlx/scheduler.h" 8 | 9 | namespace { 10 | 11 | template 12 | void swap_endianness(uint8_t* data_bytes, size_t N) { 13 | struct Elem { 14 | uint8_t bytes[scalar_size]; 15 | }; 16 | 17 | Elem* data = reinterpret_cast(data_bytes); 18 | 19 | for (size_t i = 0; i < N; i++) { 20 | for (size_t j = 0; j < (scalar_size / 2); j++) { 21 | std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]); 22 | } 23 | } 24 | } 25 | 26 | } // namespace 27 | 28 | namespace mlx::core { 29 | 30 | void Load::eval_cpu(const std::vector& inputs, array& out) { 31 | out.set_data(allocator::malloc(out.nbytes())); 32 | auto read_task = [out_ptr = out.data(), 33 | size = out.size(), 34 | itemsize = out.itemsize(), 35 | offset = offset_, 36 | reader = reader_, 37 | swap_endianness_ = swap_endianness_]() mutable { 38 | reader->read(out_ptr, size * itemsize, offset); 39 | if (swap_endianness_) { 40 | switch (itemsize) { 41 | case 2: 42 | swap_endianness<2>(reinterpret_cast(out_ptr), size); 43 | break; 44 | case 4: 45 | swap_endianness<4>(reinterpret_cast(out_ptr), size); 46 | break; 47 | case 8: 48 | swap_endianness<8>(reinterpret_cast(out_ptr), size); 49 | break; 50 | } 51 | } 52 | }; 53 | auto fut = io::thread_pool().enqueue(std::move(read_task)).share(); 54 | scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); }); 55 | } 56 | 57 | } // namespace mlx::core 58 | -------------------------------------------------------------------------------- /mlx/backend/common/reduce.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/common/utils.h" 6 | 7 | namespace mlx::core { 8 | 9 | enum ReductionOpType { 10 | // Self-explanatory. Read everything and produce 1 output. 11 | ContiguousAllReduce, 12 | 13 | // The input is contiguous and the last axis is reduced 14 | // N1xR1xN2xR2x...xNnxRn 15 | ContiguousReduce, 16 | 17 | // The input is contiguous and the last axis is not reduced 18 | // R1xN1xR2xN2x...xRnxNn 19 | ContiguousStridedReduce, 20 | 21 | // The input is not contiguous but the last axis is and it is reduced so we 22 | // need to figure out the offsets but we can call the contiguous reduce after 23 | // that. 24 | // N3xR1xN1xR4x...xRn 25 | GeneralContiguousReduce, 26 | 27 | // The input is not contiguous but the last reduction axis and the last axis 28 | // are so we need to figure out the offset but we can call the strided reduce 29 | // after that. 30 | GeneralStridedReduce, 31 | 32 | // The input is not contiguous after the reduction axis and it may contain 33 | // 0-stride axes or transpositions. We could copy the strides and produce a 34 | // transposed outcome or we can read the input out of order and write the 35 | // output in order. 36 | GeneralReduce 37 | }; 38 | 39 | struct ReductionPlan { 40 | ReductionOpType type; 41 | Shape shape; 42 | Strides strides; 43 | 44 | ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_) 45 | : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {} 46 | ReductionPlan(ReductionOpType type_) : type(type_) {} 47 | }; 48 | 49 | ReductionPlan get_reduction_plan(const array& x, const std::vector& axes); 50 | 51 | std::pair shapes_without_reduction_axes( 52 | const array& x, 53 | const std::vector& axes); 54 | 55 | } // namespace mlx::core 56 | -------------------------------------------------------------------------------- /mlx/backend/common/slicing.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | std::tuple prepare_slice( 10 | const array& in, 11 | const Shape& start_indices, 12 | const Shape& strides); 13 | 14 | void slice( 15 | const array& in, 16 | array& out, 17 | const Shape& start_indices, 18 | const Shape& strides); 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /mlx/backend/cpu/arange.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | #include "mlx/backend/cpu/encoder.h" 7 | 8 | namespace mlx::core { 9 | 10 | namespace { 11 | 12 | template 13 | void arange(T start, T next, array& out, size_t size, Stream stream) { 14 | auto ptr = out.data(); 15 | auto step_size = next - start; 16 | auto& encoder = cpu::get_command_encoder(stream); 17 | encoder.set_output_array(out); 18 | encoder.dispatch([ptr, start, step_size, size]() mutable { 19 | for (int i = 0; i < size; ++i) { 20 | ptr[i] = start; 21 | start += step_size; 22 | } 23 | }); 24 | } 25 | 26 | } // namespace 27 | 28 | } // namespace mlx::core 29 | -------------------------------------------------------------------------------- /mlx/backend/cpu/available.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cpu/available.h" 4 | 5 | namespace mlx::core::cpu { 6 | 7 | bool is_available() { 8 | return true; 9 | } 10 | 11 | } // namespace mlx::core::cpu 12 | -------------------------------------------------------------------------------- /mlx/backend/cpu/available.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | namespace mlx::core::cpu { 6 | 7 | bool is_available(); 8 | 9 | } // namespace mlx::core::cpu 10 | -------------------------------------------------------------------------------- /mlx/backend/cpu/compiled_preamble.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-24 Apple Inc. 2 | 3 | #pragma once 4 | 5 | // clang-format off 6 | #include "mlx/types/half_types.h" 7 | #include "mlx/types/complex.h" 8 | #include "mlx/backend/cpu/unary_ops.h" 9 | #include "mlx/backend/cpu/binary_ops.h" 10 | // clang-format on 11 | 12 | const char* get_kernel_preamble(); 13 | -------------------------------------------------------------------------------- /mlx/backend/cpu/copy.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/array.h" 8 | #include "mlx/backend/common/copy.h" 9 | #include "mlx/backend/common/utils.h" 10 | 11 | namespace mlx::core { 12 | 13 | void copy(const array& src, array& dst, CopyType ctype, Stream stream); 14 | void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream); 15 | 16 | void copy_inplace( 17 | const array& src, 18 | array& dst, 19 | const Shape& data_shape, 20 | const Strides& i_strides, 21 | const Strides& o_strides, 22 | int64_t i_offset, 23 | int64_t o_offset, 24 | CopyType ctype, 25 | Stream stream, 26 | const std::optional& dynamic_i_offset = std::nullopt, 27 | const std::optional& dynamic_o_offset = std::nullopt); 28 | 29 | } // namespace mlx::core 30 | -------------------------------------------------------------------------------- /mlx/backend/cpu/encoder.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cpu/encoder.h" 4 | 5 | namespace mlx::core::cpu { 6 | 7 | CommandEncoder& get_command_encoder(Stream stream) { 8 | static std::unordered_map encoder_map; 9 | auto it = encoder_map.find(stream.index); 10 | if (it == encoder_map.end()) { 11 | it = encoder_map.emplace(stream.index, stream).first; 12 | } 13 | return it->second; 14 | } 15 | 16 | } // namespace mlx::core::cpu 17 | -------------------------------------------------------------------------------- /mlx/backend/cpu/eval.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | #include "mlx/backend/cpu/eval.h" 3 | #include "mlx/backend/cpu/encoder.h" 4 | #include "mlx/primitives.h" 5 | #include "mlx/scheduler.h" 6 | #include "mlx/utils.h" 7 | 8 | namespace mlx::core::cpu { 9 | 10 | void eval(array& arr) { 11 | auto s = arr.primitive().stream(); 12 | 13 | auto outputs = arr.outputs(); 14 | { 15 | // If the array is a tracer hold a reference 16 | // to its inputs so they don't get donated 17 | std::vector inputs; 18 | if (arr.is_tracer()) { 19 | inputs = arr.inputs(); 20 | } 21 | arr.primitive().eval_cpu(arr.inputs(), outputs); 22 | } 23 | 24 | std::unordered_set> buffers; 25 | for (auto& in : arr.inputs()) { 26 | buffers.insert(in.data_shared_ptr()); 27 | } 28 | for (auto& s : arr.siblings()) { 29 | buffers.insert(s.data_shared_ptr()); 30 | } 31 | // Remove the output if it was donated to by an input 32 | if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { 33 | buffers.erase(it); 34 | } 35 | auto& encoder = cpu::get_command_encoder(s); 36 | encoder.dispatch([buffers = std::move(buffers), 37 | temps = std::move(encoder.temporaries())]() {}); 38 | } 39 | 40 | } // namespace mlx::core::cpu 41 | -------------------------------------------------------------------------------- /mlx/backend/cpu/eval.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | #include "mlx/stream.h" 7 | 8 | namespace mlx::core::cpu { 9 | 10 | void eval(array& arr); 11 | 12 | } // namespace mlx::core::cpu 13 | -------------------------------------------------------------------------------- /mlx/backend/cpu/gemm.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | #include "mlx/array.h" 5 | 6 | namespace mlx::core { 7 | 8 | template 9 | void matmul( 10 | const T* a, 11 | const T* b, 12 | T* out, 13 | bool a_transposed, 14 | bool b_transposed, 15 | size_t lda, 16 | size_t ldb, 17 | size_t ldc, 18 | float alpha, 19 | float beta, 20 | size_t batch_size, 21 | const Shape& a_shape, 22 | const Strides& a_strides, 23 | const Shape& b_shape, 24 | const Strides& b_strides); 25 | 26 | } // namespace mlx::core 27 | -------------------------------------------------------------------------------- /mlx/backend/cpu/gemms/simd_bf16.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/common/utils.h" 4 | #include "mlx/backend/cpu/gemm.h" 5 | #include "mlx/backend/cpu/gemms/simd_gemm.h" 6 | 7 | namespace mlx::core { 8 | 9 | template <> 10 | void matmul( 11 | const bfloat16_t* a, 12 | const bfloat16_t* b, 13 | bfloat16_t* out, 14 | bool a_transposed, 15 | bool b_transposed, 16 | size_t lda, 17 | size_t ldb, 18 | size_t ldc, 19 | float alpha, 20 | float beta, 21 | size_t batch_size, 22 | const Shape& a_shape, 23 | const Strides& a_strides, 24 | const Shape& b_shape, 25 | const Strides& b_strides) { 26 | auto ndim = a_shape.size(); 27 | size_t M = a_shape[ndim - 2]; 28 | size_t N = b_shape[ndim - 1]; 29 | size_t K = a_shape[ndim - 1]; 30 | for (int i = 0; i < batch_size; ++i) { 31 | simd_gemm( 32 | a + elem_to_loc(M * K * i, a_shape, a_strides), 33 | b + elem_to_loc(K * N * i, b_shape, b_strides), 34 | out + M * N * i, 35 | a_transposed, 36 | b_transposed, 37 | M, 38 | N, 39 | K, 40 | alpha, 41 | beta); 42 | } 43 | } 44 | 45 | } // namespace mlx::core 46 | -------------------------------------------------------------------------------- /mlx/backend/cpu/gemms/simd_fp16.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/common/utils.h" 4 | #include "mlx/backend/cpu/gemm.h" 5 | #include "mlx/backend/cpu/gemms/simd_gemm.h" 6 | 7 | namespace mlx::core { 8 | 9 | template <> 10 | void matmul( 11 | const float16_t* a, 12 | const float16_t* b, 13 | float16_t* out, 14 | bool a_transposed, 15 | bool b_transposed, 16 | size_t lda, 17 | size_t ldb, 18 | size_t ldc, 19 | float alpha, 20 | float beta, 21 | size_t batch_size, 22 | const Shape& a_shape, 23 | const Strides& a_strides, 24 | const Shape& b_shape, 25 | const Strides& b_strides) { 26 | auto ndim = a_shape.size(); 27 | size_t M = a_shape[ndim - 2]; 28 | size_t N = b_shape[ndim - 1]; 29 | size_t K = a_shape[ndim - 1]; 30 | for (int i = 0; i < batch_size; ++i) { 31 | simd_gemm( 32 | a + elem_to_loc(M * K * i, a_shape, a_strides), 33 | b + elem_to_loc(K * N * i, b_shape, b_strides), 34 | out + M * N * i, 35 | a_transposed, 36 | b_transposed, 37 | M, 38 | N, 39 | K, 40 | alpha, 41 | beta); 42 | } 43 | } 44 | 45 | } // namespace mlx::core 46 | -------------------------------------------------------------------------------- /mlx/backend/cpu/jit_compiler.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | #pragma once 3 | 4 | #include 5 | 6 | namespace mlx::core { 7 | 8 | class JitCompiler { 9 | public: 10 | // Build a shell command that compiles a source code file to a shared library. 11 | static std::string build_command( 12 | const std::filesystem::path& dir, 13 | const std::string& source_file_name, 14 | const std::string& shared_lib_name); 15 | 16 | // Run a command and get its output. 17 | static std::string exec(const std::string& cmd); 18 | }; 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /mlx/backend/cpu/make_compiled_preamble.ps1: -------------------------------------------------------------------------------- 1 | # This script generates a C++ function that provides the CPU 2 | # code for use with kernel generation. 3 | # 4 | # Copyright © 2024 Apple Inc. 5 | 6 | $OUTPUT_FILE = $args[0] 7 | $CL = $args[1] 8 | $SRCDIR = $args[2] 9 | 10 | # Get command result as array. 11 | $CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 12 | # Remove empty lines. 13 | # Otherwise there will be too much empty lines making the result unreadable. 14 | $CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' } 15 | # Concatenate to string. 16 | $CONTENT = $CONTENT -join "`n" 17 | 18 | # Append extra content. 19 | $CONTENT = @" 20 | $($CONTENT) 21 | using namespace mlx::core; 22 | using namespace mlx::core::detail; 23 | "@ 24 | 25 | # Convert each char to ASCII code. 26 | # Unlike the unix script that outputs string literal directly, the output from 27 | # MSVC is way too large to be embedded as string and compilation will fail, so 28 | # we store it as static array instead. 29 | $CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0' 30 | 31 | $OUTPUT = @" 32 | const char* get_kernel_preamble() { 33 | static char preamble[] = { $CHARCODES }; 34 | return preamble; 35 | } 36 | "@ 37 | 38 | Set-Content -Path $OUTPUT_FILE -Value $OUTPUT 39 | -------------------------------------------------------------------------------- /mlx/backend/cpu/make_compiled_preamble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script generates a C++ function that provides the CPU 4 | # code for use with kernel generation. 5 | # 6 | # Copyright © 2023-24 Apple Inc. 7 | 8 | 9 | OUTPUT_FILE=$1 10 | GCC=$2 11 | SRCDIR=$3 12 | CLANG=$4 13 | ARCH=$5 14 | 15 | if [ "$CLANG" = "TRUE" ]; then 16 | read -r -d '' INCLUDES <<- EOM 17 | #include 18 | #include 19 | #include 20 | #include 21 | #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC 22 | #include 23 | #endif 24 | EOM 25 | CC_FLAGS="-arch ${ARCH} -nobuiltininc -nostdinc" 26 | else 27 | CC_FLAGS="-std=c++17" 28 | fi 29 | 30 | CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E -P "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null) 31 | 32 | cat << EOF > "$OUTPUT_FILE" 33 | const char* get_kernel_preamble() { 34 | return R"preamble( 35 | $INCLUDES 36 | $CONTENT 37 | using namespace mlx::core; 38 | using namespace mlx::core::detail; 39 | )preamble"; 40 | } 41 | EOF 42 | -------------------------------------------------------------------------------- /mlx/backend/cpu/simd/accelerate_fp16_simd.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mlx/backend/cpu/simd/base_simd.h" 4 | 5 | #if MLX_SIMD_LIBRARY_VERSION < 6 6 | #include "mlx/backend/cpu/simd/neon_fp16_simd.h" 7 | #endif 8 | 9 | namespace mlx::core::simd { 10 | 11 | #if MLX_SIMD_LIBRARY_VERSION >= 6 12 | constexpr int N = 8; 13 | template 14 | struct ScalarT { 15 | using v = _Float16; 16 | }; 17 | #endif 18 | 19 | template <> 20 | inline constexpr int max_size = N; 21 | 22 | #define SIMD_FP16_DEFAULT_UNARY(op) \ 23 | template <> \ 24 | inline Simd op(Simd v) { \ 25 | Simd in = v; \ 26 | return op(in); \ 27 | } 28 | 29 | SIMD_FP16_DEFAULT_UNARY(acos) 30 | SIMD_FP16_DEFAULT_UNARY(acosh) 31 | SIMD_FP16_DEFAULT_UNARY(asin) 32 | SIMD_FP16_DEFAULT_UNARY(asinh) 33 | SIMD_FP16_DEFAULT_UNARY(atan) 34 | SIMD_FP16_DEFAULT_UNARY(atanh) 35 | SIMD_FP16_DEFAULT_UNARY(cosh) 36 | SIMD_FP16_DEFAULT_UNARY(expm1) 37 | SIMD_FP16_DEFAULT_UNARY(log) 38 | SIMD_FP16_DEFAULT_UNARY(log2) 39 | SIMD_FP16_DEFAULT_UNARY(log10) 40 | SIMD_FP16_DEFAULT_UNARY(log1p) 41 | SIMD_FP16_DEFAULT_UNARY(sinh) 42 | SIMD_FP16_DEFAULT_UNARY(tan) 43 | SIMD_FP16_DEFAULT_UNARY(tanh) 44 | 45 | #define SIMD_FP16_DEFAULT_BINARY(op) \ 46 | template <> \ 47 | inline Simd op(Simd x, Simd y) { \ 48 | Simd a = x; \ 49 | Simd b = y; \ 50 | return op(a, b); \ 51 | } 52 | SIMD_FP16_DEFAULT_BINARY(atan2) 53 | SIMD_FP16_DEFAULT_BINARY(remainder) 54 | SIMD_FP16_DEFAULT_BINARY(pow) 55 | 56 | } // namespace mlx::core::simd 57 | -------------------------------------------------------------------------------- /mlx/backend/cpu/simd/simd.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mlx/backend/cpu/simd/math.h" 4 | #include "mlx/backend/cpu/simd/type.h" 5 | -------------------------------------------------------------------------------- /mlx/backend/cpu/simd/type.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mlx/backend/cpu/simd/base_simd.h" 4 | 5 | #ifdef MLX_USE_ACCELERATE 6 | #include "mlx/backend/cpu/simd/accelerate_simd.h" 7 | #endif 8 | -------------------------------------------------------------------------------- /mlx/backend/cpu/slicing.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | std::tuple prepare_slice( 10 | const array& in, 11 | const Shape& start_indices, 12 | const Shape& strides); 13 | 14 | void shared_buffer_slice( 15 | const array& in, 16 | const Strides& out_strides, 17 | size_t data_offset, 18 | size_t data_size, 19 | array& out); 20 | 21 | } // namespace mlx::core 22 | -------------------------------------------------------------------------------- /mlx/backend/cpu/threefry.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/backend/cpu/threefry.h" 4 | 5 | namespace mlx::core::random { 6 | 7 | std::pair threefry2x32_hash( 8 | const std::pair& key, 9 | std::pair count) { 10 | constexpr static uint32_t rotations[2][4] = { 11 | {13, 15, 26, 6}, {17, 29, 16, 24}}; 12 | 13 | uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA}; 14 | 15 | count.first += ks[0]; 16 | count.second += ks[1]; 17 | 18 | for (int i = 0; i < 5; ++i) { 19 | for (auto r : rotations[i % 2]) { 20 | count.first += count.second; 21 | count.second = (count.second << r) | (count.second >> (32 - r)); 22 | count.second ^= count.first; 23 | } 24 | count.first += ks[(i + 1) % 3]; 25 | count.second += ks[(i + 2) % 3] + i + 1; 26 | } 27 | 28 | return count; 29 | } 30 | 31 | } // namespace mlx::core::random 32 | -------------------------------------------------------------------------------- /mlx/backend/cpu/threefry.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | namespace mlx::core::random { 9 | 10 | /** Applies the Threefry 2x32 hash function. 11 | * This code is based on the Jax counter-based and splittable PRNG 12 | * https://github.com/google/jax/blob/main/docs/jep/263-prng.md 13 | * 14 | * Original Threefry reference: 15 | * http://www.thesalmons.org/john/random123/papers/random123sc11.pdf 16 | */ 17 | std::pair threefry2x32_hash( 18 | const std::pair& key, 19 | std::pair count); 20 | 21 | } // namespace mlx::core::random 22 | -------------------------------------------------------------------------------- /mlx/backend/cuda/allocator.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/allocator.h" 6 | #include "mlx/backend/common/buffer_cache.h" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace mlx::core::cu { 14 | 15 | class Worker; 16 | 17 | using allocator::Buffer; 18 | 19 | // Stores cuda-managed unified memory. 20 | struct CudaBuffer { 21 | void* data; 22 | size_t size; 23 | }; 24 | 25 | class CudaAllocator : public allocator::Allocator { 26 | public: 27 | Buffer malloc(size_t size) override; 28 | void free(Buffer buffer) override; 29 | size_t size(Buffer buffer) const override; 30 | 31 | // Register current thread as safe to free buffers. 32 | // In cuda freeing a buffer implicitly synchronizes stream, and for threads 33 | // that may be waited by gpu stream (for example cpu stream threads), freeing 34 | // buffers there would result in dead lock. 35 | void register_this_thread(); 36 | 37 | size_t get_active_memory() const; 38 | size_t get_peak_memory() const; 39 | void reset_peak_memory(); 40 | size_t get_memory_limit(); 41 | size_t set_memory_limit(size_t limit); 42 | size_t get_cache_memory() const; 43 | size_t set_cache_limit(size_t limit); 44 | void clear_cache(); 45 | 46 | private: 47 | CudaAllocator(); 48 | friend CudaAllocator& allocator(); 49 | 50 | void cuda_free(CudaBuffer* buf); 51 | 52 | std::mutex worker_mutex_; 53 | std::unique_ptr worker_; 54 | std::set allowed_threads_; 55 | 56 | std::mutex mutex_; 57 | size_t memory_limit_; 58 | size_t max_pool_size_; 59 | BufferCache buffer_cache_; 60 | size_t active_memory_{0}; 61 | size_t peak_memory_{0}; 62 | }; 63 | 64 | CudaAllocator& allocator(); 65 | 66 | } // namespace mlx::core::cu 67 | -------------------------------------------------------------------------------- /mlx/backend/cuda/copy.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/gpu/copy.h" 4 | 5 | namespace mlx::core { 6 | 7 | void copy_gpu_inplace( 8 | const array& in, 9 | array& out, 10 | const Shape& data_shape, 11 | const Strides& strides_in_pre, 12 | const Strides& strides_out_pre, 13 | int64_t inp_offset, 14 | int64_t out_offset, 15 | CopyType ctype, 16 | const Stream& s, 17 | const std::optional& dynamic_i_offset /* = std::nullopt */, 18 | const std::optional& dynamic_o_offset /* = std::nullopt */) { 19 | throw std::runtime_error("copy_gpu_inplace not implemented in CUDA backend."); 20 | } 21 | 22 | void fill_gpu(const array& val, array& out, const Stream& s) { 23 | throw std::runtime_error("fill_gpu not implemented in CUDA backend."); 24 | } 25 | 26 | } // namespace mlx::core 27 | -------------------------------------------------------------------------------- /mlx/backend/cuda/event.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/stream.h" 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | namespace mlx::core::cu { 13 | 14 | class CudaEventHandle; 15 | 16 | // Wrapper of native cuda event. It can synchronize between GPU streams, or wait 17 | // on GPU stream in CPU stream, but can not wait on CPU stream. 18 | class CudaEvent { 19 | public: 20 | CudaEvent(); 21 | 22 | void wait(); 23 | void wait(cudaStream_t stream); 24 | void wait(Stream s); 25 | void record(cudaStream_t stream); 26 | void record(Stream s); 27 | 28 | // Return whether the recorded kernels have completed. Note that this method 29 | // returns true if record() has not been called. 30 | bool completed() const; 31 | 32 | bool recorded() const { 33 | return recorded_; 34 | } 35 | 36 | private: 37 | bool recorded_{false}; 38 | std::shared_ptr event_; 39 | }; 40 | 41 | // Event that can synchronize between CPU and GPU. It is much slower than 42 | // CudaEvent so the latter should always be preferred when possible. 43 | class SharedEvent { 44 | public: 45 | using Atomic = cuda::atomic; 46 | 47 | SharedEvent(); 48 | 49 | void wait(uint64_t value); 50 | void wait(cudaStream_t stream, uint64_t value); 51 | void wait(Stream s, uint64_t value); 52 | void signal(uint64_t value); 53 | void signal(cudaStream_t stream, uint64_t value); 54 | void signal(Stream s, uint64_t value); 55 | bool is_signaled(uint64_t value) const; 56 | uint64_t value() const; 57 | 58 | const std::shared_ptr& atomic() const { 59 | return ac_; 60 | } 61 | 62 | private: 63 | std::shared_ptr ac_; 64 | }; 65 | 66 | } // namespace mlx::core::cu 67 | -------------------------------------------------------------------------------- /mlx/backend/cuda/kernel_utils.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/common/utils.h" 4 | #include "mlx/backend/cuda/kernel_utils.cuh" 5 | 6 | namespace mlx::core { 7 | 8 | dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) { 9 | Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2); 10 | return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); 11 | } 12 | 13 | dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) { 14 | Dims dims = get_2d_grid_dims_common(shape, strides); 15 | return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); 16 | } 17 | 18 | dim3 get_2d_grid_dims( 19 | const Shape& shape, 20 | const Strides& strides, 21 | size_t divisor) { 22 | Dims dims = get_2d_grid_dims_common(shape, strides, divisor); 23 | return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); 24 | } 25 | 26 | } // namespace mlx::core 27 | -------------------------------------------------------------------------------- /mlx/backend/cuda/kernel_utils.cuh: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | // This file includes host-only utilies for writing CUDA kernels, the difference 4 | // from backend/cuda/kernels/utils.cuh is that the latter file only include 5 | // device-only code. 6 | 7 | #pragma once 8 | 9 | #include "mlx/array.h" 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | namespace mlx::core { 16 | 17 | // Maps CPU types to CUDA types. 18 | template 19 | struct CTypeToCudaType { 20 | using type = T; 21 | }; 22 | 23 | template <> 24 | struct CTypeToCudaType { 25 | using type = __half; 26 | }; 27 | 28 | template <> 29 | struct CTypeToCudaType { 30 | using type = __nv_bfloat16; 31 | }; 32 | 33 | template <> 34 | struct CTypeToCudaType { 35 | using type = cuComplex; 36 | }; 37 | 38 | template 39 | using cuda_type_t = typename CTypeToCudaType::type; 40 | 41 | // Compute the grid and block dimensions, check backend/common/utils.h for docs. 42 | dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); 43 | dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides); 44 | dim3 get_2d_grid_dims( 45 | const Shape& shape, 46 | const Strides& strides, 47 | size_t divisor); 48 | 49 | } // namespace mlx::core 50 | -------------------------------------------------------------------------------- /mlx/backend/cuda/kernels/arange.cuh: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | namespace mlx::core::cu { 4 | 5 | template 6 | struct Arange { 7 | const T start; 8 | const T step; 9 | 10 | __device__ T operator()(uint32_t i) const { 11 | return start + i * step; 12 | } 13 | }; 14 | 15 | } // namespace mlx::core::cu 16 | -------------------------------------------------------------------------------- /mlx/backend/cuda/slicing.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/gpu/slicing.h" 4 | 5 | namespace mlx::core { 6 | 7 | void concatenate_gpu( 8 | const std::vector& inputs, 9 | array& out, 10 | int axis, 11 | const Stream& s) { 12 | throw std::runtime_error("concatenate_gpu not implemented in CUDA backend."); 13 | } 14 | 15 | } // namespace mlx::core 16 | -------------------------------------------------------------------------------- /mlx/backend/cuda/utils.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/utils.h" 4 | #include "mlx/backend/cuda/device.h" 5 | 6 | #include 7 | 8 | namespace mlx::core { 9 | 10 | CudaStream::CudaStream(cu::Device& device) { 11 | device.make_current(); 12 | CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); 13 | } 14 | 15 | CudaStream::~CudaStream() { 16 | CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); 17 | } 18 | 19 | void check_cuda_error(const char* name, cudaError_t err) { 20 | if (err != cudaSuccess) { 21 | throw std::runtime_error( 22 | fmt::format("{} failed: {}", name, cudaGetErrorString(err))); 23 | } 24 | } 25 | 26 | } // namespace mlx::core 27 | -------------------------------------------------------------------------------- /mlx/backend/cuda/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | // This file include utilies that are used by C++ code (i.e. .cpp files). 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | namespace mlx::core { 10 | 11 | namespace cu { 12 | class Device; 13 | } 14 | 15 | // Cuda stream managed with RAII. 16 | class CudaStream { 17 | public: 18 | explicit CudaStream(cu::Device& device); 19 | ~CudaStream(); 20 | 21 | CudaStream(const CudaStream&) = delete; 22 | CudaStream& operator=(const CudaStream&) = delete; 23 | 24 | operator cudaStream_t() const { 25 | return stream_; 26 | } 27 | 28 | private: 29 | cudaStream_t stream_; 30 | }; 31 | 32 | // Throw exception if the cuda API does not succeed. 33 | void check_cuda_error(const char* name, cudaError_t err); 34 | 35 | // The macro version that prints the command that failed. 36 | #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) 37 | 38 | } // namespace mlx::core 39 | -------------------------------------------------------------------------------- /mlx/backend/cuda/worker.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/cuda/event.h" 6 | #include "mlx/backend/cuda/utils.h" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace mlx::core::cu { 14 | 15 | // Run tasks in worker thread, synchronized with cuda stream. 16 | class Worker { 17 | public: 18 | Worker(); 19 | ~Worker(); 20 | 21 | Worker(const Worker&) = delete; 22 | Worker& operator=(const Worker&) = delete; 23 | 24 | // Add a pending |task| that will run when consumed or commited. 25 | void add_task(std::function task); 26 | 27 | // Run pending tasks immediately in current thread. 28 | void consume_in_this_thread(); 29 | 30 | // Put pending tasks in a batch. 31 | void end_batch(); 32 | 33 | // Inform worker thread to run current batches now. 34 | void commit(); 35 | 36 | // Inform worker thread to run current batches after kernels in |stream| 37 | // finish running. 38 | void commit(cudaStream_t stream); 39 | 40 | // Return how many batches have been added but not committed yet. 41 | size_t uncommited_batches() const { 42 | return uncommited_batches_; 43 | } 44 | 45 | private: 46 | void thread_fn(); 47 | 48 | uint64_t batch_{0}; 49 | size_t uncommited_batches_{0}; 50 | 51 | // Cuda stream and event for signaling kernel completion. 52 | CudaStream signal_stream_; 53 | CudaEvent signal_event_; 54 | 55 | // Worker thread. 56 | SharedEvent worker_event_; 57 | std::thread worker_; 58 | std::mutex worker_mutex_; 59 | bool stop_{false}; 60 | 61 | // Tasks are put in |pending_tasks_| first, and then moved to 62 | // |worker_tasks_| when end_batch() is called. 63 | using Tasks = std::vector>; 64 | Tasks pending_tasks_; 65 | std::map worker_tasks_; 66 | }; 67 | 68 | } // namespace mlx::core::cu 69 | -------------------------------------------------------------------------------- /mlx/backend/gpu/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp 4 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp) 6 | -------------------------------------------------------------------------------- /mlx/backend/gpu/available.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | namespace mlx::core::gpu { 6 | 7 | bool is_available(); 8 | 9 | } // namespace mlx::core::gpu 10 | -------------------------------------------------------------------------------- /mlx/backend/gpu/copy.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include "mlx/backend/gpu/copy.h" 4 | #include "mlx/primitives.h" 5 | 6 | #include 7 | 8 | namespace mlx::core { 9 | 10 | void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { 11 | bool donated = set_copy_output_data(in, out, ctype); 12 | if (donated && in.dtype() == out.dtype()) { 13 | // If the output has the same type as the input then there is nothing to 14 | // copy, just use the buffer. 15 | return; 16 | } 17 | if (ctype == CopyType::GeneralGeneral) { 18 | ctype = CopyType::General; 19 | } 20 | copy_gpu_inplace(in, out, ctype, s); 21 | } 22 | 23 | void copy_gpu(const array& in, array& out, CopyType ctype) { 24 | copy_gpu(in, out, ctype, out.primitive().stream()); 25 | } 26 | 27 | void copy_gpu_inplace( 28 | const array& in, 29 | array& out, 30 | CopyType ctype, 31 | const Stream& s) { 32 | assert(in.shape() == out.shape()); 33 | return copy_gpu_inplace( 34 | in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); 35 | } 36 | 37 | void copy_gpu_inplace( 38 | const array& in, 39 | array& out, 40 | const Strides& i_strides, 41 | int64_t i_offset, 42 | CopyType ctype, 43 | const Stream& s) { 44 | assert(in.shape() == out.shape()); 45 | return copy_gpu_inplace( 46 | in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); 47 | } 48 | 49 | } // namespace mlx::core 50 | -------------------------------------------------------------------------------- /mlx/backend/gpu/copy.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/common/copy.h" 6 | #include "mlx/stream.h" 7 | 8 | #include 9 | 10 | namespace mlx::core { 11 | 12 | // Generic copy inplace 13 | void copy_gpu_inplace( 14 | const array& in, 15 | array& out, 16 | const Shape& data_shape, 17 | const Strides& i_strides, 18 | const Strides& o_strides, 19 | int64_t i_offset, 20 | int64_t o_offset, 21 | CopyType ctype, 22 | const Stream& s, 23 | const std::optional& dynamic_i_offset = std::nullopt, 24 | const std::optional& dynamic_o_offset = std::nullopt); 25 | 26 | void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); 27 | void copy_gpu(const array& src, array& out, CopyType ctype); 28 | 29 | void copy_gpu_inplace( 30 | const array& in, 31 | array& out, 32 | CopyType ctype, 33 | const Stream& s); 34 | 35 | void copy_gpu_inplace( 36 | const array& in, 37 | array& out, 38 | const Strides& i_strides, 39 | int64_t i_offset, 40 | CopyType ctype, 41 | const Stream& s); 42 | 43 | // Fill the output with the scalar val 44 | void fill_gpu(const array& val, array& out, const Stream& s); 45 | 46 | } // namespace mlx::core 47 | -------------------------------------------------------------------------------- /mlx/backend/gpu/eval.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/array.h" 9 | #include "mlx/stream.h" 10 | 11 | namespace mlx::core::gpu { 12 | 13 | void new_stream(Stream stream); 14 | void eval(array& arr); 15 | void finalize(Stream s); 16 | void synchronize(Stream s); 17 | 18 | } // namespace mlx::core::gpu 19 | -------------------------------------------------------------------------------- /mlx/backend/gpu/slicing.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/common/slicing.h" 4 | #include "mlx/backend/gpu/copy.h" 5 | #include "mlx/backend/gpu/slicing.h" 6 | 7 | namespace mlx::core { 8 | 9 | void slice_gpu( 10 | const array& in, 11 | array& out, 12 | const Shape& start_indices, 13 | const Shape& strides, 14 | const Stream& s) { 15 | slice(in, out, start_indices, strides); 16 | } 17 | 18 | void pad_gpu( 19 | const array& in, 20 | const array& val, 21 | array& out, 22 | const std::vector& axes, 23 | const Shape& low_pad_size, 24 | const Stream& s) { 25 | // Fill output with val 26 | fill_gpu(val, out, s); 27 | 28 | // Find offset for start of input values 29 | size_t data_offset = 0; 30 | for (int i = 0; i < axes.size(); i++) { 31 | auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; 32 | data_offset += out.strides()[ax] * low_pad_size[i]; 33 | } 34 | 35 | // Extract slice from output where input will be pasted 36 | array out_slice(in.shape(), out.dtype(), nullptr, {}); 37 | out_slice.copy_shared_buffer( 38 | out, out.strides(), out.flags(), out_slice.size(), data_offset); 39 | 40 | // Copy input values into the slice 41 | copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); 42 | } 43 | 44 | } // namespace mlx::core 45 | -------------------------------------------------------------------------------- /mlx/backend/gpu/slicing.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | void slice_gpu( 10 | const array& in, 11 | array& out, 12 | const Shape& start_indices, 13 | const Shape& strides, 14 | const Stream& s); 15 | 16 | void concatenate_gpu( 17 | const std::vector& inputs, 18 | array& out, 19 | int axis, 20 | const Stream& s); 21 | 22 | void pad_gpu( 23 | const array& in, 24 | const array& val, 25 | array& out, 26 | const std::vector& axes, 27 | const Shape& low_pad_size, 28 | const Stream& s); 29 | 30 | } // namespace mlx::core 31 | -------------------------------------------------------------------------------- /mlx/backend/metal/binary.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | void binary_op_gpu( 10 | const std::vector& inputs, 11 | std::vector& outputs, 12 | const std::string& op, 13 | const Stream& s); 14 | 15 | void binary_op_gpu( 16 | const std::vector& inputs, 17 | array& out, 18 | const std::string& op, 19 | const Stream& s); 20 | 21 | void binary_op_gpu_inplace( 22 | const std::vector& inputs, 23 | std::vector& outputs, 24 | const std::string& op, 25 | const Stream& s); 26 | 27 | void binary_op_gpu_inplace( 28 | const std::vector& inputs, 29 | array& out, 30 | const std::string& op, 31 | const Stream& s); 32 | 33 | } // namespace mlx::core 34 | -------------------------------------------------------------------------------- /mlx/backend/metal/distributed.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/allocator.h" 6 | #include "mlx/backend/common/utils.h" 7 | #include "mlx/backend/gpu/copy.h" 8 | #include "mlx/backend/metal/device.h" 9 | #include "mlx/backend/metal/utils.h" 10 | #include "mlx/distributed/ops.h" 11 | #include "mlx/distributed/primitives.h" 12 | #include "mlx/fence.h" 13 | #include "mlx/scheduler.h" 14 | 15 | namespace mlx::core::distributed { 16 | 17 | void AllReduce::eval_gpu(const std::vector&, std::vector&) { 18 | throw std::runtime_error("[AllReduce::eval_gpu] has no GPU implementation."); 19 | } 20 | 21 | void AllGather::eval_gpu(const std::vector&, std::vector&) { 22 | throw std::runtime_error("[AllGather::eval_gpu] has no GPU implementation."); 23 | } 24 | 25 | void Send::eval_gpu(const std::vector&, std::vector&) { 26 | throw std::runtime_error("[Send::eval_gpu] has no GPU implementation."); 27 | } 28 | 29 | void Recv::eval_gpu(const std::vector&, std::vector&) { 30 | throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation."); 31 | } 32 | 33 | } // namespace mlx::core::distributed 34 | -------------------------------------------------------------------------------- /mlx/backend/metal/jit/includes.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | namespace mlx::core::metal { 6 | 7 | const char* utils(); 8 | const char* binary_ops(); 9 | const char* unary_ops(); 10 | const char* ternary_ops(); 11 | const char* reduce_utils(); 12 | const char* gather(); 13 | const char* scatter(); 14 | 15 | const char* arange(); 16 | const char* unary(); 17 | const char* binary(); 18 | const char* binary_two(); 19 | const char* copy(); 20 | const char* fft(); 21 | const char* gather_axis(); 22 | const char* hadamard(); 23 | const char* logsumexp(); 24 | const char* quantized(); 25 | const char* ternary(); 26 | const char* scan(); 27 | const char* scatter_axis(); 28 | const char* softmax(); 29 | const char* sort(); 30 | const char* reduce(); 31 | 32 | const char* gemm(); 33 | const char* steel_gemm_fused(); 34 | const char* steel_gemm_masked(); 35 | const char* steel_gemm_splitk(); 36 | const char* steel_gemm_gather(); 37 | const char* conv(); 38 | const char* steel_conv(); 39 | const char* steel_conv_general(); 40 | const char* gemv_masked(); 41 | 42 | } // namespace mlx::core::metal 43 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/arange.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | template 3 | [[kernel]] void arange( 4 | constant const T& start, 5 | constant const T& step, 6 | device T* out, 7 | uint index [[thread_position_in_grid]]) { 8 | out[index] = start + index * step; 9 | } 10 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/arange.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | // clang-format off 4 | #include "mlx/backend/metal/kernels/utils.h" 5 | #include "mlx/backend/metal/kernels/arange.h" 6 | 7 | #define instantiate_arange(tname, type) \ 8 | instantiate_kernel("arange" #tname, arange, type) 9 | 10 | instantiate_arange(uint8, uint8_t) 11 | instantiate_arange(uint16, uint16_t) 12 | instantiate_arange(uint32, uint32_t) 13 | instantiate_arange(uint64, uint64_t) 14 | instantiate_arange(int8, int8_t) 15 | instantiate_arange(int16, int16_t) 16 | instantiate_arange(int32, int32_t) 17 | instantiate_arange(int64, int64_t) 18 | instantiate_arange(float16, half) 19 | instantiate_arange(float32, float) 20 | instantiate_arange(bfloat16, bfloat16_t) // clang-format on 21 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/defines.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #if defined __METAL__ || defined MLX_METAL_JIT 6 | #define MTL_CONST constant 7 | #else 8 | #define MTL_CONST 9 | #endif 10 | 11 | static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; 12 | static MTL_CONST constexpr int REDUCE_N_READS = 4; 13 | static MTL_CONST constexpr int REDUCE_N_WRITES = 4; 14 | static MTL_CONST constexpr int SOFTMAX_N_READS = 4; 15 | static MTL_CONST constexpr int RMS_N_READS = 4; 16 | static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; 17 | 18 | // Instantiate a templated kernel. 19 | // Extra args are used as template parameters: 20 | // e.g. instantiate_kernel(binary_int, binary, a, b) -> 21 | // [[host_name(binary_int)]] [kernel] binary 22 | #define instantiate_kernel(name, func, ...) \ 23 | template [[host_name( \ 24 | name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; 25 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/fence.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma METAL internals : enable 4 | 5 | #ifndef __METAL_MEMORY_SCOPE_SYSTEM__ 6 | #define __METAL_MEMORY_SCOPE_SYSTEM__ 3 7 | namespace metal { 8 | constexpr constant metal::thread_scope thread_scope_system = 9 | static_cast(__METAL_MEMORY_SCOPE_SYSTEM__); 10 | } 11 | #endif 12 | 13 | #include 14 | 15 | [[kernel]] void input_coherent( 16 | volatile coherent(system) device uint* input [[buffer(0)]], 17 | const constant uint& size [[buffer(1)]], 18 | uint index [[thread_position_in_grid]]) { 19 | if (index < size) { 20 | input[index] = input[index]; 21 | } 22 | metal::atomic_thread_fence( 23 | metal::mem_flags::mem_device, 24 | metal::memory_order_seq_cst, 25 | metal::thread_scope_system); 26 | } 27 | 28 | // single thread kernel to update timestamp 29 | [[kernel]] void fence_update( 30 | volatile coherent(system) device uint* timestamp [[buffer(0)]], 31 | constant uint& value [[buffer(1)]]) { 32 | timestamp[0] = value; 33 | metal::atomic_thread_fence( 34 | metal::mem_flags::mem_device, 35 | metal::memory_order_seq_cst, 36 | metal::thread_scope_system); 37 | } 38 | 39 | // single thread kernel to spin wait for timestamp value 40 | [[kernel]] void fence_wait( 41 | volatile coherent(system) device uint* timestamp [[buffer(0)]], 42 | constant uint& value [[buffer(1)]]) { 43 | while (1) { 44 | metal::atomic_thread_fence( 45 | metal::mem_flags::mem_device, 46 | metal::memory_order_seq_cst, 47 | metal::thread_scope_system); 48 | if (timestamp[0] >= value) { 49 | break; 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/gather_axis.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | template 6 | [[kernel]] void gather_axis( 7 | const device T* src [[buffer(0)]], 8 | const device IdxT* indices [[buffer(1)]], 9 | device T* out [[buffer(2)]], 10 | const constant int* shape [[buffer(3)]], 11 | const constant int64_t* src_strides [[buffer(4)]], 12 | const constant int64_t* idx_strides [[buffer(5)]], 13 | const constant size_t& ndim [[buffer(6)]], 14 | const constant int& axis [[buffer(7)]], 15 | const constant int& axis_size [[buffer(8)]], 16 | const constant size_t& src_ax_stride [[buffer(9)]], 17 | const constant size_t& idx_ax_stride [[buffer(10)]], 18 | uint3 index [[thread_position_in_grid]], 19 | uint3 grid_dim [[threads_per_grid]]) { 20 | LocT elem_idx = index.z * static_cast(grid_dim.x); 21 | LocT out_idx = elem_idx * grid_dim.y + index.x; 22 | 23 | LocT idx_loc = index.y * static_cast(idx_ax_stride); 24 | if (IdxC) { 25 | idx_loc += out_idx; 26 | } else { 27 | idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); 28 | } 29 | 30 | auto idx_val = indices[idx_loc]; 31 | if (is_signed_v) { 32 | idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val; 33 | } 34 | 35 | LocT src_idx = idx_val * static_cast(src_ax_stride); 36 | if (SrcC) { 37 | src_idx += elem_idx * axis_size + index.x; 38 | } else { 39 | src_idx += elem_to_loc(elem_idx + index.x, shape, src_strides, ndim); 40 | } 41 | 42 | out_idx += index.y * static_cast(grid_dim.x); 43 | out[out_idx] = src[src_idx]; 44 | } 45 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/indexing.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | template 8 | struct Indices { 9 | const array buffers; 10 | const constant int* shapes; 11 | const constant int64_t* strides; 12 | const constant bool* row_contiguous; 13 | const int ndim; 14 | }; 15 | 16 | template 17 | METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { 18 | if (is_unsigned_v) { 19 | return idx; 20 | } else { 21 | return (idx < 0) ? idx + size : idx; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/jit/bf16.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | // clang-format off 4 | #define jit_if #if 5 | #define jit_else #else 6 | #define jit_endif #endif 7 | 8 | jit_if (__METAL_VERSION__ >= 310) 9 | 10 | #include "mlx/backend/metal/kernels/metal_3_1/bf16.h" 11 | 12 | jit_else 13 | 14 | #include "mlx/backend/metal/kernels/metal_3_0/bf16.h" 15 | 16 | jit_endif // clang-format on 17 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/logsumexp.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | using namespace metal; 7 | 8 | // clang-format off 9 | #include "mlx/backend/metal/kernels/utils.h" 10 | #include "mlx/backend/metal/kernels/logsumexp.h" 11 | 12 | #define instantiate_logsumexp(name, itype) \ 13 | instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \ 14 | instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \ 15 | 16 | instantiate_logsumexp(float32, float) 17 | instantiate_logsumexp(float16, half) 18 | instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on 19 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/metal_3_1/bf16.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | using namespace metal; 8 | 9 | typedef bfloat bfloat16_t; 10 | inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { 11 | return as_type(x); 12 | } 13 | 14 | inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { 15 | return as_type(x); 16 | } 17 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/reduce.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "mlx/backend/metal/kernels/reduction/reduce_all.h" 3 | #include "mlx/backend/metal/kernels/reduction/reduce_col.h" 4 | #include "mlx/backend/metal/kernels/reduction/reduce_init.h" 5 | #include "mlx/backend/metal/kernels/reduction/reduce_row.h" 6 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/reduce_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/kernels/atomic.h" 6 | #include "mlx/backend/metal/kernels/reduction/ops.h" 7 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/reduction/reduce_init.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | template 4 | [[kernel]] void init_reduce( 5 | device T* out [[buffer(0)]], 6 | uint tid [[thread_position_in_grid]]) { 7 | out[tid] = Op::init; 8 | } 9 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/scatter_axis.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | template < 6 | typename T, 7 | typename IdxT, 8 | typename LocT, 9 | typename Op, 10 | bool UpdC, 11 | bool IdxC> 12 | [[kernel]] void scatter_axis( 13 | const device T* upd [[buffer(0)]], 14 | const device IdxT* indices [[buffer(1)]], 15 | device mlx_atomic* out [[buffer(2)]], 16 | const constant int* shape [[buffer(3)]], 17 | const constant int64_t* upd_strides [[buffer(4)]], 18 | const constant int64_t* idx_strides [[buffer(5)]], 19 | const constant size_t& ndim [[buffer(6)]], 20 | const constant int& axis [[buffer(7)]], 21 | const constant int& out_axis_size [[buffer(8)]], 22 | const constant size_t& upd_ax_stride [[buffer(9)]], 23 | const constant size_t& idx_ax_stride [[buffer(10)]], 24 | uint3 index [[thread_position_in_grid]], 25 | uint3 grid_dim [[threads_per_grid]]) { 26 | Op op; 27 | 28 | LocT elem_idx = index.z * static_cast(grid_dim.x); 29 | 30 | LocT idx_loc = index.y * static_cast(idx_ax_stride); 31 | if (IdxC) { 32 | idx_loc += elem_idx * grid_dim.y + index.x; 33 | } else { 34 | idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); 35 | } 36 | 37 | auto idx_val = indices[idx_loc]; 38 | if (is_signed_v) { 39 | idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val; 40 | } 41 | 42 | LocT upd_idx = index.y * static_cast(upd_ax_stride); 43 | if (UpdC) { 44 | upd_idx += elem_idx * grid_dim.y + index.x; 45 | } else { 46 | upd_idx += elem_to_loc(elem_idx + index.x, shape, upd_strides, ndim); 47 | } 48 | 49 | LocT out_idx = elem_idx * static_cast(out_axis_size) + 50 | idx_val * grid_dim.x + index.x; 51 | op.atomic_update(out, upd[upd_idx], out_idx); 52 | } 53 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/softmax.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | using namespace metal; 7 | 8 | // clang-format off 9 | #include "mlx/backend/metal/kernels/utils.h" 10 | #include "mlx/backend/metal/kernels/softmax.h" 11 | 12 | #define instantiate_softmax(name, itype) \ 13 | instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \ 14 | instantiate_kernel("looped_softmax_" #name, softmax_looped, itype) 15 | 16 | #define instantiate_softmax_precise(name, itype) \ 17 | instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \ 18 | instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float) 19 | 20 | instantiate_softmax(float32, float) 21 | instantiate_softmax(float16, half) 22 | instantiate_softmax(bfloat16, bfloat16_t) 23 | instantiate_softmax_precise(float16, half) 24 | instantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on 25 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2024-25 Apple Inc. 2 | 3 | // clang-format off 4 | #include "mlx/backend/metal/kernels/utils.h" 5 | 6 | #include "mlx/backend/metal/kernels/steel/attn/attn.h" 7 | #include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h" 8 | 9 | #define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ 10 | instantiate_kernel( \ 11 | "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ 12 | "_wm" #wm "_wn" #wn "_mask" #mname, \ 13 | attention, dtype, bq, bk, bd, wm, wn, mtype, float) 14 | 15 | #define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ 16 | instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ 17 | instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ 18 | instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) 19 | 20 | #define instantiate_attn_mask_helper(iname, itype) \ 21 | instantiate_attn_shapes_helper(iname, itype, iname, itype) \ 22 | instantiate_attn_shapes_helper(iname, itype, bool_, bool) 23 | 24 | instantiate_attn_mask_helper(float16, half); 25 | instantiate_attn_mask_helper(bfloat16, bfloat16_t); 26 | 27 | instantiate_attn_mask_helper(float32, float); 28 | // clang-format on 29 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/attn/params.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | /////////////////////////////////////////////////////////////////////////////// 6 | // Attn param classes 7 | /////////////////////////////////////////////////////////////////////////////// 8 | 9 | namespace mlx { 10 | namespace steel { 11 | 12 | struct AttnParams { 13 | int B; ///< Batch Size 14 | int H; ///< Heads 15 | int D; ///< Head Dim 16 | 17 | int qL; ///< Query Sequence Length 18 | int kL; ///< Key Sequence Length 19 | 20 | int gqa_factor; ///< Group Query factor 21 | float scale; ///< Attention scale 22 | 23 | int NQ; ///< Number of query blocks 24 | int NK; ///< Number of key/value blocks 25 | 26 | int NQ_aligned; ///< Number of full query blocks 27 | int NK_aligned; ///< Number of full key/value blocks 28 | 29 | int qL_rem; ///< Remainder in last query block 30 | int kL_rem; ///< Remainder in last key/value block 31 | int qL_off; ///< Offset in query sequence start 32 | 33 | int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) 34 | int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) 35 | int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) 36 | int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) 37 | }; 38 | 39 | struct AttnMaskParams { 40 | int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) 41 | }; 42 | 43 | } // namespace steel 44 | } // namespace mlx 45 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/attn/transforms.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/kernels/steel/utils.h" 6 | 7 | /////////////////////////////////////////////////////////////////////////////// 8 | // Transforms and Epilogues 9 | /////////////////////////////////////////////////////////////////////////////// 10 | 11 | namespace mlx { 12 | namespace steel { 13 | 14 | template 15 | struct TransformNone { 16 | static METAL_FUNC OutT apply(InT x) { 17 | return static_cast(x); 18 | } 19 | 20 | static METAL_FUNC OutT apply(InT x, OutT) { 21 | return static_cast(x); 22 | } 23 | }; 24 | 25 | template 26 | struct TransformAdd { 27 | TransformAdd(const float, const float) {} 28 | 29 | static METAL_FUNC OutT apply(InT x) { 30 | return static_cast(x); 31 | } 32 | 33 | static METAL_FUNC OutT apply(InT x, OutT c) { 34 | return static_cast(x) + c; 35 | } 36 | }; 37 | 38 | template 39 | struct TransformAxpby { 40 | const float alpha; 41 | const float beta; 42 | 43 | TransformAxpby(const float alpha_, const float beta_) 44 | : alpha(alpha_), beta(beta_) {} 45 | 46 | static METAL_FUNC OutT apply(InT x) { 47 | return static_cast(x); 48 | } 49 | 50 | METAL_FUNC OutT apply(InT x, OutT c) const { 51 | return static_cast(x * alpha + (beta * c)); 52 | } 53 | }; 54 | 55 | template 56 | struct AccumHelper { 57 | typedef float accum_type; 58 | }; 59 | 60 | struct BlockSwizzle { 61 | static METAL_FUNC int2 62 | swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { 63 | const int tid_x = (tid.x) >> swizzle_log; 64 | const int tid_y = 65 | ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); 66 | return int2(tid_x, tid_y); 67 | } 68 | }; 69 | 70 | } // namespace steel 71 | } // namespace mlx -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/conv/conv.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/kernels/steel/defines.h" 6 | #include "mlx/backend/metal/kernels/steel/utils.h" 7 | 8 | #include "mlx/backend/metal/kernels/steel/conv/loader.h" 9 | #include "mlx/backend/metal/kernels/steel/conv/params.h" 10 | #include "mlx/backend/metal/kernels/steel/gemm/mma.h" 11 | 12 | using namespace metal; 13 | using namespace mlx::steel; 14 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/conv/loader.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h" 6 | #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h" -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/conv/params.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | template 6 | struct MLXConvParams { 7 | const int N; // Batch size 8 | const int C; // In channels 9 | const int O; // Out channels 10 | const int iS[NDIM]; // Input spatial dim 11 | const int wS[NDIM]; // Weight spatial dim 12 | const int oS[NDIM]; // Output spatial dim 13 | const int str[NDIM]; // Kernel strides 14 | const int pad[NDIM]; // Input padding 15 | const int kdil[NDIM]; // Kernel dilation 16 | const int idil[NDIM]; // Input dilation 17 | const int64_t in_strides[NDIM + 2]; // In strides 18 | const int64_t wt_strides[NDIM + 2]; // Wt strides 19 | const int64_t out_strides[NDIM + 2]; // Out strides 20 | const int groups; // Input channel groups 21 | const bool flip; 22 | }; 23 | 24 | namespace mlx { 25 | namespace steel { 26 | 27 | struct ImplicitGemmConv2DParams { 28 | const int M; 29 | const int N; 30 | const int K; 31 | 32 | const int gemm_k_iterations; 33 | 34 | const int inp_jump_w; 35 | const int inp_jump_h; 36 | const int inp_jump_c; 37 | 38 | const int tiles_n; 39 | const int tiles_m; 40 | const int swizzle_log; 41 | }; 42 | 43 | struct Conv2DGeneralJumpParams { 44 | const int f_wgt_jump_h; 45 | const int f_wgt_jump_w; 46 | 47 | const int f_out_jump_h; 48 | const int f_out_jump_w; 49 | 50 | const int adj_out_h; 51 | const int adj_out_w; 52 | const int adj_out_hw; 53 | const int adj_implicit_m; 54 | }; 55 | 56 | struct Conv2DGeneralBaseInfo { 57 | int weight_base; 58 | int weight_size; 59 | }; 60 | 61 | } // namespace steel 62 | } // namespace mlx 63 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/defines.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #define STEEL_CONST static constant constexpr const 4 | #define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") 5 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/gemm/params.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | /////////////////////////////////////////////////////////////////////////////// 6 | // GEMM param classes 7 | /////////////////////////////////////////////////////////////////////////////// 8 | 9 | namespace mlx { 10 | namespace steel { 11 | 12 | struct GEMMParams { 13 | const int M; 14 | const int N; 15 | const int K; 16 | 17 | const int lda; 18 | const int ldb; 19 | const int ldd; 20 | 21 | const int tiles_n; 22 | const int tiles_m; 23 | 24 | const int64_t batch_stride_a; 25 | const int64_t batch_stride_b; 26 | const int64_t batch_stride_d; 27 | 28 | const int swizzle_log; 29 | const int gemm_k_iterations_aligned; 30 | 31 | const int batch_ndim; 32 | }; 33 | 34 | struct GEMMSpiltKParams { 35 | const int M; 36 | const int N; 37 | const int K; 38 | 39 | const int lda; 40 | const int ldb; 41 | const int ldc; 42 | 43 | const int tiles_n; 44 | const int tiles_m; 45 | 46 | const int split_k_partitions; 47 | const int split_k_partition_stride; 48 | const int split_k_partition_size; 49 | 50 | const int gemm_k_iterations_aligned; 51 | }; 52 | 53 | struct GEMMAddMMParams { 54 | const int ldc; 55 | const int fdc; 56 | 57 | const int64_t batch_stride_c; 58 | 59 | const float alpha; 60 | const float beta; 61 | }; 62 | 63 | } // namespace steel 64 | } // namespace mlx 65 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/gemm/transforms.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/kernels/steel/utils.h" 6 | 7 | /////////////////////////////////////////////////////////////////////////////// 8 | // Transforms and Epilogues 9 | /////////////////////////////////////////////////////////////////////////////// 10 | 11 | namespace mlx { 12 | namespace steel { 13 | 14 | template 15 | struct TransformNone { 16 | static METAL_FUNC OutT apply(InT x) { 17 | return static_cast(x); 18 | } 19 | 20 | static METAL_FUNC OutT apply(InT x, OutT) { 21 | return static_cast(x); 22 | } 23 | }; 24 | 25 | template 26 | struct TransformAdd { 27 | TransformAdd(const float, const float) {} 28 | 29 | static METAL_FUNC OutT apply(InT x) { 30 | return static_cast(x); 31 | } 32 | 33 | static METAL_FUNC OutT apply(InT x, OutT c) { 34 | return static_cast(x) + c; 35 | } 36 | }; 37 | 38 | template 39 | struct TransformAxpby { 40 | const float alpha; 41 | const float beta; 42 | 43 | TransformAxpby(const float alpha_, const float beta_) 44 | : alpha(alpha_), beta(beta_) {} 45 | 46 | static METAL_FUNC OutT apply(InT x) { 47 | return static_cast(x); 48 | } 49 | 50 | METAL_FUNC OutT apply(InT x, OutT c) const { 51 | return static_cast(x * alpha + (beta * c)); 52 | } 53 | }; 54 | 55 | template 56 | struct AccumHelper { 57 | typedef float accum_type; 58 | }; 59 | 60 | struct BlockSwizzle { 61 | static METAL_FUNC int2 62 | swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { 63 | const int tid_x = (tid.x) >> swizzle_log; 64 | const int tid_y = 65 | ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); 66 | return int2(tid_x, tid_y); 67 | } 68 | }; 69 | 70 | } // namespace steel 71 | } // namespace mlx -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | METAL_FUNC ulong2 elem_to_loc_broadcast( 8 | uint elem, 9 | constant const int* shape, 10 | constant const int64_t* a_strides, 11 | constant const int64_t* b_strides, 12 | int ndim) { 13 | ulong loc_a{0}; 14 | ulong loc_b{0}; 15 | for (int i = ndim - 1; i >= 0 && elem > 0; --i) { 16 | int pos_in_dim = (elem % shape[i]); 17 | elem /= shape[i]; 18 | loc_a += pos_in_dim * a_strides[i]; 19 | loc_b += pos_in_dim * b_strides[i]; 20 | } 21 | return ulong2(loc_a, loc_b); 22 | } 23 | 24 | METAL_FUNC ulong3 elem_to_loc_broadcast( 25 | uint elem, 26 | constant const int* shape, 27 | constant const int64_t* a_strides, 28 | constant const int64_t* b_strides, 29 | constant const int64_t* c_strides, 30 | int ndim) { 31 | ulong loc_a{0}; 32 | ulong loc_b{0}; 33 | ulong loc_c{0}; 34 | for (int i = ndim - 1; i >= 0 && elem > 0; --i) { 35 | int pos_in_dim = (elem % shape[i]); 36 | elem /= shape[i]; 37 | loc_a += pos_in_dim * a_strides[i]; 38 | loc_b += pos_in_dim * b_strides[i]; 39 | loc_c += pos_in_dim * c_strides[i]; 40 | } 41 | return ulong3(loc_a, loc_b, loc_c); 42 | } 43 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/utils/type_traits.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #pragma METAL internals : enable 8 | 9 | namespace metal { 10 | 11 | template 12 | struct is_empty : metal::bool_constant<__is_empty(T)> {}; 13 | 14 | #ifdef __cpp_variable_templates 15 | template 16 | constexpr constant bool is_empty_v = is_empty::value; 17 | #endif 18 | 19 | template 20 | struct make_void { 21 | typedef void type; 22 | }; 23 | 24 | template 25 | using void_t = typename make_void::type; 26 | 27 | template 28 | struct is_static : metal::bool_constant>::value> {}; 29 | 30 | template 31 | struct pointer_element {}; 32 | 33 | template 34 | struct pointer_element { 35 | using type = remove_cv_t; 36 | }; 37 | template 38 | struct pointer_element { 39 | using type = remove_cv_t; 40 | }; 41 | template 42 | struct pointer_element { 43 | using type = remove_cv_t; 44 | }; 45 | template 46 | struct pointer_element { 47 | using type = remove_cv_t; 48 | }; 49 | 50 | template 51 | using pointer_element_t = typename pointer_element>::type; 52 | 53 | } // namespace metal 54 | 55 | #pragma METAL internals : disable -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/ternary.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | // clang-format off 7 | #include "mlx/backend/metal/kernels/utils.h" 8 | #include "mlx/backend/metal/kernels/ternary_ops.h" 9 | #include "mlx/backend/metal/kernels/ternary.h" 10 | 11 | #define instantiate_ternary_all(op, tname, type) \ 12 | instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ 13 | instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ 14 | instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \ 15 | instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ 16 | instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \ 17 | instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \ 18 | instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \ 19 | instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \ 20 | instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ 21 | instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ 22 | 23 | #define instantiate_ternary_types(op) \ 24 | instantiate_ternary_all(op, bool_, bool) \ 25 | instantiate_ternary_all(op, uint8, uint8_t) \ 26 | instantiate_ternary_all(op, uint16, uint16_t) \ 27 | instantiate_ternary_all(op, uint32, uint32_t) \ 28 | instantiate_ternary_all(op, uint64, uint64_t) \ 29 | instantiate_ternary_all(op, int8, int8_t) \ 30 | instantiate_ternary_all(op, int16, int16_t) \ 31 | instantiate_ternary_all(op, int32, int32_t) \ 32 | instantiate_ternary_all(op, int64, int64_t) \ 33 | instantiate_ternary_all(op, float16, half) \ 34 | instantiate_ternary_all(op, float32, float) \ 35 | instantiate_ternary_all(op, bfloat16, bfloat16_t) \ 36 | instantiate_ternary_all(op, complex64, complex64_t) // clang-format on 37 | 38 | instantiate_ternary_types(Select) 39 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/ternary_ops.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | struct Select { 6 | template 7 | T operator()(bool condition, T x, T y) { 8 | return condition ? x : y; 9 | } 10 | }; 11 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/unary.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | template ::n> 4 | [[kernel]] void unary_v( 5 | device const T* in, 6 | device U* out, 7 | constant uint& size, 8 | uint index [[thread_position_in_grid]]) { 9 | index *= N; 10 | for (int i = 0; i < N && (index + i) < size; ++i) { 11 | out[index + i] = Op()(in[index + i]); 12 | } 13 | } 14 | 15 | template ::n> 16 | [[kernel]] void unary_v2( 17 | device const T* in, 18 | device U* out, 19 | constant int64_t& size, 20 | uint2 index [[thread_position_in_grid]], 21 | uint2 grid_dim [[threads_per_grid]]) { 22 | auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); 23 | for (int i = 0; i < N && (offset + i) < size; ++i) { 24 | out[offset + i] = Op()(in[offset + i]); 25 | } 26 | } 27 | 28 | template < 29 | typename T, 30 | typename U, 31 | typename Op, 32 | int N = 1, 33 | typename IdxT = int64_t> 34 | [[kernel]] void unary_g( 35 | device const T* in, 36 | device U* out, 37 | constant const int* in_shape, 38 | constant const int64_t* in_strides, 39 | device const int& ndim, 40 | uint3 index [[thread_position_in_grid]], 41 | uint3 grid_dim [[threads_per_grid]]) { 42 | auto idx = elem_to_loc( 43 | {N * index.x, index.y, index.z}, in_shape, in_strides, ndim); 44 | auto xshape = in_shape[ndim - 1]; 45 | IdxT xstride = in_strides[ndim - 1]; 46 | IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); 47 | for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { 48 | out[out_idx++] = Op()(in[idx]); 49 | idx += xstride; 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /mlx/backend/metal/make_compiled_preamble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script generates a C++ function that provides the Metal unary and binary 4 | # ops at runtime for use with kernel generation. 5 | # 6 | # Copyright © 2023-24 Apple Inc. 7 | 8 | OUTPUT_DIR=$1 9 | CC=$2 10 | SRC_DIR=$3 11 | SRC_FILE=$4 12 | CFLAGS=$5 13 | SRC_NAME=$(basename -- "${SRC_FILE}") 14 | JIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit 15 | INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h 16 | OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp 17 | 18 | mkdir -p "$OUTPUT_DIR" 19 | CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null) 20 | 21 | cat << EOF > "$OUTPUT_FILE" 22 | namespace mlx::core::metal { 23 | 24 | const char* $SRC_NAME() { 25 | return R"preamble( 26 | $CONTENT 27 | )preamble"; 28 | } 29 | 30 | } // namespace mlx::core::metal 31 | EOF 32 | -------------------------------------------------------------------------------- /mlx/backend/metal/matmul.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/device.h" 6 | 7 | namespace mlx::core { 8 | 9 | void steel_matmul_regular( 10 | const Stream& s, 11 | metal::Device& d, 12 | const array& a, 13 | const array& b, 14 | array& out, 15 | int M, 16 | int N, 17 | int K, 18 | int batch_size_out, 19 | int lda, 20 | int ldb, 21 | int ldd, 22 | bool transpose_a, 23 | bool transpose_b, 24 | Shape batch_shape, 25 | Strides batch_strides, 26 | int64_t A_batch_stride, 27 | int64_t B_batch_stride, 28 | int64_t matrix_stride_out, 29 | std::vector& copies); 30 | 31 | void steel_matmul( 32 | const Stream& s, 33 | metal::Device& d, 34 | const array& a, 35 | const array& b, 36 | array& out, 37 | int M, 38 | int N, 39 | int K, 40 | int batch_size_out, 41 | int lda, 42 | int ldb, 43 | bool transpose_a, 44 | bool transpose_b, 45 | std::vector& copies, 46 | Shape batch_shape = {}, 47 | Strides A_batch_stride = {}, 48 | Strides B_batch_stride = {}); 49 | 50 | } // namespace mlx::core 51 | -------------------------------------------------------------------------------- /mlx/backend/metal/metal.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace mlx::core::metal { 10 | 11 | /* Check if the Metal backend is available. */ 12 | bool is_available(); 13 | 14 | /** Capture a GPU trace, saving it to an absolute file `path` */ 15 | void start_capture(std::string path = ""); 16 | void stop_capture(); 17 | 18 | /** Get information about the GPU and system settings. */ 19 | const std::unordered_map>& 20 | device_info(); 21 | 22 | } // namespace mlx::core::metal 23 | -------------------------------------------------------------------------------- /mlx/backend/metal/no_metal.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/metal/metal.h" 6 | 7 | namespace mlx::core::metal { 8 | 9 | bool is_available() { 10 | return false; 11 | } 12 | 13 | void start_capture(std::string) {} 14 | void stop_capture() {} 15 | 16 | const std::unordered_map>& 17 | device_info() { 18 | throw std::runtime_error( 19 | "[metal::device_info] Cannot get device info without metal backend"); 20 | }; 21 | 22 | } // namespace mlx::core::metal 23 | -------------------------------------------------------------------------------- /mlx/backend/metal/reduce.h: -------------------------------------------------------------------------------- 1 | // Copyright @ 2023 - 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/common/reduce.h" 6 | #include "mlx/backend/metal/device.h" 7 | #include "mlx/stream.h" 8 | 9 | namespace mlx::core { 10 | 11 | using metal::CommandEncoder; 12 | 13 | void all_reduce_dispatch( 14 | const array& in, 15 | array& out, 16 | const std::string& op_name, 17 | CommandEncoder& compute_encoder, 18 | metal::Device& d, 19 | const Stream& s); 20 | 21 | void row_reduce_general_dispatch( 22 | const array& in, 23 | array& out, 24 | const std::string& op_name, 25 | const ReductionPlan& plan, 26 | const std::vector& axes, 27 | CommandEncoder& compute_encoder, 28 | metal::Device& d, 29 | const Stream& s); 30 | 31 | void strided_reduce_general_dispatch( 32 | const array& in, 33 | array& out, 34 | const std::string& op_name, 35 | const ReductionPlan& plan, 36 | const std::vector& axes, 37 | CommandEncoder& compute_encoder, 38 | metal::Device& d, 39 | const Stream& s); 40 | 41 | } // namespace mlx::core 42 | -------------------------------------------------------------------------------- /mlx/backend/metal/resident.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/device.h" 6 | 7 | namespace mlx::core::metal { 8 | 9 | class ResidencySet { 10 | public: 11 | ResidencySet(MTL::Device* d); 12 | ~ResidencySet(); 13 | 14 | ResidencySet(const ResidencySet&) = delete; 15 | ResidencySet& operator=(const ResidencySet&) = delete; 16 | 17 | const MTL::ResidencySet* mtl_residency_set() { 18 | return wired_set_; 19 | } 20 | 21 | void insert(MTL::Allocation* buf); 22 | void erase(MTL::Allocation* buf); 23 | 24 | void resize(size_t size); 25 | 26 | private: 27 | MTL::ResidencySet* wired_set_{nullptr}; 28 | std::unordered_set unwired_set_; 29 | size_t capacity_{0}; 30 | }; 31 | 32 | } // namespace mlx::core::metal 33 | -------------------------------------------------------------------------------- /mlx/backend/metal/slicing.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/gpu/copy.h" 6 | #include "mlx/backend/gpu/slicing.h" 7 | #include "mlx/backend/metal/device.h" 8 | 9 | namespace mlx::core { 10 | 11 | void concatenate_gpu( 12 | const std::vector& inputs, 13 | array& out, 14 | int axis, 15 | const Stream& s) { 16 | std::vector sizes; 17 | sizes.push_back(0); 18 | for (auto& p : inputs) { 19 | sizes.push_back(p.shape(axis)); 20 | } 21 | std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); 22 | 23 | out.set_data(allocator::malloc(out.nbytes())); 24 | 25 | auto strides = out.strides(); 26 | auto flags = out.flags(); 27 | flags.row_contiguous = false; 28 | flags.col_contiguous = false; 29 | flags.contiguous = false; 30 | auto& d = metal::device(s.device); 31 | auto& compute_encoder = d.get_command_encoder(s.index); 32 | auto concurrent_ctx = compute_encoder.start_concurrent(); 33 | for (int i = 0; i < inputs.size(); i++) { 34 | array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); 35 | size_t data_offset = strides[axis] * sizes[i]; 36 | out_slice.copy_shared_buffer( 37 | out, strides, flags, out_slice.size(), data_offset); 38 | copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s); 39 | } 40 | } 41 | 42 | } // namespace mlx::core 43 | -------------------------------------------------------------------------------- /mlx/backend/metal/ternary.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | void ternary_op_gpu( 10 | const std::vector& inputs, 11 | array& out, 12 | const std::string op, 13 | const Stream& s); 14 | 15 | void ternary_op_gpu_inplace( 16 | const std::vector& inputs, 17 | array& out, 18 | const std::string op, 19 | const Stream& s); 20 | 21 | } // namespace mlx::core 22 | -------------------------------------------------------------------------------- /mlx/backend/metal/unary.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | void unary_op_gpu( 10 | const std::vector& inputs, 11 | array& out, 12 | const std::string op, 13 | const Stream& s); 14 | 15 | void unary_op_gpu_inplace( 16 | const std::vector& inputs, 17 | array& out, 18 | const std::string op, 19 | const Stream& s); 20 | 21 | } // namespace mlx::core 22 | -------------------------------------------------------------------------------- /mlx/backend/no_cpu/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp 4 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp 7 | ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp) 8 | -------------------------------------------------------------------------------- /mlx/backend/no_cpu/available.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cpu/available.h" 4 | 5 | namespace mlx::core::cpu { 6 | 7 | bool is_available() { 8 | return false; 9 | } 10 | 11 | } // namespace mlx::core::cpu 12 | -------------------------------------------------------------------------------- /mlx/backend/no_cpu/compiled.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include "mlx/compile_impl.h" 4 | #include "mlx/primitives.h" 5 | 6 | namespace mlx::core { 7 | 8 | // GPU compile is always available if the GPU is available and since we are in 9 | // this file CPU compile is not available so check if the device is a GPU 10 | // device. 11 | namespace detail { 12 | bool compile_available_for_device(const Device& device) { 13 | return device == Device::gpu; 14 | } 15 | } // namespace detail 16 | 17 | void Compiled::eval_cpu( 18 | const std::vector& inputs, 19 | std::vector& outputs) { 20 | throw std::runtime_error( 21 | "[Compiled::eval_cpu] CPU compilation not supported on the platform."); 22 | } 23 | 24 | } // namespace mlx::core 25 | -------------------------------------------------------------------------------- /mlx/backend/no_gpu/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp 4 | ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp 7 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp) 8 | -------------------------------------------------------------------------------- /mlx/backend/no_gpu/apple_memory.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace { 8 | 9 | size_t get_memory_size() { 10 | size_t memsize = 0; 11 | size_t length = sizeof(memsize); 12 | sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); 13 | return memsize; 14 | } 15 | 16 | } // namespace 17 | -------------------------------------------------------------------------------- /mlx/backend/no_gpu/eval.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/gpu/available.h" 6 | #include "mlx/backend/gpu/eval.h" 7 | 8 | namespace mlx::core::gpu { 9 | 10 | bool is_available() { 11 | return false; 12 | } 13 | 14 | void new_stream(Stream) {} 15 | 16 | void eval(array&) { 17 | throw std::runtime_error("[gpu::eval] GPU backend is not available"); 18 | } 19 | 20 | void finalize(Stream) { 21 | throw std::runtime_error("[gpu::finalize] GPU backend is not available"); 22 | } 23 | 24 | void synchronize(Stream) { 25 | throw std::runtime_error("[gpu::synchronize] GPU backend is not available"); 26 | } 27 | 28 | } // namespace mlx::core::gpu 29 | -------------------------------------------------------------------------------- /mlx/backend/no_gpu/event.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/event.h" 4 | #include "mlx/scheduler.h" 5 | 6 | #include 7 | #include 8 | 9 | namespace mlx::core { 10 | 11 | struct EventCounter { 12 | uint64_t value{0}; 13 | std::mutex mtx; 14 | std::condition_variable cv; 15 | }; 16 | 17 | Event::Event(Stream stream) : stream_(stream) { 18 | auto dtor = [](void* ptr) { delete static_cast(ptr); }; 19 | event_ = std::shared_ptr(new EventCounter{}, dtor); 20 | } 21 | 22 | void Event::wait() { 23 | auto ec = static_cast(event_.get()); 24 | std::unique_lock lk(ec->mtx); 25 | if (ec->value >= value()) { 26 | return; 27 | } 28 | ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; }); 29 | } 30 | 31 | void Event::wait(Stream stream) { 32 | scheduler::enqueue(stream, [*this]() mutable { wait(); }); 33 | } 34 | 35 | void Event::signal(Stream stream) { 36 | scheduler::enqueue(stream, [*this]() mutable { 37 | auto ec = static_cast(event_.get()); 38 | { 39 | std::lock_guard lk(ec->mtx); 40 | ec->value = value(); 41 | } 42 | ec->cv.notify_all(); 43 | }); 44 | } 45 | 46 | bool Event::is_signaled() const { 47 | auto ec = static_cast(event_.get()); 48 | { 49 | std::lock_guard lk(ec->mtx); 50 | return (ec->value >= value()); 51 | } 52 | } 53 | } // namespace mlx::core 54 | -------------------------------------------------------------------------------- /mlx/backend/no_gpu/fence.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "mlx/fence.h" 7 | #include "mlx/scheduler.h" 8 | 9 | namespace mlx::core { 10 | 11 | struct FenceImpl { 12 | uint32_t count{0}; 13 | uint32_t value{0}; 14 | std::mutex mtx; 15 | std::condition_variable cv; 16 | }; 17 | 18 | Fence::Fence(Stream) { 19 | auto dtor = [](void* ptr) { delete static_cast(ptr); }; 20 | fence_ = std::shared_ptr(new FenceImpl{}, dtor); 21 | } 22 | 23 | void Fence::wait(Stream stream, const array&) { 24 | auto& f = *static_cast(fence_.get()); 25 | if (stream.device == Device::cpu) { 26 | scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable { 27 | auto& f = *static_cast(fence_.get()); 28 | std::unique_lock lk(f.mtx); 29 | if (f.value >= count) { 30 | return; 31 | } 32 | f.cv.wait(lk, [&f, count] { return f.value >= count; }); 33 | }); 34 | } else { 35 | throw std::runtime_error("[Fence::wait] Invalid stream."); 36 | } 37 | } 38 | 39 | void Fence::update(Stream stream, const array&) { 40 | auto& f = *static_cast(fence_.get()); 41 | f.count++; 42 | if (stream.device == Device::cpu) { 43 | scheduler::enqueue(stream, [count = f.count, fence_ = fence_]() mutable { 44 | auto& f = *static_cast(fence_.get()); 45 | std::unique_lock lk(f.mtx); 46 | f.value = count; 47 | f.cv.notify_all(); 48 | }); 49 | } else { 50 | throw std::runtime_error("[Fence::update] Invalid stream."); 51 | } 52 | } 53 | 54 | } // namespace mlx::core 55 | -------------------------------------------------------------------------------- /mlx/backend/no_gpu/linux_memory.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace { 8 | 9 | size_t get_memory_size() { 10 | struct sysinfo info; 11 | 12 | if (sysinfo(&info) != 0) { 13 | return 0; 14 | } 15 | 16 | size_t total_ram = info.totalram; 17 | total_ram *= info.mem_unit; 18 | 19 | return total_ram; 20 | } 21 | 22 | } // namespace 23 | -------------------------------------------------------------------------------- /mlx/compile.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | enum class CompileMode { disabled, no_simplify, no_fuse, enabled }; 10 | 11 | /** Compile takes a function and returns a compiled function. */ 12 | std::function(const std::vector&)> compile( 13 | std::function(const std::vector&)> fun, 14 | bool shapeless = false); 15 | 16 | std::function(const std::vector&)> compile( 17 | std::vector (*fun)(const std::vector&), 18 | bool shapeless = false); 19 | 20 | // Convert capture-less lambdas to function pointers. 21 | template < 22 | typename F, 23 | typename = std::enable_if_t< 24 | std::is_convertible_v())>>> 25 | std::function(const std::vector&)> compile( 26 | F&& f, 27 | bool shapeless = false) { 28 | return compile(+f, shapeless); 29 | } 30 | 31 | /** Globally disable compilation. 32 | * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also 33 | * be used to disable compilation. 34 | */ 35 | void disable_compile(); 36 | 37 | /** Globally enable compilation. 38 | * This will override the environment variable ``MLX_DISABLE_COMPILE``. 39 | */ 40 | void enable_compile(); 41 | 42 | /** Set the compiler mode to the given value. */ 43 | void set_compile_mode(CompileMode mode); 44 | } // namespace mlx::core 45 | -------------------------------------------------------------------------------- /mlx/compile_impl.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/array.h" 8 | 9 | namespace mlx::core::detail { 10 | 11 | // This is not part of the general C++ API as calling with a bad id is a bad 12 | // idea. 13 | std::function(const std::vector&)> compile( 14 | std::function(const std::vector&)> fun, 15 | std::uintptr_t fun_id, 16 | bool shapeless = false, 17 | std::vector constants = {}); 18 | 19 | // Erase cached compile functions 20 | void compile_erase(std::uintptr_t fun_id); 21 | 22 | // Clear the compiler cache causing a recompilation of all compiled functions 23 | // when called again. 24 | void compile_clear_cache(); 25 | 26 | bool compile_available_for_device(const Device& device); 27 | 28 | std::pair, std::vector> compile_trace( 29 | const std::function(const std::vector&)>& fun, 30 | const std::vector& inputs, 31 | bool shapeless); 32 | 33 | using ParentsMap = 34 | std::unordered_map>>; 35 | 36 | // Traverses the graph to build a tape and a map of array ids to their parents 37 | std::pair, ParentsMap> compile_dfs( 38 | const std::vector& inputs, 39 | const std::vector& outputs, 40 | const std::vector& original_inputs); 41 | 42 | // Simplify the tape. 43 | void compile_simplify( 44 | std::vector& tape, 45 | ParentsMap& parents_map, 46 | std::vector& outputs, 47 | int passes); 48 | 49 | std::vector compile_replace( 50 | const std::vector& tape, 51 | const std::vector& trace_inputs, 52 | const std::vector& trace_outputs, 53 | const std::vector& inputs, 54 | bool shapeless); 55 | 56 | void compile_validate_shapeless(const std::vector& tape); 57 | 58 | } // namespace mlx::core::detail 59 | -------------------------------------------------------------------------------- /mlx/device.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/cpu/available.h" 6 | #include "mlx/backend/gpu/available.h" 7 | #include "mlx/device.h" 8 | 9 | namespace mlx::core { 10 | 11 | Device& mutable_default_device() { 12 | static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu}; 13 | return default_device; 14 | } 15 | 16 | const Device& default_device() { 17 | return mutable_default_device(); 18 | } 19 | 20 | void set_default_device(const Device& d) { 21 | if (!gpu::is_available() && d == Device::gpu) { 22 | throw std::invalid_argument( 23 | "[set_default_device] Cannot set gpu device without gpu backend."); 24 | } 25 | mutable_default_device() = d; 26 | } 27 | 28 | bool operator==(const Device& lhs, const Device& rhs) { 29 | return lhs.type == rhs.type && lhs.index == rhs.index; 30 | } 31 | 32 | bool operator!=(const Device& lhs, const Device& rhs) { 33 | return !(lhs == rhs); 34 | } 35 | 36 | bool is_available(const Device& d) { 37 | switch (d.type) { 38 | case Device::cpu: 39 | return cpu::is_available(); 40 | case Device::gpu: 41 | return gpu::is_available(); 42 | } 43 | // appease compiler 44 | return false; 45 | } 46 | 47 | } // namespace mlx::core 48 | -------------------------------------------------------------------------------- /mlx/device.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | namespace mlx::core { 6 | 7 | struct Device { 8 | enum class DeviceType { 9 | cpu, 10 | gpu, 11 | }; 12 | 13 | static constexpr DeviceType cpu = DeviceType::cpu; 14 | static constexpr DeviceType gpu = DeviceType::gpu; 15 | 16 | Device(DeviceType type, int index = 0) : type(type), index(index) {} 17 | 18 | DeviceType type; 19 | int index; 20 | }; 21 | 22 | const Device& default_device(); 23 | 24 | void set_default_device(const Device& d); 25 | 26 | bool operator==(const Device& lhs, const Device& rhs); 27 | bool operator!=(const Device& lhs, const Device& rhs); 28 | 29 | bool is_available(const Device& d); 30 | 31 | } // namespace mlx::core 32 | -------------------------------------------------------------------------------- /mlx/distributed/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp 4 | ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp) 6 | 7 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) 8 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) 9 | -------------------------------------------------------------------------------- /mlx/distributed/distributed.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/array.h" 8 | 9 | namespace mlx::core::distributed { 10 | 11 | // Forward declaration of the base group implementation. 12 | namespace detail { 13 | class GroupImpl; 14 | }; 15 | 16 | /* Check if a communication backend is available */ 17 | bool is_available(); 18 | 19 | /** 20 | * A distributed::Group represents a group of independent mlx processes that 21 | * can communicate. We must also be able to create sub-groups from a group in 22 | * order to define more granular communication. 23 | */ 24 | struct Group { 25 | Group(std::shared_ptr group) : group_(std::move(group)) {} 26 | 27 | int rank() const; 28 | int size() const; 29 | 30 | /** 31 | * Split the group according to the provided color. Namely processes that use 32 | * the same color will go to the same group. 33 | * 34 | * The key defines the rank of the processes in the new group. The smaller 35 | * the key the smaller the rank. If the provided key is negative, then the 36 | * rank in the current group is used. 37 | */ 38 | Group split(int color, int key = -1) const; 39 | 40 | const std::shared_ptr& raw_group() const { 41 | return group_; 42 | } 43 | 44 | private: 45 | std::shared_ptr group_{nullptr}; 46 | }; 47 | 48 | /** 49 | * Initialize the distributed backend and return the group containing all 50 | * discoverable processes. 51 | * 52 | * If strict is true then throw an error if we couldn't initialize the 53 | * distributed subsystem. Otherwise simply return a singleton group which will 54 | * render communication operations as no-op. 55 | */ 56 | Group init(bool strict = false, const std::string& bk = "any"); 57 | 58 | } // namespace mlx::core::distributed 59 | -------------------------------------------------------------------------------- /mlx/distributed/distributed_impl.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/distributed/distributed.h" 6 | 7 | namespace mlx::core::distributed::detail { 8 | 9 | /** 10 | * Abstract base class of a distributed group implementation. 11 | */ 12 | class GroupImpl { 13 | public: 14 | virtual ~GroupImpl() {} 15 | 16 | virtual int rank() = 0; 17 | virtual int size() = 0; 18 | virtual std::shared_ptr split(int color, int key = -1) = 0; 19 | 20 | virtual void all_sum(const array& input, array& output, Stream stream) = 0; 21 | virtual void all_gather(const array& input, array& output, Stream stream) = 0; 22 | virtual void send(const array& input, int dst, Stream stream) = 0; 23 | virtual void recv(array& out, int src, Stream stream) = 0; 24 | virtual void all_max(const array& input, array& output, Stream stream) = 0; 25 | virtual void all_min(const array& input, array& output, Stream stream) = 0; 26 | }; 27 | 28 | /* Perform an all reduce sum operation */ 29 | void all_sum(Group group, const array& input, array& output, Stream stream); 30 | 31 | /* Perform an all gather operation */ 32 | void all_gather(Group group, const array& input, array& output, Stream stream); 33 | 34 | /** Send an array to the dst rank */ 35 | void send(Group group, const array& input, int dst, Stream stream); 36 | 37 | /** Recv an array from the src rank */ 38 | void recv(Group group, array& out, int src, Stream stream); 39 | 40 | /** Max reduction */ 41 | void all_max(Group group, const array& input, array& output, Stream stream); 42 | 43 | /** Min reduction */ 44 | void all_min(Group group, const array& input, array& output, Stream stream); 45 | 46 | } // namespace mlx::core::distributed::detail 47 | -------------------------------------------------------------------------------- /mlx/distributed/mpi/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(MLX_BUILD_CPU) 2 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) 3 | else() 4 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp) 5 | endif() 6 | -------------------------------------------------------------------------------- /mlx/distributed/mpi/mpi.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/distributed/distributed.h" 4 | 5 | namespace mlx::core::distributed::mpi { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available(); 10 | std::shared_ptr init(bool strict = false); 11 | 12 | } // namespace mlx::core::distributed::mpi 13 | -------------------------------------------------------------------------------- /mlx/distributed/mpi/mpi_declarations.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | // Constants 4 | 5 | #define MPI_SUCCESS 0 6 | #define MPI_ANY_SOURCE -1 7 | #define MPI_ANY_TAG -1 8 | #define MPI_IN_PLACE ((void*)1) 9 | #define MPI_MAX_LIBRARY_VERSION_STRING 256 10 | 11 | // Define all the types that we use so that we don't include which 12 | // causes linker errors on some platforms. 13 | // 14 | // NOTE: We define everything for openmpi. 15 | 16 | typedef void* MPI_Comm; 17 | typedef void* MPI_Datatype; 18 | typedef void* MPI_Op; 19 | 20 | typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*); 21 | 22 | typedef struct ompi_status_public_t { 23 | int MPI_SOURCE; 24 | int MPI_TAG; 25 | int MPI_ERROR; 26 | int _cancelled; 27 | size_t _ucount; 28 | } MPI_Status; 29 | -------------------------------------------------------------------------------- /mlx/distributed/mpi/no_mpi.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/distributed/mpi/mpi.h" 4 | 5 | namespace mlx::core::distributed::mpi { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available() { 10 | return false; 11 | } 12 | 13 | std::shared_ptr init(bool strict /* = false */) { 14 | if (strict) { 15 | throw std::runtime_error("Cannot initialize MPI"); 16 | } 17 | return nullptr; 18 | } 19 | 20 | } // namespace mlx::core::distributed::mpi 21 | -------------------------------------------------------------------------------- /mlx/distributed/ops.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/distributed/distributed.h" 8 | #include "mlx/utils.h" 9 | 10 | namespace mlx::core::distributed { 11 | 12 | array all_sum( 13 | const array& x, 14 | std::optional group = std::nullopt, 15 | StreamOrDevice s = {}); 16 | 17 | array all_gather( 18 | const array& x, 19 | std::optional group = std::nullopt, 20 | StreamOrDevice S = {}); 21 | 22 | array send( 23 | const array& x, 24 | int dst, 25 | std::optional group = std::nullopt, 26 | StreamOrDevice s = {}); 27 | 28 | array recv( 29 | Shape shape, 30 | Dtype dtype, 31 | int src, 32 | std::optional group = std::nullopt, 33 | StreamOrDevice s = {}); 34 | 35 | array recv_like( 36 | const array& x, 37 | int src, 38 | std::optional group = std::nullopt, 39 | StreamOrDevice s = {}); 40 | 41 | array all_max( 42 | const array& x, 43 | std::optional group = std::nullopt, 44 | StreamOrDevice s = {}); 45 | 46 | array all_min( 47 | const array& x, 48 | std::optional group = std::nullopt, 49 | StreamOrDevice s = {}); 50 | 51 | } // namespace mlx::core::distributed 52 | -------------------------------------------------------------------------------- /mlx/distributed/ring/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(MLX_BUILD_CPU AND NOT WIN32) 2 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp) 3 | else() 4 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ring.cpp) 5 | endif() 6 | -------------------------------------------------------------------------------- /mlx/distributed/ring/no_ring.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/distributed/ring/ring.h" 4 | 5 | namespace mlx::core::distributed::ring { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available() { 10 | return false; 11 | } 12 | 13 | std::shared_ptr init(bool strict /* = false */) { 14 | if (strict) { 15 | throw std::runtime_error("Cannot initialize ring distributed backend."); 16 | } 17 | return nullptr; 18 | } 19 | 20 | } // namespace mlx::core::distributed::ring 21 | -------------------------------------------------------------------------------- /mlx/distributed/ring/ring.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/distributed/distributed.h" 4 | 5 | namespace mlx::core::distributed::ring { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available(); 10 | std::shared_ptr init(bool strict = false); 11 | 12 | } // namespace mlx::core::distributed::ring 13 | -------------------------------------------------------------------------------- /mlx/dtype_utils.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/dtype_utils.h" 4 | 5 | namespace mlx::core { 6 | 7 | const char* dtype_to_string(Dtype arg) { 8 | if (arg == bool_) { 9 | return "bool"; 10 | } 11 | #define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ 12 | if (DTYPE == arg) { \ 13 | return #DTYPE; \ 14 | } 15 | MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) 16 | #undef SPECIALIZE_DtypeToString 17 | return "(unknown)"; 18 | } 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /mlx/einsum.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | #pragma once 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "mlx/array.h" 9 | #include "mlx/utils.h" 10 | 11 | namespace mlx::core { 12 | 13 | std::pair>, std::string> einsum_path( 14 | const std::string& subscripts, 15 | const std::vector& operands); 16 | 17 | array einsum( 18 | const std::string& subscripts, 19 | const std::vector& operands, 20 | StreamOrDevice s = {}); 21 | 22 | } // namespace mlx::core 23 | -------------------------------------------------------------------------------- /mlx/event.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | #pragma once 3 | 4 | #include 5 | #include 6 | 7 | #include "mlx/stream.h" 8 | 9 | namespace mlx::core { 10 | 11 | class Event { 12 | public: 13 | Event() {}; 14 | explicit Event(Stream stream); 15 | 16 | // Wait for the event to be signaled at its current value 17 | void wait(); 18 | 19 | // Wait in the given stream for the event to be signaled at its current value 20 | void wait(Stream stream); 21 | 22 | // Signal the event at its current value in the given stream 23 | void signal(Stream stream); 24 | 25 | // Check if the event has been signaled at its current value 26 | bool is_signaled() const; 27 | 28 | // Check if the event is valid 29 | bool valid() const { 30 | return event_ != nullptr; 31 | } 32 | 33 | uint64_t value() const { 34 | return value_; 35 | } 36 | 37 | void set_value(uint64_t v) { 38 | value_ = v; 39 | } 40 | 41 | const Stream& stream() const { 42 | if (!valid()) { 43 | throw std::runtime_error( 44 | "[Event::stream] Cannot access stream on invalid event."); 45 | } 46 | return stream_; 47 | } 48 | 49 | private: 50 | // Default constructed stream should never be used 51 | // since the event is not yet valid 52 | Stream stream_{0, Device::cpu}; 53 | std::shared_ptr event_{nullptr}; 54 | uint64_t value_{0}; 55 | }; 56 | 57 | } // namespace mlx::core 58 | -------------------------------------------------------------------------------- /mlx/export.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include "mlx/array.h" 8 | 9 | namespace mlx::core { 10 | 11 | using Args = std::vector; 12 | using Kwargs = std::unordered_map; 13 | 14 | struct FunctionExporter; 15 | 16 | /** 17 | * Make an exporter to save multiple traces of a given function to 18 | * the same file. 19 | */ 20 | FunctionExporter exporter( 21 | const std::string& file, 22 | const std::function(const Args&)>& fun, 23 | bool shapeless = false); 24 | 25 | FunctionExporter exporter( 26 | const std::string& file, 27 | const std::function(const Kwargs&)>& fun, 28 | bool shapeless = false); 29 | 30 | FunctionExporter exporter( 31 | const std::string& path, 32 | const std::function(const Args&, const Kwargs&)>& fun, 33 | bool shapeless = false); 34 | 35 | /** 36 | * Export a function to a file. 37 | */ 38 | void export_function( 39 | const std::string& file, 40 | const std::function(const Args&)>& fun, 41 | const Args& args, 42 | bool shapeless = false); 43 | 44 | void export_function( 45 | const std::string& file, 46 | const std::function(const Kwargs&)>& fun, 47 | const Kwargs& kwargs, 48 | bool shapeless = false); 49 | 50 | void export_function( 51 | const std::string& file, 52 | const std::function(const Args&, const Kwargs&)>& fun, 53 | const Args& args, 54 | const Kwargs& kwargs, 55 | bool shapeless = false); 56 | 57 | struct ImportedFunction; 58 | 59 | /** 60 | * Import a function from a file. 61 | */ 62 | ImportedFunction import_function(const std::string& file); 63 | 64 | } // namespace mlx::core 65 | 66 | #include "mlx/export_impl.h" 67 | -------------------------------------------------------------------------------- /mlx/fence.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | /* A fence to be used for synchronizing work between streams. 10 | * 11 | * Calls to `wait` wait in the given stream until all previous calls to update 12 | * are complete on their given stream. 13 | * 14 | * The array passed to `update` is computed and visible after the call to 15 | * `wait` returns. The array passed to `wait` will not be read until all 16 | * previous calls to `update` have completed. 17 | * 18 | * Note, calls to `update` should always from the same thread or explicitly 19 | * synchronized so that they occur in sequence. Calls to `wait` can be on any 20 | * thread. 21 | * 22 | * For the Metal back-end the fence supports slow (default) and fast mode. 23 | * Fast mode requires setting the environment variable 24 | * `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires Metal 3.2+ (macOS 15+, 25 | * iOS 18+). 26 | */ 27 | class Fence { 28 | public: 29 | Fence() {}; 30 | explicit Fence(Stream stream); 31 | 32 | void update(Stream stream, const array& x); 33 | void wait(Stream stream, const array& x); 34 | 35 | private: 36 | std::shared_ptr fence_{nullptr}; 37 | }; 38 | 39 | } // namespace mlx::core 40 | -------------------------------------------------------------------------------- /mlx/io/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp) 2 | 3 | if(MLX_BUILD_SAFETENSORS) 4 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp) 5 | else() 6 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp) 7 | endif() 8 | 9 | if(MLX_BUILD_GGUF) 10 | message(STATUS "Downloading gguflib") 11 | FetchContent_Declare( 12 | gguflib 13 | GIT_REPOSITORY https://github.com/antirez/gguf-tools/ 14 | GIT_TAG af7d88d808a7608a33723fba067036202910acb3) 15 | FetchContent_MakeAvailable(gguflib) 16 | target_include_directories(mlx 17 | PRIVATE $) 18 | add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c 19 | ${gguflib_SOURCE_DIR}/gguflib.c) 20 | target_link_libraries(mlx PRIVATE $) 21 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp 22 | ${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp) 23 | else() 24 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp) 25 | endif() 26 | -------------------------------------------------------------------------------- /mlx/io/gguf.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | #pragma once 3 | 4 | #include "mlx/io.h" 5 | #include "mlx/primitives.h" 6 | #include "mlx/transforms.h" 7 | #include "mlx/utils.h" 8 | 9 | extern "C" { 10 | #include 11 | } 12 | 13 | namespace mlx::core { 14 | 15 | Shape get_shape(const gguf_tensor& tensor); 16 | void gguf_load_quantized( 17 | std::unordered_map& a, 18 | const gguf_tensor& tensor); 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /mlx/io/no_gguf.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include "mlx/io.h" 4 | 5 | namespace mlx::core { 6 | 7 | GGUFLoad load_gguf(const std::string&, StreamOrDevice s) { 8 | throw std::runtime_error( 9 | "[load_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support."); 10 | } 11 | 12 | void save_gguf( 13 | std::string, 14 | std::unordered_map, 15 | std::unordered_map) { 16 | throw std::runtime_error( 17 | "[save_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support."); 18 | } 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /mlx/io/no_safetensors.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include "mlx/io.h" 4 | 5 | namespace mlx::core { 6 | 7 | SafetensorsLoad load_safetensors(std::shared_ptr, StreamOrDevice) { 8 | throw std::runtime_error( 9 | "[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " 10 | "to enable safetensors support."); 11 | } 12 | 13 | SafetensorsLoad load_safetensors(const std::string&, StreamOrDevice) { 14 | throw std::runtime_error( 15 | "[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " 16 | "to enable safetensors support."); 17 | } 18 | 19 | void save_safetensors( 20 | std::shared_ptr, 21 | std::unordered_map, 22 | std::unordered_map) { 23 | throw std::runtime_error( 24 | "[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " 25 | "to enable safetensors support."); 26 | } 27 | 28 | void save_safetensors( 29 | std::string file, 30 | std::unordered_map, 31 | std::unordered_map) { 32 | throw std::runtime_error( 33 | "[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " 34 | "to enable safetensors support."); 35 | } 36 | 37 | } // namespace mlx::core 38 | -------------------------------------------------------------------------------- /mlx/mlx.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | #include "mlx/backend/metal/metal.h" 7 | #include "mlx/compile.h" 8 | #include "mlx/device.h" 9 | #include "mlx/distributed/distributed.h" 10 | #include "mlx/distributed/ops.h" 11 | #include "mlx/einsum.h" 12 | #include "mlx/export.h" 13 | #include "mlx/fast.h" 14 | #include "mlx/fft.h" 15 | #include "mlx/io.h" 16 | #include "mlx/linalg.h" 17 | #include "mlx/memory.h" 18 | #include "mlx/ops.h" 19 | #include "mlx/random.h" 20 | #include "mlx/stream.h" 21 | #include "mlx/transforms.h" 22 | #include "mlx/utils.h" 23 | #include "mlx/version.h" 24 | -------------------------------------------------------------------------------- /mlx/stream.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/device.h" 6 | 7 | namespace mlx::core { 8 | 9 | struct Stream { 10 | int index; 11 | Device device; 12 | explicit Stream(int index, Device device) : index(index), device(device) {} 13 | }; 14 | 15 | /** Get the default stream for the given device. */ 16 | Stream default_stream(Device d); 17 | 18 | /** Make the stream the default for its device. */ 19 | void set_default_stream(Stream s); 20 | 21 | /** Make a new stream on the given device. */ 22 | Stream new_stream(Device d); 23 | 24 | /** Get the stream with the given index. */ 25 | Stream get_stream(int index); 26 | 27 | inline bool operator==(const Stream& lhs, const Stream& rhs) { 28 | return lhs.index == rhs.index; 29 | } 30 | 31 | inline bool operator!=(const Stream& lhs, const Stream& rhs) { 32 | return !(lhs == rhs); 33 | } 34 | 35 | /* Synchronize with the default stream. */ 36 | void synchronize(); 37 | 38 | /* Synchronize with the provided stream. */ 39 | void synchronize(Stream); 40 | 41 | } // namespace mlx::core 42 | -------------------------------------------------------------------------------- /mlx/types/half_types.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC 6 | 7 | #include 8 | namespace mlx::core { 9 | using ::float16_t; 10 | } // namespace mlx::core 11 | 12 | #else 13 | 14 | #define ADD_HALF_BINOPS 15 | #include "mlx/types/fp16.h" 16 | namespace mlx::core { 17 | typedef struct _MLX_Float16 float16_t; 18 | } // namespace mlx::core 19 | 20 | #endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC 21 | 22 | #ifdef __ARM_FEATURE_BF16 23 | 24 | #include 25 | namespace mlx::core { 26 | using ::bfloat16_t; 27 | } // namespace mlx::core 28 | 29 | #else 30 | 31 | #define ADD_HALF_BINOPS 32 | #include "mlx/types/bf16.h" 33 | namespace mlx::core { 34 | typedef struct _MLX_BFloat16 bfloat16_t; 35 | } // namespace mlx::core 36 | 37 | #endif // __ARM_FEATURE_BF16 38 | 39 | #ifdef ADD_HALF_BINOPS 40 | namespace mlx::core { 41 | 42 | // clang-format off 43 | #define fp16_bf16_binop_helper(__op__, __operator__) \ 44 | inline float __operator__(float16_t lhs, bfloat16_t rhs) { \ 45 | return static_cast(lhs) __op__ static_cast(rhs); \ 46 | } \ 47 | inline float __operator__(bfloat16_t lhs, float16_t rhs) { \ 48 | return static_cast(lhs) __op__ static_cast(rhs); \ 49 | } 50 | 51 | fp16_bf16_binop_helper(+, operator+) 52 | fp16_bf16_binop_helper(-, operator-) 53 | fp16_bf16_binop_helper(*, operator*) 54 | fp16_bf16_binop_helper(/, operator/) 55 | // clang-format on 56 | 57 | } // namespace mlx::core 58 | #endif 59 | -------------------------------------------------------------------------------- /mlx/types/limits.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | #pragma once 3 | 4 | #include 5 | #include "mlx/types/half_types.h" 6 | 7 | namespace mlx::core { 8 | 9 | template 10 | struct numeric_limits; 11 | 12 | template <> 13 | struct numeric_limits : public std::numeric_limits {}; 14 | 15 | template <> 16 | struct numeric_limits : public std::numeric_limits {}; 17 | 18 | template <> 19 | struct numeric_limits { 20 | private: 21 | union half_or_bits { 22 | uint16_t bits; 23 | float16_t value; 24 | }; 25 | constexpr static float16_t bits_to_half(uint16_t v) { 26 | return half_or_bits{v}.value; 27 | } 28 | 29 | public: 30 | constexpr static float16_t lowest() { 31 | return bits_to_half(0xFBFF); 32 | } 33 | static constexpr float16_t max() { 34 | return bits_to_half(0x7BFF); 35 | } 36 | static constexpr float16_t epsilon() { 37 | return bits_to_half(0x1400); 38 | } 39 | static constexpr float16_t infinity() { 40 | return bits_to_half(0x7C00); 41 | } 42 | }; 43 | 44 | template <> 45 | struct numeric_limits { 46 | private: 47 | union bfloat_or_bits { 48 | uint16_t bits; 49 | bfloat16_t value; 50 | }; 51 | constexpr static bfloat16_t bits_to_bfloat(uint16_t v) { 52 | return bfloat_or_bits{v}.value; 53 | } 54 | 55 | public: 56 | constexpr static bfloat16_t lowest() { 57 | return bits_to_bfloat(0xFF7F); 58 | } 59 | static constexpr bfloat16_t max() { 60 | return bits_to_bfloat(0x7F7F); 61 | } 62 | static constexpr bfloat16_t epsilon() { 63 | return bits_to_bfloat(0x3C00); 64 | } 65 | static constexpr bfloat16_t infinity() { 66 | return bits_to_bfloat(0x7F80); 67 | } 68 | }; 69 | 70 | } // namespace mlx::core 71 | -------------------------------------------------------------------------------- /mlx/version.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include 4 | 5 | namespace mlx::core { 6 | 7 | std::string version() { 8 | return MLX_VERSION; 9 | } 10 | 11 | } // namespace mlx::core 12 | -------------------------------------------------------------------------------- /mlx/version.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #define MLX_VERSION_MAJOR 0 6 | #define MLX_VERSION_MINOR 25 7 | #define MLX_VERSION_PATCH 2 8 | #define MLX_VERSION_NUMERIC \ 9 | (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) 10 | 11 | namespace mlx::core { 12 | 13 | /* A string representation of the MLX version in the format 14 | * "major.minor.patch". 15 | * 16 | * For dev builds, the version will include the suffix ".devYYYYMMDD+hash" 17 | */ 18 | std::string version(); 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "nanobind==2.4.0", 5 | "cmake>=3.25", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | -------------------------------------------------------------------------------- /python/mlx/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def main() -> None: 5 | from mlx.core import __version__ 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--version", 10 | action="version", 11 | version=__version__, 12 | help="Print the version number.", 13 | ) 14 | parser.add_argument( 15 | "--cmake-dir", 16 | action="store_true", 17 | help="Print the path to the MLX CMake module directory.", 18 | ) 19 | args = parser.parse_args() 20 | if args.cmake_dir: 21 | from pathlib import Path 22 | 23 | print(Path(__file__).parent) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /python/mlx/_os_warning.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | if platform.system() == "Darwin": 4 | version = tuple(map(int, platform.mac_ver()[0].split("."))) 5 | major, minor = version[0], version[1] 6 | if (major, minor) < (13, 5): 7 | raise ImportError( 8 | f"Only macOS 13.5 and newer are supported, not {major}.{minor}" 9 | ) 10 | -------------------------------------------------------------------------------- /python/mlx/_reprlib_fix.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import array 4 | import reprlib 5 | 6 | _old_repr_array = reprlib.Repr.repr_array 7 | 8 | 9 | def repr_array(self, x, maxlevel): 10 | if isinstance(x, array.array): 11 | return _old_repr_array(self, x, maxlevel) 12 | else: 13 | return self.repr_instance(x, maxlevel) 14 | 15 | 16 | reprlib.Repr.repr_array = repr_array 17 | -------------------------------------------------------------------------------- /python/mlx/_stub_patterns.txt: -------------------------------------------------------------------------------- 1 | mlx.core.distributed.__prefix__: 2 | from mlx.core import array, Dtype, Device, Stream 3 | from mlx.core.distributed import Group 4 | from typing import Sequence, Optional, Union 5 | 6 | mlx.core.fast.__prefix__: 7 | from mlx.core import array, Dtype, Device, Stream 8 | from typing import Sequence, Optional, Union 9 | 10 | mlx.core.linalg.__prefix__: 11 | from mlx.core import array, Dtype, Device, Stream 12 | from typing import Sequence, Optional, Tuple, Union 13 | 14 | mlx.core.metal.__prefix__: 15 | from mlx.core import array, Dtype, Device, Stream 16 | from typing import Sequence, Optional, Union 17 | 18 | mlx.core.random.__prefix__: 19 | from mlx.core import array, Dtype, Device, Stream 20 | from typing import Sequence, Optional, Union 21 | -------------------------------------------------------------------------------- /python/mlx/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from mlx.nn import init, losses 4 | from mlx.nn.layers import * 5 | from mlx.nn.utils import average_gradients, value_and_grad 6 | -------------------------------------------------------------------------------- /python/mlx/nn/layers/containers.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from mlx.nn.layers.base import Module 4 | 5 | 6 | class Sequential(Module): 7 | """A layer that calls the passed callables in order. 8 | 9 | We can pass either modules or plain callables to the Sequential module. If 10 | our functions have learnable parameters they should be implemented as 11 | ``nn.Module`` instances. 12 | 13 | Args: 14 | modules (tuple of Callables): The modules to call in order 15 | """ 16 | 17 | def __init__(self, *modules): 18 | super().__init__() 19 | self.layers = list(modules) 20 | 21 | def __call__(self, x): 22 | for m in self.layers: 23 | x = m(x) 24 | return x 25 | -------------------------------------------------------------------------------- /python/mlx/nn/layers/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import math 4 | 5 | import mlx.core as mx 6 | from mlx.nn.layers.base import Module 7 | from mlx.nn.layers.quantized import QuantizedEmbedding 8 | 9 | 10 | class Embedding(Module): 11 | """Implements a simple lookup table that maps each input integer to a 12 | high-dimensional vector. 13 | 14 | Typically used to embed discrete tokens for processing by neural networks. 15 | 16 | Args: 17 | num_embeddings (int): How many possible discrete tokens can we embed. 18 | Usually called the vocabulary size. 19 | dims (int): The dimensionality of the embeddings. 20 | """ 21 | 22 | def __init__(self, num_embeddings: int, dims: int): 23 | super().__init__() 24 | scale = math.sqrt(1 / dims) 25 | self.weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) 26 | 27 | def _extra_repr(self): 28 | return f"{self.weight.shape[0]}, {self.weight.shape[1]}" 29 | 30 | def __call__(self, x): 31 | return self.weight[x] 32 | 33 | def as_linear(self, x): 34 | """ 35 | Call the embedding layer as a linear layer. 36 | 37 | Use this for example when input embedding and output projection 38 | weights are tied. 39 | """ 40 | return x @ self.weight.T 41 | 42 | def to_quantized(self, group_size: int = 64, bits: int = 4): 43 | """Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer.""" 44 | return QuantizedEmbedding.from_embedding(self, group_size, bits) 45 | -------------------------------------------------------------------------------- /python/mlx/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from mlx.optimizers.optimizers import * 4 | from mlx.optimizers.schedulers import * 5 | -------------------------------------------------------------------------------- /python/mlx/py.typed: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /python/src/constants.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | namespace nb = nanobind; 7 | 8 | void init_constants(nb::module_& m) { 9 | m.attr("e") = 2.71828182845904523536028747135266249775724709369995; 10 | m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421; 11 | m.attr("inf") = std::numeric_limits::infinity(); 12 | m.attr("nan") = NAN; 13 | m.attr("newaxis") = nb::none(); 14 | m.attr("pi") = 3.1415926535897932384626433; 15 | } 16 | -------------------------------------------------------------------------------- /python/src/convert.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | #pragma once 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "mlx/array.h" 10 | #include "mlx/ops.h" 11 | 12 | namespace mx = mlx::core; 13 | namespace nb = nanobind; 14 | 15 | struct ArrayLike { 16 | ArrayLike(nb::object obj) : obj(obj) {}; 17 | nb::object obj; 18 | }; 19 | 20 | using ArrayInitType = std::variant< 21 | nb::bool_, 22 | nb::int_, 23 | nb::float_, 24 | // Must be above ndarray 25 | mx::array, 26 | // Must be above complex 27 | nb::ndarray, 28 | std::complex, 29 | nb::list, 30 | nb::tuple, 31 | ArrayLike>; 32 | 33 | mx::array nd_array_to_mlx( 34 | nb::ndarray nd_array, 35 | std::optional dtype); 36 | 37 | nb::ndarray mlx_to_np_array(const mx::array& a); 38 | nb::ndarray<> mlx_to_dlpack(const mx::array& a); 39 | 40 | nb::object to_scalar(mx::array& a); 41 | 42 | nb::object tolist(mx::array& a); 43 | 44 | mx::array create_array(ArrayInitType v, std::optional t); 45 | mx::array array_from_list(nb::list pl, std::optional dtype); 46 | mx::array array_from_list(nb::tuple pl, std::optional dtype); 47 | -------------------------------------------------------------------------------- /python/src/indexing.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/array.h" 8 | #include "python/src/utils.h" 9 | 10 | namespace mx = mlx::core; 11 | namespace nb = nanobind; 12 | 13 | mx::array mlx_get_item(const mx::array& src, const nb::object& obj); 14 | void mlx_set_item( 15 | mx::array& src, 16 | const nb::object& obj, 17 | const ScalarOrArray& v); 18 | mx::array mlx_add_item( 19 | const mx::array& src, 20 | const nb::object& obj, 21 | const ScalarOrArray& v); 22 | mx::array mlx_subtract_item( 23 | const mx::array& src, 24 | const nb::object& obj, 25 | const ScalarOrArray& v); 26 | mx::array mlx_multiply_item( 27 | const mx::array& src, 28 | const nb::object& obj, 29 | const ScalarOrArray& v); 30 | mx::array mlx_divide_item( 31 | const mx::array& src, 32 | const nb::object& obj, 33 | const ScalarOrArray& v); 34 | mx::array mlx_maximum_item( 35 | const mx::array& src, 36 | const nb::object& obj, 37 | const ScalarOrArray& v); 38 | mx::array mlx_minimum_item( 39 | const mx::array& src, 40 | const nb::object& obj, 41 | const ScalarOrArray& v); 42 | -------------------------------------------------------------------------------- /python/src/load.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include "mlx/io.h" 16 | 17 | namespace mx = mlx::core; 18 | namespace nb = nanobind; 19 | 20 | using LoadOutputTypes = std::variant< 21 | mx::array, 22 | std::unordered_map, 23 | mx::SafetensorsLoad, 24 | mx::GGUFLoad>; 25 | 26 | mx::SafetensorsLoad mlx_load_safetensor_helper( 27 | nb::object file, 28 | mx::StreamOrDevice s); 29 | void mlx_save_safetensor_helper( 30 | nb::object file, 31 | nb::dict d, 32 | std::optional m); 33 | 34 | mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s); 35 | 36 | void mlx_save_gguf_helper( 37 | nb::object file, 38 | nb::dict d, 39 | std::optional m); 40 | 41 | LoadOutputTypes mlx_load_helper( 42 | nb::object file, 43 | std::optional format, 44 | bool return_metadata, 45 | mx::StreamOrDevice s); 46 | void mlx_save_helper(nb::object file, mx::array a); 47 | void mlx_savez_helper( 48 | nb::object file, 49 | nb::args args, 50 | const nb::kwargs& kwargs, 51 | bool compressed = false); 52 | -------------------------------------------------------------------------------- /python/src/mlx.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include 4 | 5 | #define STRINGIFY(x) #x 6 | #define TOSTRING(x) STRINGIFY(x) 7 | 8 | namespace nb = nanobind; 9 | 10 | void init_mlx_func(nb::module_&); 11 | void init_array(nb::module_&); 12 | void init_device(nb::module_&); 13 | void init_stream(nb::module_&); 14 | void init_metal(nb::module_&); 15 | void init_memory(nb::module_&); 16 | void init_ops(nb::module_&); 17 | void init_transforms(nb::module_&); 18 | void init_random(nb::module_&); 19 | void init_fft(nb::module_&); 20 | void init_linalg(nb::module_&); 21 | void init_constants(nb::module_&); 22 | void init_fast(nb::module_&); 23 | void init_distributed(nb::module_&); 24 | void init_export(nb::module_&); 25 | 26 | NB_MODULE(core, m) { 27 | m.doc() = "mlx: A framework for machine learning on Apple silicon."; 28 | 29 | auto reprlib_fix = nb::module_::import_("mlx._reprlib_fix"); 30 | nb::module_::import_("mlx._os_warning"); 31 | nb::set_leak_warnings(false); 32 | 33 | init_mlx_func(m); 34 | init_device(m); 35 | init_stream(m); 36 | init_array(m); 37 | init_metal(m); 38 | init_memory(m); 39 | init_ops(m); 40 | init_transforms(m); 41 | init_random(m); 42 | init_fft(m); 43 | init_linalg(m); 44 | init_constants(m); 45 | init_fast(m); 46 | init_distributed(m); 47 | init_export(m); 48 | 49 | m.attr("__version__") = TOSTRING(_VERSION_); 50 | } 51 | -------------------------------------------------------------------------------- /python/src/mlx_func.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | namespace nb = nanobind; 11 | using namespace nb::literals; 12 | 13 | nb::callable mlx_func(nb::object func, std::vector deps); 14 | 15 | template 16 | nb::callable mlx_func(F func, Deps&&... deps) { 17 | return mlx_func( 18 | nb::cpp_function(std::move(func)), std::vector{deps.ptr()...}); 19 | } 20 | 21 | template 22 | nb::callable mlx_func(nb::object func, Deps&&... deps) { 23 | return mlx_func(std::move(func), std::vector{deps.ptr()...}); 24 | } 25 | -------------------------------------------------------------------------------- /python/src/trees.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | #pragma once 3 | #include 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mx = mlx::core; 8 | namespace nb = nanobind; 9 | 10 | void tree_visit( 11 | const std::vector& trees, 12 | std::function&)> visitor); 13 | void tree_visit(nb::handle tree, std::function visitor); 14 | 15 | nb::object tree_map( 16 | const std::vector& trees, 17 | std::function&)> transform); 18 | 19 | nb::object tree_map( 20 | nb::object tree, 21 | std::function transform); 22 | 23 | void tree_visit_update( 24 | nb::object tree, 25 | std::function visitor); 26 | 27 | /** 28 | * Fill a pytree (recursive dict or list of dict or list) in place with the 29 | * given arrays. */ 30 | void tree_fill(nb::object& tree, const std::vector& values); 31 | 32 | /** 33 | * Replace all the arrays from the src values with the dst values in the 34 | * tree. 35 | */ 36 | void tree_replace( 37 | nb::object& tree, 38 | const std::vector& src, 39 | const std::vector& dst); 40 | 41 | /** 42 | * Flatten a tree into a vector of arrays. If strict is true, then the 43 | * function will throw if the tree contains a leaf which is not an array. 44 | */ 45 | std::vector tree_flatten(nb::handle tree, bool strict = true); 46 | 47 | /** 48 | * Unflatten a tree from a vector of arrays. 49 | */ 50 | nb::object tree_unflatten( 51 | nb::object tree, 52 | const std::vector& values, 53 | int index = 0); 54 | 55 | std::pair, nb::object> tree_flatten_with_structure( 56 | nb::object tree, 57 | bool strict = true); 58 | 59 | nb::object tree_unflatten_from_structure( 60 | nb::object structure, 61 | const std::vector& values, 62 | int index = 0); 63 | -------------------------------------------------------------------------------- /python/tests/test_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import unittest 4 | 5 | import mlx.core as mx 6 | import mlx_tests 7 | import numpy as np 8 | 9 | 10 | class TestConstants(mlx_tests.MLXTestCase): 11 | def test_constants_values(self): 12 | # Check if mlx constants match expected values 13 | self.assertAlmostEqual( 14 | mx.e, 2.71828182845904523536028747135266249775724709369995 15 | ) 16 | self.assertAlmostEqual( 17 | mx.euler_gamma, 0.5772156649015328606065120900824024310421 18 | ) 19 | self.assertAlmostEqual(mx.inf, float("inf")) 20 | self.assertTrue(np.isnan(mx.nan)) 21 | self.assertIsNone(mx.newaxis) 22 | self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433) 23 | 24 | def test_constants_availability(self): 25 | # Check if mlx constants are available 26 | self.assertTrue(hasattr(mx, "e")) 27 | self.assertTrue(hasattr(mx, "euler_gamma")) 28 | self.assertTrue(hasattr(mx, "inf")) 29 | self.assertTrue(hasattr(mx, "nan")) 30 | self.assertTrue(hasattr(mx, "newaxis")) 31 | self.assertTrue(hasattr(mx, "pi")) 32 | 33 | def test_newaxis_for_reshaping_arrays(self): 34 | arr_1d = mx.array([1, 2, 3, 4, 5]) 35 | arr_2d_column = arr_1d[:, mx.newaxis] 36 | expected_result = mx.array([[1], [2], [3], [4], [5]]) 37 | self.assertTrue(mx.array_equal(arr_2d_column, expected_result)) 38 | 39 | 40 | if __name__ == "__main__": 41 | unittest.main() 42 | -------------------------------------------------------------------------------- /python/tests/test_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import io 4 | import unittest 5 | 6 | import mlx.core as mx 7 | import mlx_tests 8 | 9 | 10 | class TestGraph(mlx_tests.MLXTestCase): 11 | def test_to_dot(self): 12 | # Simply test that a few cases run. 13 | # Nothing too specific about the graph format 14 | # for now to keep it flexible 15 | a = mx.array(1.0) 16 | f = io.StringIO() 17 | mx.export_to_dot(f, a) 18 | f.seek(0) 19 | self.assertTrue(len(f.read()) > 0) 20 | 21 | b = mx.array(2.0) 22 | c = a + b 23 | f = io.StringIO() 24 | mx.export_to_dot(f, c) 25 | f.seek(0) 26 | self.assertTrue(len(f.read()) > 0) 27 | 28 | # Multi output case 29 | c = mx.divmod(a, b) 30 | f = io.StringIO() 31 | mx.export_to_dot(f, *c) 32 | f.seek(0) 33 | self.assertTrue(len(f.read()) > 0) 34 | 35 | 36 | if __name__ == "__main__": 37 | unittest.main() 38 | -------------------------------------------------------------------------------- /python/tests/test_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import unittest 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | import mlx.utils 8 | import mlx_tests 9 | 10 | 11 | class TestTreeUtils(mlx_tests.MLXTestCase): 12 | def test_tree_map(self): 13 | tree = {"a": 0, "b": 1, "c": 2} 14 | tree = mlx.utils.tree_map(lambda x: x + 1, tree) 15 | 16 | expected_tree = {"a": 1, "b": 2, "c": 3} 17 | self.assertEqual(tree, expected_tree) 18 | 19 | def test_tree_flatten(self): 20 | tree = [{"a": 1, "b": 2}, "c"] 21 | vals = (1, 2, "c") 22 | flat_tree = mlx.utils.tree_flatten(tree) 23 | self.assertEqual(list(zip(*flat_tree))[1], vals) 24 | self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree) 25 | 26 | def test_merge(self): 27 | t1 = {"a": 0} 28 | t2 = {"b": 1} 29 | t = mlx.utils.tree_merge(t1, t2) 30 | self.assertEqual({"a": 0, "b": 1}, t) 31 | with self.assertRaises(ValueError): 32 | mlx.utils.tree_merge(t1, t1) 33 | with self.assertRaises(ValueError): 34 | mlx.utils.tree_merge(t, t1) 35 | 36 | mod1 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 37 | mod2 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 38 | mod = nn.Sequential(mod1, mod2) 39 | 40 | params1 = {"layers": [mod1.parameters()]} 41 | params2 = {"layers": [None, mod2.parameters()]} 42 | params = mlx.utils.tree_merge(params1, params2) 43 | for (k1, v1), (k2, v2) in zip( 44 | mlx.utils.tree_flatten(params), mlx.utils.tree_flatten(mod.parameters()) 45 | ): 46 | self.assertEqual(k1, k2) 47 | self.assertTrue(mx.array_equal(v1, v2)) 48 | 49 | 50 | if __name__ == "__main__": 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Doctest works fine with cmake 3.5 2 | set(CMAKE_POLICY_VERSION_MINIMUM 3.5) 3 | 4 | FetchContent_Declare( 5 | doctest 6 | GIT_REPOSITORY "https://github.com/onqtam/doctest" 7 | GIT_TAG "ae7a13539fb71f270b87eb2e874fbac80bc8dda2") 8 | FetchContent_MakeAvailable(doctest) 9 | 10 | add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) 11 | 12 | if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) 13 | set(METAL_TEST_SOURCES gpu_tests.cpp) 14 | endif() 15 | 16 | include(${doctest_SOURCE_DIR}/scripts/cmake/doctest.cmake) 17 | 18 | target_sources( 19 | tests 20 | PRIVATE allocator_tests.cpp 21 | array_tests.cpp 22 | arg_reduce_tests.cpp 23 | autograd_tests.cpp 24 | blas_tests.cpp 25 | compile_tests.cpp 26 | custom_vjp_tests.cpp 27 | creations_tests.cpp 28 | device_tests.cpp 29 | einsum_tests.cpp 30 | export_import_tests.cpp 31 | eval_tests.cpp 32 | fft_tests.cpp 33 | load_tests.cpp 34 | ops_tests.cpp 35 | random_tests.cpp 36 | scheduler_tests.cpp 37 | utils_tests.cpp 38 | vmap_tests.cpp 39 | linalg_tests.cpp 40 | ${METAL_TEST_SOURCES}) 41 | 42 | target_link_libraries(tests PRIVATE mlx doctest) 43 | doctest_discover_tests(tests) 44 | add_test(NAME tests COMMAND tests) 45 | -------------------------------------------------------------------------------- /tests/allocator_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "doctest/doctest.h" 6 | 7 | #include "mlx/allocator.h" 8 | 9 | using namespace mlx::core; 10 | 11 | TEST_CASE("test simple allocations") { 12 | { 13 | auto buffer = allocator::malloc(sizeof(float)); 14 | auto fptr = static_cast(buffer.raw_ptr()); 15 | *fptr = 0.5f; 16 | CHECK_EQ(*fptr, 0.5f); 17 | allocator::free(buffer); 18 | } 19 | 20 | { 21 | auto buffer = allocator::malloc(128 * sizeof(int)); 22 | int* ptr = static_cast(buffer.raw_ptr()); 23 | for (int i = 0; i < 128; ++i) { 24 | ptr[i] = i; 25 | } 26 | allocator::free(buffer); 27 | } 28 | 29 | { 30 | auto buffer = allocator::malloc(0); 31 | allocator::free(buffer); 32 | } 33 | } 34 | 35 | TEST_CASE("test large allocations") { 36 | size_t size = 1 << 30; 37 | for (int i = 0; i < 100; ++i) { 38 | auto buffer = allocator::malloc(size); 39 | allocator::free(buffer); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /tests/custom_vjp_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include "doctest/doctest.h" 4 | 5 | #include "mlx/mlx.h" 6 | 7 | using namespace mlx::core; 8 | 9 | TEST_CASE("test simple custom vjp") { 10 | auto one = array(1.0); 11 | auto x = array(2.0); 12 | auto y = array(3.0); 13 | 14 | auto fn = [](const std::vector& inputs) { 15 | return std::vector{inputs[0] * inputs[1], inputs[0] + inputs[1]}; 16 | }; 17 | auto transformed_fn = custom_vjp( 18 | fn, 19 | [&](const std::vector&, 20 | const std::vector&, 21 | const std::vector&) { return std::vector{one, one}; }); 22 | 23 | auto [z, g] = vjp(fn, {x, y}, {one, one}); 24 | CHECK_EQ(z[0].item(), 6.0f); 25 | CHECK_EQ(z[1].item(), 5.0f); 26 | CHECK_EQ(g[0].item(), 4.0f); 27 | CHECK_EQ(g[1].item(), 3.0f); 28 | 29 | std::tie(z, g) = vjp(transformed_fn, {x, y}, {one, one}); 30 | CHECK_EQ(z[0].item(), 6.0f); 31 | CHECK_EQ(z[1].item(), 5.0f); 32 | CHECK_EQ(g[0].item(), 1.0f); 33 | CHECK_EQ(g[1].item(), 1.0f); 34 | } 35 | 36 | TEST_CASE("test checkpointing") { 37 | auto one = array(1.0); 38 | auto x = array(2.0); 39 | auto y = array(3.0); 40 | 41 | int cnt = 0; 42 | auto fn = [&cnt](const std::vector& inputs) { 43 | cnt++; 44 | auto x = inputs[0] * inputs[1]; 45 | auto y = inputs[0] + inputs[1]; 46 | return std::vector{square(x + y)}; 47 | }; 48 | auto checkpointed_fn = checkpoint(fn); 49 | 50 | auto [z, g] = vjp(checkpointed_fn, {x, y}, {one}); 51 | CHECK_EQ(z[0].item(), 121.0f); 52 | CHECK_EQ(g[0].item(), 88.0f); 53 | CHECK_EQ(g[1].item(), 66.0f); 54 | CHECK_EQ(cnt, 2); 55 | } 56 | -------------------------------------------------------------------------------- /tests/device_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "doctest/doctest.h" 4 | 5 | #include 6 | 7 | #include "mlx/mlx.h" 8 | 9 | using namespace mlx::core; 10 | 11 | TEST_CASE("test device placement") { 12 | auto device = default_device(); 13 | Device d = metal::is_available() ? Device::gpu : Device::cpu; 14 | if (std::getenv("DEVICE") == nullptr) { 15 | CHECK_EQ(device, d); 16 | } 17 | 18 | array x(1.0f); 19 | array y(1.0f); 20 | auto z = add(x, y, default_device()); 21 | if (metal::is_available()) { 22 | z = add(x, y, Device::gpu); 23 | z = add(x, y, Device(Device::gpu, 0)); 24 | } else { 25 | CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument); 26 | CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument); 27 | } 28 | 29 | // Set the default device to the CPU 30 | set_default_device(Device::cpu); 31 | CHECK_EQ(default_device(), Device::cpu); 32 | 33 | // Revert 34 | set_default_device(device); 35 | } 36 | -------------------------------------------------------------------------------- /tests/tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #define DOCTEST_CONFIG_IMPLEMENT 4 | #include "doctest/doctest.h" 5 | 6 | #include 7 | 8 | #include "mlx/mlx.h" 9 | 10 | using namespace mlx::core; 11 | 12 | int main(int argc, char** argv) { 13 | doctest::Context context; 14 | 15 | const char* device = std::getenv("DEVICE"); 16 | if (device != nullptr && std::string(device) == "cpu") { 17 | set_default_device(Device::cpu); 18 | } else if (metal::is_available()) { 19 | set_default_device(Device::gpu); 20 | } 21 | 22 | context.applyCommandLine(argc, argv); 23 | return context.run(); 24 | } 25 | --------------------------------------------------------------------------------