├── docs ├── .nojekyll ├── .gitignore ├── .clang-format ├── src │ ├── cpp │ │ └── ops.rst │ ├── _static │ │ └── mlx_logo.png │ ├── python │ │ ├── nn │ │ │ └── module.rst │ │ ├── transforms.rst │ │ ├── fft.rst │ │ ├── devices_and_streams.rst │ │ ├── tree_utils.rst │ │ ├── array.rst │ │ ├── random.rst │ │ ├── data_types.rst │ │ ├── ops.rst │ │ └── optimizers.rst │ ├── _templates │ │ ├── optimizers-template.rst │ │ └── nn-module-template.rst │ ├── using_streams.rst │ ├── conf.py │ ├── quick_start.rst │ ├── index.rst │ ├── examples │ │ └── linear_regression.rst │ └── install.rst ├── index.html ├── Makefile └── README.md ├── mlx ├── 3rdparty │ └── .clang-format ├── backend │ ├── no_metal │ │ ├── CMakeLists.txt │ │ ├── allocator.cpp │ │ ├── metal.cpp │ │ └── primitives.cpp │ ├── accelerate │ │ ├── CMakeLists.txt │ │ ├── conv.cpp │ │ └── utils.h │ ├── common │ │ ├── erf.h │ │ ├── threefry.h │ │ ├── CMakeLists.txt │ │ ├── utils.h │ │ ├── copy.h │ │ ├── threefry.cpp │ │ ├── load.cpp │ │ ├── erf.cpp │ │ ├── arange.h │ │ ├── fft.cpp │ │ ├── softmax.cpp │ │ ├── arg_reduce.cpp │ │ ├── default_primitives.cpp │ │ └── unary.h │ └── metal │ │ ├── fft.cpp │ │ ├── copy.h │ │ ├── kernels │ │ ├── defines.h │ │ ├── conv_params.h │ │ ├── arange.metal │ │ ├── CMakeLists.txt │ │ ├── erf.h │ │ ├── random.metal │ │ └── complex.h │ │ ├── metal.h │ │ ├── matmul.h │ │ ├── CMakeLists.txt │ │ ├── allocator.h │ │ ├── device.h │ │ ├── metal.cpp │ │ ├── softmax.cpp │ │ └── copy.cpp ├── mlx.h ├── transforms_impl.h ├── device.h ├── graph_utils.h ├── stream.h ├── device.cpp ├── CMakeLists.txt ├── allocator.cpp ├── utils.h ├── scheduler.cpp ├── types │ ├── half_types.h │ └── complex.h ├── allocator.h ├── dtype.h ├── load.h └── graph_utils.cpp ├── MANIFEST.in ├── pyproject.toml ├── examples ├── extensions │ ├── mlx_sample_extensions │ │ └── __init__.py │ ├── setup.py │ ├── bindings.cpp │ ├── CMakeLists.txt │ └── axpby │ │ ├── axpby.metal │ │ └── axpby.h ├── cpp │ ├── CMakeLists.txt │ ├── timer.h │ ├── logistic_regression.cpp │ ├── linear_regression.cpp │ └── tutorial.cpp └── python │ ├── logistic_regression.py │ └── linear_regression.py ├── python ├── mlx │ ├── nn │ │ ├── __init__.py │ │ ├── losses.py │ │ ├── layers │ │ │ ├── containers.py │ │ │ ├── __init__.py │ │ │ ├── embedding.py │ │ │ ├── dropout.py │ │ │ ├── linear.py │ │ │ └── activations.py │ │ └── utils.py │ ├── _reprlib_fix.py │ └── extension.py ├── src │ ├── metal.cpp │ ├── indexing.h │ ├── load.h │ ├── mlx.cpp │ ├── stream.cpp │ ├── CMakeLists.txt │ ├── device.cpp │ └── utils.h ├── tests │ ├── mlx_tests.py │ ├── test_tree.py │ ├── test_optimizers.py │ ├── test_eval.py │ ├── test_fft.py │ └── test_device.py └── README.md ├── .pre-commit-config.yaml ├── benchmarks ├── cpp │ ├── CMakeLists.txt │ ├── compare_devices.cpp │ ├── autograd.cpp │ └── time_utils.h ├── numpy │ ├── time_utils.py │ └── single_ops.py └── python │ ├── time_utils.py │ ├── comparative │ └── README.md │ ├── batch_matmul_bench.py │ ├── single_ops.py │ └── llama_mlx_bench.py ├── tests ├── tests.cpp ├── utils_tests.cpp ├── graph_optimize_tests.cpp ├── CMakeLists.txt ├── device_tests.cpp ├── allocator_tests.cpp ├── load_tests.cpp ├── eval_tests.cpp ├── blas_tests.cpp └── scheduler_tests.cpp ├── .gitignore ├── LICENSE ├── CONTRIBUTING.md ├── mlx.pc.in ├── cmake └── extension.cmake ├── .clang-format └── README.md /docs/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | src/python/_autosummary*/ 2 | -------------------------------------------------------------------------------- /docs/.clang-format: -------------------------------------------------------------------------------- 1 | DisableFormat: true 2 | SortIncludes: Never 3 | -------------------------------------------------------------------------------- /mlx/3rdparty/.clang-format: -------------------------------------------------------------------------------- 1 | DisableFormat: true 2 | SortIncludes: Never 3 | -------------------------------------------------------------------------------- /docs/src/cpp/ops.rst: -------------------------------------------------------------------------------- 1 | .. _cpp_ops: 2 | 3 | Operations 4 | ========== 5 | 6 | 7 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CMakeLists.txt 2 | recursive-include mlx/ * 3 | include python/src/* 4 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/src/_static/mlx_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/mlx/main/docs/src/_static/mlx_logo.png -------------------------------------------------------------------------------- /docs/src/python/nn/module.rst: -------------------------------------------------------------------------------- 1 | mlx.nn.Module 2 | ============= 3 | 4 | .. currentmodule:: mlx.nn 5 | 6 | .. autoclass:: Module 7 | :members: 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /examples/extensions/mlx_sample_extensions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import mlx.core as mx 4 | from .mlx_sample_extensions import * 5 | -------------------------------------------------------------------------------- /python/mlx/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from mlx.nn.layers import * 4 | from mlx.nn import losses 5 | from mlx.nn.utils import value_and_grad 6 | -------------------------------------------------------------------------------- /mlx/backend/no_metal/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE 4 | ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp 7 | ) 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-clang-format 3 | rev: v14.0.6 4 | hooks: 5 | - id: clang-format 6 | - repo: https://github.com/psf/black 7 | rev: 22.10.0 8 | hooks: 9 | - id: black 10 | -------------------------------------------------------------------------------- /docs/src/python/transforms.rst: -------------------------------------------------------------------------------- 1 | .. _transforms: 2 | 3 | Transforms 4 | ========== 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | eval 12 | grad 13 | value_and_grad 14 | jvp 15 | vjp 16 | vmap 17 | -------------------------------------------------------------------------------- /python/mlx/nn/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import mlx.core as mx 4 | 5 | 6 | def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1): 7 | score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) 8 | return mx.logsumexp(logits, axis=axis) - score 9 | -------------------------------------------------------------------------------- /docs/src/python/fft.rst: -------------------------------------------------------------------------------- 1 | .. _fft: 2 | 3 | FFT 4 | === 5 | 6 | .. currentmodule:: mlx.core.fft 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | fft 12 | ifft 13 | fft2 14 | ifft2 15 | fftn 16 | ifftn 17 | rfft 18 | irfft 19 | rfft2 20 | irfft2 21 | rfftn 22 | irfftn 23 | -------------------------------------------------------------------------------- /mlx/backend/accelerate/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE 4 | ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp 7 | ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp 8 | ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp 9 | ) 10 | -------------------------------------------------------------------------------- /mlx/mlx.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | #include "mlx/backend/metal/metal.h" 7 | #include "mlx/device.h" 8 | #include "mlx/fft.h" 9 | #include "mlx/ops.h" 10 | #include "mlx/random.h" 11 | #include "mlx/stream.h" 12 | #include "mlx/transforms.h" 13 | #include "mlx/utils.h" 14 | -------------------------------------------------------------------------------- /mlx/backend/common/erf.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | namespace mlx::core { 4 | 5 | /* Approximation to the inverse error function. 6 | * Based on code from: 7 | * https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348 8 | */ 9 | float erfinv(float a); 10 | 11 | } // namespace mlx::core 12 | -------------------------------------------------------------------------------- /mlx/backend/metal/fft.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/primitives.h" 4 | 5 | namespace mlx::core { 6 | 7 | void FFT::eval_gpu(const std::vector& inputs, array& out) { 8 | auto& in = inputs[0]; 9 | throw std::runtime_error("[FFT] NYI for Metal backend."); 10 | } 11 | 12 | } // namespace mlx::core 13 | -------------------------------------------------------------------------------- /docs/src/python/devices_and_streams.rst: -------------------------------------------------------------------------------- 1 | .. _devices_and_streams: 2 | 3 | Devices and Streams 4 | =================== 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | Device 12 | default_device 13 | set_default_device 14 | Stream 15 | default_stream 16 | new_stream 17 | set_default_stream 18 | -------------------------------------------------------------------------------- /mlx/backend/no_metal/allocator.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/allocator.h" 4 | 5 | namespace mlx::core::allocator { 6 | 7 | Allocator& allocator() { 8 | static CommonAllocator allocator_; 9 | return allocator_; 10 | } 11 | 12 | void* Buffer::raw_ptr() { 13 | return ptr_; 14 | } 15 | 16 | } // namespace mlx::core::allocator 17 | -------------------------------------------------------------------------------- /python/src/metal.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/metal/metal.h" 6 | 7 | namespace py = pybind11; 8 | 9 | using namespace mlx::core; 10 | 11 | void init_metal(py::module_& m) { 12 | py::module_ metal = m.def_submodule("metal", "mlx.metal"); 13 | metal.def("is_available", &metal::is_available); 14 | } 15 | -------------------------------------------------------------------------------- /examples/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | function(build_example SRCFILE) 2 | get_filename_component(src_name ${SRCFILE} NAME_WE) 3 | set(target "${src_name}") 4 | add_executable(${target} ${SRCFILE}) 5 | target_link_libraries(${target} PRIVATE mlx) 6 | endfunction(build_example) 7 | 8 | build_example(tutorial.cpp) 9 | build_example(linear_regression.cpp) 10 | build_example(logistic_regression.cpp) 11 | -------------------------------------------------------------------------------- /python/src/indexing.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/array.h" 8 | #include "python/src/utils.h" 9 | 10 | namespace py = pybind11; 11 | using namespace mlx::core; 12 | 13 | array mlx_get_item(const array& src, const py::object& obj); 14 | void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v); 15 | -------------------------------------------------------------------------------- /examples/cpp/timer.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace timer { 8 | 9 | using namespace std::chrono; 10 | 11 | template 12 | inline double seconds(duration x) { 13 | return duration_cast(x).count() / 1e9; 14 | } 15 | 16 | inline auto time() { 17 | return high_resolution_clock::now(); 18 | } 19 | 20 | } // namespace timer 21 | -------------------------------------------------------------------------------- /benchmarks/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | function(build_benchmark SRCFILE) 2 | get_filename_component(src_name ${SRCFILE} NAME_WE) 3 | set(target "${src_name}") 4 | add_executable(${target} ${SRCFILE}) 5 | target_link_libraries(${target} PRIVATE mlx) 6 | endfunction(build_benchmark) 7 | 8 | build_benchmark(single_ops.cpp) 9 | build_benchmark(irregular_strides.cpp) 10 | build_benchmark(compare_devices.cpp) 11 | build_benchmark(autograd.cpp) 12 | -------------------------------------------------------------------------------- /benchmarks/numpy/time_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | 5 | 6 | def time_fn(fn, *args): 7 | print(f"Timing {fn.__name__} ...", end=" ") 8 | 9 | # warmup 10 | for _ in range(5): 11 | fn(*args) 12 | 13 | num_iters = 100 14 | tic = time.perf_counter() 15 | for _ in range(num_iters): 16 | x = fn(*args) 17 | toc = time.perf_counter() 18 | 19 | msec = 1e3 * (toc - tic) / num_iters 20 | print(f"{msec:.5f} msec") 21 | -------------------------------------------------------------------------------- /mlx/backend/metal/copy.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/backend/common/copy.h" 6 | #include "mlx/stream.h" 7 | 8 | namespace mlx::core { 9 | 10 | void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); 11 | void copy_gpu(const array& src, array& out, CopyType ctype); 12 | void copy_gpu_inplace( 13 | const array& src, 14 | array& out, 15 | CopyType ctype, 16 | const Stream& s); 17 | 18 | } // namespace mlx::core 19 | -------------------------------------------------------------------------------- /docs/src/_templates/optimizers-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | {% block methods %} 8 | 9 | {% if methods %} 10 | .. rubric:: {{ _('Methods') }} 11 | 12 | .. autosummary:: 13 | {% for item in methods %} 14 | {%- if item not in inherited_members %} 15 | ~{{ name }}.{{ item }} 16 | {%- endif %} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | -------------------------------------------------------------------------------- /python/tests/mlx_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import os 4 | import unittest 5 | 6 | import mlx.core as mx 7 | 8 | 9 | class MLXTestCase(unittest.TestCase): 10 | def setUp(self): 11 | self.default = mx.default_device() 12 | device = os.getenv("DEVICE", None) 13 | if device is not None: 14 | device = getattr(mx, device) 15 | mx.set_default_device(device) 16 | 17 | def tearDown(self): 18 | mx.set_default_device(self.default) 19 | -------------------------------------------------------------------------------- /mlx/backend/accelerate/conv.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include "mlx/backend/common/copy.h" 9 | #include "mlx/primitives.h" 10 | #include "mlx/utils.h" 11 | 12 | namespace mlx::core { 13 | 14 | void Convolution::eval_cpu(const std::vector& inputs, array& out) { 15 | eval(inputs, out); 16 | 17 | // TODO: Add accelerate based optimizations for CPU conv 18 | } 19 | 20 | } // namespace mlx::core 21 | -------------------------------------------------------------------------------- /docs/src/_templates/nn-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | {#{% block methods %} 8 | 9 | {% if methods %} 10 | .. rubric:: {{ _('Methods') }} 11 | 12 | .. autosummary:: 13 | {% for item in methods %} 14 | {%- if item not in inherited_members and item != '__init__' %} 15 | ~{{ name }}.{{ item }} 16 | {%- endif %} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %}#} 20 | -------------------------------------------------------------------------------- /mlx/backend/no_metal/metal.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/metal/metal.h" 6 | 7 | namespace mlx::core::metal { 8 | 9 | void new_stream(Stream) {} 10 | 11 | std::function make_task( 12 | array& arr, 13 | std::vector> deps, 14 | std::shared_ptr> p, 15 | bool retain_graph) { 16 | throw std::runtime_error( 17 | "[metal::make_task] Cannot make GPU task without metal backend"); 18 | } 19 | 20 | } // namespace mlx::core::metal 21 | -------------------------------------------------------------------------------- /benchmarks/python/time_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import time 4 | 5 | import mlx.core as mx 6 | 7 | 8 | def time_fn(fn, *args, **kwargs): 9 | print(f"Timing {fn.__name__} ...", end=" ") 10 | 11 | # warmup 12 | for _ in range(5): 13 | mx.eval(fn(*args, **kwargs)) 14 | 15 | num_iters = 100 16 | tic = time.perf_counter() 17 | for _ in range(num_iters): 18 | x = mx.eval(fn(*args, **kwargs)) 19 | toc = time.perf_counter() 20 | 21 | msec = 1e3 * (toc - tic) / num_iters 22 | print(f"{msec:.5f} msec") 23 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/defines.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #ifdef __METAL__ 6 | #define MTL_CONST constant 7 | #else 8 | #define MTL_CONST 9 | #endif 10 | 11 | static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; 12 | static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5; 13 | static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; 14 | static MTL_CONST constexpr int REDUCE_N_READS = 16; 15 | static MTL_CONST constexpr int SOFTMAX_N_READS = 4; 16 | static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096; 17 | -------------------------------------------------------------------------------- /mlx/transforms_impl.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | namespace mlx::core::detail { 4 | 5 | std::pair, std::vector> vmap_trace( 6 | const std::function(const std::vector&)>& fun, 7 | const std::vector& inputs, 8 | const std::vector& in_axes); 9 | 10 | std::vector vmap_replace( 11 | const std::vector& inputs, 12 | const std::vector& s_inputs, 13 | const std::vector& s_outputs, 14 | const std::vector& in_axes, 15 | const std::vector& out_axes); 16 | 17 | } // namespace mlx::core::detail 18 | -------------------------------------------------------------------------------- /python/mlx/_reprlib_fix.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import array 4 | import reprlib 5 | 6 | 7 | class FixedRepr(reprlib.Repr): 8 | """Only route python array instances to repr_array.""" 9 | 10 | def repr_array(self, x, maxlevel): 11 | if isinstance(x, array.array): 12 | return super().repr_array(x, maxlevel) 13 | else: 14 | return self.repr_instance(x, maxlevel) 15 | 16 | 17 | # We need to monkey-patch reprlib so that we can use the debugger without 18 | # renaming the array to something else 19 | fixed_repr = FixedRepr() 20 | reprlib.repr = fixed_repr.repr 21 | -------------------------------------------------------------------------------- /tests/tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #define DOCTEST_CONFIG_IMPLEMENT 4 | #include "doctest/doctest.h" 5 | 6 | #include 7 | 8 | #include "mlx/mlx.h" 9 | 10 | using namespace mlx::core; 11 | 12 | int main(int argc, char** argv) { 13 | doctest::Context context; 14 | 15 | const char* device = std::getenv("DEVICE"); 16 | if (device != nullptr && std::string(device) == "cpu") { 17 | set_default_device(Device::cpu); 18 | } else if (metal::is_available()) { 19 | set_default_device(Device::gpu); 20 | } 21 | 22 | context.applyCommandLine(argc, argv); 23 | return context.run(); 24 | } 25 | -------------------------------------------------------------------------------- /python/src/load.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include "mlx/ops.h" 9 | 10 | namespace py = pybind11; 11 | using namespace mlx::core; 12 | 13 | using DictOrArray = std::variant>; 14 | 15 | DictOrArray mlx_load_helper(py::object file, StreamOrDevice s); 16 | void mlx_save_helper(py::object file, array a, bool retain_graph = true); 17 | void mlx_savez_helper( 18 | py::object file, 19 | py::args args, 20 | const py::kwargs& kwargs, 21 | bool compressed = false); -------------------------------------------------------------------------------- /docs/src/python/tree_utils.rst: -------------------------------------------------------------------------------- 1 | .. _utils: 2 | 3 | Tree Utils 4 | ========== 5 | 6 | In MLX we consider a python tree to be an arbitrarily nested collection of 7 | dictionaries, lists and tuples without cycles. Functions in this module that 8 | return python trees will be using the default python ``dict``, ``list`` and 9 | ``tuple`` but they can usually process objects that inherit from any of these. 10 | 11 | .. note:: 12 | Dictionaries should have keys that are valid python identifiers. 13 | 14 | .. currentmodule:: mlx.utils 15 | 16 | .. autosummary:: 17 | :toctree: _autosummary 18 | 19 | tree_flatten 20 | tree_unflatten 21 | tree_map 22 | -------------------------------------------------------------------------------- /mlx/backend/metal/metal.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "mlx/array.h" 10 | #include "mlx/stream.h" 11 | 12 | namespace mlx::core::metal { 13 | 14 | constexpr bool is_available() { 15 | #ifdef _METAL_ 16 | return true; 17 | #else 18 | return false; 19 | #endif 20 | } 21 | 22 | void new_stream(Stream stream); 23 | 24 | std::function make_task( 25 | array& arr, 26 | std::vector> deps, 27 | std::shared_ptr> p, 28 | bool retain_graph); 29 | 30 | } // namespace mlx::core::metal 31 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | 3 | # You can set these variables from the command line. 4 | SPHINXOPTS = 5 | SPHINXBUILD = sphinx-build 6 | SOURCEDIR = src 7 | BUILDDIR = build 8 | 9 | # Put it first so that "make" without argument is like "make help". 10 | help: 11 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 12 | 13 | .PHONY: help Makefile 14 | 15 | # Catch-all target: route all unknown targets to Sphinx using the new 16 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 17 | %: Makefile 18 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 19 | -------------------------------------------------------------------------------- /mlx/backend/common/threefry.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | namespace mlx::core::random { 9 | 10 | /** Applies the Threefry 2x32 hash function. 11 | * This code is based on the Jax counter-based and splittable PRNG 12 | * https://github.com/google/jax/blob/main/docs/jep/263-prng.md 13 | * 14 | * Original Threefry reference: 15 | * http://www.thesalmons.org/john/random123/papers/random123sc11.pdf 16 | */ 17 | std::pair threefry2x32_hash( 18 | const std::pair& key, 19 | std::pair count); 20 | 21 | } // namespace mlx::core::random 22 | -------------------------------------------------------------------------------- /mlx/device.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | namespace mlx::core { 6 | 7 | struct Device { 8 | enum class DeviceType { 9 | cpu, 10 | gpu, 11 | }; 12 | 13 | static constexpr DeviceType cpu = DeviceType::cpu; 14 | static constexpr DeviceType gpu = DeviceType::gpu; 15 | 16 | Device(DeviceType type, int index = 0) : type(type), index(index){}; 17 | 18 | DeviceType type; 19 | int index; 20 | }; 21 | 22 | const Device& default_device(); 23 | 24 | void set_default_device(const Device& d); 25 | 26 | bool operator==(const Device& lhs, const Device& rhs); 27 | bool operator!=(const Device& lhs, const Device& rhs); 28 | 29 | } // namespace mlx::core 30 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/conv_params.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | template 6 | struct MLXConvParams { 7 | const int N; // Batch size 8 | const int C; // In channels 9 | const int O; // Out channels 10 | const int iS[NDIM]; // Input spatial dim 11 | const int wS[NDIM]; // Weight spatial dim 12 | const int oS[NDIM]; // Output spatial dim 13 | const int str[NDIM]; // Kernel strides 14 | const int pad[NDIM]; // Input padding 15 | const int dil[NDIM]; // Kernel dilation 16 | const size_t in_strides[NDIM + 2]; // In strides 17 | const size_t wt_strides[NDIM + 2]; // Wt strides 18 | const size_t out_strides[NDIM + 2]; // Out strides 19 | }; 20 | -------------------------------------------------------------------------------- /mlx/graph_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | 7 | namespace mlx::core { 8 | 9 | void print_graph(std::ostream& os, const std::vector& outputs); 10 | 11 | template 12 | void print_graph(std::ostream& os, Arrays... outputs) { 13 | print_graph(os, std::vector{std::forward(outputs)...}); 14 | } 15 | 16 | void export_to_dot(std::ostream& os, const std::vector& outputs); 17 | 18 | template 19 | void export_to_dot(std::ostream& os, Arrays... outputs) { 20 | export_to_dot(os, std::vector{std::forward(outputs)...}); 21 | } 22 | 23 | } // namespace mlx::core 24 | -------------------------------------------------------------------------------- /examples/extensions/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from mlx import extension 4 | from setuptools import setup 5 | 6 | if __name__ == "__main__": 7 | setup( 8 | name="mlx_sample_extensions", 9 | version="0.0.0", 10 | description="Sample C++ and Metal extensions for MLX primitives.", 11 | ext_modules=[extension.CMakeExtension("mlx_sample_extensions")], 12 | cmdclass={"build_ext": extension.CMakeBuild}, 13 | packages=["mlx_sample_extensions"], 14 | package_dir={"": "."}, 15 | package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, 16 | zip_safe=False, 17 | python_requires=">=3.7", 18 | ) 19 | -------------------------------------------------------------------------------- /mlx/backend/common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE 4 | ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp 7 | ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp 8 | ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp 9 | ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp 10 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp 11 | ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp 12 | ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp 13 | ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp 14 | ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp 15 | ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp 16 | ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp 17 | ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp 18 | ) 19 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | ### Packaging for PyPI 2 | 3 | Install `build` and `twine`: 4 | 5 | ``` 6 | pip install --user --upgrade build 7 | pip install --user --upgrade twine 8 | ``` 9 | 10 | Generate the source distribution and wheel: 11 | 12 | ``` 13 | python -m build 14 | ``` 15 | 16 | *Warning* use a test server first 17 | 18 | #### Test Upload 19 | 20 | Upload to test server: 21 | 22 | ``` 23 | python -m twine upload --repository testpypi dist/* 24 | ``` 25 | 26 | Install from test server and check that it works: 27 | 28 | ``` 29 | python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx 30 | ``` 31 | 32 | #### Upload 33 | 34 | ``` 35 | python -m twine upload dist/* 36 | ``` 37 | 38 | -------------------------------------------------------------------------------- /docs/src/using_streams.rst: -------------------------------------------------------------------------------- 1 | Using Streams 2 | ============= 3 | 4 | .. currentmodule:: mlx.core 5 | 6 | Specifying the :obj:`Stream` 7 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 8 | 9 | All operations (including random number generation) take an optional 10 | keyword argument ``stream``. The ``stream`` kwarg specifies which 11 | :obj:`Stream` the operation should run on. If the stream is unspecified then 12 | the operation is run on the default stream of the default device: 13 | ``mx.default_stream(mx.default_device())``. The ``stream`` kwarg can also 14 | be a :obj:`Device` (e.g. ``stream=my_device``) in which case the operation is 15 | run on the default stream of the provided device 16 | ``mx.default_stream(my_device)``. 17 | -------------------------------------------------------------------------------- /benchmarks/cpp/compare_devices.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include "mlx/mlx.h" 5 | #include "time_utils.h" 6 | 7 | using namespace mlx::core; 8 | 9 | void time_add_op() { 10 | std::vector sizes(1, 1); 11 | for (int i = 0; i < 9; ++i) { 12 | sizes.push_back(10 * sizes.back()); 13 | } 14 | set_default_device(Device::cpu); 15 | for (auto size : sizes) { 16 | auto a = random::uniform({size}); 17 | auto b = random::uniform({size}); 18 | eval(a, b); 19 | std::cout << "Size " << size << std::endl; 20 | TIMEM("cpu", add, a, b, Device::cpu); 21 | TIMEM("gpu", add, a, b, Device::gpu); 22 | } 23 | } 24 | 25 | int main() { 26 | time_add_op(); 27 | } 28 | -------------------------------------------------------------------------------- /benchmarks/python/comparative/README.md: -------------------------------------------------------------------------------- 1 | Microbenchmarks comparing MLX to PyTorch 2 | ======================================== 3 | 4 | Implement the same microbenchmarks in MLX and PyTorch to compare and make a 5 | list of the biggest possible performance improvements and/or regressions. 6 | 7 | Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for 8 | instance to measure the times it takes to sum across the 3rd axis of the above 9 | tensor on the cpu. 10 | 11 | `compare.py` runs several benchmarks and compares the speed-up or lack thereof 12 | in comparison to PyTorch. 13 | 14 | Each bench script can be run with `--print-pid` to print the PID and wait for a 15 | key in order to ease attaching a debugger. 16 | -------------------------------------------------------------------------------- /tests/utils_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "doctest/doctest.h" 4 | 5 | #include "mlx/mlx.h" 6 | 7 | using namespace mlx::core; 8 | 9 | TEST_CASE("test type promotion") { 10 | for (auto t : {bool_, uint32, int32, int64, float32}) { 11 | auto a = array(0, t); 12 | CHECK_EQ(result_type({a}), t); 13 | 14 | std::vector arrs = {array(0, t), array(0, t)}; 15 | CHECK_EQ(result_type(arrs), t); 16 | } 17 | 18 | { 19 | std::vector arrs = {array(false), array(0, int32)}; 20 | CHECK_EQ(result_type(arrs), int32); 21 | } 22 | 23 | { 24 | std::vector arrs = {array(0, int32), array(false), array(0.0f)}; 25 | CHECK_EQ(result_type(arrs), float32); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /mlx/backend/metal/matmul.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "mlx/backend/metal/copy.h" 8 | #include "mlx/backend/metal/device.h" 9 | #include "mlx/backend/metal/mps/gemm.h" 10 | #include "mlx/backend/metal/utils.h" 11 | #include "mlx/utils.h" 12 | 13 | namespace mlx::core { 14 | 15 | void mlx_matmul( 16 | const Stream& s, 17 | metal::Device& d, 18 | const array& a, 19 | const array& b, 20 | array& out, 21 | int M, 22 | int N, 23 | int K, 24 | int batch_size_out, 25 | int lda, 26 | int ldb, 27 | bool transpose_a, 28 | bool transpose_b, 29 | std::vector& copies); 30 | 31 | } // namespace mlx::core -------------------------------------------------------------------------------- /python/mlx/nn/layers/containers.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from mlx.nn.layers.base import Module 4 | 5 | 6 | class Sequential(Module): 7 | """A layer that calls the passed callables in order. 8 | 9 | We can pass either modules or plain callables to the Sequential module. If 10 | our functions have learnable parameters they should be implemented as 11 | ``nn.Module`` instances. 12 | 13 | Args: 14 | modules (tuple of Callables): The modules to call in order 15 | """ 16 | 17 | def __init__(self, *modules): 18 | super().__init__() 19 | self.layers = list(modules) 20 | 21 | def __call__(self, x): 22 | for m in self.layers: 23 | x = m(x) 24 | return x 25 | -------------------------------------------------------------------------------- /mlx/backend/common/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "mlx/array.h" 8 | 9 | namespace mlx::core { 10 | 11 | inline size_t elem_to_loc( 12 | int elem, 13 | const std::vector& shape, 14 | const std::vector& strides) { 15 | size_t loc = 0; 16 | for (int i = shape.size() - 1; i >= 0; --i) { 17 | auto q_and_r = ldiv(elem, shape[i]); 18 | loc += q_and_r.rem * strides[i]; 19 | elem = q_and_r.quot; 20 | } 21 | return loc; 22 | } 23 | 24 | inline size_t elem_to_loc(int elem, const array& a) { 25 | if (a.flags().row_contiguous) { 26 | return elem; 27 | } 28 | return elem_to_loc(elem, a.shape(), a.strides()); 29 | } 30 | 31 | } // namespace mlx::core 32 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ## Build the Docs 2 | 3 | ### Setup (do once) 4 | 5 | Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html) 6 | for example with `conda`: 7 | 8 | ``` 9 | conda install sphinx 10 | pip install sphinx-book-theme 11 | ``` 12 | 13 | ### Build 14 | 15 | Build the docs from `mlx/docs/` 16 | 17 | ``` 18 | make html 19 | ``` 20 | 21 | View the docs by running a server in `mlx/docs/build/html/`: 22 | 23 | ``` 24 | python -m http.server 25 | ``` 26 | 27 | and point your browser to `http://localhost:`. 28 | 29 | ### Push to Github Pages 30 | 31 | Check-out the `gh-pages` branch (`git switch gh-pages`) and build 32 | the docs. Then force add the `build/html` directory: 33 | 34 | `git add -f build/html` 35 | 36 | Commit and push the changes to the `gh-pages` branch. 37 | -------------------------------------------------------------------------------- /mlx/stream.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/device.h" 6 | 7 | namespace mlx::core { 8 | 9 | struct Stream { 10 | int index; 11 | Device device; 12 | explicit Stream(int index, Device device) : index(index), device(device) {} 13 | }; 14 | 15 | /** Get the default stream for the given device. */ 16 | Stream default_stream(Device d); 17 | 18 | /** Make the stream the default for its device. */ 19 | void set_default_stream(Stream s); 20 | 21 | /** Make a new stream on the given device. */ 22 | Stream new_stream(Device d); 23 | 24 | inline bool operator==(const Stream& lhs, const Stream& rhs) { 25 | return lhs.index == rhs.index; 26 | } 27 | 28 | inline bool operator!=(const Stream& lhs, const Stream& rhs) { 29 | return !(lhs == rhs); 30 | } 31 | 32 | } // namespace mlx::core 33 | -------------------------------------------------------------------------------- /mlx/backend/common/copy.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/array.h" 6 | #include "mlx/backend/common/utils.h" 7 | 8 | namespace mlx::core { 9 | 10 | enum class CopyType { 11 | // Copy a raw scalar input into the full contiguous output 12 | Scalar, 13 | 14 | // Copy the raw input buffer contiguously into a raw output buffer of the same 15 | // size 16 | Vector, 17 | 18 | // Copy the full virtual input to the full contiguous output 19 | General, 20 | 21 | // Copy the full virtual input to the full virtual output. We assume the 22 | // input and output have the same shape. 23 | GeneralGeneral 24 | }; 25 | 26 | void copy(const array& src, array& dst, CopyType ctype); 27 | void copy_inplace(const array& src, array& dst, CopyType ctype); 28 | 29 | } // namespace mlx::core 30 | -------------------------------------------------------------------------------- /python/tests/test_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import unittest 4 | 5 | import mlx.core as mx 6 | import mlx.utils 7 | 8 | import mlx_tests 9 | 10 | 11 | class TestTreeUtils(mlx_tests.MLXTestCase): 12 | def test_tree_map(self): 13 | tree = {"a": 0, "b": 1, "c": 2} 14 | tree = mlx.utils.tree_map(lambda x: x + 1, tree) 15 | 16 | expected_tree = {"a": 1, "b": 2, "c": 3} 17 | self.assertEqual(tree, expected_tree) 18 | 19 | def test_tree_flatten(self): 20 | tree = [{"a": 1, "b": 2}, "c"] 21 | vals = (1, 2, "c") 22 | flat_tree = mlx.utils.tree_flatten(tree) 23 | self.assertEqual(list(zip(*flat_tree))[1], vals) 24 | self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree) 25 | 26 | 27 | if __name__ == "__main__": 28 | unittest.main() 29 | -------------------------------------------------------------------------------- /python/mlx/nn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from mlx.nn.layers.base import Module 4 | from mlx.nn.layers.activations import ( 5 | GELU, 6 | ReLU, 7 | SiLU, 8 | gelu, 9 | gelu_approx, 10 | gelu_fast_approx, 11 | relu, 12 | silu, 13 | ) 14 | from mlx.nn.layers.containers import Sequential 15 | from mlx.nn.layers.convolution import Conv1d, Conv2d 16 | from mlx.nn.layers.dropout import Dropout 17 | from mlx.nn.layers.embedding import Embedding 18 | from mlx.nn.layers.linear import Linear 19 | from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm 20 | from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding 21 | from mlx.nn.layers.transformer import ( 22 | MultiHeadAttention, 23 | TransformerEncoder, 24 | TransformerEncoderLayer, 25 | ) 26 | -------------------------------------------------------------------------------- /tests/graph_optimize_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "doctest/doctest.h" 4 | 5 | #include "mlx/mlx.h" 6 | 7 | using namespace mlx::core; 8 | 9 | TEST_CASE("test simplify scalars") { 10 | auto a = array({-1.0f, 2.0f}); 11 | auto b = maximum(a, array(0.0f)); 12 | auto c = maximum(-a, array(0.0f)); 13 | auto d = b + c; 14 | simplify({d}); 15 | CHECK(b.inputs()[1].id() == c.inputs()[1].id()); 16 | } 17 | 18 | TEST_CASE("test simplify") { 19 | auto a = array({1.0f, 2.0f}); 20 | auto b = exp(a) + exp(a); 21 | simplify(b); 22 | eval(b); 23 | CHECK(b.inputs()[0].id() == b.inputs()[1].id()); 24 | } 25 | 26 | TEST_CASE("test no simplify") { 27 | auto a = array({1.0f, 2.0f}); 28 | auto b = cos(a) + sin(a); 29 | simplify(b); 30 | eval(b); 31 | CHECK(b.inputs()[0].id() != b.inputs()[1].id()); 32 | } 33 | -------------------------------------------------------------------------------- /docs/src/python/array.rst: -------------------------------------------------------------------------------- 1 | .. _array: 2 | 3 | Array 4 | ===== 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | array 12 | array.astype 13 | array.item 14 | array.tolist 15 | array.dtype 16 | array.ndim 17 | array.shape 18 | array.size 19 | Dtype 20 | array.abs 21 | array.all 22 | array.any 23 | array.argmax 24 | array.argmin 25 | array.cos 26 | array.dtype 27 | array.exp 28 | array.log 29 | array.log1p 30 | array.logsumexp 31 | array.max 32 | array.mean 33 | array.min 34 | array.prod 35 | array.reciprocal 36 | array.reshape 37 | array.rsqrt 38 | array.sin 39 | array.split 40 | array.sqrt 41 | array.square 42 | array.sum 43 | array.transpose 44 | array.T 45 | array.var 46 | -------------------------------------------------------------------------------- /mlx/device.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/device.h" 4 | #include "mlx/backend/metal/metal.h" 5 | 6 | namespace mlx::core { 7 | 8 | static Device default_device_{ 9 | metal::is_available() ? Device::gpu : Device::cpu}; 10 | 11 | const Device& default_device() { 12 | return default_device_; 13 | } 14 | 15 | void set_default_device(const Device& d) { 16 | if (!metal::is_available() && d == Device::gpu) { 17 | throw std::invalid_argument( 18 | "[set_default_device] Cannot set gpu device without gpu backend."); 19 | } 20 | default_device_ = d; 21 | } 22 | 23 | bool operator==(const Device& lhs, const Device& rhs) { 24 | return lhs.type == rhs.type && lhs.index == rhs.index; 25 | } 26 | 27 | bool operator!=(const Device& lhs, const Device& rhs) { 28 | return !(lhs == rhs); 29 | } 30 | 31 | } // namespace mlx::core 32 | -------------------------------------------------------------------------------- /mlx/backend/accelerate/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include "mlx/dtype.h" 7 | 8 | namespace mlx::core { 9 | 10 | BNNSDataType to_bnns_dtype(Dtype mlx_dtype) { 11 | uint32_t size_bits = size_of(mlx_dtype) * 8; 12 | switch (kindof(mlx_dtype)) { 13 | case Dtype::Kind::b: 14 | return BNNSDataTypeBoolean; 15 | case Dtype::Kind::u: 16 | return BNNSDataType(BNNSDataTypeUIntBit | size_bits); 17 | case Dtype::Kind::i: 18 | return BNNSDataType(BNNSDataTypeIntBit | size_bits); 19 | case Dtype::Kind::f: 20 | return BNNSDataType(BNNSDataTypeFloatBit | size_bits); 21 | case Dtype::Kind::V: 22 | return BNNSDataTypeBFloat16; 23 | case Dtype::Kind::c: 24 | throw std::invalid_argument("BNNS does not support complex types"); 25 | } 26 | } 27 | 28 | } // namespace mlx::core -------------------------------------------------------------------------------- /python/src/mlx.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #define STRINGIFY(x) #x 6 | #define TOSTRING(x) STRINGIFY(x) 7 | 8 | namespace py = pybind11; 9 | 10 | void init_array(py::module_&); 11 | void init_device(py::module_&); 12 | void init_stream(py::module_&); 13 | void init_metal(py::module_&); 14 | void init_ops(py::module_&); 15 | void init_transforms(py::module_&); 16 | void init_random(py::module_&); 17 | void init_fft(py::module_&); 18 | 19 | PYBIND11_MODULE(core, m) { 20 | m.doc() = "mlx: A framework for machine learning on Apple Silicon."; 21 | 22 | auto reprlib_fix = py::module_::import("mlx._reprlib_fix"); 23 | 24 | init_device(m); 25 | init_stream(m); 26 | init_array(m); 27 | init_metal(m); 28 | init_ops(m); 29 | init_transforms(m); 30 | init_random(m); 31 | init_fft(m); 32 | m.attr("__version__") = TOSTRING(_VERSION_); 33 | } 34 | -------------------------------------------------------------------------------- /mlx/backend/metal/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE 4 | ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp 7 | ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp 8 | ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp 9 | ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp 10 | ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp 11 | ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp 12 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp 13 | ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp 14 | ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp 15 | ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp 16 | ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp 17 | ) 18 | 19 | if (NOT MLX_METAL_PATH) 20 | set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/) 21 | endif() 22 | 23 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels) 24 | 25 | target_compile_definitions( 26 | mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib") 27 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | FetchContent_Declare( 2 | doctest 3 | GIT_REPOSITORY "https://github.com/onqtam/doctest" 4 | GIT_TAG "b7c21ec5ceeadb4951b00396fc1e4642dd347e5f" 5 | ) 6 | FetchContent_MakeAvailable(doctest) 7 | 8 | add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) 9 | 10 | if (MLX_BUILD_METAL) 11 | set( 12 | METAL_TEST_SOURCES 13 | metal_tests.cpp 14 | ) 15 | endif() 16 | 17 | target_sources(tests PRIVATE 18 | allocator_tests.cpp 19 | array_tests.cpp 20 | arg_reduce_tests.cpp 21 | autograd_tests.cpp 22 | blas_tests.cpp 23 | creations_tests.cpp 24 | device_tests.cpp 25 | eval_tests.cpp 26 | fft_tests.cpp 27 | graph_optimize_tests.cpp 28 | load_tests.cpp 29 | ops_tests.cpp 30 | random_tests.cpp 31 | scheduler_tests.cpp 32 | utils_tests.cpp 33 | vmap_tests.cpp 34 | ${METAL_TEST_SOURCES} 35 | ) 36 | 37 | target_link_libraries(tests PRIVATE mlx doctest) 38 | add_test(NAME tests COMMAND tests) 39 | -------------------------------------------------------------------------------- /mlx/backend/common/threefry.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/backend/common/threefry.h" 4 | 5 | namespace mlx::core::random { 6 | 7 | std::pair threefry2x32_hash( 8 | const std::pair& key, 9 | std::pair count) { 10 | constexpr static uint32_t rotations[2][4] = { 11 | {13, 15, 26, 6}, {17, 29, 16, 24}}; 12 | 13 | uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA}; 14 | 15 | count.first += ks[0]; 16 | count.second += ks[1]; 17 | 18 | for (int i = 0; i < 5; ++i) { 19 | for (auto r : rotations[i % 2]) { 20 | count.first += count.second; 21 | count.second = (count.second << r) | (count.second >> (32 - r)); 22 | count.second ^= count.first; 23 | } 24 | count.first += ks[(i + 1) % 3]; 25 | count.second += ks[(i + 2) % 3] + i + 1; 26 | } 27 | 28 | return count; 29 | } 30 | 31 | } // namespace mlx::core::random 32 | -------------------------------------------------------------------------------- /python/tests/test_optimizers.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import unittest 4 | 5 | import mlx.core as mx 6 | import mlx.optimizers as opt 7 | import mlx.utils 8 | 9 | import mlx_tests 10 | 11 | 12 | class TestOptimizers(mlx_tests.MLXTestCase): 13 | def test_optimizers(self): 14 | params = { 15 | "first": [mx.zeros((10,)), mx.zeros((1,))], 16 | "second": mx.zeros((1,)), 17 | } 18 | grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params) 19 | 20 | for optim in [opt.SGD(0.1), opt.Adam(0.1)]: 21 | update = optim.apply_gradients(grads, params) 22 | mx.eval(update) 23 | equal_shape = mlx.utils.tree_map( 24 | lambda x, y: x.shape == y.shape, params, update 25 | ) 26 | all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape)) 27 | self.assertTrue(all_equal) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/device_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "doctest/doctest.h" 4 | 5 | #include 6 | 7 | #include "mlx/mlx.h" 8 | 9 | using namespace mlx::core; 10 | 11 | TEST_CASE("test device placement") { 12 | auto device = default_device(); 13 | Device d = metal::is_available() ? Device::gpu : Device::cpu; 14 | if (std::getenv("DEVICE") == nullptr) { 15 | CHECK_EQ(device, d); 16 | } 17 | 18 | array x(1.0f); 19 | array y(1.0f); 20 | auto z = add(x, y, default_device()); 21 | if (metal::is_available()) { 22 | z = add(x, y, Device::gpu); 23 | z = add(x, y, Device(Device::gpu, 0)); 24 | } else { 25 | CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument); 26 | CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument); 27 | } 28 | 29 | // Set the default device to the CPU 30 | set_default_device(Device::cpu); 31 | CHECK_EQ(default_device(), Device::cpu); 32 | 33 | // Revert 34 | set_default_device(device); 35 | } 36 | -------------------------------------------------------------------------------- /python/tests/test_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from functools import partial 4 | 5 | import unittest 6 | 7 | import mlx.core as mx 8 | 9 | import mlx_tests 10 | 11 | 12 | class TestEval(mlx_tests.MLXTestCase): 13 | def test_eval(self): 14 | arrs = [mx.ones((2, 2)) for _ in range(4)] 15 | mx.eval(*arrs) 16 | for x in arrs: 17 | self.assertEqual(x.tolist(), [[1, 1], [1, 1]]) 18 | 19 | def test_retain_graph(self): 20 | def fun(x, retain_graph): 21 | y = 3 * x 22 | mx.eval(y, retain_graph=retain_graph) 23 | return 2 * y 24 | 25 | dfun_dx_1 = mx.grad(partial(fun, retain_graph=False)) 26 | dfun_dx_2 = mx.grad(partial(fun, retain_graph=True)) 27 | 28 | with self.assertRaises(ValueError): 29 | dfun_dx_1(mx.array(1.0)) 30 | 31 | y = dfun_dx_2(mx.array(1.0)) 32 | self.assertEqual(y.item(), 6.0) 33 | 34 | 35 | if __name__ == "__main__": 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /benchmarks/numpy/single_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import numpy as np 4 | 5 | from time_utils import time_fn 6 | 7 | 8 | def time_add(): 9 | a = np.ones((100, 100, 10), dtype=np.float32) 10 | b = np.ones((100, 100, 10), dtype=np.float32) 11 | time_fn(np.add, a, b) 12 | 13 | 14 | def time_matmul(): 15 | a = np.random.rand(1000, 500).astype(np.float32) 16 | b = np.random.rand(500, 1000).astype(np.float32) 17 | time_fn(np.matmul, a, b) 18 | 19 | 20 | def time_exp(): 21 | a = np.random.randn(1000, 100).astype(np.float32) 22 | time_fn(np.exp, a) 23 | 24 | 25 | def time_take(): 26 | a = np.random.rand(10000, 500) 27 | ids = np.random.randint(0, 10000, (20, 10)) 28 | ids = [idx.reshape(-1) for idx in np.split(ids, 20)] 29 | 30 | def random_take(): 31 | return [np.take(a, idx, 0) for idx in ids] 32 | 33 | time_fn(random_take) 34 | 35 | 36 | if __name__ == "__main__": 37 | time_add() 38 | time_matmul() 39 | time_exp() 40 | time_take() 41 | -------------------------------------------------------------------------------- /python/mlx/nn/layers/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import math 4 | 5 | import mlx.core as mx 6 | from mlx.nn.layers.base import Module 7 | 8 | 9 | class Embedding(Module): 10 | """Implements a simple lookup table that maps each input integer to a 11 | high-dimensional vector. 12 | 13 | Typically used to embed discrete tokens for processing by neural networks. 14 | 15 | Args: 16 | num_embeddings (int): How many possible discrete tokens can we embed. 17 | Usually called the vocabulary size. 18 | dims (int): The dimensionality of the embeddings. 19 | """ 20 | 21 | def __init__(self, num_embeddings: int, dims: int): 22 | super().__init__() 23 | scale = math.sqrt(1 / dims) 24 | self.weight = mx.random.normal((num_embeddings, dims)) * scale 25 | 26 | def _extra_repr(self): 27 | return f"{self.weight.shape[0]}, {self.weight.shape[1]}" 28 | 29 | def __call__(self, x): 30 | return self.weight[x] 31 | -------------------------------------------------------------------------------- /benchmarks/cpp/autograd.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/mlx.h" 6 | #include "time_utils.h" 7 | 8 | using namespace mlx::core; 9 | 10 | void time_value_and_grad() { 11 | auto x = ones({200, 1000}); 12 | eval(x); 13 | auto fn = [](array x) { 14 | for (int i = 0; i < 20; ++i) { 15 | x = log(exp(x)); 16 | } 17 | return sum(x); 18 | }; 19 | 20 | auto grad_fn = grad(fn); 21 | auto independent_value_and_grad = [&]() { 22 | auto value = fn(x); 23 | auto dfdx = grad_fn(x); 24 | return std::vector{value, dfdx}; 25 | }; 26 | TIME(independent_value_and_grad); 27 | 28 | auto value_and_grad_fn = value_and_grad(fn); 29 | auto combined_value_and_grad = [&]() { 30 | auto [value, dfdx] = value_and_grad_fn(x); 31 | return std::vector{value, dfdx}; 32 | }; 33 | TIME(combined_value_and_grad); 34 | } 35 | 36 | int main() { 37 | std::cout << "Benchmarks for " << default_device() << std::endl; 38 | time_value_and_grad(); 39 | } 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Metal libraries 10 | *.metallib 11 | 12 | # Distribution / packaging 13 | python/mlx/share 14 | python/mlx/include 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # vim 35 | *.swp 36 | 37 | # Ignore build dir 38 | build/ 39 | 40 | # Prerequisites 41 | *.d 42 | 43 | # Compiled Object files 44 | *.slo 45 | *.lo 46 | *.o 47 | *.obj 48 | 49 | # Precompiled Headers 50 | *.gch 51 | *.pch 52 | 53 | # Compiled Dynamic libraries 54 | *.so 55 | *.dylib 56 | *.dll 57 | 58 | # Fortran module files 59 | *.mod 60 | *.smod 61 | 62 | # Compiled Static libraries 63 | *.lai 64 | *.la 65 | *.a 66 | *.lib 67 | 68 | # Executables 69 | *.exe 70 | *.out 71 | *.app 72 | 73 | # VSCode 74 | .vscode/ 75 | .DS_Store 76 | -------------------------------------------------------------------------------- /python/src/stream.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "mlx/stream.h" 8 | #include "mlx/utils.h" 9 | 10 | namespace py = pybind11; 11 | using namespace py::literals; 12 | using namespace mlx::core; 13 | 14 | void init_stream(py::module_& m) { 15 | py::class_(m, "Stream") 16 | .def(py::init(), "index"_a, "device"_a) 17 | .def_readonly("device", &Stream::device) 18 | .def( 19 | "__repr__", 20 | [](const Stream& s) { 21 | std::ostringstream os; 22 | os << s; 23 | return os.str(); 24 | }) 25 | .def("__eq__", [](const Stream& s1, const Stream& s2) { 26 | return s1 == s2; 27 | }); 28 | 29 | py::implicitly_convertible(); 30 | 31 | m.def("default_stream", &default_stream, "device"_a); 32 | m.def("set_default_stream", &set_default_stream, "stream"_a); 33 | m.def("new_stream", &new_stream, "device"_a); 34 | } 35 | -------------------------------------------------------------------------------- /python/src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | pybind11_add_module( 2 | core 3 | ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp 4 | ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp 7 | ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp 8 | ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp 9 | ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp 10 | ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp 11 | ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp 12 | ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp 13 | ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp 14 | ) 15 | 16 | if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) 17 | set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) 18 | endif() 19 | 20 | set_target_properties( 21 | core 22 | PROPERTIES 23 | LIBRARY_OUTPUT_DIRECTORY 24 | ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY} 25 | ) 26 | 27 | target_link_libraries(core PRIVATE mlx) 28 | target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION}) 29 | 30 | if(BUILD_SHARED_LIBS) 31 | target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib) 32 | endif() 33 | -------------------------------------------------------------------------------- /examples/python/logistic_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import mlx.core as mx 4 | import time 5 | 6 | num_features = 100 7 | num_examples = 1_000 8 | num_iters = 10_000 9 | lr = 0.1 10 | 11 | # True parameters 12 | w_star = mx.random.normal((num_features,)) 13 | 14 | # Input examples 15 | X = mx.random.normal((num_examples, num_features)) 16 | 17 | # Labels 18 | y = (X @ w_star) > 0 19 | 20 | 21 | # Initialize random parameters 22 | w = 1e-2 * mx.random.normal((num_features,)) 23 | 24 | 25 | def loss_fn(w): 26 | logits = X @ w 27 | return mx.mean(mx.logaddexp(0.0, logits) - y * logits) 28 | 29 | 30 | grad_fn = mx.grad(loss_fn) 31 | 32 | tic = time.time() 33 | for _ in range(num_iters): 34 | grad = grad_fn(w) 35 | w = w - lr * grad 36 | mx.eval(w) 37 | 38 | toc = time.time() 39 | 40 | loss = loss_fn(w) 41 | final_preds = (X @ w) > 0 42 | acc = mx.mean(final_preds == y) 43 | 44 | throughput = num_iters / (toc - tic) 45 | print( 46 | f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} " 47 | f"Throughput {throughput:.5f} (it/s)" 48 | ) 49 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/arange.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/backend/metal/kernels/bf16.h" 4 | 5 | template 6 | [[kernel]] void arange( 7 | constant const T& start, 8 | constant const T& step, 9 | device T* out, 10 | uint index [[thread_position_in_grid]]) { 11 | out[index] = start + index * step; 12 | } 13 | 14 | #define instantiate_arange(tname, type) \ 15 | template [[host_name("arange" #tname)]] \ 16 | [[kernel]] void arange( \ 17 | constant const type& start, \ 18 | constant const type& step, \ 19 | device type* out, \ 20 | uint index [[thread_position_in_grid]]); 21 | 22 | instantiate_arange(uint8, uint8_t) 23 | instantiate_arange(uint16, uint16_t) 24 | instantiate_arange(uint32, uint32_t) 25 | instantiate_arange(uint64, uint64_t) 26 | instantiate_arange(int8, int8_t) 27 | instantiate_arange(int16, int16_t) 28 | instantiate_arange(int32, int32_t) 29 | instantiate_arange(int64, int64_t) 30 | instantiate_arange(float16, half) 31 | instantiate_arange(float32, float) 32 | instantiate_arange(bfloat16, bfloat16_t) -------------------------------------------------------------------------------- /examples/python/linear_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import mlx.core as mx 4 | import time 5 | 6 | num_features = 100 7 | num_examples = 1_000 8 | num_iters = 10_000 9 | lr = 0.01 10 | 11 | # True parameters 12 | w_star = mx.random.normal((num_features,)) 13 | 14 | # Input examples (design matrix) 15 | X = mx.random.normal((num_examples, num_features)) 16 | 17 | # Noisy labels 18 | eps = 1e-2 * mx.random.normal((num_examples,)) 19 | y = X @ w_star + eps 20 | 21 | # Initialize random parameters 22 | w = 1e-2 * mx.random.normal((num_features,)) 23 | 24 | 25 | def loss_fn(w): 26 | return 0.5 * mx.mean(mx.square(X @ w - y)) 27 | 28 | 29 | grad_fn = mx.grad(loss_fn) 30 | 31 | tic = time.time() 32 | for _ in range(num_iters): 33 | grad = grad_fn(w) 34 | w = w - lr * grad 35 | mx.eval(w) 36 | toc = time.time() 37 | 38 | loss = loss_fn(w) 39 | error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5 40 | throughput = num_iters / (toc - tic) 41 | 42 | print( 43 | f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, " 44 | f"Throughput {throughput:.5f} (it/s)" 45 | ) 46 | -------------------------------------------------------------------------------- /python/mlx/nn/layers/dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import mlx.core as mx 4 | from mlx.nn.layers.base import Module 5 | 6 | 7 | class Dropout(Module): 8 | """Randomly zero a portion of the elements during training. 9 | 10 | The remaining elements are multiplied with :math:`\frac{1}{1-p}` where 11 | :math:`p` is the probability of zeroing an element. This is done so the 12 | expected value of a given element will remain the same. 13 | 14 | Args: 15 | p (float): The probability to zero an element 16 | """ 17 | 18 | def __init__(self, p: float = 0.5): 19 | super().__init__() 20 | 21 | if p < 0 or p >= 1: 22 | raise ValueError("The dropout probability should be in [0, 1)") 23 | 24 | self._p_1 = 1 - p 25 | 26 | def _extra_repr(self): 27 | return f"p={1-self._p_1}" 28 | 29 | def __call__(self, x): 30 | if self._p_1 == 1 or not self.training: 31 | return x 32 | 33 | mask = mx.random.bernoulli(self._p_1, x.shape) 34 | 35 | return (1 / self._p_1) * mask.astype(x.dtype) * x 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2023 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /python/mlx/nn/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | from typing import Callable 4 | 5 | import mlx.core as mx 6 | 7 | 8 | def value_and_grad(model: "mlx.nn.Module", fn: Callable): 9 | """Transform the passed function ``fn`` to a function that computes the 10 | gradients of ``fn`` wrt the model's trainable parameters and also its 11 | value. 12 | 13 | Args: 14 | model (mlx.nn.Module): The model whose trainable parameters to compute 15 | gradients for 16 | fn (Callable): The scalar function to compute gradients for 17 | 18 | Returns: 19 | A callable that returns the value of ``fn`` and the gradients wrt the 20 | trainable parameters of ``model`` 21 | """ 22 | 23 | def inner_fn(params, *args, **kwargs): 24 | model.update(params) 25 | return fn(*args, **kwargs) 26 | 27 | value_grad_fn = mx.value_and_grad(inner_fn) 28 | 29 | def wrapped_value_grad_fn(*args, **kwargs): 30 | value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs) 31 | return value, grad 32 | 33 | return wrapped_value_grad_fn 34 | -------------------------------------------------------------------------------- /examples/extensions/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "axpby/axpby.h" 7 | 8 | namespace py = pybind11; 9 | using namespace py::literals; 10 | using namespace mlx::core; 11 | 12 | PYBIND11_MODULE(mlx_sample_extensions, m) { 13 | m.doc() = "Sample C++ and metal extensions for MLX"; 14 | 15 | m.def( 16 | "axpby", 17 | &axpby, 18 | "x"_a, 19 | "y"_a, 20 | py::pos_only(), 21 | "alpha"_a, 22 | "beta"_a, 23 | py::kw_only(), 24 | "stream"_a = py::none(), 25 | R"pbdoc( 26 | Scale and sum two vectors elementwise 27 | ``z = alpha * x + beta * y`` 28 | 29 | Follows numpy style broadcasting between ``x`` and ``y`` 30 | Inputs are upcasted to floats if needed 31 | 32 | Args: 33 | x (array): Input array. 34 | y (array): Input array. 35 | alpha (float): Scaling factor for ``x``. 36 | beta (float): Scaling factor for ``y``. 37 | 38 | Returns: 39 | array: ``alpha * x + beta * y`` 40 | )pbdoc"); 41 | } -------------------------------------------------------------------------------- /tests/allocator_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "doctest/doctest.h" 6 | 7 | #include "mlx/allocator.h" 8 | 9 | using namespace mlx::core; 10 | 11 | TEST_CASE("test simple allocations") { 12 | { 13 | auto buffer = allocator::malloc(sizeof(float)); 14 | auto fptr = static_cast(buffer.raw_ptr()); 15 | *fptr = 0.5f; 16 | CHECK_EQ(*fptr, 0.5f); 17 | allocator::free(buffer); 18 | } 19 | 20 | { 21 | auto buffer = allocator::malloc(128 * sizeof(int)); 22 | int* ptr = static_cast(buffer.raw_ptr()); 23 | for (int i = 0; i < 128; ++i) { 24 | ptr[i] = i; 25 | } 26 | allocator::free(buffer); 27 | } 28 | 29 | { 30 | auto buffer = allocator::malloc(0); 31 | allocator::free(buffer); 32 | } 33 | } 34 | 35 | TEST_CASE("test large allocations") { 36 | size_t size = 1 << 30; 37 | for (int i = 0; i < 100; ++i) { 38 | auto buffer = allocator::malloc(size); 39 | allocator::free(buffer); 40 | } 41 | // Shouldn't be able to allocate an exabyte anytime soon. 42 | CHECK_THROWS_AS(allocator::malloc(1ull << 60), std::runtime_error); 43 | } 44 | -------------------------------------------------------------------------------- /python/mlx/nn/layers/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import math 4 | 5 | import mlx.core as mx 6 | from mlx.nn.layers.base import Module 7 | 8 | 9 | class Linear(Module): 10 | """Applies an affine transformation to the input. 11 | 12 | Args: 13 | input_dims (int): The dimensionality of the input features 14 | output_dims (int): The dimensionality of the output features 15 | bias (bool): If set to False then the layer will not use a bias 16 | """ 17 | 18 | def __init__(self, input_dims: int, output_dims: int, bias: bool = True): 19 | super().__init__() 20 | scale = math.sqrt(1 / input_dims) 21 | self.weight = mx.random.uniform( 22 | low=-scale, 23 | high=scale, 24 | shape=(output_dims, input_dims), 25 | ) 26 | if bias: 27 | self.bias = mx.zeros((output_dims,)) 28 | 29 | def _extra_repr(self): 30 | return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}" 31 | 32 | def __call__(self, x): 33 | x = x @ self.weight.T 34 | if "bias" in self: 35 | x = x + self.bias 36 | return x 37 | -------------------------------------------------------------------------------- /docs/src/python/random.rst: -------------------------------------------------------------------------------- 1 | .. _random: 2 | 3 | Random 4 | ====== 5 | 6 | Random sampling functions in MLX use an implicit global PRNG state by default. 7 | However, all function take an optional ``key`` keyword argument for when more 8 | fine-grained control or explicit state management is needed. 9 | 10 | For example, you can generate random numbers with: 11 | 12 | .. code-block:: python 13 | 14 | for _ in range(3): 15 | print(mx.random.uniform()) 16 | 17 | which will print a sequence of unique pseudo random numbers. Alternatively you 18 | can explicitly set the key: 19 | 20 | .. code-block:: python 21 | 22 | key = mx.random.key(0) 23 | for _ in range(3): 24 | print(mx.random.uniform(key=key)) 25 | 26 | which will yield the same pseudo random number at each iteration. 27 | 28 | Following `JAX's PRNG design `_ 29 | we use a splittable version of Threefry, which is a counter-based PRNG. 30 | 31 | .. currentmodule:: mlx.core.random 32 | 33 | .. autosummary:: 34 | :toctree: _autosummary 35 | 36 | seed 37 | key 38 | split 39 | bernoulli 40 | categorical 41 | gumbel 42 | normal 43 | randint 44 | uniform 45 | truncated_normal 46 | -------------------------------------------------------------------------------- /mlx/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | target_sources( 2 | mlx 3 | PRIVATE 4 | ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp 6 | ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp 7 | ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp 8 | ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp 9 | ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp 10 | ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp 11 | ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp 12 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp 13 | ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp 14 | ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp 15 | ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp 16 | ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp 17 | ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h 18 | ) 19 | 20 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) 21 | 22 | if (MLX_BUILD_ACCELERATE) 23 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) 24 | else() 25 | target_sources( 26 | mlx 27 | PRIVATE 28 | ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp 29 | ) 30 | endif() 31 | 32 | if (MLX_BUILD_METAL) 33 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) 34 | else() 35 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) 36 | endif() 37 | -------------------------------------------------------------------------------- /mlx/allocator.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "mlx/allocator.h" 7 | #include "mlx/scheduler.h" 8 | 9 | namespace mlx::core::allocator { 10 | 11 | Buffer malloc(size_t size) { 12 | auto buffer = allocator().malloc(size); 13 | if (size && !buffer.ptr()) { 14 | std::ostringstream msg; 15 | msg << "[malloc] Unable to allocate " << size << " bytes."; 16 | throw std::runtime_error(msg.str()); 17 | } 18 | return buffer; 19 | } 20 | 21 | void free(Buffer buffer) { 22 | return allocator().free(buffer); 23 | } 24 | 25 | Buffer CommonAllocator::malloc(size_t size) { 26 | return Buffer{std::malloc(size)}; 27 | } 28 | 29 | void CommonAllocator::free(Buffer buffer) { 30 | std::free(buffer.raw_ptr()); 31 | } 32 | 33 | Buffer malloc_or_wait(size_t size) { 34 | auto buffer = allocator().malloc(size); 35 | 36 | while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) { 37 | scheduler::wait_for_one(); 38 | buffer = allocator().malloc(size); 39 | } 40 | 41 | if (size && !buffer.ptr()) { 42 | std::ostringstream msg; 43 | msg << "[malloc_or_wait] Unable to allocate " << size << " bytes."; 44 | throw std::runtime_error(msg.str()); 45 | } 46 | 47 | return buffer; 48 | } 49 | 50 | } // namespace mlx::core::allocator 51 | -------------------------------------------------------------------------------- /mlx/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "array.h" 6 | #include "device.h" 7 | #include "dtype.h" 8 | #include "stream.h" 9 | 10 | namespace mlx::core { 11 | 12 | /** The type from promoting the arrays' types with one another. */ 13 | Dtype result_type(const std::vector& arrays); 14 | 15 | std::vector broadcast_shapes( 16 | const std::vector& s1, 17 | const std::vector& s2); 18 | 19 | std::ostream& operator<<(std::ostream& os, const Device& d); 20 | std::ostream& operator<<(std::ostream& os, const Stream& s); 21 | std::ostream& operator<<(std::ostream& os, const Dtype& d); 22 | std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); 23 | std::ostream& operator<<(std::ostream& os, array a); 24 | std::ostream& operator<<(std::ostream& os, const std::vector& v); 25 | std::ostream& operator<<(std::ostream& os, const std::vector& v); 26 | inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { 27 | return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j"; 28 | } 29 | inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { 30 | return os << static_cast(v); 31 | } 32 | inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { 33 | return os << static_cast(v); 34 | } 35 | } // namespace mlx::core 36 | -------------------------------------------------------------------------------- /python/src/device.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "mlx/device.h" 8 | #include "mlx/utils.h" 9 | 10 | namespace py = pybind11; 11 | using namespace py::literals; 12 | using namespace mlx::core; 13 | 14 | void init_device(py::module_& m) { 15 | py::enum_(m, "DeviceType") 16 | .value("cpu", Device::DeviceType::cpu) 17 | .value("gpu", Device::DeviceType::gpu) 18 | .export_values() 19 | .def( 20 | "__eq__", 21 | [](const Device::DeviceType& d1, const Device& d2) { 22 | return d1 == d2; 23 | }, 24 | py::prepend()); 25 | 26 | py::class_(m, "Device") 27 | .def(py::init(), "type"_a, "index"_a = 0) 28 | .def_readonly("type", &Device::type) 29 | .def( 30 | "__repr__", 31 | [](const Device& d) { 32 | std::ostringstream os; 33 | os << d; 34 | return os.str(); 35 | }) 36 | .def("__eq__", [](const Device& d1, const Device& d2) { 37 | return d1 == d2; 38 | }); 39 | 40 | py::implicitly_convertible(); 41 | 42 | m.def("default_device", &default_device); 43 | m.def("set_default_device", &set_default_device, "device"_a); 44 | } 45 | -------------------------------------------------------------------------------- /docs/src/python/data_types.rst: -------------------------------------------------------------------------------- 1 | .. _data_types: 2 | 3 | :orphan: 4 | 5 | Data Types 6 | ========== 7 | 8 | .. currentmodule:: mlx.core 9 | 10 | The default floating point type is ``float32`` and the default integer type is 11 | ``int32``. The table below shows supported values for :obj:`Dtype`. 12 | 13 | .. list-table:: Supported Data Types 14 | :widths: 5 3 20 15 | :header-rows: 1 16 | 17 | * - Type 18 | - Bytes 19 | - Description 20 | * - ``bool_`` 21 | - 1 22 | - Boolean (``True``, ``False``) data type 23 | * - ``uint8`` 24 | - 1 25 | - 8-bit unsigned integer 26 | * - ``uint16`` 27 | - 2 28 | - 16-bit unsigned integer 29 | * - ``uint32`` 30 | - 4 31 | - 32-bit unsigned integer 32 | * - ``uint32`` 33 | - 8 34 | - 32-bit unsigned integer 35 | * - ``int8`` 36 | - 1 37 | - 8-bit signed integer 38 | * - ``int16`` 39 | - 2 40 | - 16-bit signed integer 41 | * - ``int32`` 42 | - 4 43 | - 32-bit signed integer 44 | * - ``int64`` 45 | - 8 46 | - 64-bit signed integer 47 | * - ``float16`` 48 | - 2 49 | - 16-bit float, only available with `ARM C language extensions `_ 50 | * - ``float32`` 51 | - 4 52 | - 32-bit float 53 | -------------------------------------------------------------------------------- /mlx/scheduler.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/scheduler.h" 4 | #include "mlx/backend/metal/metal.h" 5 | 6 | namespace mlx::core { 7 | 8 | Stream default_stream(Device d) { 9 | if (!metal::is_available() && d == Device::gpu) { 10 | throw std::invalid_argument( 11 | "[default_stream] Cannot get gpu stream without gpu backend."); 12 | } 13 | return scheduler::scheduler().get_default_stream(d); 14 | } 15 | 16 | void set_default_stream(Stream s) { 17 | if (!metal::is_available() && s.device == Device::gpu) { 18 | throw std::invalid_argument( 19 | "[set_default_stream] Cannot set gpu stream without gpu backend."); 20 | } 21 | return scheduler::scheduler().set_default_stream(s); 22 | } 23 | 24 | Stream new_stream(Device d) { 25 | if (!metal::is_available() && d == Device::gpu) { 26 | throw std::invalid_argument( 27 | "[new_stream] Cannot make gpu stream without gpu backend."); 28 | } 29 | return scheduler::scheduler().new_stream(d); 30 | } 31 | 32 | Stream new_stream() { 33 | return scheduler::scheduler().new_stream(default_device()); 34 | } 35 | 36 | namespace scheduler { 37 | 38 | /** A singleton scheduler to manage devices, streams, and task execution. */ 39 | Scheduler& scheduler() { 40 | static Scheduler scheduler; 41 | return scheduler; 42 | } 43 | 44 | } // namespace scheduler 45 | } // namespace mlx::core 46 | -------------------------------------------------------------------------------- /docs/src/python/ops.rst: -------------------------------------------------------------------------------- 1 | .. _ops: 2 | 3 | Operations 4 | ========== 5 | 6 | .. currentmodule:: mlx.core 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | abs 12 | add 13 | all 14 | allclose 15 | any 16 | arange 17 | arccos 18 | arccosh 19 | arcsin 20 | arcsinh 21 | arctan 22 | arctanh 23 | argmax 24 | argmin 25 | argpartition 26 | argsort 27 | array_equal 28 | broadcast_to 29 | concatenate 30 | convolve 31 | conv1d 32 | conv2d 33 | cos 34 | cosh 35 | divide 36 | equal 37 | erf 38 | erfinv 39 | exp 40 | expand_dims 41 | full 42 | greater 43 | greater_equal 44 | less 45 | less_equal 46 | load 47 | log 48 | log2 49 | log10 50 | log1p 51 | logaddexp 52 | logical_not 53 | logsumexp 54 | matmul 55 | max 56 | maximum 57 | mean 58 | min 59 | minimum 60 | multiply 61 | negative 62 | ones 63 | ones_like 64 | partition 65 | pad 66 | prod 67 | reciprocal 68 | reshape 69 | rsqrt 70 | save 71 | savez 72 | savez_compressed 73 | sigmoid 74 | sign 75 | sin 76 | sinh 77 | softmax 78 | sort 79 | split 80 | sqrt 81 | square 82 | squeeze 83 | stop_gradient 84 | subtract 85 | sum 86 | take 87 | take_along_axis 88 | tan 89 | tanh 90 | transpose 91 | var 92 | where 93 | zeros 94 | zeros_like 95 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MLX 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | 1. Fork and submit pull requests to the repo. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If a change is likely to impact efficiency, run some of the benchmarks before 11 | and after the change. Examples of benchmarks can be found in `benchmarks/python/`. 12 | 4. If you've changed APIs, update the documentation. 13 | 5. Every PR should have passing tests and at least one review. 14 | 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 15 | This should install hooks for running `black` and `clang-format` to ensure 16 | consistent style for C++ and python code. 17 | 18 | You can also run the formatters manually as follows: 19 | 20 | ``` 21 | clang-format -i file.cpp 22 | ``` 23 | 24 | ``` 25 | black file.py 26 | ``` 27 | 28 | or run `pre-commit run --all-files` to check all files in the repo. 29 | 30 | ## Issues 31 | 32 | We use GitHub issues to track public bugs. Please ensure your description is 33 | clear and has sufficient instructions to be able to reproduce the issue. 34 | 35 | ## License 36 | 37 | By contributing to MLX, you agree that your contributions will be licensed 38 | under the LICENSE file in the root directory of this source tree. 39 | -------------------------------------------------------------------------------- /mlx/backend/common/load.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "mlx/allocator.h" 8 | #include "mlx/load.h" 9 | #include "mlx/primitives.h" 10 | 11 | namespace mlx::core { 12 | 13 | namespace { 14 | 15 | template 16 | void swap_endianess(uint8_t* data_bytes, size_t N) { 17 | struct Elem { 18 | uint8_t bytes[scalar_size]; 19 | }; 20 | 21 | Elem* data = reinterpret_cast(data_bytes); 22 | 23 | for (size_t i = 0; i < N; i++) { 24 | for (size_t j = 0; j < (scalar_size / 2); j++) { 25 | std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]); 26 | } 27 | } 28 | } 29 | 30 | } // namespace 31 | 32 | void Load::eval(const std::vector& inputs, array& out) { 33 | assert(inputs.size() == 0); 34 | out.set_data(allocator::malloc_or_wait(out.nbytes())); 35 | 36 | reader_->seek(offset_, std::ios_base::beg); 37 | reader_->read(out.data(), out.nbytes()); 38 | 39 | if (swap_endianness_) { 40 | switch (out.itemsize()) { 41 | case 2: 42 | swap_endianess<2>(out.data(), out.data_size()); 43 | break; 44 | case 4: 45 | swap_endianess<4>(out.data(), out.data_size()); 46 | break; 47 | case 8: 48 | swap_endianess<8>(out.data(), out.data_size()); 49 | break; 50 | } 51 | } 52 | } 53 | 54 | } // namespace mlx::core 55 | -------------------------------------------------------------------------------- /docs/src/python/optimizers.rst: -------------------------------------------------------------------------------- 1 | .. _optimizers: 2 | 3 | Optimizers 4 | ========== 5 | 6 | The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure 7 | :mod:`mlx.core` functions. A typical example involves calling 8 | :meth:`Optimizer.update` to update a model's parameters based on the loss 9 | gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the 10 | model's parameters and the **optimizer state**. 11 | 12 | .. code-block:: python 13 | 14 | # Create a model 15 | model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) 16 | mx.eval(model.parameters()) 17 | 18 | # Create the gradient function and the optimizer 19 | loss_and_grad_fn = nn.value_and_grad(model, loss_fn) 20 | optimizer = optim.SGD(learning_rate=learning_rate) 21 | 22 | for e in range(num_epochs): 23 | for X, y in batch_iterate(batch_size, train_images, train_labels): 24 | loss, grads = loss_and_grad_fn(model, X, y) 25 | 26 | # Update the model with the gradients. So far no computation has happened. 27 | optimizer.update(model, grads) 28 | 29 | # Compute the new parameters but also the optimizer state. 30 | mx.eval(model.parameters(), optimizer.state) 31 | 32 | .. currentmodule:: mlx.optimizers 33 | 34 | .. autosummary:: 35 | :toctree: _autosummary 36 | :template: optimizers-template.rst 37 | 38 | OptimizerState 39 | Optimizer 40 | SGD 41 | Adam 42 | -------------------------------------------------------------------------------- /benchmarks/cpp/time_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "mlx/mlx.h" 10 | 11 | #define milliseconds(x) \ 12 | (std::chrono::duration_cast(x).count() / 1e6) 13 | #define time_now() std::chrono::high_resolution_clock::now() 14 | 15 | #define TIME(FUNC, ...) \ 16 | std::cout << "Timing " << #FUNC << " ... " << std::flush \ 17 | << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \ 18 | << std::endl; 19 | 20 | #define TIMEM(MSG, FUNC, ...) \ 21 | std::cout << "Timing " \ 22 | << "(" << MSG << ") " << #FUNC << " ... " << std::flush \ 23 | << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \ 24 | << std::endl; 25 | 26 | template 27 | double time_fn(F fn, Args... args) { 28 | // warmup 29 | for (int i = 0; i < 5; ++i) { 30 | eval(fn(std::forward(args)...)); 31 | } 32 | 33 | int num_iters = 100; 34 | auto start = time_now(); 35 | for (int i = 0; i < num_iters; i++) { 36 | eval(fn(std::forward(args)...)); 37 | } 38 | auto end = time_now(); 39 | return milliseconds(end - start) / static_cast(num_iters); 40 | } 41 | -------------------------------------------------------------------------------- /docs/src/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | # -*- coding: utf-8 -*- 4 | 5 | import os 6 | import subprocess 7 | 8 | # -- Project information ----------------------------------------------------- 9 | 10 | project = "MLX" 11 | copyright = "2023, MLX Contributors" 12 | author = "MLX Contributors" 13 | version = "0.0.3" 14 | release = "0.0.3" 15 | 16 | # -- General configuration --------------------------------------------------- 17 | 18 | extensions = [ 19 | "sphinx.ext.autodoc", 20 | "sphinx.ext.autosummary", 21 | "sphinx.ext.intersphinx", 22 | "sphinx.ext.napoleon", 23 | ] 24 | 25 | python_use_unqualified_type_names = True 26 | autosummary_generate = True 27 | 28 | intersphinx_mapping = { 29 | "https://docs.python.org/3": None, 30 | "https://numpy.org/doc/stable/": None, 31 | } 32 | 33 | templates_path = ["_templates"] 34 | html_static_path = ["_static"] 35 | source_suffix = ".rst" 36 | master_doc = "index" 37 | highlight_language = "python" 38 | pygments_style = "sphinx" 39 | 40 | # -- Options for HTML output ------------------------------------------------- 41 | 42 | html_theme = "sphinx_book_theme" 43 | 44 | html_theme_options = { 45 | "show_toc_level": 2, 46 | "repository_url": "https://github.com/ml-explore/mlx", 47 | "use_repository_button": True, 48 | "navigation_with_keys": False, 49 | } 50 | 51 | html_logo = "_static/mlx_logo.png" 52 | 53 | 54 | # -- Options for HTMLHelp output --------------------------------------------- 55 | 56 | htmlhelp_basename = "mlx_doc" 57 | -------------------------------------------------------------------------------- /mlx.pc.in: -------------------------------------------------------------------------------- 1 | # Find MLX 2 | # 3 | # Defines the following variables: 4 | # 5 | # MLX_FOUND : True if MLX is found 6 | # MLX_INCLUDE_DIRS : Include directory 7 | # MLX_LIBRARIES : Libraries to link against 8 | # MLX_CXX_FLAGS : Additional compiler flags 9 | # MLX_BUILD_ACCELERATE : True if MLX was built with accelerate 10 | # MLX_BUILD_METAL : True if MLX was built with metal 11 | 12 | @PACKAGE_INIT@ 13 | 14 | include(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/MLXTargets.cmake) 15 | include(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/extension.cmake) 16 | 17 | set_and_check(MLX_LIBRARY_DIRS @PACKAGE_CMAKE_INSTALL_LIBDIR@) 18 | set_and_check(MLX_INCLUDE_DIRS @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@) 19 | set(MLX_LIBRARIES mlx) 20 | 21 | find_library(MLX_LIBRARY mlx PATHS ${MLX_LIBRARY_DIRS}) 22 | 23 | if (@MLX_BUILD_ACCELERATE@) 24 | set(MLX_BUILD_ACCELERATE @MLX_BUILD_ACCELERATE@) 25 | set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -DACCELERATE_NEW_LAPACK) 26 | endif() 27 | 28 | if (@MLX_BUILD_METAL@) 29 | set(MLX_BUILD_METAL @MLX_BUILD_METAL@) 30 | set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_) 31 | set_and_check(MLX_INCLUDE_DIRS 32 | ${MLX_INCLUDE_DIRS} 33 | @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp 34 | ) 35 | endif() 36 | 37 | set_target_properties(mlx PROPERTIES 38 | CXX_STANDARD 17 39 | INTERFACE_COMPILE_OPTIONS "${MLX_CXX_FLAGS}" 40 | ) 41 | 42 | include(FindPackageHandleStandardArgs) 43 | find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS) -------------------------------------------------------------------------------- /examples/cpp/logistic_regression.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "mlx/mlx.h" 8 | #include "timer.h" 9 | 10 | /** 11 | * An example of logistic regression with MLX. 12 | */ 13 | using namespace mlx::core; 14 | 15 | int main() { 16 | int num_features = 100; 17 | int num_examples = 1'000; 18 | int num_iters = 10'000; 19 | float learning_rate = 0.1; 20 | 21 | // True parameters 22 | auto w_star = random::normal({num_features}); 23 | 24 | // The input examples 25 | auto X = random::normal({num_examples, num_features}); 26 | 27 | // Labels 28 | auto y = matmul(X, w_star) > 0; 29 | 30 | // Initialize random parameters 31 | array w = 1e-2 * random::normal({num_features}); 32 | 33 | auto loss_fn = [&](array w) { 34 | auto logits = matmul(X, w); 35 | auto scale = (1.0f / num_examples); 36 | return scale * sum(logaddexp(array(0.0f), logits) - y * logits); 37 | }; 38 | 39 | auto grad_fn = grad(loss_fn); 40 | 41 | auto tic = timer::time(); 42 | for (int it = 0; it < num_iters; ++it) { 43 | auto grad = grad_fn(w); 44 | w = w - learning_rate * grad; 45 | eval(w); 46 | } 47 | auto toc = timer::time(); 48 | 49 | auto loss = loss_fn(w); 50 | auto acc = sum((matmul(X, w) > 0) == y) / num_examples; 51 | auto throughput = num_iters / timer::seconds(toc - tic); 52 | std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput " 53 | << throughput << " (it/s)." << std::endl; 54 | } 55 | -------------------------------------------------------------------------------- /examples/cpp/linear_regression.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "mlx/mlx.h" 8 | #include "timer.h" 9 | 10 | /** 11 | * An example of linear regression with MLX. 12 | */ 13 | using namespace mlx::core; 14 | 15 | int main() { 16 | int num_features = 100; 17 | int num_examples = 1'000; 18 | int num_iters = 10'000; 19 | float learning_rate = 0.01; 20 | 21 | // True parameters 22 | auto w_star = random::normal({num_features}); 23 | 24 | // The input examples (design matrix) 25 | auto X = random::normal({num_examples, num_features}); 26 | 27 | // Noisy labels 28 | auto eps = 1e-2 * random::normal({num_examples}); 29 | auto y = matmul(X, w_star) + eps; 30 | 31 | // Initialize random parameters 32 | array w = 1e-2 * random::normal({num_features}); 33 | 34 | auto loss_fn = [&](array w) { 35 | auto yhat = matmul(X, w); 36 | return (0.5f / num_examples) * sum(square(yhat - y)); 37 | }; 38 | 39 | auto grad_fn = grad(loss_fn); 40 | 41 | auto tic = timer::time(); 42 | for (int it = 0; it < num_iters; ++it) { 43 | auto grad = grad_fn(w); 44 | w = w - learning_rate * grad; 45 | eval(w); 46 | } 47 | auto toc = timer::time(); 48 | 49 | auto loss = loss_fn(w); 50 | auto error_norm = std::sqrt(sum(square(w - w_star)).item()); 51 | auto throughput = num_iters / timer::seconds(toc - tic); 52 | std::cout << "Loss " << loss << ", |w - w*| = " << error_norm 53 | << ", Throughput " << throughput << " (it/s)." << std::endl; 54 | } 55 | -------------------------------------------------------------------------------- /mlx/types/half_types.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC 5 | 6 | #include 7 | namespace mlx::core { 8 | typedef __fp16 float16_t; 9 | } // namespace mlx::core 10 | 11 | #else 12 | 13 | #define ADD_HALF_BINOPS 14 | #include "mlx/types/fp16.h" 15 | namespace mlx::core { 16 | typedef struct _MLX_Float16 float16_t; 17 | } // namespace mlx::core 18 | 19 | #endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC 20 | #ifdef __ARM_FEATURE_BF16 21 | 22 | #include 23 | namespace mlx::core { 24 | typedef __bf16 bfloat16_t; 25 | } // namespace mlx::core 26 | 27 | #else 28 | 29 | #define ADD_HALF_BINOPS 30 | #include "mlx/types/bf16.h" 31 | namespace mlx::core { 32 | typedef struct _MLX_BFloat16 bfloat16_t; 33 | } // namespace mlx::core 34 | 35 | #endif // __ARM_FEATURE_BF16 36 | 37 | #ifdef ADD_HALF_BINOPS 38 | namespace mlx::core { 39 | 40 | // clang-format off 41 | #define fp16_bf16_binop_helper(__op__, __operator__) \ 42 | inline float __operator__(float16_t lhs, bfloat16_t rhs) { \ 43 | return static_cast(lhs) __op__ static_cast(rhs); \ 44 | } \ 45 | inline float __operator__(bfloat16_t lhs, float16_t rhs) { \ 46 | return static_cast(lhs) __op__ static_cast(rhs); \ 47 | } 48 | 49 | fp16_bf16_binop_helper(+, operator+) 50 | fp16_bf16_binop_helper(-, operator-) 51 | fp16_bf16_binop_helper(*, operator*) 52 | fp16_bf16_binop_helper(/, operator/) 53 | // clang-format on 54 | 55 | } // namespace mlx::core 56 | #endif 57 | -------------------------------------------------------------------------------- /benchmarks/python/batch_matmul_bench.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | import mlx.core as mx 5 | 6 | from time_utils import time_fn 7 | 8 | B = 8 9 | T = 1024 10 | D = 512 11 | 12 | 13 | def time_batch_matmul(): 14 | mx.random.seed(3) 15 | a = mx.random.uniform(shape=(B, T, D)) 16 | b = mx.random.uniform(shape=(D, D)) 17 | c = mx.random.uniform(shape=(B, T, D)) 18 | mx.eval(a, b, c) 19 | 20 | time_fn(mx.matmul, a, b) 21 | 22 | def batch_vjp_first(): 23 | return mx.vjp(mx.matmul, [a, b], [c])[1][0] 24 | 25 | time_fn(batch_vjp_first) 26 | 27 | def batch_vjp_second(): 28 | return mx.vjp(mx.matmul, [a, b], [c])[1][1] 29 | 30 | time_fn(batch_vjp_second) 31 | 32 | 33 | def time_unbatch_matmul(): 34 | mx.random.seed(3) 35 | a = mx.random.uniform(shape=(B * T, D)) 36 | b = mx.random.uniform(shape=(D, D)) 37 | c = mx.random.uniform(shape=(B * T, D)) 38 | mx.eval(a, b, c) 39 | time_fn(mx.matmul, a, b) 40 | 41 | def unbatch_vjp_first(): 42 | return mx.matmul(c, mx.transpose(b)) 43 | 44 | time_fn(unbatch_vjp_first) 45 | 46 | def unbatch_vjp_second(): 47 | return mx.matmul(mx.transpose(a), c) 48 | 49 | time_fn(unbatch_vjp_second) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser("MLX benchmarks.") 54 | parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") 55 | args = parser.parse_args() 56 | if args.gpu: 57 | mx.set_default_device(mx.gpu) 58 | else: 59 | mx.set_default_device(mx.cpu) 60 | 61 | time_batch_matmul() 62 | time_unbatch_matmul() 63 | -------------------------------------------------------------------------------- /mlx/backend/no_metal/primitives.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/primitives.h" 4 | 5 | #define NO_GPU(func) \ 6 | void func::eval_gpu(const std::vector& inputs, array& out) { \ 7 | throw std::runtime_error(#func " has no GPU implementation."); \ 8 | } 9 | 10 | namespace mlx::core { 11 | 12 | NO_GPU(Abs) 13 | NO_GPU(Add) 14 | NO_GPU(Arange) 15 | NO_GPU(ArcCos) 16 | NO_GPU(ArcCosh) 17 | NO_GPU(ArcSin) 18 | NO_GPU(ArcSinh) 19 | NO_GPU(ArcTan) 20 | NO_GPU(ArcTanh) 21 | NO_GPU(ArgPartition) 22 | NO_GPU(ArgReduce) 23 | NO_GPU(ArgSort) 24 | NO_GPU(AsType) 25 | NO_GPU(AsStrided) 26 | NO_GPU(Broadcast) 27 | NO_GPU(Concatenate) 28 | NO_GPU(Convolution) 29 | NO_GPU(Copy) 30 | NO_GPU(Cos) 31 | NO_GPU(Cosh) 32 | NO_GPU(Divide) 33 | NO_GPU(Equal) 34 | NO_GPU(Erf) 35 | NO_GPU(ErfInv) 36 | NO_GPU(Exp) 37 | NO_GPU(FFT) 38 | NO_GPU(Full) 39 | NO_GPU(Gather) 40 | NO_GPU(Greater) 41 | NO_GPU(GreaterEqual) 42 | NO_GPU(Less) 43 | NO_GPU(LessEqual) 44 | NO_GPU(Load) 45 | NO_GPU(Log) 46 | NO_GPU(Log1p) 47 | NO_GPU(LogicalNot) 48 | NO_GPU(LogAddExp) 49 | NO_GPU(Matmul) 50 | NO_GPU(Maximum) 51 | NO_GPU(Minimum) 52 | NO_GPU(Multiply) 53 | NO_GPU(Negative) 54 | NO_GPU(NotEqual) 55 | NO_GPU(Pad) 56 | NO_GPU(Partition) 57 | NO_GPU(Power) 58 | NO_GPU(RandomBits) 59 | NO_GPU(Reduce) 60 | NO_GPU(Reshape) 61 | NO_GPU(Scan) 62 | NO_GPU(Scatter) 63 | NO_GPU(Sigmoid) 64 | NO_GPU(Sign) 65 | NO_GPU(Sin) 66 | NO_GPU(Sinh) 67 | NO_GPU(Slice) 68 | NO_GPU(Softmax) 69 | NO_GPU(Sort) 70 | NO_GPU(Square) 71 | NO_GPU(Sqrt) 72 | NO_GPU(StopGradient) 73 | NO_GPU(Subtract) 74 | NO_GPU(Tan) 75 | NO_GPU(Tanh) 76 | NO_GPU(Transpose) 77 | 78 | } // namespace mlx::core 79 | -------------------------------------------------------------------------------- /mlx/backend/common/erf.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | namespace mlx::core { 6 | 7 | /* Approximation to the inverse error function. 8 | * Based on code from: 9 | * https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348 10 | */ 11 | float erfinv(float a) { 12 | auto t = std::fma(a, 0.0f - a, 1.0f); 13 | t = std::log(t); 14 | float p; 15 | if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793 16 | p = 3.03697567e-10f; // 0x1.4deb44p-32 17 | p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 18 | p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 19 | p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 20 | p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 21 | p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 22 | p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 23 | p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 24 | p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 25 | } else { // maximum ulp error = 2.35002 26 | p = 5.43877832e-9f; // 0x1.75c000p-28 27 | p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 28 | p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 29 | p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 30 | p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 31 | p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 32 | p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 33 | p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 34 | p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 35 | p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 36 | } 37 | return a * p; 38 | } 39 | 40 | } // namespace mlx::core 41 | -------------------------------------------------------------------------------- /mlx/allocator.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace mlx::core::allocator { 8 | 9 | // Simple wrapper around buffer pointers 10 | // WARNING: Only Buffer objects constructed from and those that wrap 11 | // raw pointers from mlx::allocator are supported. 12 | class Buffer { 13 | private: 14 | void* ptr_; 15 | 16 | public: 17 | Buffer(void* ptr) : ptr_(ptr){}; 18 | 19 | // Get the raw data pointer from the buffer 20 | void* raw_ptr(); 21 | 22 | // Get the buffer pointer from the buffer 23 | const void* ptr() const { 24 | return ptr_; 25 | }; 26 | void* ptr() { 27 | return ptr_; 28 | }; 29 | }; 30 | 31 | Buffer malloc(size_t size); 32 | 33 | void free(Buffer buffer); 34 | 35 | // Wait for running tasks to finish and free up memory 36 | // if allocation fails 37 | Buffer malloc_or_wait(size_t size); 38 | 39 | class Allocator { 40 | /** Abstract base clase for a memory allocator. */ 41 | public: 42 | virtual Buffer malloc(size_t size) = 0; 43 | virtual void free(Buffer buffer) = 0; 44 | 45 | Allocator() = default; 46 | Allocator(const Allocator& other) = delete; 47 | Allocator(Allocator&& other) = delete; 48 | Allocator& operator=(const Allocator& other) = delete; 49 | Allocator& operator=(Allocator&& other) = delete; 50 | virtual ~Allocator() = default; 51 | }; 52 | 53 | Allocator& allocator(); 54 | 55 | class CommonAllocator : public Allocator { 56 | /** A general CPU allocator. */ 57 | public: 58 | virtual Buffer malloc(size_t size) override; 59 | virtual void free(Buffer buffer) override; 60 | 61 | private: 62 | CommonAllocator() = default; 63 | friend Allocator& allocator(); 64 | }; 65 | 66 | } // namespace mlx::core::allocator 67 | -------------------------------------------------------------------------------- /cmake/extension.cmake: -------------------------------------------------------------------------------- 1 | include(CMakeParseArguments) 2 | 3 | ############################################################################### 4 | # Build metal library 5 | # 6 | # Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib 7 | # from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS} 8 | # 9 | # Args: 10 | # TARGET: Custom target to be added for the metal library 11 | # TITLE: Name of the .metallib 12 | # OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib 13 | # SOURCES: List of source files 14 | # INCLUDE_DIRS: List of include dirs 15 | # DEPS: List of depedency files (like headers) 16 | # 17 | macro(mlx_build_metallib) 18 | # Parse args 19 | set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) 20 | set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) 21 | cmake_parse_arguments( 22 | MTLLIB 23 | "" 24 | "${oneValueArgs}" 25 | "${multiValueArgs}" 26 | ${ARGN} 27 | ) 28 | 29 | # Set output 30 | set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") 31 | 32 | # Collect compile options 33 | set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math) 34 | 35 | # Prepare metllib build command 36 | add_custom_command( 37 | OUTPUT ${MTLLIB_BUILD_TARGET} 38 | COMMAND xcrun -sdk macosx metal 39 | "$" 40 | ${MTLLIB_COMPILE_OPTIONS} 41 | ${MTLLIB_SOURCES} 42 | -o ${MTLLIB_BUILD_TARGET} 43 | DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES} 44 | COMMAND_EXPAND_LISTS 45 | COMMENT "Building ${MTLLIB_TITLE}.metallib" 46 | VERBATIM 47 | ) 48 | 49 | # Add metallib custom target 50 | add_custom_target( 51 | ${MTLLIB_TARGET} 52 | DEPENDS 53 | ${MTLLIB_BUILD_TARGET} 54 | ) 55 | 56 | endmacro(mlx_build_metallib) -------------------------------------------------------------------------------- /examples/extensions/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.24) 2 | 3 | project(mlx_sample_extensions LANGUAGES CXX) 4 | 5 | # ----------------------------- Setup ----------------------------- 6 | set(CMAKE_CXX_STANDARD 17) 7 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 8 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 9 | 10 | option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON) 11 | 12 | # ----------------------------- Dependencies ----------------------------- 13 | find_package(MLX CONFIG REQUIRED) 14 | find_package(Python COMPONENTS Interpreter Development) 15 | find_package(pybind11 CONFIG REQUIRED) 16 | 17 | # ----------------------------- Extensions ----------------------------- 18 | 19 | # Add library 20 | add_library(mlx_ext) 21 | 22 | # Add sources 23 | target_sources( 24 | mlx_ext 25 | PUBLIC 26 | ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp 27 | ) 28 | 29 | # Add include headers 30 | target_include_directories( 31 | mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR} 32 | ) 33 | 34 | # Link to mlx 35 | target_link_libraries(mlx_ext PUBLIC mlx) 36 | 37 | # ----------------------------- Metal ----------------------------- 38 | 39 | # Build metallib 40 | if(MLX_BUILD_METAL) 41 | 42 | mlx_build_metallib( 43 | TARGET mlx_ext_metallib 44 | TITLE mlx_ext 45 | SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal 46 | INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS} 47 | OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY} 48 | ) 49 | 50 | add_dependencies( 51 | mlx_ext 52 | mlx_ext_metallib 53 | ) 54 | 55 | endif() 56 | 57 | # ----------------------------- Pybind ----------------------------- 58 | pybind11_add_module( 59 | mlx_sample_extensions 60 | ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp 61 | ) 62 | target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext) 63 | 64 | if(BUILD_SHARED_LIBS) 65 | target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path) 66 | endif() -------------------------------------------------------------------------------- /mlx/backend/metal/allocator.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "mlx/allocator.h" 10 | #include "mlx/backend/metal/device.h" 11 | 12 | namespace mlx::core::metal { 13 | 14 | using allocator::Buffer; 15 | 16 | namespace { 17 | 18 | class BufferCache { 19 | public: 20 | BufferCache(MTL::Device* device); 21 | ~BufferCache(); 22 | void clear(); 23 | 24 | MTL::Buffer* reuse_from_cache(size_t size); 25 | void recycle_to_cache(MTL::Buffer* buf); 26 | size_t release_cached_buffers(size_t min_bytes_to_free); 27 | 28 | bool can_garbage_collect() { 29 | return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_; 30 | } 31 | 32 | private: 33 | struct BufferHolder { 34 | public: 35 | BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {} 36 | 37 | BufferHolder* prev; 38 | BufferHolder* next; 39 | MTL::Buffer* buf; 40 | }; 41 | 42 | void add_at_head(BufferHolder* to_add); 43 | void remove_from_list(BufferHolder* to_remove); 44 | 45 | MTL::Device* device_; 46 | std::mutex cache_mutex_; 47 | 48 | std::multimap buffer_pool_; 49 | BufferHolder* head_; 50 | BufferHolder* tail_; 51 | size_t pool_size_; 52 | size_t gc_limit_; 53 | }; 54 | 55 | } // namespace 56 | 57 | class MetalAllocator : public allocator::Allocator { 58 | /** Allocator for Metal GPUs. */ 59 | public: 60 | virtual Buffer malloc(size_t size) override; 61 | virtual void free(Buffer buffer) override; 62 | 63 | private: 64 | MTL::Device* device_; 65 | MetalAllocator(); 66 | friend MetalAllocator& allocator(); 67 | 68 | // Caching allocator 69 | BufferCache buffer_cache_; 70 | 71 | // Allocation stats 72 | size_t peak_allocated_size_; 73 | size_t block_limit_; 74 | }; 75 | 76 | MetalAllocator& allocator(); 77 | 78 | } // namespace mlx::core::metal 79 | -------------------------------------------------------------------------------- /tests/load_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "doctest/doctest.h" 8 | 9 | #include "mlx/mlx.h" 10 | 11 | using namespace mlx::core; 12 | 13 | std::string get_temp_file(const std::string& name) { 14 | return std::filesystem::temp_directory_path().append(name); 15 | } 16 | 17 | TEST_CASE("test single array serialization") { 18 | // Basic test 19 | { 20 | auto a = random::uniform(-5.f, 5.f, {2, 5, 12}, float32); 21 | 22 | std::string file_path = get_temp_file("test_arr.npy"); 23 | 24 | save(file_path, a); 25 | auto b = load(file_path); 26 | 27 | CHECK_EQ(a.dtype(), b.dtype()); 28 | CHECK_EQ(a.shape(), b.shape()); 29 | CHECK(array_equal(a, b).item()); 30 | } 31 | 32 | // Other shapes 33 | { 34 | auto a = random::uniform( 35 | -5.f, 36 | 5.f, 37 | { 38 | 1, 39 | }, 40 | float32); 41 | 42 | std::string file_path = get_temp_file("test_arr_0.npy"); 43 | 44 | save(file_path, a); 45 | auto b = load(file_path); 46 | 47 | CHECK_EQ(a.dtype(), b.dtype()); 48 | CHECK_EQ(a.shape(), b.shape()); 49 | CHECK(array_equal(a, b).item()); 50 | } 51 | 52 | { 53 | auto a = random::uniform( 54 | -5.f, 55 | 5.f, 56 | { 57 | 46, 58 | }, 59 | float32); 60 | 61 | std::string file_path = get_temp_file("test_arr_1.npy"); 62 | 63 | save(file_path, a); 64 | auto b = load(file_path); 65 | 66 | CHECK_EQ(a.dtype(), b.dtype()); 67 | CHECK_EQ(a.shape(), b.shape()); 68 | CHECK(array_equal(a, b).item()); 69 | } 70 | 71 | { 72 | auto a = random::uniform(-5.f, 5.f, {5, 2, 1, 3, 4}, float32); 73 | 74 | std::string file_path = get_temp_file("test_arr_2.npy"); 75 | 76 | save(file_path, a); 77 | auto b = load(file_path); 78 | 79 | CHECK_EQ(a.dtype(), b.dtype()); 80 | CHECK_EQ(a.shape(), b.shape()); 81 | CHECK(array_equal(a, b).item()); 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /docs/src/quick_start.rst: -------------------------------------------------------------------------------- 1 | Quick Start Guide 2 | ================= 3 | 4 | 5 | Basics 6 | ------ 7 | 8 | .. currentmodule:: mlx.core 9 | 10 | Import ``mlx.core`` and make an :class:`array`: 11 | 12 | .. code-block:: python 13 | 14 | >> import mlx.core as mx 15 | >> a = mx.array([1, 2, 3, 4]) 16 | >> a.shape 17 | [4] 18 | >> a.dtype 19 | int32 20 | >> b = mx.array([1.0, 2.0, 3.0, 4.0]) 21 | >> b.dtype 22 | float32 23 | 24 | Operations in MLX are lazy. The outputs of MLX operations are not computed 25 | until they are needed. To force an array to be evaluated use 26 | :func:`eval`. Arrays will automatically be evaluated in a few cases. For 27 | example, inspecting a scalar with :meth:`array.item`, printing an array, 28 | or converting an array from :class:`array` to :class:`numpy.ndarray` all 29 | automatically evaluate the array. 30 | 31 | .. code-block:: python 32 | 33 | >> c = a + b # c not yet evaluated 34 | >> mx.eval(c) # evaluates c 35 | >> c = a + b 36 | >> print(c) # Also evaluates c 37 | array([2, 4, 6, 8], dtype=float32) 38 | >> c = a + b 39 | >> import numpy as np 40 | >> np.array(c) # Also evaluates c 41 | array([2., 4., 6., 8.], dtype=float32) 42 | 43 | Function and Graph Transformations 44 | ---------------------------------- 45 | 46 | MLX has standard function transformations like :func:`grad` and :func:`vmap`. 47 | Transformations can be composed arbitrarily. For example 48 | ``grad(vmap(grad(fn)))`` (or any other composition) is allowed. 49 | 50 | .. code-block:: python 51 | 52 | >> x = mx.array(0.0) 53 | >> mx.sin(x) 54 | array(0, dtype=float32) 55 | >> mx.grad(mx.sin)(x) 56 | array(1, dtype=float32) 57 | >> mx.grad(mx.grad(mx.sin))(x) 58 | array(-0, dtype=float32) 59 | 60 | Other gradient transformations include :func:`vjp` for vector-Jacobian products 61 | and :func:`jvp` for Jacobian-vector products. 62 | 63 | Use :func:`value_and_grad` to efficiently compute both a function's output and 64 | gradient with respect to the function's input. 65 | 66 | 67 | Devices and Streams 68 | ------------------- 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /mlx/backend/common/arange.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/allocator.h" 6 | #include "mlx/array.h" 7 | 8 | namespace mlx::core { 9 | 10 | namespace { 11 | 12 | template 13 | void arange(T start, T next, array& out, size_t size) { 14 | auto ptr = out.data(); 15 | auto step_size = next - start; 16 | for (int i = 0; i < size; ++i) { 17 | ptr[i] = start; 18 | start += step_size; 19 | } 20 | } 21 | 22 | } // namespace 23 | 24 | void arange( 25 | const std::vector& inputs, 26 | array& out, 27 | double start, 28 | double step) { 29 | assert(inputs.size() == 0); 30 | out.set_data(allocator::malloc_or_wait(out.nbytes())); 31 | switch (out.dtype()) { 32 | case bool_: 33 | throw std::runtime_error("Bool type unsupported for arange."); 34 | break; 35 | case uint8: 36 | arange(start, start + step, out, out.size()); 37 | break; 38 | case uint16: 39 | arange(start, start + step, out, out.size()); 40 | break; 41 | case uint32: 42 | arange(start, start + step, out, out.size()); 43 | break; 44 | case uint64: 45 | arange(start, start + step, out, out.size()); 46 | break; 47 | case int8: 48 | arange(start, start + step, out, out.size()); 49 | break; 50 | case int16: 51 | arange(start, start + step, out, out.size()); 52 | break; 53 | case int32: 54 | arange(start, start + step, out, out.size()); 55 | break; 56 | case int64: 57 | arange(start, start + step, out, out.size()); 58 | break; 59 | case float16: 60 | arange(start, start + step, out, out.size()); 61 | break; 62 | case float32: 63 | arange(start, start + step, out, out.size()); 64 | break; 65 | case bfloat16: 66 | arange(start, start + step, out, out.size()); 67 | break; 68 | case complex64: 69 | arange(start, start + step, out, out.size()); 70 | break; 71 | } 72 | } 73 | 74 | } // namespace mlx::core 75 | -------------------------------------------------------------------------------- /docs/src/index.rst: -------------------------------------------------------------------------------- 1 | MLX 2 | === 3 | 4 | MLX is a NumPy-like array framework designed for efficient and flexible machine 5 | learning on Apple silicon, brought to you by Apple machine learning research. 6 | 7 | The Python API closely follows NumPy with a few exceptions. MLX also has a 8 | fully featured C++ API which closely follows the Python API. 9 | 10 | The main differences between MLX and NumPy are: 11 | 12 | - **Composable function transformations**: MLX has composable function 13 | transformations for automatic differentiation, automatic vectorization, 14 | and computation graph optimization. 15 | - **Lazy computation**: Computations in MLX are lazy. Arrays are only 16 | materialized when needed. 17 | - **Multi-device**: Operations can run on any of the supported devices (CPU, 18 | GPU, ...) 19 | 20 | The design of MLX is inspired by frameworks like `PyTorch 21 | `_, `Jax `_, and 22 | `ArrayFire `_. A noteable difference from these 23 | frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared 24 | memory. Operations on MLX arrays can be performed on any of the supported 25 | device types without performing data copies. Currently supported device types 26 | are the CPU and GPU. 27 | 28 | .. toctree:: 29 | :caption: Install 30 | :maxdepth: 1 31 | 32 | install 33 | 34 | .. toctree:: 35 | :caption: Usage 36 | :maxdepth: 1 37 | 38 | quick_start 39 | using_streams 40 | 41 | .. toctree:: 42 | :caption: Examples 43 | :maxdepth: 1 44 | 45 | examples/linear_regression 46 | examples/mlp 47 | examples/llama-inference 48 | 49 | .. toctree:: 50 | :caption: Python API Reference 51 | :maxdepth: 1 52 | 53 | python/array 54 | python/devices_and_streams 55 | python/ops 56 | python/random 57 | python/transforms 58 | python/fft 59 | python/nn 60 | python/optimizers 61 | python/tree_utils 62 | 63 | .. toctree:: 64 | :caption: C++ API Reference 65 | :maxdepth: 1 66 | 67 | cpp/ops 68 | 69 | .. toctree:: 70 | :caption: Further Reading 71 | :maxdepth: 1 72 | 73 | dev/extensions 74 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set( 2 | HEADERS 3 | ${CMAKE_CURRENT_SOURCE_DIR}/bf16.h 4 | ${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h 5 | ${CMAKE_CURRENT_SOURCE_DIR}/complex.h 6 | ${CMAKE_CURRENT_SOURCE_DIR}/defines.h 7 | ${CMAKE_CURRENT_SOURCE_DIR}/erf.h 8 | ${CMAKE_CURRENT_SOURCE_DIR}/reduce.h 9 | ${CMAKE_CURRENT_SOURCE_DIR}/utils.h 10 | ) 11 | 12 | set( 13 | KERNELS 14 | "arange" 15 | "arg_reduce" 16 | "binary" 17 | "conv" 18 | "copy" 19 | "gemm" 20 | "gemv" 21 | "random" 22 | "reduce" 23 | "scan" 24 | "softmax" 25 | "sort" 26 | "unary" 27 | "indexing" 28 | ) 29 | 30 | function(build_kernel KERNEL) 31 | set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal) 32 | set(HEADERS_PADDED ${HEADERS}) 33 | if(${KERNEL} STREQUAL "gemm") 34 | set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/gemm.h) 35 | endif() 36 | if(${KERNEL} STREQUAL "conv") 37 | set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/conv.h) 38 | endif() 39 | add_custom_command( 40 | COMMAND xcrun -sdk macosx metal -Wall -Wextra 41 | -fno-fast-math 42 | -c ${SRCFILE} 43 | -I${PROJECT_SOURCE_DIR} 44 | -o ${KERNEL}.air 45 | DEPENDS ${SRCFILE} ${HEADERS_PADDED} 46 | OUTPUT ${KERNEL}.air 47 | COMMENT "Building ${KERNEL}.air" 48 | VERBATIM 49 | ) 50 | endfunction(build_kernel) 51 | 52 | foreach(KERNEL ${KERNELS}) 53 | build_kernel(${KERNEL}) 54 | set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR}) 55 | endforeach() 56 | 57 | add_custom_command( 58 | OUTPUT ${MLX_METAL_PATH}/mlx.metallib 59 | COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib 60 | DEPENDS ${KERNEL_AIR} 61 | COMMENT "Building mlx.metallib" 62 | VERBATIM 63 | ) 64 | 65 | add_custom_target( 66 | mlx-metallib 67 | DEPENDS 68 | ${MLX_METAL_PATH}/mlx.metallib 69 | ) 70 | 71 | add_dependencies( 72 | mlx 73 | mlx-metallib 74 | ) 75 | 76 | # Install metallib 77 | include(GNUInstallDirs) 78 | 79 | install( 80 | FILES ${MLX_METAL_PATH}/mlx.metallib 81 | DESTINATION ${CMAKE_INSTALL_LIBDIR} 82 | COMPONENT metallib 83 | ) 84 | -------------------------------------------------------------------------------- /docs/src/examples/linear_regression.rst: -------------------------------------------------------------------------------- 1 | .. _linear_regression: 2 | 3 | Linear Regression 4 | ----------------- 5 | 6 | Let's implement a basic linear regression model as a starting point to 7 | learn MLX. First import the core package and setup some problem metadata: 8 | 9 | .. code-block:: python 10 | 11 | import mlx.core as mx 12 | 13 | num_features = 100 14 | num_examples = 1_000 15 | num_iters = 10_000 # iterations of SGD 16 | lr = 0.01 # learning rate for SGD 17 | 18 | 19 | We'll generate a synthetic dataset by: 20 | 21 | 1. Sampling the design matrix ``X``. 22 | 2. Sampling a ground truth parameter vector ``w_star``. 23 | 3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``. 24 | 25 | .. code-block:: python 26 | 27 | # True parameters 28 | w_star = mx.random.normal((num_features,)) 29 | 30 | # Input examples (design matrix) 31 | X = mx.random.normal((num_examples, num_features)) 32 | 33 | # Noisy labels 34 | eps = 1e-2 * mx.random.normal((num_examples,)) 35 | y = X @ w_star + eps 36 | 37 | 38 | We will use SGD to find the optimal weights. To start, define the squared loss 39 | and get the gradient function of the loss with respect to the parameters. 40 | 41 | .. code-block:: python 42 | 43 | def loss_fn(w): 44 | return 0.5 * mx.mean(mx.square(X @ w - y)) 45 | 46 | grad_fn = mx.grad(loss_fn) 47 | 48 | Start the optimization by initializing the parameters ``w`` randomly. Then 49 | repeatedly update the parameters for ``num_iters`` iterations. 50 | 51 | .. code-block:: python 52 | 53 | w = 1e-2 * mx.random.normal((num_features,)) 54 | 55 | for _ in range(num_iters): 56 | grad = grad_fn(w) 57 | w = w - lr * grad 58 | mx.eval(w) 59 | 60 | Finally, compute the loss of the learned parameters and verify that they are 61 | close to the ground truth parameters. 62 | 63 | .. code-block:: python 64 | 65 | loss = loss_fn(w) 66 | error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5 67 | 68 | print( 69 | f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, " 70 | ) 71 | # Should print something close to: Loss 0.00005, |w-w*| = 0.00364 72 | 73 | Complete `linear regression 74 | `_ 75 | and `logistic regression 76 | `_ 77 | examples are available in the MLX GitHub repo. 78 | -------------------------------------------------------------------------------- /tests/eval_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "doctest/doctest.h" 4 | 5 | #include "mlx/mlx.h" 6 | 7 | using namespace mlx::core; 8 | 9 | TEST_CASE("test eval") { 10 | { 11 | array x(1.0); 12 | array y(1); 13 | array z(true); 14 | eval({x, y, z}); 15 | CHECK_EQ(x.item(), 1.0); 16 | } 17 | 18 | { 19 | array x(1.0); 20 | array y = ones({2, 2}); 21 | array z(true); 22 | eval({x, y, z}); 23 | CHECK(array_equal(y, array({1.0, 1.0, 1.0, 1.0}, {2, 2})).item()); 24 | } 25 | } 26 | 27 | TEST_CASE("test eval multiple") { 28 | auto x = ones({10, 10}); 29 | auto y = ones({10, 10}); 30 | eval({x, y}); 31 | CHECK(array_equal(x, y).item()); 32 | 33 | auto a = x + y; 34 | auto b = x - y; 35 | eval({a, b}); 36 | CHECK(array_equal(a, full({10, 10}, 2.0f)).item()); 37 | CHECK(array_equal(b, full({10, 10}, 0.0f)).item()); 38 | 39 | x = ones({10, 10}); 40 | y = ones({10, 10}); 41 | eval(x, y); 42 | CHECK(array_equal(x, y).item()); 43 | 44 | a = x + y; 45 | b = x - y; 46 | eval(a, b); 47 | CHECK(array_equal(a, full({10, 10}, 2.0f)).item()); 48 | CHECK(array_equal(b, full({10, 10}, 0.0f)).item()); 49 | } 50 | 51 | TEST_CASE("test eval with tracer") { 52 | auto x = array(1); 53 | x.set_tracer(true); 54 | 55 | // Ok, x is not a node 56 | eval(x); 57 | 58 | x = ones({2, 3}); 59 | x.set_tracer(true); 60 | CHECK_THROWS(eval(x)); 61 | 62 | // Ok retain_graph=true 63 | eval({x}, true); 64 | 65 | // Make sure all arguments are checked 66 | auto y = ones({2, 3}); 67 | CHECK_THROWS(eval(x, y)); 68 | } 69 | 70 | TEST_CASE("test eval graph retention") { 71 | auto x = array(1); 72 | auto y = array(2); 73 | auto z = x + y; 74 | eval({z}, true); 75 | CHECK(z.has_primitive()); 76 | CHECK(z.is_evaled()); 77 | CHECK_EQ(z.item(true), 3); 78 | CHECK(z.has_primitive()); 79 | CHECK(z.is_evaled()); 80 | 81 | CHECK_EQ(z.item(), 3); 82 | CHECK(!z.has_primitive()); 83 | CHECK(z.is_evaled()); 84 | 85 | z = x + y; 86 | auto a = z + x; 87 | auto b = a + y; 88 | eval({b}, true); 89 | CHECK(z.has_primitive()); 90 | CHECK(z.is_evaled()); 91 | CHECK(a.has_primitive()); 92 | CHECK(a.is_evaled()); 93 | 94 | eval({b}, false); 95 | CHECK(!z.has_primitive()); 96 | CHECK(z.is_evaled()); 97 | CHECK(!a.has_primitive()); 98 | CHECK(a.is_evaled()); 99 | } 100 | -------------------------------------------------------------------------------- /mlx/backend/metal/device.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include "mlx/device.h" 15 | 16 | namespace fs = std::filesystem; 17 | 18 | namespace mlx::core::metal { 19 | 20 | inline std::string get_colocated_mtllib_path(const std::string& lib_name) { 21 | Dl_info info; 22 | std::string mtllib_path; 23 | std::string lib_ext = lib_name + ".metallib"; 24 | 25 | int success = dladdr((void*)get_colocated_mtllib_path, &info); 26 | if (success) { 27 | auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; 28 | mtllib_path = mtllib.c_str(); 29 | } 30 | 31 | return mtllib_path; 32 | } 33 | 34 | class Device { 35 | public: 36 | Device(); 37 | Device(const Device&) = delete; 38 | Device& operator=(const Device&) = delete; 39 | ~Device(); 40 | 41 | MTL::Device* mtl_device() { 42 | return device_; 43 | }; 44 | 45 | void new_queue(int index); 46 | MTL::CommandBuffer* new_command_buffer(int index); 47 | MTL::CommandBuffer* get_command_buffer(int index); 48 | int get_command_buffer_ops(int index); 49 | void increment_command_buffer_ops(int index); 50 | void commit_command_buffer(int index); 51 | MTL::ComputeCommandEncoder* get_command_encoder(int index); 52 | void end_encoding(int index); 53 | 54 | void register_library( 55 | const std::string& lib_name, 56 | const std::string& lib_path); 57 | void register_library( 58 | const std::string& lib_name, 59 | const std::function& lib_path_func = 60 | get_colocated_mtllib_path); 61 | 62 | MTL::ComputePipelineState* get_kernel( 63 | const std::string& name, 64 | const std::string& lib_name = "mlx"); 65 | 66 | MTL::ArgumentEncoder* argument_encoder( 67 | const std::vector& arg_descs) const; 68 | 69 | private: 70 | NS::AutoreleasePool* pool_; 71 | MTL::Device* device_; 72 | std::unordered_map queue_map_; 73 | std::unordered_map> buffer_map_; 74 | std::unordered_map encoder_map_; 75 | std::unordered_map kernel_map_; 76 | std::unordered_map library_map_; 77 | std::mutex mtx_; 78 | }; 79 | 80 | Device& device(mlx::core::Device); 81 | NS::AutoreleasePool*& thread_autorelease_pool(); 82 | 83 | } // namespace mlx::core::metal 84 | -------------------------------------------------------------------------------- /python/src/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "mlx/array.h" 12 | 13 | namespace py = pybind11; 14 | 15 | using namespace mlx::core; 16 | 17 | using IntOrVec = std::variant>; 18 | using ScalarOrArray = 19 | std::variant, array>; 20 | static constexpr std::monostate none{}; 21 | 22 | inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { 23 | std::vector axes; 24 | if (std::holds_alternative(v)) { 25 | axes.resize(dims); 26 | std::iota(axes.begin(), axes.end(), 0); 27 | } else if (auto pv = std::get_if(&v); pv) { 28 | axes.push_back(*pv); 29 | } else { 30 | axes = std::get>(v); 31 | } 32 | return axes; 33 | } 34 | 35 | inline array to_array( 36 | const ScalarOrArray& v, 37 | std::optional dtype = std::nullopt) { 38 | if (auto pv = std::get_if(&v); pv) { 39 | return array(py::cast(*pv), dtype.value_or(bool_)); 40 | } else if (auto pv = std::get_if(&v); pv) { 41 | auto out_t = dtype.value_or(int32); 42 | // bool_ is an exception and is always promoted 43 | return array(py::cast(*pv), (out_t == bool_) ? int32 : out_t); 44 | } else if (auto pv = std::get_if(&v); pv) { 45 | auto out_t = dtype.value_or(float32); 46 | return array( 47 | py::cast(*pv), is_floating_point(out_t) ? out_t : float32); 48 | } else if (auto pv = std::get_if>(&v); pv) { 49 | return array(static_cast(*pv), complex64); 50 | } else { 51 | return std::get(v); 52 | } 53 | } 54 | 55 | inline std::pair to_arrays( 56 | const ScalarOrArray& a, 57 | const ScalarOrArray& b) { 58 | // Four cases: 59 | // - If both a and b are arrays leave their types alone 60 | // - If a is an array but b is not, treat b as a weak python type 61 | // - If b is an array but a is not, treat a as a weak python type 62 | // - If neither is an array convert to arrays but leave their types alone 63 | if (auto pa = std::get_if(&a); pa) { 64 | if (auto pb = std::get_if(&b); pb) { 65 | return {*pa, *pb}; 66 | } 67 | return {*pa, to_array(b, pa->dtype())}; 68 | } else if (auto pb = std::get_if(&b); pb) { 69 | return {to_array(a, pb->dtype()), *pb}; 70 | } else { 71 | return {to_array(a), to_array(b)}; 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /mlx/backend/common/fft.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/3rdparty/pocketfft.h" 6 | #include "mlx/allocator.h" 7 | #include "mlx/primitives.h" 8 | 9 | namespace mlx::core { 10 | 11 | void FFT::eval(const std::vector& inputs, array& out) { 12 | auto& in = inputs[0]; 13 | std::vector strides_in( 14 | in.strides().begin(), in.strides().end()); 15 | for (auto& s : strides_in) { 16 | s *= in.itemsize(); 17 | } 18 | std::vector strides_out( 19 | out.strides().begin(), out.strides().end()); 20 | for (auto& s : strides_out) { 21 | s *= out.itemsize(); 22 | } 23 | 24 | out.set_data(allocator::malloc_or_wait(out.nbytes())); 25 | 26 | std::vector shape; 27 | if (out.dtype() == float32) { 28 | shape.insert(shape.end(), out.shape().begin(), out.shape().end()); 29 | } else { 30 | shape.insert(shape.end(), in.shape().begin(), in.shape().end()); 31 | } 32 | 33 | float scale = 1.0f; 34 | if (inverse_) { 35 | size_t nelem = std::accumulate( 36 | axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) { 37 | return x * shape[y]; 38 | }); 39 | scale /= nelem; 40 | } 41 | if (in.dtype() == complex64 && out.dtype() == complex64) { 42 | auto in_ptr = 43 | reinterpret_cast*>(in.data()); 44 | auto out_ptr = 45 | reinterpret_cast*>(out.data()); 46 | pocketfft::c2c( 47 | shape, 48 | strides_in, 49 | strides_out, 50 | axes_, 51 | !inverse_, 52 | in_ptr, 53 | out_ptr, 54 | scale); 55 | } else if (in.dtype() == float32 && out.dtype() == complex64) { 56 | auto in_ptr = in.data(); 57 | auto out_ptr = 58 | reinterpret_cast*>(out.data()); 59 | pocketfft::r2c( 60 | shape, 61 | strides_in, 62 | strides_out, 63 | axes_, 64 | !inverse_, 65 | in_ptr, 66 | out_ptr, 67 | scale); 68 | } else if (in.dtype() == complex64 && out.dtype() == float32) { 69 | auto in_ptr = 70 | reinterpret_cast*>(in.data()); 71 | auto out_ptr = out.data(); 72 | pocketfft::c2r( 73 | shape, 74 | strides_in, 75 | strides_out, 76 | axes_, 77 | !inverse_, 78 | in_ptr, 79 | out_ptr, 80 | scale); 81 | } else { 82 | throw std::runtime_error( 83 | "[FFT] Received unexpected input and output type combination."); 84 | } 85 | } 86 | 87 | } // namespace mlx::core 88 | -------------------------------------------------------------------------------- /benchmarks/python/single_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import argparse 4 | import mlx.core as mx 5 | 6 | from time_utils import time_fn 7 | 8 | 9 | def time_add(): 10 | a = mx.random.uniform(shape=(32, 1024, 1024)) 11 | b = mx.random.uniform(shape=(32, 1024, 1024)) 12 | mx.eval(a, b) 13 | time_fn(mx.add, a, b) 14 | 15 | aT = mx.transpose(a, [0, 2, 1]) 16 | mx.eval(aT) 17 | 18 | def transpose_add(a, b): 19 | return mx.add(a, b) 20 | 21 | time_fn(transpose_add, aT, b) 22 | 23 | b = mx.random.uniform(shape=(1024,)) 24 | mx.eval(b) 25 | 26 | def slice_add(a, b): 27 | return mx.add(a, b) 28 | 29 | time_fn(slice_add, a, b) 30 | 31 | b = mx.reshape(b, (1, 1024, 1)) 32 | mx.eval(b) 33 | 34 | def mid_slice_add(a, b): 35 | return mx.add(a, b) 36 | 37 | time_fn(mid_slice_add, a, b) 38 | 39 | 40 | def time_matmul(): 41 | a = mx.random.uniform(shape=(1024, 1024)) 42 | b = mx.random.uniform(shape=(1024, 1024)) 43 | mx.eval(a, b) 44 | time_fn(mx.matmul, a, b) 45 | 46 | 47 | def time_negative(): 48 | a = mx.random.uniform(shape=(10000, 1000)) 49 | mx.eval(a) 50 | 51 | def negative(a): 52 | return -a 53 | 54 | mx.eval(a) 55 | 56 | time_fn(negative, a) 57 | 58 | 59 | def time_exp(): 60 | a = mx.random.uniform(shape=(1000, 100)) 61 | mx.eval(a) 62 | time_fn(mx.exp, a) 63 | 64 | 65 | def time_logsumexp(): 66 | a = mx.random.uniform(shape=(64, 10, 10000)) 67 | mx.eval(a) 68 | time_fn(mx.logsumexp, a, axis=-1) 69 | 70 | 71 | def time_take(): 72 | a = mx.random.uniform(shape=(10000, 500)) 73 | ids = mx.random.randint(low=0, high=10000, shape=(20, 10)) 74 | ids = [mx.reshape(idx, (-1,)) for idx in ids] 75 | mx.eval(ids) 76 | 77 | def random_take(): 78 | return [mx.take(a, idx, 0) for idx in ids] 79 | 80 | time_fn(random_take) 81 | 82 | 83 | def time_reshape_transposed(): 84 | x = mx.random.uniform(shape=(256, 256, 128)) 85 | mx.eval(x) 86 | 87 | def reshape_transposed(): 88 | return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,)) 89 | 90 | time_fn(reshape_transposed) 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser("MLX benchmarks.") 95 | parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") 96 | args = parser.parse_args() 97 | if args.gpu: 98 | mx.set_default_device(mx.gpu) 99 | else: 100 | mx.set_default_device(mx.cpu) 101 | 102 | time_add() 103 | time_matmul() 104 | time_exp() 105 | time_negative() 106 | time_logsumexp() 107 | time_take() 108 | time_reshape_transposed() 109 | -------------------------------------------------------------------------------- /examples/extensions/axpby/axpby.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/metal/kernels/bf16.h" 6 | #include "mlx/backend/metal/kernels/utils.h" 7 | 8 | template 9 | [[kernel]] void axpby_general( 10 | device const T* x [[buffer(0)]], 11 | device const T* y [[buffer(1)]], 12 | device T* out [[buffer(2)]], 13 | constant const float& alpha [[buffer(3)]], 14 | constant const float& beta [[buffer(4)]], 15 | constant const int* shape [[buffer(5)]], 16 | constant const size_t* x_strides [[buffer(6)]], 17 | constant const size_t* y_strides [[buffer(7)]], 18 | constant const int& ndim [[buffer(8)]], 19 | uint index [[thread_position_in_grid]]) { 20 | auto x_offset = elem_to_loc(index, shape, x_strides, ndim); 21 | auto y_offset = elem_to_loc(index, shape, y_strides, ndim); 22 | out[index] = 23 | static_cast(alpha) * x[x_offset] + static_cast(beta) * y[y_offset]; 24 | } 25 | 26 | template 27 | [[kernel]] void axpby_contiguous( 28 | device const T* x [[buffer(0)]], 29 | device const T* y [[buffer(1)]], 30 | device T* out [[buffer(2)]], 31 | constant const float& alpha [[buffer(3)]], 32 | constant const float& beta [[buffer(4)]], 33 | uint index [[thread_position_in_grid]]) { 34 | out[index] = 35 | static_cast(alpha) * x[index] + static_cast(beta) * y[index]; 36 | } 37 | 38 | #define instantiate_axpby(type_name, type) \ 39 | template [[host_name("axpby_general_" #type_name)]] \ 40 | [[kernel]] void axpby_general( \ 41 | device const type* x [[buffer(0)]], \ 42 | device const type* y [[buffer(1)]], \ 43 | device type* out [[buffer(2)]], \ 44 | constant const float& alpha [[buffer(3)]], \ 45 | constant const float& beta [[buffer(4)]], \ 46 | constant const int* shape [[buffer(5)]], \ 47 | constant const size_t* x_strides [[buffer(6)]], \ 48 | constant const size_t* y_strides [[buffer(7)]], \ 49 | constant const int& ndim [[buffer(8)]], \ 50 | uint index [[thread_position_in_grid]]); \ 51 | template [[host_name("axpby_contiguous_" #type_name)]] \ 52 | [[kernel]] void axpby_contiguous( \ 53 | device const type* x [[buffer(0)]], \ 54 | device const type* y [[buffer(1)]], \ 55 | device type* out [[buffer(2)]], \ 56 | constant const float& alpha [[buffer(3)]], \ 57 | constant const float& beta [[buffer(4)]], \ 58 | uint index [[thread_position_in_grid]]); 59 | 60 | instantiate_axpby(float32, float); 61 | instantiate_axpby(float16, half); 62 | instantiate_axpby(bflot16, bfloat16_t); 63 | instantiate_axpby(complex64, complex64_t); -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | AccessModifierOffset: -1 3 | AlignAfterOpenBracket: AlwaysBreak 4 | AlignConsecutiveAssignments: false 5 | AlignConsecutiveDeclarations: false 6 | AlignEscapedNewlinesLeft: true 7 | AlignOperands: false 8 | AlignTrailingComments: false 9 | AllowAllParametersOfDeclarationOnNextLine: false 10 | AllowShortBlocksOnASingleLine: false 11 | AllowShortCaseLabelsOnASingleLine: false 12 | AllowShortFunctionsOnASingleLine: Empty 13 | AllowShortIfStatementsOnASingleLine: false 14 | AllowShortLoopsOnASingleLine: false 15 | AlwaysBreakAfterReturnType: None 16 | AlwaysBreakBeforeMultilineStrings: true 17 | AlwaysBreakTemplateDeclarations: true 18 | BinPackArguments: false 19 | BinPackParameters: false 20 | BraceWrapping: 21 | AfterClass: false 22 | AfterControlStatement: false 23 | AfterEnum: false 24 | AfterFunction: false 25 | AfterNamespace: false 26 | AfterObjCDeclaration: false 27 | AfterStruct: false 28 | AfterUnion: false 29 | BeforeCatch: false 30 | BeforeElse: false 31 | IndentBraces: false 32 | BreakBeforeBinaryOperators: None 33 | BreakBeforeBraces: Attach 34 | BreakBeforeTernaryOperators: true 35 | BreakConstructorInitializersBeforeComma: false 36 | BreakAfterJavaFieldAnnotations: false 37 | BreakStringLiterals: false 38 | ColumnLimit: 80 39 | CommentPragmas: '^ IWYU pragma:' 40 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 41 | ConstructorInitializerIndentWidth: 4 42 | ContinuationIndentWidth: 4 43 | Cpp11BracedListStyle: true 44 | DerivePointerAlignment: false 45 | DisableFormat: false 46 | ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] 47 | IncludeCategories: 48 | - Regex: '^<.*\.h(pp)?>' 49 | Priority: 1 50 | - Regex: '^<.*' 51 | Priority: 2 52 | - Regex: '.*' 53 | Priority: 3 54 | IndentCaseLabels: true 55 | IndentWidth: 2 56 | IndentWrappedFunctionNames: false 57 | KeepEmptyLinesAtTheStartOfBlocks: false 58 | MacroBlockBegin: '' 59 | MacroBlockEnd: '' 60 | MaxEmptyLinesToKeep: 1 61 | NamespaceIndentation: None 62 | ObjCBlockIndentWidth: 2 63 | ObjCSpaceAfterProperty: false 64 | ObjCSpaceBeforeProtocolList: false 65 | PenaltyBreakBeforeFirstCallParameter: 1 66 | PenaltyBreakComment: 300 67 | PenaltyBreakFirstLessLess: 120 68 | PenaltyBreakString: 1000 69 | PenaltyExcessCharacter: 1000000 70 | PenaltyReturnTypeOnItsOwnLine: 200 71 | PointerAlignment: Left 72 | ReflowComments: true 73 | SortIncludes: true 74 | SpaceAfterCStyleCast: false 75 | SpaceBeforeAssignmentOperators: true 76 | SpaceBeforeParens: ControlStatements 77 | SpaceInEmptyParentheses: false 78 | SpacesBeforeTrailingComments: 1 79 | SpacesInAngles: false 80 | SpacesInContainerLiterals: true 81 | SpacesInCStyleCastParentheses: false 82 | SpacesInParentheses: false 83 | SpacesInSquareBrackets: false 84 | Standard: Cpp11 85 | TabWidth: 8 86 | UseTab: Never 87 | ... 88 | -------------------------------------------------------------------------------- /examples/extensions/axpby/axpby.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/ops.h" 6 | #include "mlx/primitives.h" 7 | 8 | namespace mlx::core { 9 | 10 | /////////////////////////////////////////////////////////////////////////////// 11 | // Operation 12 | /////////////////////////////////////////////////////////////////////////////// 13 | 14 | /** 15 | * Scale and sum two vectors elementwise 16 | * z = alpha * x + beta * y 17 | * 18 | * Follow numpy style broadcasting between x and y 19 | * Inputs are upcasted to floats if needed 20 | **/ 21 | array axpby( 22 | const array& x, // Input array x 23 | const array& y, // Input array y 24 | const float alpha, // Scaling factor for x 25 | const float beta, // Scaling factor for y 26 | StreamOrDevice s = {} // Stream on which to schedule the operation 27 | ); 28 | 29 | /////////////////////////////////////////////////////////////////////////////// 30 | // Primitive 31 | /////////////////////////////////////////////////////////////////////////////// 32 | 33 | class Axpby : public Primitive { 34 | public: 35 | explicit Axpby(Stream stream, float alpha, float beta) 36 | : Primitive(stream), alpha_(alpha), beta_(beta){}; 37 | 38 | /** 39 | * A primitive must know how to evaluate itself on the CPU/GPU 40 | * for the given inputs and populate the output array. 41 | * 42 | * To avoid unecessary allocations, the evaluation function 43 | * is responsible for allocating space for the array. 44 | */ 45 | void eval_cpu(const std::vector& inputs, array& out) override; 46 | void eval_gpu(const std::vector& inputs, array& out) override; 47 | 48 | /** The Jacobian-vector product. */ 49 | array jvp( 50 | const std::vector& primals, 51 | const std::vector& tangents, 52 | const std::vector& argnums) override; 53 | 54 | /** The vector-Jacobian product. */ 55 | std::vector vjp( 56 | const std::vector& primals, 57 | const array& cotan, 58 | const std::vector& argnums) override; 59 | 60 | /** 61 | * The primitive must know how to vectorize itself accross 62 | * the given axes. The output is a pair containing the array 63 | * representing the vectorized computation and the axis which 64 | * corresponds to the output vectorized dimension. 65 | */ 66 | std::pair vmap( 67 | const std::vector& inputs, 68 | const std::vector& axes) override; 69 | 70 | /** Print the primitive. */ 71 | void print(std::ostream& os) override { 72 | os << "Axpby"; 73 | } 74 | 75 | /** Equivalence check **/ 76 | bool is_equivalent(const Primitive& other) const override; 77 | 78 | private: 79 | float alpha_; 80 | float beta_; 81 | 82 | /** Fall back implementation for evaluation on CPU */ 83 | void eval(const std::vector& inputs, array& out); 84 | }; 85 | 86 | } // namespace mlx::core -------------------------------------------------------------------------------- /mlx/backend/common/softmax.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "mlx/backend/common/copy.h" 7 | #include "mlx/primitives.h" 8 | 9 | namespace mlx::core { 10 | 11 | namespace { 12 | 13 | template 14 | void softmax(const array& in, array& out) { 15 | const T* in_ptr = in.data(); 16 | T* out_ptr = out.data(); 17 | int N = in.shape().back(); 18 | int M = in.data_size() / N; 19 | const T* current_in_ptr; 20 | T* current_out_ptr; 21 | 22 | for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) { 23 | // Find the maximum 24 | current_in_ptr = in_ptr; 25 | T maximum = *current_in_ptr; 26 | for (int j = 0; j < N; j++, current_in_ptr++) { 27 | maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum; 28 | } 29 | 30 | // Compute the normalizer and the exponentials 31 | T normalizer = 0; 32 | current_out_ptr = out_ptr; 33 | current_in_ptr = in_ptr; 34 | for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) { 35 | T expv = std::exp(*current_in_ptr - maximum); 36 | normalizer += expv; 37 | *current_out_ptr = expv; 38 | } 39 | normalizer = 1 / normalizer; 40 | 41 | // Normalize 42 | current_out_ptr = out_ptr; 43 | for (int j = 0; j < N; j++, current_out_ptr++) { 44 | *current_out_ptr *= normalizer; 45 | } 46 | } 47 | } 48 | 49 | } // namespace 50 | 51 | void Softmax::eval(const std::vector& inputs, array& out) { 52 | assert(inputs.size() == 1); 53 | 54 | // Make sure that the last dimension is contiguous 55 | auto check_input = [](array x) { 56 | if (x.strides().back() == 1) { 57 | return x; 58 | } else { 59 | array x_copy(x.shape(), x.dtype(), nullptr, {}); 60 | copy(x, x_copy, CopyType::General); 61 | return x_copy; 62 | } 63 | }; 64 | array in = check_input(std::move(inputs[0])); 65 | out.set_data( 66 | allocator::malloc_or_wait(in.data_size() * in.itemsize()), 67 | in.data_size(), 68 | in.strides(), 69 | in.flags()); 70 | 71 | switch (in.dtype()) { 72 | case bool_: 73 | case uint8: 74 | case uint16: 75 | case uint32: 76 | case uint64: 77 | case int8: 78 | case int16: 79 | case int32: 80 | case int64: 81 | throw std::invalid_argument( 82 | "Softmax is defined only for floating point types"); 83 | break; 84 | case float32: 85 | softmax(in, out); 86 | break; 87 | case float16: 88 | softmax(in, out); 89 | break; 90 | case bfloat16: 91 | softmax(in, out); 92 | break; 93 | case complex64: 94 | throw std::invalid_argument( 95 | "[Softmax] Not yet implemented for complex64"); 96 | break; 97 | } 98 | } 99 | 100 | } // namespace mlx::core 101 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/erf.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | /* 8 | * Approximation to the error function. 9 | * Based on code from: 10 | * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199 11 | */ 12 | float erf(float a) { 13 | float r, s, t, u; 14 | t = metal::abs(a); 15 | s = a * a; 16 | if (t > 0.927734375f) { 17 | // maximum error 0.99527 ulp 18 | r = metal::fma( 19 | -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 20 | u = metal::fma( 21 | -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 22 | r = metal::fma(r, s, u); 23 | r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 24 | r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 25 | r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 26 | r = metal::fma(r, t, -t); 27 | // TODO, replace with expm1 when implemented 28 | r = 1.0f - metal::exp(r); 29 | r = metal::copysign(r, a); 30 | } else { 31 | // maximum error 0.98929 ulp 32 | r = -5.96761703e-4f; // -0x1.38e000p-11 33 | r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 34 | r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 35 | r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 36 | r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 37 | r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 38 | r = metal::fma(r, a, a); 39 | } 40 | return r; 41 | } 42 | 43 | float erfinv(float a) { 44 | auto t = metal::fma(a, 0.0f - a, 1.0f); 45 | t = metal::log(t); 46 | float p; 47 | if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793 48 | p = 3.03697567e-10f; // 0x1.4deb44p-32 49 | p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 50 | p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 51 | p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 52 | p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 53 | p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 54 | p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 55 | p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 56 | p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 57 | } else { // maximum ulp error = 2.35002 58 | p = 5.43877832e-9f; // 0x1.75c000p-28 59 | p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 60 | p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 61 | p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 62 | p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 63 | p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 64 | p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 65 | p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 66 | p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 67 | p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 68 | } 69 | return a * p; 70 | } -------------------------------------------------------------------------------- /mlx/dtype.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "mlx/types/complex.h" 11 | #include "mlx/types/half_types.h" 12 | 13 | namespace mlx::core { 14 | 15 | struct Dtype { 16 | enum class Val { 17 | bool_, 18 | uint8, 19 | uint16, 20 | uint32, 21 | uint64, 22 | int8, 23 | int16, 24 | int32, 25 | int64, 26 | float16, 27 | float32, 28 | bfloat16, 29 | complex64, 30 | }; 31 | 32 | enum class Kind { 33 | b, /* bool */ 34 | u, /* unsigned int */ 35 | i, /* signed int */ 36 | f, /* float */ 37 | c, /* complex */ 38 | V, /* void - used for brain float */ 39 | }; 40 | 41 | Val val; 42 | const uint8_t size; 43 | constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){}; 44 | constexpr operator Val() const { 45 | return val; 46 | }; 47 | }; 48 | 49 | inline bool is_available(const Dtype& dtype) { 50 | return true; 51 | } 52 | 53 | static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; 54 | 55 | static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)}; 56 | static constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)}; 57 | static constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)}; 58 | static constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)}; 59 | 60 | static constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)}; 61 | static constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)}; 62 | static constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)}; 63 | static constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)}; 64 | 65 | static constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)}; 66 | static constexpr Dtype float32{Dtype::Val::float32, sizeof(float)}; 67 | static constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)}; 68 | static constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)}; 69 | 70 | Dtype promote_types(const Dtype& t1, const Dtype& t2); 71 | 72 | inline uint8_t size_of(const Dtype& t) { 73 | return t.size; 74 | } 75 | 76 | Dtype::Kind kindof(const Dtype& t); 77 | 78 | inline bool is_unsigned(const Dtype& t) { 79 | return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b; 80 | } 81 | 82 | inline bool is_floating_point(const Dtype& t) { 83 | return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V || 84 | kindof(t) == Dtype::Kind::c; 85 | } 86 | 87 | inline bool is_integral(const Dtype& t) { 88 | return !(is_floating_point(t)); 89 | } 90 | 91 | template 92 | struct TypeToDtype { 93 | operator Dtype(); 94 | }; 95 | 96 | // Array protocol typestring for Dtype 97 | std::string dtype_to_array_protocol(const Dtype& t); 98 | // Dtype from array protocol type string 99 | Dtype dtype_from_array_protocol(const std::string& t); 100 | 101 | } // namespace mlx::core 102 | -------------------------------------------------------------------------------- /mlx/load.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace mlx::core { 10 | 11 | namespace io { 12 | 13 | class Reader { 14 | public: 15 | virtual bool is_open() const = 0; 16 | virtual bool good() const = 0; 17 | virtual size_t tell() const = 0; 18 | virtual void seek( 19 | int64_t off, 20 | std::ios_base::seekdir way = std::ios_base::beg) = 0; 21 | virtual void read(char* data, size_t n) = 0; 22 | virtual std::string label() const = 0; 23 | }; 24 | 25 | class Writer { 26 | public: 27 | virtual bool is_open() const = 0; 28 | virtual bool good() const = 0; 29 | virtual size_t tell() const = 0; 30 | virtual void seek( 31 | int64_t off, 32 | std::ios_base::seekdir way = std::ios_base::beg) = 0; 33 | virtual void write(const char* data, size_t n) = 0; 34 | virtual std::string label() const = 0; 35 | }; 36 | 37 | class FileReader : public Reader { 38 | public: 39 | explicit FileReader(const std::shared_ptr& is) 40 | : is_(is), label_("stream") {} 41 | explicit FileReader(const std::string& file_path) 42 | : is_(std::make_shared(file_path, std::ios::binary)), 43 | label_(file_path) {} 44 | 45 | bool is_open() const override { 46 | return is_->is_open(); 47 | } 48 | 49 | bool good() const override { 50 | return is_->good(); 51 | } 52 | 53 | size_t tell() const override { 54 | return is_->tellg(); 55 | } 56 | 57 | void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) 58 | override { 59 | is_->seekg(off, way); 60 | } 61 | 62 | void read(char* data, size_t n) override { 63 | is_->read(data, n); 64 | } 65 | 66 | std::string label() const override { 67 | return "file " + label_; 68 | } 69 | 70 | private: 71 | std::shared_ptr is_; 72 | std::string label_; 73 | }; 74 | 75 | class FileWriter : public Writer { 76 | public: 77 | explicit FileWriter(const std::shared_ptr& is) 78 | : os_(is), label_("stream") {} 79 | explicit FileWriter(const std::string& file_path) 80 | : os_(std::make_shared(file_path, std::ios::binary)), 81 | label_(file_path) {} 82 | 83 | bool is_open() const override { 84 | return os_->is_open(); 85 | } 86 | 87 | bool good() const override { 88 | return os_->good(); 89 | } 90 | 91 | size_t tell() const override { 92 | return os_->tellp(); 93 | } 94 | 95 | void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) 96 | override { 97 | os_->seekp(off, way); 98 | } 99 | 100 | void write(const char* data, size_t n) override { 101 | os_->write(data, n); 102 | } 103 | 104 | std::string label() const override { 105 | return "file " + label_; 106 | } 107 | 108 | private: 109 | std::shared_ptr os_; 110 | std::string label_; 111 | }; 112 | 113 | } // namespace io 114 | } // namespace mlx::core -------------------------------------------------------------------------------- /mlx/types/complex.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | #include 5 | #include "mlx/types/half_types.h" 6 | 7 | namespace mlx::core { 8 | 9 | struct complex64_t; 10 | 11 | template 12 | static constexpr bool can_convert_to_complex64 = 13 | !std::is_same_v && std::is_convertible_v; 14 | 15 | struct complex64_t : public std::complex { 16 | complex64_t(float v, float u) : std::complex(v, u){}; 17 | complex64_t(std::complex v) : std::complex(v){}; 18 | 19 | template < 20 | typename T, 21 | typename = typename std::enable_if>::type> 22 | complex64_t(T x) : std::complex(x){}; 23 | 24 | operator float() const { 25 | return real(); 26 | }; 27 | }; 28 | 29 | inline bool operator>=(const complex64_t& a, const complex64_t& b) { 30 | return (a.real() > b.real()) || 31 | (a.real() == b.real() && a.imag() >= b.imag()); 32 | } 33 | 34 | inline bool operator>(const complex64_t& a, const complex64_t& b) { 35 | return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); 36 | } 37 | 38 | inline bool operator<=(const complex64_t& a, const complex64_t& b) { 39 | return operator>=(b, a); 40 | } 41 | 42 | inline bool operator<(const complex64_t& a, const complex64_t& b) { 43 | return operator>(b, a); 44 | } 45 | 46 | inline complex64_t operator-(const complex64_t& v) { 47 | return -static_cast>(v); 48 | } 49 | 50 | // clang-format off 51 | #define complex_binop_helper(_op_, _operator_, itype) \ 52 | inline complex64_t _operator_(itype x, const complex64_t& y) { \ 53 | return x _op_ static_cast>(y); \ 54 | } \ 55 | inline complex64_t _operator_(const complex64_t& x, itype y) { \ 56 | return static_cast>(x) _op_ y; \ 57 | } 58 | 59 | #define complex_binop(_op_, _operator_) \ 60 | inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ 61 | return static_cast>(x) \ 62 | _op_ static_cast>(y); \ 63 | } \ 64 | complex_binop_helper(_op_, _operator_, bool) \ 65 | complex_binop_helper(_op_, _operator_, uint32_t) \ 66 | complex_binop_helper(_op_, _operator_, uint64_t) \ 67 | complex_binop_helper(_op_, _operator_, int32_t) \ 68 | complex_binop_helper(_op_, _operator_, int64_t) \ 69 | complex_binop_helper(_op_, _operator_, float16_t) \ 70 | complex_binop_helper(_op_, _operator_, bfloat16_t) \ 71 | complex_binop_helper(_op_, _operator_, const std::complex&) \ 72 | complex_binop_helper(_op_, _operator_, float) 73 | // clang-format on 74 | 75 | complex_binop(+, operator+) 76 | 77 | } // namespace mlx::core 78 | -------------------------------------------------------------------------------- /mlx/backend/metal/metal.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "mlx/array.h" 8 | #include "mlx/backend/metal/device.h" 9 | #include "mlx/primitives.h" 10 | #include "mlx/scheduler.h" 11 | 12 | namespace mlx::core::metal { 13 | 14 | int max_ops_per_buffer() { 15 | auto get_val = []() { 16 | if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) { 17 | return atoi(buff_str); 18 | } else { 19 | return 10; 20 | } 21 | }; 22 | static int max_ops_per_buffer_ = get_val(); 23 | return max_ops_per_buffer_; 24 | } 25 | 26 | #define MAX_OPS_PER_BUFFER max_ops_per_buffer() 27 | 28 | MTL::CommandBuffer* increment_command_buffer(Stream s) { 29 | auto& d = metal::device(s.device); 30 | auto command_buffer = d.get_command_buffer(s.index); 31 | if (command_buffer == nullptr || 32 | d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) { 33 | if (command_buffer != nullptr) { 34 | d.end_encoding(s.index); 35 | scheduler::notify_new_task(s); 36 | command_buffer->addCompletedHandler( 37 | [s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); }); 38 | d.commit_command_buffer(s.index); 39 | } 40 | command_buffer = d.new_command_buffer(s.index); 41 | } 42 | d.increment_command_buffer_ops(s.index); 43 | return command_buffer; 44 | } 45 | 46 | std::function make_task( 47 | array& arr, 48 | std::vector> deps, 49 | std::shared_ptr> p, 50 | bool retain_graph) { 51 | auto task = 52 | [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { 53 | for (auto& d : deps) { 54 | d.wait(); 55 | } 56 | auto s = arr.primitive().stream(); 57 | auto command_buffer = increment_command_buffer(s); 58 | arr.primitive().eval_gpu(arr.inputs(), arr); 59 | if (p) { 60 | metal::device(s.device).end_encoding(s.index); 61 | scheduler::notify_new_task(s); 62 | command_buffer->addCompletedHandler( 63 | [retain_graph, s, arr, p = std::move(p)]( 64 | MTL::CommandBuffer*) mutable { 65 | if (!retain_graph) { 66 | arr.detach(); 67 | } 68 | p->set_value(); 69 | // Signal this thread to clear the pool on a synchroniztion. 70 | scheduler::enqueue(s, []() { 71 | thread_autorelease_pool()->release(); 72 | thread_autorelease_pool() = 73 | NS::AutoreleasePool::alloc()->init(); 74 | }); 75 | scheduler::notify_task_completion(s); 76 | }); 77 | metal::device(s.device).commit_command_buffer(s.index); 78 | } else { 79 | command_buffer->addCompletedHandler( 80 | [retain_graph, s, arr](MTL::CommandBuffer*) mutable { 81 | if (!retain_graph) { 82 | arr.detach(); 83 | } 84 | }); 85 | } 86 | }; 87 | return task; 88 | } 89 | 90 | } // namespace mlx::core::metal 91 | -------------------------------------------------------------------------------- /examples/cpp/tutorial.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | 6 | #include "mlx/mlx.h" 7 | 8 | using namespace mlx::core; 9 | 10 | void array_basics() { 11 | // Make a scalar array: 12 | array x(1.0); 13 | 14 | // Get the value out of it: 15 | auto s = x.item(); 16 | assert(s == 1.0); 17 | 18 | // Scalars have a size of 1: 19 | size_t size = x.size(); 20 | assert(size == 1); 21 | 22 | // Scalars have 0 dimensions: 23 | int ndim = x.ndim(); 24 | assert(ndim == 0); 25 | 26 | // The shape should be an empty vector: 27 | auto shape = x.shape(); 28 | assert(shape.empty()); 29 | 30 | // The datatype should be float32: 31 | auto dtype = x.dtype(); 32 | assert(dtype == float32); 33 | 34 | // Specify the dtype when constructing the array: 35 | x = array(1, int32); 36 | assert(x.dtype() == int32); 37 | x.item(); // OK 38 | // x.item(); // Undefined! 39 | 40 | // Make a multidimensional array: 41 | x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); 42 | // mlx is row-major by default so the first row of this array 43 | // is [1.0, 2.0] and the second row is [3.0, 4.0] 44 | 45 | // Make an array of shape {2, 2} filled with ones: 46 | auto y = ones({2, 2}); 47 | 48 | // Pointwise add x and y: 49 | auto z = add(x, y); 50 | 51 | // Same thing: 52 | z = x + y; 53 | 54 | // mlx is lazy by default. At this point `z` only 55 | // has a shape and a type but no actual data: 56 | assert(z.dtype() == float32); 57 | assert(z.shape(0) == 2); 58 | assert(z.shape(1) == 2); 59 | 60 | // To actually run the compuation you must evaluate `z`. 61 | // Under the hood, mlx records operations in a graph. 62 | // The variable `z` is a node in the graph which points to its operation 63 | // and inputs. When `eval` is called on an array (or arrays), the array and 64 | // all of its dependencies are recursively evaluated to produce the result. 65 | // Once an array is evaluated, it has data and is detached from its inputs. 66 | eval(z); 67 | 68 | // Of course the array can still be an input to other operations. You can even 69 | // call eval on the array again, this will just be a no-op: 70 | eval(z); // no-op 71 | 72 | // Some functions or methods on arrays implicitly evaluate them. For example 73 | // accessing a value in an array or printing the array implicitly evaluate it: 74 | z = ones({1}); 75 | z.item(); // implicit evaluation 76 | 77 | z = ones({2, 2}); 78 | std::cout << z << std::endl; // implicit evaluation 79 | } 80 | 81 | void automatic_differentiation() { 82 | auto fn = [](array x) { return square(x); }; 83 | 84 | // Computing the derivative function of a function 85 | auto grad_fn = grad(fn); 86 | // Call grad_fn on the input to get the derivative 87 | auto x = array(1.5); 88 | auto dfdx = grad_fn(x); 89 | // dfdx is 2 * x 90 | 91 | // Get the second derivative by composing grad with grad 92 | auto df2dx2 = grad(grad(fn))(x); 93 | // df2dx2 is 2 94 | } 95 | 96 | int main() { 97 | array_basics(); 98 | automatic_differentiation(); 99 | } 100 | -------------------------------------------------------------------------------- /mlx/backend/metal/softmax.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/metal/copy.h" 6 | #include "mlx/backend/metal/device.h" 7 | #include "mlx/backend/metal/kernels/defines.h" 8 | #include "mlx/backend/metal/utils.h" 9 | #include "mlx/primitives.h" 10 | 11 | namespace mlx::core { 12 | 13 | void Softmax::eval_gpu(const std::vector& inputs, array& out) { 14 | assert(inputs.size() == 1); 15 | if (!is_floating_point(out.dtype())) { 16 | throw std::runtime_error( 17 | "[softmax] Does not support non-floating point types."); 18 | } 19 | auto& s = stream(); 20 | auto& d = metal::device(s.device); 21 | 22 | // Make sure that the last dimension is contiguous 23 | std::vector copies; 24 | auto check_input = [&copies, &s](const array& x) { 25 | if (x.strides()[x.ndim() - 1] == 1) { 26 | return x; 27 | } else { 28 | array x_copy(x.shape(), x.dtype(), nullptr, {}); 29 | copy_gpu(x, x_copy, CopyType::General, s); 30 | copies.push_back(x_copy); 31 | return x_copy; 32 | } 33 | }; 34 | const array& in = check_input(inputs[0]); 35 | out.set_data( 36 | allocator::malloc_or_wait(in.data_size() * in.itemsize()), 37 | in.data_size(), 38 | in.strides(), 39 | in.flags()); 40 | 41 | int axis_size = in.shape().back(); 42 | int n_rows = in.data_size() / axis_size; 43 | 44 | const int simd_size = 32; 45 | const int n_reads = SOFTMAX_N_READS; 46 | const int looped_limit = SOFTMAX_LOOPED_LIMIT; 47 | std::string op_name = "softmax_"; 48 | if (axis_size > looped_limit) { 49 | op_name += "looped_"; 50 | } 51 | op_name += type_to_name(out); 52 | auto compute_encoder = d.get_command_encoder(s.index); 53 | { 54 | auto kernel = d.get_kernel(op_name); 55 | 56 | MTL::Size grid_dims, group_dims; 57 | if (axis_size <= looped_limit) { 58 | size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads; 59 | size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size; 60 | size_t threadgroup_size = simd_size * simds_needed; 61 | assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup()); 62 | size_t n_threads = n_rows * threadgroup_size; 63 | grid_dims = MTL::Size(n_threads, 1, 1); 64 | group_dims = MTL::Size(threadgroup_size, 1, 1); 65 | } else { 66 | size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup(); 67 | size_t n_threads = n_rows * threadgroup_size; 68 | grid_dims = MTL::Size(n_threads, 1, 1); 69 | group_dims = MTL::Size(threadgroup_size, 1, 1); 70 | } 71 | 72 | compute_encoder->setComputePipelineState(kernel); 73 | set_array_buffer(compute_encoder, in, 0); 74 | set_array_buffer(compute_encoder, out, 1); 75 | compute_encoder->setBytes(&axis_size, sizeof(int), 2); 76 | compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0); 77 | compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1); 78 | compute_encoder->dispatchThreads(grid_dims, group_dims); 79 | } 80 | d.get_command_buffer(s.index)->addCompletedHandler( 81 | [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); 82 | } 83 | 84 | } // namespace mlx::core 85 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/random.metal: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "mlx/backend/metal/kernels/utils.h" 4 | 5 | static constexpr constant uint32_t rotations[2][4] = { 6 | {13, 15, 26, 6}, 7 | {17, 29, 16, 24} 8 | }; 9 | 10 | union rbits { 11 | uint2 val; 12 | uchar4 bytes[2]; 13 | }; 14 | 15 | rbits threefry2x32_hash(const thread uint2& key, uint2 count) { 16 | 17 | uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; 18 | 19 | rbits v; 20 | v.val.x = count.x + ks[0]; 21 | v.val.y = count.y + ks[1]; 22 | 23 | for (int i = 0; i < 5; ++i) { 24 | for (auto r : rotations[i % 2]) { 25 | v.val.x += v.val.y; 26 | v.val.y = (v.val.y << r) | (v.val.y >> (32 - r)); 27 | v.val.y ^= v.val.x; 28 | } 29 | v.val.x += ks[(i + 1) % 3]; 30 | v.val.y += ks[(i + 2) % 3] + i + 1; 31 | } 32 | 33 | return v; 34 | } 35 | 36 | [[kernel]] void rbitsc( 37 | device const uint32_t* keys, 38 | device char* out, 39 | device const bool& odd, 40 | device const uint& bytes_per_key, 41 | uint2 grid_dim [[threads_per_grid]], 42 | uint2 index [[thread_position_in_grid]]) { 43 | auto kidx = 2 * index.x; 44 | auto key = uint2(keys[kidx], keys[kidx + 1]); 45 | auto half_size = grid_dim.y - odd; 46 | out += index.x * bytes_per_key; 47 | bool drop_last = odd && (index.y == half_size); 48 | auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y); 49 | auto bits = threefry2x32_hash(key, count); 50 | for (int i = 0; i < 4; ++i) { 51 | out[4 * count.x + i] = bits.bytes[0][i]; 52 | } 53 | if (!drop_last) { 54 | if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { 55 | int edge_bytes = (bytes_per_key % 4); 56 | for (int i = 0; i < edge_bytes; ++i) { 57 | out[4 * count.y + i] = bits.bytes[1][i]; 58 | } 59 | } else { 60 | for (int i = 0; i < 4; ++i) { 61 | out[4 * count.y + i] = bits.bytes[1][i]; 62 | } 63 | } 64 | } 65 | } 66 | 67 | [[kernel]] void rbits( 68 | device const uint32_t* keys, 69 | device char* out, 70 | device const bool& odd, 71 | device const uint& bytes_per_key, 72 | device const int& ndim, 73 | device const int* key_shape, 74 | device const size_t* key_strides, 75 | uint2 grid_dim [[threads_per_grid]], 76 | uint2 index [[thread_position_in_grid]]) { 77 | auto kidx = 2 * index.x; 78 | auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim); 79 | auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim); 80 | auto key = uint2(keys[k1_elem], keys[k2_elem]); 81 | auto half_size = grid_dim.y - odd; 82 | out += index.x * bytes_per_key; 83 | bool drop_last = odd && (index.y == half_size); 84 | auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y); 85 | auto bits = threefry2x32_hash(key, count); 86 | for (int i = 0; i < 4; ++i) { 87 | out[4 * count.x + i] = bits.bytes[0][i]; 88 | } 89 | if (!drop_last) { 90 | if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { 91 | int edge_bytes = (bytes_per_key % 4); 92 | for (int i = 0; i < edge_bytes; ++i) { 93 | out[4 * count.y + i] = bits.bytes[1][i]; 94 | } 95 | } else { 96 | for (int i = 0; i < 4; ++i) { 97 | out[4 * count.y + i] = bits.bytes[1][i]; 98 | } 99 | } 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /tests/blas_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "doctest/doctest.h" 6 | 7 | #include "mlx/mlx.h" 8 | 9 | using namespace mlx::core; 10 | 11 | TEST_CASE("test matmul") { 12 | auto a = array(1); 13 | auto b = array({1.0}); 14 | CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); 15 | 16 | a = array({1.0}); 17 | b = array({1.0}); 18 | auto out = matmul(a, b); 19 | CHECK_EQ(out.shape(), std::vector{}); 20 | CHECK_EQ(out.size(), 1); 21 | CHECK_EQ(out.dtype(), float32); 22 | CHECK_EQ(out.item(), 1.0f); 23 | 24 | a = ones({2, 4}); 25 | b = ones({2}); 26 | CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); 27 | 28 | a = ones({2, 4}); 29 | b = ones({3, 2}); 30 | CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); 31 | 32 | a = ones({2, 4}); 33 | b = ones({4, 3, 2}); 34 | CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); 35 | 36 | a = ones({2}); 37 | b = ones({4, 2}); 38 | CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); 39 | 40 | a = ones({2, 3}); 41 | b = ones({4, 2}); 42 | CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); 43 | 44 | a = ones({2, 4, 3}); 45 | b = ones({4, 2}); 46 | CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); 47 | 48 | a = ones({2, 4}); 49 | b = ones({4, 2}); 50 | out = matmul(a, b); 51 | CHECK(array_equal(out, full({2, 2}, 4.0f)).item()); 52 | 53 | a = ones({2, 4}, int32); 54 | b = ones({4, 2}, float32); 55 | out = matmul(a, b); 56 | CHECK(array_equal(out, full({2, 2}, 4.0f)).item()); 57 | 58 | // Check single dimensions 59 | a = ones({4}); 60 | b = ones({4, 2}); 61 | out = matmul(a, b); 62 | CHECK(array_equal(out, full({2}, 4.0f)).item()); 63 | 64 | a = ones({2, 4}); 65 | b = ones({4}); 66 | out = matmul(a, b); 67 | CHECK(array_equal(out, full({2}, 4.0f)).item()); 68 | 69 | a = ones({4}); 70 | b = ones({4}); 71 | out = matmul(a, b); 72 | CHECK(array_equal(out, full({}, 4.0f)).item()); 73 | 74 | // Test transposed arrays 75 | a = array({1.0f, 1.0f, 1.0f, 1.0f}, {1, 4}); 76 | b = array({1.0f, 1.0f, 1.0f, 1.0f}, {4, 1}); 77 | out = matmul(transpose(a), transpose(b)); 78 | CHECK(array_equal(out, ones({4, 4})).item()); 79 | 80 | a = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); 81 | b = array({1.0f, 2.0f, 1.0f, 2.0f}, {2, 2}); 82 | out = matmul(transpose(a), b); 83 | CHECK( 84 | array_equal(out, array({4.0f, 8.0f, 6.0f, 12.0f}, {2, 2})).item()); 85 | 86 | out = matmul(a, transpose(b)); 87 | CHECK( 88 | array_equal(out, array({5.0f, 5.0f, 11.0f, 11.0f}, {2, 2})).item()); 89 | 90 | out = matmul(transpose(a), transpose(b)); 91 | CHECK( 92 | array_equal(out, array({7.0f, 7.0f, 10.0f, 10.0f}, {2, 2})).item()); 93 | 94 | // Test broadcasting for both arrays 95 | a = ones({5, 4, 2}); 96 | b = ones({2, 3}); 97 | out = matmul(a, b); 98 | CHECK(array_equal(out, full({5, 4, 3}, 2.0f)).item()); 99 | 100 | a = ones({5, 1, 4, 2}); 101 | b = ones({1, 7, 2, 3}); 102 | out = matmul(a, b); 103 | CHECK(array_equal(out, full({5, 7, 4, 3}, 2.0f)).item()); 104 | 105 | // Test batched matmul with transpose 106 | a = ones({2, 2, 4}); 107 | b = ones({2, 4, 2}); 108 | out = matmul(transpose(a, {0, 2, 1}), transpose(b, {0, 2, 1})); 109 | CHECK(array_equal(out, full({2, 4, 4}, 2.0f)).item()); 110 | } 111 | -------------------------------------------------------------------------------- /python/tests/test_fft.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import unittest 4 | 5 | import itertools 6 | import mlx.core as mx 7 | import numpy as np 8 | 9 | import mlx_tests 10 | 11 | 12 | class TestFFT(mlx_tests.MLXTestCase): 13 | def check_mx_np(self, op, a_np, axes, s): 14 | with self.subTest(op=op, axes=axes, s=s): 15 | op_np = getattr(np.fft, op) 16 | op_mx = getattr(mx.fft, op) 17 | out_np = op_np(a_np, s=s, axes=axes) 18 | a_mx = mx.array(a_np) 19 | out_mx = op_mx(a_mx, s=s, axes=axes) 20 | self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) 21 | 22 | def test_fft(self): 23 | default = mx.default_device() 24 | mx.set_default_device(mx.cpu) 25 | 26 | def check_mx_np(op_mx, op_np, a_np, **kwargs): 27 | out_np = op_np(a_np, **kwargs) 28 | a_mx = mx.array(a_np) 29 | out_mx = op_mx(a_mx, **kwargs) 30 | self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) 31 | 32 | r = np.random.rand(100).astype(np.float32) 33 | i = np.random.rand(100).astype(np.float32) 34 | a_np = r + 1j * i 35 | check_mx_np(mx.fft.fft, np.fft.fft, a_np) 36 | 37 | # Check with slicing and padding 38 | r = np.random.rand(100).astype(np.float32) 39 | i = np.random.rand(100).astype(np.float32) 40 | a_np = r + 1j * i 41 | check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) 42 | check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) 43 | 44 | # Check different axes 45 | r = np.random.rand(100, 100).astype(np.float32) 46 | i = np.random.rand(100, 100).astype(np.float32) 47 | a_np = r + 1j * i 48 | check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) 49 | check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) 50 | 51 | # Check real fft 52 | a_np = np.random.rand(100).astype(np.float32) 53 | check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) 54 | check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) 55 | check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) 56 | 57 | # Check real inverse 58 | r = np.random.rand(100, 100).astype(np.float32) 59 | i = np.random.rand(100, 100).astype(np.float32) 60 | a_np = r + 1j * i 61 | check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) 62 | check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) 63 | check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) 64 | check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) 65 | check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) 66 | check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) 67 | 68 | mx.set_default_device(default) 69 | 70 | def test_fftn(self): 71 | default = mx.default_device() 72 | mx.set_default_device(mx.cpu) 73 | 74 | r = np.random.randn(8, 8, 8).astype(np.float32) 75 | i = np.random.randn(8, 8, 8).astype(np.float32) 76 | a = r + 1j * i 77 | 78 | axes = [None, (1, 2), (2, 1), (0, 2)] 79 | shapes = [None, (10, 5), (5, 10)] 80 | ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"] 81 | 82 | for op, ax, s in itertools.product(ops, axes, shapes): 83 | x = a 84 | if op in ["rfft2", "rfftn"]: 85 | x = r 86 | self.check_mx_np(op, x, axes=ax, s=s) 87 | 88 | mx.set_default_device(default) 89 | 90 | 91 | if __name__ == "__main__": 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLX 2 | 3 | [**Quickstart**](#quickstart) | [**Installation**](#installation) | 4 | [**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) | 5 | [**Examples**](#examples) 6 | 7 | MLX is an array framework for machine learning on Apple silicon, brought to you 8 | by Apple machine learning research. 9 | 10 | Some key features of MLX include: 11 | 12 | - **Familiar APIs**: MLX has a Python API that closely follows NumPy. 13 | MLX also has a fully featured C++ API, which closely mirrors the Python API. 14 | MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs 15 | that closely follow PyTorch to simplify building more complex models. 16 | 17 | - **Composable function transformations**: MLX has composable function 18 | transformations for automatic differentiation, automatic vectorization, 19 | and computation graph optimization. 20 | 21 | - **Lazy computation**: Computations in MLX are lazy. Arrays are only 22 | materialized when needed. 23 | 24 | - **Dynamic graph construction**: Computation graphs in MLX are built 25 | dynamically. Changing the shapes of function arguments does not trigger 26 | slow compilations, and debugging is simple and intuitive. 27 | 28 | - **Multi-device**: Operations can run on any of the supported devices 29 | (currently, the CPU and GPU). 30 | 31 | - **Unified memory**: A notable difference from MLX and other frameworks 32 | is the *unified memory model*. Arrays in MLX live in shared memory. 33 | Operations on MLX arrays can be performed on any of the supported 34 | device types without moving data. 35 | 36 | MLX is designed by machine learning researchers for machine learning 37 | researchers. The framework is intended to be user-friendly, but still efficient 38 | to train and deploy models. The design of the framework itself is also 39 | conceptually simple. We intend to make it easy for researchers to extend and 40 | improve MLX with the goal of quickly exploring new ideas. 41 | 42 | The design of MLX is inspired by frameworks like 43 | [NumPy](https://numpy.org/doc/stable/index.html), 44 | [PyTorch](https://pytorch.org/), [Jax](https://github.com/google/jax), and 45 | [ArrayFire](https://arrayfire.org/). 46 | 47 | ## Examples 48 | 49 | The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a 50 | variety of examples, including: 51 | 52 | - [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training. 53 | - Large-scale text generation with 54 | [LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and 55 | finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora). 56 | - Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion). 57 | - Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper). 58 | 59 | ## Quickstart 60 | 61 | See the [quick start 62 | guide](https://ml-explore.github.io/mlx/build/html/quick_start.html) 63 | in the documentation. 64 | 65 | ## Installation 66 | 67 | MLX is available on [PyPi](https://pypi.org/project/mlx/). To install the Python API, run: 68 | 69 | ``` 70 | pip install mlx 71 | ``` 72 | 73 | Checkout the 74 | [documentation](https://ml-explore.github.io/mlx/build/html/install.html#) 75 | for more information on building the C++ and Python APIs from source. 76 | 77 | ## Contributing 78 | 79 | Check out the [contribution guidelines](CONTRIBUTING.md) for more information 80 | on contributing to MLX. 81 | -------------------------------------------------------------------------------- /mlx/backend/common/arg_reduce.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/primitives.h" 6 | #include "utils.h" 7 | 8 | namespace mlx::core { 9 | 10 | namespace { 11 | 12 | template 13 | void arg_reduce(const array& in, array& out, const OpT& op, int axis) { 14 | auto axis_size = in.shape()[axis]; 15 | auto axis_stride = in.strides()[axis]; 16 | std::vector strides = in.strides(); 17 | std::vector shape = in.shape(); 18 | strides.erase(strides.begin() + axis); 19 | shape.erase(shape.begin() + axis); 20 | for (uint32_t i = 0; i < out.size(); ++i) { 21 | auto loc = elem_to_loc(i, shape, strides); 22 | auto in_ptr = in.data() + loc; 23 | uint32_t ind_v = 0; 24 | InT v = (*in_ptr); 25 | for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) { 26 | op(j, (*in_ptr), &ind_v, &v); 27 | } 28 | out.data()[i] = ind_v; 29 | } 30 | } 31 | 32 | template 33 | void arg_reduce_dispatch( 34 | const array& in, 35 | array& out, 36 | ArgReduce::ReduceType rtype, 37 | int axis) { 38 | switch (rtype) { 39 | case ArgReduce::ArgMin: { 40 | auto op = [](auto ind_x, auto x, auto ind_y, auto y) { 41 | if (x < (*y)) { 42 | (*y) = x; 43 | (*ind_y) = ind_x; 44 | } 45 | }; 46 | arg_reduce(in, out, op, axis); 47 | break; 48 | } 49 | case ArgReduce::ArgMax: { 50 | auto op = [](auto ind_x, auto x, auto ind_y, auto y) { 51 | if (x > (*y)) { 52 | (*y) = x; 53 | (*ind_y) = ind_x; 54 | } 55 | }; 56 | arg_reduce(in, out, op, axis); 57 | break; 58 | } 59 | } 60 | } 61 | 62 | } // namespace 63 | 64 | void ArgReduce::eval(const std::vector& inputs, array& out) { 65 | assert(inputs.size() == 1); 66 | auto& in = inputs[0]; 67 | out.set_data(allocator::malloc_or_wait(out.nbytes())); 68 | 69 | switch (in.dtype()) { 70 | case bool_: 71 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 72 | break; 73 | case uint8: 74 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 75 | break; 76 | case uint16: 77 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 78 | break; 79 | case uint32: 80 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 81 | break; 82 | case uint64: 83 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 84 | break; 85 | case int8: 86 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 87 | break; 88 | case int16: 89 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 90 | break; 91 | case int32: 92 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 93 | break; 94 | case int64: 95 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 96 | break; 97 | case float16: 98 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 99 | break; 100 | case float32: 101 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 102 | break; 103 | case bfloat16: 104 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 105 | break; 106 | case complex64: 107 | arg_reduce_dispatch(in, out, reduce_type_, axis_); 108 | break; 109 | } 110 | } 111 | 112 | } // namespace mlx::core 113 | -------------------------------------------------------------------------------- /mlx/backend/metal/kernels/complex.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | using namespace metal; 8 | 9 | struct complex64_t; 10 | 11 | template 12 | static constexpr constant bool can_convert_to_complex64 = 13 | !is_same_v && is_convertible_v; 14 | 15 | template 16 | static constexpr constant bool can_convert_from_complex64 = 17 | !is_same_v && 18 | (is_convertible_v || is_convertible_v); 19 | 20 | struct complex64_t { 21 | float real; 22 | float imag; 23 | 24 | // Constructors 25 | constexpr complex64_t(float real, float imag) : real(real), imag(imag){}; 26 | 27 | // Conversions to complex64_t 28 | template < 29 | typename T, 30 | typename = typename enable_if>::type> 31 | constexpr complex64_t(T x) thread : real(x), imag(0) {} 32 | 33 | template < 34 | typename T, 35 | typename = typename enable_if>::type> 36 | constexpr complex64_t(T x) threadgroup : real(x), imag(0) {} 37 | 38 | template < 39 | typename T, 40 | typename = typename enable_if>::type> 41 | constexpr complex64_t(T x) device : real(x), imag(0) {} 42 | 43 | template < 44 | typename T, 45 | typename = typename enable_if>::type> 46 | constexpr complex64_t(T x) constant : real(x), imag(0) {} 47 | 48 | // Converstions from complex64_t 49 | template < 50 | typename T, 51 | typename = typename enable_if>::type> 52 | constexpr operator T() const thread { 53 | return static_cast(real); 54 | } 55 | 56 | template < 57 | typename T, 58 | typename = typename enable_if>::type> 59 | constexpr operator T() const threadgroup { 60 | return static_cast(real); 61 | } 62 | 63 | template < 64 | typename T, 65 | typename = typename enable_if>::type> 66 | constexpr operator T() const device { 67 | return static_cast(real); 68 | } 69 | 70 | template < 71 | typename T, 72 | typename = typename enable_if>::type> 73 | constexpr operator T() const constant { 74 | return static_cast(real); 75 | } 76 | }; 77 | 78 | constexpr complex64_t operator-(complex64_t x) { 79 | return {-x.real, -x.imag}; 80 | } 81 | 82 | constexpr bool operator>=(complex64_t a, complex64_t b) { 83 | return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag); 84 | } 85 | 86 | constexpr bool operator>(complex64_t a, complex64_t b) { 87 | return (a.real > b.real) || (a.real == b.real && a.imag > b.imag); 88 | } 89 | 90 | constexpr bool operator<=(complex64_t a, complex64_t b) { 91 | return operator>=(b, a); 92 | } 93 | 94 | constexpr bool operator<(complex64_t a, complex64_t b) { 95 | return operator>(b, a); 96 | } 97 | 98 | constexpr bool operator==(complex64_t a, complex64_t b) { 99 | return a.real == b.real && a.imag == b.imag; 100 | } 101 | 102 | constexpr complex64_t operator+(complex64_t a, complex64_t b) { 103 | return {a.real + b.real, a.imag + b.imag}; 104 | } 105 | 106 | constexpr complex64_t operator-(complex64_t a, complex64_t b) { 107 | return {a.real - b.real, a.imag - b.imag}; 108 | } 109 | 110 | constexpr complex64_t operator*(complex64_t a, complex64_t b) { 111 | return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; 112 | } 113 | -------------------------------------------------------------------------------- /tests/scheduler_tests.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include "doctest/doctest.h" 4 | 5 | #include "mlx/mlx.h" 6 | #include "mlx/scheduler.h" 7 | 8 | using namespace mlx::core; 9 | 10 | TEST_CASE("test stream management") { 11 | auto s1 = default_stream(default_device()); 12 | CHECK_EQ(s1.device, default_device()); 13 | 14 | auto s2 = new_stream(default_device()); 15 | CHECK_EQ(s2.device, default_device()); 16 | CHECK_NE(s1, s2); 17 | 18 | // Check that default streams have the correct devices 19 | if (metal::is_available()) { 20 | auto s_gpu = default_stream(Device::gpu); 21 | CHECK_EQ(s_gpu.device, Device::gpu); 22 | } else { 23 | CHECK_THROWS_AS(default_stream(Device::gpu), std::invalid_argument); 24 | } 25 | auto s_cpu = default_stream(Device::cpu); 26 | CHECK_EQ(s_cpu.device, Device::cpu); 27 | 28 | s_cpu = new_stream(Device::cpu); 29 | CHECK_EQ(s_cpu.device, Device::cpu); 30 | 31 | if (metal::is_available()) { 32 | auto s_gpu = new_stream(Device::gpu); 33 | CHECK_EQ(s_gpu.device, Device::gpu); 34 | } else { 35 | CHECK_THROWS_AS(new_stream(Device::gpu), std::invalid_argument); 36 | } 37 | } 38 | 39 | TEST_CASE("test asynchronous launch") { 40 | auto s1 = default_stream(default_device()); 41 | auto s2 = new_stream(default_device()); 42 | 43 | // Make sure streams execute asynchronously 44 | int x = 1; 45 | auto p1 = std::make_shared>(); 46 | auto p2 = std::make_shared>(); 47 | auto f1 = p1->get_future().share(); 48 | auto f2 = p2->get_future().share(); 49 | auto fn1 = [&x, p = std::move(p1)]() { 50 | x++; 51 | p->set_value(); 52 | }; 53 | auto fn2 = [&x, p = std::move(p2), f = std::move(f1)]() { 54 | f.wait(); 55 | x *= 5; 56 | p->set_value(); 57 | }; 58 | 59 | // fn2 is launched first and is waiting on fn1 but since 60 | // they are on different streams there is no deadlock. 61 | scheduler::enqueue(s2, std::move(fn2)); 62 | scheduler::enqueue(s1, std::move(fn1)); 63 | 64 | f2.wait(); 65 | 66 | CHECK_EQ(x, 10); 67 | } 68 | 69 | TEST_CASE("test stream placement") { 70 | auto s1 = default_stream(default_device()); 71 | auto s2 = new_stream(default_device()); 72 | 73 | { 74 | // Wait on stream 1 75 | auto p = std::make_shared>(); 76 | auto f = p->get_future().share(); 77 | scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); }); 78 | 79 | // Do some work on stream 2 80 | auto x = zeros({100}, float32, s2); 81 | auto y = ones({100}, float32, s2); 82 | auto z = add(x, y, s2); 83 | eval(z); 84 | p->set_value(); 85 | } 86 | 87 | { 88 | // Wait on stream 1 89 | auto p = std::make_shared>(); 90 | auto f = p->get_future().share(); 91 | scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); }); 92 | 93 | // Do some work on stream 2 94 | auto fn = [&s2](array a) { return add(a, add(a, a, s2), s2); }; 95 | auto x = zeros({100}, s2); 96 | 97 | // The whole vjp computation should happen 98 | // on the second stream otherwise this will hang. 99 | auto [out, dout] = vjp(fn, x, ones({100}, s2)); 100 | 101 | // The whole jvp computation should happen on the 102 | // second stream. 103 | std::tie(out, dout) = jvp(fn, x, ones({100}, s2)); 104 | eval(out, dout); 105 | 106 | p->set_value(); 107 | } 108 | } 109 | 110 | TEST_CASE("test scheduler races") { 111 | auto x = zeros({1}); 112 | auto y = zeros({100}); 113 | eval(x, y); 114 | auto a = exp(x); 115 | eval(a); 116 | a = exp(x); 117 | for (int i = 0; i < 10000; ++i) { 118 | y = exp(y); 119 | } 120 | eval(a, y); 121 | } 122 | -------------------------------------------------------------------------------- /docs/src/install.rst: -------------------------------------------------------------------------------- 1 | Build and Install 2 | ================= 3 | 4 | Install from PyPI 5 | ----------------- 6 | 7 | MLX is available on PyPI. All you have to do to use MLX with your own Apple 8 | silicon computer is 9 | 10 | .. code-block:: shell 11 | 12 | pip install mlx 13 | 14 | .. note:: 15 | MLX is only available on devices running MacOS >= 13.3 16 | It is highly recommended to use MacOS 14 (Sonoma) 17 | 18 | Build from source 19 | ----------------- 20 | 21 | Build Requirements 22 | ^^^^^^^^^^^^^^^^^^ 23 | 24 | - A C++ compiler with C++17 support (e.g. Clang >= 5.0) 25 | - `cmake `_ -- version 3.24 or later, and ``make`` 26 | - Xcode >= 14.3 (Xcode >= 15.0 for MacOS 14 and above) 27 | 28 | 29 | Python API 30 | ^^^^^^^^^^ 31 | 32 | To build and install the MLX python library from source, first, clone MLX from 33 | `its GitHub repo `_: 34 | 35 | .. code-block:: shell 36 | 37 | git clone git@github.com:ml-explore/mlx.git mlx && cd mlx 38 | 39 | Make sure that you have `pybind11 `_ 40 | installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows: 41 | 42 | .. code-block:: shell 43 | 44 | pip install "pybind11[global]" 45 | conda install pybind11 46 | brew install pybind11 47 | 48 | Then simply build and install it using pip: 49 | 50 | .. code-block:: shell 51 | 52 | env CMAKE_BUILD_PARALLEL_LEVEL="" pip install . 53 | 54 | For developing use an editable install: 55 | 56 | .. code-block:: shell 57 | 58 | env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . 59 | 60 | To make sure the install is working run the tests with: 61 | 62 | .. code-block:: shell 63 | 64 | python -m unittest discover python/tests 65 | 66 | C++ API 67 | ^^^^^^^ 68 | 69 | Currently, MLX must be built and installed from source. 70 | 71 | Similarly to the python library, to build and install the MLX C++ library start 72 | by cloning MLX from `its GitHub repo 73 | `_: 74 | 75 | .. code-block:: shell 76 | 77 | git clone git@github.com:ml-explore/mlx.git mlx && cd mlx 78 | 79 | Create a build directory and run CMake and make: 80 | 81 | .. code-block:: shell 82 | 83 | mkdir -p build && cd build 84 | cmake .. && make -j 85 | 86 | Run tests with: 87 | 88 | .. code-block:: shell 89 | 90 | make test 91 | 92 | Install with: 93 | 94 | .. code-block:: shell 95 | 96 | make install 97 | 98 | Note that the built ``mlx.metallib`` file should be either at the same 99 | directory as the executable statically linked to ``libmlx.a`` or the 100 | preprocessor constant ``METAL_PATH`` should be defined at build time and it 101 | should point to the path to the built metal library. 102 | 103 | .. list-table:: Build Options 104 | :widths: 25 8 105 | :header-rows: 1 106 | 107 | * - Option 108 | - Default 109 | * - MLX_BUILD_TESTS 110 | - ON 111 | * - MLX_BUILD_EXAMPLES 112 | - OFF 113 | * - MLX_BUILD_BENCHMARKS 114 | - OFF 115 | * - MLX_BUILD_METAL 116 | - ON 117 | * - MLX_BUILD_PYTHON_BINDINGS 118 | - OFF 119 | 120 | 121 | .. note:: 122 | 123 | If you have multiple Xcode installations and wish to use 124 | a specific one while building, you can do so by adding the 125 | following environment variable before building 126 | 127 | .. code-block:: shell 128 | 129 | export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/" 130 | 131 | Further, you can use the following command to find out which 132 | MacOS SDK will be used 133 | 134 | .. code-block:: shell 135 | 136 | xcrun -sdk macosx --show-sdk-version 137 | -------------------------------------------------------------------------------- /mlx/backend/common/default_primitives.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/array.h" 6 | #include "mlx/backend/common/copy.h" 7 | #include "mlx/backend/common/utils.h" 8 | #include "mlx/primitives.h" 9 | 10 | #define DEFAULT(primitive) \ 11 | void primitive::eval_cpu(const std::vector& inputs, array& out) { \ 12 | primitive::eval(inputs, out); \ 13 | } 14 | 15 | namespace mlx::core { 16 | 17 | DEFAULT(Abs) 18 | DEFAULT(Add) 19 | DEFAULT(Arange) 20 | DEFAULT(ArcCos) 21 | DEFAULT(ArcCosh) 22 | DEFAULT(ArcSin) 23 | DEFAULT(ArcSinh) 24 | DEFAULT(ArcTan) 25 | DEFAULT(ArcTanh) 26 | DEFAULT(ArgPartition) 27 | DEFAULT(ArgReduce) 28 | DEFAULT(ArgSort) 29 | DEFAULT(AsType) 30 | DEFAULT(AsStrided) 31 | DEFAULT(Broadcast) 32 | DEFAULT(Concatenate) 33 | DEFAULT(Convolution) 34 | DEFAULT(Copy) 35 | DEFAULT(Cos) 36 | DEFAULT(Cosh) 37 | DEFAULT(Divide) 38 | DEFAULT(Equal) 39 | DEFAULT(Erf) 40 | DEFAULT(ErfInv) 41 | DEFAULT(Exp) 42 | DEFAULT(FFT) 43 | DEFAULT(Full) 44 | DEFAULT(Gather) 45 | DEFAULT(Greater) 46 | DEFAULT(GreaterEqual) 47 | DEFAULT(Less) 48 | DEFAULT(LessEqual) 49 | DEFAULT(Load) 50 | DEFAULT(Log) 51 | DEFAULT(Log1p) 52 | DEFAULT(LogicalNot) 53 | DEFAULT(LogAddExp) 54 | DEFAULT(Maximum) 55 | DEFAULT(Minimum) 56 | DEFAULT(Multiply) 57 | DEFAULT(Negative) 58 | DEFAULT(NotEqual) 59 | DEFAULT(Pad) 60 | DEFAULT(Partition) 61 | DEFAULT(Power) 62 | DEFAULT(RandomBits) 63 | DEFAULT(Reduce) 64 | DEFAULT(Reshape) 65 | DEFAULT(Scan) 66 | DEFAULT(Scatter) 67 | DEFAULT(Sigmoid) 68 | DEFAULT(Sign) 69 | DEFAULT(Sin) 70 | DEFAULT(Sinh) 71 | DEFAULT(Slice) 72 | DEFAULT(Softmax) 73 | DEFAULT(Sort) 74 | DEFAULT(Square) 75 | DEFAULT(Sqrt) 76 | DEFAULT(StopGradient) 77 | DEFAULT(Subtract) 78 | DEFAULT(Tan) 79 | DEFAULT(Tanh) 80 | DEFAULT(Transpose) 81 | 82 | void Matmul::eval_cpu(const std::vector& inputs, array& out) { 83 | if (out.dtype() != float32) { 84 | throw std::runtime_error( 85 | "[Matmul::eval_cpu] Currently only supports float32."); 86 | } 87 | out.set_data(allocator::malloc_or_wait(out.nbytes())); 88 | 89 | auto& a_pre = inputs[0]; 90 | auto& b_pre = inputs[1]; 91 | 92 | auto check_transpose = [](const array& arr) { 93 | auto stx = arr.strides()[arr.ndim() - 2]; 94 | auto sty = arr.strides()[arr.ndim() - 1]; 95 | if (stx == arr.shape(-1) && sty == 1) { 96 | return std::make_tuple(false, stx, arr); 97 | } else if (stx == 1 && sty == arr.shape(-2)) { 98 | return std::make_tuple(true, sty, arr); 99 | } else { 100 | array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); 101 | copy(arr, arr_copy, CopyType::General); 102 | size_t stx = arr.shape(-1); 103 | return std::make_tuple(false, stx, arr_copy); 104 | } 105 | }; 106 | 107 | auto [a_transposed, lda, a] = check_transpose(a_pre); 108 | auto [b_transposed, ldb, b] = check_transpose(b_pre); 109 | int M = a.shape(-2); 110 | int N = b.shape(-1); 111 | int K = a.shape(-1); 112 | for (int i = 0; i < (a.size() / (M * K)); ++i) { 113 | cblas_sgemm( 114 | CblasRowMajor, 115 | a_transposed ? CblasTrans : CblasNoTrans, // transA 116 | b_transposed ? CblasTrans : CblasNoTrans, // transB 117 | M, 118 | N, 119 | K, 120 | 1.0f, // alpha 121 | a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()), 122 | lda, 123 | b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()), 124 | ldb, 125 | 0.0f, // beta 126 | out.data() + M * N * i, 127 | out.shape(-1) // ldc 128 | ); 129 | } 130 | } 131 | 132 | } // namespace mlx::core 133 | -------------------------------------------------------------------------------- /python/tests/test_device.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import unittest 4 | 5 | import mlx.core as mx 6 | 7 | import mlx_tests 8 | 9 | 10 | # Don't inherit from MLXTestCase to avoid call to setUp 11 | class TestDefaultDevice(unittest.TestCase): 12 | def test_mlx_default_device(self): 13 | device = mx.default_device() 14 | if mx.metal.is_available(): 15 | self.assertEqual(device, mx.Device(mx.gpu)) 16 | self.assertEqual(str(device), "Device(gpu, 0)") 17 | self.assertEqual(device, mx.gpu) 18 | self.assertEqual(mx.gpu, device) 19 | else: 20 | self.assertEqual(device.type, mx.Device(mx.cpu)) 21 | with self.assertRaises(ValueError): 22 | mx.set_default_device(mx.gpu) 23 | 24 | 25 | class TestDevice(mlx_tests.MLXTestCase): 26 | def test_device(self): 27 | device = mx.default_device() 28 | 29 | cpu = mx.Device(mx.cpu) 30 | mx.set_default_device(cpu) 31 | self.assertEqual(mx.default_device(), cpu) 32 | self.assertEqual(str(cpu), "Device(cpu, 0)") 33 | 34 | mx.set_default_device(mx.cpu) 35 | self.assertEqual(mx.default_device(), mx.cpu) 36 | self.assertEqual(cpu, mx.cpu) 37 | self.assertEqual(mx.cpu, cpu) 38 | 39 | # Restore device 40 | mx.set_default_device(device) 41 | 42 | def test_op_on_device(self): 43 | x = mx.array(1.0) 44 | y = mx.array(1.0) 45 | 46 | a = mx.add(x, y, stream=None) 47 | b = mx.add(x, y, stream=mx.default_device()) 48 | self.assertEqual(a.item(), b.item()) 49 | b = mx.add(x, y, stream=mx.cpu) 50 | self.assertEqual(a.item(), b.item()) 51 | 52 | if mx.metal.is_available(): 53 | b = mx.add(x, y, stream=mx.gpu) 54 | self.assertEqual(a.item(), b.item()) 55 | 56 | 57 | class TestStream(mlx_tests.MLXTestCase): 58 | def test_stream(self): 59 | s1 = mx.default_stream(mx.default_device()) 60 | self.assertEqual(s1.device, mx.default_device()) 61 | 62 | s2 = mx.new_stream(mx.default_device()) 63 | self.assertEqual(s2.device, mx.default_device()) 64 | self.assertNotEqual(s1, s2) 65 | 66 | if mx.metal.is_available(): 67 | s_gpu = mx.default_stream(mx.gpu) 68 | self.assertEqual(s_gpu.device, mx.gpu) 69 | else: 70 | with self.assertRaises(ValueError): 71 | mx.default_stream(mx.gpu) 72 | 73 | s_cpu = mx.default_stream(mx.cpu) 74 | self.assertEqual(s_cpu.device, mx.cpu) 75 | 76 | s_cpu = mx.new_stream(mx.cpu) 77 | self.assertEqual(s_cpu.device, mx.cpu) 78 | 79 | if mx.metal.is_available(): 80 | s_gpu = mx.new_stream(mx.gpu) 81 | self.assertEqual(s_gpu.device, mx.gpu) 82 | else: 83 | with self.assertRaises(ValueError): 84 | mx.new_stream(mx.gpu) 85 | 86 | def test_op_on_stream(self): 87 | x = mx.array(1.0) 88 | y = mx.array(1.0) 89 | 90 | a = mx.add(x, y, stream=mx.default_stream(mx.default_device())) 91 | 92 | if mx.metal.is_available(): 93 | b = mx.add(x, y, stream=mx.default_stream(mx.gpu)) 94 | self.assertEqual(a.item(), b.item()) 95 | s_gpu = mx.new_stream(mx.gpu) 96 | b = mx.add(x, y, stream=s_gpu) 97 | self.assertEqual(a.item(), b.item()) 98 | 99 | b = mx.add(x, y, stream=mx.default_stream(mx.cpu)) 100 | self.assertEqual(a.item(), b.item()) 101 | s_cpu = mx.new_stream(mx.cpu) 102 | b = mx.add(x, y, stream=s_cpu) 103 | self.assertEqual(a.item(), b.item()) 104 | 105 | 106 | if __name__ == "__main__": 107 | unittest.main() 108 | -------------------------------------------------------------------------------- /python/mlx/nn/layers/activations.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import math 4 | 5 | import mlx.core as mx 6 | from mlx.nn.layers.base import Module 7 | 8 | 9 | def _make_activation_module(f): 10 | def decorator(klass): 11 | klass.__doc__ = f.__doc__ 12 | klass.__call__ = lambda self, x: f(x) 13 | return klass 14 | 15 | return decorator 16 | 17 | 18 | def relu(x): 19 | """Applies the Rectified Linear Unit. 20 | 21 | Simply ``mx.maximum(x, 0)``. 22 | """ 23 | return mx.maximum(x, 0) 24 | 25 | 26 | def silu(x): 27 | r"""Applies the Sigmoid Linear Unit. 28 | 29 | Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is 30 | the logistic sigmoid. 31 | """ 32 | return x * mx.sigmoid(x) 33 | 34 | 35 | def gelu(x): 36 | """Applies the Gaussian Error Linear Units function. 37 | 38 | .. math:: 39 | \\textrm{GELU}(x) = x * \Phi(x) 40 | 41 | where :math:`\Phi(x)` is the Gaussian CDF. 42 | 43 | See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster 44 | approximations. 45 | """ 46 | return x * (1 + mx.erf(x / math.sqrt(2))) / 2 47 | 48 | 49 | def gelu_approx(x): 50 | r"""An approximation to Gaussian Error Linear Unit. 51 | 52 | See :func:`gelu` for the exact computation. 53 | 54 | This function approximates ``gelu`` with a maximum absolute error :math:`< 55 | 0.0003` in the range :math:`[-6, 6]` using the following 56 | 57 | .. math:: 58 | 59 | x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right) 60 | 61 | where :math:`\sigma(\cdot)` is the logistic sigmoid. 62 | """ 63 | return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square())) 64 | 65 | 66 | def gelu_fast_approx(x): 67 | r"""A fast approximation to Gaussian Error Linear Unit. 68 | 69 | See :func:`gelu` for the exact computation. 70 | 71 | This function approximates ``gelu`` with a maximum absolute error :math:`< 72 | 0.015` in the range :math:`[-6, 6]` using the following 73 | 74 | .. math:: 75 | 76 | x = x \sigma\left(1.773 x\right) 77 | 78 | where :math:`\sigma(\cdot)` is the logistic sigmoid. 79 | """ 80 | return x * mx.sigmoid(1.773 * x) 81 | 82 | 83 | @_make_activation_module(relu) 84 | class ReLU(Module): 85 | pass 86 | 87 | 88 | @_make_activation_module(silu) 89 | class SiLU(Module): 90 | pass 91 | 92 | 93 | class GELU(Module): 94 | r"""Applies the Gaussian Error Linear Units. 95 | 96 | .. math:: 97 | \textrm{GELU}(x) = x * \Phi(x) 98 | 99 | where :math:`\Phi(x)` is the Gaussian CDF. 100 | 101 | However, if ``approx`` is set to 'precise' or 'fast' it applies 102 | 103 | .. math:: 104 | \textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\ 105 | \textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right) 106 | 107 | respectively. 108 | 109 | See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the 110 | functional equivalents and information regarding error bounds. 111 | 112 | Args: 113 | approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any. 114 | """ 115 | 116 | def __init__(self, approx="none"): 117 | super().__init__() 118 | 119 | if approx == "none": 120 | self._act = gelu 121 | elif approx == "precise": 122 | self._act = gelu_approx 123 | elif approx == "fast": 124 | self._act = gelu_fast_approx 125 | else: 126 | raise ValueError( 127 | f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given" 128 | ) 129 | 130 | def __call__(self, x): 131 | return self._act(x) 132 | -------------------------------------------------------------------------------- /mlx/backend/common/unary.h: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #pragma once 4 | 5 | #include "mlx/allocator.h" 6 | #include "mlx/array.h" 7 | #include "mlx/backend/common/utils.h" 8 | #include "mlx/utils.h" 9 | 10 | namespace mlx::core { 11 | 12 | namespace { 13 | 14 | struct AbsOp { 15 | template 16 | T operator()(T x) { 17 | return std::abs(x); 18 | } 19 | uint8_t operator()(uint8_t x) { 20 | return x; 21 | } 22 | uint16_t operator()(uint16_t x) { 23 | return x; 24 | } 25 | uint32_t operator()(uint32_t x) { 26 | return x; 27 | } 28 | uint64_t operator()(uint64_t x) { 29 | return x; 30 | } 31 | bool operator()(bool x) { 32 | return x; 33 | } 34 | }; 35 | 36 | struct SignOp { 37 | template 38 | T operator()(T x) { 39 | return (x > T(0)) - (x < T(0)); 40 | } 41 | 42 | uint8_t operator()(uint8_t x) { 43 | return x != 0; 44 | } 45 | uint16_t operator()(uint16_t x) { 46 | return x != 0; 47 | } 48 | uint32_t operator()(uint32_t x) { 49 | return x != 0; 50 | } 51 | uint64_t operator()(uint64_t x) { 52 | return x != 0; 53 | } 54 | }; 55 | 56 | template 57 | void unary_op(const array& a, array& out, Op op) { 58 | const T* a_ptr = a.data(); 59 | if (a.flags().contiguous) { 60 | out.set_data( 61 | allocator::malloc_or_wait(a.data_size() * out.itemsize()), 62 | a.data_size(), 63 | a.strides(), 64 | a.flags()); 65 | T* dst = out.data(); 66 | for (size_t i = 0; i < a.data_size(); ++i) { 67 | dst[i] = op(a_ptr[i]); 68 | } 69 | } else { 70 | out.set_data(allocator::malloc_or_wait(out.nbytes())); 71 | T* dst = out.data(); 72 | for (size_t i = 0; i < out.size(); ++i) { 73 | // TODO this is super inefficient, need to fix. 74 | int a_idx = elem_to_loc(i, a.shape(), a.strides()); 75 | dst[i] = op(a_ptr[a_idx]); 76 | } 77 | } 78 | } 79 | 80 | template 81 | void unary(const array& a, array& out, Op op) { 82 | switch (out.dtype()) { 83 | case bool_: 84 | unary_op(a, out, op); 85 | break; 86 | case uint8: 87 | unary_op(a, out, op); 88 | break; 89 | case uint16: 90 | unary_op(a, out, op); 91 | break; 92 | case uint32: 93 | unary_op(a, out, op); 94 | break; 95 | case uint64: 96 | unary_op(a, out, op); 97 | break; 98 | case int8: 99 | unary_op(a, out, op); 100 | break; 101 | case int16: 102 | unary_op(a, out, op); 103 | break; 104 | case int32: 105 | unary_op(a, out, op); 106 | break; 107 | case int64: 108 | unary_op(a, out, op); 109 | break; 110 | case float16: 111 | unary_op(a, out, op); 112 | break; 113 | case float32: 114 | unary_op(a, out, op); 115 | break; 116 | case bfloat16: 117 | unary_op(a, out, op); 118 | break; 119 | case complex64: 120 | unary_op(a, out, op); 121 | break; 122 | } 123 | } 124 | 125 | template 126 | void unary_fp(const array& a, array& out, Op op) { 127 | switch (out.dtype()) { 128 | case bfloat16: 129 | unary_op(a, out, op); 130 | break; 131 | case float16: 132 | unary_op(a, out, op); 133 | break; 134 | case float32: 135 | unary_op(a, out, op); 136 | break; 137 | case complex64: 138 | unary_op(a, out, op); 139 | break; 140 | default: 141 | std::ostringstream err; 142 | err << "[unary_fp] Does not support " << out.dtype(); 143 | throw std::runtime_error(err.str()); 144 | } 145 | } 146 | 147 | } // namespace 148 | 149 | } // namespace mlx::core 150 | -------------------------------------------------------------------------------- /python/mlx/extension.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import os 4 | import re 5 | import subprocess 6 | import sys 7 | from pathlib import Path 8 | 9 | from setuptools import Extension, setup, find_namespace_packages 10 | from setuptools.command.build_ext import build_ext 11 | 12 | import mlx 13 | 14 | _MLX_PATH = str(mlx.__path__[0]) 15 | 16 | 17 | # A CMakeExtension needs a sourcedir instead of a file list. 18 | class CMakeExtension(Extension): 19 | def __init__(self, name: str, sourcedir: str = "") -> None: 20 | super().__init__(name, sources=[]) 21 | self.sourcedir = os.fspath(Path(sourcedir).resolve()) 22 | 23 | 24 | class CMakeBuild(build_ext): 25 | def build_extension(self, ext: CMakeExtension) -> None: 26 | # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ 27 | ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call] 28 | extdir = ext_fullpath.parent.resolve() 29 | 30 | debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug 31 | cfg = "Debug" if debug else "Release" 32 | 33 | # CMake lets you override the generator - we need to check this. 34 | # Can be set with Conda-Build, for example. 35 | cmake_generator = os.environ.get("CMAKE_GENERATOR", "") 36 | 37 | # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON 38 | # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code 39 | # from Python. 40 | cmake_args = [ 41 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", 42 | f"-DCMAKE_BUILD_TYPE={cfg}", 43 | "-DBUILD_SHARED_LIBS=ON", 44 | ] 45 | build_args = [] 46 | # Adding CMake arguments set as environment variable 47 | # (needed e.g. to build for ARM OSx on conda-forge) 48 | if "CMAKE_ARGS" in os.environ: 49 | cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] 50 | 51 | if sys.platform.startswith("darwin"): 52 | # Cross-compile support for macOS - respect ARCHFLAGS if set 53 | archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) 54 | if archs: 55 | cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] 56 | 57 | # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level 58 | # across all generators. 59 | if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: 60 | # self.parallel is a Python 3 only way to set parallel jobs by hand 61 | # using -j in the build_ext call, not supported by pip or PyPA-build. 62 | if hasattr(self, "parallel") and self.parallel: 63 | # CMake 3.12+ only. 64 | build_args += [f"-j{self.parallel}"] 65 | 66 | build_temp = Path(self.build_temp) / ext.name 67 | if not build_temp.exists(): 68 | build_temp.mkdir(parents=True) 69 | 70 | # Make sure cmake can find MLX 71 | os.environ["MLX_DIR"] = _MLX_PATH 72 | 73 | subprocess.run( 74 | ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True 75 | ) 76 | subprocess.run( 77 | ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True 78 | ) 79 | 80 | def run(self): 81 | super().run() 82 | 83 | # Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102 84 | if self.inplace: 85 | for ext in self.extensions: 86 | if isinstance(ext, CMakeExtension): 87 | # Resolve inplace package dir 88 | build_py = self.get_finalized_command("build_py") 89 | inplace_file, regular_file = self._get_inplace_equivalent( 90 | build_py, ext 91 | ) 92 | 93 | inplace_dir = str(Path(inplace_file).parent.resolve()) 94 | regular_dir = str(Path(regular_file).parent.resolve()) 95 | 96 | self.copy_tree(regular_dir, inplace_dir) 97 | -------------------------------------------------------------------------------- /mlx/graph_utils.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "mlx/graph_utils.h" 10 | #include "mlx/primitives.h" 11 | #include "mlx/utils.h" 12 | 13 | namespace mlx::core { 14 | 15 | using OptionalArrayRef = std::optional>; 16 | 17 | struct ArrayNames { 18 | std::unordered_map names; 19 | 20 | std::string get_name(const array& x) { 21 | auto it = names.find(x.id()); 22 | if (it == names.end()) { 23 | // Get the next name in the sequence 24 | // [A, B, ..., Z, AA, AB, ...] 25 | std::vector letters; 26 | auto var_num = names.size() + 1; 27 | while (var_num > 0) { 28 | letters.push_back('A' + (var_num - 1) % 26); 29 | var_num = (var_num - 1) / 26; 30 | } 31 | std::string name(letters.rbegin(), letters.rend()); 32 | names.insert({x.id(), name}); 33 | return name; 34 | } 35 | return it->second; 36 | } 37 | }; 38 | 39 | void depth_first_traversal( 40 | std::function callback, 41 | const std::vector& outputs) { 42 | std::function recurse; 43 | std::unordered_set cache; 44 | recurse = [&](OptionalArrayRef parent, const array& x, int input_index) { 45 | auto id = x.id(); 46 | if (cache.find(id) != cache.end()) { 47 | return; 48 | } 49 | cache.insert(id); 50 | for (int i = 0; i < x.inputs().size(); i++) { 51 | recurse(x, x.inputs()[i], i); 52 | } 53 | callback(parent, x, input_index); 54 | }; 55 | 56 | for (auto x : outputs) { 57 | recurse(std::nullopt, x, 0); 58 | } 59 | } 60 | 61 | void depth_first_traversal( 62 | std::function callback, 63 | const std::vector& outputs) { 64 | depth_first_traversal( 65 | [&callback](OptionalArrayRef p, const array& x, int input_index) { 66 | callback(x); 67 | }, 68 | outputs); 69 | } 70 | 71 | void print_graph(std::ostream& os, const std::vector& outputs) { 72 | std::vector tape; 73 | std::vector inputs; 74 | 75 | depth_first_traversal( 76 | [&](const array& x) { 77 | if (x.has_primitive()) { 78 | tape.push_back(x); 79 | } else { 80 | inputs.push_back(x); 81 | } 82 | }, 83 | outputs); 84 | 85 | ArrayNames namer; 86 | auto print_arr = [&namer, &os](const array& a) { 87 | os << namer.get_name(a); 88 | os << " [" << a.shape() << ", " << a.dtype() << "]"; 89 | }; 90 | 91 | auto print_arrs = [&](const std::vector& arrs) { 92 | for (auto& arr : arrs) { 93 | print_arr(arr); 94 | if (&arr != &arrs.back()) { 95 | os << ", "; 96 | } 97 | } 98 | }; 99 | 100 | os << "Inputs: "; 101 | print_arrs(inputs); 102 | os << "\nOutputs: "; 103 | print_arrs(outputs); 104 | os << "\n"; 105 | 106 | for (auto& arr : tape) { 107 | arr.primitive().print(os); 108 | os << " "; 109 | print_arrs(arr.inputs()); 110 | os << " -> "; 111 | print_arr(arr); 112 | os << "\n"; 113 | } 114 | } 115 | 116 | void export_to_dot(std::ostream& os, const std::vector& outputs) { 117 | os << "digraph {" << std::endl; 118 | 119 | ArrayNames namer; 120 | depth_first_traversal( 121 | [&namer, &os](auto parent, const array& x, int input_index) { 122 | os << "{ "; 123 | if (!x.has_primitive()) { 124 | os << "rank=source; "; 125 | } 126 | if (!parent) { 127 | os << "rank=sink; "; 128 | } 129 | os << namer.get_name(x); 130 | if (x.has_primitive()) { 131 | os << " [label =\""; 132 | x.primitive().print(os); 133 | os << "\"]"; 134 | } 135 | os << "; }" << std::endl; 136 | 137 | for (auto c : x.inputs()) { 138 | os << namer.get_name(c) << " -> " << namer.get_name(x) << std::endl; 139 | } 140 | }, 141 | outputs); 142 | 143 | os << "}"; 144 | } 145 | 146 | } // namespace mlx::core 147 | -------------------------------------------------------------------------------- /benchmarks/python/llama_mlx_bench.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | import math 4 | import time 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | import mlx.utils 9 | 10 | 11 | class LlamaAttention(nn.Module): 12 | def __init__(self, dims: int, num_heads: int): 13 | super().__init__() 14 | self.num_heads = num_heads 15 | self.rope = nn.RoPE(dims // num_heads, True) 16 | self.query_proj = nn.Linear(dims, dims, False) 17 | self.key_proj = nn.Linear(dims, dims, False) 18 | self.value_proj = nn.Linear(dims, dims, False) 19 | self.out_proj = nn.Linear(dims, dims, False) 20 | 21 | def __call__(self, queries, keys, values, mask=None, cache=None): 22 | queries = self.query_proj(queries) 23 | keys = self.key_proj(keys) 24 | values = self.value_proj(values) 25 | 26 | num_heads = self.num_heads 27 | B, L, D = queries.shape 28 | queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3)) 29 | keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3)) 30 | values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3)) 31 | 32 | if cache is not None: 33 | key_cache, value_cache = cache 34 | queries = self.rope(queries, offset=key_cache.shape[2]) 35 | keys = self.rope(keys, offset=key_cache.shape[2]) 36 | keys = mx.concatenate([key_cache, keys], axis=2) 37 | values = mx.concatenate([value_cache, values], axis=2) 38 | else: 39 | queries = self.rope(queries) 40 | keys = self.rope(keys) 41 | 42 | # Dimensions are [batch x num heads x sequence x hidden dim] 43 | scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype) 44 | scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2)) 45 | if mask is not None: 46 | scores = scores + mask 47 | scores = mx.softmax(scores, axis=-1) 48 | values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1)) 49 | 50 | return self.out_proj(values_hat), (keys, values) 51 | 52 | 53 | class LlamaEncoderLayer(nn.Module): 54 | def __init__(self, dims: int, mlp_dims: int, num_heads: int): 55 | super().__init__() 56 | 57 | self.attention = LlamaAttention(dims, num_heads) 58 | 59 | self.norm1 = nn.RMSNorm(dims) 60 | self.norm2 = nn.RMSNorm(dims) 61 | 62 | self.linear1 = nn.Linear(dims, mlp_dims, False) 63 | self.linear2 = nn.Linear(dims, mlp_dims, False) 64 | self.linear3 = nn.Linear(mlp_dims, dims, False) 65 | 66 | def __call__(self, x, mask=None, cache=None): 67 | y = self.norm1(x) 68 | y, cache = self.attention(y, y, y, mask, cache) 69 | x = x + y 70 | 71 | y = self.norm2(x) 72 | a = self.linear1(y) 73 | b = self.linear2(y) 74 | y = a * mx.sigmoid(a) * b 75 | y = self.linear3(y) 76 | x = x + y 77 | 78 | return x, cache 79 | 80 | 81 | def measure(model, x, cache): 82 | for i in range(5): 83 | y, c = model(x, mask=None, cache=cache) 84 | mx.eval(y, c) 85 | 86 | start = time.time() 87 | rs = [] 88 | for i in range(5): 89 | y, c = model(x, mask=None, cache=cache) 90 | rs.append((y, c)) 91 | mx.eval(rs) 92 | end = time.time() 93 | 94 | return (end - start) * 1000 / 5 95 | 96 | 97 | if __name__ == "__main__": 98 | H = 32 99 | D = 4096 100 | F = 43 * 256 101 | C = 1000 102 | mx.set_default_device(mx.gpu) 103 | dtype = mx.float16 104 | 105 | layer = LlamaEncoderLayer(D, F, H) 106 | layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters())) 107 | k1, k2, k3 = mx.random.split(mx.random.key(0), 3) 108 | x = mx.random.normal([1, 1, D], dtype=dtype) 109 | cache = [ 110 | mx.random.normal([1, H, C, D // H], dtype=dtype), 111 | mx.random.normal([1, H, C, D // H], dtype=dtype), 112 | ] 113 | mx.eval(x, cache) 114 | 115 | T = measure(layer, x, cache) 116 | 117 | print("Time per layer per token:", T, "ms") 118 | print("Lower bound total time per token:", T * 32, "ms") 119 | -------------------------------------------------------------------------------- /mlx/backend/metal/copy.cpp: -------------------------------------------------------------------------------- 1 | // Copyright © 2023 Apple Inc. 2 | 3 | #include 4 | 5 | #include "mlx/backend/metal/copy.h" 6 | #include "mlx/backend/metal/device.h" 7 | #include "mlx/backend/metal/kernels/defines.h" 8 | #include "mlx/backend/metal/utils.h" 9 | #include "mlx/primitives.h" 10 | 11 | namespace mlx::core { 12 | 13 | void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { 14 | if (ctype == CopyType::Vector) { 15 | out.set_data( 16 | allocator::malloc_or_wait(in.data_size() * out.itemsize()), 17 | in.data_size(), 18 | in.strides(), 19 | in.flags()); 20 | } else { 21 | out.set_data(allocator::malloc_or_wait(out.nbytes())); 22 | } 23 | if (ctype == CopyType::GeneralGeneral) { 24 | ctype = CopyType::General; 25 | } 26 | copy_gpu_inplace(in, out, ctype, s); 27 | } 28 | 29 | void copy_gpu(const array& in, array& out, CopyType ctype) { 30 | copy_gpu(in, out, ctype, out.primitive().stream()); 31 | } 32 | 33 | void copy_gpu_inplace( 34 | const array& in, 35 | array& out, 36 | CopyType ctype, 37 | const Stream& s) { 38 | // Try to collapse contiguous dims 39 | auto [shape, strides] = collapse_contiguous_dims(in, out); 40 | auto& strides_in = strides[0]; 41 | auto& strides_out = strides[1]; 42 | 43 | auto& d = metal::device(s.device); 44 | std::ostringstream kname; 45 | switch (ctype) { 46 | case CopyType::Scalar: 47 | kname << "scopy"; 48 | break; 49 | case CopyType::Vector: 50 | kname << "vcopy"; 51 | break; 52 | case CopyType::General: 53 | kname << "gcopy"; 54 | break; 55 | case CopyType::GeneralGeneral: 56 | kname << "ggcopy"; 57 | break; 58 | } 59 | kname << type_to_name(in) << type_to_name(out); 60 | if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) && 61 | shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { 62 | kname << "_" << shape.size(); 63 | } 64 | auto kernel = d.get_kernel(kname.str()); 65 | auto compute_encoder = d.get_command_encoder(s.index); 66 | compute_encoder->setComputePipelineState(kernel); 67 | set_array_buffer(compute_encoder, in, 0); 68 | set_array_buffer(compute_encoder, out, 1); 69 | 70 | if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { 71 | size_t ndim = shape.size(); 72 | if (ndim > 3) { 73 | compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2); 74 | compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3); 75 | if (ctype == CopyType::GeneralGeneral) { 76 | compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4); 77 | } 78 | } else { 79 | // The shape is implicit in the grid for <= 3D 80 | compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2); 81 | if (ctype == CopyType::GeneralGeneral) { 82 | compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3); 83 | } 84 | } 85 | 86 | if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { 87 | compute_encoder->setBytes( 88 | &ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4); 89 | } 90 | 91 | int dim0 = ndim > 0 ? shape[ndim - 1] : 1; 92 | int dim1 = ndim > 1 ? shape[ndim - 2] : 1; 93 | int rest = in.size() / (dim0 * dim1); 94 | 95 | // NB assuming thread_group_size is a power of 2 larger than 32 x 32 96 | NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); 97 | if (thread_group_size != 1024) { 98 | throw std::runtime_error("[Metal::copy] Must use 1024 sized block"); 99 | } 100 | auto group_dims = get_block_dims(dim0, dim1, rest); 101 | MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); 102 | compute_encoder->dispatchThreads(grid_dims, group_dims); 103 | } else { 104 | size_t nthreads = out.data_size(); 105 | MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); 106 | NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); 107 | if (thread_group_size > nthreads) { 108 | thread_group_size = nthreads; 109 | } 110 | MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); 111 | compute_encoder->dispatchThreads(grid_dims, group_dims); 112 | } 113 | } 114 | 115 | } // namespace mlx::core 116 | --------------------------------------------------------------------------------