├── docs ├── .nojekyll ├── .clang-format ├── index.html ├── requirements.txt ├── src │ ├── _static │ │ ├── mlx_logo.png │ │ ├── mlx_logo_dark.png │ │ └── metal_debugger │ │ │ ├── capture.png │ │ │ └── schema.png │ ├── cpp │ │ └── ops.rst │ ├── python │ │ ├── cuda.rst │ │ ├── metal.rst │ │ ├── export.rst │ │ ├── fast.rst │ │ ├── optimizers │ │ │ ├── schedulers.rst │ │ │ ├── common_optimizers.rst │ │ │ └── optimizer.rst │ │ ├── fft.rst │ │ ├── memory_management.rst │ │ ├── transforms.rst │ │ ├── devices_and_streams.rst │ │ ├── linalg.rst │ │ ├── nn │ │ │ ├── losses.rst │ │ │ ├── functions.rst │ │ │ ├── module.rst │ │ │ ├── init.rst │ │ │ └── layers.rst │ │ ├── distributed.rst │ │ ├── tree_utils.rst │ │ ├── array.rst │ │ └── random.rst │ ├── _templates │ │ ├── optimizers-template.rst │ │ ├── nn-module-template.rst │ │ └── module-base-class.rst │ └── usage │ │ └── using_streams.rst ├── .gitignore ├── Makefile └── README.md ├── python ├── mlx │ ├── py.typed │ ├── optimizers │ │ └── __init__.py │ ├── nn │ │ ├── __init__.py │ │ └── layers │ │ │ └── containers.py │ ├── _reprlib_fix.py │ ├── __main__.py │ └── _stub_patterns.txt ├── tests │ ├── __main__.py │ ├── test_graph.py │ └── test_constants.py ├── src │ ├── cuda.cpp │ ├── constants.cpp │ ├── mlx_func.h │ ├── indexing.h │ ├── mlx.cpp │ ├── convert.h │ └── load.h └── scripts │ ├── repair_linux.sh │ ├── repair_cuda.sh │ └── repair_record.py ├── mlx ├── 3rdparty │ └── .clang-format ├── backend │ ├── cpu │ │ ├── simd │ │ │ ├── simd.h │ │ │ └── type.h │ │ ├── available.h │ │ ├── available.cpp │ │ ├── eval.h │ │ ├── compiled_preamble.h │ │ ├── encoder.cpp │ │ ├── slicing.h │ │ ├── gemm.h │ │ ├── jit_compiler.h │ │ ├── threefry.h │ │ ├── arange.h │ │ ├── threefry.cpp │ │ ├── make_compiled_preamble.sh │ │ ├── copy.h │ │ ├── gemms │ │ │ ├── simd_fp16.cpp │ │ │ └── simd_bf16.cpp │ │ ├── eval.cpp │ │ └── make_compiled_preamble.ps1 │ ├── cuda │ │ ├── unary │ │ │ ├── abs.cu │ │ │ ├── ceil.cu │ │ │ ├── cos.cu │ │ │ ├── cosh.cu │ │ │ ├── erf.cu │ │ │ ├── exp.cu │ │ │ ├── imag.cu │ │ │ ├── real.cu │ │ │ ├── sign.cu │ │ │ ├── sin.cu │ │ │ ├── sinh.cu │ │ │ ├── tan.cu │ │ │ ├── tanh.cu │ │ │ ├── arccos.cu │ │ │ ├── arcsin.cu │ │ │ ├── arctan.cu │ │ │ ├── erf_inv.cu │ │ │ ├── expm1.cu │ │ │ ├── floor.cu │ │ │ ├── log1p.cu │ │ │ ├── square.cu │ │ │ ├── arccosh.cu │ │ │ ├── arcsinh.cu │ │ │ ├── arctanh.cu │ │ │ ├── conjugate.cu │ │ │ ├── negative.cu │ │ │ ├── sigmoid.cu │ │ │ ├── logical_not.cu │ │ │ ├── bitwise_invert.cu │ │ │ ├── sqrt.cu │ │ │ ├── round.cu │ │ │ └── log.cu │ │ ├── binary │ │ │ ├── add.cu │ │ │ ├── less.cu │ │ │ ├── arctan2.cu │ │ │ ├── divide.cu │ │ │ ├── greater.cu │ │ │ ├── maximum.cu │ │ │ ├── minimum.cu │ │ │ ├── power.cu │ │ │ ├── less_equal.cu │ │ │ ├── logical_or.cu │ │ │ ├── multiply.cu │ │ │ ├── not_equal.cu │ │ │ ├── remainder.cu │ │ │ ├── subtract.cu │ │ │ ├── log_add_exp.cu │ │ │ ├── logical_and.cu │ │ │ ├── greater_equal.cu │ │ │ ├── equal.cu │ │ │ ├── bitwise_binary.cu │ │ │ └── CMakeLists.txt │ │ ├── cuda.cpp │ │ ├── cuda.h │ │ ├── steel │ │ │ └── defines.cuh │ │ ├── detect_cuda_arch.sh │ │ ├── device │ │ │ ├── ternary_ops.cuh │ │ │ ├── config.h │ │ │ ├── indexing.cuh │ │ │ └── scatter_ops.cuh │ │ ├── gemms │ │ │ └── gemv.h │ │ ├── quantized │ │ │ ├── convert_fp8.cu │ │ │ ├── quantized.h │ │ │ └── quantized_utils.cuh │ │ ├── no_cuda.cpp │ │ ├── utils.h │ │ ├── primitives.cpp │ │ ├── copy │ │ │ └── copy.cuh │ │ ├── worker.h │ │ └── fence.cpp │ ├── gpu │ │ ├── available.h │ │ ├── CMakeLists.txt │ │ ├── eval.h │ │ ├── slicing.h │ │ └── slicing.cpp │ ├── metal │ │ ├── kernels │ │ │ ├── reduce_utils.h │ │ │ ├── steel │ │ │ │ ├── conv │ │ │ │ │ ├── loader.h │ │ │ │ │ └── conv.h │ │ │ │ ├── defines.h │ │ │ │ ├── utils.h │ │ │ │ ├── attn │ │ │ │ │ ├── kernels │ │ │ │ │ │ └── steel_attention.metal │ │ │ │ │ └── params.h │ │ │ │ ├── utils │ │ │ │ │ └── type_traits.h │ │ │ │ └── gemm │ │ │ │ │ └── params.h │ │ │ ├── ternary_ops.h │ │ │ ├── reduction │ │ │ │ └── reduce_init.h │ │ │ ├── reduce.h │ │ │ ├── arange.h │ │ │ ├── bf16.h │ │ │ ├── indexing │ │ │ │ ├── indexing.h │ │ │ │ ├── gather_front.h │ │ │ │ └── masked_scatter.h │ │ │ ├── logsumexp.metal │ │ │ ├── arange.metal │ │ │ ├── defines.h │ │ │ ├── softmax.metal │ │ │ └── fp4.h │ │ ├── scan.h │ │ ├── unary.h │ │ ├── ternary.h │ │ ├── metal.h │ │ ├── binary.h │ │ ├── resident.h │ │ ├── make_compiled_preamble.sh │ │ ├── no_metal.cpp │ │ ├── reduce.h │ │ ├── jit │ │ │ └── includes.h │ │ └── distributed.cpp │ ├── common │ │ ├── broadcasting.h │ │ ├── CMakeLists.txt │ │ ├── slicing.h │ │ ├── broadcasting.cpp │ │ ├── unary.h │ │ └── copy.h │ ├── no_cpu │ │ ├── available.cpp │ │ ├── CMakeLists.txt │ │ └── compiled.cpp │ └── no_gpu │ │ ├── CMakeLists.txt │ │ ├── apple_memory.h │ │ ├── linux_memory.h │ │ ├── eval.cpp │ │ └── event.cpp ├── version.cpp ├── distributed │ ├── nccl │ │ ├── nccl_stub │ │ │ ├── nccl_stubs.cpp │ │ │ └── CMakeLists.txt │ │ ├── nccl.h │ │ ├── no_nccl.cpp │ │ └── CMakeLists.txt │ ├── mpi │ │ ├── CMakeLists.txt │ │ ├── mpi.h │ │ ├── no_mpi.cpp │ │ └── mpi_declarations.h │ ├── ring │ │ ├── CMakeLists.txt │ │ ├── ring.h │ │ └── no_ring.cpp │ ├── jaccl │ │ ├── jaccl.h │ │ ├── CMakeLists.txt │ │ └── no_jaccl.cpp │ ├── CMakeLists.txt │ ├── reduction_ops.h │ └── ops.h ├── io │ ├── gguf.h │ ├── no_gguf.cpp │ ├── CMakeLists.txt │ └── no_safetensors.cpp ├── einsum.h ├── version.h ├── allocator.cpp ├── mlx.h ├── device.h ├── dtype_utils.cpp ├── stream.h ├── device.cpp ├── fence.h ├── allocator.h ├── event.h └── compile.h ├── tests_jaccl ├── jaccl_config.json ├── run_jaccl.sh ├── deploy.sh ├── build_cpp_benchmark.sh └── README.md ├── examples ├── extensions │ ├── requirements.txt │ ├── mlx_sample_extensions │ │ └── __init__.py │ ├── pyproject.toml │ ├── README.md │ ├── test.py │ ├── setup.py │ └── bindings.cpp ├── cmake_project │ ├── example.cpp │ ├── README.md │ └── CMakeLists.txt ├── cpp │ ├── timer.h │ ├── CMakeLists.txt │ ├── distributed.cpp │ ├── metal_capture.cpp │ └── logistic_regression.cpp ├── export │ ├── eval_mlp.cpp │ ├── CMakeLists.txt │ ├── README.md │ ├── train_mlp.cpp │ └── eval_mlp.py └── python │ ├── logistic_regression.py │ └── linear_regression.py ├── .github ├── dependabot.yml ├── actions │ ├── build-cuda-release │ │ └── action.yml │ ├── setup-macos │ │ └── action.yml │ ├── build-linux │ │ └── action.yml │ ├── build-cuda │ │ └── action.yml │ ├── build-macos-release │ │ └── action.yml │ ├── build-docs │ │ └── action.yml │ └── build-linux-release │ │ └── action.yml ├── pull_request_template.md ├── ISSUE_TEMPLATE │ └── bug_report.md ├── workflows │ └── documentation.yml └── scripts │ └── setup+build-cpp-linux-fedora-container.sh ├── pyproject.toml ├── MANIFEST.in ├── cmake └── Findnvpl.cmake ├── requirements-dev.txt ├── benchmarks ├── cpp │ ├── CMakeLists.txt │ ├── compare_devices.cpp │ ├── autograd.cpp │ └── time_utils.h ├── numpy │ ├── time_utils.py │ └── single_ops.py └── python │ ├── comparative │ └── README.md │ ├── rope_bench.py │ ├── time_utils.py │ └── synchronize_bench.py ├── tests ├── tests.cpp ├── device_tests.cpp └── allocator_tests.cpp ├── CITATION.cff ├── .pre-commit-config.yaml ├── LICENSE ├── create_library_symlinks.sh ├── .gitignore ├── CONTRIBUTING.md └── sync_to_m3u.sh /docs/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/mlx/py.typed: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/.clang-format: -------------------------------------------------------------------------------- 1 | DisableFormat: true 2 | SortIncludes: Never 3 | -------------------------------------------------------------------------------- /mlx/3rdparty/.clang-format: -------------------------------------------------------------------------------- 1 | DisableFormat: true 2 | SortIncludes: Never 3 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_jaccl/jaccl_config.json: -------------------------------------------------------------------------------- 1 | [ 2 | [null, "rdma_en2"], 3 | ["rdma_en5", null] 4 | ] -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | breathe 3 | sphinx-book-theme 4 | sphinx-copybutton 5 | mlx 6 | -------------------------------------------------------------------------------- /examples/extensions/requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools>=42 2 | cmake>=3.25 3 | mlx>=0.21.0 4 | nanobind==2.4.0 5 | -------------------------------------------------------------------------------- /docs/src/_static/mlx_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anemll/mlx-rdma/HEAD/docs/src/_static/mlx_logo.png -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | src/python/_autosummary*/ 2 | src/python/nn/_autosummary*/ 3 | src/python/optimizers/_autosummary*/ 4 | -------------------------------------------------------------------------------- /docs/src/_static/mlx_logo_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anemll/mlx-rdma/HEAD/docs/src/_static/mlx_logo_dark.png -------------------------------------------------------------------------------- /docs/src/cpp/ops.rst: -------------------------------------------------------------------------------- 1 | .. _cpp_ops: 2 | 3 | Operations 4 | ========== 5 | 6 | .. doxygengroup:: ops 7 | :content-only: 8 | -------------------------------------------------------------------------------- /python/tests/__main__.py: -------------------------------------------------------------------------------- 1 | from . import mlx_tests 2 | 3 | __unittest = True 4 | 5 | mlx_tests.MLXTestRunner(module=None) 6 | -------------------------------------------------------------------------------- /docs/src/_static/metal_debugger/capture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anemll/mlx-rdma/HEAD/docs/src/_static/metal_debugger/capture.png -------------------------------------------------------------------------------- /docs/src/_static/metal_debugger/schema.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anemll/mlx-rdma/HEAD/docs/src/_static/metal_debugger/schema.png -------------------------------------------------------------------------------- /mlx/backend/cpu/simd/simd.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mlx/backend/cpu/simd/math.h" 4 | #include "mlx/backend/cpu/simd/type.h" 5 | -------------------------------------------------------------------------------- /examples/extensions/mlx_sample_extensions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import mlx.core as mx 4 | 5 | from ._ext import axpby 6 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /python/mlx/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from mlx.optimizers.optimizers import * 4 | from mlx.optimizers.schedulers import * 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=80", 4 | "nanobind==2.4.0", 5 | "cmake>=3.25", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | -------------------------------------------------------------------------------- /docs/src/python/cuda.rst: -------------------------------------------------------------------------------- 1 | CUDA 2 | ===== 3 | 4 | .. currentmodule:: mlx.core.cuda 5 | 6 | .. autosummary:: 7 | :toctree: _autosummary 8 | 9 | is_available 10 | -------------------------------------------------------------------------------- /mlx/version.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | namespace mlx::core { 4 | 5 | const char* version() { 6 | return MLX_VERSION; 7 | } 8 | 9 | } // namespace mlx::core 10 | -------------------------------------------------------------------------------- /mlx/backend/cpu/simd/type.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mlx/backend/cpu/simd/base_simd.h" 4 | 5 | #ifdef MLX_USE_ACCELERATE 6 | #include "mlx/backend/cpu/simd/accelerate_simd.h" 7 | #endif 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/abs.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Abs) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/ceil.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Ceil) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/cos.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Cos) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/cosh.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Cosh) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/erf.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Erf) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/exp.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Exp) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/imag.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Imag) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/real.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Real) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/sign.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Sign) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/sin.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Sin) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/sinh.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Sinh) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/tan.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Tan) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/tanh.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Tanh) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /python/mlx/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from mlx.nn import init, losses 4 | from mlx.nn.layers import * 5 | from mlx.nn.utils import average_gradients, value_and_grad 6 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CMakeLists.txt 2 | include mlx.pc.in 3 | recursive-include mlx/ * 4 | include cmake/* 5 | include python/src/* 6 | include python/mlx/py.typed # support type hinting as in PEP-561 7 | -------------------------------------------------------------------------------- /mlx/backend/cpu/available.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | namespace mlx::core::cpu { 6 | 7 | bool is_available(); 8 | 9 | } // namespace mlx::core::cpu 10 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/add.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(Add) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/less.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(Less) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/arccos.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(ArcCos) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/arcsin.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(ArcSin) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/arctan.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(ArcTan) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/erf_inv.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(ErfInv) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/expm1.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Expm1) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/floor.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Floor) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/log1p.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Log1p) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/square.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Square) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/gpu/available.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | namespace mlx::core::gpu { 6 | 7 | bool is_available(); 8 | 9 | } // namespace mlx::core::gpu 10 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/arctan2.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(ArcTan2) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/divide.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(Divide) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/greater.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(Greater) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/maximum.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(Maximum) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/minimum.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(Minimum) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/power.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(Power) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/arccosh.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(ArcCosh) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/arcsinh.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(ArcSinh) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/arctanh.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(ArcTanh) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/conjugate.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Conjugate) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/negative.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Negative) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/sigmoid.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(Sigmoid) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /examples/extensions/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "cmake>=3.25", 5 | "mlx>=0.18.0", 6 | "nanobind==2.4.0", 7 | ] 8 | build-backend = "setuptools.build_meta" 9 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/less_equal.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(LessEqual) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/logical_or.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(LogicalOr) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/multiply.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(Multiply) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/not_equal.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(NotEqual) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/remainder.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(Remainder) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/subtract.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(Subtract) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/logical_not.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(LogicalNot) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/reduce_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/kernels/atomic.h" 6 | #include "mlx/backend/metal/kernels/reduction/ops.h" 7 | -------------------------------------------------------------------------------- /mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | ncclResult_t ncclGetUniqueId(ncclUniqueId*) { 6 | return ncclSuccess; 7 | } 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/log_add_exp.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(LogAddExp) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/logical_and.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(LogicalAnd) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/bitwise_invert.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | UNARY_GPU(BitwiseInvert) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/greater_equal.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | BINARY_GPU(GreaterEqual) 7 | } // namespace mlx::core 8 | -------------------------------------------------------------------------------- /mlx/backend/gpu/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp 4 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp) 6 | -------------------------------------------------------------------------------- /mlx/distributed/mpi/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(MLX_BUILD_CPU) 2 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp) 3 | else() 4 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp) 5 | endif() 6 | -------------------------------------------------------------------------------- /docs/src/python/metal.rst: -------------------------------------------------------------------------------- 1 | Metal 2 | ===== 3 | 4 | .. currentmodule:: mlx.core.metal 5 | 6 | .. autosummary:: 7 | :toctree: _autosummary 8 | 9 | is_available 10 | device_info 11 | start_capture 12 | stop_capture 13 | -------------------------------------------------------------------------------- /cmake/Findnvpl.cmake: -------------------------------------------------------------------------------- 1 | # This file does nothing but to suppress the cmake warning: "By not providing 2 | # Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the 3 | # find_package(nvpl) from cmake's builtin FindLAPACK.cmake module. 4 | -------------------------------------------------------------------------------- /mlx/distributed/ring/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(MLX_BUILD_CPU AND NOT WIN32) 2 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp) 3 | else() 4 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ring.cpp) 5 | endif() 6 | -------------------------------------------------------------------------------- /mlx/backend/cuda/cuda.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/cuda.h" 4 | 5 | namespace mlx::core::cu { 6 | 7 | bool is_available() { 8 | return true; 9 | } 10 | 11 | } // namespace mlx::core::cu 12 | -------------------------------------------------------------------------------- /mlx/backend/cuda/cuda.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | namespace mlx::core::cu { 6 | 7 | /* Check if the CUDA backend is available. */ 8 | bool is_available(); 9 | 10 | } // namespace mlx::core::cu 11 | -------------------------------------------------------------------------------- /mlx/backend/common/broadcasting.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | void broadcast(const array& in, array& out); 10 | 11 | } // namespace mlx::core 12 | -------------------------------------------------------------------------------- /mlx/backend/cpu/available.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cpu/available.h" 4 | 5 | namespace mlx::core::cpu { 6 | 7 | bool is_available() { 8 | return true; 9 | } 10 | 11 | } // namespace mlx::core::cpu 12 | -------------------------------------------------------------------------------- /mlx/backend/cuda/steel/defines.cuh: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #define MLX_UNROLL _Pragma("unroll") 6 | 7 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) 8 | #define MLX_CUDA_SM_80_ENABLED 9 | #endif 10 | -------------------------------------------------------------------------------- /mlx/backend/cpu/eval.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | #include "mlx/stream.h" 7 | 8 | namespace mlx::core::cpu { 9 | 10 | void eval(array& arr); 11 | 12 | } // namespace mlx::core::cpu 13 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/conv/loader.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h" 6 | #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h" -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/ternary_ops.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | struct Select { 6 | template 7 | T operator()(bool condition, T x, T y) { 8 | return condition ? x : y; 9 | } 10 | }; 11 | -------------------------------------------------------------------------------- /mlx/backend/no_cpu/available.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cpu/available.h" 4 | 5 | namespace mlx::core::cpu { 6 | 7 | bool is_available() { 8 | return false; 9 | } 10 | 11 | } // namespace mlx::core::cpu 12 | -------------------------------------------------------------------------------- /docs/src/python/export.rst: -------------------------------------------------------------------------------- 1 | .. _export: 2 | 3 | Export Functions 4 | ================ 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | export_function 12 | import_function 13 | exporter 14 | export_to_dot 15 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/reduction/reduce_init.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | template 4 | [[kernel]] void init_reduce( 5 | device T* out [[buffer(0)]], 6 | uint tid [[thread_position_in_grid]]) { 7 | out[tid] = Op::init; 8 | } 9 | -------------------------------------------------------------------------------- /docs/src/python/fast.rst: -------------------------------------------------------------------------------- 1 | .. _fast: 2 | 3 | Fast 4 | ==== 5 | 6 | .. currentmodule:: mlx.core.fast 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | rms_norm 12 | layer_norm 13 | rope 14 | scaled_dot_product_attention 15 | metal_kernel 16 | cuda_kernel 17 | -------------------------------------------------------------------------------- /mlx/backend/cuda/detect_cuda_arch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | arch=`__nvcc_device_query` 4 | case "$arch" in 5 | "90") 6 | echo "90a" ;; 7 | "100") 8 | echo "100a" ;; 9 | "121") 10 | echo "121a" ;; 11 | *) 12 | echo "native" ;; 13 | esac 14 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/defines.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #define STEEL_CONST static constant constexpr const 6 | #define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") 7 | #define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)") 8 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/reduce.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "mlx/backend/metal/kernels/reduction/reduce_all.h" 3 | #include "mlx/backend/metal/kernels/reduction/reduce_col.h" 4 | #include "mlx/backend/metal/kernels/reduction/reduce_init.h" 5 | #include "mlx/backend/metal/kernels/reduction/reduce_row.h" 6 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/arange.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | template 3 | [[kernel]] void arange( 4 | constant const T& start, 5 | constant const T& step, 6 | device T* out, 7 | uint index [[thread_position_in_grid]]) { 8 | out[index] = start + index * step; 9 | } 10 | -------------------------------------------------------------------------------- /mlx/backend/no_gpu/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp 4 | ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp 7 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp) 8 | -------------------------------------------------------------------------------- /docs/src/python/optimizers/schedulers.rst: -------------------------------------------------------------------------------- 1 | .. _schedulers: 2 | 3 | Schedulers 4 | ========== 5 | 6 | .. currentmodule:: mlx.optimizers 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | cosine_decay 12 | exponential_decay 13 | join_schedules 14 | linear_schedule 15 | step_decay 16 | -------------------------------------------------------------------------------- /examples/cmake_project/example.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/mlx.h" 6 | 7 | namespace mx = mlx::core; 8 | 9 | int main() { 10 | auto x = mx::array({1, 2, 3}); 11 | auto y = mx::array({1, 2, 3}); 12 | std::cout << x + y << std::endl; 13 | return 0; 14 | } 15 | -------------------------------------------------------------------------------- /mlx/backend/cuda/device/ternary_ops.cuh: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | #pragma once 3 | 4 | namespace mlx::core::cu { 5 | 6 | struct Select { 7 | template 8 | __device__ T operator()(bool condition, T x, T y) { 9 | return condition ? x : y; 10 | } 11 | }; 12 | 13 | } // namespace mlx::core::cu 14 | -------------------------------------------------------------------------------- /mlx/backend/no_cpu/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp 4 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp 7 | ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp) 8 | -------------------------------------------------------------------------------- /docs/src/python/fft.rst: -------------------------------------------------------------------------------- 1 | .. _fft: 2 | 3 | FFT 4 | === 5 | 6 | .. currentmodule:: mlx.core.fft 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | fft 12 | ifft 13 | fft2 14 | ifft2 15 | fftn 16 | ifftn 17 | rfft 18 | irfft 19 | rfft2 20 | irfft2 21 | rfftn 22 | irfftn 23 | fftshift 24 | ifftshift 25 | -------------------------------------------------------------------------------- /mlx/backend/cpu/compiled_preamble.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-24 Apple Inc. 2 | 3 | #pragma once 4 | 5 | // clang-format off 6 | #include "mlx/types/half_types.h" 7 | #include "mlx/types/complex.h" 8 | #include "mlx/backend/cpu/unary_ops.h" 9 | #include "mlx/backend/cpu/binary_ops.h" 10 | // clang-format on 11 | 12 | const char* get_kernel_preamble(); 13 | -------------------------------------------------------------------------------- /docs/src/python/memory_management.rst: -------------------------------------------------------------------------------- 1 | Memory Management 2 | ================= 3 | 4 | .. currentmodule:: mlx.core 5 | 6 | .. autosummary:: 7 | :toctree: _autosummary 8 | 9 | get_active_memory 10 | get_peak_memory 11 | reset_peak_memory 12 | get_cache_memory 13 | set_memory_limit 14 | set_cache_limit 15 | set_wired_limit 16 | clear_cache 17 | -------------------------------------------------------------------------------- /mlx/backend/no_gpu/apple_memory.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace { 8 | 9 | size_t get_memory_size() { 10 | size_t memsize = 0; 11 | size_t length = sizeof(memsize); 12 | sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); 13 | return memsize; 14 | } 15 | 16 | } // namespace 17 | -------------------------------------------------------------------------------- /mlx/distributed/mpi/mpi.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/distributed/distributed.h" 4 | 5 | namespace mlx::core::distributed::mpi { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available(); 10 | std::shared_ptr init(bool strict = false); 11 | 12 | } // namespace mlx::core::distributed::mpi 13 | -------------------------------------------------------------------------------- /docs/src/python/transforms.rst: -------------------------------------------------------------------------------- 1 | .. _transforms: 2 | 3 | Transforms 4 | ========== 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | eval 12 | async_eval 13 | compile 14 | custom_function 15 | disable_compile 16 | enable_compile 17 | grad 18 | value_and_grad 19 | jvp 20 | vjp 21 | vmap 22 | -------------------------------------------------------------------------------- /mlx/backend/metal/scan.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mlx/array.h" 4 | #include "mlx/primitives.h" 5 | 6 | namespace mlx::core { 7 | 8 | void scan_gpu_inplace( 9 | array in, 10 | array& out, 11 | Scan::ReduceType reduce_type, 12 | int axis, 13 | bool reverse, 14 | bool inclusive, 15 | const Stream& s); 16 | 17 | } // namespace mlx::core 18 | -------------------------------------------------------------------------------- /mlx/distributed/nccl/nccl.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/distributed/distributed.h" 4 | 5 | namespace mlx::core::distributed::nccl { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available(); 10 | std::shared_ptr init(bool strict = false); 11 | 12 | } // namespace mlx::core::distributed::nccl 13 | -------------------------------------------------------------------------------- /mlx/distributed/ring/ring.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/distributed/distributed.h" 4 | 5 | namespace mlx::core::distributed::ring { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available(); 10 | std::shared_ptr init(bool strict = false); 11 | 12 | } // namespace mlx::core::distributed::ring 13 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Development dependencies for MLX 2 | # Install with: uv pip install -r requirements-dev.txt 3 | 4 | # Build dependencies 5 | cmake>=3.25 6 | setuptools>=80 7 | nanobind==2.4.0 8 | 9 | # Test dependencies 10 | numpy 11 | torch 12 | tensorflow 13 | unittest-xml-reporting 14 | typing_extensions 15 | 16 | # Development tools 17 | pre-commit 18 | 19 | -------------------------------------------------------------------------------- /mlx/distributed/jaccl/jaccl.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/distributed/distributed.h" 4 | 5 | namespace mlx::core::distributed::jaccl { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available(); 10 | std::shared_ptr init(bool strict = false); 11 | 12 | } // namespace mlx::core::distributed::jaccl 13 | -------------------------------------------------------------------------------- /examples/extensions/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Build 3 | 4 | ``` 5 | pip install -e . 6 | ``` 7 | 8 | For faster builds during development, you can also pre-install the requirements: 9 | 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | And then run: 15 | 16 | ``` 17 | python setup.py build_ext -j8 --inplace 18 | ``` 19 | 20 | ## Test 21 | 22 | ``` 23 | python test.py 24 | ``` 25 | -------------------------------------------------------------------------------- /docs/src/python/devices_and_streams.rst: -------------------------------------------------------------------------------- 1 | .. _devices_and_streams: 2 | 3 | Devices and Streams 4 | =================== 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | Device 12 | Stream 13 | default_device 14 | set_default_device 15 | default_stream 16 | new_stream 17 | set_default_stream 18 | stream 19 | synchronize 20 | -------------------------------------------------------------------------------- /mlx/backend/gpu/eval.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/array.h" 9 | #include "mlx/stream.h" 10 | 11 | namespace mlx::core::gpu { 12 | 13 | void new_stream(Stream stream); 14 | void eval(array& arr); 15 | void finalize(Stream s); 16 | void synchronize(Stream s); 17 | 18 | } // namespace mlx::core::gpu 19 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/bf16.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | using namespace metal; 8 | 9 | typedef bfloat bfloat16_t; 10 | inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { 11 | return as_type(x); 12 | } 13 | 14 | inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { 15 | return as_type(x); 16 | } 17 | -------------------------------------------------------------------------------- /python/mlx/_reprlib_fix.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import array 4 | import reprlib 5 | 6 | _old_repr_array = reprlib.Repr.repr_array 7 | 8 | 9 | def repr_array(self, x, maxlevel): 10 | if isinstance(x, array.array): 11 | return _old_repr_array(self, x, maxlevel) 12 | else: 13 | return self.repr_instance(x, maxlevel) 14 | 15 | 16 | reprlib.Repr.repr_array = repr_array 17 | -------------------------------------------------------------------------------- /mlx/backend/no_gpu/linux_memory.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace { 8 | 9 | size_t get_memory_size() { 10 | struct sysinfo info; 11 | 12 | if (sysinfo(&info) != 0) { 13 | return 0; 14 | } 15 | 16 | size_t total_ram = info.totalram; 17 | total_ram *= info.mem_unit; 18 | 19 | return total_ram; 20 | } 21 | 22 | } // namespace 23 | -------------------------------------------------------------------------------- /docs/src/python/optimizers/common_optimizers.rst: -------------------------------------------------------------------------------- 1 | .. _common_optimizers: 2 | 3 | Common Optimizers 4 | ================= 5 | 6 | .. currentmodule:: mlx.optimizers 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | :template: optimizers-template.rst 11 | 12 | SGD 13 | RMSprop 14 | Adagrad 15 | Adafactor 16 | AdaDelta 17 | Adam 18 | AdamW 19 | Adamax 20 | Lion 21 | MultiOptimizer 22 | Muon 23 | -------------------------------------------------------------------------------- /examples/cmake_project/README.md: -------------------------------------------------------------------------------- 1 | ## Build and Run 2 | 3 | Install MLX with Python: 4 | 5 | ```bash 6 | pip install mlx>=0.22 7 | ``` 8 | 9 | Build the C++ example: 10 | 11 | ```bash 12 | cmake -B build -DCMAKE_BUILD_TYPE=Release 13 | cmake --build build 14 | ``` 15 | 16 | Run the C++ example: 17 | 18 | ``` 19 | ./build/example 20 | ``` 21 | 22 | which should output: 23 | 24 | ``` 25 | array([2, 4, 6], dtype=int32) 26 | ``` 27 | -------------------------------------------------------------------------------- /examples/cpp/timer.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace timer { 8 | 9 | using namespace std::chrono; 10 | 11 | template 12 | inline double seconds(duration x) { 13 | return duration_cast(x).count() / 1e9; 14 | } 15 | 16 | inline auto time() { 17 | return high_resolution_clock::now(); 18 | } 19 | 20 | } // namespace timer 21 | -------------------------------------------------------------------------------- /mlx/backend/common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp 4 | ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp 7 | ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp 8 | ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp 9 | ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) 10 | -------------------------------------------------------------------------------- /examples/extensions/test.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | from mlx_sample_extensions import axpby 3 | 4 | a = mx.ones((3, 4)) 5 | b = mx.ones((3, 4)) 6 | c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu) 7 | c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu) 8 | 9 | print(f"c shape: {c_cpu.shape}") 10 | print(f"c dtype: {c_cpu.dtype}") 11 | print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}") 12 | print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}") 13 | -------------------------------------------------------------------------------- /docs/src/python/linalg.rst: -------------------------------------------------------------------------------- 1 | .. _linalg: 2 | 3 | Linear Algebra 4 | ============== 5 | 6 | .. currentmodule:: mlx.core.linalg 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | inv 12 | tri_inv 13 | norm 14 | cholesky 15 | cholesky_inv 16 | cross 17 | qr 18 | svd 19 | eigvals 20 | eig 21 | eigvalsh 22 | eigh 23 | lu 24 | lu_factor 25 | pinv 26 | solve 27 | solve_triangular 28 | -------------------------------------------------------------------------------- /mlx/backend/common/slicing.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | std::tuple prepare_slice( 10 | const array& in, 11 | const Shape& start_indices, 12 | const Shape& strides); 13 | 14 | void slice( 15 | const array& in, 16 | array& out, 17 | const Shape& start_indices, 18 | const Shape& strides); 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /mlx/backend/cuda/device/config.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | // This file is used by both CUDA kernel code and host-only C++ code. 4 | 5 | #pragma once 6 | 7 | // The maximum dimensions of shape/strides passed as kernel parameters. 8 | #define MAX_NDIM 10 9 | 10 | // All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in 11 | // warpSize variable exists, using it would prevent compile-time optimizations. 12 | #define WARP_SIZE 32 13 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/conv/conv.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/kernels/steel/defines.h" 6 | #include "mlx/backend/metal/kernels/steel/utils.h" 7 | 8 | #include "mlx/backend/metal/kernels/steel/conv/loader.h" 9 | #include "mlx/backend/metal/kernels/steel/conv/params.h" 10 | #include "mlx/backend/metal/kernels/steel/gemm/mma.h" 11 | 12 | using namespace metal; 13 | using namespace mlx::steel; 14 | -------------------------------------------------------------------------------- /examples/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | function(build_example SRCFILE) 2 | get_filename_component(src_name ${SRCFILE} NAME_WE) 3 | set(target "${src_name}") 4 | add_executable(${target} ${SRCFILE}) 5 | target_link_libraries(${target} PRIVATE mlx) 6 | endfunction(build_example) 7 | 8 | build_example(tutorial.cpp) 9 | build_example(linear_regression.cpp) 10 | build_example(logistic_regression.cpp) 11 | build_example(metal_capture.cpp) 12 | build_example(distributed.cpp) 13 | -------------------------------------------------------------------------------- /mlx/backend/metal/unary.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | void unary_op_gpu( 10 | const std::vector& inputs, 11 | array& out, 12 | const char* op, 13 | const Stream& s); 14 | 15 | void unary_op_gpu_inplace( 16 | const std::vector& inputs, 17 | array& out, 18 | const char* op, 19 | const Stream& s); 20 | 21 | } // namespace mlx::core 22 | -------------------------------------------------------------------------------- /docs/src/python/optimizers/optimizer.rst: -------------------------------------------------------------------------------- 1 | Optimizer 2 | ========= 3 | 4 | .. currentmodule:: mlx.optimizers 5 | 6 | .. autoclass:: Optimizer 7 | 8 | 9 | .. rubric:: Attributes 10 | 11 | .. autosummary:: 12 | :toctree: _autosummary 13 | 14 | Optimizer.state 15 | 16 | .. rubric:: Methods 17 | 18 | .. autosummary:: 19 | :toctree: _autosummary 20 | 21 | Optimizer.apply_gradients 22 | Optimizer.init 23 | Optimizer.update 24 | -------------------------------------------------------------------------------- /mlx/backend/metal/ternary.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | void ternary_op_gpu( 10 | const std::vector& inputs, 11 | array& out, 12 | const char* op, 13 | const Stream& s); 14 | 15 | void ternary_op_gpu_inplace( 16 | const std::vector& inputs, 17 | array& out, 18 | const char* op, 19 | const Stream& s); 20 | 21 | } // namespace mlx::core 22 | -------------------------------------------------------------------------------- /benchmarks/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | function(build_benchmark SRCFILE) 2 | get_filename_component(src_name ${SRCFILE} NAME_WE) 3 | set(target "${src_name}") 4 | add_executable(${target} ${SRCFILE}) 5 | target_link_libraries(${target} PRIVATE mlx) 6 | endfunction(build_benchmark) 7 | 8 | build_benchmark(single_ops.cpp) 9 | build_benchmark(irregular_strides.cpp) 10 | build_benchmark(compare_devices.cpp) 11 | build_benchmark(autograd.cpp) 12 | build_benchmark(separate_tp_vs_single.cpp) 13 | -------------------------------------------------------------------------------- /mlx/io/gguf.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | #pragma once 3 | 4 | #include "mlx/io.h" 5 | #include "mlx/primitives.h" 6 | #include "mlx/transforms.h" 7 | #include "mlx/utils.h" 8 | 9 | extern "C" { 10 | #include 11 | } 12 | 13 | namespace mlx::core { 14 | 15 | Shape get_shape(const gguf_tensor& tensor); 16 | void gguf_load_quantized( 17 | std::unordered_map& a, 18 | const gguf_tensor& tensor); 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /python/src/cuda.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2025 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/cuda/cuda.h" 6 | 7 | namespace mx = mlx::core; 8 | namespace nb = nanobind; 9 | 10 | void init_cuda(nb::module_& m) { 11 | nb::module_ cuda = m.def_submodule("cuda", "mlx.cuda"); 12 | 13 | cuda.def( 14 | "is_available", 15 | &mx::cu::is_available, 16 | R"pbdoc( 17 | Check if the CUDA back-end is available. 18 | )pbdoc"); 19 | } 20 | -------------------------------------------------------------------------------- /benchmarks/numpy/time_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | 5 | 6 | def time_fn(fn, *args): 7 | print(f"Timing {fn.__name__} ...", end=" ") 8 | 9 | # warmup 10 | for _ in range(5): 11 | fn(*args) 12 | 13 | num_iters = 100 14 | tic = time.perf_counter() 15 | for _ in range(num_iters): 16 | x = fn(*args) 17 | toc = time.perf_counter() 18 | 19 | msec = 1e3 * (toc - tic) / num_iters 20 | print(f"{msec:.5f} msec") 21 | -------------------------------------------------------------------------------- /mlx/backend/cpu/encoder.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cpu/encoder.h" 4 | 5 | namespace mlx::core::cpu { 6 | 7 | CommandEncoder& get_command_encoder(Stream stream) { 8 | static std::unordered_map encoder_map; 9 | auto it = encoder_map.find(stream.index); 10 | if (it == encoder_map.end()) { 11 | it = encoder_map.emplace(stream.index, stream).first; 12 | } 13 | return it->second; 14 | } 15 | 16 | } // namespace mlx::core::cpu 17 | -------------------------------------------------------------------------------- /mlx/backend/cpu/slicing.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | std::tuple prepare_slice( 10 | const array& in, 11 | const Shape& start_indices, 12 | const Shape& strides); 13 | 14 | void shared_buffer_slice( 15 | const array& in, 16 | const Strides& out_strides, 17 | size_t data_offset, 18 | size_t data_size, 19 | array& out); 20 | 21 | } // namespace mlx::core 22 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/sqrt.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | void Sqrt::eval_gpu(const std::vector& inputs, array& out) { 7 | nvtx3::scoped_range r("Sqrt::eval_gpu"); 8 | auto& s = out.primitive().stream(); 9 | if (recip_) { 10 | unary_op_gpu(inputs, out, "Rsqrt", s); 11 | } else { 12 | unary_op_gpu(inputs, out, "Sqrt", s); 13 | } 14 | } 15 | } // namespace mlx::core 16 | -------------------------------------------------------------------------------- /docs/src/_templates/optimizers-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | {% block methods %} 8 | 9 | {% if methods %} 10 | .. rubric:: {{ _('Methods') }} 11 | 12 | .. autosummary:: 13 | {% for item in methods %} 14 | {%- if item not in inherited_members %} 15 | ~{{ name }}.{{ item }} 16 | {%- endif %} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/equal.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | void Equal::eval_gpu(const std::vector& inputs, array& out) { 7 | nvtx3::scoped_range r("Equal::eval_gpu"); 8 | auto& s = out.primitive().stream(); 9 | if (equal_nan_) { 10 | binary_op_gpu(inputs, out, name(), s); 11 | } else { 12 | binary_op_gpu(inputs, out, name(), s); 13 | } 14 | } 15 | } // namespace mlx::core 16 | -------------------------------------------------------------------------------- /docs/src/_templates/nn-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | {% block methods %} 8 | 9 | {% if methods %} 10 | .. rubric:: {{ _('Methods') }} 11 | 12 | .. autosummary:: 13 | {% for item in methods %} 14 | {%- if item not in inherited_members and item != "__init__" %} 15 | ~{{ name }}.{{ item }} 16 | {%- endif %} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | -------------------------------------------------------------------------------- /mlx/distributed/mpi/no_mpi.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/distributed/mpi/mpi.h" 4 | 5 | namespace mlx::core::distributed::mpi { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available() { 10 | return false; 11 | } 12 | 13 | std::shared_ptr init(bool strict /* = false */) { 14 | if (strict) { 15 | throw std::runtime_error("Cannot initialize MPI"); 16 | } 17 | return nullptr; 18 | } 19 | 20 | } // namespace mlx::core::distributed::mpi 21 | -------------------------------------------------------------------------------- /docs/src/python/nn/losses.rst: -------------------------------------------------------------------------------- 1 | .. _losses: 2 | 3 | .. currentmodule:: mlx.nn.losses 4 | 5 | Loss Functions 6 | -------------- 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary_functions 10 | :template: nn-module-template.rst 11 | 12 | binary_cross_entropy 13 | cosine_similarity_loss 14 | cross_entropy 15 | gaussian_nll_loss 16 | hinge_loss 17 | huber_loss 18 | kl_div_loss 19 | l1_loss 20 | log_cosh_loss 21 | margin_ranking_loss 22 | mse_loss 23 | nll_loss 24 | smooth_l1_loss 25 | triplet_loss -------------------------------------------------------------------------------- /mlx/distributed/jaccl/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(MLX_BUILD_CPU 2 | AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin" 3 | AND MACOS_SDK_VERSION GREATER_EQUAL 26.2 4 | OR MLX_BUILD_JACCL) 5 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp) 6 | # On macOS, link against librdma which provides InfiniBand Verbs 7 | # Use PUBLIC so shared libraries linking to mlx also get rdma dependency 8 | target_link_libraries(mlx PUBLIC rdma) 9 | else() 10 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp) 11 | endif() 12 | -------------------------------------------------------------------------------- /python/src/constants.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | namespace nb = nanobind; 7 | 8 | void init_constants(nb::module_& m) { 9 | m.attr("e") = 2.71828182845904523536028747135266249775724709369995; 10 | m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421; 11 | m.attr("inf") = std::numeric_limits::infinity(); 12 | m.attr("nan") = NAN; 13 | m.attr("newaxis") = nb::none(); 14 | m.attr("pi") = 3.1415926535897932384626433; 15 | } 16 | -------------------------------------------------------------------------------- /docs/src/python/distributed.rst: -------------------------------------------------------------------------------- 1 | .. _distributed: 2 | 3 | .. currentmodule:: mlx.core.distributed 4 | 5 | Distributed Communication 6 | ========================== 7 | 8 | MLX provides a distributed communication package using MPI. The MPI library is 9 | loaded at runtime; if MPI is available then distributed communication is also 10 | made available. 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | 15 | Group 16 | is_available 17 | init 18 | all_sum 19 | all_gather 20 | send 21 | recv 22 | recv_like 23 | -------------------------------------------------------------------------------- /mlx/distributed/nccl/no_nccl.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/distributed/nccl/nccl.h" 4 | 5 | namespace mlx::core::distributed::nccl { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available() { 10 | return false; 11 | } 12 | 13 | std::shared_ptr init(bool strict /* = false */) { 14 | if (strict) { 15 | throw std::runtime_error("Cannot initialize nccl distributed backend."); 16 | } 17 | return nullptr; 18 | } 19 | 20 | } // namespace mlx::core::distributed::nccl 21 | -------------------------------------------------------------------------------- /mlx/distributed/ring/no_ring.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/distributed/ring/ring.h" 4 | 5 | namespace mlx::core::distributed::ring { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available() { 10 | return false; 11 | } 12 | 13 | std::shared_ptr init(bool strict /* = false */) { 14 | if (strict) { 15 | throw std::runtime_error("Cannot initialize ring distributed backend."); 16 | } 17 | return nullptr; 18 | } 19 | 20 | } // namespace mlx::core::distributed::ring 21 | -------------------------------------------------------------------------------- /mlx/distributed/jaccl/no_jaccl.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/distributed/jaccl/jaccl.h" 4 | 5 | namespace mlx::core::distributed::jaccl { 6 | 7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl; 8 | 9 | bool is_available() { 10 | return false; 11 | } 12 | 13 | std::shared_ptr init(bool strict /* = false */) { 14 | if (strict) { 15 | throw std::runtime_error("Cannot initialize jaccl distributed backend."); 16 | } 17 | return nullptr; 18 | } 19 | 20 | } // namespace mlx::core::distributed::jaccl 21 | -------------------------------------------------------------------------------- /mlx/einsum.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | #pragma once 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "mlx/array.h" 9 | #include "mlx/utils.h" 10 | 11 | namespace mlx::core { 12 | 13 | std::pair>, std::string> einsum_path( 14 | const std::string& subscripts, 15 | const std::vector& operands); 16 | 17 | array einsum( 18 | const std::string& subscripts, 19 | const std::vector& operands, 20 | StreamOrDevice s = {}); 21 | 22 | } // namespace mlx::core 23 | -------------------------------------------------------------------------------- /mlx/distributed/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp 4 | ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp) 6 | 7 | if(MLX_BUILD_CPU AND NOT WIN32) 8 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) 9 | endif() 10 | 11 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) 12 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring) 13 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl) 14 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl) 15 | -------------------------------------------------------------------------------- /mlx/distributed/nccl/nccl_stub/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.25) 2 | 3 | project(nccl LANGUAGES C CXX) 4 | 5 | file( 6 | DOWNLOAD 7 | "https://raw.githubusercontent.com/NVIDIA/nccl/refs/tags/v2.27.5-1/src/nccl.h.in" 8 | "${CMAKE_CURRENT_BINARY_DIR}/nccl.h") 9 | 10 | add_library(nccl SHARED nccl_stubs.cpp) 11 | set_target_properties(nccl PROPERTIES SOVERSION 2) 12 | find_package(CUDAToolkit REQUIRED) 13 | target_include_directories(nccl PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) 14 | target_include_directories(nccl PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) 15 | -------------------------------------------------------------------------------- /mlx/version.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #define MLX_VERSION_MAJOR 0 6 | #define MLX_VERSION_MINOR 30 7 | #define MLX_VERSION_PATCH 1 8 | #define MLX_VERSION_NUMERIC \ 9 | (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) 10 | 11 | namespace mlx::core { 12 | 13 | /* A string representation of the MLX version in the format 14 | * "major.minor.patch". 15 | * 16 | * For dev builds, the version will include the suffix ".devYYYYMMDD+hash" 17 | */ 18 | const char* version(); 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /mlx/backend/cpu/gemm.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | #include "mlx/array.h" 5 | 6 | namespace mlx::core { 7 | 8 | template 9 | void matmul( 10 | const T* a, 11 | const T* b, 12 | T* out, 13 | bool a_transposed, 14 | bool b_transposed, 15 | size_t lda, 16 | size_t ldb, 17 | size_t ldc, 18 | float alpha, 19 | float beta, 20 | size_t batch_size, 21 | const Shape& a_shape, 22 | const Strides& a_strides, 23 | const Shape& b_shape, 24 | const Strides& b_strides); 25 | 26 | } // namespace mlx::core 27 | -------------------------------------------------------------------------------- /mlx/backend/cpu/jit_compiler.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | #pragma once 3 | 4 | #include 5 | 6 | namespace mlx::core { 7 | 8 | class JitCompiler { 9 | public: 10 | // Build a shell command that compiles a source code file to a shared library. 11 | static std::string build_command( 12 | const std::filesystem::path& dir, 13 | const std::string& source_file_name, 14 | const std::string& shared_lib_name); 15 | 16 | // Run a command and get its output. 17 | static std::string exec(const std::string& cmd); 18 | }; 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /examples/export/eval_mlp.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | namespace mx = mlx::core; 7 | 8 | int main() { 9 | int batch_size = 8; 10 | int input_dim = 32; 11 | 12 | // Make the input 13 | mx::random::seed(42); 14 | auto example_x = mx::random::uniform({batch_size, input_dim}); 15 | 16 | // Import the function 17 | auto forward = mx::import_function("eval_mlp.mlxfn"); 18 | 19 | // Call the imported function 20 | auto out = forward({example_x})[0]; 21 | 22 | std::cout << out << std::endl; 23 | 24 | return 0; 25 | } 26 | -------------------------------------------------------------------------------- /mlx/allocator.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "mlx/allocator.h" 7 | 8 | namespace mlx::core::allocator { 9 | 10 | Buffer malloc(size_t size) { 11 | auto buffer = allocator().malloc(size); 12 | if (size && !buffer.ptr()) { 13 | std::ostringstream msg; 14 | msg << "[malloc] Unable to allocate " << size << " bytes."; 15 | throw std::runtime_error(msg.str()); 16 | } 17 | return buffer; 18 | } 19 | 20 | void free(Buffer buffer) { 21 | allocator().free(buffer); 22 | } 23 | 24 | } // namespace mlx::core::allocator 25 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/round.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | void Round::eval_gpu(const std::vector& inputs, array& out) { 7 | nvtx3::scoped_range r("Round::eval_gpu"); 8 | assert(inputs.size() == 1); 9 | const auto& in = inputs[0]; 10 | auto& s = out.primitive().stream(); 11 | if (issubdtype(in.dtype(), inexact)) { 12 | unary_op_gpu(inputs, out, name(), s); 13 | } else { 14 | // No-op integer types 15 | out.copy_shared_buffer(in); 16 | } 17 | } 18 | } // namespace mlx::core 19 | -------------------------------------------------------------------------------- /mlx/io/no_gguf.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include "mlx/io.h" 4 | 5 | namespace mlx::core { 6 | 7 | GGUFLoad load_gguf(const std::string&, StreamOrDevice s) { 8 | throw std::runtime_error( 9 | "[load_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support."); 10 | } 11 | 12 | void save_gguf( 13 | std::string, 14 | std::unordered_map, 15 | std::unordered_map) { 16 | throw std::runtime_error( 17 | "[save_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support."); 18 | } 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /examples/cpp/distributed.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/mlx.h" 6 | 7 | namespace mx = mlx::core; 8 | 9 | int main() { 10 | if (!mx::distributed::is_available()) { 11 | std::cout << "No communication backend found" << std::endl; 12 | return 1; 13 | } 14 | 15 | auto global_group = mx::distributed::init(); 16 | std::cout << global_group.rank() << " / " << global_group.size() << std::endl; 17 | 18 | mx::array x = mx::ones({10}); 19 | mx::array out = mx::distributed::all_sum(x, global_group); 20 | 21 | std::cout << out << std::endl; 22 | } 23 | -------------------------------------------------------------------------------- /tests_jaccl/run_jaccl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Example usage: 4 | # ./run_jaccl.sh 5 | 6 | RANK=$1 7 | COORDINATOR=$2 8 | 9 | if [ -z "$RANK" ] || [ -z "$COORDINATOR" ]; then 10 | echo "Usage: $0 " 11 | exit 1 12 | fi 13 | 14 | export MLX_RANK=$RANK 15 | export MLX_IBV_COORDINATOR=$COORDINATOR 16 | export MLX_IBV_DEVICES=$(pwd)/jaccl_config.json 17 | export MLX_IBV_VERBOSE=1 18 | 19 | # Activate virtual environment if it exists 20 | if [ -f "../.venv/bin/activate" ]; then 21 | source ../.venv/bin/activate 22 | fi 23 | 24 | python3 test_tp_mlp.py 25 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/indexing/indexing.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | template 8 | struct Indices { 9 | const array buffers; 10 | const constant int* shapes; 11 | const constant int64_t* strides; 12 | const constant bool* row_contiguous; 13 | const int ndim; 14 | }; 15 | 16 | template 17 | METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { 18 | if (is_unsigned_v) { 19 | return idx; 20 | } else { 21 | return (idx < 0) ? idx + size : idx; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /python/scripts/repair_linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | auditwheel repair dist/* \ 4 | --plat manylinux_2_35_${1} \ 5 | --only-plat \ 6 | --exclude libmlx* \ 7 | -w wheel_tmp 8 | 9 | mkdir wheelhouse 10 | cd wheel_tmp 11 | repaired_wheel=$(find . -name "*.whl" -print -quit) 12 | unzip -q "${repaired_wheel}" 13 | rm "${repaired_wheel}" 14 | core_so=$(find mlx -name "core*.so" -print -quit) 15 | rpath="\$ORIGIN/lib" 16 | patchelf --force-rpath --set-rpath "$rpath" "$core_so" 17 | python ../python/scripts/repair_record.py ${core_so} 18 | 19 | # Re-zip the repaired wheel 20 | zip -r -q "../wheelhouse/${repaired_wheel}" . 21 | -------------------------------------------------------------------------------- /tests/tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #define DOCTEST_CONFIG_IMPLEMENT 4 | #include "doctest/doctest.h" 5 | 6 | #include 7 | 8 | #include "mlx/mlx.h" 9 | 10 | using namespace mlx::core; 11 | 12 | int main(int argc, char** argv) { 13 | doctest::Context context; 14 | 15 | const char* device = std::getenv("DEVICE"); 16 | if (device != nullptr && std::string(device) == "cpu") { 17 | set_default_device(Device::cpu); 18 | } else if (metal::is_available()) { 19 | set_default_device(Device::gpu); 20 | } 21 | 22 | context.applyCommandLine(argc, argv); 23 | return context.run(); 24 | } 25 | -------------------------------------------------------------------------------- /mlx/backend/cuda/gemms/gemv.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/cuda/device.h" 6 | 7 | namespace mlx::core::cu { 8 | 9 | bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed); 10 | 11 | void gemv( 12 | const array& a, 13 | const array& b, 14 | array& out, 15 | int M, 16 | int N, 17 | int K, 18 | uint32_t batch_count, 19 | const mlx::core::Shape& batch_shape, 20 | const mlx::core::Strides& a_batch_strides, 21 | const mlx::core::Strides& b_batch_strides, 22 | CommandEncoder& encoder); 23 | 24 | } // namespace mlx::core::cu 25 | -------------------------------------------------------------------------------- /mlx/backend/metal/metal.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace mlx::core::metal { 10 | 11 | /* Check if the Metal backend is available. */ 12 | bool is_available(); 13 | 14 | /** Capture a GPU trace, saving it to an absolute file `path` */ 15 | void start_capture(std::string path = ""); 16 | void stop_capture(); 17 | 18 | /** Get information about the GPU and system settings. */ 19 | const std::unordered_map>& 20 | device_info(); 21 | 22 | } // namespace mlx::core::metal 23 | -------------------------------------------------------------------------------- /.github/actions/build-cuda-release/action.yml: -------------------------------------------------------------------------------- 1 | name: 'Build CUDA wheel' 2 | description: 'Build CUDA wheel' 3 | 4 | inputs: 5 | toolkit: 6 | description: 'The CUDA toolkit' 7 | required: true 8 | 9 | runs: 10 | using: "composite" 11 | steps: 12 | - name: Build package 13 | shell: bash 14 | env: 15 | CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc 16 | run: | 17 | pip install auditwheel build patchelf setuptools 18 | python setup.py clean --all 19 | MLX_BUILD_STAGE=2 python -m build -w 20 | bash python/scripts/repair_cuda.sh 21 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Proposed changes 2 | 3 | Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #. 4 | 5 | ## Checklist 6 | 7 | Put an `x` in the boxes that apply. 8 | 9 | - [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document 10 | - [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes 11 | - [ ] I have added tests that prove my fix is effective or that my feature works 12 | - [ ] I have updated the necessary documentation (if needed) 13 | -------------------------------------------------------------------------------- /mlx/backend/cuda/quantized/convert_fp8.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | #include "mlx/backend/cuda/unary/unary.cuh" 3 | #include "mlx/fast_primitives.h" 4 | 5 | namespace mlx::core { 6 | void fast::ConvertFP8::eval_gpu( 7 | const std::vector& inputs, 8 | std::vector& outputs) { 9 | nvtx3::scoped_range r("ConvertFP8::eval_gpu"); 10 | auto& in = inputs[0]; 11 | auto& out = outputs[0]; 12 | auto& s = out.primitive().stream(); 13 | if (to_fp8_) { 14 | unary_op_gpu(inputs, out, name(), s); 15 | } else { 16 | unary_op_gpu(inputs, out, name(), s); 17 | } 18 | } 19 | } // namespace mlx::core 20 | -------------------------------------------------------------------------------- /tests_jaccl/deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Syncs the current directory to the test hosts 4 | # Usage: ./deploy.sh 5 | 6 | REMOTE_PATH="/Users/anemll/SourceRelease/GITHUB/ML_playground/mlx-rdma" 7 | 8 | echo "Deploying to m4p.local..." 9 | rsync -avz --exclude '.git' --exclude 'build' ../ m4p.local:$REMOTE_PATH/ 10 | 11 | echo "Deploying to m3u.local..." 12 | # Check if m3u is up first to avoid long timeout if it's still down 13 | if ping -c 1 -W 1 m3u.local &> /dev/null; then 14 | rsync -avz --exclude '.git' --exclude 'build' ../ m3u.local:$REMOTE_PATH/ 15 | else 16 | echo "m3u.local seems to be down. Skipping." 17 | fi 18 | 19 | echo "Done." 20 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | 3 | # You can set these variables from the command line. 4 | SPHINXOPTS = 5 | SPHINXBUILD = sphinx-build 6 | SOURCEDIR = src 7 | BUILDDIR = build 8 | 9 | # Put it first so that "make" without argument is like "make help". 10 | help: 11 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 12 | 13 | .PHONY: help Makefile 14 | 15 | # Catch-all target: route all unknown targets to Sphinx using the new 16 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 17 | %: Makefile 18 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 19 | -------------------------------------------------------------------------------- /mlx/backend/cuda/unary/log.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/unary/unary.cuh" 4 | 5 | namespace mlx::core { 6 | void Log::eval_gpu(const std::vector& inputs, array& out) { 7 | nvtx3::scoped_range r("Log::eval_gpu"); 8 | auto& s = out.primitive().stream(); 9 | switch (base_) { 10 | case Base::e: 11 | unary_op_gpu(inputs, out, name(), s); 12 | break; 13 | case Base::two: 14 | unary_op_gpu(inputs, out, name(), s); 15 | break; 16 | case Base::ten: 17 | unary_op_gpu(inputs, out, name(), s); 18 | break; 19 | } 20 | } 21 | } // namespace mlx::core 22 | -------------------------------------------------------------------------------- /examples/export/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.27) 2 | 3 | project(import_mlx LANGUAGES CXX) 4 | 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | find_package( 9 | Python 3.9 10 | COMPONENTS Interpreter Development.Module 11 | REQUIRED) 12 | execute_process( 13 | COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir 14 | OUTPUT_STRIP_TRAILING_WHITESPACE 15 | OUTPUT_VARIABLE MLX_ROOT) 16 | find_package(MLX CONFIG REQUIRED) 17 | 18 | add_executable(eval_mlp eval_mlp.cpp) 19 | target_link_libraries(eval_mlp PRIVATE mlx) 20 | 21 | add_executable(train_mlp train_mlp.cpp) 22 | target_link_libraries(train_mlp PRIVATE mlx) 23 | -------------------------------------------------------------------------------- /mlx/backend/cpu/threefry.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | namespace mlx::core::random { 9 | 10 | /** Applies the Threefry 2x32 hash function. 11 | * This code is based on the Jax counter-based and splittable PRNG 12 | * https://github.com/google/jax/blob/main/docs/jep/263-prng.md 13 | * 14 | * Original Threefry reference: 15 | * http://www.thesalmons.org/john/random123/papers/random123sc11.pdf 16 | */ 17 | std::pair threefry2x32_hash( 18 | const std::pair& key, 19 | std::pair count); 20 | 21 | } // namespace mlx::core::random 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report about an issue you've encountered 4 | title: "[BUG] " 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | 15 | Include code snippet 16 | ```python 17 | 18 | ``` 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Desktop (please complete the following information):** 24 | - OS Version: [e.g. MacOS 14.1.2] 25 | - Version [e.g. 0.7.0] 26 | 27 | **Additional context** 28 | Add any other context about the problem here. 29 | -------------------------------------------------------------------------------- /.github/workflows/documentation.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | permissions: 7 | contents: read 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-22.04 12 | steps: 13 | - uses: actions/checkout@v5 14 | - uses: ./.github/actions/build-docs 15 | 16 | deploy: 17 | needs: build 18 | permissions: 19 | pages: write 20 | id-token: write 21 | runs-on: ubuntu-latest 22 | environment: 23 | name: github-pages 24 | url: ${{ steps.deployment.outputs.page_url }} 25 | steps: 26 | - name: Deploy to GitHub Pages 27 | id: deployment 28 | uses: actions/deploy-pages@v4 29 | -------------------------------------------------------------------------------- /examples/extensions/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | from setuptools import setup 4 | 5 | from mlx import extension 6 | 7 | if __name__ == "__main__": 8 | setup( 9 | name="mlx_sample_extensions", 10 | version="0.0.0", 11 | description="Sample C++ and Metal extensions for MLX primitives.", 12 | ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")], 13 | cmdclass={"build_ext": extension.CMakeBuild}, 14 | packages=["mlx_sample_extensions"], 15 | package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, 16 | zip_safe=False, 17 | python_requires=">=3.8", 18 | ) 19 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: mlx 3 | message: >- 4 | If you use this software, please cite it using the 5 | metadata from this file. 6 | type: software 7 | authors: 8 | - given-names: Awni 9 | family-names: Hannun 10 | affiliation: Apple 11 | - given-names: Jagrit 12 | family-names: Digani 13 | affiliation: Apple 14 | - given-names: Angelos 15 | family-names: Katharopoulos 16 | affiliation: Apple 17 | - given-names: Ronan 18 | family-names: Collobert 19 | affiliation: Apple 20 | repository-code: 'https://github.com/ml-explore' 21 | abstract: >- 22 | MLX: efficient and flexible machine learning on Apple 23 | silicon 24 | license: MIT 25 | -------------------------------------------------------------------------------- /docs/src/python/tree_utils.rst: -------------------------------------------------------------------------------- 1 | .. _utils: 2 | 3 | Tree Utils 4 | ========== 5 | 6 | In MLX we consider a python tree to be an arbitrarily nested collection of 7 | dictionaries, lists and tuples without cycles. Functions in this module that 8 | return python trees will be using the default python ``dict``, ``list`` and 9 | ``tuple`` but they can usually process objects that inherit from any of these. 10 | 11 | .. note:: 12 | Dictionaries should have keys that are valid python identifiers. 13 | 14 | .. currentmodule:: mlx.utils 15 | 16 | .. autosummary:: 17 | :toctree: _autosummary 18 | 19 | tree_flatten 20 | tree_unflatten 21 | tree_map 22 | tree_map_with_path 23 | tree_reduce 24 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/logsumexp.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | using namespace metal; 7 | 8 | // clang-format off 9 | #include "mlx/backend/metal/kernels/utils.h" 10 | #include "mlx/backend/metal/kernels/logsumexp.h" 11 | 12 | #define instantiate_logsumexp(name, itype) \ 13 | instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \ 14 | instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \ 15 | 16 | instantiate_logsumexp(float32, float) 17 | instantiate_logsumexp(float16, half) 18 | instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on 19 | -------------------------------------------------------------------------------- /mlx/backend/no_gpu/eval.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/gpu/available.h" 6 | #include "mlx/backend/gpu/eval.h" 7 | 8 | namespace mlx::core::gpu { 9 | 10 | bool is_available() { 11 | return false; 12 | } 13 | 14 | void new_stream(Stream) {} 15 | 16 | void eval(array&) { 17 | throw std::runtime_error("[gpu::eval] GPU backend is not available"); 18 | } 19 | 20 | void finalize(Stream) { 21 | throw std::runtime_error("[gpu::finalize] GPU backend is not available"); 22 | } 23 | 24 | void synchronize(Stream) { 25 | throw std::runtime_error("[gpu::synchronize] GPU backend is not available"); 26 | } 27 | 28 | } // namespace mlx::core::gpu 29 | -------------------------------------------------------------------------------- /python/mlx/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def main() -> None: 5 | from mlx.core import __version__ 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--version", 10 | action="version", 11 | version=__version__, 12 | help="Print the version number.", 13 | ) 14 | parser.add_argument( 15 | "--cmake-dir", 16 | action="store_true", 17 | help="Print the path to the MLX CMake module directory.", 18 | ) 19 | args = parser.parse_args() 20 | if args.cmake_dir: 21 | from pathlib import Path 22 | 23 | print(Path(__file__).parent) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /examples/cmake_project/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.27) 2 | 3 | project(example LANGUAGES CXX) 4 | 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | # Comment the following two commands only the MLX C++ library is installed and 9 | # set(MLX_ROOT "/path/to/mlx") directly if needed. 10 | find_package( 11 | Python 3.9 12 | COMPONENTS Interpreter Development.Module 13 | REQUIRED) 14 | execute_process( 15 | COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir 16 | OUTPUT_STRIP_TRAILING_WHITESPACE 17 | OUTPUT_VARIABLE MLX_ROOT) 18 | 19 | find_package(MLX CONFIG REQUIRED) 20 | 21 | add_executable(example example.cpp) 22 | target_link_libraries(example PRIVATE mlx) 23 | -------------------------------------------------------------------------------- /mlx/mlx.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | #include "mlx/backend/cuda/cuda.h" 7 | #include "mlx/backend/gpu/available.h" 8 | #include "mlx/backend/metal/metal.h" 9 | #include "mlx/compile.h" 10 | #include "mlx/device.h" 11 | #include "mlx/distributed/distributed.h" 12 | #include "mlx/distributed/ops.h" 13 | #include "mlx/einsum.h" 14 | #include "mlx/export.h" 15 | #include "mlx/fast.h" 16 | #include "mlx/fft.h" 17 | #include "mlx/io.h" 18 | #include "mlx/linalg.h" 19 | #include "mlx/memory.h" 20 | #include "mlx/ops.h" 21 | #include "mlx/random.h" 22 | #include "mlx/stream.h" 23 | #include "mlx/transforms.h" 24 | #include "mlx/utils.h" 25 | #include "mlx/version.h" 26 | -------------------------------------------------------------------------------- /.github/scripts/setup+build-cpp-linux-fedora-container.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | # [Setup] Install dependencies inside the container. 5 | dnf update -y 6 | dnf install -y \ 7 | blas-devel \ 8 | lapack-devel \ 9 | openblas-devel \ 10 | make \ 11 | cmake \ 12 | clang \ 13 | git 14 | dnf clean all 15 | 16 | # [C++] CI Build Sanity Check: Verifies code compilation, not for release. 17 | export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" 18 | export DEBUG=1 19 | export CMAKE_C_COMPILER=/usr/bin/clang 20 | export CMAKE_CXX_COMPILER=/usr/bin/clang++ 21 | 22 | mkdir -p build 23 | pushd build 24 | cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG 25 | make -j $(nproc) 26 | ./tests/tests 27 | popd 28 | -------------------------------------------------------------------------------- /mlx/backend/cpu/arange.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | #include "mlx/backend/cpu/encoder.h" 7 | 8 | namespace mlx::core { 9 | 10 | namespace { 11 | 12 | template 13 | void arange(T start, T next, array& out, size_t size, Stream stream) { 14 | auto ptr = out.data(); 15 | auto step_size = next - start; 16 | auto& encoder = cpu::get_command_encoder(stream); 17 | encoder.set_output_array(out); 18 | encoder.dispatch([ptr, start, step_size, size]() mutable { 19 | for (int i = 0; i < size; ++i) { 20 | ptr[i] = start; 21 | start += step_size; 22 | } 23 | }); 24 | } 25 | 26 | } // namespace 27 | 28 | } // namespace mlx::core 29 | -------------------------------------------------------------------------------- /benchmarks/python/comparative/README.md: -------------------------------------------------------------------------------- 1 | Microbenchmarks comparing MLX to PyTorch 2 | ======================================== 3 | 4 | Implement the same microbenchmarks in MLX and PyTorch to compare and make a 5 | list of the biggest possible performance improvements and/or regressions. 6 | 7 | Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for 8 | instance to measure the times it takes to sum across the 3rd axis of the above 9 | tensor on the cpu. 10 | 11 | `compare.py` runs several benchmarks and compares the speed-up or lack thereof 12 | in comparison to PyTorch. 13 | 14 | Each bench script can be run with `--print-pid` to print the PID and wait for a 15 | key in order to ease attaching a debugger. 16 | -------------------------------------------------------------------------------- /.github/actions/setup-macos/action.yml: -------------------------------------------------------------------------------- 1 | name: 'Setup macOS Environment' 2 | description: 'Install dependencies for macOS builds' 3 | 4 | inputs: 5 | python-version: 6 | description: 'Python version to use' 7 | required: false 8 | default: '3.10' 9 | 10 | runs: 11 | using: "composite" 12 | steps: 13 | - name: Install Homebrew packages 14 | shell: sh 15 | run: /opt/homebrew/bin/brew install openmpi 16 | 17 | - name: Verify MetalToolchain installed 18 | shell: bash 19 | run: xcodebuild -showComponent MetalToolchain 20 | 21 | - uses: conda-incubator/setup-miniconda@v3 22 | with: 23 | miniconda-version: "latest" 24 | python-version: ${{ inputs.python-version }} 25 | -------------------------------------------------------------------------------- /mlx/device.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | namespace mlx::core { 6 | 7 | struct Device { 8 | enum class DeviceType { 9 | cpu, 10 | gpu, 11 | }; 12 | 13 | static constexpr DeviceType cpu = DeviceType::cpu; 14 | static constexpr DeviceType gpu = DeviceType::gpu; 15 | 16 | Device(DeviceType type, int index = 0) : type(type), index(index) {} 17 | 18 | DeviceType type; 19 | int index; 20 | }; 21 | 22 | const Device& default_device(); 23 | 24 | void set_default_device(const Device& d); 25 | 26 | bool operator==(const Device& lhs, const Device& rhs); 27 | bool operator!=(const Device& lhs, const Device& rhs); 28 | 29 | bool is_available(const Device& d); 30 | 31 | } // namespace mlx::core 32 | -------------------------------------------------------------------------------- /docs/src/python/nn/functions.rst: -------------------------------------------------------------------------------- 1 | .. _nn_functions: 2 | 3 | .. currentmodule:: mlx.nn 4 | 5 | Functions 6 | --------- 7 | 8 | Layers without parameters (e.g. activation functions) are also provided as 9 | simple functions. 10 | 11 | .. autosummary:: 12 | :toctree: _autosummary_functions 13 | :template: nn-module-template.rst 14 | 15 | elu 16 | celu 17 | gelu 18 | gelu_approx 19 | gelu_fast_approx 20 | glu 21 | hard_shrink 22 | hard_tanh 23 | hardswish 24 | leaky_relu 25 | log_sigmoid 26 | log_softmax 27 | mish 28 | prelu 29 | relu 30 | relu2 31 | relu6 32 | selu 33 | sigmoid 34 | silu 35 | softmax 36 | softmin 37 | softplus 38 | softshrink 39 | step 40 | tanh 41 | -------------------------------------------------------------------------------- /python/mlx/nn/layers/containers.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from mlx.nn.layers.base import Module 4 | 5 | 6 | class Sequential(Module): 7 | """A layer that calls the passed callables in order. 8 | 9 | We can pass either modules or plain callables to the Sequential module. If 10 | our functions have learnable parameters they should be implemented as 11 | ``nn.Module`` instances. 12 | 13 | Args: 14 | modules (tuple of Callables): The modules to call in order 15 | """ 16 | 17 | def __init__(self, *modules): 18 | super().__init__() 19 | self.layers = list(modules) 20 | 21 | def __call__(self, x): 22 | for m in self.layers: 23 | x = m(x) 24 | return x 25 | -------------------------------------------------------------------------------- /docs/src/usage/using_streams.rst: -------------------------------------------------------------------------------- 1 | .. _using_streams: 2 | 3 | Using Streams 4 | ============= 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | Specifying the :obj:`Stream` 9 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 10 | 11 | All operations (including random number generation) take an optional 12 | keyword argument ``stream``. The ``stream`` kwarg specifies which 13 | :obj:`Stream` the operation should run on. If the stream is unspecified then 14 | the operation is run on the default stream of the default device: 15 | ``mx.default_stream(mx.default_device())``. The ``stream`` kwarg can also 16 | be a :obj:`Device` (e.g. ``stream=my_device``) in which case the operation is 17 | run on the default stream of the provided device 18 | ``mx.default_stream(my_device)``. 19 | -------------------------------------------------------------------------------- /mlx/backend/common/broadcasting.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/backend/common/utils.h" 4 | 5 | namespace mlx::core { 6 | 7 | void broadcast(const array& in, array& out) { 8 | if (out.size() == 0) { 9 | out.set_data(allocator::malloc(0)); 10 | return; 11 | } 12 | Strides strides(out.ndim(), 0); 13 | int diff = out.ndim() - in.ndim(); 14 | for (int i = in.ndim() - 1; i >= 0; --i) { 15 | strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; 16 | } 17 | auto flags = in.flags(); 18 | if (out.size() > in.size()) { 19 | flags.row_contiguous = flags.col_contiguous = false; 20 | } 21 | out.copy_shared_buffer(in, strides, flags, in.data_size()); 22 | } 23 | 24 | } // namespace mlx::core 25 | -------------------------------------------------------------------------------- /benchmarks/cpp/compare_devices.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include "mlx/mlx.h" 5 | #include "time_utils.h" 6 | 7 | namespace mx = mlx::core; 8 | 9 | void time_add_op() { 10 | std::vector sizes(1, 1); 11 | for (int i = 0; i < 9; ++i) { 12 | sizes.push_back(10 * sizes.back()); 13 | } 14 | set_default_device(mx::Device::cpu); 15 | for (auto size : sizes) { 16 | auto a = mx::random::uniform({size}); 17 | auto b = mx::random::uniform({size}); 18 | mx::eval(a, b); 19 | std::cout << "Size " << size << std::endl; 20 | TIMEM("cpu", mx::add, a, b, mx::Device::cpu); 21 | TIMEM("gpu", mx::add, a, b, mx::Device::gpu); 22 | } 23 | } 24 | 25 | int main() { 26 | time_add_op(); 27 | } 28 | -------------------------------------------------------------------------------- /.github/actions/build-linux/action.yml: -------------------------------------------------------------------------------- 1 | name: 'Build and Test on Linux' 2 | description: 'Build and test MLX on Linux' 3 | 4 | runs: 5 | using: "composite" 6 | steps: 7 | - name: Install Python package 8 | shell: sh 9 | env: 10 | CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" 11 | DEBUG: 1 12 | run: pip install --no-build-isolation -e ".[dev]" -v 13 | 14 | - name: Generate package stubs 15 | shell: sh 16 | run: | 17 | pip install typing_extensions 18 | python setup.py generate_stubs 19 | 20 | - name: Build CPP only 21 | shell: bash 22 | run: | 23 | mkdir -p build && cd build 24 | cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG 25 | make -j $(nproc) 26 | -------------------------------------------------------------------------------- /mlx/backend/no_cpu/compiled.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include "mlx/compile_impl.h" 4 | #include "mlx/primitives.h" 5 | 6 | namespace mlx::core { 7 | 8 | // GPU compile is always available if the GPU is available and since we are in 9 | // this file CPU compile is not available so check if the device is a GPU 10 | // device. 11 | namespace detail { 12 | bool compile_available_for_device(const Device& device) { 13 | return device == Device::gpu; 14 | } 15 | } // namespace detail 16 | 17 | void Compiled::eval_cpu( 18 | const std::vector& inputs, 19 | std::vector& outputs) { 20 | throw std::runtime_error( 21 | "[Compiled::eval_cpu] CPU compilation not supported on the platform."); 22 | } 23 | 24 | } // namespace mlx::core 25 | -------------------------------------------------------------------------------- /mlx/backend/common/unary.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/allocator.h" 6 | #include "mlx/backend/common/utils.h" 7 | 8 | namespace mlx::core { 9 | 10 | inline void set_unary_output_data( 11 | const array& in, 12 | array& out, 13 | std::function mallocfn = allocator::malloc) { 14 | if (in.flags().contiguous) { 15 | if (is_donatable(in, out)) { 16 | out.copy_shared_buffer(in); 17 | } else { 18 | out.set_data( 19 | mallocfn(in.data_size() * out.itemsize()), 20 | in.data_size(), 21 | in.strides(), 22 | in.flags()); 23 | } 24 | } else { 25 | out.set_data(mallocfn(out.nbytes())); 26 | } 27 | } 28 | 29 | } // namespace mlx::core 30 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ## Build the Docs 2 | 3 | ### Setup (do once) 4 | 5 | Install Doxygen: 6 | 7 | ``` 8 | brew install doxygen 9 | ``` 10 | 11 | Install Python packages: 12 | 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ### Build 18 | 19 | Build the docs from `mlx/docs/` 20 | 21 | ``` 22 | doxygen && make html 23 | ``` 24 | 25 | View the docs by running a server in `mlx/docs/build/html/`: 26 | 27 | ``` 28 | python -m http.server 29 | ``` 30 | 31 | and point your browser to `http://localhost:`. 32 | 33 | ### Push to GitHub Pages 34 | 35 | Check-out the `gh-pages` branch (`git switch gh-pages`) and build 36 | the docs. Then force add the `build/html` directory: 37 | 38 | `git add -f build/html` 39 | 40 | Commit and push the changes to the `gh-pages` branch. 41 | -------------------------------------------------------------------------------- /mlx/distributed/mpi/mpi_declarations.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | // Constants 4 | 5 | #define MPI_SUCCESS 0 6 | #define MPI_ANY_SOURCE -1 7 | #define MPI_ANY_TAG -1 8 | #define MPI_IN_PLACE ((void*)1) 9 | #define MPI_MAX_LIBRARY_VERSION_STRING 256 10 | 11 | // Define all the types that we use so that we don't include which 12 | // causes linker errors on some platforms. 13 | // 14 | // NOTE: We define everything for openmpi. 15 | 16 | typedef void* MPI_Comm; 17 | typedef void* MPI_Datatype; 18 | typedef void* MPI_Op; 19 | 20 | typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*); 21 | 22 | typedef struct ompi_status_public_t { 23 | int MPI_SOURCE; 24 | int MPI_TAG; 25 | int MPI_ERROR; 26 | int _cancelled; 27 | size_t _ucount; 28 | } MPI_Status; 29 | -------------------------------------------------------------------------------- /mlx/backend/metal/binary.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | void binary_op_gpu( 10 | const std::vector& inputs, 11 | std::vector& outputs, 12 | const char* op, 13 | const Stream& s); 14 | 15 | void binary_op_gpu( 16 | const std::vector& inputs, 17 | array& out, 18 | const char* op, 19 | const Stream& s); 20 | 21 | void binary_op_gpu_inplace( 22 | const std::vector& inputs, 23 | std::vector& outputs, 24 | const char* op, 25 | const Stream& s); 26 | 27 | void binary_op_gpu_inplace( 28 | const std::vector& inputs, 29 | array& out, 30 | const char* op, 31 | const Stream& s); 32 | 33 | } // namespace mlx::core 34 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/arange.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | // clang-format off 4 | #include "mlx/backend/metal/kernels/utils.h" 5 | #include "mlx/backend/metal/kernels/arange.h" 6 | 7 | #define instantiate_arange(tname, type) \ 8 | instantiate_kernel("arange" #tname, arange, type) 9 | 10 | instantiate_arange(uint8, uint8_t) 11 | instantiate_arange(uint16, uint16_t) 12 | instantiate_arange(uint32, uint32_t) 13 | instantiate_arange(uint64, uint64_t) 14 | instantiate_arange(int8, int8_t) 15 | instantiate_arange(int16, int16_t) 16 | instantiate_arange(int32, int32_t) 17 | instantiate_arange(int64, int64_t) 18 | instantiate_arange(float16, half) 19 | instantiate_arange(float32, float) 20 | instantiate_arange(bfloat16, bfloat16_t) // clang-format on 21 | -------------------------------------------------------------------------------- /benchmarks/python/rope_bench.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | from time_utils import time_fn 6 | 7 | 8 | def time_rope(): 9 | rope = nn.RoPE(64) 10 | 11 | # vec 12 | x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16) 13 | mx.eval(x) 14 | 15 | def rope_vec(x): 16 | for _ in range(32): 17 | x = rope(x, offset=100) 18 | return x 19 | 20 | time_fn(rope_vec, x) 21 | 22 | # matrix 23 | x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16) 24 | mx.eval(x) 25 | 26 | def rope_mat(x): 27 | for _ in range(32): 28 | x = rope(x) 29 | return x 30 | 31 | time_fn(rope_mat, x) 32 | 33 | 34 | if __name__ == "__main__": 35 | time_rope() 36 | -------------------------------------------------------------------------------- /mlx/backend/metal/resident.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/device.h" 6 | 7 | namespace mlx::core::metal { 8 | 9 | class ResidencySet { 10 | public: 11 | ResidencySet(MTL::Device* d); 12 | ~ResidencySet(); 13 | 14 | ResidencySet(const ResidencySet&) = delete; 15 | ResidencySet& operator=(const ResidencySet&) = delete; 16 | 17 | const MTL::ResidencySet* mtl_residency_set() { 18 | return wired_set_; 19 | } 20 | 21 | void insert(MTL::Allocation* buf); 22 | void erase(MTL::Allocation* buf); 23 | 24 | void resize(size_t size); 25 | 26 | private: 27 | MTL::ResidencySet* wired_set_{nullptr}; 28 | std::unordered_set unwired_set_; 29 | size_t capacity_{0}; 30 | }; 31 | 32 | } // namespace mlx::core::metal 33 | -------------------------------------------------------------------------------- /mlx/backend/gpu/slicing.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | void slice_gpu( 10 | const array& in, 11 | array& out, 12 | const Shape& start_indices, 13 | const Shape& strides, 14 | const Stream& s); 15 | 16 | void concatenate_gpu( 17 | const std::vector& inputs, 18 | array& out, 19 | int axis, 20 | const Stream& s); 21 | 22 | void pad_gpu( 23 | const array& in, 24 | const array& val, 25 | array& out, 26 | const std::vector& axes, 27 | const Shape& low_pad_size, 28 | const Stream& s); 29 | 30 | array compute_dynamic_offset( 31 | const array& indices, 32 | const Strides& strides, 33 | const std::vector& axes, 34 | const Stream& s); 35 | 36 | } // namespace mlx::core 37 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v6.0.0 4 | hooks: 5 | - id: check-yaml 6 | # - id: end-of-file-fixer 7 | # - id: trailing-whitespace 8 | - repo: https://github.com/pre-commit/mirrors-clang-format 9 | rev: v19.1.7 10 | hooks: 11 | - id: clang-format 12 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster 13 | - repo: https://github.com/psf/black-pre-commit-mirror 14 | rev: 25.1.0 15 | hooks: 16 | - id: black 17 | 18 | - repo: https://github.com/pycqa/isort 19 | rev: 6.0.0 20 | hooks: 21 | - id: isort 22 | args: 23 | - --profile=black 24 | - repo: https://github.com/cheshirekow/cmake-format-precommit 25 | rev: v0.6.13 26 | hooks: 27 | - id: cmake-format 28 | -------------------------------------------------------------------------------- /python/scripts/repair_cuda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | auditwheel repair dist/* \ 4 | --plat manylinux_2_35_x86_64 \ 5 | --exclude libcublas* \ 6 | --exclude libnvrtc* \ 7 | --exclude libcuda* \ 8 | --exclude libcudnn* \ 9 | --exclude libnccl* \ 10 | -w wheel_tmp 11 | 12 | 13 | mkdir wheelhouse 14 | cd wheel_tmp 15 | repaired_wheel=$(find . -name "*.whl" -print -quit) 16 | unzip -q "${repaired_wheel}" 17 | rm "${repaired_wheel}" 18 | mlx_so="mlx/lib/libmlx.so" 19 | rpath=$(patchelf --print-rpath "${mlx_so}") 20 | base="\$ORIGIN/../../nvidia" 21 | rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib:${base}/nccl/lib 22 | patchelf --force-rpath --set-rpath "$rpath" "$mlx_so" 23 | python ../python/scripts/repair_record.py ${mlx_so} 24 | 25 | # Re-zip the repaired wheel 26 | zip -r -q "../wheelhouse/${repaired_wheel}" . 27 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/indexing/gather_front.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/metal/kernels/indexing/indexing.h" 6 | 7 | template 8 | [[kernel]] void gather_front( 9 | const device T* src, 10 | const device IdxT* indices, 11 | device T* out, 12 | const constant int64_t& stride, 13 | const constant int& size, 14 | uint2 index [[thread_position_in_grid]], 15 | uint2 grid_dim [[threads_per_grid]]) { 16 | auto idx = offset_neg_idx(indices[index.y], size); 17 | LocT src_idx = static_cast(stride) * idx; 18 | LocT out_idx = static_cast(stride) * index.y; 19 | 20 | int s_idx = N * index.x; 21 | for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) { 22 | out[out_idx + s_idx] = src[src_idx + s_idx]; 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /docs/src/_templates/module-base-class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | .. rubric:: Attributes 12 | 13 | .. autosummary:: 14 | :toctree: . 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | .. rubric:: Methods 24 | 25 | .. autosummary:: 26 | :toctree: . 27 | {% for item in methods %} 28 | {%- if item not in inherited_members and item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | -------------------------------------------------------------------------------- /.github/actions/build-cuda/action.yml: -------------------------------------------------------------------------------- 1 | name: 'Build and Test with CUDA' 2 | description: 'Build and test MLX with CUDA' 3 | 4 | inputs: 5 | toolkit: 6 | description: 'The CUDA toolkit' 7 | required: true 8 | 9 | runs: 10 | using: "composite" 11 | steps: 12 | - name: Install Python package 13 | shell: bash 14 | env: 15 | DEBUG: 1 16 | CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc 17 | run: pip install --no-build-isolation -e ".[dev]" -v 18 | 19 | - name: Build CPP only 20 | shell: bash 21 | run: | 22 | cmake . -B build \ 23 | -DMLX_BUILD_CUDA=ON \ 24 | -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \ 25 | -DCMAKE_BUILD_TYPE=DEBUG 26 | cmake --build build -j $(nproc) 27 | -------------------------------------------------------------------------------- /python/scripts/repair_record.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import glob 3 | import hashlib 4 | import sys 5 | 6 | filename = sys.argv[1] 7 | 8 | 9 | # Compute the new hash and size 10 | def urlsafe_b64encode(data: bytes) -> bytes: 11 | return base64.urlsafe_b64encode(data).rstrip(b"=") 12 | 13 | 14 | hasher = hashlib.sha256() 15 | with open(filename, "rb") as f: 16 | data = f.read() 17 | hasher.update(data) 18 | hash_str = urlsafe_b64encode(hasher.digest()).decode("ascii") 19 | size = len(data) 20 | 21 | # Update the record file 22 | record_file = glob.glob("*/RECORD")[0] 23 | with open(record_file, "r") as f: 24 | lines = [l.split(",") for l in f.readlines()] 25 | 26 | for l in lines: 27 | if filename == l[0]: 28 | l[1] = hash_str 29 | l[2] = f"{size}\n" 30 | 31 | with open(record_file, "w") as f: 32 | for l in lines: 33 | f.write(",".join(l)) 34 | -------------------------------------------------------------------------------- /docs/src/python/nn/module.rst: -------------------------------------------------------------------------------- 1 | Module 2 | ====== 3 | 4 | .. currentmodule:: mlx.nn 5 | 6 | .. autoclass:: Module 7 | 8 | .. rubric:: Attributes 9 | 10 | .. autosummary:: 11 | :toctree: _autosummary 12 | 13 | Module.training 14 | Module.state 15 | 16 | .. rubric:: Methods 17 | 18 | .. autosummary:: 19 | :toctree: _autosummary 20 | 21 | Module.apply 22 | Module.apply_to_modules 23 | Module.children 24 | Module.eval 25 | Module.filter_and_map 26 | Module.freeze 27 | Module.leaf_modules 28 | Module.load_weights 29 | Module.modules 30 | Module.named_modules 31 | Module.parameters 32 | Module.save_weights 33 | Module.set_dtype 34 | Module.train 35 | Module.trainable_parameters 36 | Module.unfreeze 37 | Module.update 38 | Module.update_modules 39 | -------------------------------------------------------------------------------- /mlx/backend/metal/make_compiled_preamble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script generates a C++ function that provides the Metal unary and binary 4 | # ops at runtime for use with kernel generation. 5 | # 6 | # Copyright © 2023-24 Apple Inc. 7 | 8 | OUTPUT_DIR=$1 9 | CC=$2 10 | SRC_DIR=$3 11 | SRC_FILE=$4 12 | CFLAGS=$5 13 | SRC_NAME=$(basename -- "${SRC_FILE}") 14 | JIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit 15 | INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h 16 | OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp 17 | 18 | mkdir -p "$OUTPUT_DIR" 19 | CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null) 20 | 21 | cat << EOF > "$OUTPUT_FILE" 22 | namespace mlx::core::metal { 23 | 24 | const char* $SRC_NAME() { 25 | return R"preamble( 26 | $CONTENT 27 | )preamble"; 28 | } 29 | 30 | } // namespace mlx::core::metal 31 | EOF 32 | -------------------------------------------------------------------------------- /python/src/mlx_func.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | namespace nb = nanobind; 11 | using namespace nb::literals; 12 | 13 | nb::callable mlx_func( 14 | nb::object func, 15 | const nb::callable& orig_func, 16 | std::vector deps); 17 | 18 | template 19 | nb::callable mlx_func(F func, const nb::callable& orig_func, Deps&&... deps) { 20 | return mlx_func( 21 | nb::cpp_function(std::move(func)), 22 | orig_func, 23 | std::vector{deps.ptr()...}); 24 | } 25 | 26 | template 27 | nb::callable 28 | mlx_func(nb::object func, const nb::callable& orig_func, Deps&&... deps) { 29 | return mlx_func( 30 | std::move(func), orig_func, std::vector{deps.ptr()...}); 31 | } 32 | -------------------------------------------------------------------------------- /examples/export/README.md: -------------------------------------------------------------------------------- 1 | ## Setup 2 | 3 | Install MLX: 4 | 5 | ```bash 6 | pip install mlx>=0.22 7 | ``` 8 | 9 | Build the C++ examples: 10 | 11 | ```bash 12 | cmake -B build -DCMAKE_BUILD_TYPE=Release 13 | cmake --build build 14 | ``` 15 | 16 | ## Run 17 | 18 | ### Eval MLP 19 | 20 | Run the Python script to export the eval function: 21 | 22 | ```bash 23 | python eval_mlp.py 24 | ``` 25 | 26 | Then run the C++ program to import and run the function: 27 | 28 | ``` 29 | ./build/eval_mlp 30 | ``` 31 | 32 | The Python and C++ programs should output the same result. 33 | 34 | ### Train MLP 35 | 36 | Run the Python script to export the model initialization and training 37 | functions: 38 | 39 | ```bash 40 | python train_mlp.py 41 | ``` 42 | 43 | Then run the C++ program to import and run the functions: 44 | 45 | ``` 46 | ./build/train_mlp 47 | ``` 48 | 49 | The Python and C++ programs should output the same results. 50 | -------------------------------------------------------------------------------- /mlx/distributed/reduction_ops.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | namespace mlx::core::distributed::detail { 4 | 5 | template 6 | struct SumOp { 7 | void operator()(const T* input, T* output, size_t N) const { 8 | while (N-- > 0) { 9 | *output += *input; 10 | input++; 11 | output++; 12 | } 13 | } 14 | }; 15 | 16 | template 17 | struct MaxOp { 18 | void operator()(const T* input, T* output, size_t N) const { 19 | while (N-- > 0) { 20 | *output = std::max(*output, *input); 21 | input++; 22 | output++; 23 | } 24 | } 25 | }; 26 | 27 | template 28 | struct MinOp { 29 | void operator()(const T* input, T* output, size_t N) const { 30 | while (N-- > 0) { 31 | *output = std::min(*output, *input); 32 | input++; 33 | output++; 34 | } 35 | } 36 | }; 37 | 38 | } // namespace mlx::core::distributed::detail 39 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/defines.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #if defined __METAL__ || defined MLX_METAL_JIT 6 | #define MTL_CONST constant 7 | #else 8 | #define MTL_CONST 9 | #endif 10 | 11 | static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; 12 | static MTL_CONST constexpr int REDUCE_N_READS = 4; 13 | static MTL_CONST constexpr int REDUCE_N_WRITES = 4; 14 | static MTL_CONST constexpr int SOFTMAX_N_READS = 4; 15 | static MTL_CONST constexpr int RMS_N_READS = 4; 16 | static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; 17 | 18 | // Instantiate a templated kernel. 19 | // Extra args are used as template parameters: 20 | // e.g. instantiate_kernel(binary_int, binary, a, b) -> 21 | // [[host_name(binary_int)]] [kernel] binary 22 | #define instantiate_kernel(name, func, ...) \ 23 | template [[host_name( \ 24 | name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; 25 | -------------------------------------------------------------------------------- /tests_jaccl/build_cpp_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Build the C++ separate FFN TP vs single benchmark. 3 | 4 | set -euo pipefail 5 | 6 | ROOT="$(cd "$(dirname "$0")/.." && pwd)" 7 | BUILD_DIR="$ROOT/build_cpp" 8 | CACHE_DIR="$BUILD_DIR/clang_module_cache" 9 | FAKE_HOME="$BUILD_DIR/fake_home" 10 | mkdir -p "$CACHE_DIR" "$FAKE_HOME" 11 | export CLANG_MODULE_CACHE_PATH="$CACHE_DIR" 12 | export HOME="$FAKE_HOME" 13 | 14 | # Only rerun CMake configure when the build directory is missing a cache. 15 | if [ ! -f "$BUILD_DIR/CMakeCache.txt" ]; then 16 | cmake -B "$BUILD_DIR" -S "$ROOT" \ 17 | -DCMAKE_BUILD_TYPE=RelWithDebInfo \ 18 | -DMLX_BUILD_BENCHMARKS=ON \ 19 | -DCMAKE_CXX_FLAGS="-fmodules-cache-path=$CACHE_DIR" 20 | else 21 | echo "Found existing build at $BUILD_DIR (skipping configure)." 22 | fi 23 | 24 | cmake --build "$BUILD_DIR" --target separate_tp_vs_single 25 | 26 | echo "Benchmark built at: $BUILD_DIR/benchmarks/cpp/separate_tp_vs_single" 27 | -------------------------------------------------------------------------------- /mlx/backend/cpu/threefry.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/backend/cpu/threefry.h" 4 | 5 | namespace mlx::core::random { 6 | 7 | std::pair threefry2x32_hash( 8 | const std::pair& key, 9 | std::pair count) { 10 | constexpr static uint32_t rotations[2][4] = { 11 | {13, 15, 26, 6}, {17, 29, 16, 24}}; 12 | 13 | uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA}; 14 | 15 | count.first += ks[0]; 16 | count.second += ks[1]; 17 | 18 | for (int i = 0; i < 5; ++i) { 19 | for (auto r : rotations[i % 2]) { 20 | count.first += count.second; 21 | count.second = (count.second << r) | (count.second >> (32 - r)); 22 | count.second ^= count.first; 23 | } 24 | count.first += ks[(i + 1) % 3]; 25 | count.second += ks[(i + 2) % 3] + i + 1; 26 | } 27 | 28 | return count; 29 | } 30 | 31 | } // namespace mlx::core::random 32 | -------------------------------------------------------------------------------- /mlx/dtype_utils.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/dtype_utils.h" 4 | 5 | namespace mlx::core { 6 | 7 | const char* dtype_to_string(Dtype arg) { 8 | switch (arg) { 9 | case bool_: 10 | return "bool"; 11 | case int8: 12 | return "int8"; 13 | case int16: 14 | return "int16"; 15 | case int32: 16 | return "int32"; 17 | case int64: 18 | return "int64"; 19 | case uint8: 20 | return "uint8"; 21 | case uint16: 22 | return "uint16"; 23 | case uint32: 24 | return "uint32"; 25 | case uint64: 26 | return "uint64"; 27 | case float16: 28 | return "float16"; 29 | case bfloat16: 30 | return "bfloat16"; 31 | case float32: 32 | return "float32"; 33 | case float64: 34 | return "float64"; 35 | case complex64: 36 | return "complex64"; 37 | default: 38 | return "unknown"; 39 | } 40 | } 41 | 42 | } // namespace mlx::core 43 | -------------------------------------------------------------------------------- /benchmarks/python/time_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import time 4 | 5 | import mlx.core as mx 6 | 7 | 8 | def time_fn(fn, *args, **kwargs): 9 | msg = kwargs.pop("msg", None) 10 | if msg: 11 | print(f"Timing {msg} ...", end=" ") 12 | else: 13 | print(f"Timing {fn.__name__} ...", end=" ") 14 | 15 | # warmup 16 | for _ in range(5): 17 | mx.eval(fn(*args, **kwargs)) 18 | 19 | num_iters = 100 20 | tic = time.perf_counter() 21 | for _ in range(num_iters): 22 | x = mx.eval(fn(*args, **kwargs)) 23 | toc = time.perf_counter() 24 | 25 | msec = 1e3 * (toc - tic) / num_iters 26 | print(f"{msec:.5f} msec") 27 | 28 | 29 | def measure_runtime(fn, **kwargs): 30 | # Warmup 31 | for _ in range(5): 32 | fn(**kwargs) 33 | 34 | tic = time.time() 35 | iters = 100 36 | for _ in range(iters): 37 | fn(**kwargs) 38 | return (time.time() - tic) * 1000 / iters 39 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/bitwise_binary.cu: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/binary/binary.cuh" 4 | 5 | namespace mlx::core { 6 | void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { 7 | nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); 8 | auto& s = out.primitive().stream(); 9 | switch (op_) { 10 | case BitwiseBinary::And: 11 | binary_op_gpu(inputs, out, name(), s); 12 | break; 13 | case BitwiseBinary::Or: 14 | binary_op_gpu(inputs, out, name(), s); 15 | break; 16 | case BitwiseBinary::Xor: 17 | binary_op_gpu(inputs, out, name(), s); 18 | break; 19 | case BitwiseBinary::LeftShift: 20 | binary_op_gpu(inputs, out, name(), s); 21 | break; 22 | case BitwiseBinary::RightShift: 23 | binary_op_gpu(inputs, out, name(), s); 24 | break; 25 | } 26 | } 27 | } // namespace mlx::core 28 | -------------------------------------------------------------------------------- /mlx/backend/cuda/device/indexing.cuh: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | namespace mlx::core::cu { 7 | 8 | // Convert an absolute index to positions in a 3d grid, assuming the index is 9 | // calculated with: 10 | // index = x * dim1 * dim2 + y * dim2 + z 11 | template 12 | inline __host__ __device__ cuda::std::tuple 13 | index_to_dims(T index, T dim1, T dim2) { 14 | T x = index / (dim1 * dim2); 15 | T y = (index % (dim1 * dim2)) / dim2; 16 | T z = index % dim2; 17 | return cuda::std::make_tuple(x, y, z); 18 | } 19 | 20 | // Get absolute index from possible negative index. 21 | template 22 | inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) { 23 | if constexpr (cuda::std::is_unsigned_v) { 24 | return idx; 25 | } else { 26 | return static_cast(idx < 0 ? idx + size : idx); 27 | } 28 | } 29 | 30 | } // namespace mlx::core::cu 31 | -------------------------------------------------------------------------------- /mlx/backend/cpu/make_compiled_preamble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script generates a C++ function that provides the CPU 4 | # code for use with kernel generation. 5 | # 6 | # Copyright © 2023-24 Apple Inc. 7 | 8 | 9 | OUTPUT_FILE=$1 10 | GCC=$2 11 | SRCDIR=$3 12 | CLANG=$4 13 | ARCH=$5 14 | 15 | if [ "$CLANG" = "TRUE" ]; then 16 | read -r -d '' INCLUDES <<- EOM 17 | #include 18 | #include 19 | #include 20 | #include 21 | #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC 22 | #include 23 | #endif 24 | EOM 25 | CC_FLAGS="-arch ${ARCH} -nobuiltininc -nostdinc" 26 | else 27 | CC_FLAGS="-std=c++17" 28 | fi 29 | 30 | CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E -P "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null) 31 | 32 | cat << EOF > "$OUTPUT_FILE" 33 | const char* get_kernel_preamble() { 34 | return R"preamble( 35 | $INCLUDES 36 | $CONTENT 37 | using namespace mlx::core; 38 | using namespace mlx::core::detail; 39 | )preamble"; 40 | } 41 | EOF 42 | -------------------------------------------------------------------------------- /examples/export/train_mlp.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | namespace mx = mlx::core; 7 | 8 | int main() { 9 | int batch_size = 8; 10 | int input_dim = 32; 11 | int output_dim = 10; 12 | 13 | auto state = mx::import_function("init_mlp.mlxfn")({}); 14 | 15 | // Make the input 16 | mx::random::seed(42); 17 | auto example_X = mx::random::normal({batch_size, input_dim}); 18 | auto example_y = mx::random::randint(0, output_dim, {batch_size}); 19 | 20 | // Import the function 21 | auto step = mx::import_function("train_mlp.mlxfn"); 22 | 23 | // Call the imported function 24 | for (int it = 0; it < 100; ++it) { 25 | state.insert(state.end(), {example_X, example_y}); 26 | state = step(state); 27 | eval(state); 28 | auto loss = state.back(); 29 | state.pop_back(); 30 | if (it % 10 == 0) { 31 | std::cout << "Loss " << loss.item() << std::endl; 32 | } 33 | } 34 | return 0; 35 | } 36 | -------------------------------------------------------------------------------- /tests/device_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "doctest/doctest.h" 4 | 5 | #include 6 | 7 | #include "mlx/mlx.h" 8 | 9 | using namespace mlx::core; 10 | 11 | TEST_CASE("test device placement") { 12 | auto device = default_device(); 13 | Device d = gpu::is_available() ? Device::gpu : Device::cpu; 14 | if (std::getenv("DEVICE") == nullptr) { 15 | CHECK_EQ(device, d); 16 | } 17 | 18 | array x(1.0f); 19 | array y(1.0f); 20 | auto z = add(x, y, default_device()); 21 | if (gpu::is_available()) { 22 | z = add(x, y, Device::gpu); 23 | z = add(x, y, Device(Device::gpu, 0)); 24 | } else { 25 | CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument); 26 | CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument); 27 | } 28 | 29 | // Set the default device to the CPU 30 | set_default_device(Device::cpu); 31 | CHECK_EQ(default_device(), Device::cpu); 32 | 33 | // Revert 34 | set_default_device(device); 35 | } 36 | -------------------------------------------------------------------------------- /mlx/backend/metal/no_metal.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/metal/metal.h" 6 | #include "mlx/fast.h" 7 | 8 | namespace mlx::core { 9 | 10 | namespace metal { 11 | 12 | bool is_available() { 13 | return false; 14 | } 15 | 16 | void start_capture(std::string) {} 17 | void stop_capture() {} 18 | 19 | const std::unordered_map>& 20 | device_info() { 21 | throw std::runtime_error( 22 | "[metal::device_info] Cannot get device info without metal backend"); 23 | }; 24 | 25 | } // namespace metal 26 | 27 | namespace fast { 28 | 29 | CustomKernelFunction metal_kernel( 30 | const std::string&, 31 | const std::vector&, 32 | const std::vector&, 33 | const std::string&, 34 | const std::string&, 35 | bool, 36 | bool) { 37 | throw std::runtime_error("[metal_kernel] No Metal back-end."); 38 | } 39 | 40 | } // namespace fast 41 | 42 | } // namespace mlx::core 43 | -------------------------------------------------------------------------------- /benchmarks/numpy/single_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import numpy as np 4 | from time_utils import time_fn 5 | 6 | 7 | def time_add(): 8 | a = np.ones((100, 100, 10), dtype=np.float32) 9 | b = np.ones((100, 100, 10), dtype=np.float32) 10 | time_fn(np.add, a, b) 11 | 12 | 13 | def time_matmul(): 14 | a = np.random.rand(1000, 500).astype(np.float32) 15 | b = np.random.rand(500, 1000).astype(np.float32) 16 | time_fn(np.matmul, a, b) 17 | 18 | 19 | def time_exp(): 20 | a = np.random.randn(1000, 100).astype(np.float32) 21 | time_fn(np.exp, a) 22 | 23 | 24 | def time_take(): 25 | a = np.random.rand(10000, 500) 26 | ids = np.random.randint(0, 10000, (20, 10)) 27 | ids = [idx.reshape(-1) for idx in np.split(ids, 20)] 28 | 29 | def random_take(): 30 | return [np.take(a, idx, 0) for idx in ids] 31 | 32 | time_fn(random_take) 33 | 34 | 35 | if __name__ == "__main__": 36 | time_add() 37 | time_matmul() 38 | time_exp() 39 | time_take() 40 | -------------------------------------------------------------------------------- /mlx/backend/cuda/device/scatter_ops.cuh: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/cuda/device/atomic_ops.cuh" 6 | 7 | namespace mlx::core::cu { 8 | 9 | struct ScatterAssign { 10 | template 11 | __device__ void operator()(T* out, T val) const { 12 | *out = val; 13 | } 14 | }; 15 | 16 | struct ScatterSum { 17 | template 18 | __device__ void operator()(T* out, T val) const { 19 | atomic_add(out, val); 20 | } 21 | }; 22 | 23 | struct ScatterProd { 24 | template 25 | __device__ void operator()(T* out, T val) const { 26 | atomic_prod(out, val); 27 | } 28 | }; 29 | 30 | struct ScatterMax { 31 | template 32 | __device__ void operator()(T* out, T val) const { 33 | atomic_max(out, val); 34 | } 35 | }; 36 | 37 | struct ScatterMin { 38 | template 39 | __device__ void operator()(T* out, T val) const { 40 | atomic_min(out, val); 41 | } 42 | }; 43 | 44 | } // namespace mlx::core::cu 45 | -------------------------------------------------------------------------------- /mlx/distributed/nccl/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(MLX_BUILD_CUDA) 2 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp) 3 | find_package(NCCL) 4 | if(NCCL_FOUND) 5 | target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES}) 6 | target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS}) 7 | else() 8 | message( 9 | STATUS 10 | "NCCL not found, using stubs. To run distributed with NCCL backend, install NCCL." 11 | ) 12 | 13 | include(ExternalProject) 14 | ExternalProject_Add( 15 | nccl_stub 16 | SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/nccl_stub" 17 | BUILD_COMMAND ${CMAKE_COMMAND} --build . 18 | INSTALL_COMMAND "") 19 | set(NCCL_PATH 20 | "${CMAKE_CURRENT_BINARY_DIR}/nccl_stub-prefix/src/nccl_stub-build/") 21 | target_link_libraries(mlx PRIVATE ${NCCL_PATH}/libnccl.so) 22 | target_include_directories(mlx PRIVATE ${NCCL_PATH}) 23 | endif() 24 | else() 25 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp) 26 | endif() 27 | -------------------------------------------------------------------------------- /python/tests/test_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import io 4 | import unittest 5 | 6 | import mlx.core as mx 7 | import mlx_tests 8 | 9 | 10 | class TestGraph(mlx_tests.MLXTestCase): 11 | def test_to_dot(self): 12 | # Simply test that a few cases run. 13 | # Nothing too specific about the graph format 14 | # for now to keep it flexible 15 | a = mx.array(1.0) 16 | f = io.StringIO() 17 | mx.export_to_dot(f, a) 18 | f.seek(0) 19 | self.assertTrue(len(f.read()) > 0) 20 | 21 | b = mx.array(2.0) 22 | c = a + b 23 | f = io.StringIO() 24 | mx.export_to_dot(f, c) 25 | f.seek(0) 26 | self.assertTrue(len(f.read()) > 0) 27 | 28 | # Multi output case 29 | c = mx.divmod(a, b) 30 | f = io.StringIO() 31 | mx.export_to_dot(f, *c) 32 | f.seek(0) 33 | self.assertTrue(len(f.read()) > 0) 34 | 35 | 36 | if __name__ == "__main__": 37 | mlx_tests.MLXTestRunner() 38 | -------------------------------------------------------------------------------- /tests/allocator_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "doctest/doctest.h" 6 | 7 | #include "mlx/allocator.h" 8 | 9 | using namespace mlx::core; 10 | 11 | TEST_CASE("test simple allocations") { 12 | { 13 | auto buffer = allocator::malloc(sizeof(float)); 14 | auto fptr = static_cast(buffer.raw_ptr()); 15 | *fptr = 0.5f; 16 | CHECK_EQ(*fptr, 0.5f); 17 | allocator::free(buffer); 18 | } 19 | 20 | { 21 | auto buffer = allocator::malloc(128 * sizeof(int)); 22 | int* ptr = static_cast(buffer.raw_ptr()); 23 | for (int i = 0; i < 128; ++i) { 24 | ptr[i] = i; 25 | } 26 | allocator::free(buffer); 27 | } 28 | 29 | { 30 | auto buffer = allocator::malloc(0); 31 | allocator::free(buffer); 32 | } 33 | } 34 | 35 | TEST_CASE("test large allocations") { 36 | size_t size = 1 << 30; 37 | for (int i = 0; i < 100; ++i) { 38 | auto buffer = allocator::malloc(size); 39 | allocator::free(buffer); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /.github/actions/build-macos-release/action.yml: -------------------------------------------------------------------------------- 1 | name: 'Build macOS release' 2 | description: 'Build MLX releases macOS' 3 | 4 | inputs: 5 | macos-target: 6 | description: 'macOS build target' 7 | required: false 8 | default: '15.0' 9 | build-backend: 10 | description: 'Build the backend mlx-metal package' 11 | type: boolean 12 | required: false 13 | default: false 14 | 15 | runs: 16 | using: "composite" 17 | steps: 18 | - name: Build Python package 19 | shell: bash -l {0} 20 | env: 21 | MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }} 22 | run: | 23 | pip install build 24 | python setup.py clean --all 25 | MLX_BUILD_STAGE=1 python -m build -w 26 | 27 | - name: Build backend package 28 | if: ${{ inputs.build-backend }} 29 | shell: bash -l {0} 30 | env: 31 | MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }} 32 | run: | 33 | python setup.py clean --all 34 | MLX_BUILD_STAGE=2 python -m build -w 35 | -------------------------------------------------------------------------------- /mlx/backend/cuda/quantized/quantized.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/device.h" 4 | 5 | namespace mlx::core { 6 | 7 | void affine_quantize( 8 | const array& w, 9 | array& wq, 10 | array& scales, 11 | array& biases, 12 | int group_size_, 13 | int bits_, 14 | cu::CommandEncoder& enc, 15 | const Stream& s); 16 | 17 | void affine_dequantize( 18 | const array& wq, 19 | const array& scales, 20 | const array& biases, 21 | array& w, 22 | int group_size_, 23 | int bits_, 24 | cu::CommandEncoder& enc, 25 | const Stream& s); 26 | 27 | void fp_quantize( 28 | const array& w, 29 | array& wq, 30 | array& scales, 31 | int group_size, 32 | int bits, 33 | cu::CommandEncoder& enc, 34 | const Stream& s); 35 | 36 | void fp_dequantize( 37 | const array& wq, 38 | const array& scales, 39 | array& w, 40 | int group_size, 41 | int bits, 42 | cu::CommandEncoder& enc, 43 | const Stream& s); 44 | 45 | } // namespace mlx::core 46 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/softmax.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | using namespace metal; 7 | 8 | // clang-format off 9 | #include "mlx/backend/metal/kernels/utils.h" 10 | #include "mlx/backend/metal/kernels/softmax.h" 11 | 12 | #define instantiate_softmax(name, itype) \ 13 | instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \ 14 | instantiate_kernel("looped_softmax_" #name, softmax_looped, itype) 15 | 16 | #define instantiate_softmax_precise(name, itype) \ 17 | instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \ 18 | instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float) 19 | 20 | instantiate_softmax(float32, float) 21 | instantiate_softmax(float16, half) 22 | instantiate_softmax(bfloat16, bfloat16_t) 23 | instantiate_softmax_precise(float16, half) 24 | instantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on 25 | -------------------------------------------------------------------------------- /mlx/backend/cpu/copy.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/array.h" 8 | #include "mlx/backend/common/copy.h" 9 | #include "mlx/backend/common/utils.h" 10 | 11 | namespace mlx::core { 12 | 13 | void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream); 14 | void copy_cpu_inplace( 15 | const array& src, 16 | array& dst, 17 | CopyType ctype, 18 | Stream stream); 19 | 20 | void copy_cpu_inplace( 21 | const array& src, 22 | array& dst, 23 | const Shape& data_shape, 24 | const Strides& i_strides, 25 | const Strides& o_strides, 26 | int64_t i_offset, 27 | int64_t o_offset, 28 | CopyType ctype, 29 | Stream stream, 30 | const std::optional& dynamic_i_offset = std::nullopt, 31 | const std::optional& dynamic_o_offset = std::nullopt); 32 | 33 | // Return a contiguous array with same shape that copies the data of |arr|. 34 | array contiguous_copy_cpu(const array& arr, Stream stream); 35 | 36 | } // namespace mlx::core 37 | -------------------------------------------------------------------------------- /examples/extensions/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "axpby/axpby.h" 7 | 8 | namespace nb = nanobind; 9 | using namespace nb::literals; 10 | 11 | NB_MODULE(_ext, m) { 12 | m.doc() = "Sample extension for MLX"; 13 | 14 | m.def( 15 | "axpby", 16 | &my_ext::axpby, 17 | "x"_a, 18 | "y"_a, 19 | "alpha"_a, 20 | "beta"_a, 21 | nb::kw_only(), 22 | "stream"_a = nb::none(), 23 | R"( 24 | Scale and sum two vectors element-wise 25 | ``z = alpha * x + beta * y`` 26 | 27 | Follows numpy style broadcasting between ``x`` and ``y`` 28 | Inputs are upcasted to floats if needed 29 | 30 | Args: 31 | x (array): Input array. 32 | y (array): Input array. 33 | alpha (float): Scaling factor for ``x``. 34 | beta (float): Scaling factor for ``y``. 35 | 36 | Returns: 37 | array: ``alpha * x + beta * y`` 38 | )"); 39 | } 40 | -------------------------------------------------------------------------------- /examples/cpp/metal_capture.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "mlx/mlx.h" 7 | 8 | namespace mx = mlx::core; 9 | 10 | int main() { 11 | // To use Metal debugging and profiling: 12 | // 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON). 13 | // 2. Run with MTL_CAPTURE_ENABLED=1. 14 | mx::metal::start_capture("mlx_trace.gputrace"); 15 | 16 | // Start at index two because the default GPU and CPU streams have indices 17 | // zero and one, respectively. This naming matches the label assigned to each 18 | // stream's command queue. 19 | auto s2 = new_stream(mx::Device::gpu); 20 | auto s3 = new_stream(mx::Device::gpu); 21 | 22 | auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2); 23 | auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3); 24 | auto x = mx::add(a, a, s2); 25 | auto y = mx::add(b, b, s3); 26 | 27 | // The multiply will happen on the default stream. 28 | std::cout << mx::multiply(x, y) << std::endl; 29 | 30 | mx::metal::stop_capture(); 31 | } 32 | -------------------------------------------------------------------------------- /mlx/stream.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/device.h" 6 | 7 | namespace mlx::core { 8 | 9 | struct Stream { 10 | int index; 11 | Device device; 12 | explicit Stream(int index, Device device) : index(index), device(device) {} 13 | }; 14 | 15 | /** Get the default stream for the given device. */ 16 | Stream default_stream(Device d); 17 | 18 | /** Make the stream the default for its device. */ 19 | void set_default_stream(Stream s); 20 | 21 | /** Make a new stream on the given device. */ 22 | Stream new_stream(Device d); 23 | 24 | /** Get the stream with the given index. */ 25 | Stream get_stream(int index); 26 | 27 | inline bool operator==(const Stream& lhs, const Stream& rhs) { 28 | return lhs.index == rhs.index; 29 | } 30 | 31 | inline bool operator!=(const Stream& lhs, const Stream& rhs) { 32 | return !(lhs == rhs); 33 | } 34 | 35 | /* Synchronize with the default stream. */ 36 | void synchronize(); 37 | 38 | /* Synchronize with the provided stream. */ 39 | void synchronize(Stream); 40 | 41 | } // namespace mlx::core 42 | -------------------------------------------------------------------------------- /benchmarks/cpp/autograd.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/mlx.h" 6 | #include "time_utils.h" 7 | 8 | namespace mx = mlx::core; 9 | 10 | void time_value_and_grad() { 11 | auto x = mx::ones({200, 1000}); 12 | mx::eval(x); 13 | auto fn = [](mx::array x) { 14 | for (int i = 0; i < 20; ++i) { 15 | x = mx::log(mx::exp(x)); 16 | } 17 | return mx::sum(x); 18 | }; 19 | 20 | auto grad_fn = mx::grad(fn); 21 | auto independent_value_and_grad = [&]() { 22 | auto value = fn(x); 23 | auto dfdx = grad_fn(x); 24 | return std::vector{value, dfdx}; 25 | }; 26 | TIME(independent_value_and_grad); 27 | 28 | auto value_and_grad_fn = mx::value_and_grad(fn); 29 | auto combined_value_and_grad = [&]() { 30 | auto [value, dfdx] = value_and_grad_fn(x); 31 | return std::vector{value, dfdx}; 32 | }; 33 | TIME(combined_value_and_grad); 34 | } 35 | 36 | int main() { 37 | std::cout << "Benchmarks for " << mx::default_device() << std::endl; 38 | time_value_and_grad(); 39 | } 40 | -------------------------------------------------------------------------------- /mlx/backend/cuda/binary/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu 4 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu 5 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu 6 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu 7 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu 8 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu 9 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu 10 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu 11 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu 12 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu 13 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu 14 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu 15 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu 16 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu 17 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu 18 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu 19 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu 20 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu 21 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu) 22 | -------------------------------------------------------------------------------- /examples/python/logistic_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | 5 | import mlx.core as mx 6 | 7 | num_features = 100 8 | num_examples = 1_000 9 | num_iters = 10_000 10 | lr = 0.1 11 | 12 | # True parameters 13 | w_star = mx.random.normal((num_features,)) 14 | 15 | # Input examples 16 | X = mx.random.normal((num_examples, num_features)) 17 | 18 | # Labels 19 | y = (X @ w_star) > 0 20 | 21 | 22 | # Initialize random parameters 23 | w = 1e-2 * mx.random.normal((num_features,)) 24 | 25 | 26 | def loss_fn(w): 27 | logits = X @ w 28 | return mx.mean(mx.logaddexp(0.0, logits) - y * logits) 29 | 30 | 31 | grad_fn = mx.grad(loss_fn) 32 | 33 | tic = time.time() 34 | for _ in range(num_iters): 35 | grad = grad_fn(w) 36 | w = w - lr * grad 37 | mx.eval(w) 38 | 39 | toc = time.time() 40 | 41 | loss = loss_fn(w) 42 | final_preds = (X @ w) > 0 43 | acc = mx.mean(final_preds == y) 44 | 45 | throughput = num_iters / (toc - tic) 46 | print( 47 | f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} " 48 | f"Throughput {throughput:.5f} (it/s)" 49 | ) 50 | -------------------------------------------------------------------------------- /mlx/backend/metal/reduce.h: -------------------------------------------------------------------------------- 1 | // Copyright @ 2023 - 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/common/reduce.h" 6 | #include "mlx/backend/metal/device.h" 7 | #include "mlx/stream.h" 8 | 9 | namespace mlx::core { 10 | 11 | using metal::CommandEncoder; 12 | 13 | void all_reduce_dispatch( 14 | const array& in, 15 | array& out, 16 | const std::string& op_name, 17 | CommandEncoder& compute_encoder, 18 | metal::Device& d, 19 | const Stream& s); 20 | 21 | void row_reduce_general_dispatch( 22 | const array& in, 23 | array& out, 24 | const std::string& op_name, 25 | const ReductionPlan& plan, 26 | const std::vector& axes, 27 | CommandEncoder& compute_encoder, 28 | metal::Device& d, 29 | const Stream& s); 30 | 31 | void strided_reduce_general_dispatch( 32 | const array& in, 33 | array& out, 34 | const std::string& op_name, 35 | const ReductionPlan& plan, 36 | const std::vector& axes, 37 | CommandEncoder& compute_encoder, 38 | metal::Device& d, 39 | const Stream& s); 40 | 41 | } // namespace mlx::core 42 | -------------------------------------------------------------------------------- /examples/python/linear_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | 5 | import mlx.core as mx 6 | 7 | num_features = 100 8 | num_examples = 1_000 9 | num_iters = 10_000 10 | lr = 0.01 11 | 12 | # True parameters 13 | w_star = mx.random.normal((num_features,)) 14 | 15 | # Input examples (design matrix) 16 | X = mx.random.normal((num_examples, num_features)) 17 | 18 | # Noisy labels 19 | eps = 1e-2 * mx.random.normal((num_examples,)) 20 | y = X @ w_star + eps 21 | 22 | # Initialize random parameters 23 | w = 1e-2 * mx.random.normal((num_features,)) 24 | 25 | 26 | def loss_fn(w): 27 | return 0.5 * mx.mean(mx.square(X @ w - y)) 28 | 29 | 30 | grad_fn = mx.grad(loss_fn) 31 | 32 | tic = time.time() 33 | for _ in range(num_iters): 34 | grad = grad_fn(w) 35 | w = w - lr * grad 36 | mx.eval(w) 37 | toc = time.time() 38 | 39 | loss = loss_fn(w) 40 | error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5 41 | throughput = num_iters / (toc - tic) 42 | 43 | print( 44 | f"Loss {loss.item():.5f}, L2 distance: |w-w*| = {error_norm:.5f}, " 45 | f"Throughput {throughput:.5f} (it/s)" 46 | ) 47 | -------------------------------------------------------------------------------- /mlx/io/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp) 2 | 3 | if(MLX_BUILD_SAFETENSORS) 4 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp) 5 | else() 6 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp) 7 | endif() 8 | 9 | if(MLX_BUILD_GGUF) 10 | message(STATUS "Downloading gguflib") 11 | FetchContent_Declare( 12 | gguflib 13 | GIT_REPOSITORY https://github.com/antirez/gguf-tools/ 14 | GIT_TAG 8fa6eb65236618e28fd7710a0fba565f7faa1848) 15 | FetchContent_MakeAvailable(gguflib) 16 | target_include_directories(mlx 17 | PRIVATE $) 18 | add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c 19 | ${gguflib_SOURCE_DIR}/gguflib.c) 20 | target_link_libraries(mlx PRIVATE $) 21 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp 22 | ${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp) 23 | else() 24 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp) 25 | endif() 26 | -------------------------------------------------------------------------------- /.github/actions/build-docs/action.yml: -------------------------------------------------------------------------------- 1 | name: 'Build Documentation' 2 | description: 'Build documentation' 3 | 4 | runs: 5 | using: "composite" 6 | steps: 7 | - name: Setup machine 8 | uses: ./.github/actions/setup-linux 9 | 10 | - name: Install dependencies 11 | shell: bash 12 | run: | 13 | sudo apt-get install -y doxygen 14 | source .venv/bin/activate 15 | pip install -r docs/requirements.txt 16 | pip install . -v 17 | 18 | - name: Build documentation 19 | shell: bash 20 | run: | 21 | source .venv/bin/activate 22 | cd docs 23 | doxygen 24 | make html O=-W 25 | 26 | - name: Create artifact tar 27 | shell: bash 28 | run: tar -cf artifact.tar -C docs --dereference build/html index.html 29 | 30 | # Do it manually because upload-pages-artifact requires gtar 31 | - name: Upload artifact 32 | id: upload-artifact 33 | uses: actions/upload-artifact@v5 34 | with: 35 | name: github-pages 36 | path: artifact.tar 37 | retention-days: 1 38 | if-no-files-found: error 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/src/python/nn/init.rst: -------------------------------------------------------------------------------- 1 | .. _init: 2 | 3 | .. currentmodule:: mlx.nn.init 4 | 5 | Initializers 6 | ------------ 7 | 8 | The ``mlx.nn.init`` package contains commonly used initializers for neural 9 | network parameters. Initializers return a function which can be applied to any 10 | input :obj:`mlx.core.array` to produce an initialized output. 11 | 12 | For example: 13 | 14 | .. code:: python 15 | 16 | import mlx.core as mx 17 | import mlx.nn as nn 18 | 19 | init_fn = nn.init.uniform() 20 | 21 | # Produces a [2, 2] uniform matrix 22 | param = init_fn(mx.zeros((2, 2))) 23 | 24 | To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform 25 | distribution, you can do: 26 | 27 | .. code:: python 28 | 29 | import mlx.nn as nn 30 | model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5)) 31 | init_fn = nn.init.uniform(low=-0.1, high=0.1) 32 | model.apply(init_fn) 33 | 34 | 35 | .. autosummary:: 36 | :toctree: _autosummary 37 | 38 | constant 39 | normal 40 | uniform 41 | identity 42 | glorot_normal 43 | glorot_uniform 44 | he_normal 45 | he_uniform 46 | -------------------------------------------------------------------------------- /docs/src/python/nn/layers.rst: -------------------------------------------------------------------------------- 1 | .. _layers: 2 | 3 | .. currentmodule:: mlx.nn 4 | 5 | Layers 6 | ------ 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | :template: nn-module-template.rst 11 | 12 | ALiBi 13 | AvgPool1d 14 | AvgPool2d 15 | AvgPool3d 16 | BatchNorm 17 | CELU 18 | Conv1d 19 | Conv2d 20 | Conv3d 21 | ConvTranspose1d 22 | ConvTranspose2d 23 | ConvTranspose3d 24 | Dropout 25 | Dropout2d 26 | Dropout3d 27 | Embedding 28 | ELU 29 | GELU 30 | GLU 31 | GroupNorm 32 | GRU 33 | HardShrink 34 | HardTanh 35 | Hardswish 36 | InstanceNorm 37 | LayerNorm 38 | LeakyReLU 39 | Linear 40 | LogSigmoid 41 | LogSoftmax 42 | LSTM 43 | MaxPool1d 44 | MaxPool2d 45 | MaxPool3d 46 | Mish 47 | MultiHeadAttention 48 | PReLU 49 | QuantizedEmbedding 50 | QuantizedLinear 51 | RMSNorm 52 | ReLU 53 | ReLU2 54 | ReLU6 55 | RNN 56 | RoPE 57 | SELU 58 | Sequential 59 | Sigmoid 60 | SiLU 61 | SinusoidalPositionalEncoding 62 | Softmin 63 | Softshrink 64 | Softsign 65 | Softmax 66 | Softplus 67 | Step 68 | Tanh 69 | Transformer 70 | Upsample 71 | -------------------------------------------------------------------------------- /python/src/indexing.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/array.h" 8 | #include "python/src/utils.h" 9 | 10 | namespace mx = mlx::core; 11 | namespace nb = nanobind; 12 | 13 | mx::array mlx_get_item(const mx::array& src, const nb::object& obj); 14 | void mlx_set_item( 15 | mx::array& src, 16 | const nb::object& obj, 17 | const ScalarOrArray& v); 18 | mx::array mlx_add_item( 19 | const mx::array& src, 20 | const nb::object& obj, 21 | const ScalarOrArray& v); 22 | mx::array mlx_subtract_item( 23 | const mx::array& src, 24 | const nb::object& obj, 25 | const ScalarOrArray& v); 26 | mx::array mlx_multiply_item( 27 | const mx::array& src, 28 | const nb::object& obj, 29 | const ScalarOrArray& v); 30 | mx::array mlx_divide_item( 31 | const mx::array& src, 32 | const nb::object& obj, 33 | const ScalarOrArray& v); 34 | mx::array mlx_maximum_item( 35 | const mx::array& src, 36 | const nb::object& obj, 37 | const ScalarOrArray& v); 38 | mx::array mlx_minimum_item( 39 | const mx::array& src, 40 | const nb::object& obj, 41 | const ScalarOrArray& v); 42 | -------------------------------------------------------------------------------- /mlx/backend/cuda/no_cuda.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/cuda/cuda.h" 4 | #include "mlx/fast.h" 5 | 6 | namespace mlx::core { 7 | 8 | namespace cu { 9 | 10 | bool is_available() { 11 | return false; 12 | } 13 | 14 | } // namespace cu 15 | 16 | namespace fast { 17 | 18 | CustomKernelFunction cuda_kernel( 19 | const std::string&, 20 | const std::vector&, 21 | const std::vector&, 22 | const std::string&, 23 | const std::string&, 24 | bool, 25 | int) { 26 | throw std::runtime_error("[cuda_kernel] No CUDA back-end."); 27 | } 28 | 29 | std::vector precompiled_cuda_kernel( 30 | const std::string&, 31 | const std::string&, 32 | const std::vector&, 33 | const std::vector&, 34 | const std::vector&, 35 | const std::vector&, 36 | std::tuple, 37 | std::tuple, 38 | int shared_memory, 39 | std::optional init_value, 40 | bool ensure_row_contiguous, 41 | StreamOrDevice) { 42 | throw std::runtime_error("[cuda_kernel] No CUDA back-end."); 43 | } 44 | 45 | } // namespace fast 46 | 47 | } // namespace mlx::core 48 | -------------------------------------------------------------------------------- /tests_jaccl/README.md: -------------------------------------------------------------------------------- 1 | # JACCL Distributed MLP Test 2 | 3 | This directory contains scripts to test Tensor Parallel MLP using the JACCL backend on `m4p.local` and `m3u.local`. 4 | 5 | ## Configuration 6 | 7 | - **Hosts**: 8 | - Rank 0: `m4p.local` (IP: `10.1.12.4`, Device: `rdma_en2`) 9 | - Rank 1: `m3u.local` (IP: `10.1.12.3`, Device: `rdma_en5`) 10 | - **Config File**: `jaccl_config.json` maps ranks to devices. 11 | 12 | ## Files 13 | 14 | - `test_tp_mlp.py`: The distributed MLP test script. 15 | - `jaccl_config.json`: Device configuration. 16 | - `run_jaccl.sh`: Helper script to run the test. 17 | - `deploy.sh`: Helper script to sync files to hosts. 18 | 19 | ## How to Run 20 | 21 | 1. **Deploy files**: 22 | Run `./deploy.sh` from this directory to sync code to both hosts. 23 | 24 | 2. **Run on Rank 0 (m4p.local)**: 25 | ```bash 26 | ssh m4p.local 27 | cd /Users/anemll/SourceRelease/GITHUB/ML_playground/mlx-rdma/tests_jaccl 28 | ./run_jaccl.sh 0 10.1.12.4:1234 29 | ``` 30 | 31 | 3. **Run on Rank 1 (m3u.local)**: 32 | ```bash 33 | ssh m3u.local 34 | cd /Users/anemll/SourceRelease/GITHUB/ML_playground/mlx-rdma/tests_jaccl 35 | ./run_jaccl.sh 1 10.1.12.4:1234 36 | ``` 37 | -------------------------------------------------------------------------------- /mlx/backend/cpu/gemms/simd_fp16.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/common/utils.h" 4 | #include "mlx/backend/cpu/gemm.h" 5 | #include "mlx/backend/cpu/gemms/simd_gemm.h" 6 | 7 | namespace mlx::core { 8 | 9 | template <> 10 | void matmul( 11 | const float16_t* a, 12 | const float16_t* b, 13 | float16_t* out, 14 | bool a_transposed, 15 | bool b_transposed, 16 | size_t lda, 17 | size_t ldb, 18 | size_t ldc, 19 | float alpha, 20 | float beta, 21 | size_t batch_size, 22 | const Shape& a_shape, 23 | const Strides& a_strides, 24 | const Shape& b_shape, 25 | const Strides& b_strides) { 26 | auto ndim = a_shape.size(); 27 | size_t M = a_shape[ndim - 2]; 28 | size_t N = b_shape[ndim - 1]; 29 | size_t K = a_shape[ndim - 1]; 30 | for (int i = 0; i < batch_size; ++i) { 31 | simd_gemm( 32 | a + elem_to_loc(M * K * i, a_shape, a_strides), 33 | b + elem_to_loc(K * N * i, b_shape, b_strides), 34 | out + M * N * i, 35 | a_transposed, 36 | b_transposed, 37 | M, 38 | N, 39 | K, 40 | alpha, 41 | beta); 42 | } 43 | } 44 | 45 | } // namespace mlx::core 46 | -------------------------------------------------------------------------------- /mlx/backend/cpu/gemms/simd_bf16.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/common/utils.h" 4 | #include "mlx/backend/cpu/gemm.h" 5 | #include "mlx/backend/cpu/gemms/simd_gemm.h" 6 | 7 | namespace mlx::core { 8 | 9 | template <> 10 | void matmul( 11 | const bfloat16_t* a, 12 | const bfloat16_t* b, 13 | bfloat16_t* out, 14 | bool a_transposed, 15 | bool b_transposed, 16 | size_t lda, 17 | size_t ldb, 18 | size_t ldc, 19 | float alpha, 20 | float beta, 21 | size_t batch_size, 22 | const Shape& a_shape, 23 | const Strides& a_strides, 24 | const Shape& b_shape, 25 | const Strides& b_strides) { 26 | auto ndim = a_shape.size(); 27 | size_t M = a_shape[ndim - 2]; 28 | size_t N = b_shape[ndim - 1]; 29 | size_t K = a_shape[ndim - 1]; 30 | for (int i = 0; i < batch_size; ++i) { 31 | simd_gemm( 32 | a + elem_to_loc(M * K * i, a_shape, a_strides), 33 | b + elem_to_loc(K * N * i, b_shape, b_strides), 34 | out + M * N * i, 35 | a_transposed, 36 | b_transposed, 37 | M, 38 | N, 39 | K, 40 | alpha, 41 | beta); 42 | } 43 | } 44 | 45 | } // namespace mlx::core 46 | -------------------------------------------------------------------------------- /.github/actions/build-linux-release/action.yml: -------------------------------------------------------------------------------- 1 | name: 'Build Linux wheel' 2 | description: 'Build Linux wheel' 3 | 4 | inputs: 5 | build-backend: 6 | description: 'Build the backend mlx-cpu package' 7 | type: boolean 8 | required: false 9 | default: false 10 | arch: 11 | description: 'Platform architecture tag' 12 | required: true 13 | type: choice 14 | options: 15 | - x86_64 16 | - aarch64 17 | 18 | runs: 19 | using: "composite" 20 | steps: 21 | - name: Generate package stubs 22 | shell: bash 23 | run: | 24 | pip install -e ".[dev]" -v 25 | pip install typing_extensions 26 | python setup.py generate_stubs 27 | - name: Build Python package 28 | shell: bash 29 | run: | 30 | pip install auditwheel patchelf build 31 | python setup.py clean --all 32 | MLX_BUILD_STAGE=1 python -m build -w 33 | bash python/scripts/repair_linux.sh ${{ inputs.arch }} 34 | - name: Build backend package 35 | if: ${{ inputs.build-backend }} 36 | shell: bash 37 | run: | 38 | python setup.py clean --all 39 | MLX_BUILD_STAGE=2 python -m build -w 40 | auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }} 41 | -------------------------------------------------------------------------------- /mlx/backend/cuda/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | // This file include utilities that are used by C++ code (i.e. .cpp files). 4 | 5 | #pragma once 6 | 7 | #include "mlx/array.h" 8 | #include "mlx/backend/cuda/allocator.h" 9 | #include "mlx/backend/cuda/cuda_utils.h" 10 | 11 | namespace mlx::core { 12 | 13 | template 14 | inline uint max_occupancy_block_dim(T kernel) { 15 | int _, block_dim; 16 | if constexpr (std::is_same_v) { 17 | CHECK_CUDA_ERROR( 18 | cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); 19 | } else { 20 | CHECK_CUDA_ERROR( 21 | cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); 22 | } 23 | return block_dim; 24 | } 25 | 26 | template 27 | inline T* gpu_ptr(array& arr) { 28 | return reinterpret_cast( 29 | static_cast( 30 | static_cast(arr.buffer().ptr())->data) + 31 | arr.offset()); 32 | } 33 | 34 | template 35 | inline const T* gpu_ptr(const array& arr) { 36 | return gpu_ptr(const_cast(arr)); 37 | } 38 | 39 | struct Dtype; 40 | 41 | // Convert Dtype to CUDA C++ types. 42 | const char* dtype_to_cuda_type(const Dtype& dtype); 43 | 44 | } // namespace mlx::core 45 | -------------------------------------------------------------------------------- /mlx/io/no_safetensors.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include "mlx/io.h" 4 | 5 | namespace mlx::core { 6 | 7 | SafetensorsLoad load_safetensors(std::shared_ptr, StreamOrDevice) { 8 | throw std::runtime_error( 9 | "[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " 10 | "to enable safetensors support."); 11 | } 12 | 13 | SafetensorsLoad load_safetensors(const std::string&, StreamOrDevice) { 14 | throw std::runtime_error( 15 | "[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " 16 | "to enable safetensors support."); 17 | } 18 | 19 | void save_safetensors( 20 | std::shared_ptr, 21 | std::unordered_map, 22 | std::unordered_map) { 23 | throw std::runtime_error( 24 | "[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " 25 | "to enable safetensors support."); 26 | } 27 | 28 | void save_safetensors( 29 | std::string file, 30 | std::unordered_map, 31 | std::unordered_map) { 32 | throw std::runtime_error( 33 | "[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON " 34 | "to enable safetensors support."); 35 | } 36 | 37 | } // namespace mlx::core 38 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/indexing/masked_scatter.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | template 6 | [[kernel]] void masked_assign_impl( 7 | const device bool* mask [[buffer(0)]], 8 | const device uint* scatter_offsets [[buffer(1)]], 9 | const device T* src [[buffer(2)]], 10 | device T* out [[buffer(3)]], 11 | const constant int* src_shapes [[buffer(4)]], 12 | const constant int64_t* src_strides [[buffer(5)]], 13 | const constant int& src_ndim [[buffer(6)]], 14 | const constant int64_t& src_batch_size [[buffer(7)]], 15 | const constant int64_t& mask_batch_size [[buffer(8)]], 16 | uint idx [[thread_position_in_grid]]) { 17 | const bool mask_value = mask[idx]; 18 | if (!mask_value) { 19 | return; 20 | } 21 | 22 | const uint src_index = scatter_offsets[idx]; 23 | if (src_index >= src_batch_size) { 24 | return; 25 | } 26 | 27 | const uint batch_idx = idx / mask_batch_size; 28 | 29 | if (src_contiguous) { 30 | out[idx] = src[batch_idx * src_batch_size + src_index]; 31 | } else { 32 | out[idx] = src[elem_to_loc( 33 | batch_idx * src_batch_size + src_index, 34 | src_shapes, 35 | src_strides, 36 | src_ndim)]; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | METAL_FUNC ulong2 elem_to_loc_broadcast( 8 | uint elem, 9 | constant const int* shape, 10 | constant const int64_t* a_strides, 11 | constant const int64_t* b_strides, 12 | int ndim) { 13 | ulong loc_a{0}; 14 | ulong loc_b{0}; 15 | for (int i = ndim - 1; i >= 0 && elem > 0; --i) { 16 | int pos_in_dim = (elem % shape[i]); 17 | elem /= shape[i]; 18 | loc_a += pos_in_dim * a_strides[i]; 19 | loc_b += pos_in_dim * b_strides[i]; 20 | } 21 | return ulong2(loc_a, loc_b); 22 | } 23 | 24 | METAL_FUNC ulong3 elem_to_loc_broadcast( 25 | uint elem, 26 | constant const int* shape, 27 | constant const int64_t* a_strides, 28 | constant const int64_t* b_strides, 29 | constant const int64_t* c_strides, 30 | int ndim) { 31 | ulong loc_a{0}; 32 | ulong loc_b{0}; 33 | ulong loc_c{0}; 34 | for (int i = ndim - 1; i >= 0 && elem > 0; --i) { 35 | int pos_in_dim = (elem % shape[i]); 36 | elem /= shape[i]; 37 | loc_a += pos_in_dim * a_strides[i]; 38 | loc_b += pos_in_dim * b_strides[i]; 39 | loc_c += pos_in_dim * c_strides[i]; 40 | } 41 | return ulong3(loc_a, loc_b, loc_c); 42 | } 43 | -------------------------------------------------------------------------------- /mlx/device.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/cpu/available.h" 6 | #include "mlx/backend/gpu/available.h" 7 | #include "mlx/device.h" 8 | 9 | namespace mlx::core { 10 | 11 | Device& mutable_default_device() { 12 | static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu}; 13 | return default_device; 14 | } 15 | 16 | const Device& default_device() { 17 | return mutable_default_device(); 18 | } 19 | 20 | void set_default_device(const Device& d) { 21 | if (!gpu::is_available() && d == Device::gpu) { 22 | throw std::invalid_argument( 23 | "[set_default_device] Cannot set gpu device without gpu backend."); 24 | } 25 | mutable_default_device() = d; 26 | } 27 | 28 | bool operator==(const Device& lhs, const Device& rhs) { 29 | return lhs.type == rhs.type && lhs.index == rhs.index; 30 | } 31 | 32 | bool operator!=(const Device& lhs, const Device& rhs) { 33 | return !(lhs == rhs); 34 | } 35 | 36 | bool is_available(const Device& d) { 37 | switch (d.type) { 38 | case Device::cpu: 39 | return cpu::is_available(); 40 | case Device::gpu: 41 | return gpu::is_available(); 42 | } 43 | // appease compiler 44 | return false; 45 | } 46 | 47 | } // namespace mlx::core 48 | -------------------------------------------------------------------------------- /docs/src/python/array.rst: -------------------------------------------------------------------------------- 1 | .. _array: 2 | 3 | Array 4 | ===== 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | array 12 | array.astype 13 | array.at 14 | array.item 15 | array.tolist 16 | array.dtype 17 | array.itemsize 18 | array.nbytes 19 | array.ndim 20 | array.shape 21 | array.size 22 | array.real 23 | array.imag 24 | array.abs 25 | array.all 26 | array.any 27 | array.argmax 28 | array.argmin 29 | array.conj 30 | array.cos 31 | array.cummax 32 | array.cummin 33 | array.cumprod 34 | array.cumsum 35 | array.diag 36 | array.diagonal 37 | array.exp 38 | array.flatten 39 | array.log 40 | array.log10 41 | array.log1p 42 | array.log2 43 | array.logcumsumexp 44 | array.logsumexp 45 | array.max 46 | array.mean 47 | array.min 48 | array.moveaxis 49 | array.prod 50 | array.reciprocal 51 | array.reshape 52 | array.round 53 | array.rsqrt 54 | array.sin 55 | array.split 56 | array.sqrt 57 | array.square 58 | array.squeeze 59 | array.std 60 | array.sum 61 | array.swapaxes 62 | array.transpose 63 | array.T 64 | array.var 65 | array.view 66 | -------------------------------------------------------------------------------- /docs/src/python/random.rst: -------------------------------------------------------------------------------- 1 | .. _random: 2 | 3 | Random 4 | ====== 5 | 6 | Random sampling functions in MLX use an implicit global PRNG state by default. 7 | However, all function take an optional ``key`` keyword argument for when more 8 | fine-grained control or explicit state management is needed. 9 | 10 | For example, you can generate random numbers with: 11 | 12 | .. code-block:: python 13 | 14 | for _ in range(3): 15 | print(mx.random.uniform()) 16 | 17 | which will print a sequence of unique pseudo random numbers. Alternatively you 18 | can explicitly set the key: 19 | 20 | .. code-block:: python 21 | 22 | key = mx.random.key(0) 23 | for _ in range(3): 24 | print(mx.random.uniform(key=key)) 25 | 26 | which will yield the same pseudo random number at each iteration. 27 | 28 | Following `JAX's PRNG design `_ 29 | we use a splittable version of Threefry, which is a counter-based PRNG. 30 | 31 | .. currentmodule:: mlx.core.random 32 | 33 | .. autosummary:: 34 | :toctree: _autosummary 35 | 36 | bernoulli 37 | categorical 38 | gumbel 39 | key 40 | normal 41 | multivariate_normal 42 | randint 43 | seed 44 | split 45 | truncated_normal 46 | uniform 47 | laplace 48 | permutation 49 | -------------------------------------------------------------------------------- /mlx/backend/metal/jit/includes.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | namespace mlx::core::metal { 6 | 7 | const char* utils(); 8 | const char* binary_ops(); 9 | const char* unary_ops(); 10 | const char* ternary_ops(); 11 | const char* reduce_utils(); 12 | const char* gather(); 13 | const char* scatter(); 14 | const char* masked_scatter(); 15 | 16 | const char* arange(); 17 | const char* unary(); 18 | const char* binary(); 19 | const char* binary_two(); 20 | const char* copy(); 21 | const char* fft(); 22 | const char* gather_axis(); 23 | const char* gather_front(); 24 | const char* hadamard(); 25 | const char* logsumexp(); 26 | const char* quantized_utils(); 27 | const char* quantized(); 28 | const char* fp_quantized(); 29 | const char* ternary(); 30 | const char* scan(); 31 | const char* scatter_axis(); 32 | const char* softmax(); 33 | const char* sort(); 34 | const char* reduce(); 35 | 36 | const char* gemm(); 37 | const char* steel_gemm_fused(); 38 | const char* steel_gemm_masked(); 39 | const char* steel_gemm_splitk(); 40 | const char* steel_gemm_gather(); 41 | const char* steel_gemm_segmented(); 42 | const char* conv(); 43 | const char* steel_conv(); 44 | const char* steel_conv_general(); 45 | const char* gemv_masked(); 46 | 47 | } // namespace mlx::core::metal 48 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/fp4.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | constexpr constant static float FP4_LUT[16] = { 4 | +0.0f, 5 | +0.5f, 6 | +1.0f, 7 | +1.5f, 8 | +2.0f, 9 | +3.0f, 10 | +4.0f, 11 | +6.0f, 12 | -0.0f, 13 | -0.5f, 14 | -1.0f, 15 | -1.5f, 16 | -2.0f, 17 | -3.0f, 18 | -4.0f, 19 | -6.0f}; 20 | 21 | struct fp4_e2m1 { 22 | fp4_e2m1(float x) { 23 | if (metal::isnan(x)) { 24 | bits = 0x7; 25 | return; 26 | } 27 | 28 | const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0; 29 | x = metal::abs(x); 30 | 31 | if (x > 5.0f) { 32 | bits = 0x7; 33 | } else if (x >= 3.5f) { 34 | bits = 0x6; 35 | } else if (x > 2.5f) { 36 | bits = 0x5; 37 | } else if (x >= 1.75f) { 38 | bits = 0x4; 39 | } else if (x > 1.25f) { 40 | bits = 0x3; 41 | } else if (x >= 0.75f) { 42 | bits = 0x2; 43 | } else if (x > 0.25f) { 44 | bits = 0x1; 45 | } else { 46 | bits = 0x0; 47 | } 48 | bits |= sign_bit; 49 | } 50 | 51 | operator float() { 52 | half converted = as_type(ushort((bits & 7) << 9)); 53 | converted *= 16384.0; 54 | converted = bits & 8 ? -converted : converted; 55 | return converted; 56 | } 57 | 58 | uint8_t bits; 59 | }; 60 | -------------------------------------------------------------------------------- /mlx/fence.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | /* A fence to be used for synchronizing work between streams. 10 | * 11 | * Calls to `wait` wait in the given stream until all previous calls to update 12 | * are complete on their given stream. 13 | * 14 | * The array passed to `update` is computed and visible after the call to 15 | * `wait` returns. The array passed to `wait` will not be read until all 16 | * previous calls to `update` have completed. 17 | * 18 | * Note, calls to `update` should always be from the same thread or explicitly 19 | * synchronized so that they occur in sequence. Calls to `wait` can be on any 20 | * thread. 21 | * 22 | * For the Metal back-end the fence supports slow (default) and fast mode. 23 | * Fast mode requires setting the environment variable 24 | * `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires Metal 3.2+ (macOS 15+, 25 | * iOS 18+). 26 | */ 27 | class Fence { 28 | public: 29 | Fence() {}; 30 | explicit Fence(Stream stream); 31 | 32 | void update(Stream stream, const array& x, bool cross_device); 33 | void wait(Stream stream, const array& x); 34 | 35 | private: 36 | std::shared_ptr fence_{nullptr}; 37 | }; 38 | 39 | } // namespace mlx::core 40 | -------------------------------------------------------------------------------- /mlx/backend/cpu/eval.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | #include "mlx/backend/cpu/eval.h" 3 | #include "mlx/backend/cpu/encoder.h" 4 | #include "mlx/primitives.h" 5 | #include "mlx/scheduler.h" 6 | #include "mlx/utils.h" 7 | 8 | namespace mlx::core::cpu { 9 | 10 | void eval(array& arr) { 11 | auto s = arr.primitive().stream(); 12 | 13 | auto outputs = arr.outputs(); 14 | { 15 | // If the array is a tracer hold a reference 16 | // to its inputs so they don't get donated 17 | std::vector inputs; 18 | if (arr.is_tracer()) { 19 | inputs = arr.inputs(); 20 | } 21 | arr.primitive().eval_cpu(arr.inputs(), outputs); 22 | } 23 | 24 | std::unordered_set> buffers; 25 | for (auto& in : arr.inputs()) { 26 | buffers.insert(in.data_shared_ptr()); 27 | } 28 | for (auto& s : arr.siblings()) { 29 | buffers.insert(s.data_shared_ptr()); 30 | } 31 | // Remove the output if it was donated to by an input 32 | if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { 33 | buffers.erase(it); 34 | } 35 | auto& encoder = cpu::get_command_encoder(s); 36 | encoder.dispatch([buffers = std::move(buffers), 37 | temps = std::move(encoder.temporaries())]() {}); 38 | } 39 | 40 | } // namespace mlx::core::cpu 41 | -------------------------------------------------------------------------------- /mlx/backend/cpu/make_compiled_preamble.ps1: -------------------------------------------------------------------------------- 1 | # This script generates a C++ function that provides the CPU 2 | # code for use with kernel generation. 3 | # 4 | # Copyright © 2024 Apple Inc. 5 | 6 | $OUTPUT_FILE = $args[0] 7 | $CL = $args[1] 8 | $SRCDIR = $args[2] 9 | 10 | # Get command result as array. 11 | $CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 12 | # Remove empty lines. 13 | # Otherwise there will be too much empty lines making the result unreadable. 14 | $CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' } 15 | # Concatenate to string. 16 | $CONTENT = $CONTENT -join "`n" 17 | 18 | # Append extra content. 19 | $CONTENT = @" 20 | $($CONTENT) 21 | using namespace mlx::core; 22 | using namespace mlx::core::detail; 23 | "@ 24 | 25 | # Convert each char to ASCII code. 26 | # Unlike the unix script that outputs string literal directly, the output from 27 | # MSVC is way too large to be embedded as string and compilation will fail, so 28 | # we store it as static array instead. 29 | $CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0' 30 | 31 | $OUTPUT = @" 32 | const char* get_kernel_preamble() { 33 | static char preamble[] = { $CHARCODES }; 34 | return preamble; 35 | } 36 | "@ 37 | 38 | Set-Content -Path $OUTPUT_FILE -Value $OUTPUT 39 | -------------------------------------------------------------------------------- /python/mlx/_stub_patterns.txt: -------------------------------------------------------------------------------- 1 | mlx.core.__prefix__: 2 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union 3 | import sys 4 | if sys.version_info >= (3, 10): 5 | from typing import TypeAlias 6 | else: 7 | from typing_extensions import TypeAlias 8 | 9 | mlx.core.__suffix__: 10 | from typing import Union 11 | scalar: TypeAlias = Union[int, float, bool] 12 | list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]] 13 | bool_: Dtype = ... 14 | 15 | mlx.core.distributed.__prefix__: 16 | from mlx.core import array, Dtype, Device, Stream, scalar 17 | from mlx.core.distributed import Group 18 | from typing import Sequence, Optional, Union 19 | 20 | mlx.core.fast.__prefix__: 21 | from mlx.core import array, Dtype, Device, Stream, scalar 22 | from typing import Sequence, Optional, Union 23 | 24 | mlx.core.linalg.__prefix__: 25 | from mlx.core import array, Dtype, Device, Stream, scalar 26 | from typing import Sequence, Optional, Tuple, Union 27 | 28 | mlx.core.metal.__prefix__: 29 | from mlx.core import array, Dtype, Device, Stream, scalar 30 | from typing import Sequence, Optional, Union 31 | 32 | mlx.core.random.__prefix__: 33 | from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32 34 | from typing import Sequence, Optional, Union 35 | -------------------------------------------------------------------------------- /mlx/backend/gpu/slicing.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/backend/common/slicing.h" 4 | #include "mlx/backend/gpu/copy.h" 5 | #include "mlx/backend/gpu/slicing.h" 6 | 7 | namespace mlx::core { 8 | 9 | void slice_gpu( 10 | const array& in, 11 | array& out, 12 | const Shape& start_indices, 13 | const Shape& strides, 14 | const Stream&) { 15 | slice(in, out, start_indices, strides); 16 | } 17 | 18 | void pad_gpu( 19 | const array& in, 20 | const array& val, 21 | array& out, 22 | const std::vector& axes, 23 | const Shape& low_pad_size, 24 | const Stream& s) { 25 | // Fill output with val 26 | fill_gpu(val, out, s); 27 | 28 | // Find offset for start of input values 29 | size_t data_offset = 0; 30 | for (int i = 0; i < axes.size(); i++) { 31 | auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i]; 32 | data_offset += out.strides()[ax] * low_pad_size[i]; 33 | } 34 | 35 | // Extract slice from output where input will be pasted 36 | array out_slice(in.shape(), out.dtype(), nullptr, {}); 37 | out_slice.copy_shared_buffer( 38 | out, out.strides(), out.flags(), out_slice.size(), data_offset); 39 | 40 | // Copy input values into the slice 41 | copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s); 42 | } 43 | 44 | } // namespace mlx::core 45 | -------------------------------------------------------------------------------- /benchmarks/python/synchronize_bench.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import mlx.core as mx 4 | 5 | rank = mx.distributed.init().rank() 6 | 7 | 8 | def timeit(fn, a): 9 | 10 | # warmup 11 | for _ in range(5): 12 | mx.eval(fn(a)) 13 | 14 | its = 10 15 | tic = time.perf_counter() 16 | for _ in range(its): 17 | mx.eval(fn(a)) 18 | toc = time.perf_counter() 19 | ms = 1000 * (toc - tic) / its 20 | return ms 21 | 22 | 23 | def all_reduce_benchmark(): 24 | a = mx.ones((5, 5), mx.int32) 25 | 26 | its_per_eval = 100 27 | 28 | def fn(x): 29 | for _ in range(its_per_eval): 30 | x = mx.distributed.all_sum(x) 31 | x = x - 1 32 | return x 33 | 34 | ms = timeit(fn, a) / its_per_eval 35 | if rank == 0: 36 | print(f"All Reduce: time per iteration {ms:.6f} (ms)") 37 | 38 | 39 | def all_gather_benchmark(): 40 | a = mx.ones((5, 5), mx.int32) 41 | its_per_eval = 100 42 | 43 | def fn(x): 44 | for _ in range(its_per_eval): 45 | x = mx.distributed.all_gather(x)[0] 46 | return x 47 | 48 | ms = timeit(fn, a) / its_per_eval 49 | if rank == 0: 50 | print(f"All gather: time per iteration {ms:.6f} (ms)") 51 | 52 | 53 | if __name__ == "__main__": 54 | all_reduce_benchmark() 55 | all_gather_benchmark() 56 | -------------------------------------------------------------------------------- /create_library_symlinks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Create symlinks for RDMA libraries from CommandLineTools SDK to Xcode-beta SDK 3 | # This is needed because Xcode-beta SDK is missing these library files 4 | 5 | set -e 6 | 7 | echo "=== Creating RDMA Library Symlinks ===" 8 | echo "" 9 | echo "This will create symlinks in Xcode-beta SDK pointing to CommandLineTools SDK" 10 | echo "Requires sudo access" 11 | echo "" 12 | 13 | # Create librdma.tbd symlink 14 | echo "Creating librdma.tbd symlink..." 15 | sudo ln -sf \ 16 | /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib/librdma.tbd \ 17 | /Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX26.0.sdk/usr/lib/librdma.tbd 18 | 19 | echo "✓ Created: librdma.tbd" 20 | 21 | # Create rdma directory symlink 22 | echo "Creating rdma directory symlink..." 23 | sudo ln -sf \ 24 | /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib/rdma \ 25 | /Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX26.0.sdk/usr/lib/rdma 26 | 27 | echo "✓ Created: rdma/" 28 | 29 | echo "" 30 | echo "=== Symlinks Created Successfully ===" 31 | echo "" 32 | echo "Verify with:" 33 | echo " ls -la /Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX26.0.sdk/usr/lib/ | grep rdma" 34 | echo "" 35 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2024-25 Apple Inc. 2 | 3 | // clang-format off 4 | #include "mlx/backend/metal/kernels/utils.h" 5 | 6 | #include "mlx/backend/metal/kernels/steel/attn/attn.h" 7 | #include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h" 8 | 9 | #define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ 10 | instantiate_kernel( \ 11 | "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ 12 | "_wm" #wm "_wn" #wn "_mask" #mname, \ 13 | attention, dtype, bq, bk, bd, wm, wn, mtype, float) 14 | 15 | #define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ 16 | instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ 17 | instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ 18 | instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) 19 | 20 | #define instantiate_attn_mask_helper(iname, itype) \ 21 | instantiate_attn_shapes_helper(iname, itype, iname, itype) \ 22 | instantiate_attn_shapes_helper(iname, itype, bool_, bool) 23 | 24 | instantiate_attn_mask_helper(float16, half); 25 | instantiate_attn_mask_helper(bfloat16, bfloat16_t); 26 | 27 | instantiate_attn_mask_helper(float32, float); 28 | // clang-format on 29 | -------------------------------------------------------------------------------- /mlx/allocator.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace mlx::core::allocator { 8 | 9 | // Simple wrapper around buffer pointers 10 | // WARNING: Only Buffer objects constructed from and those that wrap 11 | // raw pointers from mlx::allocator are supported. 12 | class Buffer { 13 | private: 14 | void* ptr_; 15 | 16 | public: 17 | explicit Buffer(void* ptr) : ptr_(ptr) {}; 18 | 19 | // Get the raw data pointer from the buffer 20 | void* raw_ptr(); 21 | 22 | // Get the buffer pointer from the buffer 23 | const void* ptr() const { 24 | return ptr_; 25 | }; 26 | void* ptr() { 27 | return ptr_; 28 | }; 29 | }; 30 | 31 | Buffer malloc(size_t size); 32 | 33 | void free(Buffer buffer); 34 | 35 | class Allocator { 36 | /** Abstract base class for a memory allocator. */ 37 | public: 38 | virtual Buffer malloc(size_t size) = 0; 39 | virtual void free(Buffer buffer) = 0; 40 | virtual size_t size(Buffer buffer) const = 0; 41 | 42 | Allocator() = default; 43 | Allocator(const Allocator& other) = delete; 44 | Allocator(Allocator&& other) = delete; 45 | Allocator& operator=(const Allocator& other) = delete; 46 | Allocator& operator=(Allocator&& other) = delete; 47 | virtual ~Allocator() = default; 48 | }; 49 | 50 | Allocator& allocator(); 51 | 52 | } // namespace mlx::core::allocator 53 | -------------------------------------------------------------------------------- /mlx/backend/metal/distributed.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/allocator.h" 6 | #include "mlx/backend/common/utils.h" 7 | #include "mlx/backend/gpu/copy.h" 8 | #include "mlx/backend/metal/device.h" 9 | #include "mlx/backend/metal/utils.h" 10 | #include "mlx/distributed/ops.h" 11 | #include "mlx/distributed/primitives.h" 12 | #include "mlx/fence.h" 13 | #include "mlx/scheduler.h" 14 | 15 | namespace mlx::core::distributed { 16 | 17 | void AllReduce::eval_gpu(const std::vector&, std::vector&) { 18 | throw std::runtime_error("[AllReduce::eval_gpu] has no GPU implementation."); 19 | } 20 | 21 | void AllGather::eval_gpu(const std::vector&, std::vector&) { 22 | throw std::runtime_error("[AllGather::eval_gpu] has no GPU implementation."); 23 | } 24 | 25 | void Send::eval_gpu(const std::vector&, std::vector&) { 26 | throw std::runtime_error("[Send::eval_gpu] has no GPU implementation."); 27 | } 28 | 29 | void Recv::eval_gpu(const std::vector&, std::vector&) { 30 | throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation."); 31 | } 32 | 33 | void ReduceScatter::eval_gpu(const std::vector&, std::vector&) { 34 | throw std::runtime_error( 35 | "[ReduceScatter::eval_gpu] has no GPU implementation."); 36 | } 37 | 38 | } // namespace mlx::core::distributed 39 | -------------------------------------------------------------------------------- /mlx/distributed/ops.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/distributed/distributed.h" 8 | #include "mlx/utils.h" 9 | 10 | namespace mlx::core::distributed { 11 | 12 | array all_sum( 13 | const array& x, 14 | std::optional group = std::nullopt, 15 | StreamOrDevice s = {}); 16 | 17 | array all_gather( 18 | const array& x, 19 | std::optional group = std::nullopt, 20 | StreamOrDevice S = {}); 21 | 22 | array send( 23 | const array& x, 24 | int dst, 25 | std::optional group = std::nullopt, 26 | StreamOrDevice s = {}); 27 | 28 | array recv( 29 | Shape shape, 30 | Dtype dtype, 31 | int src, 32 | std::optional group = std::nullopt, 33 | StreamOrDevice s = {}); 34 | 35 | array recv_like( 36 | const array& x, 37 | int src, 38 | std::optional group = std::nullopt, 39 | StreamOrDevice s = {}); 40 | 41 | array all_max( 42 | const array& x, 43 | std::optional group = std::nullopt, 44 | StreamOrDevice s = {}); 45 | 46 | array all_min( 47 | const array& x, 48 | std::optional group = std::nullopt, 49 | StreamOrDevice s = {}); 50 | 51 | array sum_scatter( 52 | const array& x, 53 | std::optional group = std::nullopt, 54 | StreamOrDevice s = {}); 55 | 56 | } // namespace mlx::core::distributed 57 | -------------------------------------------------------------------------------- /benchmarks/cpp/time_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "mlx/mlx.h" 10 | 11 | #define milliseconds(x) \ 12 | (std::chrono::duration_cast(x).count() / 1e6) 13 | #define time_now() std::chrono::high_resolution_clock::now() 14 | 15 | #define TIME(FUNC, ...) \ 16 | std::cout << "Timing " << #FUNC << " ... " << std::flush \ 17 | << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \ 18 | << std::endl; 19 | 20 | #define TIMEM(MSG, FUNC, ...) \ 21 | std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \ 22 | << std::flush << std::setprecision(5) \ 23 | << time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl; 24 | 25 | template 26 | double time_fn(F fn, Args&&... args) { 27 | // warmup 28 | for (int i = 0; i < 5; ++i) { 29 | eval(fn(std::forward(args)...)); 30 | } 31 | 32 | int num_iters = 100; 33 | auto start = time_now(); 34 | for (int i = 0; i < num_iters; i++) { 35 | eval(fn(std::forward(args)...)); 36 | } 37 | auto end = time_now(); 38 | return milliseconds(end - start) / static_cast(num_iters); 39 | } 40 | -------------------------------------------------------------------------------- /python/src/mlx.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/version.h" 6 | 7 | namespace mx = mlx::core; 8 | namespace nb = nanobind; 9 | 10 | void init_mlx_func(nb::module_&); 11 | void init_array(nb::module_&); 12 | void init_device(nb::module_&); 13 | void init_stream(nb::module_&); 14 | void init_metal(nb::module_&); 15 | void init_cuda(nb::module_&); 16 | void init_memory(nb::module_&); 17 | void init_ops(nb::module_&); 18 | void init_transforms(nb::module_&); 19 | void init_random(nb::module_&); 20 | void init_fft(nb::module_&); 21 | void init_linalg(nb::module_&); 22 | void init_constants(nb::module_&); 23 | void init_fast(nb::module_&); 24 | void init_distributed(nb::module_&); 25 | void init_export(nb::module_&); 26 | 27 | NB_MODULE(core, m) { 28 | m.doc() = "mlx: A framework for machine learning on Apple silicon."; 29 | 30 | auto reprlib_fix = nb::module_::import_("mlx._reprlib_fix"); 31 | nb::set_leak_warnings(false); 32 | 33 | init_mlx_func(m); 34 | init_device(m); 35 | init_stream(m); 36 | init_array(m); 37 | init_metal(m); 38 | init_cuda(m); 39 | init_memory(m); 40 | init_ops(m); 41 | init_transforms(m); 42 | init_random(m); 43 | init_fft(m); 44 | init_linalg(m); 45 | init_constants(m); 46 | init_fast(m); 47 | init_distributed(m); 48 | init_export(m); 49 | 50 | m.attr("__version__") = mx::version(); 51 | } 52 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/utils/type_traits.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #pragma METAL internals : enable 8 | 9 | namespace metal { 10 | 11 | template 12 | struct is_empty : metal::bool_constant<__is_empty(T)> {}; 13 | 14 | #ifdef __cpp_variable_templates 15 | template 16 | constexpr constant bool is_empty_v = is_empty::value; 17 | #endif 18 | 19 | template 20 | struct make_void { 21 | typedef void type; 22 | }; 23 | 24 | template 25 | using void_t = typename make_void::type; 26 | 27 | template 28 | struct is_static : metal::bool_constant>::value> {}; 29 | 30 | template 31 | struct pointer_element {}; 32 | 33 | template 34 | struct pointer_element { 35 | using type = remove_cv_t; 36 | }; 37 | template 38 | struct pointer_element { 39 | using type = remove_cv_t; 40 | }; 41 | template 42 | struct pointer_element { 43 | using type = remove_cv_t; 44 | }; 45 | template 46 | struct pointer_element { 47 | using type = remove_cv_t; 48 | }; 49 | 50 | template 51 | using pointer_element_t = typename pointer_element>::type; 52 | 53 | } // namespace metal 54 | 55 | #pragma METAL internals : disable -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # tensor files 10 | *.safe 11 | *.safetensors 12 | 13 | # Metal libraries 14 | *.metallib 15 | venv/ 16 | 17 | # Distribution / packaging 18 | python/mlx/core 19 | python/mlx/share 20 | python/mlx/include 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | uv.lock 40 | 41 | # vim 42 | *.swp 43 | 44 | # Ignore build dir 45 | build/ 46 | 47 | # Prerequisites 48 | *.d 49 | 50 | # Compiled Object files 51 | *.slo 52 | *.lo 53 | *.o 54 | *.obj 55 | 56 | # Precompiled Headers 57 | *.gch 58 | *.pch 59 | 60 | # Compiled Dynamic libraries 61 | *.so 62 | *.dylib 63 | *.dll 64 | 65 | # Fortran module files 66 | *.mod 67 | *.smod 68 | 69 | # Compiled Static libraries 70 | *.lai 71 | *.la 72 | *.a 73 | *.lib 74 | 75 | # Executables 76 | *.exe 77 | *.out 78 | *.app 79 | 80 | # Debug symbols 81 | *.pdb 82 | 83 | # VSCode 84 | .vscode/ 85 | .DS_Store 86 | 87 | # Jetbrains 88 | .cache 89 | discover_ibv_devices 90 | check_rdma_peer_interface 91 | SCP_GUIDE.md 92 | check_rdma_peer_interface_simple 93 | tcp_roundtrip 94 | ibv_roundtrip 95 | .claude/CLAUDE.local.md 96 | QP.PLAN.md 97 | sync_hosts_debug.sh 98 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/gemm/params.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | /////////////////////////////////////////////////////////////////////////////// 6 | // GEMM param classes 7 | /////////////////////////////////////////////////////////////////////////////// 8 | 9 | namespace mlx { 10 | namespace steel { 11 | 12 | struct GEMMParams { 13 | const int M; 14 | const int N; 15 | const int K; 16 | 17 | const int lda; 18 | const int ldb; 19 | const int ldd; 20 | 21 | const int tiles_n; 22 | const int tiles_m; 23 | 24 | const int64_t batch_stride_a; 25 | const int64_t batch_stride_b; 26 | const int64_t batch_stride_d; 27 | 28 | const int swizzle_log; 29 | const int gemm_k_iterations_aligned; 30 | 31 | const int batch_ndim; 32 | }; 33 | 34 | struct GEMMSpiltKParams { 35 | const int M; 36 | const int N; 37 | const int K; 38 | 39 | const int lda; 40 | const int ldb; 41 | const int ldc; 42 | 43 | const int tiles_n; 44 | const int tiles_m; 45 | 46 | const int split_k_partitions; 47 | const int split_k_partition_stride; 48 | const int split_k_partition_size; 49 | 50 | const int gemm_k_iterations_aligned; 51 | }; 52 | 53 | struct GEMMAddMMParams { 54 | const int ldc; 55 | const int fdc; 56 | 57 | const int64_t batch_stride_c; 58 | 59 | const float alpha; 60 | const float beta; 61 | }; 62 | 63 | } // namespace steel 64 | } // namespace mlx 65 | -------------------------------------------------------------------------------- /python/src/convert.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | #pragma once 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "mlx/array.h" 10 | #include "mlx/ops.h" 11 | 12 | namespace mx = mlx::core; 13 | namespace nb = nanobind; 14 | 15 | namespace nanobind { 16 | static constexpr dlpack::dtype bfloat16{4, 16, 1}; 17 | }; // namespace nanobind 18 | 19 | struct ArrayLike { 20 | ArrayLike(nb::object obj) : obj(obj) {}; 21 | nb::object obj; 22 | }; 23 | 24 | using ArrayInitType = std::variant< 25 | nb::bool_, 26 | nb::int_, 27 | nb::float_, 28 | // Must be above ndarray 29 | mx::array, 30 | // Must be above complex 31 | nb::ndarray, 32 | std::complex, 33 | nb::list, 34 | nb::tuple, 35 | ArrayLike>; 36 | 37 | mx::array nd_array_to_mlx( 38 | nb::ndarray nd_array, 39 | std::optional dtype); 40 | 41 | nb::ndarray mlx_to_np_array(const mx::array& a); 42 | nb::ndarray<> mlx_to_dlpack(const mx::array& a); 43 | 44 | nb::object to_scalar(mx::array& a); 45 | 46 | nb::object tolist(mx::array& a); 47 | 48 | mx::array create_array(ArrayInitType v, std::optional t); 49 | mx::array array_from_list(nb::list pl, std::optional dtype); 50 | mx::array array_from_list(nb::tuple pl, std::optional dtype); 51 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MLX 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | 1. Fork and submit pull requests to the repo. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If a change is likely to impact efficiency, run some of the benchmarks before 11 | and after the change. Examples of benchmarks can be found in `benchmarks/python/`. 12 | 4. If you've changed APIs, update the documentation. 13 | 5. Every PR should have passing tests and at least one review. 14 | 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 15 | This should install hooks for running `black` and `clang-format` to ensure 16 | consistent style for C++ and python code. 17 | 18 | You can also run the formatters manually as follows: 19 | 20 | ```shell 21 | clang-format -i file.cpp 22 | ``` 23 | 24 | ```shell 25 | black file.py 26 | ``` 27 | 28 | or run `pre-commit run --all-files` to check all files in the repo. 29 | 30 | ## Issues 31 | 32 | We use GitHub issues to track public bugs. Please ensure your description is 33 | clear and has sufficient instructions to be able to reproduce the issue. 34 | 35 | ## License 36 | 37 | By contributing to MLX, you agree that your contributions will be licensed 38 | under the LICENSE file in the root directory of this source tree. 39 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/steel/attn/params.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | /////////////////////////////////////////////////////////////////////////////// 6 | // Attn param classes 7 | /////////////////////////////////////////////////////////////////////////////// 8 | 9 | namespace mlx { 10 | namespace steel { 11 | 12 | struct AttnParams { 13 | int B; ///< Batch Size 14 | int H; ///< Heads 15 | int D; ///< Head Dim 16 | 17 | int qL; ///< Query Sequence Length 18 | int kL; ///< Key Sequence Length 19 | 20 | int gqa_factor; ///< Group Query factor 21 | float scale; ///< Attention scale 22 | 23 | int NQ; ///< Number of query blocks 24 | int NK; ///< Number of key/value blocks 25 | 26 | int NQ_aligned; ///< Number of full query blocks 27 | int NK_aligned; ///< Number of full key/value blocks 28 | 29 | int qL_rem; ///< Remainder in last query block 30 | int kL_rem; ///< Remainder in last key/value block 31 | int qL_off; ///< Offset in query sequence start 32 | 33 | int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) 34 | int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) 35 | int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) 36 | int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) 37 | }; 38 | 39 | struct AttnMaskParams { 40 | int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) 41 | }; 42 | 43 | } // namespace steel 44 | } // namespace mlx 45 | -------------------------------------------------------------------------------- /python/src/load.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include "mlx/io.h" 16 | 17 | namespace mx = mlx::core; 18 | namespace nb = nanobind; 19 | 20 | using LoadOutputTypes = std::variant< 21 | mx::array, 22 | std::unordered_map, 23 | mx::SafetensorsLoad, 24 | mx::GGUFLoad>; 25 | 26 | mx::SafetensorsLoad mlx_load_safetensor_helper( 27 | nb::object file, 28 | mx::StreamOrDevice s); 29 | void mlx_save_safetensor_helper( 30 | nb::object file, 31 | nb::dict d, 32 | std::optional m); 33 | 34 | mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s); 35 | 36 | void mlx_save_gguf_helper( 37 | nb::object file, 38 | nb::dict d, 39 | std::optional m); 40 | 41 | LoadOutputTypes mlx_load_helper( 42 | nb::object file, 43 | std::optional format, 44 | bool return_metadata, 45 | mx::StreamOrDevice s); 46 | void mlx_save_helper(nb::object file, mx::array a); 47 | void mlx_savez_helper( 48 | nb::object file, 49 | nb::args args, 50 | const nb::kwargs& kwargs, 51 | bool compressed = false); 52 | -------------------------------------------------------------------------------- /mlx/event.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | #pragma once 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "mlx/stream.h" 9 | 10 | namespace mlx::core { 11 | 12 | class Event { 13 | public: 14 | Event() {}; 15 | explicit Event(Stream stream); 16 | 17 | // Wait for the event to be signaled at its current value 18 | void wait(); 19 | 20 | // Wait in the given stream for the event to be signaled at its current value 21 | void wait(Stream stream); 22 | 23 | // Signal the event at its current value in the given stream 24 | void signal(Stream stream); 25 | 26 | // Check if the event has been signaled at its current value 27 | bool is_signaled() const; 28 | 29 | // Check if the event is valid 30 | bool valid() const { 31 | return event_ != nullptr; 32 | } 33 | 34 | uint64_t value() const { 35 | return value_; 36 | } 37 | 38 | void set_value(uint64_t v) { 39 | value_ = v; 40 | } 41 | 42 | const Stream& stream() const { 43 | if (!valid()) { 44 | throw std::runtime_error( 45 | "[Event::stream] Cannot access stream on invalid event."); 46 | } 47 | return stream_; 48 | } 49 | 50 | private: 51 | // Default constructed stream should never be used 52 | // since the event is not yet valid 53 | Stream stream_{0, Device::cpu}; 54 | std::shared_ptr event_{nullptr}; 55 | uint64_t value_{0}; 56 | }; 57 | 58 | } // namespace mlx::core 59 | -------------------------------------------------------------------------------- /mlx/backend/cuda/primitives.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/distributed/primitives.h" 4 | #include "mlx/fast_primitives.h" 5 | #include "mlx/primitives.h" 6 | 7 | namespace mlx::core { 8 | 9 | #define NO_GPU_MULTI(func) \ 10 | void func::eval_gpu( \ 11 | const std::vector& inputs, std::vector& outputs) { \ 12 | throw std::runtime_error(#func " has no CUDA implementation."); \ 13 | } 14 | 15 | #define NO_GPU_USE_FALLBACK(func) \ 16 | bool func::use_fallback(Stream s) { \ 17 | return true; \ 18 | } \ 19 | NO_GPU_MULTI(func) 20 | 21 | #define NO_GPU(func) \ 22 | void func::eval_gpu(const std::vector& inputs, array& out) { \ 23 | throw std::runtime_error(#func " has no CUDA implementation."); \ 24 | } 25 | 26 | NO_GPU(BlockMaskedMM) 27 | NO_GPU(FFT) 28 | NO_GPU(GatherMM) 29 | NO_GPU(GatherQMM) 30 | NO_GPU(Hadamard) 31 | NO_GPU_MULTI(LUF) 32 | NO_GPU_MULTI(QRF) 33 | NO_GPU(QuantizedMatmul) 34 | NO_GPU(SegmentedMM) 35 | NO_GPU_MULTI(SVD) 36 | NO_GPU(Inverse) 37 | NO_GPU(Cholesky) 38 | NO_GPU_MULTI(Eig) 39 | NO_GPU_MULTI(Eigh) 40 | NO_GPU(MaskedScatter) 41 | 42 | namespace distributed { 43 | NO_GPU_MULTI(Send) 44 | NO_GPU_MULTI(Recv) 45 | } // namespace distributed 46 | 47 | } // namespace mlx::core 48 | -------------------------------------------------------------------------------- /mlx/backend/common/copy.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/common/utils.h" 6 | 7 | namespace mlx::core { 8 | 9 | enum class CopyType { 10 | // Copy a raw scalar input into the full contiguous output 11 | Scalar, 12 | 13 | // Copy the raw input buffer contiguously into a raw output buffer of the same 14 | // size 15 | Vector, 16 | 17 | // Copy the full virtual input to the full contiguous output 18 | General, 19 | 20 | // Copy the full virtual input to the full virtual output. We assume the 21 | // input and output have the same shape. 22 | GeneralGeneral 23 | }; 24 | 25 | inline bool set_copy_output_data( 26 | const array& in, 27 | array& out, 28 | CopyType ctype, 29 | std::function mallocfn = allocator::malloc) { 30 | if (ctype == CopyType::Vector) { 31 | // If the input is donateable, we are doing a vector copy and the types 32 | // have the same size, then the input buffer can hold the output. 33 | if (is_donatable(in, out)) { 34 | out.copy_shared_buffer(in); 35 | return true; 36 | } else { 37 | out.set_data( 38 | mallocfn(in.data_size() * out.itemsize()), 39 | in.data_size(), 40 | in.strides(), 41 | in.flags()); 42 | return false; 43 | } 44 | } else { 45 | out.set_data(mallocfn(out.nbytes())); 46 | return false; 47 | } 48 | } 49 | 50 | } // namespace mlx::core 51 | -------------------------------------------------------------------------------- /mlx/backend/no_gpu/event.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | #include "mlx/event.h" 4 | #include "mlx/scheduler.h" 5 | 6 | #include 7 | #include 8 | 9 | namespace mlx::core { 10 | 11 | struct EventCounter { 12 | uint64_t value{0}; 13 | std::mutex mtx; 14 | std::condition_variable cv; 15 | }; 16 | 17 | Event::Event(Stream stream) : stream_(stream) { 18 | auto dtor = [](void* ptr) { delete static_cast(ptr); }; 19 | event_ = std::shared_ptr(new EventCounter{}, dtor); 20 | } 21 | 22 | void Event::wait() { 23 | auto ec = static_cast(event_.get()); 24 | std::unique_lock lk(ec->mtx); 25 | if (ec->value >= value()) { 26 | return; 27 | } 28 | ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; }); 29 | } 30 | 31 | void Event::wait(Stream stream) { 32 | scheduler::enqueue(stream, [*this]() mutable { wait(); }); 33 | } 34 | 35 | void Event::signal(Stream stream) { 36 | scheduler::enqueue(stream, [*this]() mutable { 37 | auto ec = static_cast(event_.get()); 38 | { 39 | std::lock_guard lk(ec->mtx); 40 | ec->value = value(); 41 | } 42 | ec->cv.notify_all(); 43 | }); 44 | } 45 | 46 | bool Event::is_signaled() const { 47 | auto ec = static_cast(event_.get()); 48 | { 49 | std::lock_guard lk(ec->mtx); 50 | return (ec->value >= value()); 51 | } 52 | } 53 | } // namespace mlx::core 54 | -------------------------------------------------------------------------------- /mlx/compile.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023-2024 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | enum class CompileMode { disabled, no_simplify, no_fuse, enabled }; 10 | 11 | /** Compile takes a function and returns a compiled function. */ 12 | std::function(const std::vector&)> compile( 13 | std::function(const std::vector&)> fun, 14 | bool shapeless = false); 15 | 16 | std::function(const std::vector&)> compile( 17 | std::vector (*fun)(const std::vector&), 18 | bool shapeless = false); 19 | 20 | // Convert capture-less lambdas to function pointers. 21 | template < 22 | typename F, 23 | typename = std::enable_if_t< 24 | std::is_convertible_v())>>> 25 | std::function(const std::vector&)> compile( 26 | F&& f, 27 | bool shapeless = false) { 28 | return compile(+f, shapeless); 29 | } 30 | 31 | /** Globally disable compilation. 32 | * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also 33 | * be used to disable compilation. 34 | */ 35 | void disable_compile(); 36 | 37 | /** Globally enable compilation. 38 | * This will override the environment variable ``MLX_DISABLE_COMPILE``. 39 | */ 40 | void enable_compile(); 41 | 42 | /** Set the compiler mode to the given value. */ 43 | void set_compile_mode(CompileMode mode); 44 | } // namespace mlx::core 45 | -------------------------------------------------------------------------------- /mlx/backend/cuda/copy/copy.cuh: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/cuda/device.h" 6 | #include "mlx/backend/cuda/device/cast_op.cuh" 7 | #include "mlx/backend/cuda/kernel_utils.cuh" 8 | #include "mlx/backend/gpu/copy.h" 9 | #include "mlx/dtype_utils.h" 10 | 11 | namespace mlx::core { 12 | 13 | void copy_contiguous( 14 | cu::CommandEncoder& encoder, 15 | CopyType ctype, 16 | const array& in, 17 | array& out, 18 | int64_t offset_in, 19 | int64_t offset_out); 20 | 21 | void copy_general( 22 | cu::CommandEncoder& encoder, 23 | CopyType ctype, 24 | const array& in, 25 | array& out, 26 | int64_t offset_in, 27 | int64_t offset_out, 28 | const Shape& shape, 29 | const Strides& strides_in, 30 | const Strides& strides_out); 31 | 32 | void copy_general_dynamic( 33 | cu::CommandEncoder& encoder, 34 | CopyType ctype, 35 | const array& in, 36 | array& out, 37 | int64_t offset_in, 38 | int64_t offset_out, 39 | const Shape& shape, 40 | const Strides& strides_in, 41 | const Strides& strides_out, 42 | const array& dynamic_offset_in, 43 | const array& dynamic_offset_out); 44 | 45 | void copy_general_input( 46 | cu::CommandEncoder& encoder, 47 | CopyType ctype, 48 | const array& in, 49 | array& out, 50 | int64_t offset_in, 51 | int64_t offset_out, 52 | const Shape& shape, 53 | const Strides& strides_in); 54 | 55 | } // namespace mlx::core 56 | -------------------------------------------------------------------------------- /mlx/backend/cuda/worker.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/cuda/event.h" 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace mlx::core::cu { 14 | 15 | // Run tasks in worker thread, synchronized with cuda stream. 16 | class Worker { 17 | public: 18 | explicit Worker(Device& d); 19 | ~Worker(); 20 | 21 | Worker(const Worker&) = delete; 22 | Worker& operator=(const Worker&) = delete; 23 | 24 | // Add a pending |task| that will run when consumed or commited. 25 | void add_task(std::function task); 26 | 27 | // Inform worker thread to run current batches after kernels in |stream| 28 | // finish running. 29 | void commit(cudaStream_t stream); 30 | 31 | private: 32 | static void signal(void*); 33 | 34 | void thread_fn(); 35 | std::mutex mtx_; 36 | std::condition_variable cond_; 37 | 38 | uint64_t committed_batch_{0}; 39 | uint64_t signaled_batch_{0}; 40 | 41 | // Cuda stream and event for signaling kernel completion. 42 | CudaStream signal_stream_; 43 | CudaEvent signal_event_; 44 | 45 | bool stop_{false}; 46 | 47 | // Tasks are put in |pending_tasks_| first, and then moved to 48 | // |worker_tasks_| when end_batch() is called. 49 | using Tasks = std::vector>; 50 | Tasks pending_tasks_; 51 | std::map worker_tasks_; 52 | std::thread worker_; 53 | }; 54 | 55 | } // namespace mlx::core::cu 56 | -------------------------------------------------------------------------------- /sync_to_m3u.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Sync files needed to rebuild MLX with JACCL on m3u.local 3 | 4 | set -e 5 | 6 | REMOTE_HOST="m3u.local" 7 | REMOTE_PATH="/Users/anemll/SourceRelease/GITHUB/ML_playground/mlx-rdma" 8 | 9 | echo "=== Syncing MLX-RDMA to m3u.local ===" 10 | echo "" 11 | 12 | # Check if m3u is reachable 13 | if ! ping -c 1 -W 1 $REMOTE_HOST &> /dev/null; then 14 | echo "ERROR: $REMOTE_HOST is not reachable" 15 | exit 1 16 | fi 17 | 18 | echo "✓ $REMOTE_HOST is reachable" 19 | echo "" 20 | 21 | # Files/directories to sync 22 | echo "Syncing source code and build scripts..." 23 | 24 | rsync -avz --progress \ 25 | --exclude '.git' \ 26 | --exclude 'build/*' \ 27 | --exclude '.venv' \ 28 | --exclude '*.o' \ 29 | --exclude '*.a' \ 30 | --exclude '*.so' \ 31 | --exclude '__pycache__' \ 32 | --exclude 'results.json' \ 33 | --exclude 'results_plot.png' \ 34 | --exclude 'ibv_roundtrip' \ 35 | --exclude 'discover_ibv_devices' \ 36 | --exclude 'check_rdma_peer_interface*' \ 37 | --exclude 'tcp_roundtrip' \ 38 | --exclude 'test_jaccl_init*' \ 39 | ./ $REMOTE_HOST:$REMOTE_PATH/ 40 | 41 | echo "" 42 | echo "=== Sync Complete ===" 43 | echo "" 44 | echo "Files synced to: $REMOTE_HOST:$REMOTE_PATH" 45 | echo "" 46 | echo "Next steps on m3u.local:" 47 | echo " 1. ssh $REMOTE_HOST" 48 | echo " 2. cd $REMOTE_PATH" 49 | echo " 3. ./setup_and_build_jaccl.sh" 50 | echo "" 51 | echo "Or run these commands:" 52 | echo " ssh $REMOTE_HOST 'cd $REMOTE_PATH && ./setup_and_build_jaccl.sh'" 53 | echo "" 54 | -------------------------------------------------------------------------------- /python/tests/test_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import unittest 4 | 5 | import mlx.core as mx 6 | import mlx_tests 7 | import numpy as np 8 | 9 | 10 | class TestConstants(mlx_tests.MLXTestCase): 11 | def test_constants_values(self): 12 | # Check if mlx constants match expected values 13 | self.assertAlmostEqual( 14 | mx.e, 2.71828182845904523536028747135266249775724709369995 15 | ) 16 | self.assertAlmostEqual( 17 | mx.euler_gamma, 0.5772156649015328606065120900824024310421 18 | ) 19 | self.assertAlmostEqual(mx.inf, float("inf")) 20 | self.assertTrue(np.isnan(mx.nan)) 21 | self.assertIsNone(mx.newaxis) 22 | self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433) 23 | 24 | def test_constants_availability(self): 25 | # Check if mlx constants are available 26 | self.assertTrue(hasattr(mx, "e")) 27 | self.assertTrue(hasattr(mx, "euler_gamma")) 28 | self.assertTrue(hasattr(mx, "inf")) 29 | self.assertTrue(hasattr(mx, "nan")) 30 | self.assertTrue(hasattr(mx, "newaxis")) 31 | self.assertTrue(hasattr(mx, "pi")) 32 | 33 | def test_newaxis_for_reshaping_arrays(self): 34 | arr_1d = mx.array([1, 2, 3, 4, 5]) 35 | arr_2d_column = arr_1d[:, mx.newaxis] 36 | expected_result = mx.array([[1], [2], [3], [4], [5]]) 37 | self.assertTrue(mx.array_equal(arr_2d_column, expected_result)) 38 | 39 | 40 | if __name__ == "__main__": 41 | mlx_tests.MLXTestRunner() 42 | -------------------------------------------------------------------------------- /mlx/backend/cuda/fence.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | #include "mlx/fence.h" 4 | #include "mlx/backend/cuda/allocator.h" 5 | #include "mlx/backend/cuda/device.h" 6 | #include "mlx/backend/cuda/event.h" 7 | 8 | namespace mlx::core { 9 | 10 | struct FenceImpl { 11 | uint32_t count; 12 | cu::AtomicEvent event; 13 | }; 14 | 15 | Fence::Fence(Stream s) { 16 | fence_ = std::shared_ptr( 17 | new FenceImpl{0}, [](void* ptr) { delete static_cast(ptr); }); 18 | } 19 | 20 | void Fence::wait(Stream s, const array&) { 21 | auto* fence = static_cast(fence_.get()); 22 | fence->event.wait(fence->count); 23 | } 24 | 25 | void Fence::update(Stream s, const array& a, bool cross_device) { 26 | auto* fence = static_cast(fence_.get()); 27 | if (cross_device) { 28 | // Move to managed memory if there is a device switch 29 | auto& cbuf = 30 | *static_cast(const_cast(a).buffer().ptr()); 31 | if (cbuf.device != -1) { 32 | void* new_data; 33 | CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size)); 34 | cbuf.device = -1; 35 | auto& encoder = cu::device(s.device).get_command_encoder(s); 36 | encoder.commit(); 37 | CHECK_CUDA_ERROR(cudaMemcpyAsync( 38 | new_data, cbuf.data, cbuf.size, cudaMemcpyDefault, encoder.stream())); 39 | CHECK_CUDA_ERROR(cudaFreeAsync(cbuf.data, encoder.stream())); 40 | cbuf.data = new_data; 41 | } 42 | } 43 | fence->count++; 44 | fence->event.signal(s, fence->count); 45 | } 46 | 47 | } // namespace mlx::core 48 | -------------------------------------------------------------------------------- /mlx/backend/cuda/quantized/quantized_utils.cuh: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | namespace mlx::core { 4 | 5 | namespace cu { 6 | 7 | template 8 | inline constexpr __device__ short get_pack_factor() { 9 | return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); 10 | } 11 | 12 | template 13 | inline constexpr __device__ short get_bytes_per_pack() { 14 | constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; 15 | return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); 16 | } 17 | 18 | } // namespace cu 19 | 20 | template 21 | void dispatch_groups(int group_size, F&& f) { 22 | switch (group_size) { 23 | case 32: 24 | f(std::integral_constant{}); 25 | break; 26 | case 64: 27 | f(std::integral_constant{}); 28 | break; 29 | case 128: 30 | f(std::integral_constant{}); 31 | break; 32 | } 33 | } 34 | 35 | template 36 | void dispatch_bits(int bits, F&& f) { 37 | switch (bits) { 38 | case 2: 39 | f(std::integral_constant{}); 40 | break; 41 | case 3: 42 | f(std::integral_constant{}); 43 | break; 44 | case 4: 45 | f(std::integral_constant{}); 46 | break; 47 | case 5: 48 | f(std::integral_constant{}); 49 | break; 50 | case 6: 51 | f(std::integral_constant{}); 52 | break; 53 | case 8: 54 | f(std::integral_constant{}); 55 | break; 56 | } 57 | } 58 | 59 | } // namespace mlx::core 60 | -------------------------------------------------------------------------------- /examples/cpp/logistic_regression.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "mlx/mlx.h" 8 | #include "timer.h" 9 | 10 | /** 11 | * An example of logistic regression with MLX. 12 | */ 13 | namespace mx = mlx::core; 14 | 15 | int main() { 16 | int num_features = 100; 17 | int num_examples = 1'000; 18 | int num_iters = 10'000; 19 | float learning_rate = 0.1; 20 | 21 | // True parameters 22 | auto w_star = mx::random::normal({num_features}); 23 | 24 | // The input examples 25 | auto X = mx::random::normal({num_examples, num_features}); 26 | 27 | // Labels 28 | auto y = mx::matmul(X, w_star) > 0; 29 | 30 | // Initialize random parameters 31 | mx::array w = 1e-2 * mx::random::normal({num_features}); 32 | 33 | auto loss_fn = [&](mx::array w) { 34 | auto logits = mx::matmul(X, w); 35 | auto scale = (1.0f / num_examples); 36 | return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits); 37 | }; 38 | 39 | auto grad_fn = mx::grad(loss_fn); 40 | 41 | auto tic = timer::time(); 42 | for (int it = 0; it < num_iters; ++it) { 43 | auto grads = grad_fn(w); 44 | w = w - learning_rate * grads; 45 | mx::eval(w); 46 | } 47 | auto toc = timer::time(); 48 | 49 | auto loss = loss_fn(w); 50 | auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples; 51 | auto throughput = num_iters / timer::seconds(toc - tic); 52 | std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput " 53 | << throughput << " (it/s)." << std::endl; 54 | } 55 | -------------------------------------------------------------------------------- /examples/export/eval_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2024 Apple Inc. 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | import mlx.utils 6 | 7 | 8 | class MLP(nn.Module): 9 | """A simple MLP.""" 10 | 11 | def __init__( 12 | self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int 13 | ): 14 | super().__init__() 15 | layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] 16 | self.layers = [ 17 | nn.Linear(idim, odim) 18 | for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) 19 | ] 20 | 21 | def __call__(self, x): 22 | for l in self.layers[:-1]: 23 | x = nn.relu(l(x)) 24 | return self.layers[-1](x) 25 | 26 | 27 | if __name__ == "__main__": 28 | 29 | batch_size = 8 30 | input_dim = 32 31 | output_dim = 10 32 | 33 | # Load the model 34 | mx.random.seed(0) # Seed for params 35 | model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim) 36 | mx.eval(model) 37 | 38 | # Note, the model parameters are saved in the export function 39 | def forward(x): 40 | return model(x) 41 | 42 | mx.random.seed(42) # Seed for input 43 | example_x = mx.random.uniform(shape=(batch_size, input_dim)) 44 | 45 | mx.export_function("eval_mlp.mlxfn", forward, example_x) 46 | 47 | # Import in Python 48 | imported_forward = mx.import_function("eval_mlp.mlxfn") 49 | expected = forward(example_x) 50 | (out,) = imported_forward(example_x) 51 | assert mx.allclose(expected, out) 52 | print(out) 53 | --------------------------------------------------------------------------------