├── docs
├── .nojekyll
├── .gitignore
├── .clang-format
├── src
│ ├── cpp
│ │ └── ops.rst
│ ├── _static
│ │ └── mlx_logo.png
│ ├── python
│ │ ├── nn
│ │ │ └── module.rst
│ │ ├── transforms.rst
│ │ ├── fft.rst
│ │ ├── devices_and_streams.rst
│ │ ├── tree_utils.rst
│ │ ├── array.rst
│ │ ├── random.rst
│ │ ├── data_types.rst
│ │ ├── ops.rst
│ │ └── optimizers.rst
│ ├── _templates
│ │ ├── optimizers-template.rst
│ │ └── nn-module-template.rst
│ ├── using_streams.rst
│ ├── conf.py
│ ├── quick_start.rst
│ ├── index.rst
│ ├── examples
│ │ └── linear_regression.rst
│ └── install.rst
├── index.html
├── Makefile
└── README.md
├── mlx
├── 3rdparty
│ └── .clang-format
├── backend
│ ├── no_metal
│ │ ├── CMakeLists.txt
│ │ ├── allocator.cpp
│ │ ├── metal.cpp
│ │ └── primitives.cpp
│ ├── accelerate
│ │ ├── CMakeLists.txt
│ │ ├── conv.cpp
│ │ └── utils.h
│ ├── common
│ │ ├── erf.h
│ │ ├── threefry.h
│ │ ├── CMakeLists.txt
│ │ ├── utils.h
│ │ ├── copy.h
│ │ ├── threefry.cpp
│ │ ├── load.cpp
│ │ ├── erf.cpp
│ │ ├── arange.h
│ │ ├── fft.cpp
│ │ ├── softmax.cpp
│ │ ├── arg_reduce.cpp
│ │ ├── default_primitives.cpp
│ │ └── unary.h
│ └── metal
│ │ ├── fft.cpp
│ │ ├── copy.h
│ │ ├── kernels
│ │ ├── defines.h
│ │ ├── conv_params.h
│ │ ├── arange.metal
│ │ ├── CMakeLists.txt
│ │ ├── erf.h
│ │ ├── random.metal
│ │ └── complex.h
│ │ ├── metal.h
│ │ ├── matmul.h
│ │ ├── CMakeLists.txt
│ │ ├── allocator.h
│ │ ├── device.h
│ │ ├── metal.cpp
│ │ ├── softmax.cpp
│ │ └── copy.cpp
├── mlx.h
├── transforms_impl.h
├── device.h
├── graph_utils.h
├── stream.h
├── device.cpp
├── CMakeLists.txt
├── allocator.cpp
├── utils.h
├── scheduler.cpp
├── types
│ ├── half_types.h
│ └── complex.h
├── allocator.h
├── dtype.h
├── load.h
└── graph_utils.cpp
├── MANIFEST.in
├── pyproject.toml
├── examples
├── extensions
│ ├── mlx_sample_extensions
│ │ └── __init__.py
│ ├── setup.py
│ ├── bindings.cpp
│ ├── CMakeLists.txt
│ └── axpby
│ │ ├── axpby.metal
│ │ └── axpby.h
├── cpp
│ ├── CMakeLists.txt
│ ├── timer.h
│ ├── logistic_regression.cpp
│ ├── linear_regression.cpp
│ └── tutorial.cpp
└── python
│ ├── logistic_regression.py
│ └── linear_regression.py
├── python
├── mlx
│ ├── nn
│ │ ├── __init__.py
│ │ ├── losses.py
│ │ ├── layers
│ │ │ ├── containers.py
│ │ │ ├── __init__.py
│ │ │ ├── embedding.py
│ │ │ ├── dropout.py
│ │ │ ├── linear.py
│ │ │ └── activations.py
│ │ └── utils.py
│ ├── _reprlib_fix.py
│ └── extension.py
├── src
│ ├── metal.cpp
│ ├── indexing.h
│ ├── load.h
│ ├── mlx.cpp
│ ├── stream.cpp
│ ├── CMakeLists.txt
│ ├── device.cpp
│ └── utils.h
├── tests
│ ├── mlx_tests.py
│ ├── test_tree.py
│ ├── test_optimizers.py
│ ├── test_eval.py
│ ├── test_fft.py
│ └── test_device.py
└── README.md
├── .pre-commit-config.yaml
├── benchmarks
├── cpp
│ ├── CMakeLists.txt
│ ├── compare_devices.cpp
│ ├── autograd.cpp
│ └── time_utils.h
├── numpy
│ ├── time_utils.py
│ └── single_ops.py
└── python
│ ├── time_utils.py
│ ├── comparative
│ └── README.md
│ ├── batch_matmul_bench.py
│ ├── single_ops.py
│ └── llama_mlx_bench.py
├── tests
├── tests.cpp
├── utils_tests.cpp
├── graph_optimize_tests.cpp
├── CMakeLists.txt
├── device_tests.cpp
├── allocator_tests.cpp
├── load_tests.cpp
├── eval_tests.cpp
├── blas_tests.cpp
└── scheduler_tests.cpp
├── .gitignore
├── LICENSE
├── CONTRIBUTING.md
├── mlx.pc.in
├── cmake
└── extension.cmake
├── .clang-format
└── README.md
/docs/.nojekyll:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | src/python/_autosummary*/
2 |
--------------------------------------------------------------------------------
/docs/.clang-format:
--------------------------------------------------------------------------------
1 | DisableFormat: true
2 | SortIncludes: Never
3 |
--------------------------------------------------------------------------------
/mlx/3rdparty/.clang-format:
--------------------------------------------------------------------------------
1 | DisableFormat: true
2 | SortIncludes: Never
3 |
--------------------------------------------------------------------------------
/docs/src/cpp/ops.rst:
--------------------------------------------------------------------------------
1 | .. _cpp_ops:
2 |
3 | Operations
4 | ==========
5 |
6 |
7 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include CMakeLists.txt
2 | recursive-include mlx/ *
3 | include python/src/*
4 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/src/_static/mlx_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vaibhavs10/mlx/main/docs/src/_static/mlx_logo.png
--------------------------------------------------------------------------------
/docs/src/python/nn/module.rst:
--------------------------------------------------------------------------------
1 | mlx.nn.Module
2 | =============
3 |
4 | .. currentmodule:: mlx.nn
5 |
6 | .. autoclass:: Module
7 | :members:
8 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24"]
3 | build-backend = "setuptools.build_meta"
4 |
--------------------------------------------------------------------------------
/examples/extensions/mlx_sample_extensions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import mlx.core as mx
4 | from .mlx_sample_extensions import *
5 |
--------------------------------------------------------------------------------
/python/mlx/nn/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | from mlx.nn.layers import *
4 | from mlx.nn import losses
5 | from mlx.nn.utils import value_and_grad
6 |
--------------------------------------------------------------------------------
/mlx/backend/no_metal/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(
2 | mlx
3 | PRIVATE
4 | ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
5 | ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
6 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
7 | )
8 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/mirrors-clang-format
3 | rev: v14.0.6
4 | hooks:
5 | - id: clang-format
6 | - repo: https://github.com/psf/black
7 | rev: 22.10.0
8 | hooks:
9 | - id: black
10 |
--------------------------------------------------------------------------------
/docs/src/python/transforms.rst:
--------------------------------------------------------------------------------
1 | .. _transforms:
2 |
3 | Transforms
4 | ==========
5 |
6 | .. currentmodule:: mlx.core
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | eval
12 | grad
13 | value_and_grad
14 | jvp
15 | vjp
16 | vmap
17 |
--------------------------------------------------------------------------------
/python/mlx/nn/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import mlx.core as mx
4 |
5 |
6 | def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1):
7 | score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
8 | return mx.logsumexp(logits, axis=axis) - score
9 |
--------------------------------------------------------------------------------
/docs/src/python/fft.rst:
--------------------------------------------------------------------------------
1 | .. _fft:
2 |
3 | FFT
4 | ===
5 |
6 | .. currentmodule:: mlx.core.fft
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | fft
12 | ifft
13 | fft2
14 | ifft2
15 | fftn
16 | ifftn
17 | rfft
18 | irfft
19 | rfft2
20 | irfft2
21 | rfftn
22 | irfftn
23 |
--------------------------------------------------------------------------------
/mlx/backend/accelerate/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(
2 | mlx
3 | PRIVATE
4 | ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
5 | ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
6 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
7 | ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
8 | ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
9 | )
10 |
--------------------------------------------------------------------------------
/mlx/mlx.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 | #include "mlx/backend/metal/metal.h"
7 | #include "mlx/device.h"
8 | #include "mlx/fft.h"
9 | #include "mlx/ops.h"
10 | #include "mlx/random.h"
11 | #include "mlx/stream.h"
12 | #include "mlx/transforms.h"
13 | #include "mlx/utils.h"
14 |
--------------------------------------------------------------------------------
/mlx/backend/common/erf.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | namespace mlx::core {
4 |
5 | /* Approximation to the inverse error function.
6 | * Based on code from:
7 | * https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
8 | */
9 | float erfinv(float a);
10 |
11 | } // namespace mlx::core
12 |
--------------------------------------------------------------------------------
/mlx/backend/metal/fft.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/primitives.h"
4 |
5 | namespace mlx::core {
6 |
7 | void FFT::eval_gpu(const std::vector& inputs, array& out) {
8 | auto& in = inputs[0];
9 | throw std::runtime_error("[FFT] NYI for Metal backend.");
10 | }
11 |
12 | } // namespace mlx::core
13 |
--------------------------------------------------------------------------------
/docs/src/python/devices_and_streams.rst:
--------------------------------------------------------------------------------
1 | .. _devices_and_streams:
2 |
3 | Devices and Streams
4 | ===================
5 |
6 | .. currentmodule:: mlx.core
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | Device
12 | default_device
13 | set_default_device
14 | Stream
15 | default_stream
16 | new_stream
17 | set_default_stream
18 |
--------------------------------------------------------------------------------
/mlx/backend/no_metal/allocator.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/allocator.h"
4 |
5 | namespace mlx::core::allocator {
6 |
7 | Allocator& allocator() {
8 | static CommonAllocator allocator_;
9 | return allocator_;
10 | }
11 |
12 | void* Buffer::raw_ptr() {
13 | return ptr_;
14 | }
15 |
16 | } // namespace mlx::core::allocator
17 |
--------------------------------------------------------------------------------
/python/src/metal.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/backend/metal/metal.h"
6 |
7 | namespace py = pybind11;
8 |
9 | using namespace mlx::core;
10 |
11 | void init_metal(py::module_& m) {
12 | py::module_ metal = m.def_submodule("metal", "mlx.metal");
13 | metal.def("is_available", &metal::is_available);
14 | }
15 |
--------------------------------------------------------------------------------
/examples/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | function(build_example SRCFILE)
2 | get_filename_component(src_name ${SRCFILE} NAME_WE)
3 | set(target "${src_name}")
4 | add_executable(${target} ${SRCFILE})
5 | target_link_libraries(${target} PRIVATE mlx)
6 | endfunction(build_example)
7 |
8 | build_example(tutorial.cpp)
9 | build_example(linear_regression.cpp)
10 | build_example(logistic_regression.cpp)
11 |
--------------------------------------------------------------------------------
/python/src/indexing.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | #include "mlx/array.h"
8 | #include "python/src/utils.h"
9 |
10 | namespace py = pybind11;
11 | using namespace mlx::core;
12 |
13 | array mlx_get_item(const array& src, const py::object& obj);
14 | void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v);
15 |
--------------------------------------------------------------------------------
/examples/cpp/timer.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | namespace timer {
8 |
9 | using namespace std::chrono;
10 |
11 | template
12 | inline double seconds(duration x) {
13 | return duration_cast(x).count() / 1e9;
14 | }
15 |
16 | inline auto time() {
17 | return high_resolution_clock::now();
18 | }
19 |
20 | } // namespace timer
21 |
--------------------------------------------------------------------------------
/benchmarks/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | function(build_benchmark SRCFILE)
2 | get_filename_component(src_name ${SRCFILE} NAME_WE)
3 | set(target "${src_name}")
4 | add_executable(${target} ${SRCFILE})
5 | target_link_libraries(${target} PRIVATE mlx)
6 | endfunction(build_benchmark)
7 |
8 | build_benchmark(single_ops.cpp)
9 | build_benchmark(irregular_strides.cpp)
10 | build_benchmark(compare_devices.cpp)
11 | build_benchmark(autograd.cpp)
12 |
--------------------------------------------------------------------------------
/benchmarks/numpy/time_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import time
4 |
5 |
6 | def time_fn(fn, *args):
7 | print(f"Timing {fn.__name__} ...", end=" ")
8 |
9 | # warmup
10 | for _ in range(5):
11 | fn(*args)
12 |
13 | num_iters = 100
14 | tic = time.perf_counter()
15 | for _ in range(num_iters):
16 | x = fn(*args)
17 | toc = time.perf_counter()
18 |
19 | msec = 1e3 * (toc - tic) / num_iters
20 | print(f"{msec:.5f} msec")
21 |
--------------------------------------------------------------------------------
/mlx/backend/metal/copy.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/backend/common/copy.h"
6 | #include "mlx/stream.h"
7 |
8 | namespace mlx::core {
9 |
10 | void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
11 | void copy_gpu(const array& src, array& out, CopyType ctype);
12 | void copy_gpu_inplace(
13 | const array& src,
14 | array& out,
15 | CopyType ctype,
16 | const Stream& s);
17 |
18 | } // namespace mlx::core
19 |
--------------------------------------------------------------------------------
/docs/src/_templates/optimizers-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 |
7 | {% block methods %}
8 |
9 | {% if methods %}
10 | .. rubric:: {{ _('Methods') }}
11 |
12 | .. autosummary::
13 | {% for item in methods %}
14 | {%- if item not in inherited_members %}
15 | ~{{ name }}.{{ item }}
16 | {%- endif %}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}
20 |
21 |
--------------------------------------------------------------------------------
/python/tests/mlx_tests.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import os
4 | import unittest
5 |
6 | import mlx.core as mx
7 |
8 |
9 | class MLXTestCase(unittest.TestCase):
10 | def setUp(self):
11 | self.default = mx.default_device()
12 | device = os.getenv("DEVICE", None)
13 | if device is not None:
14 | device = getattr(mx, device)
15 | mx.set_default_device(device)
16 |
17 | def tearDown(self):
18 | mx.set_default_device(self.default)
19 |
--------------------------------------------------------------------------------
/mlx/backend/accelerate/conv.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include
6 | #include
7 |
8 | #include "mlx/backend/common/copy.h"
9 | #include "mlx/primitives.h"
10 | #include "mlx/utils.h"
11 |
12 | namespace mlx::core {
13 |
14 | void Convolution::eval_cpu(const std::vector& inputs, array& out) {
15 | eval(inputs, out);
16 |
17 | // TODO: Add accelerate based optimizations for CPU conv
18 | }
19 |
20 | } // namespace mlx::core
21 |
--------------------------------------------------------------------------------
/docs/src/_templates/nn-module-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 |
7 | {#{% block methods %}
8 |
9 | {% if methods %}
10 | .. rubric:: {{ _('Methods') }}
11 |
12 | .. autosummary::
13 | {% for item in methods %}
14 | {%- if item not in inherited_members and item != '__init__' %}
15 | ~{{ name }}.{{ item }}
16 | {%- endif %}
17 | {%- endfor %}
18 | {% endif %}
19 | {% endblock %}#}
20 |
--------------------------------------------------------------------------------
/mlx/backend/no_metal/metal.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/backend/metal/metal.h"
6 |
7 | namespace mlx::core::metal {
8 |
9 | void new_stream(Stream) {}
10 |
11 | std::function make_task(
12 | array& arr,
13 | std::vector> deps,
14 | std::shared_ptr> p,
15 | bool retain_graph) {
16 | throw std::runtime_error(
17 | "[metal::make_task] Cannot make GPU task without metal backend");
18 | }
19 |
20 | } // namespace mlx::core::metal
21 |
--------------------------------------------------------------------------------
/benchmarks/python/time_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import time
4 |
5 | import mlx.core as mx
6 |
7 |
8 | def time_fn(fn, *args, **kwargs):
9 | print(f"Timing {fn.__name__} ...", end=" ")
10 |
11 | # warmup
12 | for _ in range(5):
13 | mx.eval(fn(*args, **kwargs))
14 |
15 | num_iters = 100
16 | tic = time.perf_counter()
17 | for _ in range(num_iters):
18 | x = mx.eval(fn(*args, **kwargs))
19 | toc = time.perf_counter()
20 |
21 | msec = 1e3 * (toc - tic) / num_iters
22 | print(f"{msec:.5f} msec")
23 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/defines.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #ifdef __METAL__
6 | #define MTL_CONST constant
7 | #else
8 | #define MTL_CONST
9 | #endif
10 |
11 | static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
12 | static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
13 | static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
14 | static MTL_CONST constexpr int REDUCE_N_READS = 16;
15 | static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
16 | static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
17 |
--------------------------------------------------------------------------------
/mlx/transforms_impl.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | namespace mlx::core::detail {
4 |
5 | std::pair, std::vector> vmap_trace(
6 | const std::function(const std::vector&)>& fun,
7 | const std::vector& inputs,
8 | const std::vector& in_axes);
9 |
10 | std::vector vmap_replace(
11 | const std::vector& inputs,
12 | const std::vector& s_inputs,
13 | const std::vector& s_outputs,
14 | const std::vector& in_axes,
15 | const std::vector& out_axes);
16 |
17 | } // namespace mlx::core::detail
18 |
--------------------------------------------------------------------------------
/python/mlx/_reprlib_fix.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import array
4 | import reprlib
5 |
6 |
7 | class FixedRepr(reprlib.Repr):
8 | """Only route python array instances to repr_array."""
9 |
10 | def repr_array(self, x, maxlevel):
11 | if isinstance(x, array.array):
12 | return super().repr_array(x, maxlevel)
13 | else:
14 | return self.repr_instance(x, maxlevel)
15 |
16 |
17 | # We need to monkey-patch reprlib so that we can use the debugger without
18 | # renaming the array to something else
19 | fixed_repr = FixedRepr()
20 | reprlib.repr = fixed_repr.repr
21 |
--------------------------------------------------------------------------------
/tests/tests.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #define DOCTEST_CONFIG_IMPLEMENT
4 | #include "doctest/doctest.h"
5 |
6 | #include
7 |
8 | #include "mlx/mlx.h"
9 |
10 | using namespace mlx::core;
11 |
12 | int main(int argc, char** argv) {
13 | doctest::Context context;
14 |
15 | const char* device = std::getenv("DEVICE");
16 | if (device != nullptr && std::string(device) == "cpu") {
17 | set_default_device(Device::cpu);
18 | } else if (metal::is_available()) {
19 | set_default_device(Device::gpu);
20 | }
21 |
22 | context.applyCommandLine(argc, argv);
23 | return context.run();
24 | }
25 |
--------------------------------------------------------------------------------
/python/src/load.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 | #include
8 | #include "mlx/ops.h"
9 |
10 | namespace py = pybind11;
11 | using namespace mlx::core;
12 |
13 | using DictOrArray = std::variant>;
14 |
15 | DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
16 | void mlx_save_helper(py::object file, array a, bool retain_graph = true);
17 | void mlx_savez_helper(
18 | py::object file,
19 | py::args args,
20 | const py::kwargs& kwargs,
21 | bool compressed = false);
--------------------------------------------------------------------------------
/docs/src/python/tree_utils.rst:
--------------------------------------------------------------------------------
1 | .. _utils:
2 |
3 | Tree Utils
4 | ==========
5 |
6 | In MLX we consider a python tree to be an arbitrarily nested collection of
7 | dictionaries, lists and tuples without cycles. Functions in this module that
8 | return python trees will be using the default python ``dict``, ``list`` and
9 | ``tuple`` but they can usually process objects that inherit from any of these.
10 |
11 | .. note::
12 | Dictionaries should have keys that are valid python identifiers.
13 |
14 | .. currentmodule:: mlx.utils
15 |
16 | .. autosummary::
17 | :toctree: _autosummary
18 |
19 | tree_flatten
20 | tree_unflatten
21 | tree_map
22 |
--------------------------------------------------------------------------------
/mlx/backend/metal/metal.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 | #include
8 |
9 | #include "mlx/array.h"
10 | #include "mlx/stream.h"
11 |
12 | namespace mlx::core::metal {
13 |
14 | constexpr bool is_available() {
15 | #ifdef _METAL_
16 | return true;
17 | #else
18 | return false;
19 | #endif
20 | }
21 |
22 | void new_stream(Stream stream);
23 |
24 | std::function make_task(
25 | array& arr,
26 | std::vector> deps,
27 | std::shared_ptr> p,
28 | bool retain_graph);
29 |
30 | } // namespace mlx::core::metal
31 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 |
3 | # You can set these variables from the command line.
4 | SPHINXOPTS =
5 | SPHINXBUILD = sphinx-build
6 | SOURCEDIR = src
7 | BUILDDIR = build
8 |
9 | # Put it first so that "make" without argument is like "make help".
10 | help:
11 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
12 |
13 | .PHONY: help Makefile
14 |
15 | # Catch-all target: route all unknown targets to Sphinx using the new
16 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
17 | %: Makefile
18 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
19 |
--------------------------------------------------------------------------------
/mlx/backend/common/threefry.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 |
8 | namespace mlx::core::random {
9 |
10 | /** Applies the Threefry 2x32 hash function.
11 | * This code is based on the Jax counter-based and splittable PRNG
12 | * https://github.com/google/jax/blob/main/docs/jep/263-prng.md
13 | *
14 | * Original Threefry reference:
15 | * http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
16 | */
17 | std::pair threefry2x32_hash(
18 | const std::pair& key,
19 | std::pair count);
20 |
21 | } // namespace mlx::core::random
22 |
--------------------------------------------------------------------------------
/mlx/device.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | namespace mlx::core {
6 |
7 | struct Device {
8 | enum class DeviceType {
9 | cpu,
10 | gpu,
11 | };
12 |
13 | static constexpr DeviceType cpu = DeviceType::cpu;
14 | static constexpr DeviceType gpu = DeviceType::gpu;
15 |
16 | Device(DeviceType type, int index = 0) : type(type), index(index){};
17 |
18 | DeviceType type;
19 | int index;
20 | };
21 |
22 | const Device& default_device();
23 |
24 | void set_default_device(const Device& d);
25 |
26 | bool operator==(const Device& lhs, const Device& rhs);
27 | bool operator!=(const Device& lhs, const Device& rhs);
28 |
29 | } // namespace mlx::core
30 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/conv_params.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | template
6 | struct MLXConvParams {
7 | const int N; // Batch size
8 | const int C; // In channels
9 | const int O; // Out channels
10 | const int iS[NDIM]; // Input spatial dim
11 | const int wS[NDIM]; // Weight spatial dim
12 | const int oS[NDIM]; // Output spatial dim
13 | const int str[NDIM]; // Kernel strides
14 | const int pad[NDIM]; // Input padding
15 | const int dil[NDIM]; // Kernel dilation
16 | const size_t in_strides[NDIM + 2]; // In strides
17 | const size_t wt_strides[NDIM + 2]; // Wt strides
18 | const size_t out_strides[NDIM + 2]; // Out strides
19 | };
20 |
--------------------------------------------------------------------------------
/mlx/graph_utils.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 |
7 | namespace mlx::core {
8 |
9 | void print_graph(std::ostream& os, const std::vector& outputs);
10 |
11 | template
12 | void print_graph(std::ostream& os, Arrays... outputs) {
13 | print_graph(os, std::vector{std::forward(outputs)...});
14 | }
15 |
16 | void export_to_dot(std::ostream& os, const std::vector& outputs);
17 |
18 | template
19 | void export_to_dot(std::ostream& os, Arrays... outputs) {
20 | export_to_dot(os, std::vector{std::forward(outputs)...});
21 | }
22 |
23 | } // namespace mlx::core
24 |
--------------------------------------------------------------------------------
/examples/extensions/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | from mlx import extension
4 | from setuptools import setup
5 |
6 | if __name__ == "__main__":
7 | setup(
8 | name="mlx_sample_extensions",
9 | version="0.0.0",
10 | description="Sample C++ and Metal extensions for MLX primitives.",
11 | ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
12 | cmdclass={"build_ext": extension.CMakeBuild},
13 | packages=["mlx_sample_extensions"],
14 | package_dir={"": "."},
15 | package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
16 | zip_safe=False,
17 | python_requires=">=3.7",
18 | )
19 |
--------------------------------------------------------------------------------
/mlx/backend/common/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(
2 | mlx
3 | PRIVATE
4 | ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
5 | ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
6 | ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
7 | ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
8 | ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
9 | ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
10 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
11 | ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
12 | ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
13 | ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
14 | ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
15 | ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
16 | ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
17 | ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
18 | )
19 |
--------------------------------------------------------------------------------
/python/README.md:
--------------------------------------------------------------------------------
1 | ### Packaging for PyPI
2 |
3 | Install `build` and `twine`:
4 |
5 | ```
6 | pip install --user --upgrade build
7 | pip install --user --upgrade twine
8 | ```
9 |
10 | Generate the source distribution and wheel:
11 |
12 | ```
13 | python -m build
14 | ```
15 |
16 | *Warning* use a test server first
17 |
18 | #### Test Upload
19 |
20 | Upload to test server:
21 |
22 | ```
23 | python -m twine upload --repository testpypi dist/*
24 | ```
25 |
26 | Install from test server and check that it works:
27 |
28 | ```
29 | python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx
30 | ```
31 |
32 | #### Upload
33 |
34 | ```
35 | python -m twine upload dist/*
36 | ```
37 |
38 |
--------------------------------------------------------------------------------
/docs/src/using_streams.rst:
--------------------------------------------------------------------------------
1 | Using Streams
2 | =============
3 |
4 | .. currentmodule:: mlx.core
5 |
6 | Specifying the :obj:`Stream`
7 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8 |
9 | All operations (including random number generation) take an optional
10 | keyword argument ``stream``. The ``stream`` kwarg specifies which
11 | :obj:`Stream` the operation should run on. If the stream is unspecified then
12 | the operation is run on the default stream of the default device:
13 | ``mx.default_stream(mx.default_device())``. The ``stream`` kwarg can also
14 | be a :obj:`Device` (e.g. ``stream=my_device``) in which case the operation is
15 | run on the default stream of the provided device
16 | ``mx.default_stream(my_device)``.
17 |
--------------------------------------------------------------------------------
/benchmarks/cpp/compare_devices.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include "mlx/mlx.h"
5 | #include "time_utils.h"
6 |
7 | using namespace mlx::core;
8 |
9 | void time_add_op() {
10 | std::vector sizes(1, 1);
11 | for (int i = 0; i < 9; ++i) {
12 | sizes.push_back(10 * sizes.back());
13 | }
14 | set_default_device(Device::cpu);
15 | for (auto size : sizes) {
16 | auto a = random::uniform({size});
17 | auto b = random::uniform({size});
18 | eval(a, b);
19 | std::cout << "Size " << size << std::endl;
20 | TIMEM("cpu", add, a, b, Device::cpu);
21 | TIMEM("gpu", add, a, b, Device::gpu);
22 | }
23 | }
24 |
25 | int main() {
26 | time_add_op();
27 | }
28 |
--------------------------------------------------------------------------------
/benchmarks/python/comparative/README.md:
--------------------------------------------------------------------------------
1 | Microbenchmarks comparing MLX to PyTorch
2 | ========================================
3 |
4 | Implement the same microbenchmarks in MLX and PyTorch to compare and make a
5 | list of the biggest possible performance improvements and/or regressions.
6 |
7 | Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for
8 | instance to measure the times it takes to sum across the 3rd axis of the above
9 | tensor on the cpu.
10 |
11 | `compare.py` runs several benchmarks and compares the speed-up or lack thereof
12 | in comparison to PyTorch.
13 |
14 | Each bench script can be run with `--print-pid` to print the PID and wait for a
15 | key in order to ease attaching a debugger.
16 |
--------------------------------------------------------------------------------
/tests/utils_tests.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "doctest/doctest.h"
4 |
5 | #include "mlx/mlx.h"
6 |
7 | using namespace mlx::core;
8 |
9 | TEST_CASE("test type promotion") {
10 | for (auto t : {bool_, uint32, int32, int64, float32}) {
11 | auto a = array(0, t);
12 | CHECK_EQ(result_type({a}), t);
13 |
14 | std::vector arrs = {array(0, t), array(0, t)};
15 | CHECK_EQ(result_type(arrs), t);
16 | }
17 |
18 | {
19 | std::vector arrs = {array(false), array(0, int32)};
20 | CHECK_EQ(result_type(arrs), int32);
21 | }
22 |
23 | {
24 | std::vector arrs = {array(0, int32), array(false), array(0.0f)};
25 | CHECK_EQ(result_type(arrs), float32);
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/mlx/backend/metal/matmul.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | #include "mlx/backend/metal/copy.h"
8 | #include "mlx/backend/metal/device.h"
9 | #include "mlx/backend/metal/mps/gemm.h"
10 | #include "mlx/backend/metal/utils.h"
11 | #include "mlx/utils.h"
12 |
13 | namespace mlx::core {
14 |
15 | void mlx_matmul(
16 | const Stream& s,
17 | metal::Device& d,
18 | const array& a,
19 | const array& b,
20 | array& out,
21 | int M,
22 | int N,
23 | int K,
24 | int batch_size_out,
25 | int lda,
26 | int ldb,
27 | bool transpose_a,
28 | bool transpose_b,
29 | std::vector& copies);
30 |
31 | } // namespace mlx::core
--------------------------------------------------------------------------------
/python/mlx/nn/layers/containers.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | from mlx.nn.layers.base import Module
4 |
5 |
6 | class Sequential(Module):
7 | """A layer that calls the passed callables in order.
8 |
9 | We can pass either modules or plain callables to the Sequential module. If
10 | our functions have learnable parameters they should be implemented as
11 | ``nn.Module`` instances.
12 |
13 | Args:
14 | modules (tuple of Callables): The modules to call in order
15 | """
16 |
17 | def __init__(self, *modules):
18 | super().__init__()
19 | self.layers = list(modules)
20 |
21 | def __call__(self, x):
22 | for m in self.layers:
23 | x = m(x)
24 | return x
25 |
--------------------------------------------------------------------------------
/mlx/backend/common/utils.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | #include "mlx/array.h"
8 |
9 | namespace mlx::core {
10 |
11 | inline size_t elem_to_loc(
12 | int elem,
13 | const std::vector& shape,
14 | const std::vector& strides) {
15 | size_t loc = 0;
16 | for (int i = shape.size() - 1; i >= 0; --i) {
17 | auto q_and_r = ldiv(elem, shape[i]);
18 | loc += q_and_r.rem * strides[i];
19 | elem = q_and_r.quot;
20 | }
21 | return loc;
22 | }
23 |
24 | inline size_t elem_to_loc(int elem, const array& a) {
25 | if (a.flags().row_contiguous) {
26 | return elem;
27 | }
28 | return elem_to_loc(elem, a.shape(), a.strides());
29 | }
30 |
31 | } // namespace mlx::core
32 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | ## Build the Docs
2 |
3 | ### Setup (do once)
4 |
5 | Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
6 | for example with `conda`:
7 |
8 | ```
9 | conda install sphinx
10 | pip install sphinx-book-theme
11 | ```
12 |
13 | ### Build
14 |
15 | Build the docs from `mlx/docs/`
16 |
17 | ```
18 | make html
19 | ```
20 |
21 | View the docs by running a server in `mlx/docs/build/html/`:
22 |
23 | ```
24 | python -m http.server
25 | ```
26 |
27 | and point your browser to `http://localhost:`.
28 |
29 | ### Push to Github Pages
30 |
31 | Check-out the `gh-pages` branch (`git switch gh-pages`) and build
32 | the docs. Then force add the `build/html` directory:
33 |
34 | `git add -f build/html`
35 |
36 | Commit and push the changes to the `gh-pages` branch.
37 |
--------------------------------------------------------------------------------
/mlx/stream.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/device.h"
6 |
7 | namespace mlx::core {
8 |
9 | struct Stream {
10 | int index;
11 | Device device;
12 | explicit Stream(int index, Device device) : index(index), device(device) {}
13 | };
14 |
15 | /** Get the default stream for the given device. */
16 | Stream default_stream(Device d);
17 |
18 | /** Make the stream the default for its device. */
19 | void set_default_stream(Stream s);
20 |
21 | /** Make a new stream on the given device. */
22 | Stream new_stream(Device d);
23 |
24 | inline bool operator==(const Stream& lhs, const Stream& rhs) {
25 | return lhs.index == rhs.index;
26 | }
27 |
28 | inline bool operator!=(const Stream& lhs, const Stream& rhs) {
29 | return !(lhs == rhs);
30 | }
31 |
32 | } // namespace mlx::core
33 |
--------------------------------------------------------------------------------
/mlx/backend/common/copy.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "mlx/array.h"
6 | #include "mlx/backend/common/utils.h"
7 |
8 | namespace mlx::core {
9 |
10 | enum class CopyType {
11 | // Copy a raw scalar input into the full contiguous output
12 | Scalar,
13 |
14 | // Copy the raw input buffer contiguously into a raw output buffer of the same
15 | // size
16 | Vector,
17 |
18 | // Copy the full virtual input to the full contiguous output
19 | General,
20 |
21 | // Copy the full virtual input to the full virtual output. We assume the
22 | // input and output have the same shape.
23 | GeneralGeneral
24 | };
25 |
26 | void copy(const array& src, array& dst, CopyType ctype);
27 | void copy_inplace(const array& src, array& dst, CopyType ctype);
28 |
29 | } // namespace mlx::core
30 |
--------------------------------------------------------------------------------
/python/tests/test_tree.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import unittest
4 |
5 | import mlx.core as mx
6 | import mlx.utils
7 |
8 | import mlx_tests
9 |
10 |
11 | class TestTreeUtils(mlx_tests.MLXTestCase):
12 | def test_tree_map(self):
13 | tree = {"a": 0, "b": 1, "c": 2}
14 | tree = mlx.utils.tree_map(lambda x: x + 1, tree)
15 |
16 | expected_tree = {"a": 1, "b": 2, "c": 3}
17 | self.assertEqual(tree, expected_tree)
18 |
19 | def test_tree_flatten(self):
20 | tree = [{"a": 1, "b": 2}, "c"]
21 | vals = (1, 2, "c")
22 | flat_tree = mlx.utils.tree_flatten(tree)
23 | self.assertEqual(list(zip(*flat_tree))[1], vals)
24 | self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
25 |
26 |
27 | if __name__ == "__main__":
28 | unittest.main()
29 |
--------------------------------------------------------------------------------
/python/mlx/nn/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | from mlx.nn.layers.base import Module
4 | from mlx.nn.layers.activations import (
5 | GELU,
6 | ReLU,
7 | SiLU,
8 | gelu,
9 | gelu_approx,
10 | gelu_fast_approx,
11 | relu,
12 | silu,
13 | )
14 | from mlx.nn.layers.containers import Sequential
15 | from mlx.nn.layers.convolution import Conv1d, Conv2d
16 | from mlx.nn.layers.dropout import Dropout
17 | from mlx.nn.layers.embedding import Embedding
18 | from mlx.nn.layers.linear import Linear
19 | from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
20 | from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
21 | from mlx.nn.layers.transformer import (
22 | MultiHeadAttention,
23 | TransformerEncoder,
24 | TransformerEncoderLayer,
25 | )
26 |
--------------------------------------------------------------------------------
/tests/graph_optimize_tests.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "doctest/doctest.h"
4 |
5 | #include "mlx/mlx.h"
6 |
7 | using namespace mlx::core;
8 |
9 | TEST_CASE("test simplify scalars") {
10 | auto a = array({-1.0f, 2.0f});
11 | auto b = maximum(a, array(0.0f));
12 | auto c = maximum(-a, array(0.0f));
13 | auto d = b + c;
14 | simplify({d});
15 | CHECK(b.inputs()[1].id() == c.inputs()[1].id());
16 | }
17 |
18 | TEST_CASE("test simplify") {
19 | auto a = array({1.0f, 2.0f});
20 | auto b = exp(a) + exp(a);
21 | simplify(b);
22 | eval(b);
23 | CHECK(b.inputs()[0].id() == b.inputs()[1].id());
24 | }
25 |
26 | TEST_CASE("test no simplify") {
27 | auto a = array({1.0f, 2.0f});
28 | auto b = cos(a) + sin(a);
29 | simplify(b);
30 | eval(b);
31 | CHECK(b.inputs()[0].id() != b.inputs()[1].id());
32 | }
33 |
--------------------------------------------------------------------------------
/docs/src/python/array.rst:
--------------------------------------------------------------------------------
1 | .. _array:
2 |
3 | Array
4 | =====
5 |
6 | .. currentmodule:: mlx.core
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | array
12 | array.astype
13 | array.item
14 | array.tolist
15 | array.dtype
16 | array.ndim
17 | array.shape
18 | array.size
19 | Dtype
20 | array.abs
21 | array.all
22 | array.any
23 | array.argmax
24 | array.argmin
25 | array.cos
26 | array.dtype
27 | array.exp
28 | array.log
29 | array.log1p
30 | array.logsumexp
31 | array.max
32 | array.mean
33 | array.min
34 | array.prod
35 | array.reciprocal
36 | array.reshape
37 | array.rsqrt
38 | array.sin
39 | array.split
40 | array.sqrt
41 | array.square
42 | array.sum
43 | array.transpose
44 | array.T
45 | array.var
46 |
--------------------------------------------------------------------------------
/mlx/device.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/device.h"
4 | #include "mlx/backend/metal/metal.h"
5 |
6 | namespace mlx::core {
7 |
8 | static Device default_device_{
9 | metal::is_available() ? Device::gpu : Device::cpu};
10 |
11 | const Device& default_device() {
12 | return default_device_;
13 | }
14 |
15 | void set_default_device(const Device& d) {
16 | if (!metal::is_available() && d == Device::gpu) {
17 | throw std::invalid_argument(
18 | "[set_default_device] Cannot set gpu device without gpu backend.");
19 | }
20 | default_device_ = d;
21 | }
22 |
23 | bool operator==(const Device& lhs, const Device& rhs) {
24 | return lhs.type == rhs.type && lhs.index == rhs.index;
25 | }
26 |
27 | bool operator!=(const Device& lhs, const Device& rhs) {
28 | return !(lhs == rhs);
29 | }
30 |
31 | } // namespace mlx::core
32 |
--------------------------------------------------------------------------------
/mlx/backend/accelerate/utils.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include "mlx/dtype.h"
7 |
8 | namespace mlx::core {
9 |
10 | BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
11 | uint32_t size_bits = size_of(mlx_dtype) * 8;
12 | switch (kindof(mlx_dtype)) {
13 | case Dtype::Kind::b:
14 | return BNNSDataTypeBoolean;
15 | case Dtype::Kind::u:
16 | return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
17 | case Dtype::Kind::i:
18 | return BNNSDataType(BNNSDataTypeIntBit | size_bits);
19 | case Dtype::Kind::f:
20 | return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
21 | case Dtype::Kind::V:
22 | return BNNSDataTypeBFloat16;
23 | case Dtype::Kind::c:
24 | throw std::invalid_argument("BNNS does not support complex types");
25 | }
26 | }
27 |
28 | } // namespace mlx::core
--------------------------------------------------------------------------------
/python/src/mlx.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #define STRINGIFY(x) #x
6 | #define TOSTRING(x) STRINGIFY(x)
7 |
8 | namespace py = pybind11;
9 |
10 | void init_array(py::module_&);
11 | void init_device(py::module_&);
12 | void init_stream(py::module_&);
13 | void init_metal(py::module_&);
14 | void init_ops(py::module_&);
15 | void init_transforms(py::module_&);
16 | void init_random(py::module_&);
17 | void init_fft(py::module_&);
18 |
19 | PYBIND11_MODULE(core, m) {
20 | m.doc() = "mlx: A framework for machine learning on Apple Silicon.";
21 |
22 | auto reprlib_fix = py::module_::import("mlx._reprlib_fix");
23 |
24 | init_device(m);
25 | init_stream(m);
26 | init_array(m);
27 | init_metal(m);
28 | init_ops(m);
29 | init_transforms(m);
30 | init_random(m);
31 | init_fft(m);
32 | m.attr("__version__") = TOSTRING(_VERSION_);
33 | }
34 |
--------------------------------------------------------------------------------
/mlx/backend/metal/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(
2 | mlx
3 | PRIVATE
4 | ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
5 | ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
6 | ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
7 | ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
8 | ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
9 | ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
10 | ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
11 | ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
12 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
13 | ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
14 | ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
15 | ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
16 | ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
17 | )
18 |
19 | if (NOT MLX_METAL_PATH)
20 | set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/)
21 | endif()
22 |
23 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels)
24 |
25 | target_compile_definitions(
26 | mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib")
27 |
--------------------------------------------------------------------------------
/tests/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | FetchContent_Declare(
2 | doctest
3 | GIT_REPOSITORY "https://github.com/onqtam/doctest"
4 | GIT_TAG "b7c21ec5ceeadb4951b00396fc1e4642dd347e5f"
5 | )
6 | FetchContent_MakeAvailable(doctest)
7 |
8 | add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp)
9 |
10 | if (MLX_BUILD_METAL)
11 | set(
12 | METAL_TEST_SOURCES
13 | metal_tests.cpp
14 | )
15 | endif()
16 |
17 | target_sources(tests PRIVATE
18 | allocator_tests.cpp
19 | array_tests.cpp
20 | arg_reduce_tests.cpp
21 | autograd_tests.cpp
22 | blas_tests.cpp
23 | creations_tests.cpp
24 | device_tests.cpp
25 | eval_tests.cpp
26 | fft_tests.cpp
27 | graph_optimize_tests.cpp
28 | load_tests.cpp
29 | ops_tests.cpp
30 | random_tests.cpp
31 | scheduler_tests.cpp
32 | utils_tests.cpp
33 | vmap_tests.cpp
34 | ${METAL_TEST_SOURCES}
35 | )
36 |
37 | target_link_libraries(tests PRIVATE mlx doctest)
38 | add_test(NAME tests COMMAND tests)
39 |
--------------------------------------------------------------------------------
/mlx/backend/common/threefry.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/backend/common/threefry.h"
4 |
5 | namespace mlx::core::random {
6 |
7 | std::pair threefry2x32_hash(
8 | const std::pair& key,
9 | std::pair count) {
10 | constexpr static uint32_t rotations[2][4] = {
11 | {13, 15, 26, 6}, {17, 29, 16, 24}};
12 |
13 | uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA};
14 |
15 | count.first += ks[0];
16 | count.second += ks[1];
17 |
18 | for (int i = 0; i < 5; ++i) {
19 | for (auto r : rotations[i % 2]) {
20 | count.first += count.second;
21 | count.second = (count.second << r) | (count.second >> (32 - r));
22 | count.second ^= count.first;
23 | }
24 | count.first += ks[(i + 1) % 3];
25 | count.second += ks[(i + 2) % 3] + i + 1;
26 | }
27 |
28 | return count;
29 | }
30 |
31 | } // namespace mlx::core::random
32 |
--------------------------------------------------------------------------------
/python/tests/test_optimizers.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import unittest
4 |
5 | import mlx.core as mx
6 | import mlx.optimizers as opt
7 | import mlx.utils
8 |
9 | import mlx_tests
10 |
11 |
12 | class TestOptimizers(mlx_tests.MLXTestCase):
13 | def test_optimizers(self):
14 | params = {
15 | "first": [mx.zeros((10,)), mx.zeros((1,))],
16 | "second": mx.zeros((1,)),
17 | }
18 | grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
19 |
20 | for optim in [opt.SGD(0.1), opt.Adam(0.1)]:
21 | update = optim.apply_gradients(grads, params)
22 | mx.eval(update)
23 | equal_shape = mlx.utils.tree_map(
24 | lambda x, y: x.shape == y.shape, params, update
25 | )
26 | all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
27 | self.assertTrue(all_equal)
28 |
29 |
30 | if __name__ == "__main__":
31 | unittest.main()
32 |
--------------------------------------------------------------------------------
/tests/device_tests.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "doctest/doctest.h"
4 |
5 | #include
6 |
7 | #include "mlx/mlx.h"
8 |
9 | using namespace mlx::core;
10 |
11 | TEST_CASE("test device placement") {
12 | auto device = default_device();
13 | Device d = metal::is_available() ? Device::gpu : Device::cpu;
14 | if (std::getenv("DEVICE") == nullptr) {
15 | CHECK_EQ(device, d);
16 | }
17 |
18 | array x(1.0f);
19 | array y(1.0f);
20 | auto z = add(x, y, default_device());
21 | if (metal::is_available()) {
22 | z = add(x, y, Device::gpu);
23 | z = add(x, y, Device(Device::gpu, 0));
24 | } else {
25 | CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument);
26 | CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument);
27 | }
28 |
29 | // Set the default device to the CPU
30 | set_default_device(Device::cpu);
31 | CHECK_EQ(default_device(), Device::cpu);
32 |
33 | // Revert
34 | set_default_device(device);
35 | }
36 |
--------------------------------------------------------------------------------
/python/tests/test_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | from functools import partial
4 |
5 | import unittest
6 |
7 | import mlx.core as mx
8 |
9 | import mlx_tests
10 |
11 |
12 | class TestEval(mlx_tests.MLXTestCase):
13 | def test_eval(self):
14 | arrs = [mx.ones((2, 2)) for _ in range(4)]
15 | mx.eval(*arrs)
16 | for x in arrs:
17 | self.assertEqual(x.tolist(), [[1, 1], [1, 1]])
18 |
19 | def test_retain_graph(self):
20 | def fun(x, retain_graph):
21 | y = 3 * x
22 | mx.eval(y, retain_graph=retain_graph)
23 | return 2 * y
24 |
25 | dfun_dx_1 = mx.grad(partial(fun, retain_graph=False))
26 | dfun_dx_2 = mx.grad(partial(fun, retain_graph=True))
27 |
28 | with self.assertRaises(ValueError):
29 | dfun_dx_1(mx.array(1.0))
30 |
31 | y = dfun_dx_2(mx.array(1.0))
32 | self.assertEqual(y.item(), 6.0)
33 |
34 |
35 | if __name__ == "__main__":
36 | unittest.main()
37 |
--------------------------------------------------------------------------------
/benchmarks/numpy/single_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import numpy as np
4 |
5 | from time_utils import time_fn
6 |
7 |
8 | def time_add():
9 | a = np.ones((100, 100, 10), dtype=np.float32)
10 | b = np.ones((100, 100, 10), dtype=np.float32)
11 | time_fn(np.add, a, b)
12 |
13 |
14 | def time_matmul():
15 | a = np.random.rand(1000, 500).astype(np.float32)
16 | b = np.random.rand(500, 1000).astype(np.float32)
17 | time_fn(np.matmul, a, b)
18 |
19 |
20 | def time_exp():
21 | a = np.random.randn(1000, 100).astype(np.float32)
22 | time_fn(np.exp, a)
23 |
24 |
25 | def time_take():
26 | a = np.random.rand(10000, 500)
27 | ids = np.random.randint(0, 10000, (20, 10))
28 | ids = [idx.reshape(-1) for idx in np.split(ids, 20)]
29 |
30 | def random_take():
31 | return [np.take(a, idx, 0) for idx in ids]
32 |
33 | time_fn(random_take)
34 |
35 |
36 | if __name__ == "__main__":
37 | time_add()
38 | time_matmul()
39 | time_exp()
40 | time_take()
41 |
--------------------------------------------------------------------------------
/python/mlx/nn/layers/embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import math
4 |
5 | import mlx.core as mx
6 | from mlx.nn.layers.base import Module
7 |
8 |
9 | class Embedding(Module):
10 | """Implements a simple lookup table that maps each input integer to a
11 | high-dimensional vector.
12 |
13 | Typically used to embed discrete tokens for processing by neural networks.
14 |
15 | Args:
16 | num_embeddings (int): How many possible discrete tokens can we embed.
17 | Usually called the vocabulary size.
18 | dims (int): The dimensionality of the embeddings.
19 | """
20 |
21 | def __init__(self, num_embeddings: int, dims: int):
22 | super().__init__()
23 | scale = math.sqrt(1 / dims)
24 | self.weight = mx.random.normal((num_embeddings, dims)) * scale
25 |
26 | def _extra_repr(self):
27 | return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
28 |
29 | def __call__(self, x):
30 | return self.weight[x]
31 |
--------------------------------------------------------------------------------
/benchmarks/cpp/autograd.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include "mlx/mlx.h"
6 | #include "time_utils.h"
7 |
8 | using namespace mlx::core;
9 |
10 | void time_value_and_grad() {
11 | auto x = ones({200, 1000});
12 | eval(x);
13 | auto fn = [](array x) {
14 | for (int i = 0; i < 20; ++i) {
15 | x = log(exp(x));
16 | }
17 | return sum(x);
18 | };
19 |
20 | auto grad_fn = grad(fn);
21 | auto independent_value_and_grad = [&]() {
22 | auto value = fn(x);
23 | auto dfdx = grad_fn(x);
24 | return std::vector{value, dfdx};
25 | };
26 | TIME(independent_value_and_grad);
27 |
28 | auto value_and_grad_fn = value_and_grad(fn);
29 | auto combined_value_and_grad = [&]() {
30 | auto [value, dfdx] = value_and_grad_fn(x);
31 | return std::vector{value, dfdx};
32 | };
33 | TIME(combined_value_and_grad);
34 | }
35 |
36 | int main() {
37 | std::cout << "Benchmarks for " << default_device() << std::endl;
38 | time_value_and_grad();
39 | }
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Metal libraries
10 | *.metallib
11 |
12 | # Distribution / packaging
13 | python/mlx/share
14 | python/mlx/include
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # vim
35 | *.swp
36 |
37 | # Ignore build dir
38 | build/
39 |
40 | # Prerequisites
41 | *.d
42 |
43 | # Compiled Object files
44 | *.slo
45 | *.lo
46 | *.o
47 | *.obj
48 |
49 | # Precompiled Headers
50 | *.gch
51 | *.pch
52 |
53 | # Compiled Dynamic libraries
54 | *.so
55 | *.dylib
56 | *.dll
57 |
58 | # Fortran module files
59 | *.mod
60 | *.smod
61 |
62 | # Compiled Static libraries
63 | *.lai
64 | *.la
65 | *.a
66 | *.lib
67 |
68 | # Executables
69 | *.exe
70 | *.out
71 | *.app
72 |
73 | # VSCode
74 | .vscode/
75 | .DS_Store
76 |
--------------------------------------------------------------------------------
/python/src/stream.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include
6 |
7 | #include "mlx/stream.h"
8 | #include "mlx/utils.h"
9 |
10 | namespace py = pybind11;
11 | using namespace py::literals;
12 | using namespace mlx::core;
13 |
14 | void init_stream(py::module_& m) {
15 | py::class_(m, "Stream")
16 | .def(py::init(), "index"_a, "device"_a)
17 | .def_readonly("device", &Stream::device)
18 | .def(
19 | "__repr__",
20 | [](const Stream& s) {
21 | std::ostringstream os;
22 | os << s;
23 | return os.str();
24 | })
25 | .def("__eq__", [](const Stream& s1, const Stream& s2) {
26 | return s1 == s2;
27 | });
28 |
29 | py::implicitly_convertible();
30 |
31 | m.def("default_stream", &default_stream, "device"_a);
32 | m.def("set_default_stream", &set_default_stream, "stream"_a);
33 | m.def("new_stream", &new_stream, "device"_a);
34 | }
35 |
--------------------------------------------------------------------------------
/python/src/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | pybind11_add_module(
2 | core
3 | ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp
4 | ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
5 | ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
6 | ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
7 | ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
8 | ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
9 | ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
10 | ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
11 | ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
12 | ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
13 | ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
14 | )
15 |
16 | if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
17 | set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
18 | endif()
19 |
20 | set_target_properties(
21 | core
22 | PROPERTIES
23 | LIBRARY_OUTPUT_DIRECTORY
24 | ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY}
25 | )
26 |
27 | target_link_libraries(core PRIVATE mlx)
28 | target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION})
29 |
30 | if(BUILD_SHARED_LIBS)
31 | target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib)
32 | endif()
33 |
--------------------------------------------------------------------------------
/examples/python/logistic_regression.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import mlx.core as mx
4 | import time
5 |
6 | num_features = 100
7 | num_examples = 1_000
8 | num_iters = 10_000
9 | lr = 0.1
10 |
11 | # True parameters
12 | w_star = mx.random.normal((num_features,))
13 |
14 | # Input examples
15 | X = mx.random.normal((num_examples, num_features))
16 |
17 | # Labels
18 | y = (X @ w_star) > 0
19 |
20 |
21 | # Initialize random parameters
22 | w = 1e-2 * mx.random.normal((num_features,))
23 |
24 |
25 | def loss_fn(w):
26 | logits = X @ w
27 | return mx.mean(mx.logaddexp(0.0, logits) - y * logits)
28 |
29 |
30 | grad_fn = mx.grad(loss_fn)
31 |
32 | tic = time.time()
33 | for _ in range(num_iters):
34 | grad = grad_fn(w)
35 | w = w - lr * grad
36 | mx.eval(w)
37 |
38 | toc = time.time()
39 |
40 | loss = loss_fn(w)
41 | final_preds = (X @ w) > 0
42 | acc = mx.mean(final_preds == y)
43 |
44 | throughput = num_iters / (toc - tic)
45 | print(
46 | f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} "
47 | f"Throughput {throughput:.5f} (it/s)"
48 | )
49 |
--------------------------------------------------------------------------------
/mlx/backend/metal/kernels/arange.metal:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/backend/metal/kernels/bf16.h"
4 |
5 | template
6 | [[kernel]] void arange(
7 | constant const T& start,
8 | constant const T& step,
9 | device T* out,
10 | uint index [[thread_position_in_grid]]) {
11 | out[index] = start + index * step;
12 | }
13 |
14 | #define instantiate_arange(tname, type) \
15 | template [[host_name("arange" #tname)]] \
16 | [[kernel]] void arange( \
17 | constant const type& start, \
18 | constant const type& step, \
19 | device type* out, \
20 | uint index [[thread_position_in_grid]]);
21 |
22 | instantiate_arange(uint8, uint8_t)
23 | instantiate_arange(uint16, uint16_t)
24 | instantiate_arange(uint32, uint32_t)
25 | instantiate_arange(uint64, uint64_t)
26 | instantiate_arange(int8, int8_t)
27 | instantiate_arange(int16, int16_t)
28 | instantiate_arange(int32, int32_t)
29 | instantiate_arange(int64, int64_t)
30 | instantiate_arange(float16, half)
31 | instantiate_arange(float32, float)
32 | instantiate_arange(bfloat16, bfloat16_t)
--------------------------------------------------------------------------------
/examples/python/linear_regression.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import mlx.core as mx
4 | import time
5 |
6 | num_features = 100
7 | num_examples = 1_000
8 | num_iters = 10_000
9 | lr = 0.01
10 |
11 | # True parameters
12 | w_star = mx.random.normal((num_features,))
13 |
14 | # Input examples (design matrix)
15 | X = mx.random.normal((num_examples, num_features))
16 |
17 | # Noisy labels
18 | eps = 1e-2 * mx.random.normal((num_examples,))
19 | y = X @ w_star + eps
20 |
21 | # Initialize random parameters
22 | w = 1e-2 * mx.random.normal((num_features,))
23 |
24 |
25 | def loss_fn(w):
26 | return 0.5 * mx.mean(mx.square(X @ w - y))
27 |
28 |
29 | grad_fn = mx.grad(loss_fn)
30 |
31 | tic = time.time()
32 | for _ in range(num_iters):
33 | grad = grad_fn(w)
34 | w = w - lr * grad
35 | mx.eval(w)
36 | toc = time.time()
37 |
38 | loss = loss_fn(w)
39 | error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
40 | throughput = num_iters / (toc - tic)
41 |
42 | print(
43 | f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
44 | f"Throughput {throughput:.5f} (it/s)"
45 | )
46 |
--------------------------------------------------------------------------------
/python/mlx/nn/layers/dropout.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import mlx.core as mx
4 | from mlx.nn.layers.base import Module
5 |
6 |
7 | class Dropout(Module):
8 | """Randomly zero a portion of the elements during training.
9 |
10 | The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
11 | :math:`p` is the probability of zeroing an element. This is done so the
12 | expected value of a given element will remain the same.
13 |
14 | Args:
15 | p (float): The probability to zero an element
16 | """
17 |
18 | def __init__(self, p: float = 0.5):
19 | super().__init__()
20 |
21 | if p < 0 or p >= 1:
22 | raise ValueError("The dropout probability should be in [0, 1)")
23 |
24 | self._p_1 = 1 - p
25 |
26 | def _extra_repr(self):
27 | return f"p={1-self._p_1}"
28 |
29 | def __call__(self, x):
30 | if self._p_1 == 1 or not self.training:
31 | return x
32 |
33 | mask = mx.random.bernoulli(self._p_1, x.shape)
34 |
35 | return (1 / self._p_1) * mask.astype(x.dtype) * x
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright © 2023 Apple Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/python/mlx/nn/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | from typing import Callable
4 |
5 | import mlx.core as mx
6 |
7 |
8 | def value_and_grad(model: "mlx.nn.Module", fn: Callable):
9 | """Transform the passed function ``fn`` to a function that computes the
10 | gradients of ``fn`` wrt the model's trainable parameters and also its
11 | value.
12 |
13 | Args:
14 | model (mlx.nn.Module): The model whose trainable parameters to compute
15 | gradients for
16 | fn (Callable): The scalar function to compute gradients for
17 |
18 | Returns:
19 | A callable that returns the value of ``fn`` and the gradients wrt the
20 | trainable parameters of ``model``
21 | """
22 |
23 | def inner_fn(params, *args, **kwargs):
24 | model.update(params)
25 | return fn(*args, **kwargs)
26 |
27 | value_grad_fn = mx.value_and_grad(inner_fn)
28 |
29 | def wrapped_value_grad_fn(*args, **kwargs):
30 | value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
31 | return value, grad
32 |
33 | return wrapped_value_grad_fn
34 |
--------------------------------------------------------------------------------
/examples/extensions/bindings.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | #include "axpby/axpby.h"
7 |
8 | namespace py = pybind11;
9 | using namespace py::literals;
10 | using namespace mlx::core;
11 |
12 | PYBIND11_MODULE(mlx_sample_extensions, m) {
13 | m.doc() = "Sample C++ and metal extensions for MLX";
14 |
15 | m.def(
16 | "axpby",
17 | &axpby,
18 | "x"_a,
19 | "y"_a,
20 | py::pos_only(),
21 | "alpha"_a,
22 | "beta"_a,
23 | py::kw_only(),
24 | "stream"_a = py::none(),
25 | R"pbdoc(
26 | Scale and sum two vectors elementwise
27 | ``z = alpha * x + beta * y``
28 |
29 | Follows numpy style broadcasting between ``x`` and ``y``
30 | Inputs are upcasted to floats if needed
31 |
32 | Args:
33 | x (array): Input array.
34 | y (array): Input array.
35 | alpha (float): Scaling factor for ``x``.
36 | beta (float): Scaling factor for ``y``.
37 |
38 | Returns:
39 | array: ``alpha * x + beta * y``
40 | )pbdoc");
41 | }
--------------------------------------------------------------------------------
/tests/allocator_tests.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include "doctest/doctest.h"
6 |
7 | #include "mlx/allocator.h"
8 |
9 | using namespace mlx::core;
10 |
11 | TEST_CASE("test simple allocations") {
12 | {
13 | auto buffer = allocator::malloc(sizeof(float));
14 | auto fptr = static_cast(buffer.raw_ptr());
15 | *fptr = 0.5f;
16 | CHECK_EQ(*fptr, 0.5f);
17 | allocator::free(buffer);
18 | }
19 |
20 | {
21 | auto buffer = allocator::malloc(128 * sizeof(int));
22 | int* ptr = static_cast(buffer.raw_ptr());
23 | for (int i = 0; i < 128; ++i) {
24 | ptr[i] = i;
25 | }
26 | allocator::free(buffer);
27 | }
28 |
29 | {
30 | auto buffer = allocator::malloc(0);
31 | allocator::free(buffer);
32 | }
33 | }
34 |
35 | TEST_CASE("test large allocations") {
36 | size_t size = 1 << 30;
37 | for (int i = 0; i < 100; ++i) {
38 | auto buffer = allocator::malloc(size);
39 | allocator::free(buffer);
40 | }
41 | // Shouldn't be able to allocate an exabyte anytime soon.
42 | CHECK_THROWS_AS(allocator::malloc(1ull << 60), std::runtime_error);
43 | }
44 |
--------------------------------------------------------------------------------
/python/mlx/nn/layers/linear.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import math
4 |
5 | import mlx.core as mx
6 | from mlx.nn.layers.base import Module
7 |
8 |
9 | class Linear(Module):
10 | """Applies an affine transformation to the input.
11 |
12 | Args:
13 | input_dims (int): The dimensionality of the input features
14 | output_dims (int): The dimensionality of the output features
15 | bias (bool): If set to False then the layer will not use a bias
16 | """
17 |
18 | def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
19 | super().__init__()
20 | scale = math.sqrt(1 / input_dims)
21 | self.weight = mx.random.uniform(
22 | low=-scale,
23 | high=scale,
24 | shape=(output_dims, input_dims),
25 | )
26 | if bias:
27 | self.bias = mx.zeros((output_dims,))
28 |
29 | def _extra_repr(self):
30 | return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
31 |
32 | def __call__(self, x):
33 | x = x @ self.weight.T
34 | if "bias" in self:
35 | x = x + self.bias
36 | return x
37 |
--------------------------------------------------------------------------------
/docs/src/python/random.rst:
--------------------------------------------------------------------------------
1 | .. _random:
2 |
3 | Random
4 | ======
5 |
6 | Random sampling functions in MLX use an implicit global PRNG state by default.
7 | However, all function take an optional ``key`` keyword argument for when more
8 | fine-grained control or explicit state management is needed.
9 |
10 | For example, you can generate random numbers with:
11 |
12 | .. code-block:: python
13 |
14 | for _ in range(3):
15 | print(mx.random.uniform())
16 |
17 | which will print a sequence of unique pseudo random numbers. Alternatively you
18 | can explicitly set the key:
19 |
20 | .. code-block:: python
21 |
22 | key = mx.random.key(0)
23 | for _ in range(3):
24 | print(mx.random.uniform(key=key))
25 |
26 | which will yield the same pseudo random number at each iteration.
27 |
28 | Following `JAX's PRNG design `_
29 | we use a splittable version of Threefry, which is a counter-based PRNG.
30 |
31 | .. currentmodule:: mlx.core.random
32 |
33 | .. autosummary::
34 | :toctree: _autosummary
35 |
36 | seed
37 | key
38 | split
39 | bernoulli
40 | categorical
41 | gumbel
42 | normal
43 | randint
44 | uniform
45 | truncated_normal
46 |
--------------------------------------------------------------------------------
/mlx/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | target_sources(
2 | mlx
3 | PRIVATE
4 | ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
5 | ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
6 | ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
7 | ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
8 | ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
9 | ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
10 | ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
11 | ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
12 | ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
13 | ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
14 | ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
15 | ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
16 | ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
17 | ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
18 | )
19 |
20 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
21 |
22 | if (MLX_BUILD_ACCELERATE)
23 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
24 | else()
25 | target_sources(
26 | mlx
27 | PRIVATE
28 | ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
29 | )
30 | endif()
31 |
32 | if (MLX_BUILD_METAL)
33 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
34 | else()
35 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
36 | endif()
37 |
--------------------------------------------------------------------------------
/mlx/allocator.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include
5 |
6 | #include "mlx/allocator.h"
7 | #include "mlx/scheduler.h"
8 |
9 | namespace mlx::core::allocator {
10 |
11 | Buffer malloc(size_t size) {
12 | auto buffer = allocator().malloc(size);
13 | if (size && !buffer.ptr()) {
14 | std::ostringstream msg;
15 | msg << "[malloc] Unable to allocate " << size << " bytes.";
16 | throw std::runtime_error(msg.str());
17 | }
18 | return buffer;
19 | }
20 |
21 | void free(Buffer buffer) {
22 | return allocator().free(buffer);
23 | }
24 |
25 | Buffer CommonAllocator::malloc(size_t size) {
26 | return Buffer{std::malloc(size)};
27 | }
28 |
29 | void CommonAllocator::free(Buffer buffer) {
30 | std::free(buffer.raw_ptr());
31 | }
32 |
33 | Buffer malloc_or_wait(size_t size) {
34 | auto buffer = allocator().malloc(size);
35 |
36 | while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
37 | scheduler::wait_for_one();
38 | buffer = allocator().malloc(size);
39 | }
40 |
41 | if (size && !buffer.ptr()) {
42 | std::ostringstream msg;
43 | msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
44 | throw std::runtime_error(msg.str());
45 | }
46 |
47 | return buffer;
48 | }
49 |
50 | } // namespace mlx::core::allocator
51 |
--------------------------------------------------------------------------------
/mlx/utils.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include "array.h"
6 | #include "device.h"
7 | #include "dtype.h"
8 | #include "stream.h"
9 |
10 | namespace mlx::core {
11 |
12 | /** The type from promoting the arrays' types with one another. */
13 | Dtype result_type(const std::vector& arrays);
14 |
15 | std::vector broadcast_shapes(
16 | const std::vector& s1,
17 | const std::vector& s2);
18 |
19 | std::ostream& operator<<(std::ostream& os, const Device& d);
20 | std::ostream& operator<<(std::ostream& os, const Stream& s);
21 | std::ostream& operator<<(std::ostream& os, const Dtype& d);
22 | std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
23 | std::ostream& operator<<(std::ostream& os, array a);
24 | std::ostream& operator<<(std::ostream& os, const std::vector& v);
25 | std::ostream& operator<<(std::ostream& os, const std::vector& v);
26 | inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
27 | return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j";
28 | }
29 | inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
30 | return os << static_cast(v);
31 | }
32 | inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
33 | return os << static_cast(v);
34 | }
35 | } // namespace mlx::core
36 |
--------------------------------------------------------------------------------
/python/src/device.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | #include
6 |
7 | #include "mlx/device.h"
8 | #include "mlx/utils.h"
9 |
10 | namespace py = pybind11;
11 | using namespace py::literals;
12 | using namespace mlx::core;
13 |
14 | void init_device(py::module_& m) {
15 | py::enum_(m, "DeviceType")
16 | .value("cpu", Device::DeviceType::cpu)
17 | .value("gpu", Device::DeviceType::gpu)
18 | .export_values()
19 | .def(
20 | "__eq__",
21 | [](const Device::DeviceType& d1, const Device& d2) {
22 | return d1 == d2;
23 | },
24 | py::prepend());
25 |
26 | py::class_(m, "Device")
27 | .def(py::init(), "type"_a, "index"_a = 0)
28 | .def_readonly("type", &Device::type)
29 | .def(
30 | "__repr__",
31 | [](const Device& d) {
32 | std::ostringstream os;
33 | os << d;
34 | return os.str();
35 | })
36 | .def("__eq__", [](const Device& d1, const Device& d2) {
37 | return d1 == d2;
38 | });
39 |
40 | py::implicitly_convertible();
41 |
42 | m.def("default_device", &default_device);
43 | m.def("set_default_device", &set_default_device, "device"_a);
44 | }
45 |
--------------------------------------------------------------------------------
/docs/src/python/data_types.rst:
--------------------------------------------------------------------------------
1 | .. _data_types:
2 |
3 | :orphan:
4 |
5 | Data Types
6 | ==========
7 |
8 | .. currentmodule:: mlx.core
9 |
10 | The default floating point type is ``float32`` and the default integer type is
11 | ``int32``. The table below shows supported values for :obj:`Dtype`.
12 |
13 | .. list-table:: Supported Data Types
14 | :widths: 5 3 20
15 | :header-rows: 1
16 |
17 | * - Type
18 | - Bytes
19 | - Description
20 | * - ``bool_``
21 | - 1
22 | - Boolean (``True``, ``False``) data type
23 | * - ``uint8``
24 | - 1
25 | - 8-bit unsigned integer
26 | * - ``uint16``
27 | - 2
28 | - 16-bit unsigned integer
29 | * - ``uint32``
30 | - 4
31 | - 32-bit unsigned integer
32 | * - ``uint32``
33 | - 8
34 | - 32-bit unsigned integer
35 | * - ``int8``
36 | - 1
37 | - 8-bit signed integer
38 | * - ``int16``
39 | - 2
40 | - 16-bit signed integer
41 | * - ``int32``
42 | - 4
43 | - 32-bit signed integer
44 | * - ``int64``
45 | - 8
46 | - 64-bit signed integer
47 | * - ``float16``
48 | - 2
49 | - 16-bit float, only available with `ARM C language extensions `_
50 | * - ``float32``
51 | - 4
52 | - 32-bit float
53 |
--------------------------------------------------------------------------------
/mlx/scheduler.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/scheduler.h"
4 | #include "mlx/backend/metal/metal.h"
5 |
6 | namespace mlx::core {
7 |
8 | Stream default_stream(Device d) {
9 | if (!metal::is_available() && d == Device::gpu) {
10 | throw std::invalid_argument(
11 | "[default_stream] Cannot get gpu stream without gpu backend.");
12 | }
13 | return scheduler::scheduler().get_default_stream(d);
14 | }
15 |
16 | void set_default_stream(Stream s) {
17 | if (!metal::is_available() && s.device == Device::gpu) {
18 | throw std::invalid_argument(
19 | "[set_default_stream] Cannot set gpu stream without gpu backend.");
20 | }
21 | return scheduler::scheduler().set_default_stream(s);
22 | }
23 |
24 | Stream new_stream(Device d) {
25 | if (!metal::is_available() && d == Device::gpu) {
26 | throw std::invalid_argument(
27 | "[new_stream] Cannot make gpu stream without gpu backend.");
28 | }
29 | return scheduler::scheduler().new_stream(d);
30 | }
31 |
32 | Stream new_stream() {
33 | return scheduler::scheduler().new_stream(default_device());
34 | }
35 |
36 | namespace scheduler {
37 |
38 | /** A singleton scheduler to manage devices, streams, and task execution. */
39 | Scheduler& scheduler() {
40 | static Scheduler scheduler;
41 | return scheduler;
42 | }
43 |
44 | } // namespace scheduler
45 | } // namespace mlx::core
46 |
--------------------------------------------------------------------------------
/docs/src/python/ops.rst:
--------------------------------------------------------------------------------
1 | .. _ops:
2 |
3 | Operations
4 | ==========
5 |
6 | .. currentmodule:: mlx.core
7 |
8 | .. autosummary::
9 | :toctree: _autosummary
10 |
11 | abs
12 | add
13 | all
14 | allclose
15 | any
16 | arange
17 | arccos
18 | arccosh
19 | arcsin
20 | arcsinh
21 | arctan
22 | arctanh
23 | argmax
24 | argmin
25 | argpartition
26 | argsort
27 | array_equal
28 | broadcast_to
29 | concatenate
30 | convolve
31 | conv1d
32 | conv2d
33 | cos
34 | cosh
35 | divide
36 | equal
37 | erf
38 | erfinv
39 | exp
40 | expand_dims
41 | full
42 | greater
43 | greater_equal
44 | less
45 | less_equal
46 | load
47 | log
48 | log2
49 | log10
50 | log1p
51 | logaddexp
52 | logical_not
53 | logsumexp
54 | matmul
55 | max
56 | maximum
57 | mean
58 | min
59 | minimum
60 | multiply
61 | negative
62 | ones
63 | ones_like
64 | partition
65 | pad
66 | prod
67 | reciprocal
68 | reshape
69 | rsqrt
70 | save
71 | savez
72 | savez_compressed
73 | sigmoid
74 | sign
75 | sin
76 | sinh
77 | softmax
78 | sort
79 | split
80 | sqrt
81 | square
82 | squeeze
83 | stop_gradient
84 | subtract
85 | sum
86 | take
87 | take_along_axis
88 | tan
89 | tanh
90 | transpose
91 | var
92 | where
93 | zeros
94 | zeros_like
95 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to MLX
2 |
3 | We want to make contributing to this project as easy and transparent as
4 | possible.
5 |
6 | ## Pull Requests
7 |
8 | 1. Fork and submit pull requests to the repo.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. If a change is likely to impact efficiency, run some of the benchmarks before
11 | and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
12 | 4. If you've changed APIs, update the documentation.
13 | 5. Every PR should have passing tests and at least one review.
14 | 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
15 | This should install hooks for running `black` and `clang-format` to ensure
16 | consistent style for C++ and python code.
17 |
18 | You can also run the formatters manually as follows:
19 |
20 | ```
21 | clang-format -i file.cpp
22 | ```
23 |
24 | ```
25 | black file.py
26 | ```
27 |
28 | or run `pre-commit run --all-files` to check all files in the repo.
29 |
30 | ## Issues
31 |
32 | We use GitHub issues to track public bugs. Please ensure your description is
33 | clear and has sufficient instructions to be able to reproduce the issue.
34 |
35 | ## License
36 |
37 | By contributing to MLX, you agree that your contributions will be licensed
38 | under the LICENSE file in the root directory of this source tree.
39 |
--------------------------------------------------------------------------------
/mlx/backend/common/load.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | #include "mlx/allocator.h"
8 | #include "mlx/load.h"
9 | #include "mlx/primitives.h"
10 |
11 | namespace mlx::core {
12 |
13 | namespace {
14 |
15 | template
16 | void swap_endianess(uint8_t* data_bytes, size_t N) {
17 | struct Elem {
18 | uint8_t bytes[scalar_size];
19 | };
20 |
21 | Elem* data = reinterpret_cast(data_bytes);
22 |
23 | for (size_t i = 0; i < N; i++) {
24 | for (size_t j = 0; j < (scalar_size / 2); j++) {
25 | std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
26 | }
27 | }
28 | }
29 |
30 | } // namespace
31 |
32 | void Load::eval(const std::vector& inputs, array& out) {
33 | assert(inputs.size() == 0);
34 | out.set_data(allocator::malloc_or_wait(out.nbytes()));
35 |
36 | reader_->seek(offset_, std::ios_base::beg);
37 | reader_->read(out.data(), out.nbytes());
38 |
39 | if (swap_endianness_) {
40 | switch (out.itemsize()) {
41 | case 2:
42 | swap_endianess<2>(out.data(), out.data_size());
43 | break;
44 | case 4:
45 | swap_endianess<4>(out.data(), out.data_size());
46 | break;
47 | case 8:
48 | swap_endianess<8>(out.data(), out.data_size());
49 | break;
50 | }
51 | }
52 | }
53 |
54 | } // namespace mlx::core
55 |
--------------------------------------------------------------------------------
/docs/src/python/optimizers.rst:
--------------------------------------------------------------------------------
1 | .. _optimizers:
2 |
3 | Optimizers
4 | ==========
5 |
6 | The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure
7 | :mod:`mlx.core` functions. A typical example involves calling
8 | :meth:`Optimizer.update` to update a model's parameters based on the loss
9 | gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the
10 | model's parameters and the **optimizer state**.
11 |
12 | .. code-block:: python
13 |
14 | # Create a model
15 | model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
16 | mx.eval(model.parameters())
17 |
18 | # Create the gradient function and the optimizer
19 | loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
20 | optimizer = optim.SGD(learning_rate=learning_rate)
21 |
22 | for e in range(num_epochs):
23 | for X, y in batch_iterate(batch_size, train_images, train_labels):
24 | loss, grads = loss_and_grad_fn(model, X, y)
25 |
26 | # Update the model with the gradients. So far no computation has happened.
27 | optimizer.update(model, grads)
28 |
29 | # Compute the new parameters but also the optimizer state.
30 | mx.eval(model.parameters(), optimizer.state)
31 |
32 | .. currentmodule:: mlx.optimizers
33 |
34 | .. autosummary::
35 | :toctree: _autosummary
36 | :template: optimizers-template.rst
37 |
38 | OptimizerState
39 | Optimizer
40 | SGD
41 | Adam
42 |
--------------------------------------------------------------------------------
/benchmarks/cpp/time_utils.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 | #include
7 | #include
8 |
9 | #include "mlx/mlx.h"
10 |
11 | #define milliseconds(x) \
12 | (std::chrono::duration_cast(x).count() / 1e6)
13 | #define time_now() std::chrono::high_resolution_clock::now()
14 |
15 | #define TIME(FUNC, ...) \
16 | std::cout << "Timing " << #FUNC << " ... " << std::flush \
17 | << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
18 | << std::endl;
19 |
20 | #define TIMEM(MSG, FUNC, ...) \
21 | std::cout << "Timing " \
22 | << "(" << MSG << ") " << #FUNC << " ... " << std::flush \
23 | << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
24 | << std::endl;
25 |
26 | template
27 | double time_fn(F fn, Args... args) {
28 | // warmup
29 | for (int i = 0; i < 5; ++i) {
30 | eval(fn(std::forward(args)...));
31 | }
32 |
33 | int num_iters = 100;
34 | auto start = time_now();
35 | for (int i = 0; i < num_iters; i++) {
36 | eval(fn(std::forward(args)...));
37 | }
38 | auto end = time_now();
39 | return milliseconds(end - start) / static_cast(num_iters);
40 | }
41 |
--------------------------------------------------------------------------------
/docs/src/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | # -*- coding: utf-8 -*-
4 |
5 | import os
6 | import subprocess
7 |
8 | # -- Project information -----------------------------------------------------
9 |
10 | project = "MLX"
11 | copyright = "2023, MLX Contributors"
12 | author = "MLX Contributors"
13 | version = "0.0.3"
14 | release = "0.0.3"
15 |
16 | # -- General configuration ---------------------------------------------------
17 |
18 | extensions = [
19 | "sphinx.ext.autodoc",
20 | "sphinx.ext.autosummary",
21 | "sphinx.ext.intersphinx",
22 | "sphinx.ext.napoleon",
23 | ]
24 |
25 | python_use_unqualified_type_names = True
26 | autosummary_generate = True
27 |
28 | intersphinx_mapping = {
29 | "https://docs.python.org/3": None,
30 | "https://numpy.org/doc/stable/": None,
31 | }
32 |
33 | templates_path = ["_templates"]
34 | html_static_path = ["_static"]
35 | source_suffix = ".rst"
36 | master_doc = "index"
37 | highlight_language = "python"
38 | pygments_style = "sphinx"
39 |
40 | # -- Options for HTML output -------------------------------------------------
41 |
42 | html_theme = "sphinx_book_theme"
43 |
44 | html_theme_options = {
45 | "show_toc_level": 2,
46 | "repository_url": "https://github.com/ml-explore/mlx",
47 | "use_repository_button": True,
48 | "navigation_with_keys": False,
49 | }
50 |
51 | html_logo = "_static/mlx_logo.png"
52 |
53 |
54 | # -- Options for HTMLHelp output ---------------------------------------------
55 |
56 | htmlhelp_basename = "mlx_doc"
57 |
--------------------------------------------------------------------------------
/mlx.pc.in:
--------------------------------------------------------------------------------
1 | # Find MLX
2 | #
3 | # Defines the following variables:
4 | #
5 | # MLX_FOUND : True if MLX is found
6 | # MLX_INCLUDE_DIRS : Include directory
7 | # MLX_LIBRARIES : Libraries to link against
8 | # MLX_CXX_FLAGS : Additional compiler flags
9 | # MLX_BUILD_ACCELERATE : True if MLX was built with accelerate
10 | # MLX_BUILD_METAL : True if MLX was built with metal
11 |
12 | @PACKAGE_INIT@
13 |
14 | include(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/MLXTargets.cmake)
15 | include(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/extension.cmake)
16 |
17 | set_and_check(MLX_LIBRARY_DIRS @PACKAGE_CMAKE_INSTALL_LIBDIR@)
18 | set_and_check(MLX_INCLUDE_DIRS @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@)
19 | set(MLX_LIBRARIES mlx)
20 |
21 | find_library(MLX_LIBRARY mlx PATHS ${MLX_LIBRARY_DIRS})
22 |
23 | if (@MLX_BUILD_ACCELERATE@)
24 | set(MLX_BUILD_ACCELERATE @MLX_BUILD_ACCELERATE@)
25 | set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -DACCELERATE_NEW_LAPACK)
26 | endif()
27 |
28 | if (@MLX_BUILD_METAL@)
29 | set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
30 | set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
31 | set_and_check(MLX_INCLUDE_DIRS
32 | ${MLX_INCLUDE_DIRS}
33 | @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
34 | )
35 | endif()
36 |
37 | set_target_properties(mlx PROPERTIES
38 | CXX_STANDARD 17
39 | INTERFACE_COMPILE_OPTIONS "${MLX_CXX_FLAGS}"
40 | )
41 |
42 | include(FindPackageHandleStandardArgs)
43 | find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
--------------------------------------------------------------------------------
/examples/cpp/logistic_regression.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | #include "mlx/mlx.h"
8 | #include "timer.h"
9 |
10 | /**
11 | * An example of logistic regression with MLX.
12 | */
13 | using namespace mlx::core;
14 |
15 | int main() {
16 | int num_features = 100;
17 | int num_examples = 1'000;
18 | int num_iters = 10'000;
19 | float learning_rate = 0.1;
20 |
21 | // True parameters
22 | auto w_star = random::normal({num_features});
23 |
24 | // The input examples
25 | auto X = random::normal({num_examples, num_features});
26 |
27 | // Labels
28 | auto y = matmul(X, w_star) > 0;
29 |
30 | // Initialize random parameters
31 | array w = 1e-2 * random::normal({num_features});
32 |
33 | auto loss_fn = [&](array w) {
34 | auto logits = matmul(X, w);
35 | auto scale = (1.0f / num_examples);
36 | return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
37 | };
38 |
39 | auto grad_fn = grad(loss_fn);
40 |
41 | auto tic = timer::time();
42 | for (int it = 0; it < num_iters; ++it) {
43 | auto grad = grad_fn(w);
44 | w = w - learning_rate * grad;
45 | eval(w);
46 | }
47 | auto toc = timer::time();
48 |
49 | auto loss = loss_fn(w);
50 | auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
51 | auto throughput = num_iters / timer::seconds(toc - tic);
52 | std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
53 | << throughput << " (it/s)." << std::endl;
54 | }
55 |
--------------------------------------------------------------------------------
/examples/cpp/linear_regression.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | #include "mlx/mlx.h"
8 | #include "timer.h"
9 |
10 | /**
11 | * An example of linear regression with MLX.
12 | */
13 | using namespace mlx::core;
14 |
15 | int main() {
16 | int num_features = 100;
17 | int num_examples = 1'000;
18 | int num_iters = 10'000;
19 | float learning_rate = 0.01;
20 |
21 | // True parameters
22 | auto w_star = random::normal({num_features});
23 |
24 | // The input examples (design matrix)
25 | auto X = random::normal({num_examples, num_features});
26 |
27 | // Noisy labels
28 | auto eps = 1e-2 * random::normal({num_examples});
29 | auto y = matmul(X, w_star) + eps;
30 |
31 | // Initialize random parameters
32 | array w = 1e-2 * random::normal({num_features});
33 |
34 | auto loss_fn = [&](array w) {
35 | auto yhat = matmul(X, w);
36 | return (0.5f / num_examples) * sum(square(yhat - y));
37 | };
38 |
39 | auto grad_fn = grad(loss_fn);
40 |
41 | auto tic = timer::time();
42 | for (int it = 0; it < num_iters; ++it) {
43 | auto grad = grad_fn(w);
44 | w = w - learning_rate * grad;
45 | eval(w);
46 | }
47 | auto toc = timer::time();
48 |
49 | auto loss = loss_fn(w);
50 | auto error_norm = std::sqrt(sum(square(w - w_star)).item());
51 | auto throughput = num_iters / timer::seconds(toc - tic);
52 | std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
53 | << ", Throughput " << throughput << " (it/s)." << std::endl;
54 | }
55 |
--------------------------------------------------------------------------------
/mlx/types/half_types.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 | #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
5 |
6 | #include
7 | namespace mlx::core {
8 | typedef __fp16 float16_t;
9 | } // namespace mlx::core
10 |
11 | #else
12 |
13 | #define ADD_HALF_BINOPS
14 | #include "mlx/types/fp16.h"
15 | namespace mlx::core {
16 | typedef struct _MLX_Float16 float16_t;
17 | } // namespace mlx::core
18 |
19 | #endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
20 | #ifdef __ARM_FEATURE_BF16
21 |
22 | #include
23 | namespace mlx::core {
24 | typedef __bf16 bfloat16_t;
25 | } // namespace mlx::core
26 |
27 | #else
28 |
29 | #define ADD_HALF_BINOPS
30 | #include "mlx/types/bf16.h"
31 | namespace mlx::core {
32 | typedef struct _MLX_BFloat16 bfloat16_t;
33 | } // namespace mlx::core
34 |
35 | #endif // __ARM_FEATURE_BF16
36 |
37 | #ifdef ADD_HALF_BINOPS
38 | namespace mlx::core {
39 |
40 | // clang-format off
41 | #define fp16_bf16_binop_helper(__op__, __operator__) \
42 | inline float __operator__(float16_t lhs, bfloat16_t rhs) { \
43 | return static_cast(lhs) __op__ static_cast(rhs); \
44 | } \
45 | inline float __operator__(bfloat16_t lhs, float16_t rhs) { \
46 | return static_cast(lhs) __op__ static_cast(rhs); \
47 | }
48 |
49 | fp16_bf16_binop_helper(+, operator+)
50 | fp16_bf16_binop_helper(-, operator-)
51 | fp16_bf16_binop_helper(*, operator*)
52 | fp16_bf16_binop_helper(/, operator/)
53 | // clang-format on
54 |
55 | } // namespace mlx::core
56 | #endif
57 |
--------------------------------------------------------------------------------
/benchmarks/python/batch_matmul_bench.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | import argparse
4 | import mlx.core as mx
5 |
6 | from time_utils import time_fn
7 |
8 | B = 8
9 | T = 1024
10 | D = 512
11 |
12 |
13 | def time_batch_matmul():
14 | mx.random.seed(3)
15 | a = mx.random.uniform(shape=(B, T, D))
16 | b = mx.random.uniform(shape=(D, D))
17 | c = mx.random.uniform(shape=(B, T, D))
18 | mx.eval(a, b, c)
19 |
20 | time_fn(mx.matmul, a, b)
21 |
22 | def batch_vjp_first():
23 | return mx.vjp(mx.matmul, [a, b], [c])[1][0]
24 |
25 | time_fn(batch_vjp_first)
26 |
27 | def batch_vjp_second():
28 | return mx.vjp(mx.matmul, [a, b], [c])[1][1]
29 |
30 | time_fn(batch_vjp_second)
31 |
32 |
33 | def time_unbatch_matmul():
34 | mx.random.seed(3)
35 | a = mx.random.uniform(shape=(B * T, D))
36 | b = mx.random.uniform(shape=(D, D))
37 | c = mx.random.uniform(shape=(B * T, D))
38 | mx.eval(a, b, c)
39 | time_fn(mx.matmul, a, b)
40 |
41 | def unbatch_vjp_first():
42 | return mx.matmul(c, mx.transpose(b))
43 |
44 | time_fn(unbatch_vjp_first)
45 |
46 | def unbatch_vjp_second():
47 | return mx.matmul(mx.transpose(a), c)
48 |
49 | time_fn(unbatch_vjp_second)
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser("MLX benchmarks.")
54 | parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
55 | args = parser.parse_args()
56 | if args.gpu:
57 | mx.set_default_device(mx.gpu)
58 | else:
59 | mx.set_default_device(mx.cpu)
60 |
61 | time_batch_matmul()
62 | time_unbatch_matmul()
63 |
--------------------------------------------------------------------------------
/mlx/backend/no_metal/primitives.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include "mlx/primitives.h"
4 |
5 | #define NO_GPU(func) \
6 | void func::eval_gpu(const std::vector& inputs, array& out) { \
7 | throw std::runtime_error(#func " has no GPU implementation."); \
8 | }
9 |
10 | namespace mlx::core {
11 |
12 | NO_GPU(Abs)
13 | NO_GPU(Add)
14 | NO_GPU(Arange)
15 | NO_GPU(ArcCos)
16 | NO_GPU(ArcCosh)
17 | NO_GPU(ArcSin)
18 | NO_GPU(ArcSinh)
19 | NO_GPU(ArcTan)
20 | NO_GPU(ArcTanh)
21 | NO_GPU(ArgPartition)
22 | NO_GPU(ArgReduce)
23 | NO_GPU(ArgSort)
24 | NO_GPU(AsType)
25 | NO_GPU(AsStrided)
26 | NO_GPU(Broadcast)
27 | NO_GPU(Concatenate)
28 | NO_GPU(Convolution)
29 | NO_GPU(Copy)
30 | NO_GPU(Cos)
31 | NO_GPU(Cosh)
32 | NO_GPU(Divide)
33 | NO_GPU(Equal)
34 | NO_GPU(Erf)
35 | NO_GPU(ErfInv)
36 | NO_GPU(Exp)
37 | NO_GPU(FFT)
38 | NO_GPU(Full)
39 | NO_GPU(Gather)
40 | NO_GPU(Greater)
41 | NO_GPU(GreaterEqual)
42 | NO_GPU(Less)
43 | NO_GPU(LessEqual)
44 | NO_GPU(Load)
45 | NO_GPU(Log)
46 | NO_GPU(Log1p)
47 | NO_GPU(LogicalNot)
48 | NO_GPU(LogAddExp)
49 | NO_GPU(Matmul)
50 | NO_GPU(Maximum)
51 | NO_GPU(Minimum)
52 | NO_GPU(Multiply)
53 | NO_GPU(Negative)
54 | NO_GPU(NotEqual)
55 | NO_GPU(Pad)
56 | NO_GPU(Partition)
57 | NO_GPU(Power)
58 | NO_GPU(RandomBits)
59 | NO_GPU(Reduce)
60 | NO_GPU(Reshape)
61 | NO_GPU(Scan)
62 | NO_GPU(Scatter)
63 | NO_GPU(Sigmoid)
64 | NO_GPU(Sign)
65 | NO_GPU(Sin)
66 | NO_GPU(Sinh)
67 | NO_GPU(Slice)
68 | NO_GPU(Softmax)
69 | NO_GPU(Sort)
70 | NO_GPU(Square)
71 | NO_GPU(Sqrt)
72 | NO_GPU(StopGradient)
73 | NO_GPU(Subtract)
74 | NO_GPU(Tan)
75 | NO_GPU(Tanh)
76 | NO_GPU(Transpose)
77 |
78 | } // namespace mlx::core
79 |
--------------------------------------------------------------------------------
/mlx/backend/common/erf.cpp:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #include
4 |
5 | namespace mlx::core {
6 |
7 | /* Approximation to the inverse error function.
8 | * Based on code from:
9 | * https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
10 | */
11 | float erfinv(float a) {
12 | auto t = std::fma(a, 0.0f - a, 1.0f);
13 | t = std::log(t);
14 | float p;
15 | if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
16 | p = 3.03697567e-10f; // 0x1.4deb44p-32
17 | p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
18 | p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
19 | p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
20 | p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
21 | p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
22 | p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
23 | p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
24 | p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
25 | } else { // maximum ulp error = 2.35002
26 | p = 5.43877832e-9f; // 0x1.75c000p-28
27 | p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
28 | p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
29 | p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
30 | p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
31 | p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
32 | p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
33 | p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
34 | p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
35 | p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
36 | }
37 | return a * p;
38 | }
39 |
40 | } // namespace mlx::core
41 |
--------------------------------------------------------------------------------
/mlx/allocator.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include
6 |
7 | namespace mlx::core::allocator {
8 |
9 | // Simple wrapper around buffer pointers
10 | // WARNING: Only Buffer objects constructed from and those that wrap
11 | // raw pointers from mlx::allocator are supported.
12 | class Buffer {
13 | private:
14 | void* ptr_;
15 |
16 | public:
17 | Buffer(void* ptr) : ptr_(ptr){};
18 |
19 | // Get the raw data pointer from the buffer
20 | void* raw_ptr();
21 |
22 | // Get the buffer pointer from the buffer
23 | const void* ptr() const {
24 | return ptr_;
25 | };
26 | void* ptr() {
27 | return ptr_;
28 | };
29 | };
30 |
31 | Buffer malloc(size_t size);
32 |
33 | void free(Buffer buffer);
34 |
35 | // Wait for running tasks to finish and free up memory
36 | // if allocation fails
37 | Buffer malloc_or_wait(size_t size);
38 |
39 | class Allocator {
40 | /** Abstract base clase for a memory allocator. */
41 | public:
42 | virtual Buffer malloc(size_t size) = 0;
43 | virtual void free(Buffer buffer) = 0;
44 |
45 | Allocator() = default;
46 | Allocator(const Allocator& other) = delete;
47 | Allocator(Allocator&& other) = delete;
48 | Allocator& operator=(const Allocator& other) = delete;
49 | Allocator& operator=(Allocator&& other) = delete;
50 | virtual ~Allocator() = default;
51 | };
52 |
53 | Allocator& allocator();
54 |
55 | class CommonAllocator : public Allocator {
56 | /** A general CPU allocator. */
57 | public:
58 | virtual Buffer malloc(size_t size) override;
59 | virtual void free(Buffer buffer) override;
60 |
61 | private:
62 | CommonAllocator() = default;
63 | friend Allocator& allocator();
64 | };
65 |
66 | } // namespace mlx::core::allocator
67 |
--------------------------------------------------------------------------------
/cmake/extension.cmake:
--------------------------------------------------------------------------------
1 | include(CMakeParseArguments)
2 |
3 | ###############################################################################
4 | # Build metal library
5 | #
6 | # Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
7 | # from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
8 | #
9 | # Args:
10 | # TARGET: Custom target to be added for the metal library
11 | # TITLE: Name of the .metallib
12 | # OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
13 | # SOURCES: List of source files
14 | # INCLUDE_DIRS: List of include dirs
15 | # DEPS: List of depedency files (like headers)
16 | #
17 | macro(mlx_build_metallib)
18 | # Parse args
19 | set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
20 | set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
21 | cmake_parse_arguments(
22 | MTLLIB
23 | ""
24 | "${oneValueArgs}"
25 | "${multiValueArgs}"
26 | ${ARGN}
27 | )
28 |
29 | # Set output
30 | set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
31 |
32 | # Collect compile options
33 | set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
34 |
35 | # Prepare metllib build command
36 | add_custom_command(
37 | OUTPUT ${MTLLIB_BUILD_TARGET}
38 | COMMAND xcrun -sdk macosx metal
39 | "$"
40 | ${MTLLIB_COMPILE_OPTIONS}
41 | ${MTLLIB_SOURCES}
42 | -o ${MTLLIB_BUILD_TARGET}
43 | DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
44 | COMMAND_EXPAND_LISTS
45 | COMMENT "Building ${MTLLIB_TITLE}.metallib"
46 | VERBATIM
47 | )
48 |
49 | # Add metallib custom target
50 | add_custom_target(
51 | ${MTLLIB_TARGET}
52 | DEPENDS
53 | ${MTLLIB_BUILD_TARGET}
54 | )
55 |
56 | endmacro(mlx_build_metallib)
--------------------------------------------------------------------------------
/examples/extensions/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.24)
2 |
3 | project(mlx_sample_extensions LANGUAGES CXX)
4 |
5 | # ----------------------------- Setup -----------------------------
6 | set(CMAKE_CXX_STANDARD 17)
7 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
8 | set(CMAKE_POSITION_INDEPENDENT_CODE ON)
9 |
10 | option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
11 |
12 | # ----------------------------- Dependencies -----------------------------
13 | find_package(MLX CONFIG REQUIRED)
14 | find_package(Python COMPONENTS Interpreter Development)
15 | find_package(pybind11 CONFIG REQUIRED)
16 |
17 | # ----------------------------- Extensions -----------------------------
18 |
19 | # Add library
20 | add_library(mlx_ext)
21 |
22 | # Add sources
23 | target_sources(
24 | mlx_ext
25 | PUBLIC
26 | ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
27 | )
28 |
29 | # Add include headers
30 | target_include_directories(
31 | mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
32 | )
33 |
34 | # Link to mlx
35 | target_link_libraries(mlx_ext PUBLIC mlx)
36 |
37 | # ----------------------------- Metal -----------------------------
38 |
39 | # Build metallib
40 | if(MLX_BUILD_METAL)
41 |
42 | mlx_build_metallib(
43 | TARGET mlx_ext_metallib
44 | TITLE mlx_ext
45 | SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
46 | INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
47 | OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
48 | )
49 |
50 | add_dependencies(
51 | mlx_ext
52 | mlx_ext_metallib
53 | )
54 |
55 | endif()
56 |
57 | # ----------------------------- Pybind -----------------------------
58 | pybind11_add_module(
59 | mlx_sample_extensions
60 | ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
61 | )
62 | target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
63 |
64 | if(BUILD_SHARED_LIBS)
65 | target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
66 | endif()
--------------------------------------------------------------------------------
/mlx/backend/metal/allocator.h:
--------------------------------------------------------------------------------
1 | // Copyright © 2023 Apple Inc.
2 |
3 | #pragma once
4 |
5 | #include