├── docs
├── .nojekyll
├── .clang-format
├── index.html
├── requirements.txt
├── src
│ ├── _static
│ │ ├── mlx_logo.png
│ │ ├── mlx_logo_dark.png
│ │ └── metal_debugger
│ │ │ ├── capture.png
│ │ │ └── schema.png
│ ├── cpp
│ │ └── ops.rst
│ ├── python
│ │ ├── cuda.rst
│ │ ├── metal.rst
│ │ ├── export.rst
│ │ ├── fast.rst
│ │ ├── optimizers
│ │ │ ├── schedulers.rst
│ │ │ ├── common_optimizers.rst
│ │ │ └── optimizer.rst
│ │ ├── fft.rst
│ │ ├── memory_management.rst
│ │ ├── transforms.rst
│ │ ├── devices_and_streams.rst
│ │ ├── linalg.rst
│ │ ├── nn
│ │ │ ├── losses.rst
│ │ │ ├── functions.rst
│ │ │ ├── module.rst
│ │ │ ├── init.rst
│ │ │ └── layers.rst
│ │ ├── distributed.rst
│ │ ├── tree_utils.rst
│ │ ├── array.rst
│ │ └── random.rst
│ ├── _templates
│ │ ├── optimizers-template.rst
│ │ ├── nn-module-template.rst
│ │ └── module-base-class.rst
│ └── usage
│ │ └── using_streams.rst
├── .gitignore
├── Makefile
└── README.md
├── python
├── mlx
│ ├── py.typed
│ ├── optimizers
│ │ └── __init__.py
│ ├── nn
│ │ ├── __init__.py
│ │ └── layers
│ │ │ └── containers.py
│ ├── _reprlib_fix.py
│ ├── __main__.py
│ └── _stub_patterns.txt
├── tests
│ ├── __main__.py
│ ├── test_graph.py
│ └── test_constants.py
├── src
│ ├── cuda.cpp
│ ├── constants.cpp
│ ├── mlx_func.h
│ ├── indexing.h
│ ├── mlx.cpp
│ ├── convert.h
│ └── load.h
└── scripts
│ ├── repair_linux.sh
│ ├── repair_cuda.sh
│ └── repair_record.py
├── mlx
├── 3rdparty
│ └── .clang-format
├── backend
│ ├── cpu
│ │ ├── simd
│ │ │ ├── simd.h
│ │ │ └── type.h
│ │ ├── available.h
│ │ ├── available.cpp
│ │ ├── eval.h
│ │ ├── compiled_preamble.h
│ │ ├── encoder.cpp
│ │ ├── slicing.h
│ │ ├── gemm.h
│ │ ├── jit_compiler.h
│ │ ├── threefry.h
│ │ ├── arange.h
│ │ ├── threefry.cpp
│ │ ├── make_compiled_preamble.sh
│ │ ├── copy.h
│ │ ├── gemms
│ │ │ ├── simd_fp16.cpp
│ │ │ └── simd_bf16.cpp
│ │ ├── eval.cpp
│ │ └── make_compiled_preamble.ps1
│ ├── cuda
│ │ ├── unary
│ │ │ ├── abs.cu
│ │ │ ├── ceil.cu
│ │ │ ├── cos.cu
│ │ │ ├── cosh.cu
│ │ │ ├── erf.cu
│ │ │ ├── exp.cu
│ │ │ ├── imag.cu
│ │ │ ├── real.cu
│ │ │ ├── sign.cu
│ │ │ ├── sin.cu
│ │ │ ├── sinh.cu
│ │ │ ├── tan.cu
│ │ │ ├── tanh.cu
│ │ │ ├── arccos.cu
│ │ │ ├── arcsin.cu
│ │ │ ├── arctan.cu
│ │ │ ├── erf_inv.cu
│ │ │ ├── expm1.cu
│ │ │ ├── floor.cu
│ │ │ ├── log1p.cu
│ │ │ ├── square.cu
│ │ │ ├── arccosh.cu
│ │ │ ├── arcsinh.cu
│ │ │ ├── arctanh.cu
│ │ │ ├── conjugate.cu
│ │ │ ├── negative.cu
│ │ │ ├── sigmoid.cu
│ │ │ ├── logical_not.cu
│ │ │ ├── bitwise_invert.cu
│ │ │ ├── sqrt.cu
│ │ │ ├── round.cu
│ │ │ └── log.cu
│ │ ├── binary
│ │ │ ├── add.cu
│ │ │ ├── less.cu
│ │ │ ├── arctan2.cu
│ │ │ ├── divide.cu
│ │ │ ├── greater.cu
│ │ │ ├── maximum.cu
│ │ │ ├── minimum.cu
│ │ │ ├── power.cu
│ │ │ ├── less_equal.cu
│ │ │ ├── logical_or.cu
│ │ │ ├── multiply.cu
│ │ │ ├── not_equal.cu
│ │ │ ├── remainder.cu
│ │ │ ├── subtract.cu
│ │ │ ├── log_add_exp.cu
│ │ │ ├── logical_and.cu
│ │ │ ├── greater_equal.cu
│ │ │ ├── equal.cu
│ │ │ ├── bitwise_binary.cu
│ │ │ └── CMakeLists.txt
│ │ ├── cuda.cpp
│ │ ├── cuda.h
│ │ ├── steel
│ │ │ └── defines.cuh
│ │ ├── detect_cuda_arch.sh
│ │ ├── device
│ │ │ ├── ternary_ops.cuh
│ │ │ ├── config.h
│ │ │ ├── indexing.cuh
│ │ │ └── scatter_ops.cuh
│ │ ├── gemms
│ │ │ └── gemv.h
│ │ ├── quantized
│ │ │ ├── convert_fp8.cu
│ │ │ ├── quantized.h
│ │ │ └── quantized_utils.cuh
│ │ ├── no_cuda.cpp
│ │ ├── utils.h
│ │ ├── primitives.cpp
│ │ ├── copy
│ │ │ └── copy.cuh
│ │ ├── worker.h
│ │ └── fence.cpp
│ ├── gpu
│ │ ├── available.h
│ │ ├── CMakeLists.txt
│ │ ├── eval.h
│ │ ├── slicing.h
│ │ └── slicing.cpp
│ ├── metal
│ │ ├── kernels
│ │ │ ├── reduce_utils.h
│ │ │ ├── steel
│ │ │ │ ├── conv
│ │ │ │ │ ├── loader.h
│ │ │ │ │ └── conv.h
│ │ │ │ ├── defines.h
│ │ │ │ ├── utils.h
│ │ │ │ ├── attn
│ │ │ │ │ ├── kernels
│ │ │ │ │ │ └── steel_attention.metal
│ │ │ │ │ └── params.h
│ │ │ │ ├── utils
│ │ │ │ │ └── type_traits.h
│ │ │ │ └── gemm
│ │ │ │ │ └── params.h
│ │ │ ├── ternary_ops.h
│ │ │ ├── reduction
│ │ │ │ └── reduce_init.h
│ │ │ ├── reduce.h
│ │ │ ├── arange.h
│ │ │ ├── bf16.h
│ │ │ ├── indexing
│ │ │ │ ├── indexing.h
│ │ │ │ ├── gather_front.h
│ │ │ │ └── masked_scatter.h
│ │ │ ├── logsumexp.metal
│ │ │ ├── arange.metal
│ │ │ ├── defines.h
│ │ │ ├── softmax.metal
│ │ │ └── fp4.h
│ │ ├── scan.h
│ │ ├── unary.h
│ │ ├── ternary.h
│ │ ├── metal.h
│ │ ├── binary.h
│ │ ├── resident.h
│ │ ├── make_compiled_preamble.sh
│ │ ├── no_metal.cpp
│ │ ├── reduce.h
│ │ ├── jit
│ │ │ └── includes.h
│ │ └── distributed.cpp
│ ├── common
│ │ ├── broadcasting.h
│ │ ├── CMakeLists.txt
│ │ ├── slicing.h
│ │ ├── broadcasting.cpp
│ │ ├── unary.h
│ │ └── copy.h
│ ├── no_cpu
│ │ ├── available.cpp
│ │ ├── CMakeLists.txt
│ │ └── compiled.cpp
│ └── no_gpu
│ │ ├── CMakeLists.txt
│ │ ├── apple_memory.h
│ │ ├── linux_memory.h
│ │ ├── eval.cpp
│ │ └── event.cpp
├── version.cpp
├── distributed
│ ├── nccl
│ │ ├── nccl_stub
│ │ │ ├── nccl_stubs.cpp
│ │ │ └── CMakeLists.txt
│ │ ├── nccl.h
│ │ ├── no_nccl.cpp
│ │ └── CMakeLists.txt
│ ├── mpi
│ │ ├── CMakeLists.txt
│ │ ├── mpi.h
│ │ ├── no_mpi.cpp
│ │ └── mpi_declarations.h
│ ├── ring
│ │ ├── CMakeLists.txt
│ │ ├── ring.h
│ │ └── no_ring.cpp
│ ├── jaccl
│ │ ├── jaccl.h
│ │ ├── CMakeLists.txt
│ │ └── no_jaccl.cpp
│ ├── CMakeLists.txt
│ ├── reduction_ops.h
│ └── ops.h
├── io
│ ├── gguf.h
│ ├── no_gguf.cpp
│ ├── CMakeLists.txt
│ └── no_safetensors.cpp
├── einsum.h
├── version.h
├── allocator.cpp
├── mlx.h
├── device.h
├── dtype_utils.cpp
├── stream.h
├── device.cpp
├── fence.h
├── allocator.h
├── event.h
└── compile.h
├── tests_jaccl
├── jaccl_config.json
├── run_jaccl.sh
├── deploy.sh
├── build_cpp_benchmark.sh
└── README.md
├── examples
├── extensions
│ ├── requirements.txt
│ ├── mlx_sample_extensions
│ │ └── __init__.py
│ ├── pyproject.toml
│ ├── README.md
│ ├── test.py
│ ├── setup.py
│ └── bindings.cpp
├── cmake_project
│ ├── example.cpp
│ ├── README.md
│ └── CMakeLists.txt
├── cpp
│ ├── timer.h
│ ├── CMakeLists.txt
│ ├── distributed.cpp
│ ├── metal_capture.cpp
│ └── logistic_regression.cpp
├── export
│ ├── eval_mlp.cpp
│ ├── CMakeLists.txt
│ ├── README.md
│ ├── train_mlp.cpp
│ └── eval_mlp.py
└── python
│ ├── logistic_regression.py
│ └── linear_regression.py
├── .github
├── dependabot.yml
├── actions
│ ├── build-cuda-release
│ │ └── action.yml
│ ├── setup-macos
│ │ └── action.yml
│ ├── build-linux
│ │ └── action.yml
│ ├── build-cuda
│ │ └── action.yml
│ ├── build-macos-release
│ │ └── action.yml
│ ├── build-docs
│ │ └── action.yml
│ └── build-linux-release
│ │ └── action.yml
├── pull_request_template.md
├── ISSUE_TEMPLATE
│ └── bug_report.md
├── workflows
│ └── documentation.yml
└── scripts
│ └── setup+build-cpp-linux-fedora-container.sh
├── pyproject.toml
├── MANIFEST.in
├── cmake
└── Findnvpl.cmake
├── requirements-dev.txt
├── benchmarks
├── cpp
│ ├── CMakeLists.txt
│ ├── compare_devices.cpp
│ ├── autograd.cpp
│ └── time_utils.h
├── numpy
│ ├── time_utils.py
│ └── single_ops.py
└── python
│ ├── comparative
│ └── README.md
│ ├── rope_bench.py
│ ├── time_utils.py
│ └── synchronize_bench.py
├── tests
├── tests.cpp
├── device_tests.cpp
└── allocator_tests.cpp
├── CITATION.cff
├── .pre-commit-config.yaml
├── LICENSE
├── create_library_symlinks.sh
├── .gitignore
├── CONTRIBUTING.md
└── sync_to_m3u.sh
/docs/.nojekyll:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/python/mlx/py.typed:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/.clang-format:
--------------------------------------------------------------------------------
1 | DisableFormat: true
2 | SortIncludes: Never
3 |
--------------------------------------------------------------------------------
/mlx/3rdparty/.clang-format:
--------------------------------------------------------------------------------
1 | DisableFormat: true
2 | SortIncludes: Never
3 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_jaccl/jaccl_config.json:
--------------------------------------------------------------------------------
1 | [
2 | [null, "rdma_en2"],
3 | ["rdma_en5", null]
4 | ]
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | breathe
3 | sphinx-book-theme
4 | sphinx-copybutton
5 | mlx
6 |
--------------------------------------------------------------------------------
/examples/extensions/requirements.txt:
--------------------------------------------------------------------------------
1 | setuptools>=42
2 | cmake>=3.25
3 | mlx>=0.21.0
4 | nanobind==2.4.0
5 |
--------------------------------------------------------------------------------
/docs/src/_static/mlx_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anemll/mlx-rdma/HEAD/docs/src/_static/mlx_logo.png
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | src/python/_autosummary*/
2 | src/python/nn/_autosummary*/
3 | src/python/optimizers/_autosummary*/
4 |
--------------------------------------------------------------------------------
/docs/src/_static/mlx_logo_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anemll/mlx-rdma/HEAD/docs/src/_static/mlx_logo_dark.png
--------------------------------------------------------------------------------
/docs/src/cpp/ops.rst:
--------------------------------------------------------------------------------
1 | .. _cpp_ops:
2 |
3 | Operations
4 | ==========
5 |
6 | .. doxygengroup:: ops
7 | :content-only:
8 |
--------------------------------------------------------------------------------
/python/tests/__main__.py:
--------------------------------------------------------------------------------
1 | from . import mlx_tests
2 |
3 | __unittest = True
4 |
5 | mlx_tests.MLXTestRunner(module=None)
6 |
--------------------------------------------------------------------------------
/docs/src/_static/metal_debugger/capture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anemll/mlx-rdma/HEAD/docs/src/_static/metal_debugger/capture.png
--------------------------------------------------------------------------------
/docs/src/_static/metal_debugger/schema.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anemll/mlx-rdma/HEAD/docs/src/_static/metal_debugger/schema.png
--------------------------------------------------------------------------------
/mlx/backend/cpu/simd/simd.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "mlx/backend/cpu/simd/math.h"
4 | #include "mlx/backend/cpu/simd/type.h"
5 |
--------------------------------------------------------------------------------
/examples/extensions/mlx_sample_extensions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import mlx.core as mx
4 |
5 | from ._ext import axpby
6 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "github-actions"
4 | directory: "/"
5 | schedule:
6 | interval: "weekly"
7 |
--------------------------------------------------------------------------------
/python/mlx/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023-2024 Apple Inc.
2 |
3 | from mlx.optimizers.optimizers import *
4 | from mlx.optimizers.schedulers import *
5 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=80",
4 | "nanobind==2.4.0",
5 | "cmake>=3.25",
6 | ]
7 | build-backend = "setuptools.build_meta"
8 |
--------------------------------------------------------------------------------
/docs/src/python/cuda.rst:
--------------------------------------------------------------------------------
1 | CUDA
2 | =====
3 |
4 | .. currentmodule:: mlx.core.cuda
5 |
6 | .. autosummary::
7 | :toctree: _autosummary
8 |
9 | is_available
10 |
--------------------------------------------------------------------------------
/mlx/version.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | namespace mlx::core {
4 |
5 | const char* version() {
6 | return MLX_VERSION;
7 | }
8 |
9 | } // namespace mlx::core
10 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/simd/type.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "mlx/backend/cpu/simd/base_simd.h"
4 |
5 | #ifdef MLX_USE_ACCELERATE
6 | #include "mlx/backend/cpu/simd/accelerate_simd.h"
7 | #endif
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/abs.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Abs)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/ceil.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Ceil)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/cos.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Cos)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/cosh.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Cosh)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/erf.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Erf)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/exp.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Exp)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/imag.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Imag)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/real.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Real)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/sign.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Sign)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/sin.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Sin)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/sinh.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Sinh)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/tan.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Tan)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/tanh.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Tanh)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/python/mlx/nn/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | from mlx.nn import init, losses
4 | from mlx.nn.layers import *
5 | from mlx.nn.utils import average_gradients, value_and_grad
6 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include CMakeLists.txt
2 | include mlx.pc.in
3 | recursive-include mlx/ *
4 | include cmake/*
5 | include python/src/*
6 | include python/mlx/py.typed # support type hinting as in PEP-561
7 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/available.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | namespace mlx::core::cpu {
6 |
7 | bool is_available();
8 |
9 | } // namespace mlx::core::cpu
10 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/add.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(Add)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/less.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(Less)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/arccos.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(ArcCos)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/arcsin.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(ArcSin)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/arctan.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(ArcTan)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/erf_inv.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(ErfInv)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/expm1.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Expm1)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/floor.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Floor)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/log1p.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Log1p)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/square.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Square)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/gpu/available.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | namespace mlx::core::gpu {
6 |
7 | bool is_available();
8 |
9 | } // namespace mlx::core::gpu
10 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/arctan2.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(ArcTan2)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/divide.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(Divide)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/greater.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(Greater)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/maximum.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(Maximum)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/minimum.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(Minimum)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/power.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(Power)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/arccosh.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(ArcCosh)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/arcsinh.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(ArcSinh)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/arctanh.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(ArcTanh)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/conjugate.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Conjugate)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/negative.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Negative)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/sigmoid.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(Sigmoid)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/examples/extensions/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=42",
4 | "cmake>=3.25",
5 | "mlx>=0.18.0",
6 | "nanobind==2.4.0",
7 | ]
8 | build-backend = "setuptools.build_meta"
9 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/less_equal.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(LessEqual)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/logical_or.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(LogicalOr)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/multiply.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(Multiply)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/not_equal.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(NotEqual)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/remainder.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(Remainder)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/subtract.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(Subtract)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/logical_not.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(LogicalNot)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/reduce_utils.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/backend/metal/kernels/atomic.h"
6 | #include "mlx/backend/metal/kernels/reduction/ops.h"
7 |
--------------------------------------------------------------------------------
/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | ncclResult_t ncclGetUniqueId(ncclUniqueId*) {
6 | return ncclSuccess;
7 | }
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/log_add_exp.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(LogAddExp)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/logical_and.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(LogicalAnd)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/bitwise_invert.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | UNARY_GPU(BitwiseInvert)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/greater_equal.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | BINARY_GPU(GreaterEqual)
7 | } // namespace mlx::core
8 |
--------------------------------------------------------------------------------
/mlx/backend/gpu/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(
2 | mlx
3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
4 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
5 | ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)
6 |
--------------------------------------------------------------------------------
/mlx/distributed/mpi/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | if(MLX_BUILD_CPU)
2 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp)
3 | else()
4 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_mpi.cpp)
5 | endif()
6 |
--------------------------------------------------------------------------------
/docs/src/python/metal.rst:
--------------------------------------------------------------------------------
1 | Metal
2 | =====
3 |
4 | .. currentmodule:: mlx.core.metal
5 |
6 | .. autosummary::
7 | :toctree: _autosummary
8 |
9 | is_available
10 | device_info
11 | start_capture
12 | stop_capture
13 |
--------------------------------------------------------------------------------
/cmake/Findnvpl.cmake:
--------------------------------------------------------------------------------
1 | # This file does nothing but to suppress the cmake warning: "By not providing
2 | # Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
3 | # find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.
4 |
--------------------------------------------------------------------------------
/mlx/distributed/ring/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | if(MLX_BUILD_CPU AND NOT WIN32)
2 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp)
3 | else()
4 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_ring.cpp)
5 | endif()
6 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/cuda.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/cuda.h"
4 |
5 | namespace mlx::core::cu {
6 |
7 | bool is_available() {
8 | return true;
9 | }
10 |
11 | } // namespace mlx::core::cu
12 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/cuda.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | namespace mlx::core::cu {
6 |
7 | /* Check if the CUDA backend is available. */
8 | bool is_available();
9 |
10 | } // namespace mlx::core::cu
11 |
--------------------------------------------------------------------------------
/mlx/backend/common/broadcasting.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 |
7 | namespace mlx::core {
8 |
9 | void broadcast(const array& in, array& out);
10 |
11 | } // namespace mlx::core
12 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/available.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cpu/available.h"
4 |
5 | namespace mlx::core::cpu {
6 |
7 | bool is_available() {
8 | return true;
9 | }
10 |
11 | } // namespace mlx::core::cpu
12 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/steel/defines.cuh:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #define MLX_UNROLL _Pragma("unroll")
6 |
7 | #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
8 | #define MLX_CUDA_SM_80_ENABLED
9 | #endif
10 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/eval.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 | #include "mlx/stream.h"
7 |
8 | namespace mlx::core::cpu {
9 |
10 | void eval(array& arr);
11 |
12 | } // namespace mlx::core::cpu
13 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/steel/conv/loader.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h"
6 | #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h"
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/ternary_ops.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | struct Select {
6 | template
7 | T operator()(bool condition, T x, T y) {
8 | return condition ? x : y;
9 | }
10 | };
11 |
--------------------------------------------------------------------------------
/mlx/backend/no_cpu/available.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cpu/available.h"
4 |
5 | namespace mlx::core::cpu {
6 |
7 | bool is_available() {
8 | return false;
9 | }
10 |
11 | } // namespace mlx::core::cpu
12 |
--------------------------------------------------------------------------------
/docs/src/python/export.rst:
--------------------------------------------------------------------------------
1 | .. _export:
2 |
3 | Export Functions
4 | ================
5 |
6 | .. currentmodule:: mlx.core
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | export_function
12 | import_function
13 | exporter
14 | export_to_dot
15 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/reduction/reduce_init.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | template
4 | [[kernel]] void init_reduce(
5 | device T* out [[buffer(0)]],
6 | uint tid [[thread_position_in_grid]]) {
7 | out[tid] = Op::init;
8 | }
9 |
--------------------------------------------------------------------------------
/docs/src/python/fast.rst:
--------------------------------------------------------------------------------
1 | .. _fast:
2 |
3 | Fast
4 | ====
5 |
6 | .. currentmodule:: mlx.core.fast
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | rms_norm
12 | layer_norm
13 | rope
14 | scaled_dot_product_attention
15 | metal_kernel
16 | cuda_kernel
17 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/detect_cuda_arch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | arch=`__nvcc_device_query`
4 | case "$arch" in
5 | "90")
6 | echo "90a" ;;
7 | "100")
8 | echo "100a" ;;
9 | "121")
10 | echo "121a" ;;
11 | *)
12 | echo "native" ;;
13 | esac
14 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/steel/defines.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #define STEEL_CONST static constant constexpr const
6 | #define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
7 | #define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)")
8 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/reduce.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "mlx/backend/metal/kernels/reduction/reduce_all.h"
3 | #include "mlx/backend/metal/kernels/reduction/reduce_col.h"
4 | #include "mlx/backend/metal/kernels/reduction/reduce_init.h"
5 | #include "mlx/backend/metal/kernels/reduction/reduce_row.h"
6 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/arange.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 | template
3 | [[kernel]] void arange(
4 | constant const T& start,
5 | constant const T& step,
6 | device T* out,
7 | uint index [[thread_position_in_grid]]) {
8 | out[index] = start + index * step;
9 | }
10 |
--------------------------------------------------------------------------------
/mlx/backend/no_gpu/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(
2 | mlx
3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
4 | ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
5 | ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
6 | ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
7 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp)
8 |
--------------------------------------------------------------------------------
/docs/src/python/optimizers/schedulers.rst:
--------------------------------------------------------------------------------
1 | .. _schedulers:
2 |
3 | Schedulers
4 | ==========
5 |
6 | .. currentmodule:: mlx.optimizers
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | cosine_decay
12 | exponential_decay
13 | join_schedules
14 | linear_schedule
15 | step_decay
16 |
--------------------------------------------------------------------------------
/examples/cmake_project/example.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/mlx.h"
6 |
7 | namespace mx = mlx::core;
8 |
9 | int main() {
10 | auto x = mx::array({1, 2, 3});
11 | auto y = mx::array({1, 2, 3});
12 | std::cout << x + y << std::endl;
13 | return 0;
14 | }
15 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/device/ternary_ops.cuh:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 | #pragma once
3 |
4 | namespace mlx::core::cu {
5 |
6 | struct Select {
7 | template
8 | __device__ T operator()(bool condition, T x, T y) {
9 | return condition ? x : y;
10 | }
11 | };
12 |
13 | } // namespace mlx::core::cu
14 |
--------------------------------------------------------------------------------
/mlx/backend/no_cpu/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(
2 | mlx
3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
4 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
5 | ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/eval.cpp
6 | ${CMAKE_CURRENT_SOURCE_DIR}/../cpu/encoder.cpp
7 | ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp)
8 |
--------------------------------------------------------------------------------
/docs/src/python/fft.rst:
--------------------------------------------------------------------------------
1 | .. _fft:
2 |
3 | FFT
4 | ===
5 |
6 | .. currentmodule:: mlx.core.fft
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | fft
12 | ifft
13 | fft2
14 | ifft2
15 | fftn
16 | ifftn
17 | rfft
18 | irfft
19 | rfft2
20 | irfft2
21 | rfftn
22 | irfftn
23 | fftshift
24 | ifftshift
25 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/compiled_preamble.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-24 Apple Inc.
2 |
3 | #pragma once
4 |
5 | // clang-format off
6 | #include "mlx/types/half_types.h"
7 | #include "mlx/types/complex.h"
8 | #include "mlx/backend/cpu/unary_ops.h"
9 | #include "mlx/backend/cpu/binary_ops.h"
10 | // clang-format on
11 |
12 | const char* get_kernel_preamble();
13 |
--------------------------------------------------------------------------------
/docs/src/python/memory_management.rst:
--------------------------------------------------------------------------------
1 | Memory Management
2 | =================
3 |
4 | .. currentmodule:: mlx.core
5 |
6 | .. autosummary::
7 | :toctree: _autosummary
8 |
9 | get_active_memory
10 | get_peak_memory
11 | reset_peak_memory
12 | get_cache_memory
13 | set_memory_limit
14 | set_cache_limit
15 | set_wired_limit
16 | clear_cache
17 |
--------------------------------------------------------------------------------
/mlx/backend/no_gpu/apple_memory.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | namespace {
8 |
9 | size_t get_memory_size() {
10 | size_t memsize = 0;
11 | size_t length = sizeof(memsize);
12 | sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
13 | return memsize;
14 | }
15 |
16 | } // namespace
17 |
--------------------------------------------------------------------------------
/mlx/distributed/mpi/mpi.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include "mlx/distributed/distributed.h"
4 |
5 | namespace mlx::core::distributed::mpi {
6 |
7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8 |
9 | bool is_available();
10 | std::shared_ptr init(bool strict = false);
11 |
12 | } // namespace mlx::core::distributed::mpi
13 |
--------------------------------------------------------------------------------
/docs/src/python/transforms.rst:
--------------------------------------------------------------------------------
1 | .. _transforms:
2 |
3 | Transforms
4 | ==========
5 |
6 | .. currentmodule:: mlx.core
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | eval
12 | async_eval
13 | compile
14 | custom_function
15 | disable_compile
16 | enable_compile
17 | grad
18 | value_and_grad
19 | jvp
20 | vjp
21 | vmap
22 |
--------------------------------------------------------------------------------
/mlx/backend/metal/scan.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "mlx/array.h"
4 | #include "mlx/primitives.h"
5 |
6 | namespace mlx::core {
7 |
8 | void scan_gpu_inplace(
9 | array in,
10 | array& out,
11 | Scan::ReduceType reduce_type,
12 | int axis,
13 | bool reverse,
14 | bool inclusive,
15 | const Stream& s);
16 |
17 | } // namespace mlx::core
18 |
--------------------------------------------------------------------------------
/mlx/distributed/nccl/nccl.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include "mlx/distributed/distributed.h"
4 |
5 | namespace mlx::core::distributed::nccl {
6 |
7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8 |
9 | bool is_available();
10 | std::shared_ptr init(bool strict = false);
11 |
12 | } // namespace mlx::core::distributed::nccl
13 |
--------------------------------------------------------------------------------
/mlx/distributed/ring/ring.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include "mlx/distributed/distributed.h"
4 |
5 | namespace mlx::core::distributed::ring {
6 |
7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8 |
9 | bool is_available();
10 | std::shared_ptr init(bool strict = false);
11 |
12 | } // namespace mlx::core::distributed::ring
13 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | # Development dependencies for MLX
2 | # Install with: uv pip install -r requirements-dev.txt
3 |
4 | # Build dependencies
5 | cmake>=3.25
6 | setuptools>=80
7 | nanobind==2.4.0
8 |
9 | # Test dependencies
10 | numpy
11 | torch
12 | tensorflow
13 | unittest-xml-reporting
14 | typing_extensions
15 |
16 | # Development tools
17 | pre-commit
18 |
19 |
--------------------------------------------------------------------------------
/mlx/distributed/jaccl/jaccl.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/distributed/distributed.h"
4 |
5 | namespace mlx::core::distributed::jaccl {
6 |
7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8 |
9 | bool is_available();
10 | std::shared_ptr init(bool strict = false);
11 |
12 | } // namespace mlx::core::distributed::jaccl
13 |
--------------------------------------------------------------------------------
/examples/extensions/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Build
3 |
4 | ```
5 | pip install -e .
6 | ```
7 |
8 | For faster builds during development, you can also pre-install the requirements:
9 |
10 | ```
11 | pip install -r requirements.txt
12 | ```
13 |
14 | And then run:
15 |
16 | ```
17 | python setup.py build_ext -j8 --inplace
18 | ```
19 |
20 | ## Test
21 |
22 | ```
23 | python test.py
24 | ```
25 |
--------------------------------------------------------------------------------
/docs/src/python/devices_and_streams.rst:
--------------------------------------------------------------------------------
1 | .. _devices_and_streams:
2 |
3 | Devices and Streams
4 | ===================
5 |
6 | .. currentmodule:: mlx.core
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | Device
12 | Stream
13 | default_device
14 | set_default_device
15 | default_stream
16 | new_stream
17 | set_default_stream
18 | stream
19 | synchronize
20 |
--------------------------------------------------------------------------------
/mlx/backend/gpu/eval.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 |
8 | #include "mlx/array.h"
9 | #include "mlx/stream.h"
10 |
11 | namespace mlx::core::gpu {
12 |
13 | void new_stream(Stream stream);
14 | void eval(array& arr);
15 | void finalize(Stream s);
16 | void synchronize(Stream s);
17 |
18 | } // namespace mlx::core::gpu
19 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/bf16.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | using namespace metal;
8 |
9 | typedef bfloat bfloat16_t;
10 | inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
11 | return as_type(x);
12 | }
13 |
14 | inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
15 | return as_type(x);
16 | }
17 |
--------------------------------------------------------------------------------
/python/mlx/_reprlib_fix.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import array
4 | import reprlib
5 |
6 | _old_repr_array = reprlib.Repr.repr_array
7 |
8 |
9 | def repr_array(self, x, maxlevel):
10 | if isinstance(x, array.array):
11 | return _old_repr_array(self, x, maxlevel)
12 | else:
13 | return self.repr_instance(x, maxlevel)
14 |
15 |
16 | reprlib.Repr.repr_array = repr_array
17 |
--------------------------------------------------------------------------------
/mlx/backend/no_gpu/linux_memory.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | namespace {
8 |
9 | size_t get_memory_size() {
10 | struct sysinfo info;
11 |
12 | if (sysinfo(&info) != 0) {
13 | return 0;
14 | }
15 |
16 | size_t total_ram = info.totalram;
17 | total_ram *= info.mem_unit;
18 |
19 | return total_ram;
20 | }
21 |
22 | } // namespace
23 |
--------------------------------------------------------------------------------
/docs/src/python/optimizers/common_optimizers.rst:
--------------------------------------------------------------------------------
1 | .. _common_optimizers:
2 |
3 | Common Optimizers
4 | =================
5 |
6 | .. currentmodule:: mlx.optimizers
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 | :template: optimizers-template.rst
11 |
12 | SGD
13 | RMSprop
14 | Adagrad
15 | Adafactor
16 | AdaDelta
17 | Adam
18 | AdamW
19 | Adamax
20 | Lion
21 | MultiOptimizer
22 | Muon
23 |
--------------------------------------------------------------------------------
/examples/cmake_project/README.md:
--------------------------------------------------------------------------------
1 | ## Build and Run
2 |
3 | Install MLX with Python:
4 |
5 | ```bash
6 | pip install mlx>=0.22
7 | ```
8 |
9 | Build the C++ example:
10 |
11 | ```bash
12 | cmake -B build -DCMAKE_BUILD_TYPE=Release
13 | cmake --build build
14 | ```
15 |
16 | Run the C++ example:
17 |
18 | ```
19 | ./build/example
20 | ```
21 |
22 | which should output:
23 |
24 | ```
25 | array([2, 4, 6], dtype=int32)
26 | ```
27 |
--------------------------------------------------------------------------------
/examples/cpp/timer.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | namespace timer {
8 |
9 | using namespace std::chrono;
10 |
11 | template
12 | inline double seconds(duration x) {
13 | return duration_cast(x).count() / 1e9;
14 | }
15 |
16 | inline auto time() {
17 | return high_resolution_clock::now();
18 | }
19 |
20 | } // namespace timer
21 |
--------------------------------------------------------------------------------
/mlx/backend/common/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(
2 | mlx
3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
4 | ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
5 | ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
6 | ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
7 | ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
8 | ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
9 | ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
10 |
--------------------------------------------------------------------------------
/examples/extensions/test.py:
--------------------------------------------------------------------------------
1 | import mlx.core as mx
2 | from mlx_sample_extensions import axpby
3 |
4 | a = mx.ones((3, 4))
5 | b = mx.ones((3, 4))
6 | c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
7 | c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
8 |
9 | print(f"c shape: {c_cpu.shape}")
10 | print(f"c dtype: {c_cpu.dtype}")
11 | print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
12 | print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
13 |
--------------------------------------------------------------------------------
/docs/src/python/linalg.rst:
--------------------------------------------------------------------------------
1 | .. _linalg:
2 |
3 | Linear Algebra
4 | ==============
5 |
6 | .. currentmodule:: mlx.core.linalg
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | inv
12 | tri_inv
13 | norm
14 | cholesky
15 | cholesky_inv
16 | cross
17 | qr
18 | svd
19 | eigvals
20 | eig
21 | eigvalsh
22 | eigh
23 | lu
24 | lu_factor
25 | pinv
26 | solve
27 | solve_triangular
28 |
--------------------------------------------------------------------------------
/mlx/backend/common/slicing.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 |
7 | namespace mlx::core {
8 |
9 | std::tuple prepare_slice(
10 | const array& in,
11 | const Shape& start_indices,
12 | const Shape& strides);
13 |
14 | void slice(
15 | const array& in,
16 | array& out,
17 | const Shape& start_indices,
18 | const Shape& strides);
19 |
20 | } // namespace mlx::core
21 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/device/config.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | // This file is used by both CUDA kernel code and host-only C++ code.
4 |
5 | #pragma once
6 |
7 | // The maximum dimensions of shape/strides passed as kernel parameters.
8 | #define MAX_NDIM 10
9 |
10 | // All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
11 | // warpSize variable exists, using it would prevent compile-time optimizations.
12 | #define WARP_SIZE 32
13 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/steel/conv/conv.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/backend/metal/kernels/steel/defines.h"
6 | #include "mlx/backend/metal/kernels/steel/utils.h"
7 |
8 | #include "mlx/backend/metal/kernels/steel/conv/loader.h"
9 | #include "mlx/backend/metal/kernels/steel/conv/params.h"
10 | #include "mlx/backend/metal/kernels/steel/gemm/mma.h"
11 |
12 | using namespace metal;
13 | using namespace mlx::steel;
14 |
--------------------------------------------------------------------------------
/examples/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | function(build_example SRCFILE)
2 | get_filename_component(src_name ${SRCFILE} NAME_WE)
3 | set(target "${src_name}")
4 | add_executable(${target} ${SRCFILE})
5 | target_link_libraries(${target} PRIVATE mlx)
6 | endfunction(build_example)
7 |
8 | build_example(tutorial.cpp)
9 | build_example(linear_regression.cpp)
10 | build_example(logistic_regression.cpp)
11 | build_example(metal_capture.cpp)
12 | build_example(distributed.cpp)
13 |
--------------------------------------------------------------------------------
/mlx/backend/metal/unary.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 |
7 | namespace mlx::core {
8 |
9 | void unary_op_gpu(
10 | const std::vector& inputs,
11 | array& out,
12 | const char* op,
13 | const Stream& s);
14 |
15 | void unary_op_gpu_inplace(
16 | const std::vector& inputs,
17 | array& out,
18 | const char* op,
19 | const Stream& s);
20 |
21 | } // namespace mlx::core
22 |
--------------------------------------------------------------------------------
/docs/src/python/optimizers/optimizer.rst:
--------------------------------------------------------------------------------
1 | Optimizer
2 | =========
3 |
4 | .. currentmodule:: mlx.optimizers
5 |
6 | .. autoclass:: Optimizer
7 |
8 |
9 | .. rubric:: Attributes
10 |
11 | .. autosummary::
12 | :toctree: _autosummary
13 |
14 | Optimizer.state
15 |
16 | .. rubric:: Methods
17 |
18 | .. autosummary::
19 | :toctree: _autosummary
20 |
21 | Optimizer.apply_gradients
22 | Optimizer.init
23 | Optimizer.update
24 |
--------------------------------------------------------------------------------
/mlx/backend/metal/ternary.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 |
7 | namespace mlx::core {
8 |
9 | void ternary_op_gpu(
10 | const std::vector& inputs,
11 | array& out,
12 | const char* op,
13 | const Stream& s);
14 |
15 | void ternary_op_gpu_inplace(
16 | const std::vector& inputs,
17 | array& out,
18 | const char* op,
19 | const Stream& s);
20 |
21 | } // namespace mlx::core
22 |
--------------------------------------------------------------------------------
/benchmarks/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | function(build_benchmark SRCFILE)
2 | get_filename_component(src_name ${SRCFILE} NAME_WE)
3 | set(target "${src_name}")
4 | add_executable(${target} ${SRCFILE})
5 | target_link_libraries(${target} PRIVATE mlx)
6 | endfunction(build_benchmark)
7 |
8 | build_benchmark(single_ops.cpp)
9 | build_benchmark(irregular_strides.cpp)
10 | build_benchmark(compare_devices.cpp)
11 | build_benchmark(autograd.cpp)
12 | build_benchmark(separate_tp_vs_single.cpp)
13 |
--------------------------------------------------------------------------------
/mlx/io/gguf.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 | #pragma once
3 |
4 | #include "mlx/io.h"
5 | #include "mlx/primitives.h"
6 | #include "mlx/transforms.h"
7 | #include "mlx/utils.h"
8 |
9 | extern "C" {
10 | #include
11 | }
12 |
13 | namespace mlx::core {
14 |
15 | Shape get_shape(const gguf_tensor& tensor);
16 | void gguf_load_quantized(
17 | std::unordered_map& a,
18 | const gguf_tensor& tensor);
19 |
20 | } // namespace mlx::core
21 |
--------------------------------------------------------------------------------
/python/src/cuda.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2025 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/backend/cuda/cuda.h"
6 |
7 | namespace mx = mlx::core;
8 | namespace nb = nanobind;
9 |
10 | void init_cuda(nb::module_& m) {
11 | nb::module_ cuda = m.def_submodule("cuda", "mlx.cuda");
12 |
13 | cuda.def(
14 | "is_available",
15 | &mx::cu::is_available,
16 | R"pbdoc(
17 | Check if the CUDA back-end is available.
18 | )pbdoc");
19 | }
20 |
--------------------------------------------------------------------------------
/benchmarks/numpy/time_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import time
4 |
5 |
6 | def time_fn(fn, *args):
7 | print(f"Timing {fn.__name__} ...", end=" ")
8 |
9 | # warmup
10 | for _ in range(5):
11 | fn(*args)
12 |
13 | num_iters = 100
14 | tic = time.perf_counter()
15 | for _ in range(num_iters):
16 | x = fn(*args)
17 | toc = time.perf_counter()
18 |
19 | msec = 1e3 * (toc - tic) / num_iters
20 | print(f"{msec:.5f} msec")
21 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/encoder.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cpu/encoder.h"
4 |
5 | namespace mlx::core::cpu {
6 |
7 | CommandEncoder& get_command_encoder(Stream stream) {
8 | static std::unordered_map encoder_map;
9 | auto it = encoder_map.find(stream.index);
10 | if (it == encoder_map.end()) {
11 | it = encoder_map.emplace(stream.index, stream).first;
12 | }
13 | return it->second;
14 | }
15 |
16 | } // namespace mlx::core::cpu
17 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/slicing.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 |
7 | namespace mlx::core {
8 |
9 | std::tuple prepare_slice(
10 | const array& in,
11 | const Shape& start_indices,
12 | const Shape& strides);
13 |
14 | void shared_buffer_slice(
15 | const array& in,
16 | const Strides& out_strides,
17 | size_t data_offset,
18 | size_t data_size,
19 | array& out);
20 |
21 | } // namespace mlx::core
22 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/sqrt.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | void Sqrt::eval_gpu(const std::vector& inputs, array& out) {
7 | nvtx3::scoped_range r("Sqrt::eval_gpu");
8 | auto& s = out.primitive().stream();
9 | if (recip_) {
10 | unary_op_gpu(inputs, out, "Rsqrt", s);
11 | } else {
12 | unary_op_gpu(inputs, out, "Sqrt", s);
13 | }
14 | }
15 | } // namespace mlx::core
16 |
--------------------------------------------------------------------------------
/docs/src/_templates/optimizers-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 |
7 | {% block methods %}
8 |
9 | {% if methods %}
10 | .. rubric:: {{ _('Methods') }}
11 |
12 | .. autosummary::
13 | {% for item in methods %}
14 | {%- if item not in inherited_members %}
15 | ~{{ name }}.{{ item }}
16 | {%- endif %}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/equal.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | void Equal::eval_gpu(const std::vector& inputs, array& out) {
7 | nvtx3::scoped_range r("Equal::eval_gpu");
8 | auto& s = out.primitive().stream();
9 | if (equal_nan_) {
10 | binary_op_gpu(inputs, out, name(), s);
11 | } else {
12 | binary_op_gpu(inputs, out, name(), s);
13 | }
14 | }
15 | } // namespace mlx::core
16 |
--------------------------------------------------------------------------------
/docs/src/_templates/nn-module-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 |
7 | {% block methods %}
8 |
9 | {% if methods %}
10 | .. rubric:: {{ _('Methods') }}
11 |
12 | .. autosummary::
13 | {% for item in methods %}
14 | {%- if item not in inherited_members and item != "__init__" %}
15 | ~{{ name }}.{{ item }}
16 | {%- endif %}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 |
--------------------------------------------------------------------------------
/mlx/distributed/mpi/no_mpi.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include "mlx/distributed/mpi/mpi.h"
4 |
5 | namespace mlx::core::distributed::mpi {
6 |
7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8 |
9 | bool is_available() {
10 | return false;
11 | }
12 |
13 | std::shared_ptr init(bool strict /* = false */) {
14 | if (strict) {
15 | throw std::runtime_error("Cannot initialize MPI");
16 | }
17 | return nullptr;
18 | }
19 |
20 | } // namespace mlx::core::distributed::mpi
21 |
--------------------------------------------------------------------------------
/docs/src/python/nn/losses.rst:
--------------------------------------------------------------------------------
1 | .. _losses:
2 |
3 | .. currentmodule:: mlx.nn.losses
4 |
5 | Loss Functions
6 | --------------
7 |
8 | .. autosummary::
9 | :toctree: _autosummary_functions
10 | :template: nn-module-template.rst
11 |
12 | binary_cross_entropy
13 | cosine_similarity_loss
14 | cross_entropy
15 | gaussian_nll_loss
16 | hinge_loss
17 | huber_loss
18 | kl_div_loss
19 | l1_loss
20 | log_cosh_loss
21 | margin_ranking_loss
22 | mse_loss
23 | nll_loss
24 | smooth_l1_loss
25 | triplet_loss
--------------------------------------------------------------------------------
/mlx/distributed/jaccl/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | if(MLX_BUILD_CPU
2 | AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
3 | AND MACOS_SDK_VERSION GREATER_EQUAL 26.2
4 | OR MLX_BUILD_JACCL)
5 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp)
6 | # On macOS, link against librdma which provides InfiniBand Verbs
7 | # Use PUBLIC so shared libraries linking to mlx also get rdma dependency
8 | target_link_libraries(mlx PUBLIC rdma)
9 | else()
10 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
11 | endif()
12 |
--------------------------------------------------------------------------------
/python/src/constants.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | namespace nb = nanobind;
7 |
8 | void init_constants(nb::module_& m) {
9 | m.attr("e") = 2.71828182845904523536028747135266249775724709369995;
10 | m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421;
11 | m.attr("inf") = std::numeric_limits::infinity();
12 | m.attr("nan") = NAN;
13 | m.attr("newaxis") = nb::none();
14 | m.attr("pi") = 3.1415926535897932384626433;
15 | }
16 |
--------------------------------------------------------------------------------
/docs/src/python/distributed.rst:
--------------------------------------------------------------------------------
1 | .. _distributed:
2 |
3 | .. currentmodule:: mlx.core.distributed
4 |
5 | Distributed Communication
6 | ==========================
7 |
8 | MLX provides a distributed communication package using MPI. The MPI library is
9 | loaded at runtime; if MPI is available then distributed communication is also
10 | made available.
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 |
15 | Group
16 | is_available
17 | init
18 | all_sum
19 | all_gather
20 | send
21 | recv
22 | recv_like
23 |
--------------------------------------------------------------------------------
/mlx/distributed/nccl/no_nccl.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include "mlx/distributed/nccl/nccl.h"
4 |
5 | namespace mlx::core::distributed::nccl {
6 |
7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8 |
9 | bool is_available() {
10 | return false;
11 | }
12 |
13 | std::shared_ptr init(bool strict /* = false */) {
14 | if (strict) {
15 | throw std::runtime_error("Cannot initialize nccl distributed backend.");
16 | }
17 | return nullptr;
18 | }
19 |
20 | } // namespace mlx::core::distributed::nccl
21 |
--------------------------------------------------------------------------------
/mlx/distributed/ring/no_ring.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include "mlx/distributed/ring/ring.h"
4 |
5 | namespace mlx::core::distributed::ring {
6 |
7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8 |
9 | bool is_available() {
10 | return false;
11 | }
12 |
13 | std::shared_ptr init(bool strict /* = false */) {
14 | if (strict) {
15 | throw std::runtime_error("Cannot initialize ring distributed backend.");
16 | }
17 | return nullptr;
18 | }
19 |
20 | } // namespace mlx::core::distributed::ring
21 |
--------------------------------------------------------------------------------
/mlx/distributed/jaccl/no_jaccl.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/distributed/jaccl/jaccl.h"
4 |
5 | namespace mlx::core::distributed::jaccl {
6 |
7 | using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8 |
9 | bool is_available() {
10 | return false;
11 | }
12 |
13 | std::shared_ptr init(bool strict /* = false */) {
14 | if (strict) {
15 | throw std::runtime_error("Cannot initialize jaccl distributed backend.");
16 | }
17 | return nullptr;
18 | }
19 |
20 | } // namespace mlx::core::distributed::jaccl
21 |
--------------------------------------------------------------------------------
/mlx/einsum.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 | #pragma once
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include "mlx/array.h"
9 | #include "mlx/utils.h"
10 |
11 | namespace mlx::core {
12 |
13 | std::pair>, std::string> einsum_path(
14 | const std::string& subscripts,
15 | const std::vector& operands);
16 |
17 | array einsum(
18 | const std::string& subscripts,
19 | const std::vector& operands,
20 | StreamOrDevice s = {});
21 |
22 | } // namespace mlx::core
23 |
--------------------------------------------------------------------------------
/mlx/distributed/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(
2 | mlx
3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
4 | ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
5 | ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
6 |
7 | if(MLX_BUILD_CPU AND NOT WIN32)
8 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
9 | endif()
10 |
11 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
12 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
13 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
14 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)
15 |
--------------------------------------------------------------------------------
/mlx/distributed/nccl/nccl_stub/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.25)
2 |
3 | project(nccl LANGUAGES C CXX)
4 |
5 | file(
6 | DOWNLOAD
7 | "https://raw.githubusercontent.com/NVIDIA/nccl/refs/tags/v2.27.5-1/src/nccl.h.in"
8 | "${CMAKE_CURRENT_BINARY_DIR}/nccl.h")
9 |
10 | add_library(nccl SHARED nccl_stubs.cpp)
11 | set_target_properties(nccl PROPERTIES SOVERSION 2)
12 | find_package(CUDAToolkit REQUIRED)
13 | target_include_directories(nccl PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
14 | target_include_directories(nccl PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
15 |
--------------------------------------------------------------------------------
/mlx/version.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #define MLX_VERSION_MAJOR 0
6 | #define MLX_VERSION_MINOR 30
7 | #define MLX_VERSION_PATCH 1
8 | #define MLX_VERSION_NUMERIC \
9 | (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
10 |
11 | namespace mlx::core {
12 |
13 | /* A string representation of the MLX version in the format
14 | * "major.minor.patch".
15 | *
16 | * For dev builds, the version will include the suffix ".devYYYYMMDD+hash"
17 | */
18 | const char* version();
19 |
20 | } // namespace mlx::core
21 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/gemm.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 | #include "mlx/array.h"
5 |
6 | namespace mlx::core {
7 |
8 | template
9 | void matmul(
10 | const T* a,
11 | const T* b,
12 | T* out,
13 | bool a_transposed,
14 | bool b_transposed,
15 | size_t lda,
16 | size_t ldb,
17 | size_t ldc,
18 | float alpha,
19 | float beta,
20 | size_t batch_size,
21 | const Shape& a_shape,
22 | const Strides& a_strides,
23 | const Shape& b_shape,
24 | const Strides& b_strides);
25 |
26 | } // namespace mlx::core
27 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/jit_compiler.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 | #pragma once
3 |
4 | #include
5 |
6 | namespace mlx::core {
7 |
8 | class JitCompiler {
9 | public:
10 | // Build a shell command that compiles a source code file to a shared library.
11 | static std::string build_command(
12 | const std::filesystem::path& dir,
13 | const std::string& source_file_name,
14 | const std::string& shared_lib_name);
15 |
16 | // Run a command and get its output.
17 | static std::string exec(const std::string& cmd);
18 | };
19 |
20 | } // namespace mlx::core
21 |
--------------------------------------------------------------------------------
/examples/export/eval_mlp.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | namespace mx = mlx::core;
7 |
8 | int main() {
9 | int batch_size = 8;
10 | int input_dim = 32;
11 |
12 | // Make the input
13 | mx::random::seed(42);
14 | auto example_x = mx::random::uniform({batch_size, input_dim});
15 |
16 | // Import the function
17 | auto forward = mx::import_function("eval_mlp.mlxfn");
18 |
19 | // Call the imported function
20 | auto out = forward({example_x})[0];
21 |
22 | std::cout << out << std::endl;
23 |
24 | return 0;
25 | }
26 |
--------------------------------------------------------------------------------
/mlx/allocator.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | #include "mlx/allocator.h"
7 |
8 | namespace mlx::core::allocator {
9 |
10 | Buffer malloc(size_t size) {
11 | auto buffer = allocator().malloc(size);
12 | if (size && !buffer.ptr()) {
13 | std::ostringstream msg;
14 | msg << "[malloc] Unable to allocate " << size << " bytes.";
15 | throw std::runtime_error(msg.str());
16 | }
17 | return buffer;
18 | }
19 |
20 | void free(Buffer buffer) {
21 | allocator().free(buffer);
22 | }
23 |
24 | } // namespace mlx::core::allocator
25 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/round.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | void Round::eval_gpu(const std::vector& inputs, array& out) {
7 | nvtx3::scoped_range r("Round::eval_gpu");
8 | assert(inputs.size() == 1);
9 | const auto& in = inputs[0];
10 | auto& s = out.primitive().stream();
11 | if (issubdtype(in.dtype(), inexact)) {
12 | unary_op_gpu(inputs, out, name(), s);
13 | } else {
14 | // No-op integer types
15 | out.copy_shared_buffer(in);
16 | }
17 | }
18 | } // namespace mlx::core
19 |
--------------------------------------------------------------------------------
/mlx/io/no_gguf.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #include "mlx/io.h"
4 |
5 | namespace mlx::core {
6 |
7 | GGUFLoad load_gguf(const std::string&, StreamOrDevice s) {
8 | throw std::runtime_error(
9 | "[load_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support.");
10 | }
11 |
12 | void save_gguf(
13 | std::string,
14 | std::unordered_map,
15 | std::unordered_map) {
16 | throw std::runtime_error(
17 | "[save_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support.");
18 | }
19 |
20 | } // namespace mlx::core
21 |
--------------------------------------------------------------------------------
/examples/cpp/distributed.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/mlx.h"
6 |
7 | namespace mx = mlx::core;
8 |
9 | int main() {
10 | if (!mx::distributed::is_available()) {
11 | std::cout << "No communication backend found" << std::endl;
12 | return 1;
13 | }
14 |
15 | auto global_group = mx::distributed::init();
16 | std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
17 |
18 | mx::array x = mx::ones({10});
19 | mx::array out = mx::distributed::all_sum(x, global_group);
20 |
21 | std::cout << out << std::endl;
22 | }
23 |
--------------------------------------------------------------------------------
/tests_jaccl/run_jaccl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Example usage:
4 | # ./run_jaccl.sh
5 |
6 | RANK=$1
7 | COORDINATOR=$2
8 |
9 | if [ -z "$RANK" ] || [ -z "$COORDINATOR" ]; then
10 | echo "Usage: $0 "
11 | exit 1
12 | fi
13 |
14 | export MLX_RANK=$RANK
15 | export MLX_IBV_COORDINATOR=$COORDINATOR
16 | export MLX_IBV_DEVICES=$(pwd)/jaccl_config.json
17 | export MLX_IBV_VERBOSE=1
18 |
19 | # Activate virtual environment if it exists
20 | if [ -f "../.venv/bin/activate" ]; then
21 | source ../.venv/bin/activate
22 | fi
23 |
24 | python3 test_tp_mlp.py
25 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/indexing/indexing.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | template
8 | struct Indices {
9 | const array buffers;
10 | const constant int* shapes;
11 | const constant int64_t* strides;
12 | const constant bool* row_contiguous;
13 | const int ndim;
14 | };
15 |
16 | template
17 | METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
18 | if (is_unsigned_v) {
19 | return idx;
20 | } else {
21 | return (idx < 0) ? idx + size : idx;
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/python/scripts/repair_linux.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | auditwheel repair dist/* \
4 | --plat manylinux_2_35_${1} \
5 | --only-plat \
6 | --exclude libmlx* \
7 | -w wheel_tmp
8 |
9 | mkdir wheelhouse
10 | cd wheel_tmp
11 | repaired_wheel=$(find . -name "*.whl" -print -quit)
12 | unzip -q "${repaired_wheel}"
13 | rm "${repaired_wheel}"
14 | core_so=$(find mlx -name "core*.so" -print -quit)
15 | rpath="\$ORIGIN/lib"
16 | patchelf --force-rpath --set-rpath "$rpath" "$core_so"
17 | python ../python/scripts/repair_record.py ${core_so}
18 |
19 | # Re-zip the repaired wheel
20 | zip -r -q "../wheelhouse/${repaired_wheel}" .
21 |
--------------------------------------------------------------------------------
/tests/tests.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #define DOCTEST_CONFIG_IMPLEMENT
4 | #include "doctest/doctest.h"
5 |
6 | #include
7 |
8 | #include "mlx/mlx.h"
9 |
10 | using namespace mlx::core;
11 |
12 | int main(int argc, char** argv) {
13 | doctest::Context context;
14 |
15 | const char* device = std::getenv("DEVICE");
16 | if (device != nullptr && std::string(device) == "cpu") {
17 | set_default_device(Device::cpu);
18 | } else if (metal::is_available()) {
19 | set_default_device(Device::gpu);
20 | }
21 |
22 | context.applyCommandLine(argc, argv);
23 | return context.run();
24 | }
25 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/gemms/gemv.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/backend/cuda/device.h"
6 |
7 | namespace mlx::core::cu {
8 |
9 | bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed);
10 |
11 | void gemv(
12 | const array& a,
13 | const array& b,
14 | array& out,
15 | int M,
16 | int N,
17 | int K,
18 | uint32_t batch_count,
19 | const mlx::core::Shape& batch_shape,
20 | const mlx::core::Strides& a_batch_strides,
21 | const mlx::core::Strides& b_batch_strides,
22 | CommandEncoder& encoder);
23 |
24 | } // namespace mlx::core::cu
25 |
--------------------------------------------------------------------------------
/mlx/backend/metal/metal.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 | #include
8 |
9 | namespace mlx::core::metal {
10 |
11 | /* Check if the Metal backend is available. */
12 | bool is_available();
13 |
14 | /** Capture a GPU trace, saving it to an absolute file `path` */
15 | void start_capture(std::string path = "");
16 | void stop_capture();
17 |
18 | /** Get information about the GPU and system settings. */
19 | const std::unordered_map>&
20 | device_info();
21 |
22 | } // namespace mlx::core::metal
23 |
--------------------------------------------------------------------------------
/.github/actions/build-cuda-release/action.yml:
--------------------------------------------------------------------------------
1 | name: 'Build CUDA wheel'
2 | description: 'Build CUDA wheel'
3 |
4 | inputs:
5 | toolkit:
6 | description: 'The CUDA toolkit'
7 | required: true
8 |
9 | runs:
10 | using: "composite"
11 | steps:
12 | - name: Build package
13 | shell: bash
14 | env:
15 | CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
16 | run: |
17 | pip install auditwheel build patchelf setuptools
18 | python setup.py clean --all
19 | MLX_BUILD_STAGE=2 python -m build -w
20 | bash python/scripts/repair_cuda.sh
21 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ## Proposed changes
2 |
3 | Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.
4 |
5 | ## Checklist
6 |
7 | Put an `x` in the boxes that apply.
8 |
9 | - [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document
10 | - [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes
11 | - [ ] I have added tests that prove my fix is effective or that my feature works
12 | - [ ] I have updated the necessary documentation (if needed)
13 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/quantized/convert_fp8.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 | #include "mlx/backend/cuda/unary/unary.cuh"
3 | #include "mlx/fast_primitives.h"
4 |
5 | namespace mlx::core {
6 | void fast::ConvertFP8::eval_gpu(
7 | const std::vector& inputs,
8 | std::vector& outputs) {
9 | nvtx3::scoped_range r("ConvertFP8::eval_gpu");
10 | auto& in = inputs[0];
11 | auto& out = outputs[0];
12 | auto& s = out.primitive().stream();
13 | if (to_fp8_) {
14 | unary_op_gpu(inputs, out, name(), s);
15 | } else {
16 | unary_op_gpu(inputs, out, name(), s);
17 | }
18 | }
19 | } // namespace mlx::core
20 |
--------------------------------------------------------------------------------
/tests_jaccl/deploy.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Syncs the current directory to the test hosts
4 | # Usage: ./deploy.sh
5 |
6 | REMOTE_PATH="/Users/anemll/SourceRelease/GITHUB/ML_playground/mlx-rdma"
7 |
8 | echo "Deploying to m4p.local..."
9 | rsync -avz --exclude '.git' --exclude 'build' ../ m4p.local:$REMOTE_PATH/
10 |
11 | echo "Deploying to m3u.local..."
12 | # Check if m3u is up first to avoid long timeout if it's still down
13 | if ping -c 1 -W 1 m3u.local &> /dev/null; then
14 | rsync -avz --exclude '.git' --exclude 'build' ../ m3u.local:$REMOTE_PATH/
15 | else
16 | echo "m3u.local seems to be down. Skipping."
17 | fi
18 |
19 | echo "Done."
20 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 |
3 | # You can set these variables from the command line.
4 | SPHINXOPTS =
5 | SPHINXBUILD = sphinx-build
6 | SOURCEDIR = src
7 | BUILDDIR = build
8 |
9 | # Put it first so that "make" without argument is like "make help".
10 | help:
11 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
12 |
13 | .PHONY: help Makefile
14 |
15 | # Catch-all target: route all unknown targets to Sphinx using the new
16 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
17 | %: Makefile
18 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
19 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/unary/log.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/unary/unary.cuh"
4 |
5 | namespace mlx::core {
6 | void Log::eval_gpu(const std::vector& inputs, array& out) {
7 | nvtx3::scoped_range r("Log::eval_gpu");
8 | auto& s = out.primitive().stream();
9 | switch (base_) {
10 | case Base::e:
11 | unary_op_gpu(inputs, out, name(), s);
12 | break;
13 | case Base::two:
14 | unary_op_gpu(inputs, out, name(), s);
15 | break;
16 | case Base::ten:
17 | unary_op_gpu(inputs, out, name(), s);
18 | break;
19 | }
20 | }
21 | } // namespace mlx::core
22 |
--------------------------------------------------------------------------------
/examples/export/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.27)
2 |
3 | project(import_mlx LANGUAGES CXX)
4 |
5 | set(CMAKE_CXX_STANDARD 17)
6 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
7 |
8 | find_package(
9 | Python 3.9
10 | COMPONENTS Interpreter Development.Module
11 | REQUIRED)
12 | execute_process(
13 | COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
14 | OUTPUT_STRIP_TRAILING_WHITESPACE
15 | OUTPUT_VARIABLE MLX_ROOT)
16 | find_package(MLX CONFIG REQUIRED)
17 |
18 | add_executable(eval_mlp eval_mlp.cpp)
19 | target_link_libraries(eval_mlp PRIVATE mlx)
20 |
21 | add_executable(train_mlp train_mlp.cpp)
22 | target_link_libraries(train_mlp PRIVATE mlx)
23 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/threefry.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 |
8 | namespace mlx::core::random {
9 |
10 | /** Applies the Threefry 2x32 hash function.
11 | * This code is based on the Jax counter-based and splittable PRNG
12 | * https://github.com/google/jax/blob/main/docs/jep/263-prng.md
13 | *
14 | * Original Threefry reference:
15 | * http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
16 | */
17 | std::pair threefry2x32_hash(
18 | const std::pair& key,
19 | std::pair count);
20 |
21 | } // namespace mlx::core::random
22 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report about an issue you've encountered
4 | title: "[BUG] "
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 |
15 | Include code snippet
16 | ```python
17 |
18 | ```
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Desktop (please complete the following information):**
24 | - OS Version: [e.g. MacOS 14.1.2]
25 | - Version [e.g. 0.7.0]
26 |
27 | **Additional context**
28 | Add any other context about the problem here.
29 |
--------------------------------------------------------------------------------
/.github/workflows/documentation.yml:
--------------------------------------------------------------------------------
1 | name: Documentation
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | permissions:
7 | contents: read
8 |
9 | jobs:
10 | build:
11 | runs-on: ubuntu-22.04
12 | steps:
13 | - uses: actions/checkout@v5
14 | - uses: ./.github/actions/build-docs
15 |
16 | deploy:
17 | needs: build
18 | permissions:
19 | pages: write
20 | id-token: write
21 | runs-on: ubuntu-latest
22 | environment:
23 | name: github-pages
24 | url: ${{ steps.deployment.outputs.page_url }}
25 | steps:
26 | - name: Deploy to GitHub Pages
27 | id: deployment
28 | uses: actions/deploy-pages@v4
29 |
--------------------------------------------------------------------------------
/examples/extensions/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023-2024 Apple Inc.
2 |
3 | from setuptools import setup
4 |
5 | from mlx import extension
6 |
7 | if __name__ == "__main__":
8 | setup(
9 | name="mlx_sample_extensions",
10 | version="0.0.0",
11 | description="Sample C++ and Metal extensions for MLX primitives.",
12 | ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
13 | cmdclass={"build_ext": extension.CMakeBuild},
14 | packages=["mlx_sample_extensions"],
15 | package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
16 | zip_safe=False,
17 | python_requires=">=3.8",
18 | )
19 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | title: mlx
3 | message: >-
4 | If you use this software, please cite it using the
5 | metadata from this file.
6 | type: software
7 | authors:
8 | - given-names: Awni
9 | family-names: Hannun
10 | affiliation: Apple
11 | - given-names: Jagrit
12 | family-names: Digani
13 | affiliation: Apple
14 | - given-names: Angelos
15 | family-names: Katharopoulos
16 | affiliation: Apple
17 | - given-names: Ronan
18 | family-names: Collobert
19 | affiliation: Apple
20 | repository-code: 'https://github.com/ml-explore'
21 | abstract: >-
22 | MLX: efficient and flexible machine learning on Apple
23 | silicon
24 | license: MIT
25 |
--------------------------------------------------------------------------------
/docs/src/python/tree_utils.rst:
--------------------------------------------------------------------------------
1 | .. _utils:
2 |
3 | Tree Utils
4 | ==========
5 |
6 | In MLX we consider a python tree to be an arbitrarily nested collection of
7 | dictionaries, lists and tuples without cycles. Functions in this module that
8 | return python trees will be using the default python ``dict``, ``list`` and
9 | ``tuple`` but they can usually process objects that inherit from any of these.
10 |
11 | .. note::
12 | Dictionaries should have keys that are valid python identifiers.
13 |
14 | .. currentmodule:: mlx.utils
15 |
16 | .. autosummary::
17 | :toctree: _autosummary
18 |
19 | tree_flatten
20 | tree_unflatten
21 | tree_map
22 | tree_map_with_path
23 | tree_reduce
24 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/logsumexp.metal:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | using namespace metal;
7 |
8 | // clang-format off
9 | #include "mlx/backend/metal/kernels/utils.h"
10 | #include "mlx/backend/metal/kernels/logsumexp.h"
11 |
12 | #define instantiate_logsumexp(name, itype) \
13 | instantiate_kernel("block_logsumexp_" #name, logsumexp, itype) \
14 | instantiate_kernel("looped_logsumexp_" #name, logsumexp_looped, itype) \
15 |
16 | instantiate_logsumexp(float32, float)
17 | instantiate_logsumexp(float16, half)
18 | instantiate_logsumexp(bfloat16, bfloat16_t) // clang-format on
19 |
--------------------------------------------------------------------------------
/mlx/backend/no_gpu/eval.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/backend/gpu/available.h"
6 | #include "mlx/backend/gpu/eval.h"
7 |
8 | namespace mlx::core::gpu {
9 |
10 | bool is_available() {
11 | return false;
12 | }
13 |
14 | void new_stream(Stream) {}
15 |
16 | void eval(array&) {
17 | throw std::runtime_error("[gpu::eval] GPU backend is not available");
18 | }
19 |
20 | void finalize(Stream) {
21 | throw std::runtime_error("[gpu::finalize] GPU backend is not available");
22 | }
23 |
24 | void synchronize(Stream) {
25 | throw std::runtime_error("[gpu::synchronize] GPU backend is not available");
26 | }
27 |
28 | } // namespace mlx::core::gpu
29 |
--------------------------------------------------------------------------------
/python/mlx/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def main() -> None:
5 | from mlx.core import __version__
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument(
9 | "--version",
10 | action="version",
11 | version=__version__,
12 | help="Print the version number.",
13 | )
14 | parser.add_argument(
15 | "--cmake-dir",
16 | action="store_true",
17 | help="Print the path to the MLX CMake module directory.",
18 | )
19 | args = parser.parse_args()
20 | if args.cmake_dir:
21 | from pathlib import Path
22 |
23 | print(Path(__file__).parent)
24 |
25 |
26 | if __name__ == "__main__":
27 | main()
28 |
--------------------------------------------------------------------------------
/examples/cmake_project/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.27)
2 |
3 | project(example LANGUAGES CXX)
4 |
5 | set(CMAKE_CXX_STANDARD 17)
6 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
7 |
8 | # Comment the following two commands only the MLX C++ library is installed and
9 | # set(MLX_ROOT "/path/to/mlx") directly if needed.
10 | find_package(
11 | Python 3.9
12 | COMPONENTS Interpreter Development.Module
13 | REQUIRED)
14 | execute_process(
15 | COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
16 | OUTPUT_STRIP_TRAILING_WHITESPACE
17 | OUTPUT_VARIABLE MLX_ROOT)
18 |
19 | find_package(MLX CONFIG REQUIRED)
20 |
21 | add_executable(example example.cpp)
22 | target_link_libraries(example PRIVATE mlx)
23 |
--------------------------------------------------------------------------------
/mlx/mlx.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 | #include "mlx/backend/cuda/cuda.h"
7 | #include "mlx/backend/gpu/available.h"
8 | #include "mlx/backend/metal/metal.h"
9 | #include "mlx/compile.h"
10 | #include "mlx/device.h"
11 | #include "mlx/distributed/distributed.h"
12 | #include "mlx/distributed/ops.h"
13 | #include "mlx/einsum.h"
14 | #include "mlx/export.h"
15 | #include "mlx/fast.h"
16 | #include "mlx/fft.h"
17 | #include "mlx/io.h"
18 | #include "mlx/linalg.h"
19 | #include "mlx/memory.h"
20 | #include "mlx/ops.h"
21 | #include "mlx/random.h"
22 | #include "mlx/stream.h"
23 | #include "mlx/transforms.h"
24 | #include "mlx/utils.h"
25 | #include "mlx/version.h"
26 |
--------------------------------------------------------------------------------
/.github/scripts/setup+build-cpp-linux-fedora-container.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ex
3 |
4 | # [Setup] Install dependencies inside the container.
5 | dnf update -y
6 | dnf install -y \
7 | blas-devel \
8 | lapack-devel \
9 | openblas-devel \
10 | make \
11 | cmake \
12 | clang \
13 | git
14 | dnf clean all
15 |
16 | # [C++] CI Build Sanity Check: Verifies code compilation, not for release.
17 | export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
18 | export DEBUG=1
19 | export CMAKE_C_COMPILER=/usr/bin/clang
20 | export CMAKE_CXX_COMPILER=/usr/bin/clang++
21 |
22 | mkdir -p build
23 | pushd build
24 | cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
25 | make -j $(nproc)
26 | ./tests/tests
27 | popd
28 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/arange.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 | #include "mlx/backend/cpu/encoder.h"
7 |
8 | namespace mlx::core {
9 |
10 | namespace {
11 |
12 | template
13 | void arange(T start, T next, array& out, size_t size, Stream stream) {
14 | auto ptr = out.data();
15 | auto step_size = next - start;
16 | auto& encoder = cpu::get_command_encoder(stream);
17 | encoder.set_output_array(out);
18 | encoder.dispatch([ptr, start, step_size, size]() mutable {
19 | for (int i = 0; i < size; ++i) {
20 | ptr[i] = start;
21 | start += step_size;
22 | }
23 | });
24 | }
25 |
26 | } // namespace
27 |
28 | } // namespace mlx::core
29 |
--------------------------------------------------------------------------------
/benchmarks/python/comparative/README.md:
--------------------------------------------------------------------------------
1 | Microbenchmarks comparing MLX to PyTorch
2 | ========================================
3 |
4 | Implement the same microbenchmarks in MLX and PyTorch to compare and make a
5 | list of the biggest possible performance improvements and/or regressions.
6 |
7 | Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for
8 | instance to measure the times it takes to sum across the 3rd axis of the above
9 | tensor on the cpu.
10 |
11 | `compare.py` runs several benchmarks and compares the speed-up or lack thereof
12 | in comparison to PyTorch.
13 |
14 | Each bench script can be run with `--print-pid` to print the PID and wait for a
15 | key in order to ease attaching a debugger.
16 |
--------------------------------------------------------------------------------
/.github/actions/setup-macos/action.yml:
--------------------------------------------------------------------------------
1 | name: 'Setup macOS Environment'
2 | description: 'Install dependencies for macOS builds'
3 |
4 | inputs:
5 | python-version:
6 | description: 'Python version to use'
7 | required: false
8 | default: '3.10'
9 |
10 | runs:
11 | using: "composite"
12 | steps:
13 | - name: Install Homebrew packages
14 | shell: sh
15 | run: /opt/homebrew/bin/brew install openmpi
16 |
17 | - name: Verify MetalToolchain installed
18 | shell: bash
19 | run: xcodebuild -showComponent MetalToolchain
20 |
21 | - uses: conda-incubator/setup-miniconda@v3
22 | with:
23 | miniconda-version: "latest"
24 | python-version: ${{ inputs.python-version }}
25 |
--------------------------------------------------------------------------------
/mlx/device.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | namespace mlx::core {
6 |
7 | struct Device {
8 | enum class DeviceType {
9 | cpu,
10 | gpu,
11 | };
12 |
13 | static constexpr DeviceType cpu = DeviceType::cpu;
14 | static constexpr DeviceType gpu = DeviceType::gpu;
15 |
16 | Device(DeviceType type, int index = 0) : type(type), index(index) {}
17 |
18 | DeviceType type;
19 | int index;
20 | };
21 |
22 | const Device& default_device();
23 |
24 | void set_default_device(const Device& d);
25 |
26 | bool operator==(const Device& lhs, const Device& rhs);
27 | bool operator!=(const Device& lhs, const Device& rhs);
28 |
29 | bool is_available(const Device& d);
30 |
31 | } // namespace mlx::core
32 |
--------------------------------------------------------------------------------
/docs/src/python/nn/functions.rst:
--------------------------------------------------------------------------------
1 | .. _nn_functions:
2 |
3 | .. currentmodule:: mlx.nn
4 |
5 | Functions
6 | ---------
7 |
8 | Layers without parameters (e.g. activation functions) are also provided as
9 | simple functions.
10 |
11 | .. autosummary::
12 | :toctree: _autosummary_functions
13 | :template: nn-module-template.rst
14 |
15 | elu
16 | celu
17 | gelu
18 | gelu_approx
19 | gelu_fast_approx
20 | glu
21 | hard_shrink
22 | hard_tanh
23 | hardswish
24 | leaky_relu
25 | log_sigmoid
26 | log_softmax
27 | mish
28 | prelu
29 | relu
30 | relu2
31 | relu6
32 | selu
33 | sigmoid
34 | silu
35 | softmax
36 | softmin
37 | softplus
38 | softshrink
39 | step
40 | tanh
41 |
--------------------------------------------------------------------------------
/python/mlx/nn/layers/containers.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | from mlx.nn.layers.base import Module
4 |
5 |
6 | class Sequential(Module):
7 | """A layer that calls the passed callables in order.
8 |
9 | We can pass either modules or plain callables to the Sequential module. If
10 | our functions have learnable parameters they should be implemented as
11 | ``nn.Module`` instances.
12 |
13 | Args:
14 | modules (tuple of Callables): The modules to call in order
15 | """
16 |
17 | def __init__(self, *modules):
18 | super().__init__()
19 | self.layers = list(modules)
20 |
21 | def __call__(self, x):
22 | for m in self.layers:
23 | x = m(x)
24 | return x
25 |
--------------------------------------------------------------------------------
/docs/src/usage/using_streams.rst:
--------------------------------------------------------------------------------
1 | .. _using_streams:
2 |
3 | Using Streams
4 | =============
5 |
6 | .. currentmodule:: mlx.core
7 |
8 | Specifying the :obj:`Stream`
9 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
10 |
11 | All operations (including random number generation) take an optional
12 | keyword argument ``stream``. The ``stream`` kwarg specifies which
13 | :obj:`Stream` the operation should run on. If the stream is unspecified then
14 | the operation is run on the default stream of the default device:
15 | ``mx.default_stream(mx.default_device())``. The ``stream`` kwarg can also
16 | be a :obj:`Device` (e.g. ``stream=my_device``) in which case the operation is
17 | run on the default stream of the provided device
18 | ``mx.default_stream(my_device)``.
19 |
--------------------------------------------------------------------------------
/mlx/backend/common/broadcasting.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include "mlx/backend/common/utils.h"
4 |
5 | namespace mlx::core {
6 |
7 | void broadcast(const array& in, array& out) {
8 | if (out.size() == 0) {
9 | out.set_data(allocator::malloc(0));
10 | return;
11 | }
12 | Strides strides(out.ndim(), 0);
13 | int diff = out.ndim() - in.ndim();
14 | for (int i = in.ndim() - 1; i >= 0; --i) {
15 | strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
16 | }
17 | auto flags = in.flags();
18 | if (out.size() > in.size()) {
19 | flags.row_contiguous = flags.col_contiguous = false;
20 | }
21 | out.copy_shared_buffer(in, strides, flags, in.data_size());
22 | }
23 |
24 | } // namespace mlx::core
25 |
--------------------------------------------------------------------------------
/benchmarks/cpp/compare_devices.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include "mlx/mlx.h"
5 | #include "time_utils.h"
6 |
7 | namespace mx = mlx::core;
8 |
9 | void time_add_op() {
10 | std::vector sizes(1, 1);
11 | for (int i = 0; i < 9; ++i) {
12 | sizes.push_back(10 * sizes.back());
13 | }
14 | set_default_device(mx::Device::cpu);
15 | for (auto size : sizes) {
16 | auto a = mx::random::uniform({size});
17 | auto b = mx::random::uniform({size});
18 | mx::eval(a, b);
19 | std::cout << "Size " << size << std::endl;
20 | TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
21 | TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
22 | }
23 | }
24 |
25 | int main() {
26 | time_add_op();
27 | }
28 |
--------------------------------------------------------------------------------
/.github/actions/build-linux/action.yml:
--------------------------------------------------------------------------------
1 | name: 'Build and Test on Linux'
2 | description: 'Build and test MLX on Linux'
3 |
4 | runs:
5 | using: "composite"
6 | steps:
7 | - name: Install Python package
8 | shell: sh
9 | env:
10 | CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
11 | DEBUG: 1
12 | run: pip install --no-build-isolation -e ".[dev]" -v
13 |
14 | - name: Generate package stubs
15 | shell: sh
16 | run: |
17 | pip install typing_extensions
18 | python setup.py generate_stubs
19 |
20 | - name: Build CPP only
21 | shell: bash
22 | run: |
23 | mkdir -p build && cd build
24 | cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
25 | make -j $(nproc)
26 |
--------------------------------------------------------------------------------
/mlx/backend/no_cpu/compiled.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #include "mlx/compile_impl.h"
4 | #include "mlx/primitives.h"
5 |
6 | namespace mlx::core {
7 |
8 | // GPU compile is always available if the GPU is available and since we are in
9 | // this file CPU compile is not available so check if the device is a GPU
10 | // device.
11 | namespace detail {
12 | bool compile_available_for_device(const Device& device) {
13 | return device == Device::gpu;
14 | }
15 | } // namespace detail
16 |
17 | void Compiled::eval_cpu(
18 | const std::vector& inputs,
19 | std::vector& outputs) {
20 | throw std::runtime_error(
21 | "[Compiled::eval_cpu] CPU compilation not supported on the platform.");
22 | }
23 |
24 | } // namespace mlx::core
25 |
--------------------------------------------------------------------------------
/mlx/backend/common/unary.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/allocator.h"
6 | #include "mlx/backend/common/utils.h"
7 |
8 | namespace mlx::core {
9 |
10 | inline void set_unary_output_data(
11 | const array& in,
12 | array& out,
13 | std::function mallocfn = allocator::malloc) {
14 | if (in.flags().contiguous) {
15 | if (is_donatable(in, out)) {
16 | out.copy_shared_buffer(in);
17 | } else {
18 | out.set_data(
19 | mallocfn(in.data_size() * out.itemsize()),
20 | in.data_size(),
21 | in.strides(),
22 | in.flags());
23 | }
24 | } else {
25 | out.set_data(mallocfn(out.nbytes()));
26 | }
27 | }
28 |
29 | } // namespace mlx::core
30 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | ## Build the Docs
2 |
3 | ### Setup (do once)
4 |
5 | Install Doxygen:
6 |
7 | ```
8 | brew install doxygen
9 | ```
10 |
11 | Install Python packages:
12 |
13 | ```
14 | pip install -r requirements.txt
15 | ```
16 |
17 | ### Build
18 |
19 | Build the docs from `mlx/docs/`
20 |
21 | ```
22 | doxygen && make html
23 | ```
24 |
25 | View the docs by running a server in `mlx/docs/build/html/`:
26 |
27 | ```
28 | python -m http.server
29 | ```
30 |
31 | and point your browser to `http://localhost:`.
32 |
33 | ### Push to GitHub Pages
34 |
35 | Check-out the `gh-pages` branch (`git switch gh-pages`) and build
36 | the docs. Then force add the `build/html` directory:
37 |
38 | `git add -f build/html`
39 |
40 | Commit and push the changes to the `gh-pages` branch.
41 |
--------------------------------------------------------------------------------
/mlx/distributed/mpi/mpi_declarations.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | // Constants
4 |
5 | #define MPI_SUCCESS 0
6 | #define MPI_ANY_SOURCE -1
7 | #define MPI_ANY_TAG -1
8 | #define MPI_IN_PLACE ((void*)1)
9 | #define MPI_MAX_LIBRARY_VERSION_STRING 256
10 |
11 | // Define all the types that we use so that we don't include which
12 | // causes linker errors on some platforms.
13 | //
14 | // NOTE: We define everything for openmpi.
15 |
16 | typedef void* MPI_Comm;
17 | typedef void* MPI_Datatype;
18 | typedef void* MPI_Op;
19 |
20 | typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*);
21 |
22 | typedef struct ompi_status_public_t {
23 | int MPI_SOURCE;
24 | int MPI_TAG;
25 | int MPI_ERROR;
26 | int _cancelled;
27 | size_t _ucount;
28 | } MPI_Status;
29 |
--------------------------------------------------------------------------------
/mlx/backend/metal/binary.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 |
7 | namespace mlx::core {
8 |
9 | void binary_op_gpu(
10 | const std::vector& inputs,
11 | std::vector& outputs,
12 | const char* op,
13 | const Stream& s);
14 |
15 | void binary_op_gpu(
16 | const std::vector& inputs,
17 | array& out,
18 | const char* op,
19 | const Stream& s);
20 |
21 | void binary_op_gpu_inplace(
22 | const std::vector& inputs,
23 | std::vector& outputs,
24 | const char* op,
25 | const Stream& s);
26 |
27 | void binary_op_gpu_inplace(
28 | const std::vector& inputs,
29 | array& out,
30 | const char* op,
31 | const Stream& s);
32 |
33 | } // namespace mlx::core
34 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/arange.metal:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | // clang-format off
4 | #include "mlx/backend/metal/kernels/utils.h"
5 | #include "mlx/backend/metal/kernels/arange.h"
6 |
7 | #define instantiate_arange(tname, type) \
8 | instantiate_kernel("arange" #tname, arange, type)
9 |
10 | instantiate_arange(uint8, uint8_t)
11 | instantiate_arange(uint16, uint16_t)
12 | instantiate_arange(uint32, uint32_t)
13 | instantiate_arange(uint64, uint64_t)
14 | instantiate_arange(int8, int8_t)
15 | instantiate_arange(int16, int16_t)
16 | instantiate_arange(int32, int32_t)
17 | instantiate_arange(int64, int64_t)
18 | instantiate_arange(float16, half)
19 | instantiate_arange(float32, float)
20 | instantiate_arange(bfloat16, bfloat16_t) // clang-format on
21 |
--------------------------------------------------------------------------------
/benchmarks/python/rope_bench.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023-2024 Apple Inc.
2 |
3 | import mlx.core as mx
4 | import mlx.nn as nn
5 | from time_utils import time_fn
6 |
7 |
8 | def time_rope():
9 | rope = nn.RoPE(64)
10 |
11 | # vec
12 | x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
13 | mx.eval(x)
14 |
15 | def rope_vec(x):
16 | for _ in range(32):
17 | x = rope(x, offset=100)
18 | return x
19 |
20 | time_fn(rope_vec, x)
21 |
22 | # matrix
23 | x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
24 | mx.eval(x)
25 |
26 | def rope_mat(x):
27 | for _ in range(32):
28 | x = rope(x)
29 | return x
30 |
31 | time_fn(rope_mat, x)
32 |
33 |
34 | if __name__ == "__main__":
35 | time_rope()
36 |
--------------------------------------------------------------------------------
/mlx/backend/metal/resident.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/backend/metal/device.h"
6 |
7 | namespace mlx::core::metal {
8 |
9 | class ResidencySet {
10 | public:
11 | ResidencySet(MTL::Device* d);
12 | ~ResidencySet();
13 |
14 | ResidencySet(const ResidencySet&) = delete;
15 | ResidencySet& operator=(const ResidencySet&) = delete;
16 |
17 | const MTL::ResidencySet* mtl_residency_set() {
18 | return wired_set_;
19 | }
20 |
21 | void insert(MTL::Allocation* buf);
22 | void erase(MTL::Allocation* buf);
23 |
24 | void resize(size_t size);
25 |
26 | private:
27 | MTL::ResidencySet* wired_set_{nullptr};
28 | std::unordered_set unwired_set_;
29 | size_t capacity_{0};
30 | };
31 |
32 | } // namespace mlx::core::metal
33 |
--------------------------------------------------------------------------------
/mlx/backend/gpu/slicing.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 |
7 | namespace mlx::core {
8 |
9 | void slice_gpu(
10 | const array& in,
11 | array& out,
12 | const Shape& start_indices,
13 | const Shape& strides,
14 | const Stream& s);
15 |
16 | void concatenate_gpu(
17 | const std::vector& inputs,
18 | array& out,
19 | int axis,
20 | const Stream& s);
21 |
22 | void pad_gpu(
23 | const array& in,
24 | const array& val,
25 | array& out,
26 | const std::vector& axes,
27 | const Shape& low_pad_size,
28 | const Stream& s);
29 |
30 | array compute_dynamic_offset(
31 | const array& indices,
32 | const Strides& strides,
33 | const std::vector& axes,
34 | const Stream& s);
35 |
36 | } // namespace mlx::core
37 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v6.0.0
4 | hooks:
5 | - id: check-yaml
6 | # - id: end-of-file-fixer
7 | # - id: trailing-whitespace
8 | - repo: https://github.com/pre-commit/mirrors-clang-format
9 | rev: v19.1.7
10 | hooks:
11 | - id: clang-format
12 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
13 | - repo: https://github.com/psf/black-pre-commit-mirror
14 | rev: 25.1.0
15 | hooks:
16 | - id: black
17 |
18 | - repo: https://github.com/pycqa/isort
19 | rev: 6.0.0
20 | hooks:
21 | - id: isort
22 | args:
23 | - --profile=black
24 | - repo: https://github.com/cheshirekow/cmake-format-precommit
25 | rev: v0.6.13
26 | hooks:
27 | - id: cmake-format
28 |
--------------------------------------------------------------------------------
/python/scripts/repair_cuda.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | auditwheel repair dist/* \
4 | --plat manylinux_2_35_x86_64 \
5 | --exclude libcublas* \
6 | --exclude libnvrtc* \
7 | --exclude libcuda* \
8 | --exclude libcudnn* \
9 | --exclude libnccl* \
10 | -w wheel_tmp
11 |
12 |
13 | mkdir wheelhouse
14 | cd wheel_tmp
15 | repaired_wheel=$(find . -name "*.whl" -print -quit)
16 | unzip -q "${repaired_wheel}"
17 | rm "${repaired_wheel}"
18 | mlx_so="mlx/lib/libmlx.so"
19 | rpath=$(patchelf --print-rpath "${mlx_so}")
20 | base="\$ORIGIN/../../nvidia"
21 | rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib:${base}/nccl/lib
22 | patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
23 | python ../python/scripts/repair_record.py ${mlx_so}
24 |
25 | # Re-zip the repaired wheel
26 | zip -r -q "../wheelhouse/${repaired_wheel}" .
27 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/indexing/gather_front.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/backend/metal/kernels/indexing/indexing.h"
6 |
7 | template
8 | [[kernel]] void gather_front(
9 | const device T* src,
10 | const device IdxT* indices,
11 | device T* out,
12 | const constant int64_t& stride,
13 | const constant int& size,
14 | uint2 index [[thread_position_in_grid]],
15 | uint2 grid_dim [[threads_per_grid]]) {
16 | auto idx = offset_neg_idx(indices[index.y], size);
17 | LocT src_idx = static_cast(stride) * idx;
18 | LocT out_idx = static_cast(stride) * index.y;
19 |
20 | int s_idx = N * index.x;
21 | for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) {
22 | out[out_idx + s_idx] = src[src_idx + s_idx];
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/docs/src/_templates/module-base-class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. add toctree option to make autodoc generate the pages
6 |
7 | .. autoclass:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | .. rubric:: Attributes
12 |
13 | .. autosummary::
14 | :toctree: .
15 | {% for item in attributes %}
16 | ~{{ fullname }}.{{ item }}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 | {% block methods %}
22 | {% if methods %}
23 | .. rubric:: Methods
24 |
25 | .. autosummary::
26 | :toctree: .
27 | {% for item in methods %}
28 | {%- if item not in inherited_members and item != '__init__' %}
29 | ~{{ fullname }}.{{ item }}
30 | {%- endif -%}
31 | {%- endfor %}
32 | {% endif %}
33 | {% endblock %}
34 |
--------------------------------------------------------------------------------
/.github/actions/build-cuda/action.yml:
--------------------------------------------------------------------------------
1 | name: 'Build and Test with CUDA'
2 | description: 'Build and test MLX with CUDA'
3 |
4 | inputs:
5 | toolkit:
6 | description: 'The CUDA toolkit'
7 | required: true
8 |
9 | runs:
10 | using: "composite"
11 | steps:
12 | - name: Install Python package
13 | shell: bash
14 | env:
15 | DEBUG: 1
16 | CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
17 | run: pip install --no-build-isolation -e ".[dev]" -v
18 |
19 | - name: Build CPP only
20 | shell: bash
21 | run: |
22 | cmake . -B build \
23 | -DMLX_BUILD_CUDA=ON \
24 | -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \
25 | -DCMAKE_BUILD_TYPE=DEBUG
26 | cmake --build build -j $(nproc)
27 |
--------------------------------------------------------------------------------
/python/scripts/repair_record.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import glob
3 | import hashlib
4 | import sys
5 |
6 | filename = sys.argv[1]
7 |
8 |
9 | # Compute the new hash and size
10 | def urlsafe_b64encode(data: bytes) -> bytes:
11 | return base64.urlsafe_b64encode(data).rstrip(b"=")
12 |
13 |
14 | hasher = hashlib.sha256()
15 | with open(filename, "rb") as f:
16 | data = f.read()
17 | hasher.update(data)
18 | hash_str = urlsafe_b64encode(hasher.digest()).decode("ascii")
19 | size = len(data)
20 |
21 | # Update the record file
22 | record_file = glob.glob("*/RECORD")[0]
23 | with open(record_file, "r") as f:
24 | lines = [l.split(",") for l in f.readlines()]
25 |
26 | for l in lines:
27 | if filename == l[0]:
28 | l[1] = hash_str
29 | l[2] = f"{size}\n"
30 |
31 | with open(record_file, "w") as f:
32 | for l in lines:
33 | f.write(",".join(l))
34 |
--------------------------------------------------------------------------------
/docs/src/python/nn/module.rst:
--------------------------------------------------------------------------------
1 | Module
2 | ======
3 |
4 | .. currentmodule:: mlx.nn
5 |
6 | .. autoclass:: Module
7 |
8 | .. rubric:: Attributes
9 |
10 | .. autosummary::
11 | :toctree: _autosummary
12 |
13 | Module.training
14 | Module.state
15 |
16 | .. rubric:: Methods
17 |
18 | .. autosummary::
19 | :toctree: _autosummary
20 |
21 | Module.apply
22 | Module.apply_to_modules
23 | Module.children
24 | Module.eval
25 | Module.filter_and_map
26 | Module.freeze
27 | Module.leaf_modules
28 | Module.load_weights
29 | Module.modules
30 | Module.named_modules
31 | Module.parameters
32 | Module.save_weights
33 | Module.set_dtype
34 | Module.train
35 | Module.trainable_parameters
36 | Module.unfreeze
37 | Module.update
38 | Module.update_modules
39 |
--------------------------------------------------------------------------------
/mlx/backend/metal/make_compiled_preamble.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # This script generates a C++ function that provides the Metal unary and binary
4 | # ops at runtime for use with kernel generation.
5 | #
6 | # Copyright © 2023-24 Apple Inc.
7 |
8 | OUTPUT_DIR=$1
9 | CC=$2
10 | SRC_DIR=$3
11 | SRC_FILE=$4
12 | CFLAGS=$5
13 | SRC_NAME=$(basename -- "${SRC_FILE}")
14 | JIT_INCLUDES=${SRC_DIR}/mlx/backend/metal/kernels/jit
15 | INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
16 | OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
17 |
18 | mkdir -p "$OUTPUT_DIR"
19 | CONTENT=$($CC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P "$INPUT_FILE" $CFLAGS 2>/dev/null)
20 |
21 | cat << EOF > "$OUTPUT_FILE"
22 | namespace mlx::core::metal {
23 |
24 | const char* $SRC_NAME() {
25 | return R"preamble(
26 | $CONTENT
27 | )preamble";
28 | }
29 |
30 | } // namespace mlx::core::metal
31 | EOF
32 |
--------------------------------------------------------------------------------
/python/src/mlx_func.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | #include
8 | #include
9 |
10 | namespace nb = nanobind;
11 | using namespace nb::literals;
12 |
13 | nb::callable mlx_func(
14 | nb::object func,
15 | const nb::callable& orig_func,
16 | std::vector deps);
17 |
18 | template
19 | nb::callable mlx_func(F func, const nb::callable& orig_func, Deps&&... deps) {
20 | return mlx_func(
21 | nb::cpp_function(std::move(func)),
22 | orig_func,
23 | std::vector{deps.ptr()...});
24 | }
25 |
26 | template
27 | nb::callable
28 | mlx_func(nb::object func, const nb::callable& orig_func, Deps&&... deps) {
29 | return mlx_func(
30 | std::move(func), orig_func, std::vector{deps.ptr()...});
31 | }
32 |
--------------------------------------------------------------------------------
/examples/export/README.md:
--------------------------------------------------------------------------------
1 | ## Setup
2 |
3 | Install MLX:
4 |
5 | ```bash
6 | pip install mlx>=0.22
7 | ```
8 |
9 | Build the C++ examples:
10 |
11 | ```bash
12 | cmake -B build -DCMAKE_BUILD_TYPE=Release
13 | cmake --build build
14 | ```
15 |
16 | ## Run
17 |
18 | ### Eval MLP
19 |
20 | Run the Python script to export the eval function:
21 |
22 | ```bash
23 | python eval_mlp.py
24 | ```
25 |
26 | Then run the C++ program to import and run the function:
27 |
28 | ```
29 | ./build/eval_mlp
30 | ```
31 |
32 | The Python and C++ programs should output the same result.
33 |
34 | ### Train MLP
35 |
36 | Run the Python script to export the model initialization and training
37 | functions:
38 |
39 | ```bash
40 | python train_mlp.py
41 | ```
42 |
43 | Then run the C++ program to import and run the functions:
44 |
45 | ```
46 | ./build/train_mlp
47 | ```
48 |
49 | The Python and C++ programs should output the same results.
50 |
--------------------------------------------------------------------------------
/mlx/distributed/reduction_ops.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | namespace mlx::core::distributed::detail {
4 |
5 | template
6 | struct SumOp {
7 | void operator()(const T* input, T* output, size_t N) const {
8 | while (N-- > 0) {
9 | *output += *input;
10 | input++;
11 | output++;
12 | }
13 | }
14 | };
15 |
16 | template
17 | struct MaxOp {
18 | void operator()(const T* input, T* output, size_t N) const {
19 | while (N-- > 0) {
20 | *output = std::max(*output, *input);
21 | input++;
22 | output++;
23 | }
24 | }
25 | };
26 |
27 | template
28 | struct MinOp {
29 | void operator()(const T* input, T* output, size_t N) const {
30 | while (N-- > 0) {
31 | *output = std::min(*output, *input);
32 | input++;
33 | output++;
34 | }
35 | }
36 | };
37 |
38 | } // namespace mlx::core::distributed::detail
39 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/defines.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #if defined __METAL__ || defined MLX_METAL_JIT
6 | #define MTL_CONST constant
7 | #else
8 | #define MTL_CONST
9 | #endif
10 |
11 | static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
12 | static MTL_CONST constexpr int REDUCE_N_READS = 4;
13 | static MTL_CONST constexpr int REDUCE_N_WRITES = 4;
14 | static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
15 | static MTL_CONST constexpr int RMS_N_READS = 4;
16 | static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
17 |
18 | // Instantiate a templated kernel.
19 | // Extra args are used as template parameters:
20 | // e.g. instantiate_kernel(binary_int, binary, a, b) ->
21 | // [[host_name(binary_int)]] [kernel] binary
22 | #define instantiate_kernel(name, func, ...) \
23 | template [[host_name( \
24 | name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;
25 |
--------------------------------------------------------------------------------
/tests_jaccl/build_cpp_benchmark.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Build the C++ separate FFN TP vs single benchmark.
3 |
4 | set -euo pipefail
5 |
6 | ROOT="$(cd "$(dirname "$0")/.." && pwd)"
7 | BUILD_DIR="$ROOT/build_cpp"
8 | CACHE_DIR="$BUILD_DIR/clang_module_cache"
9 | FAKE_HOME="$BUILD_DIR/fake_home"
10 | mkdir -p "$CACHE_DIR" "$FAKE_HOME"
11 | export CLANG_MODULE_CACHE_PATH="$CACHE_DIR"
12 | export HOME="$FAKE_HOME"
13 |
14 | # Only rerun CMake configure when the build directory is missing a cache.
15 | if [ ! -f "$BUILD_DIR/CMakeCache.txt" ]; then
16 | cmake -B "$BUILD_DIR" -S "$ROOT" \
17 | -DCMAKE_BUILD_TYPE=RelWithDebInfo \
18 | -DMLX_BUILD_BENCHMARKS=ON \
19 | -DCMAKE_CXX_FLAGS="-fmodules-cache-path=$CACHE_DIR"
20 | else
21 | echo "Found existing build at $BUILD_DIR (skipping configure)."
22 | fi
23 |
24 | cmake --build "$BUILD_DIR" --target separate_tp_vs_single
25 |
26 | echo "Benchmark built at: $BUILD_DIR/benchmarks/cpp/separate_tp_vs_single"
27 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/threefry.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/backend/cpu/threefry.h"
4 |
5 | namespace mlx::core::random {
6 |
7 | std::pair threefry2x32_hash(
8 | const std::pair& key,
9 | std::pair count) {
10 | constexpr static uint32_t rotations[2][4] = {
11 | {13, 15, 26, 6}, {17, 29, 16, 24}};
12 |
13 | uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA};
14 |
15 | count.first += ks[0];
16 | count.second += ks[1];
17 |
18 | for (int i = 0; i < 5; ++i) {
19 | for (auto r : rotations[i % 2]) {
20 | count.first += count.second;
21 | count.second = (count.second << r) | (count.second >> (32 - r));
22 | count.second ^= count.first;
23 | }
24 | count.first += ks[(i + 1) % 3];
25 | count.second += ks[(i + 2) % 3] + i + 1;
26 | }
27 |
28 | return count;
29 | }
30 |
31 | } // namespace mlx::core::random
32 |
--------------------------------------------------------------------------------
/mlx/dtype_utils.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/dtype_utils.h"
4 |
5 | namespace mlx::core {
6 |
7 | const char* dtype_to_string(Dtype arg) {
8 | switch (arg) {
9 | case bool_:
10 | return "bool";
11 | case int8:
12 | return "int8";
13 | case int16:
14 | return "int16";
15 | case int32:
16 | return "int32";
17 | case int64:
18 | return "int64";
19 | case uint8:
20 | return "uint8";
21 | case uint16:
22 | return "uint16";
23 | case uint32:
24 | return "uint32";
25 | case uint64:
26 | return "uint64";
27 | case float16:
28 | return "float16";
29 | case bfloat16:
30 | return "bfloat16";
31 | case float32:
32 | return "float32";
33 | case float64:
34 | return "float64";
35 | case complex64:
36 | return "complex64";
37 | default:
38 | return "unknown";
39 | }
40 | }
41 |
42 | } // namespace mlx::core
43 |
--------------------------------------------------------------------------------
/benchmarks/python/time_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023-2024 Apple Inc.
2 |
3 | import time
4 |
5 | import mlx.core as mx
6 |
7 |
8 | def time_fn(fn, *args, **kwargs):
9 | msg = kwargs.pop("msg", None)
10 | if msg:
11 | print(f"Timing {msg} ...", end=" ")
12 | else:
13 | print(f"Timing {fn.__name__} ...", end=" ")
14 |
15 | # warmup
16 | for _ in range(5):
17 | mx.eval(fn(*args, **kwargs))
18 |
19 | num_iters = 100
20 | tic = time.perf_counter()
21 | for _ in range(num_iters):
22 | x = mx.eval(fn(*args, **kwargs))
23 | toc = time.perf_counter()
24 |
25 | msec = 1e3 * (toc - tic) / num_iters
26 | print(f"{msec:.5f} msec")
27 |
28 |
29 | def measure_runtime(fn, **kwargs):
30 | # Warmup
31 | for _ in range(5):
32 | fn(**kwargs)
33 |
34 | tic = time.time()
35 | iters = 100
36 | for _ in range(iters):
37 | fn(**kwargs)
38 | return (time.time() - tic) * 1000 / iters
39 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/bitwise_binary.cu:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/binary/binary.cuh"
4 |
5 | namespace mlx::core {
6 | void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) {
7 | nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
8 | auto& s = out.primitive().stream();
9 | switch (op_) {
10 | case BitwiseBinary::And:
11 | binary_op_gpu(inputs, out, name(), s);
12 | break;
13 | case BitwiseBinary::Or:
14 | binary_op_gpu(inputs, out, name(), s);
15 | break;
16 | case BitwiseBinary::Xor:
17 | binary_op_gpu(inputs, out, name(), s);
18 | break;
19 | case BitwiseBinary::LeftShift:
20 | binary_op_gpu(inputs, out, name(), s);
21 | break;
22 | case BitwiseBinary::RightShift:
23 | binary_op_gpu(inputs, out, name(), s);
24 | break;
25 | }
26 | }
27 | } // namespace mlx::core
28 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/device/indexing.cuh:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | namespace mlx::core::cu {
7 |
8 | // Convert an absolute index to positions in a 3d grid, assuming the index is
9 | // calculated with:
10 | // index = x * dim1 * dim2 + y * dim2 + z
11 | template
12 | inline __host__ __device__ cuda::std::tuple
13 | index_to_dims(T index, T dim1, T dim2) {
14 | T x = index / (dim1 * dim2);
15 | T y = (index % (dim1 * dim2)) / dim2;
16 | T z = index % dim2;
17 | return cuda::std::make_tuple(x, y, z);
18 | }
19 |
20 | // Get absolute index from possible negative index.
21 | template
22 | inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) {
23 | if constexpr (cuda::std::is_unsigned_v) {
24 | return idx;
25 | } else {
26 | return static_cast(idx < 0 ? idx + size : idx);
27 | }
28 | }
29 |
30 | } // namespace mlx::core::cu
31 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/make_compiled_preamble.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # This script generates a C++ function that provides the CPU
4 | # code for use with kernel generation.
5 | #
6 | # Copyright © 2023-24 Apple Inc.
7 |
8 |
9 | OUTPUT_FILE=$1
10 | GCC=$2
11 | SRCDIR=$3
12 | CLANG=$4
13 | ARCH=$5
14 |
15 | if [ "$CLANG" = "TRUE" ]; then
16 | read -r -d '' INCLUDES <<- EOM
17 | #include
18 | #include
19 | #include
20 | #include
21 | #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
22 | #include
23 | #endif
24 | EOM
25 | CC_FLAGS="-arch ${ARCH} -nobuiltininc -nostdinc"
26 | else
27 | CC_FLAGS="-std=c++17"
28 | fi
29 |
30 | CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E -P "$SRCDIR/mlx/backend/cpu/compiled_preamble.h" 2>/dev/null)
31 |
32 | cat << EOF > "$OUTPUT_FILE"
33 | const char* get_kernel_preamble() {
34 | return R"preamble(
35 | $INCLUDES
36 | $CONTENT
37 | using namespace mlx::core;
38 | using namespace mlx::core::detail;
39 | )preamble";
40 | }
41 | EOF
42 |
--------------------------------------------------------------------------------
/examples/export/train_mlp.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | namespace mx = mlx::core;
7 |
8 | int main() {
9 | int batch_size = 8;
10 | int input_dim = 32;
11 | int output_dim = 10;
12 |
13 | auto state = mx::import_function("init_mlp.mlxfn")({});
14 |
15 | // Make the input
16 | mx::random::seed(42);
17 | auto example_X = mx::random::normal({batch_size, input_dim});
18 | auto example_y = mx::random::randint(0, output_dim, {batch_size});
19 |
20 | // Import the function
21 | auto step = mx::import_function("train_mlp.mlxfn");
22 |
23 | // Call the imported function
24 | for (int it = 0; it < 100; ++it) {
25 | state.insert(state.end(), {example_X, example_y});
26 | state = step(state);
27 | eval(state);
28 | auto loss = state.back();
29 | state.pop_back();
30 | if (it % 10 == 0) {
31 | std::cout << "Loss " << loss.item() << std::endl;
32 | }
33 | }
34 | return 0;
35 | }
36 |
--------------------------------------------------------------------------------
/tests/device_tests.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "doctest/doctest.h"
4 |
5 | #include
6 |
7 | #include "mlx/mlx.h"
8 |
9 | using namespace mlx::core;
10 |
11 | TEST_CASE("test device placement") {
12 | auto device = default_device();
13 | Device d = gpu::is_available() ? Device::gpu : Device::cpu;
14 | if (std::getenv("DEVICE") == nullptr) {
15 | CHECK_EQ(device, d);
16 | }
17 |
18 | array x(1.0f);
19 | array y(1.0f);
20 | auto z = add(x, y, default_device());
21 | if (gpu::is_available()) {
22 | z = add(x, y, Device::gpu);
23 | z = add(x, y, Device(Device::gpu, 0));
24 | } else {
25 | CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument);
26 | CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument);
27 | }
28 |
29 | // Set the default device to the CPU
30 | set_default_device(Device::cpu);
31 | CHECK_EQ(default_device(), Device::cpu);
32 |
33 | // Revert
34 | set_default_device(device);
35 | }
36 |
--------------------------------------------------------------------------------
/mlx/backend/metal/no_metal.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/backend/metal/metal.h"
6 | #include "mlx/fast.h"
7 |
8 | namespace mlx::core {
9 |
10 | namespace metal {
11 |
12 | bool is_available() {
13 | return false;
14 | }
15 |
16 | void start_capture(std::string) {}
17 | void stop_capture() {}
18 |
19 | const std::unordered_map>&
20 | device_info() {
21 | throw std::runtime_error(
22 | "[metal::device_info] Cannot get device info without metal backend");
23 | };
24 |
25 | } // namespace metal
26 |
27 | namespace fast {
28 |
29 | CustomKernelFunction metal_kernel(
30 | const std::string&,
31 | const std::vector&,
32 | const std::vector&,
33 | const std::string&,
34 | const std::string&,
35 | bool,
36 | bool) {
37 | throw std::runtime_error("[metal_kernel] No Metal back-end.");
38 | }
39 |
40 | } // namespace fast
41 |
42 | } // namespace mlx::core
43 |
--------------------------------------------------------------------------------
/benchmarks/numpy/single_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import numpy as np
4 | from time_utils import time_fn
5 |
6 |
7 | def time_add():
8 | a = np.ones((100, 100, 10), dtype=np.float32)
9 | b = np.ones((100, 100, 10), dtype=np.float32)
10 | time_fn(np.add, a, b)
11 |
12 |
13 | def time_matmul():
14 | a = np.random.rand(1000, 500).astype(np.float32)
15 | b = np.random.rand(500, 1000).astype(np.float32)
16 | time_fn(np.matmul, a, b)
17 |
18 |
19 | def time_exp():
20 | a = np.random.randn(1000, 100).astype(np.float32)
21 | time_fn(np.exp, a)
22 |
23 |
24 | def time_take():
25 | a = np.random.rand(10000, 500)
26 | ids = np.random.randint(0, 10000, (20, 10))
27 | ids = [idx.reshape(-1) for idx in np.split(ids, 20)]
28 |
29 | def random_take():
30 | return [np.take(a, idx, 0) for idx in ids]
31 |
32 | time_fn(random_take)
33 |
34 |
35 | if __name__ == "__main__":
36 | time_add()
37 | time_matmul()
38 | time_exp()
39 | time_take()
40 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/device/scatter_ops.cuh:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/backend/cuda/device/atomic_ops.cuh"
6 |
7 | namespace mlx::core::cu {
8 |
9 | struct ScatterAssign {
10 | template
11 | __device__ void operator()(T* out, T val) const {
12 | *out = val;
13 | }
14 | };
15 |
16 | struct ScatterSum {
17 | template
18 | __device__ void operator()(T* out, T val) const {
19 | atomic_add(out, val);
20 | }
21 | };
22 |
23 | struct ScatterProd {
24 | template
25 | __device__ void operator()(T* out, T val) const {
26 | atomic_prod(out, val);
27 | }
28 | };
29 |
30 | struct ScatterMax {
31 | template
32 | __device__ void operator()(T* out, T val) const {
33 | atomic_max(out, val);
34 | }
35 | };
36 |
37 | struct ScatterMin {
38 | template
39 | __device__ void operator()(T* out, T val) const {
40 | atomic_min(out, val);
41 | }
42 | };
43 |
44 | } // namespace mlx::core::cu
45 |
--------------------------------------------------------------------------------
/mlx/distributed/nccl/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | if(MLX_BUILD_CUDA)
2 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nccl.cpp)
3 | find_package(NCCL)
4 | if(NCCL_FOUND)
5 | target_link_libraries(mlx PRIVATE ${NCCL_LIBRARIES})
6 | target_include_directories(mlx PRIVATE ${NCCL_INCLUDE_DIRS})
7 | else()
8 | message(
9 | STATUS
10 | "NCCL not found, using stubs. To run distributed with NCCL backend, install NCCL."
11 | )
12 |
13 | include(ExternalProject)
14 | ExternalProject_Add(
15 | nccl_stub
16 | SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/nccl_stub"
17 | BUILD_COMMAND ${CMAKE_COMMAND} --build .
18 | INSTALL_COMMAND "")
19 | set(NCCL_PATH
20 | "${CMAKE_CURRENT_BINARY_DIR}/nccl_stub-prefix/src/nccl_stub-build/")
21 | target_link_libraries(mlx PRIVATE ${NCCL_PATH}/libnccl.so)
22 | target_include_directories(mlx PRIVATE ${NCCL_PATH})
23 | endif()
24 | else()
25 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_nccl.cpp)
26 | endif()
27 |
--------------------------------------------------------------------------------
/python/tests/test_graph.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import io
4 | import unittest
5 |
6 | import mlx.core as mx
7 | import mlx_tests
8 |
9 |
10 | class TestGraph(mlx_tests.MLXTestCase):
11 | def test_to_dot(self):
12 | # Simply test that a few cases run.
13 | # Nothing too specific about the graph format
14 | # for now to keep it flexible
15 | a = mx.array(1.0)
16 | f = io.StringIO()
17 | mx.export_to_dot(f, a)
18 | f.seek(0)
19 | self.assertTrue(len(f.read()) > 0)
20 |
21 | b = mx.array(2.0)
22 | c = a + b
23 | f = io.StringIO()
24 | mx.export_to_dot(f, c)
25 | f.seek(0)
26 | self.assertTrue(len(f.read()) > 0)
27 |
28 | # Multi output case
29 | c = mx.divmod(a, b)
30 | f = io.StringIO()
31 | mx.export_to_dot(f, *c)
32 | f.seek(0)
33 | self.assertTrue(len(f.read()) > 0)
34 |
35 |
36 | if __name__ == "__main__":
37 | mlx_tests.MLXTestRunner()
38 |
--------------------------------------------------------------------------------
/tests/allocator_tests.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include "doctest/doctest.h"
6 |
7 | #include "mlx/allocator.h"
8 |
9 | using namespace mlx::core;
10 |
11 | TEST_CASE("test simple allocations") {
12 | {
13 | auto buffer = allocator::malloc(sizeof(float));
14 | auto fptr = static_cast(buffer.raw_ptr());
15 | *fptr = 0.5f;
16 | CHECK_EQ(*fptr, 0.5f);
17 | allocator::free(buffer);
18 | }
19 |
20 | {
21 | auto buffer = allocator::malloc(128 * sizeof(int));
22 | int* ptr = static_cast(buffer.raw_ptr());
23 | for (int i = 0; i < 128; ++i) {
24 | ptr[i] = i;
25 | }
26 | allocator::free(buffer);
27 | }
28 |
29 | {
30 | auto buffer = allocator::malloc(0);
31 | allocator::free(buffer);
32 | }
33 | }
34 |
35 | TEST_CASE("test large allocations") {
36 | size_t size = 1 << 30;
37 | for (int i = 0; i < 100; ++i) {
38 | auto buffer = allocator::malloc(size);
39 | allocator::free(buffer);
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/.github/actions/build-macos-release/action.yml:
--------------------------------------------------------------------------------
1 | name: 'Build macOS release'
2 | description: 'Build MLX releases macOS'
3 |
4 | inputs:
5 | macos-target:
6 | description: 'macOS build target'
7 | required: false
8 | default: '15.0'
9 | build-backend:
10 | description: 'Build the backend mlx-metal package'
11 | type: boolean
12 | required: false
13 | default: false
14 |
15 | runs:
16 | using: "composite"
17 | steps:
18 | - name: Build Python package
19 | shell: bash -l {0}
20 | env:
21 | MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
22 | run: |
23 | pip install build
24 | python setup.py clean --all
25 | MLX_BUILD_STAGE=1 python -m build -w
26 |
27 | - name: Build backend package
28 | if: ${{ inputs.build-backend }}
29 | shell: bash -l {0}
30 | env:
31 | MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
32 | run: |
33 | python setup.py clean --all
34 | MLX_BUILD_STAGE=2 python -m build -w
35 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/quantized/quantized.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/device.h"
4 |
5 | namespace mlx::core {
6 |
7 | void affine_quantize(
8 | const array& w,
9 | array& wq,
10 | array& scales,
11 | array& biases,
12 | int group_size_,
13 | int bits_,
14 | cu::CommandEncoder& enc,
15 | const Stream& s);
16 |
17 | void affine_dequantize(
18 | const array& wq,
19 | const array& scales,
20 | const array& biases,
21 | array& w,
22 | int group_size_,
23 | int bits_,
24 | cu::CommandEncoder& enc,
25 | const Stream& s);
26 |
27 | void fp_quantize(
28 | const array& w,
29 | array& wq,
30 | array& scales,
31 | int group_size,
32 | int bits,
33 | cu::CommandEncoder& enc,
34 | const Stream& s);
35 |
36 | void fp_dequantize(
37 | const array& wq,
38 | const array& scales,
39 | array& w,
40 | int group_size,
41 | int bits,
42 | cu::CommandEncoder& enc,
43 | const Stream& s);
44 |
45 | } // namespace mlx::core
46 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/softmax.metal:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | using namespace metal;
7 |
8 | // clang-format off
9 | #include "mlx/backend/metal/kernels/utils.h"
10 | #include "mlx/backend/metal/kernels/softmax.h"
11 |
12 | #define instantiate_softmax(name, itype) \
13 | instantiate_kernel("block_softmax_" #name, softmax_single_row, itype) \
14 | instantiate_kernel("looped_softmax_" #name, softmax_looped, itype)
15 |
16 | #define instantiate_softmax_precise(name, itype) \
17 | instantiate_kernel("block_softmax_precise_" #name, softmax_single_row, itype, float) \
18 | instantiate_kernel("looped_softmax_precise_" #name, softmax_looped, itype, float)
19 |
20 | instantiate_softmax(float32, float)
21 | instantiate_softmax(float16, half)
22 | instantiate_softmax(bfloat16, bfloat16_t)
23 | instantiate_softmax_precise(float16, half)
24 | instantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on
25 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/copy.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | #include "mlx/array.h"
8 | #include "mlx/backend/common/copy.h"
9 | #include "mlx/backend/common/utils.h"
10 |
11 | namespace mlx::core {
12 |
13 | void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
14 | void copy_cpu_inplace(
15 | const array& src,
16 | array& dst,
17 | CopyType ctype,
18 | Stream stream);
19 |
20 | void copy_cpu_inplace(
21 | const array& src,
22 | array& dst,
23 | const Shape& data_shape,
24 | const Strides& i_strides,
25 | const Strides& o_strides,
26 | int64_t i_offset,
27 | int64_t o_offset,
28 | CopyType ctype,
29 | Stream stream,
30 | const std::optional& dynamic_i_offset = std::nullopt,
31 | const std::optional& dynamic_o_offset = std::nullopt);
32 |
33 | // Return a contiguous array with same shape that copies the data of |arr|.
34 | array contiguous_copy_cpu(const array& arr, Stream stream);
35 |
36 | } // namespace mlx::core
37 |
--------------------------------------------------------------------------------
/examples/extensions/bindings.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | #include "axpby/axpby.h"
7 |
8 | namespace nb = nanobind;
9 | using namespace nb::literals;
10 |
11 | NB_MODULE(_ext, m) {
12 | m.doc() = "Sample extension for MLX";
13 |
14 | m.def(
15 | "axpby",
16 | &my_ext::axpby,
17 | "x"_a,
18 | "y"_a,
19 | "alpha"_a,
20 | "beta"_a,
21 | nb::kw_only(),
22 | "stream"_a = nb::none(),
23 | R"(
24 | Scale and sum two vectors element-wise
25 | ``z = alpha * x + beta * y``
26 |
27 | Follows numpy style broadcasting between ``x`` and ``y``
28 | Inputs are upcasted to floats if needed
29 |
30 | Args:
31 | x (array): Input array.
32 | y (array): Input array.
33 | alpha (float): Scaling factor for ``x``.
34 | beta (float): Scaling factor for ``y``.
35 |
36 | Returns:
37 | array: ``alpha * x + beta * y``
38 | )");
39 | }
40 |
--------------------------------------------------------------------------------
/examples/cpp/metal_capture.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | #include "mlx/mlx.h"
7 |
8 | namespace mx = mlx::core;
9 |
10 | int main() {
11 | // To use Metal debugging and profiling:
12 | // 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
13 | // 2. Run with MTL_CAPTURE_ENABLED=1.
14 | mx::metal::start_capture("mlx_trace.gputrace");
15 |
16 | // Start at index two because the default GPU and CPU streams have indices
17 | // zero and one, respectively. This naming matches the label assigned to each
18 | // stream's command queue.
19 | auto s2 = new_stream(mx::Device::gpu);
20 | auto s3 = new_stream(mx::Device::gpu);
21 |
22 | auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
23 | auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
24 | auto x = mx::add(a, a, s2);
25 | auto y = mx::add(b, b, s3);
26 |
27 | // The multiply will happen on the default stream.
28 | std::cout << mx::multiply(x, y) << std::endl;
29 |
30 | mx::metal::stop_capture();
31 | }
32 |
--------------------------------------------------------------------------------
/mlx/stream.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/device.h"
6 |
7 | namespace mlx::core {
8 |
9 | struct Stream {
10 | int index;
11 | Device device;
12 | explicit Stream(int index, Device device) : index(index), device(device) {}
13 | };
14 |
15 | /** Get the default stream for the given device. */
16 | Stream default_stream(Device d);
17 |
18 | /** Make the stream the default for its device. */
19 | void set_default_stream(Stream s);
20 |
21 | /** Make a new stream on the given device. */
22 | Stream new_stream(Device d);
23 |
24 | /** Get the stream with the given index. */
25 | Stream get_stream(int index);
26 |
27 | inline bool operator==(const Stream& lhs, const Stream& rhs) {
28 | return lhs.index == rhs.index;
29 | }
30 |
31 | inline bool operator!=(const Stream& lhs, const Stream& rhs) {
32 | return !(lhs == rhs);
33 | }
34 |
35 | /* Synchronize with the default stream. */
36 | void synchronize();
37 |
38 | /* Synchronize with the provided stream. */
39 | void synchronize(Stream);
40 |
41 | } // namespace mlx::core
42 |
--------------------------------------------------------------------------------
/benchmarks/cpp/autograd.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/mlx.h"
6 | #include "time_utils.h"
7 |
8 | namespace mx = mlx::core;
9 |
10 | void time_value_and_grad() {
11 | auto x = mx::ones({200, 1000});
12 | mx::eval(x);
13 | auto fn = [](mx::array x) {
14 | for (int i = 0; i < 20; ++i) {
15 | x = mx::log(mx::exp(x));
16 | }
17 | return mx::sum(x);
18 | };
19 |
20 | auto grad_fn = mx::grad(fn);
21 | auto independent_value_and_grad = [&]() {
22 | auto value = fn(x);
23 | auto dfdx = grad_fn(x);
24 | return std::vector{value, dfdx};
25 | };
26 | TIME(independent_value_and_grad);
27 |
28 | auto value_and_grad_fn = mx::value_and_grad(fn);
29 | auto combined_value_and_grad = [&]() {
30 | auto [value, dfdx] = value_and_grad_fn(x);
31 | return std::vector{value, dfdx};
32 | };
33 | TIME(combined_value_and_grad);
34 | }
35 |
36 | int main() {
37 | std::cout << "Benchmarks for " << mx::default_device() << std::endl;
38 | time_value_and_grad();
39 | }
40 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/binary/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(
2 | mlx
3 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
4 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
5 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
6 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
7 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
8 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
9 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
10 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
11 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
12 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
13 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
14 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
15 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
16 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
17 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
18 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
19 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
20 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
21 | PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)
22 |
--------------------------------------------------------------------------------
/examples/python/logistic_regression.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import time
4 |
5 | import mlx.core as mx
6 |
7 | num_features = 100
8 | num_examples = 1_000
9 | num_iters = 10_000
10 | lr = 0.1
11 |
12 | # True parameters
13 | w_star = mx.random.normal((num_features,))
14 |
15 | # Input examples
16 | X = mx.random.normal((num_examples, num_features))
17 |
18 | # Labels
19 | y = (X @ w_star) > 0
20 |
21 |
22 | # Initialize random parameters
23 | w = 1e-2 * mx.random.normal((num_features,))
24 |
25 |
26 | def loss_fn(w):
27 | logits = X @ w
28 | return mx.mean(mx.logaddexp(0.0, logits) - y * logits)
29 |
30 |
31 | grad_fn = mx.grad(loss_fn)
32 |
33 | tic = time.time()
34 | for _ in range(num_iters):
35 | grad = grad_fn(w)
36 | w = w - lr * grad
37 | mx.eval(w)
38 |
39 | toc = time.time()
40 |
41 | loss = loss_fn(w)
42 | final_preds = (X @ w) > 0
43 | acc = mx.mean(final_preds == y)
44 |
45 | throughput = num_iters / (toc - tic)
46 | print(
47 | f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} "
48 | f"Throughput {throughput:.5f} (it/s)"
49 | )
50 |
--------------------------------------------------------------------------------
/mlx/backend/metal/reduce.h:
--------------------------------------------------------------------------------
1 | // Copyright @ 2023 - 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/backend/common/reduce.h"
6 | #include "mlx/backend/metal/device.h"
7 | #include "mlx/stream.h"
8 |
9 | namespace mlx::core {
10 |
11 | using metal::CommandEncoder;
12 |
13 | void all_reduce_dispatch(
14 | const array& in,
15 | array& out,
16 | const std::string& op_name,
17 | CommandEncoder& compute_encoder,
18 | metal::Device& d,
19 | const Stream& s);
20 |
21 | void row_reduce_general_dispatch(
22 | const array& in,
23 | array& out,
24 | const std::string& op_name,
25 | const ReductionPlan& plan,
26 | const std::vector& axes,
27 | CommandEncoder& compute_encoder,
28 | metal::Device& d,
29 | const Stream& s);
30 |
31 | void strided_reduce_general_dispatch(
32 | const array& in,
33 | array& out,
34 | const std::string& op_name,
35 | const ReductionPlan& plan,
36 | const std::vector& axes,
37 | CommandEncoder& compute_encoder,
38 | metal::Device& d,
39 | const Stream& s);
40 |
41 | } // namespace mlx::core
42 |
--------------------------------------------------------------------------------
/examples/python/linear_regression.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import time
4 |
5 | import mlx.core as mx
6 |
7 | num_features = 100
8 | num_examples = 1_000
9 | num_iters = 10_000
10 | lr = 0.01
11 |
12 | # True parameters
13 | w_star = mx.random.normal((num_features,))
14 |
15 | # Input examples (design matrix)
16 | X = mx.random.normal((num_examples, num_features))
17 |
18 | # Noisy labels
19 | eps = 1e-2 * mx.random.normal((num_examples,))
20 | y = X @ w_star + eps
21 |
22 | # Initialize random parameters
23 | w = 1e-2 * mx.random.normal((num_features,))
24 |
25 |
26 | def loss_fn(w):
27 | return 0.5 * mx.mean(mx.square(X @ w - y))
28 |
29 |
30 | grad_fn = mx.grad(loss_fn)
31 |
32 | tic = time.time()
33 | for _ in range(num_iters):
34 | grad = grad_fn(w)
35 | w = w - lr * grad
36 | mx.eval(w)
37 | toc = time.time()
38 |
39 | loss = loss_fn(w)
40 | error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
41 | throughput = num_iters / (toc - tic)
42 |
43 | print(
44 | f"Loss {loss.item():.5f}, L2 distance: |w-w*| = {error_norm:.5f}, "
45 | f"Throughput {throughput:.5f} (it/s)"
46 | )
47 |
--------------------------------------------------------------------------------
/mlx/io/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp)
2 |
3 | if(MLX_BUILD_SAFETENSORS)
4 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp)
5 | else()
6 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp)
7 | endif()
8 |
9 | if(MLX_BUILD_GGUF)
10 | message(STATUS "Downloading gguflib")
11 | FetchContent_Declare(
12 | gguflib
13 | GIT_REPOSITORY https://github.com/antirez/gguf-tools/
14 | GIT_TAG 8fa6eb65236618e28fd7710a0fba565f7faa1848)
15 | FetchContent_MakeAvailable(gguflib)
16 | target_include_directories(mlx
17 | PRIVATE $)
18 | add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c
19 | ${gguflib_SOURCE_DIR}/gguflib.c)
20 | target_link_libraries(mlx PRIVATE $)
21 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
22 | ${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp)
23 | else()
24 | target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp)
25 | endif()
26 |
--------------------------------------------------------------------------------
/.github/actions/build-docs/action.yml:
--------------------------------------------------------------------------------
1 | name: 'Build Documentation'
2 | description: 'Build documentation'
3 |
4 | runs:
5 | using: "composite"
6 | steps:
7 | - name: Setup machine
8 | uses: ./.github/actions/setup-linux
9 |
10 | - name: Install dependencies
11 | shell: bash
12 | run: |
13 | sudo apt-get install -y doxygen
14 | source .venv/bin/activate
15 | pip install -r docs/requirements.txt
16 | pip install . -v
17 |
18 | - name: Build documentation
19 | shell: bash
20 | run: |
21 | source .venv/bin/activate
22 | cd docs
23 | doxygen
24 | make html O=-W
25 |
26 | - name: Create artifact tar
27 | shell: bash
28 | run: tar -cf artifact.tar -C docs --dereference build/html index.html
29 |
30 | # Do it manually because upload-pages-artifact requires gtar
31 | - name: Upload artifact
32 | id: upload-artifact
33 | uses: actions/upload-artifact@v5
34 | with:
35 | name: github-pages
36 | path: artifact.tar
37 | retention-days: 1
38 | if-no-files-found: error
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright © 2023 Apple Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/src/python/nn/init.rst:
--------------------------------------------------------------------------------
1 | .. _init:
2 |
3 | .. currentmodule:: mlx.nn.init
4 |
5 | Initializers
6 | ------------
7 |
8 | The ``mlx.nn.init`` package contains commonly used initializers for neural
9 | network parameters. Initializers return a function which can be applied to any
10 | input :obj:`mlx.core.array` to produce an initialized output.
11 |
12 | For example:
13 |
14 | .. code:: python
15 |
16 | import mlx.core as mx
17 | import mlx.nn as nn
18 |
19 | init_fn = nn.init.uniform()
20 |
21 | # Produces a [2, 2] uniform matrix
22 | param = init_fn(mx.zeros((2, 2)))
23 |
24 | To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
25 | distribution, you can do:
26 |
27 | .. code:: python
28 |
29 | import mlx.nn as nn
30 | model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
31 | init_fn = nn.init.uniform(low=-0.1, high=0.1)
32 | model.apply(init_fn)
33 |
34 |
35 | .. autosummary::
36 | :toctree: _autosummary
37 |
38 | constant
39 | normal
40 | uniform
41 | identity
42 | glorot_normal
43 | glorot_uniform
44 | he_normal
45 | he_uniform
46 |
--------------------------------------------------------------------------------
/docs/src/python/nn/layers.rst:
--------------------------------------------------------------------------------
1 | .. _layers:
2 |
3 | .. currentmodule:: mlx.nn
4 |
5 | Layers
6 | ------
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 | :template: nn-module-template.rst
11 |
12 | ALiBi
13 | AvgPool1d
14 | AvgPool2d
15 | AvgPool3d
16 | BatchNorm
17 | CELU
18 | Conv1d
19 | Conv2d
20 | Conv3d
21 | ConvTranspose1d
22 | ConvTranspose2d
23 | ConvTranspose3d
24 | Dropout
25 | Dropout2d
26 | Dropout3d
27 | Embedding
28 | ELU
29 | GELU
30 | GLU
31 | GroupNorm
32 | GRU
33 | HardShrink
34 | HardTanh
35 | Hardswish
36 | InstanceNorm
37 | LayerNorm
38 | LeakyReLU
39 | Linear
40 | LogSigmoid
41 | LogSoftmax
42 | LSTM
43 | MaxPool1d
44 | MaxPool2d
45 | MaxPool3d
46 | Mish
47 | MultiHeadAttention
48 | PReLU
49 | QuantizedEmbedding
50 | QuantizedLinear
51 | RMSNorm
52 | ReLU
53 | ReLU2
54 | ReLU6
55 | RNN
56 | RoPE
57 | SELU
58 | Sequential
59 | Sigmoid
60 | SiLU
61 | SinusoidalPositionalEncoding
62 | Softmin
63 | Softshrink
64 | Softsign
65 | Softmax
66 | Softplus
67 | Step
68 | Tanh
69 | Transformer
70 | Upsample
71 |
--------------------------------------------------------------------------------
/python/src/indexing.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | #include "mlx/array.h"
8 | #include "python/src/utils.h"
9 |
10 | namespace mx = mlx::core;
11 | namespace nb = nanobind;
12 |
13 | mx::array mlx_get_item(const mx::array& src, const nb::object& obj);
14 | void mlx_set_item(
15 | mx::array& src,
16 | const nb::object& obj,
17 | const ScalarOrArray& v);
18 | mx::array mlx_add_item(
19 | const mx::array& src,
20 | const nb::object& obj,
21 | const ScalarOrArray& v);
22 | mx::array mlx_subtract_item(
23 | const mx::array& src,
24 | const nb::object& obj,
25 | const ScalarOrArray& v);
26 | mx::array mlx_multiply_item(
27 | const mx::array& src,
28 | const nb::object& obj,
29 | const ScalarOrArray& v);
30 | mx::array mlx_divide_item(
31 | const mx::array& src,
32 | const nb::object& obj,
33 | const ScalarOrArray& v);
34 | mx::array mlx_maximum_item(
35 | const mx::array& src,
36 | const nb::object& obj,
37 | const ScalarOrArray& v);
38 | mx::array mlx_minimum_item(
39 | const mx::array& src,
40 | const nb::object& obj,
41 | const ScalarOrArray& v);
42 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/no_cuda.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/cuda/cuda.h"
4 | #include "mlx/fast.h"
5 |
6 | namespace mlx::core {
7 |
8 | namespace cu {
9 |
10 | bool is_available() {
11 | return false;
12 | }
13 |
14 | } // namespace cu
15 |
16 | namespace fast {
17 |
18 | CustomKernelFunction cuda_kernel(
19 | const std::string&,
20 | const std::vector&,
21 | const std::vector&,
22 | const std::string&,
23 | const std::string&,
24 | bool,
25 | int) {
26 | throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
27 | }
28 |
29 | std::vector precompiled_cuda_kernel(
30 | const std::string&,
31 | const std::string&,
32 | const std::vector&,
33 | const std::vector&,
34 | const std::vector&,
35 | const std::vector&,
36 | std::tuple,
37 | std::tuple,
38 | int shared_memory,
39 | std::optional init_value,
40 | bool ensure_row_contiguous,
41 | StreamOrDevice) {
42 | throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
43 | }
44 |
45 | } // namespace fast
46 |
47 | } // namespace mlx::core
48 |
--------------------------------------------------------------------------------
/tests_jaccl/README.md:
--------------------------------------------------------------------------------
1 | # JACCL Distributed MLP Test
2 |
3 | This directory contains scripts to test Tensor Parallel MLP using the JACCL backend on `m4p.local` and `m3u.local`.
4 |
5 | ## Configuration
6 |
7 | - **Hosts**:
8 | - Rank 0: `m4p.local` (IP: `10.1.12.4`, Device: `rdma_en2`)
9 | - Rank 1: `m3u.local` (IP: `10.1.12.3`, Device: `rdma_en5`)
10 | - **Config File**: `jaccl_config.json` maps ranks to devices.
11 |
12 | ## Files
13 |
14 | - `test_tp_mlp.py`: The distributed MLP test script.
15 | - `jaccl_config.json`: Device configuration.
16 | - `run_jaccl.sh`: Helper script to run the test.
17 | - `deploy.sh`: Helper script to sync files to hosts.
18 |
19 | ## How to Run
20 |
21 | 1. **Deploy files**:
22 | Run `./deploy.sh` from this directory to sync code to both hosts.
23 |
24 | 2. **Run on Rank 0 (m4p.local)**:
25 | ```bash
26 | ssh m4p.local
27 | cd /Users/anemll/SourceRelease/GITHUB/ML_playground/mlx-rdma/tests_jaccl
28 | ./run_jaccl.sh 0 10.1.12.4:1234
29 | ```
30 |
31 | 3. **Run on Rank 1 (m3u.local)**:
32 | ```bash
33 | ssh m3u.local
34 | cd /Users/anemll/SourceRelease/GITHUB/ML_playground/mlx-rdma/tests_jaccl
35 | ./run_jaccl.sh 1 10.1.12.4:1234
36 | ```
37 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/gemms/simd_fp16.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/common/utils.h"
4 | #include "mlx/backend/cpu/gemm.h"
5 | #include "mlx/backend/cpu/gemms/simd_gemm.h"
6 |
7 | namespace mlx::core {
8 |
9 | template <>
10 | void matmul(
11 | const float16_t* a,
12 | const float16_t* b,
13 | float16_t* out,
14 | bool a_transposed,
15 | bool b_transposed,
16 | size_t lda,
17 | size_t ldb,
18 | size_t ldc,
19 | float alpha,
20 | float beta,
21 | size_t batch_size,
22 | const Shape& a_shape,
23 | const Strides& a_strides,
24 | const Shape& b_shape,
25 | const Strides& b_strides) {
26 | auto ndim = a_shape.size();
27 | size_t M = a_shape[ndim - 2];
28 | size_t N = b_shape[ndim - 1];
29 | size_t K = a_shape[ndim - 1];
30 | for (int i = 0; i < batch_size; ++i) {
31 | simd_gemm(
32 | a + elem_to_loc(M * K * i, a_shape, a_strides),
33 | b + elem_to_loc(K * N * i, b_shape, b_strides),
34 | out + M * N * i,
35 | a_transposed,
36 | b_transposed,
37 | M,
38 | N,
39 | K,
40 | alpha,
41 | beta);
42 | }
43 | }
44 |
45 | } // namespace mlx::core
46 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/gemms/simd_bf16.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/common/utils.h"
4 | #include "mlx/backend/cpu/gemm.h"
5 | #include "mlx/backend/cpu/gemms/simd_gemm.h"
6 |
7 | namespace mlx::core {
8 |
9 | template <>
10 | void matmul(
11 | const bfloat16_t* a,
12 | const bfloat16_t* b,
13 | bfloat16_t* out,
14 | bool a_transposed,
15 | bool b_transposed,
16 | size_t lda,
17 | size_t ldb,
18 | size_t ldc,
19 | float alpha,
20 | float beta,
21 | size_t batch_size,
22 | const Shape& a_shape,
23 | const Strides& a_strides,
24 | const Shape& b_shape,
25 | const Strides& b_strides) {
26 | auto ndim = a_shape.size();
27 | size_t M = a_shape[ndim - 2];
28 | size_t N = b_shape[ndim - 1];
29 | size_t K = a_shape[ndim - 1];
30 | for (int i = 0; i < batch_size; ++i) {
31 | simd_gemm(
32 | a + elem_to_loc(M * K * i, a_shape, a_strides),
33 | b + elem_to_loc(K * N * i, b_shape, b_strides),
34 | out + M * N * i,
35 | a_transposed,
36 | b_transposed,
37 | M,
38 | N,
39 | K,
40 | alpha,
41 | beta);
42 | }
43 | }
44 |
45 | } // namespace mlx::core
46 |
--------------------------------------------------------------------------------
/.github/actions/build-linux-release/action.yml:
--------------------------------------------------------------------------------
1 | name: 'Build Linux wheel'
2 | description: 'Build Linux wheel'
3 |
4 | inputs:
5 | build-backend:
6 | description: 'Build the backend mlx-cpu package'
7 | type: boolean
8 | required: false
9 | default: false
10 | arch:
11 | description: 'Platform architecture tag'
12 | required: true
13 | type: choice
14 | options:
15 | - x86_64
16 | - aarch64
17 |
18 | runs:
19 | using: "composite"
20 | steps:
21 | - name: Generate package stubs
22 | shell: bash
23 | run: |
24 | pip install -e ".[dev]" -v
25 | pip install typing_extensions
26 | python setup.py generate_stubs
27 | - name: Build Python package
28 | shell: bash
29 | run: |
30 | pip install auditwheel patchelf build
31 | python setup.py clean --all
32 | MLX_BUILD_STAGE=1 python -m build -w
33 | bash python/scripts/repair_linux.sh ${{ inputs.arch }}
34 | - name: Build backend package
35 | if: ${{ inputs.build-backend }}
36 | shell: bash
37 | run: |
38 | python setup.py clean --all
39 | MLX_BUILD_STAGE=2 python -m build -w
40 | auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}
41 |
--------------------------------------------------------------------------------
/mlx/backend/cuda/utils.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | // This file include utilities that are used by C++ code (i.e. .cpp files).
4 |
5 | #pragma once
6 |
7 | #include "mlx/array.h"
8 | #include "mlx/backend/cuda/allocator.h"
9 | #include "mlx/backend/cuda/cuda_utils.h"
10 |
11 | namespace mlx::core {
12 |
13 | template
14 | inline uint max_occupancy_block_dim(T kernel) {
15 | int _, block_dim;
16 | if constexpr (std::is_same_v) {
17 | CHECK_CUDA_ERROR(
18 | cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
19 | } else {
20 | CHECK_CUDA_ERROR(
21 | cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
22 | }
23 | return block_dim;
24 | }
25 |
26 | template
27 | inline T* gpu_ptr(array& arr) {
28 | return reinterpret_cast(
29 | static_cast(
30 | static_cast(arr.buffer().ptr())->data) +
31 | arr.offset());
32 | }
33 |
34 | template
35 | inline const T* gpu_ptr(const array& arr) {
36 | return gpu_ptr(const_cast(arr));
37 | }
38 |
39 | struct Dtype;
40 |
41 | // Convert Dtype to CUDA C++ types.
42 | const char* dtype_to_cuda_type(const Dtype& dtype);
43 |
44 | } // namespace mlx::core
45 |
--------------------------------------------------------------------------------
/mlx/io/no_safetensors.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #include "mlx/io.h"
4 |
5 | namespace mlx::core {
6 |
7 | SafetensorsLoad load_safetensors(std::shared_ptr, StreamOrDevice) {
8 | throw std::runtime_error(
9 | "[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
10 | "to enable safetensors support.");
11 | }
12 |
13 | SafetensorsLoad load_safetensors(const std::string&, StreamOrDevice) {
14 | throw std::runtime_error(
15 | "[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
16 | "to enable safetensors support.");
17 | }
18 |
19 | void save_safetensors(
20 | std::shared_ptr,
21 | std::unordered_map,
22 | std::unordered_map) {
23 | throw std::runtime_error(
24 | "[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
25 | "to enable safetensors support.");
26 | }
27 |
28 | void save_safetensors(
29 | std::string file,
30 | std::unordered_map,
31 | std::unordered_map) {
32 | throw std::runtime_error(
33 | "[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
34 | "to enable safetensors support.");
35 | }
36 |
37 | } // namespace mlx::core
38 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/indexing/masked_scatter.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #pragma once
4 |
5 | template
6 | [[kernel]] void masked_assign_impl(
7 | const device bool* mask [[buffer(0)]],
8 | const device uint* scatter_offsets [[buffer(1)]],
9 | const device T* src [[buffer(2)]],
10 | device T* out [[buffer(3)]],
11 | const constant int* src_shapes [[buffer(4)]],
12 | const constant int64_t* src_strides [[buffer(5)]],
13 | const constant int& src_ndim [[buffer(6)]],
14 | const constant int64_t& src_batch_size [[buffer(7)]],
15 | const constant int64_t& mask_batch_size [[buffer(8)]],
16 | uint idx [[thread_position_in_grid]]) {
17 | const bool mask_value = mask[idx];
18 | if (!mask_value) {
19 | return;
20 | }
21 |
22 | const uint src_index = scatter_offsets[idx];
23 | if (src_index >= src_batch_size) {
24 | return;
25 | }
26 |
27 | const uint batch_idx = idx / mask_batch_size;
28 |
29 | if (src_contiguous) {
30 | out[idx] = src[batch_idx * src_batch_size + src_index];
31 | } else {
32 | out[idx] = src[elem_to_loc(
33 | batch_idx * src_batch_size + src_index,
34 | src_shapes,
35 | src_strides,
36 | src_ndim)];
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/steel/utils.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | METAL_FUNC ulong2 elem_to_loc_broadcast(
8 | uint elem,
9 | constant const int* shape,
10 | constant const int64_t* a_strides,
11 | constant const int64_t* b_strides,
12 | int ndim) {
13 | ulong loc_a{0};
14 | ulong loc_b{0};
15 | for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
16 | int pos_in_dim = (elem % shape[i]);
17 | elem /= shape[i];
18 | loc_a += pos_in_dim * a_strides[i];
19 | loc_b += pos_in_dim * b_strides[i];
20 | }
21 | return ulong2(loc_a, loc_b);
22 | }
23 |
24 | METAL_FUNC ulong3 elem_to_loc_broadcast(
25 | uint elem,
26 | constant const int* shape,
27 | constant const int64_t* a_strides,
28 | constant const int64_t* b_strides,
29 | constant const int64_t* c_strides,
30 | int ndim) {
31 | ulong loc_a{0};
32 | ulong loc_b{0};
33 | ulong loc_c{0};
34 | for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
35 | int pos_in_dim = (elem % shape[i]);
36 | elem /= shape[i];
37 | loc_a += pos_in_dim * a_strides[i];
38 | loc_b += pos_in_dim * b_strides[i];
39 | loc_c += pos_in_dim * c_strides[i];
40 | }
41 | return ulong3(loc_a, loc_b, loc_c);
42 | }
43 |
--------------------------------------------------------------------------------
/mlx/device.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/backend/cpu/available.h"
6 | #include "mlx/backend/gpu/available.h"
7 | #include "mlx/device.h"
8 |
9 | namespace mlx::core {
10 |
11 | Device& mutable_default_device() {
12 | static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu};
13 | return default_device;
14 | }
15 |
16 | const Device& default_device() {
17 | return mutable_default_device();
18 | }
19 |
20 | void set_default_device(const Device& d) {
21 | if (!gpu::is_available() && d == Device::gpu) {
22 | throw std::invalid_argument(
23 | "[set_default_device] Cannot set gpu device without gpu backend.");
24 | }
25 | mutable_default_device() = d;
26 | }
27 |
28 | bool operator==(const Device& lhs, const Device& rhs) {
29 | return lhs.type == rhs.type && lhs.index == rhs.index;
30 | }
31 |
32 | bool operator!=(const Device& lhs, const Device& rhs) {
33 | return !(lhs == rhs);
34 | }
35 |
36 | bool is_available(const Device& d) {
37 | switch (d.type) {
38 | case Device::cpu:
39 | return cpu::is_available();
40 | case Device::gpu:
41 | return gpu::is_available();
42 | }
43 | // appease compiler
44 | return false;
45 | }
46 |
47 | } // namespace mlx::core
48 |
--------------------------------------------------------------------------------
/docs/src/python/array.rst:
--------------------------------------------------------------------------------
1 | .. _array:
2 |
3 | Array
4 | =====
5 |
6 | .. currentmodule:: mlx.core
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | array
12 | array.astype
13 | array.at
14 | array.item
15 | array.tolist
16 | array.dtype
17 | array.itemsize
18 | array.nbytes
19 | array.ndim
20 | array.shape
21 | array.size
22 | array.real
23 | array.imag
24 | array.abs
25 | array.all
26 | array.any
27 | array.argmax
28 | array.argmin
29 | array.conj
30 | array.cos
31 | array.cummax
32 | array.cummin
33 | array.cumprod
34 | array.cumsum
35 | array.diag
36 | array.diagonal
37 | array.exp
38 | array.flatten
39 | array.log
40 | array.log10
41 | array.log1p
42 | array.log2
43 | array.logcumsumexp
44 | array.logsumexp
45 | array.max
46 | array.mean
47 | array.min
48 | array.moveaxis
49 | array.prod
50 | array.reciprocal
51 | array.reshape
52 | array.round
53 | array.rsqrt
54 | array.sin
55 | array.split
56 | array.sqrt
57 | array.square
58 | array.squeeze
59 | array.std
60 | array.sum
61 | array.swapaxes
62 | array.transpose
63 | array.T
64 | array.var
65 | array.view
66 |
--------------------------------------------------------------------------------
/docs/src/python/random.rst:
--------------------------------------------------------------------------------
1 | .. _random:
2 |
3 | Random
4 | ======
5 |
6 | Random sampling functions in MLX use an implicit global PRNG state by default.
7 | However, all function take an optional ``key`` keyword argument for when more
8 | fine-grained control or explicit state management is needed.
9 |
10 | For example, you can generate random numbers with:
11 |
12 | .. code-block:: python
13 |
14 | for _ in range(3):
15 | print(mx.random.uniform())
16 |
17 | which will print a sequence of unique pseudo random numbers. Alternatively you
18 | can explicitly set the key:
19 |
20 | .. code-block:: python
21 |
22 | key = mx.random.key(0)
23 | for _ in range(3):
24 | print(mx.random.uniform(key=key))
25 |
26 | which will yield the same pseudo random number at each iteration.
27 |
28 | Following `JAX's PRNG design `_
29 | we use a splittable version of Threefry, which is a counter-based PRNG.
30 |
31 | .. currentmodule:: mlx.core.random
32 |
33 | .. autosummary::
34 | :toctree: _autosummary
35 |
36 | bernoulli
37 | categorical
38 | gumbel
39 | key
40 | normal
41 | multivariate_normal
42 | randint
43 | seed
44 | split
45 | truncated_normal
46 | uniform
47 | laplace
48 | permutation
49 |
--------------------------------------------------------------------------------
/mlx/backend/metal/jit/includes.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | namespace mlx::core::metal {
6 |
7 | const char* utils();
8 | const char* binary_ops();
9 | const char* unary_ops();
10 | const char* ternary_ops();
11 | const char* reduce_utils();
12 | const char* gather();
13 | const char* scatter();
14 | const char* masked_scatter();
15 |
16 | const char* arange();
17 | const char* unary();
18 | const char* binary();
19 | const char* binary_two();
20 | const char* copy();
21 | const char* fft();
22 | const char* gather_axis();
23 | const char* gather_front();
24 | const char* hadamard();
25 | const char* logsumexp();
26 | const char* quantized_utils();
27 | const char* quantized();
28 | const char* fp_quantized();
29 | const char* ternary();
30 | const char* scan();
31 | const char* scatter_axis();
32 | const char* softmax();
33 | const char* sort();
34 | const char* reduce();
35 |
36 | const char* gemm();
37 | const char* steel_gemm_fused();
38 | const char* steel_gemm_masked();
39 | const char* steel_gemm_splitk();
40 | const char* steel_gemm_gather();
41 | const char* steel_gemm_segmented();
42 | const char* conv();
43 | const char* steel_conv();
44 | const char* steel_conv_general();
45 | const char* gemv_masked();
46 |
47 | } // namespace mlx::core::metal
48 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/fp4.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | constexpr constant static float FP4_LUT[16] = {
4 | +0.0f,
5 | +0.5f,
6 | +1.0f,
7 | +1.5f,
8 | +2.0f,
9 | +3.0f,
10 | +4.0f,
11 | +6.0f,
12 | -0.0f,
13 | -0.5f,
14 | -1.0f,
15 | -1.5f,
16 | -2.0f,
17 | -3.0f,
18 | -4.0f,
19 | -6.0f};
20 |
21 | struct fp4_e2m1 {
22 | fp4_e2m1(float x) {
23 | if (metal::isnan(x)) {
24 | bits = 0x7;
25 | return;
26 | }
27 |
28 | const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0;
29 | x = metal::abs(x);
30 |
31 | if (x > 5.0f) {
32 | bits = 0x7;
33 | } else if (x >= 3.5f) {
34 | bits = 0x6;
35 | } else if (x > 2.5f) {
36 | bits = 0x5;
37 | } else if (x >= 1.75f) {
38 | bits = 0x4;
39 | } else if (x > 1.25f) {
40 | bits = 0x3;
41 | } else if (x >= 0.75f) {
42 | bits = 0x2;
43 | } else if (x > 0.25f) {
44 | bits = 0x1;
45 | } else {
46 | bits = 0x0;
47 | }
48 | bits |= sign_bit;
49 | }
50 |
51 | operator float() {
52 | half converted = as_type(ushort((bits & 7) << 9));
53 | converted *= 16384.0;
54 | converted = bits & 8 ? -converted : converted;
55 | return converted;
56 | }
57 |
58 | uint8_t bits;
59 | };
60 |
--------------------------------------------------------------------------------
/mlx/fence.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/array.h"
6 |
7 | namespace mlx::core {
8 |
9 | /* A fence to be used for synchronizing work between streams.
10 | *
11 | * Calls to `wait` wait in the given stream until all previous calls to update
12 | * are complete on their given stream.
13 | *
14 | * The array passed to `update` is computed and visible after the call to
15 | * `wait` returns. The array passed to `wait` will not be read until all
16 | * previous calls to `update` have completed.
17 | *
18 | * Note, calls to `update` should always be from the same thread or explicitly
19 | * synchronized so that they occur in sequence. Calls to `wait` can be on any
20 | * thread.
21 | *
22 | * For the Metal back-end the fence supports slow (default) and fast mode.
23 | * Fast mode requires setting the environment variable
24 | * `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires Metal 3.2+ (macOS 15+,
25 | * iOS 18+).
26 | */
27 | class Fence {
28 | public:
29 | Fence() {};
30 | explicit Fence(Stream stream);
31 |
32 | void update(Stream stream, const array& x, bool cross_device);
33 | void wait(Stream stream, const array& x);
34 |
35 | private:
36 | std::shared_ptr fence_{nullptr};
37 | };
38 |
39 | } // namespace mlx::core
40 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/eval.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 | #include "mlx/backend/cpu/eval.h"
3 | #include "mlx/backend/cpu/encoder.h"
4 | #include "mlx/primitives.h"
5 | #include "mlx/scheduler.h"
6 | #include "mlx/utils.h"
7 |
8 | namespace mlx::core::cpu {
9 |
10 | void eval(array& arr) {
11 | auto s = arr.primitive().stream();
12 |
13 | auto outputs = arr.outputs();
14 | {
15 | // If the array is a tracer hold a reference
16 | // to its inputs so they don't get donated
17 | std::vector inputs;
18 | if (arr.is_tracer()) {
19 | inputs = arr.inputs();
20 | }
21 | arr.primitive().eval_cpu(arr.inputs(), outputs);
22 | }
23 |
24 | std::unordered_set> buffers;
25 | for (auto& in : arr.inputs()) {
26 | buffers.insert(in.data_shared_ptr());
27 | }
28 | for (auto& s : arr.siblings()) {
29 | buffers.insert(s.data_shared_ptr());
30 | }
31 | // Remove the output if it was donated to by an input
32 | if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
33 | buffers.erase(it);
34 | }
35 | auto& encoder = cpu::get_command_encoder(s);
36 | encoder.dispatch([buffers = std::move(buffers),
37 | temps = std::move(encoder.temporaries())]() {});
38 | }
39 |
40 | } // namespace mlx::core::cpu
41 |
--------------------------------------------------------------------------------
/mlx/backend/cpu/make_compiled_preamble.ps1:
--------------------------------------------------------------------------------
1 | # This script generates a C++ function that provides the CPU
2 | # code for use with kernel generation.
3 | #
4 | # Copyright © 2024 Apple Inc.
5 |
6 | $OUTPUT_FILE = $args[0]
7 | $CL = $args[1]
8 | $SRCDIR = $args[2]
9 |
10 | # Get command result as array.
11 | $CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/cpu/compiled_preamble.h"
12 | # Remove empty lines.
13 | # Otherwise there will be too much empty lines making the result unreadable.
14 | $CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
15 | # Concatenate to string.
16 | $CONTENT = $CONTENT -join "`n"
17 |
18 | # Append extra content.
19 | $CONTENT = @"
20 | $($CONTENT)
21 | using namespace mlx::core;
22 | using namespace mlx::core::detail;
23 | "@
24 |
25 | # Convert each char to ASCII code.
26 | # Unlike the unix script that outputs string literal directly, the output from
27 | # MSVC is way too large to be embedded as string and compilation will fail, so
28 | # we store it as static array instead.
29 | $CHARCODES = ([System.Text.Encoding]::ASCII.GetBytes($CONTENT) -join ', ') + ', 0'
30 |
31 | $OUTPUT = @"
32 | const char* get_kernel_preamble() {
33 | static char preamble[] = { $CHARCODES };
34 | return preamble;
35 | }
36 | "@
37 |
38 | Set-Content -Path $OUTPUT_FILE -Value $OUTPUT
39 |
--------------------------------------------------------------------------------
/python/mlx/_stub_patterns.txt:
--------------------------------------------------------------------------------
1 | mlx.core.__prefix__:
2 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
3 | import sys
4 | if sys.version_info >= (3, 10):
5 | from typing import TypeAlias
6 | else:
7 | from typing_extensions import TypeAlias
8 |
9 | mlx.core.__suffix__:
10 | from typing import Union
11 | scalar: TypeAlias = Union[int, float, bool]
12 | list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]]
13 | bool_: Dtype = ...
14 |
15 | mlx.core.distributed.__prefix__:
16 | from mlx.core import array, Dtype, Device, Stream, scalar
17 | from mlx.core.distributed import Group
18 | from typing import Sequence, Optional, Union
19 |
20 | mlx.core.fast.__prefix__:
21 | from mlx.core import array, Dtype, Device, Stream, scalar
22 | from typing import Sequence, Optional, Union
23 |
24 | mlx.core.linalg.__prefix__:
25 | from mlx.core import array, Dtype, Device, Stream, scalar
26 | from typing import Sequence, Optional, Tuple, Union
27 |
28 | mlx.core.metal.__prefix__:
29 | from mlx.core import array, Dtype, Device, Stream, scalar
30 | from typing import Sequence, Optional, Union
31 |
32 | mlx.core.random.__prefix__:
33 | from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32
34 | from typing import Sequence, Optional, Union
35 |
--------------------------------------------------------------------------------
/mlx/backend/gpu/slicing.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | #include "mlx/backend/common/slicing.h"
4 | #include "mlx/backend/gpu/copy.h"
5 | #include "mlx/backend/gpu/slicing.h"
6 |
7 | namespace mlx::core {
8 |
9 | void slice_gpu(
10 | const array& in,
11 | array& out,
12 | const Shape& start_indices,
13 | const Shape& strides,
14 | const Stream&) {
15 | slice(in, out, start_indices, strides);
16 | }
17 |
18 | void pad_gpu(
19 | const array& in,
20 | const array& val,
21 | array& out,
22 | const std::vector& axes,
23 | const Shape& low_pad_size,
24 | const Stream& s) {
25 | // Fill output with val
26 | fill_gpu(val, out, s);
27 |
28 | // Find offset for start of input values
29 | size_t data_offset = 0;
30 | for (int i = 0; i < axes.size(); i++) {
31 | auto ax = axes[i] < 0 ? out.ndim() + axes[i] : axes[i];
32 | data_offset += out.strides()[ax] * low_pad_size[i];
33 | }
34 |
35 | // Extract slice from output where input will be pasted
36 | array out_slice(in.shape(), out.dtype(), nullptr, {});
37 | out_slice.copy_shared_buffer(
38 | out, out.strides(), out.flags(), out_slice.size(), data_offset);
39 |
40 | // Copy input values into the slice
41 | copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, s);
42 | }
43 |
44 | } // namespace mlx::core
45 |
--------------------------------------------------------------------------------
/benchmarks/python/synchronize_bench.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import mlx.core as mx
4 |
5 | rank = mx.distributed.init().rank()
6 |
7 |
8 | def timeit(fn, a):
9 |
10 | # warmup
11 | for _ in range(5):
12 | mx.eval(fn(a))
13 |
14 | its = 10
15 | tic = time.perf_counter()
16 | for _ in range(its):
17 | mx.eval(fn(a))
18 | toc = time.perf_counter()
19 | ms = 1000 * (toc - tic) / its
20 | return ms
21 |
22 |
23 | def all_reduce_benchmark():
24 | a = mx.ones((5, 5), mx.int32)
25 |
26 | its_per_eval = 100
27 |
28 | def fn(x):
29 | for _ in range(its_per_eval):
30 | x = mx.distributed.all_sum(x)
31 | x = x - 1
32 | return x
33 |
34 | ms = timeit(fn, a) / its_per_eval
35 | if rank == 0:
36 | print(f"All Reduce: time per iteration {ms:.6f} (ms)")
37 |
38 |
39 | def all_gather_benchmark():
40 | a = mx.ones((5, 5), mx.int32)
41 | its_per_eval = 100
42 |
43 | def fn(x):
44 | for _ in range(its_per_eval):
45 | x = mx.distributed.all_gather(x)[0]
46 | return x
47 |
48 | ms = timeit(fn, a) / its_per_eval
49 | if rank == 0:
50 | print(f"All gather: time per iteration {ms:.6f} (ms)")
51 |
52 |
53 | if __name__ == "__main__":
54 | all_reduce_benchmark()
55 | all_gather_benchmark()
56 |
--------------------------------------------------------------------------------
/create_library_symlinks.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Create symlinks for RDMA libraries from CommandLineTools SDK to Xcode-beta SDK
3 | # This is needed because Xcode-beta SDK is missing these library files
4 |
5 | set -e
6 |
7 | echo "=== Creating RDMA Library Symlinks ==="
8 | echo ""
9 | echo "This will create symlinks in Xcode-beta SDK pointing to CommandLineTools SDK"
10 | echo "Requires sudo access"
11 | echo ""
12 |
13 | # Create librdma.tbd symlink
14 | echo "Creating librdma.tbd symlink..."
15 | sudo ln -sf \
16 | /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib/librdma.tbd \
17 | /Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX26.0.sdk/usr/lib/librdma.tbd
18 |
19 | echo "✓ Created: librdma.tbd"
20 |
21 | # Create rdma directory symlink
22 | echo "Creating rdma directory symlink..."
23 | sudo ln -sf \
24 | /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib/rdma \
25 | /Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX26.0.sdk/usr/lib/rdma
26 |
27 | echo "✓ Created: rdma/"
28 |
29 | echo ""
30 | echo "=== Symlinks Created Successfully ==="
31 | echo ""
32 | echo "Verify with:"
33 | echo " ls -la /Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX26.0.sdk/usr/lib/ | grep rdma"
34 | echo ""
35 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal:
--------------------------------------------------------------------------------
1 | // Copyright © 2024-25 Apple Inc.
2 |
3 | // clang-format off
4 | #include "mlx/backend/metal/kernels/utils.h"
5 |
6 | #include "mlx/backend/metal/kernels/steel/attn/attn.h"
7 | #include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
8 |
9 | #define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
10 | instantiate_kernel( \
11 | "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \
12 | "_wm" #wm "_wn" #wn "_mask" #mname, \
13 | attention, dtype, bq, bk, bd, wm, wn, mtype, float)
14 |
15 | #define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
16 | instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \
17 | instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \
18 | instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype)
19 |
20 | #define instantiate_attn_mask_helper(iname, itype) \
21 | instantiate_attn_shapes_helper(iname, itype, iname, itype) \
22 | instantiate_attn_shapes_helper(iname, itype, bool_, bool)
23 |
24 | instantiate_attn_mask_helper(float16, half);
25 | instantiate_attn_mask_helper(bfloat16, bfloat16_t);
26 |
27 | instantiate_attn_mask_helper(float32, float);
28 | // clang-format on
29 |
--------------------------------------------------------------------------------
/mlx/allocator.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | namespace mlx::core::allocator {
8 |
9 | // Simple wrapper around buffer pointers
10 | // WARNING: Only Buffer objects constructed from and those that wrap
11 | // raw pointers from mlx::allocator are supported.
12 | class Buffer {
13 | private:
14 | void* ptr_;
15 |
16 | public:
17 | explicit Buffer(void* ptr) : ptr_(ptr) {};
18 |
19 | // Get the raw data pointer from the buffer
20 | void* raw_ptr();
21 |
22 | // Get the buffer pointer from the buffer
23 | const void* ptr() const {
24 | return ptr_;
25 | };
26 | void* ptr() {
27 | return ptr_;
28 | };
29 | };
30 |
31 | Buffer malloc(size_t size);
32 |
33 | void free(Buffer buffer);
34 |
35 | class Allocator {
36 | /** Abstract base class for a memory allocator. */
37 | public:
38 | virtual Buffer malloc(size_t size) = 0;
39 | virtual void free(Buffer buffer) = 0;
40 | virtual size_t size(Buffer buffer) const = 0;
41 |
42 | Allocator() = default;
43 | Allocator(const Allocator& other) = delete;
44 | Allocator(Allocator&& other) = delete;
45 | Allocator& operator=(const Allocator& other) = delete;
46 | Allocator& operator=(Allocator&& other) = delete;
47 | virtual ~Allocator() = default;
48 | };
49 |
50 | Allocator& allocator();
51 |
52 | } // namespace mlx::core::allocator
53 |
--------------------------------------------------------------------------------
/mlx/backend/metal/distributed.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/allocator.h"
6 | #include "mlx/backend/common/utils.h"
7 | #include "mlx/backend/gpu/copy.h"
8 | #include "mlx/backend/metal/device.h"
9 | #include "mlx/backend/metal/utils.h"
10 | #include "mlx/distributed/ops.h"
11 | #include "mlx/distributed/primitives.h"
12 | #include "mlx/fence.h"
13 | #include "mlx/scheduler.h"
14 |
15 | namespace mlx::core::distributed {
16 |
17 | void AllReduce::eval_gpu(const std::vector&, std::vector&) {
18 | throw std::runtime_error("[AllReduce::eval_gpu] has no GPU implementation.");
19 | }
20 |
21 | void AllGather::eval_gpu(const std::vector&, std::vector&) {
22 | throw std::runtime_error("[AllGather::eval_gpu] has no GPU implementation.");
23 | }
24 |
25 | void Send::eval_gpu(const std::vector&, std::vector&) {
26 | throw std::runtime_error("[Send::eval_gpu] has no GPU implementation.");
27 | }
28 |
29 | void Recv::eval_gpu(const std::vector&, std::vector&) {
30 | throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
31 | }
32 |
33 | void ReduceScatter::eval_gpu(const std::vector&, std::vector&) {
34 | throw std::runtime_error(
35 | "[ReduceScatter::eval_gpu] has no GPU implementation.");
36 | }
37 |
38 | } // namespace mlx::core::distributed
39 |
--------------------------------------------------------------------------------
/mlx/distributed/ops.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | #include "mlx/distributed/distributed.h"
8 | #include "mlx/utils.h"
9 |
10 | namespace mlx::core::distributed {
11 |
12 | array all_sum(
13 | const array& x,
14 | std::optional group = std::nullopt,
15 | StreamOrDevice s = {});
16 |
17 | array all_gather(
18 | const array& x,
19 | std::optional group = std::nullopt,
20 | StreamOrDevice S = {});
21 |
22 | array send(
23 | const array& x,
24 | int dst,
25 | std::optional group = std::nullopt,
26 | StreamOrDevice s = {});
27 |
28 | array recv(
29 | Shape shape,
30 | Dtype dtype,
31 | int src,
32 | std::optional group = std::nullopt,
33 | StreamOrDevice s = {});
34 |
35 | array recv_like(
36 | const array& x,
37 | int src,
38 | std::optional group = std::nullopt,
39 | StreamOrDevice s = {});
40 |
41 | array all_max(
42 | const array& x,
43 | std::optional group = std::nullopt,
44 | StreamOrDevice s = {});
45 |
46 | array all_min(
47 | const array& x,
48 | std::optional group = std::nullopt,
49 | StreamOrDevice s = {});
50 |
51 | array sum_scatter(
52 | const array& x,
53 | std::optional group = std::nullopt,
54 | StreamOrDevice s = {});
55 |
56 | } // namespace mlx::core::distributed
57 |
--------------------------------------------------------------------------------
/benchmarks/cpp/time_utils.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 | #include
8 |
9 | #include "mlx/mlx.h"
10 |
11 | #define milliseconds(x) \
12 | (std::chrono::duration_cast(x).count() / 1e6)
13 | #define time_now() std::chrono::high_resolution_clock::now()
14 |
15 | #define TIME(FUNC, ...) \
16 | std::cout << "Timing " << #FUNC << " ... " << std::flush \
17 | << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
18 | << std::endl;
19 |
20 | #define TIMEM(MSG, FUNC, ...) \
21 | std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
22 | << std::flush << std::setprecision(5) \
23 | << time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
24 |
25 | template
26 | double time_fn(F fn, Args&&... args) {
27 | // warmup
28 | for (int i = 0; i < 5; ++i) {
29 | eval(fn(std::forward(args)...));
30 | }
31 |
32 | int num_iters = 100;
33 | auto start = time_now();
34 | for (int i = 0; i < num_iters; i++) {
35 | eval(fn(std::forward(args)...));
36 | }
37 | auto end = time_now();
38 | return milliseconds(end - start) / static_cast(num_iters);
39 | }
40 |
--------------------------------------------------------------------------------
/python/src/mlx.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023-2024 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/version.h"
6 |
7 | namespace mx = mlx::core;
8 | namespace nb = nanobind;
9 |
10 | void init_mlx_func(nb::module_&);
11 | void init_array(nb::module_&);
12 | void init_device(nb::module_&);
13 | void init_stream(nb::module_&);
14 | void init_metal(nb::module_&);
15 | void init_cuda(nb::module_&);
16 | void init_memory(nb::module_&);
17 | void init_ops(nb::module_&);
18 | void init_transforms(nb::module_&);
19 | void init_random(nb::module_&);
20 | void init_fft(nb::module_&);
21 | void init_linalg(nb::module_&);
22 | void init_constants(nb::module_&);
23 | void init_fast(nb::module_&);
24 | void init_distributed(nb::module_&);
25 | void init_export(nb::module_&);
26 |
27 | NB_MODULE(core, m) {
28 | m.doc() = "mlx: A framework for machine learning on Apple silicon.";
29 |
30 | auto reprlib_fix = nb::module_::import_("mlx._reprlib_fix");
31 | nb::set_leak_warnings(false);
32 |
33 | init_mlx_func(m);
34 | init_device(m);
35 | init_stream(m);
36 | init_array(m);
37 | init_metal(m);
38 | init_cuda(m);
39 | init_memory(m);
40 | init_ops(m);
41 | init_transforms(m);
42 | init_random(m);
43 | init_fft(m);
44 | init_linalg(m);
45 | init_constants(m);
46 | init_fast(m);
47 | init_distributed(m);
48 | init_export(m);
49 |
50 | m.attr("__version__") = mx::version();
51 | }
52 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/steel/utils/type_traits.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | #pragma METAL internals : enable
8 |
9 | namespace metal {
10 |
11 | template
12 | struct is_empty : metal::bool_constant<__is_empty(T)> {};
13 |
14 | #ifdef __cpp_variable_templates
15 | template