├── Applications
├── LLMEval
│ ├── Assets.xcassets
│ │ ├── Contents.json
│ │ ├── AccentColor.colorset
│ │ │ └── Contents.json
│ │ └── AppIcon.appiconset
│ │ │ └── Contents.json
│ ├── Preview Content
│ │ └── Preview Assets.xcassets
│ │ │ └── Contents.json
│ ├── LLMEvalApp.swift
│ ├── LLMEval.entitlements
│ ├── ViewModels
│ │ └── DeviceStat.swift
│ └── README.md
├── VLMEval
│ ├── Assets.xcassets
│ │ ├── Contents.json
│ │ ├── AccentColor.colorset
│ │ │ └── Contents.json
│ │ └── AppIcon.appiconset
│ │ │ └── Contents.json
│ ├── Preview Content
│ │ └── Preview Assets.xcassets
│ │ │ └── Contents.json
│ ├── VLMEvalApp.swift
│ ├── VLMEval.entitlements
│ └── README.md
├── MNISTTrainer
│ ├── Assets.xcassets
│ │ ├── Contents.json
│ │ ├── AccentColor.colorset
│ │ │ └── Contents.json
│ │ └── AppIcon.appiconset
│ │ │ └── Contents.json
│ ├── Preview Content
│ │ └── Preview Assets.xcassets
│ │ │ └── Contents.json
│ ├── MNISTTrainer-Info.plist
│ ├── MNISTTrainerApp.swift
│ ├── MNISTTrainer.entitlements
│ ├── README.md
│ ├── PredictionView.swift
│ └── ContentView.swift
├── LoRATrainingExample
│ ├── Assets.xcassets
│ │ ├── Contents.json
│ │ ├── AccentColor.colorset
│ │ │ └── Contents.json
│ │ └── AppIcon.appiconset
│ │ │ └── Contents.json
│ ├── Preview Content
│ │ └── Preview Assets.xcassets
│ │ │ └── Contents.json
│ ├── LoRATrainingExampleApp.swift
│ ├── LoRATrainingExample.entitlements
│ └── README.md
└── StableDiffusionExample
│ ├── Assets.xcassets
│ ├── Contents.json
│ ├── AccentColor.colorset
│ │ └── Contents.json
│ └── AppIcon.appiconset
│ │ └── Contents.json
│ ├── Preview Content
│ └── Preview Assets.xcassets
│ │ └── Contents.json
│ ├── StableDiffusionExampleApp.swift
│ ├── StableDiffusionExample.entitlements
│ └── README.md
├── .spi.yml
├── .swift-format
├── Libraries
├── MLXVLM
│ └── VLMModel.swift
├── MLXLMCommon
│ ├── Documentation.docc
│ │ └── Documentation.md
│ ├── Module+Extensions.swift
│ ├── Registries
│ │ ├── ModelTypeRegistry.swift
│ │ ├── ProcessorTypeRegistry.swift
│ │ └── AbstractModelRegistry.swift
│ ├── ModelConfiguration.swift
│ ├── ModelContainer.swift
│ ├── StringOrNumber.swift
│ ├── Load.swift
│ ├── KVCache.swift
│ ├── README.md
│ ├── ModelFactory.swift
│ └── Tokenizer.swift
├── MLXMNIST
│ ├── README.md
│ ├── Random.swift
│ ├── MNIST.swift
│ └── Files.swift
├── MLXLLM
│ ├── Documentation.docc
│ │ ├── Documentation.md
│ │ ├── adding-model.md
│ │ └── using-model.md
│ ├── LLMModel.swift
│ ├── SuScaledRotaryEmbedding.swift
│ ├── Lora+Data.swift
│ ├── README.md
│ └── SwitchLayers.swift
├── Embedders
│ ├── README.md
│ ├── Tokenizer.swift
│ ├── Pooling.swift
│ ├── EmbeddingModel.swift
│ ├── Load.swift
│ ├── Configuration.swift
│ └── Models.swift
└── StableDiffusion
│ ├── README.md
│ ├── Sampler.swift
│ ├── Tokenizer.swift
│ ├── Clip.swift
│ └── Image.swift
├── mlx-swift-examples.xcodeproj
├── project.xcworkspace
│ ├── contents.xcworkspacedata
│ └── xcshareddata
│ │ ├── IDEWorkspaceChecks.plist
│ │ └── swiftpm
│ │ └── Package.resolved
└── xcshareddata
│ └── xcschemes
│ ├── LLMEval.xcscheme
│ ├── VLMEval.xcscheme
│ ├── StableDiffusionExample.xcscheme
│ └── llm-tool.xcscheme
├── .pre-commit-config.yaml
├── Tools
├── LinearModelTraining
│ ├── README.md
│ └── LinearModelTraining.swift
├── llm-tool
│ └── Arguments.swift
├── mnist-tool
│ ├── README.md
│ └── MNISTTool.swift
├── image-tool
│ └── Arguments.swift
└── Tutorial
│ └── Tutorial.swift
├── Configuration
└── Build.xcconfig
├── ACKNOWLEDGMENTS.md
├── LICENSE
├── CONTRIBUTING.md
├── mlx-run
├── .circleci
└── config.yml
├── Package.resolved
├── .gitignore
├── README.md
├── Data
└── lora
│ └── wikisql.py
├── CODE_OF_CONDUCT.md
└── Package.swift
/Applications/LLMEval/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/VLMEval/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/.spi.yml:
--------------------------------------------------------------------------------
1 | version: 1
2 | builder:
3 | configs:
4 | - documentation_targets: [MLXLLM, MLXVLM, MLXLMCommon, MLXMNIST, MLXEmbedders, StableDiffusion]
5 |
--------------------------------------------------------------------------------
/.swift-format:
--------------------------------------------------------------------------------
1 | {
2 | "version": 1,
3 | "indentation": {
4 | "spaces": 4
5 | },
6 | "spacesAroundRangeFormationOperators": true,
7 | }
8 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/VLMEval/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Libraries/MLXVLM/VLMModel.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import MLX
4 | import MLXLMCommon
5 |
6 | public protocol VLMModel: LanguageModel, LoRAModel {
7 | }
8 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/project.xcworkspace/contents.xcworkspacedata:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/slessans/pre-commit-swift-format
3 | rev: "fd627de92bdf84a75c924ed95691336d14e94cf1"
4 | hooks:
5 | - id: swift-format
6 | args: ["--configuration", ".swift-format"]
7 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/MNISTTrainer-Info.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Applications/VLMEval/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/MNISTTrainerApp.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | @main
6 | struct MNISTTrainerApp: App {
7 | var body: some Scene {
8 | WindowGroup {
9 | ContentView()
10 | }
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/LoRATrainingExampleApp.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | @main
6 | struct LoRATrainingExampleApp: App {
7 | var body: some Scene {
8 | WindowGroup {
9 | ContentView()
10 | }
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/Applications/LLMEval/LLMEvalApp.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | @main
6 | struct LLMEvalApp: App {
7 | var body: some Scene {
8 | WindowGroup {
9 | ContentView()
10 | .environment(DeviceStat())
11 | }
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/StableDiffusionExampleApp.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | @main
6 | struct StableDiffusionExampleApp: App {
7 | var body: some Scene {
8 | WindowGroup {
9 | ContentView()
10 | }
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/Applications/VLMEval/VLMEvalApp.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | @main
6 | struct VLMEvalApp: App {
7 | var body: some Scene {
8 | WindowGroup {
9 | ContentView()
10 | .environment(DeviceStat())
11 | }
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | IDEDidComputeMac32BitWarning
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/Documentation.docc/Documentation.md:
--------------------------------------------------------------------------------
1 | # ``MLXLMCommon``
2 |
3 | Common language model code.
4 |
5 | ## Other MLX Libraries Packages
6 |
7 | - [MLXEmbedders](MLXEmbedders)
8 | - [MLXLLM](MLXLLM)
9 | - [MLXLMCommon](MLXLMCommon)
10 | - [MLXMNIST](MLXMNIST)
11 | - [MLXVLM](MLXVLM)
12 | - [StableDiffusion](StableDiffusion)
13 |
14 |
--------------------------------------------------------------------------------
/Tools/LinearModelTraining/README.md:
--------------------------------------------------------------------------------
1 | # LinearModelTraining
2 |
3 | A command line tool that creates a Model that represents:
4 |
5 | f(x) = mx + b
6 |
7 | and trains it against an unknown linear function. Very
8 | simple but illustrates:
9 |
10 | - a very simple model with parameters
11 | - a loss function
12 | - the gradient
13 | - use of an optimizers
14 | - the training loop
15 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/MNISTTrainer.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.security.app-sandbox
6 |
7 | com.apple.security.files.user-selected.read-only
8 |
9 | com.apple.security.network.client
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/Configuration/Build.xcconfig:
--------------------------------------------------------------------------------
1 | // The `DISAMBIGUATOR` configuration is to make it easier to build
2 | // and run a sample code project. Once you set your project's development team,
3 | // you'll have a unique bundle identifier. This is because the bundle identifier
4 | // is derived based on the 'DISAMBIGUATOR' value. Do not use this
5 | // approach in your own projects—it's only useful for example projects because
6 | // they are frequently downloaded and don't have a development team set.
7 | DISAMBIGUATOR=${DEVELOPMENT_TEAM}
8 |
--------------------------------------------------------------------------------
/Libraries/MLXMNIST/README.md:
--------------------------------------------------------------------------------
1 | # MNIST
2 |
3 | This is a port of the MNIST training code from the [Python MLX example](https://github.com/ml-explore/mlx-examples/blob/main/mnist). This example uses a [LeNet](https://en.wikipedia.org/wiki/LeNet) instead of an MLP.
4 |
5 | It provides code to:
6 |
7 | - Download the MNIST test/train data
8 | - Build the LeNet
9 | - Some functions to shuffle and batch the data
10 |
11 | See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.
12 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/LoRATrainingExample.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.developer.kernel.increased-memory-limit
6 |
7 | com.apple.security.app-sandbox
8 |
9 | com.apple.security.files.user-selected.read-only
10 |
11 | com.apple.security.network.client
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/StableDiffusionExample.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.developer.kernel.increased-memory-limit
6 |
7 | com.apple.security.app-sandbox
8 |
9 | com.apple.security.files.user-selected.read-only
10 |
11 | com.apple.security.network.client
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/Applications/LLMEval/LLMEval.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.developer.kernel.increased-memory-limit
6 |
7 | com.apple.security.app-sandbox
8 |
9 | com.apple.security.device.usb
10 |
11 | com.apple.security.files.user-selected.read-only
12 |
13 | com.apple.security.network.client
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/Applications/VLMEval/VLMEval.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.developer.kernel.increased-memory-limit
6 |
7 | com.apple.security.app-sandbox
8 |
9 | com.apple.security.device.usb
10 |
11 | com.apple.security.files.user-selected.read-only
12 |
13 | com.apple.security.network.client
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/ACKNOWLEDGMENTS.md:
--------------------------------------------------------------------------------
1 | # Individual Contributors
2 |
3 | If you wish to be acknowledged for your contributions, please list your name
4 | with a short description of your contribution(s) below. For example:
5 |
6 | - Jane Smith: Added the `foo` and `bar` ops.
7 |
8 | MLX Swift was developed with contributions from the following individuals:
9 |
10 |
11 |
12 |
13 |
14 |
15 | SOFTWARE.
16 |
17 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/README.md:
--------------------------------------------------------------------------------
1 | # MNISTTrainer
2 |
3 | This is an example of model training that works on both macOS and iOS.
4 | The example will download the MNIST training data, create a LeNet, and train
5 | it. It will show the epoch time and test accuracy as it trains.
6 |
7 | You will need to set the Team on the MNISTTrainer target in order to build and
8 | run on iOS.
9 |
10 | Some notes about the setup:
11 |
12 | - This will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
13 | - The website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist
14 |
--------------------------------------------------------------------------------
/Libraries/MLXLLM/Documentation.docc/Documentation.md:
--------------------------------------------------------------------------------
1 | # ``MLXLLM``
2 |
3 | Example implementations of various Large Language Models (LLMs).
4 |
5 | ## Other MLX Libraries Packages
6 |
7 | - [MLXEmbedders](MLXEmbedders)
8 | - [MLXLLM](MLXLLM)
9 | - [MLXLMCommon](MLXLMCommon)
10 | - [MLXMNIST](MLXMNIST)
11 | - [MLXVLM](MLXVLM)
12 | - [StableDiffusion](StableDiffusion)
13 |
14 | ## Topics
15 |
16 | -
17 | -
18 |
19 | ### Models
20 |
21 | - ``CohereModel``
22 | - ``GemmaModel``
23 | - ``Gemma2Model``
24 | - ``InternLM2Model``
25 | - ``LlamaModel``
26 | - ``OpenELMModel``
27 | - ``PhiModel``
28 | - ``Phi3Model``
29 | - ``PhiMoEModel``
30 | - ``Qwen2Model``
31 | - ``Starcoder2Model``
32 |
33 |
34 |
--------------------------------------------------------------------------------
/Applications/LLMEval/ViewModels/DeviceStat.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import MLX
3 |
4 | @Observable
5 | final class DeviceStat: @unchecked Sendable {
6 |
7 | @MainActor
8 | var gpuUsage = GPU.snapshot()
9 |
10 | private let initialGPUSnapshot = GPU.snapshot()
11 | private var timer: Timer?
12 |
13 | init() {
14 | timer = Timer.scheduledTimer(withTimeInterval: 2.0, repeats: true) { [weak self] _ in
15 | self?.updateGPUUsages()
16 | }
17 | }
18 |
19 | deinit {
20 | timer?.invalidate()
21 | }
22 |
23 | private func updateGPUUsages() {
24 | let gpuSnapshotDelta = initialGPUSnapshot.delta(GPU.snapshot())
25 | DispatchQueue.main.async { [weak self] in
26 | self?.gpuUsage = gpuSnapshotDelta
27 | }
28 | }
29 |
30 | }
31 |
--------------------------------------------------------------------------------
/Applications/VLMEval/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | },
8 | {
9 | "appearances" : [
10 | {
11 | "appearance" : "luminosity",
12 | "value" : "dark"
13 | }
14 | ],
15 | "idiom" : "universal",
16 | "platform" : "ios",
17 | "size" : "1024x1024"
18 | },
19 | {
20 | "appearances" : [
21 | {
22 | "appearance" : "luminosity",
23 | "value" : "tinted"
24 | }
25 | ],
26 | "idiom" : "universal",
27 | "platform" : "ios",
28 | "size" : "1024x1024"
29 | }
30 | ],
31 | "info" : {
32 | "author" : "xcode",
33 | "version" : 1
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/Tools/llm-tool/Arguments.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 |
6 | /// Extension to allow URL command line arguments.
7 | #if swift(>=5.10)
8 | extension URL: @retroactive ExpressibleByArgument {
9 | public init?(argument: String) {
10 | if argument.contains("://") {
11 | self.init(string: argument)
12 | } else {
13 | self.init(filePath: argument)
14 | }
15 | }
16 | }
17 | #else
18 | extension URL: ExpressibleByArgument {
19 | public init?(argument: String) {
20 | if argument.contains("://") {
21 | self.init(string: argument)
22 | } else {
23 | self.init(filePath: argument)
24 | }
25 | }
26 | }
27 | #endif
28 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/Module+Extensions.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import MLXNN
4 |
5 | extension Module {
6 |
7 | /// Compute the number of parameters in a possibly quantized model
8 | public func numParameters() -> Int {
9 | return leafModules().flattenedValues().map {
10 | mod -> Int in
11 | if let qlin = mod as? QuantizedLinear {
12 | return qlin.scales.size * qlin.groupSize
13 | } else if let qemb = mod as? QuantizedEmbedding {
14 | return qemb.scales.size * qemb.groupSize
15 | } else {
16 | return mod.parameters().flattenedValues().reduce(
17 | 0,
18 | {
19 | $0 + $1.size
20 | })
21 | }
22 | }.reduce(0, +)
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/README.md:
--------------------------------------------------------------------------------
1 | # LoRATrainingExample
2 |
3 | Example application that:
4 |
5 | - downloads the `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model from huggingface
6 | - loads the train/valid/test data from `$SRCROOT/Data/lora` (this is copied into the build but you can imagine how it might be downloaded)
7 | - adds LoRA adapters and trains the model
8 | - let's you evaluate a prompt against the model
9 |
10 | This roughly equates to the command line example in [Tools/llm-tool](../../Tools/llm-tool) and
11 | you can read more about LoRA there.
12 |
13 | This evaluates the LoRA adapted model rather than a fused model. This doesn't persist
14 | the LoRA weights or the fused model -- it will retrain it each time the program is launched.
15 |
16 | ### Troubleshooting
17 |
18 | The `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model requires a little over 4G of
19 | memory to load an train -- this may require ~6G of physical RAM.
20 |
21 |
22 |
--------------------------------------------------------------------------------
/Tools/mnist-tool/README.md:
--------------------------------------------------------------------------------
1 | # mnist-tool
2 |
3 | See the [MNIST README.md](../../Libraries/MNIST/README.md).
4 |
5 | ### Building
6 |
7 | `mnist-tool` has no dependencies outside of the package dependencies
8 | represented in xcode.
9 |
10 | When you run the tool it will download the test/train datasets and
11 | store them in a specified directory (see run arguments -- default is /tmp).
12 |
13 | Simply build the project in xcode.
14 |
15 | ### Running (Xcode)
16 |
17 | To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
18 |
19 | ```
20 | --data /tmp
21 | ```
22 |
23 | Then cmd-r to run.
24 |
25 | ### Running (CommandLine)
26 |
27 | Use the `mlx-run` script to run the command line tools:
28 |
29 | ```
30 | ./mlx-run mnist-tool --data /tmp
31 | ```
32 |
33 | By default this will find and run the tools built in _Release_ configuration. Specify `--debug`
34 | to find and run the tool built in _Debug_ configuration.
35 |
36 | See also:
37 |
38 | - [MLX troubleshooting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/troubleshooting)
39 |
--------------------------------------------------------------------------------
/Libraries/MLXLLM/LLMModel.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import MLX
4 | import MLXLMCommon
5 |
6 | /// Marker protocol for LLMModels
7 | public protocol LLMModel: LanguageModel, LoRAModel {
8 | }
9 |
10 | extension LLMModel {
11 |
12 | /// Default prepare step for ``LLMModel``.
13 | ///
14 | /// This will evaluate the prompt in chunks until there is a small amount of
15 | /// tokens left to feed into the `TokenIterator`.
16 | public func prepare(_ input: LMInput, cache: [KVCache], windowSize: Int?) throws
17 | -> PrepareResult
18 | {
19 | let prefillStepSize = windowSize ?? 512
20 | var y = input.text
21 | var state: LMOutput.State? = nil
22 |
23 | // prepare the prompt in chunks if larger than the prefill size
24 | while y.tokens.size > prefillStepSize {
25 | let input = y[.newAxis, .. UInt64 {
24 | self.state &+= 0x9e37_79b9_7f4a_7c15
25 | var z: UInt64 = self.state
26 | z = (z ^ (z &>> 30)) &* 0xbf58_476d_1ce4_e5b9
27 | z = (z ^ (z &>> 27)) &* 0x94d0_49bb_1331_11eb
28 | return z ^ (z &>> 31)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/mlx-run:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Wrapper to help run command line tools -- this will find the build directory
4 | # and set the DYLD_FRAMEWORK_PATH so that command line tools that link frameworks
5 | # can be run.
6 | #
7 | # Example:
8 | # ./mlx-run --debug llm-tool --help
9 |
10 | if [ "$#" -lt 1 ]; then
11 | echo "usage: mlx-run [--debug/--release] arguments"
12 | exit 1
13 | fi
14 |
15 | CONFIGURATION=Release
16 | if [ "$1" == "--release" ]; then
17 | CONFIGURATION=Release
18 | shift
19 | fi
20 | if [ "$1" == "--debug" ]; then
21 | CONFIGURATION=Debug
22 | shift
23 | fi
24 | if [ "$1" == "--list" ]; then
25 | xcodebuild -list
26 | exit 0
27 | fi
28 |
29 | COMMAND="$1"
30 | shift
31 |
32 | BUILD_DIR=`xcodebuild -configuration $CONFIGURATION -showBuildSettings -scheme $COMMAND | grep 'BUILT_PRODUCTS_DIR = /' | sed -e 's/^[^=]*= //g'`
33 |
34 | if [ -d "$BUILD_DIR/$COMMAND.app" ]; then
35 | exec $BUILD_DIR/$COMMAND.app/Contents/MacOS/$COMMAND "$@" &
36 | fi
37 |
38 | if [ -f "$BUILD_DIR/$COMMAND" ]; then
39 | export DYLD_FRAMEWORK_PATH=$BUILD_DIR/PackageFrameworks:$BUILD_DIR
40 | exec "$BUILD_DIR/$COMMAND" "$@"
41 | else
42 | echo "$BUILD_DIR/$COMMAND does not exist -- check build configuration ($CONFIGURATION)"
43 | exit 1
44 | fi
45 |
46 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | },
8 | {
9 | "idiom" : "mac",
10 | "scale" : "1x",
11 | "size" : "16x16"
12 | },
13 | {
14 | "idiom" : "mac",
15 | "scale" : "2x",
16 | "size" : "16x16"
17 | },
18 | {
19 | "idiom" : "mac",
20 | "scale" : "1x",
21 | "size" : "32x32"
22 | },
23 | {
24 | "idiom" : "mac",
25 | "scale" : "2x",
26 | "size" : "32x32"
27 | },
28 | {
29 | "idiom" : "mac",
30 | "scale" : "1x",
31 | "size" : "128x128"
32 | },
33 | {
34 | "idiom" : "mac",
35 | "scale" : "2x",
36 | "size" : "128x128"
37 | },
38 | {
39 | "idiom" : "mac",
40 | "scale" : "1x",
41 | "size" : "256x256"
42 | },
43 | {
44 | "idiom" : "mac",
45 | "scale" : "2x",
46 | "size" : "256x256"
47 | },
48 | {
49 | "idiom" : "mac",
50 | "scale" : "1x",
51 | "size" : "512x512"
52 | },
53 | {
54 | "idiom" : "mac",
55 | "scale" : "2x",
56 | "size" : "512x512"
57 | }
58 | ],
59 | "info" : {
60 | "author" : "xcode",
61 | "version" : 1
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | },
8 | {
9 | "idiom" : "mac",
10 | "scale" : "1x",
11 | "size" : "16x16"
12 | },
13 | {
14 | "idiom" : "mac",
15 | "scale" : "2x",
16 | "size" : "16x16"
17 | },
18 | {
19 | "idiom" : "mac",
20 | "scale" : "1x",
21 | "size" : "32x32"
22 | },
23 | {
24 | "idiom" : "mac",
25 | "scale" : "2x",
26 | "size" : "32x32"
27 | },
28 | {
29 | "idiom" : "mac",
30 | "scale" : "1x",
31 | "size" : "128x128"
32 | },
33 | {
34 | "idiom" : "mac",
35 | "scale" : "2x",
36 | "size" : "128x128"
37 | },
38 | {
39 | "idiom" : "mac",
40 | "scale" : "1x",
41 | "size" : "256x256"
42 | },
43 | {
44 | "idiom" : "mac",
45 | "scale" : "2x",
46 | "size" : "256x256"
47 | },
48 | {
49 | "idiom" : "mac",
50 | "scale" : "1x",
51 | "size" : "512x512"
52 | },
53 | {
54 | "idiom" : "mac",
55 | "scale" : "2x",
56 | "size" : "512x512"
57 | }
58 | ],
59 | "info" : {
60 | "author" : "xcode",
61 | "version" : 1
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | },
8 | {
9 | "idiom" : "mac",
10 | "scale" : "1x",
11 | "size" : "16x16"
12 | },
13 | {
14 | "idiom" : "mac",
15 | "scale" : "2x",
16 | "size" : "16x16"
17 | },
18 | {
19 | "idiom" : "mac",
20 | "scale" : "1x",
21 | "size" : "32x32"
22 | },
23 | {
24 | "idiom" : "mac",
25 | "scale" : "2x",
26 | "size" : "32x32"
27 | },
28 | {
29 | "idiom" : "mac",
30 | "scale" : "1x",
31 | "size" : "128x128"
32 | },
33 | {
34 | "idiom" : "mac",
35 | "scale" : "2x",
36 | "size" : "128x128"
37 | },
38 | {
39 | "idiom" : "mac",
40 | "scale" : "1x",
41 | "size" : "256x256"
42 | },
43 | {
44 | "idiom" : "mac",
45 | "scale" : "2x",
46 | "size" : "256x256"
47 | },
48 | {
49 | "idiom" : "mac",
50 | "scale" : "1x",
51 | "size" : "512x512"
52 | },
53 | {
54 | "idiom" : "mac",
55 | "scale" : "2x",
56 | "size" : "512x512"
57 | }
58 | ],
59 | "info" : {
60 | "author" : "xcode",
61 | "version" : 1
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | },
8 | {
9 | "idiom" : "mac",
10 | "scale" : "1x",
11 | "size" : "16x16"
12 | },
13 | {
14 | "idiom" : "mac",
15 | "scale" : "2x",
16 | "size" : "16x16"
17 | },
18 | {
19 | "idiom" : "mac",
20 | "scale" : "1x",
21 | "size" : "32x32"
22 | },
23 | {
24 | "idiom" : "mac",
25 | "scale" : "2x",
26 | "size" : "32x32"
27 | },
28 | {
29 | "idiom" : "mac",
30 | "scale" : "1x",
31 | "size" : "128x128"
32 | },
33 | {
34 | "idiom" : "mac",
35 | "scale" : "2x",
36 | "size" : "128x128"
37 | },
38 | {
39 | "idiom" : "mac",
40 | "scale" : "1x",
41 | "size" : "256x256"
42 | },
43 | {
44 | "idiom" : "mac",
45 | "scale" : "2x",
46 | "size" : "256x256"
47 | },
48 | {
49 | "idiom" : "mac",
50 | "scale" : "1x",
51 | "size" : "512x512"
52 | },
53 | {
54 | "idiom" : "mac",
55 | "scale" : "2x",
56 | "size" : "512x512"
57 | }
58 | ],
59 | "info" : {
60 | "author" : "xcode",
61 | "version" : 1
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/README.md:
--------------------------------------------------------------------------------
1 | # StableDiffusionExample
2 |
3 | An example application that runs the StableDiffusion example code.
4 |
5 | See also [image-tool](../../Tools/image-tool) for a command line example.
6 |
7 | This example application accepts a prompt and used the StableDiffusion example
8 | library to render an image using:
9 |
10 | - [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo)
11 |
12 | Please refer to that model for license and other information.
13 |
14 | If you are interested in adjusting the generated images, look in
15 | [ContentView.swift](ContentView.swift) at this method:
16 |
17 | ```swift
18 | func generate(prompt: String, negativePrompt: String, showProgress: Bool) async
19 | ```
20 |
21 | ### Troubleshooting
22 |
23 | Stable diffusion can run in less that 4G available memory (typically a
24 | device or computer with 6G of memory or more) in a constrained mode -- it will
25 | load and unload parts of the model as it runs and it can only perform one step
26 | of diffusion. This is configured automatically, see `modelFactory.conserveMemory`
27 | in [ContentView.swift](ContentView.swift).
28 |
29 | On a device or computer with more memory the model will be kept resident and
30 | images can be regenerated much more efficiently.
31 |
32 | If the program exits while generating the image it may have exceeded the available
33 | memory.
34 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/Registries/ModelTypeRegistry.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 |
5 | open class ModelTypeRegistry: @unchecked Sendable {
6 |
7 | /// Creates an empty registry.
8 | public init() {
9 | self.creators = [:]
10 | }
11 |
12 | /// Creates a registry with given creators.
13 | public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) {
14 | self.creators = creators
15 | }
16 |
17 | // Note: using NSLock as we have very small (just dictionary get/set)
18 | // critical sections and expect no contention. this allows the methods
19 | // to remain synchronous.
20 | private let lock = NSLock()
21 | private var creators: [String: @Sendable (URL) throws -> any LanguageModel]
22 |
23 | /// Add a new model to the type registry.
24 | public func registerModelType(
25 | _ type: String, creator: @Sendable @escaping (URL) throws -> any LanguageModel
26 | ) {
27 | lock.withLock {
28 | creators[type] = creator
29 | }
30 | }
31 |
32 | /// Given a `modelType` and configuration file instantiate a new `LanguageModel`.
33 | public func createModel(configuration: URL, modelType: String) throws -> LanguageModel {
34 | let creator = lock.withLock {
35 | creators[modelType]
36 | }
37 | guard let creator else {
38 | throw ModelFactoryError.unsupportedModelType(modelType)
39 | }
40 | return try creator(configuration)
41 | }
42 |
43 | }
44 |
--------------------------------------------------------------------------------
/Libraries/Embedders/README.md:
--------------------------------------------------------------------------------
1 | # MLXEmbedders
2 |
3 | This directory contains ports of popular Encoders / Embedding Models.
4 |
5 | ## Usage Example
6 |
7 | ```swift
8 |
9 | let modelContainer = try await MLXEmbedders.loadModelContainer(
10 | configuration: ModelConfiguration.nomic_text_v1_5)
11 | let result = await modelContainer.perform {
12 | (model: EmbeddingModel, tokenizer, pooling) -> [[Float]] in
13 | let inputs = [
14 | "search_query: Animals in Tropical Climates.",
15 | "search_document: Elephants",
16 | "search_document: Horses",
17 | "search_document: Polar Bears",
18 | ].map {
19 | tokenizer.encode(text: $0, addSpecialTokens: true)
20 | }
21 | // Pad to longest
22 | let maxLength = inputs.reduce(into: 16) { acc, elem in
23 | acc = max(acc, elem.count)
24 | }
25 |
26 | let padded = stacked(
27 | inputs.map { elem in
28 | MLXArray(
29 | elem
30 | + Array(
31 | repeating: tokenizer.eosTokenId ?? 0,
32 | count: maxLength - elem.count))
33 | })
34 | let mask = (padded .!= tokenizer.eosTokenId ?? 0)
35 | let tokenTypes = MLXArray.zeros(like: padded)
36 | let result = pooling(
37 | model(padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask),
38 | normalize: true, applyLayerNorm: true
39 | ).eval()
40 | return result.map { $0.asArray(Float.self) }
41 | }
42 | ```
43 |
44 |
45 | Ported to swift from [taylorai/mlx_embedding_models](https://github.com/taylorai/mlx_embedding_models/tree/main)
46 |
--------------------------------------------------------------------------------
/Applications/VLMEval/README.md:
--------------------------------------------------------------------------------
1 | # VLMEval
2 |
3 | An example that:
4 |
5 | - downloads a vision language model (Qwen-VL-2B)
6 | - processes an image with a prompt
7 |
8 | You will need to set the Team on the VLMEval target in order to build and run on macOS.
9 |
10 | Some notes about the setup:
11 |
12 | - This downloads models from hugging face so VLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
13 | - VLM models are large so this uses significant memory
14 | - The example processes images and provides detailed analysis
15 |
16 | ### Image Processing
17 |
18 | The example application uses Qwen-VL-2B model by default, see [ContentView.swift](ContentView.swift):
19 |
20 | ```swift
21 | self.modelContainer = try await VLMModelFactory.shared.loadContainer(
22 | configuration: ModelRegistry.qwen2VL2BInstruct4Bit)
23 | ```
24 |
25 | The application:
26 | 1. Downloads a sample image
27 | 2. Processes it through the vision language model
28 | 3. Describes the images based on the prompt, providing detailed analysis of the content, objects, colors, and composition.
29 |
30 | ### Troubleshooting
31 |
32 | If the program crashes with a very deep stack trace you may need to build
33 | in Release configuration. This seems to depend on the size of the model.
34 |
35 | There are a couple options:
36 |
37 | - Build Release
38 | - Force the model evaluation to run on the main thread, e.g. using @MainActor
39 | - Build `Cmlx` with optimizations by modifying `mlx/Package.swift` and adding `.unsafeOptions(["-O3"]),`
40 |
41 | ### Performance
42 |
43 | You may find that running outside the debugger boosts performance. You can do this in Xcode by pressing cmd-opt-r and unchecking "Debug Executable".
44 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/Registries/ProcessorTypeRegistry.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import Tokenizers
5 |
6 | open class ProcessorTypeRegistry: @unchecked Sendable {
7 |
8 | /// Creates an empty registry.
9 | public init() {
10 | self.creators = [:]
11 | }
12 |
13 | /// Creates a registry with given creators.
14 | public init(creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor])
15 | {
16 | self.creators = creators
17 | }
18 |
19 | // Note: using NSLock as we have very small (just dictionary get/set)
20 | // critical sections and expect no contention. this allows the methods
21 | // to remain synchronous.
22 | private let lock = NSLock()
23 |
24 | private var creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor]
25 |
26 | /// Add a new model to the type registry.
27 | public func registerProcessorType(
28 | _ type: String,
29 | creator: @Sendable @escaping (
30 | URL,
31 | any Tokenizer
32 | ) throws -> any UserInputProcessor
33 | ) {
34 | lock.withLock {
35 | creators[type] = creator
36 | }
37 | }
38 |
39 | /// Given a `processorType` and configuration file instantiate a new `UserInputProcessor`.
40 | public func createModel(configuration: URL, processorType: String, tokenizer: any Tokenizer)
41 | throws -> any UserInputProcessor
42 | {
43 | let creator = lock.withLock {
44 | creators[processorType]
45 | }
46 | guard let creator else {
47 | throw ModelFactoryError.unsupportedProcessorType(processorType)
48 | }
49 | return try creator(configuration, tokenizer)
50 | }
51 |
52 | }
53 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/Registries/AbstractModelRegistry.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 |
5 | open class AbstractModelRegistry: @unchecked Sendable {
6 |
7 | /// Creates an empty registry.
8 | public init() {
9 | self.registry = Dictionary()
10 | }
11 |
12 | /// Creates a new registry with from given model configurations.
13 | public init(modelConfigurations: [ModelConfiguration]) {
14 | self.registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) })
15 | }
16 |
17 | private let lock = NSLock()
18 | private var registry: [String: ModelConfiguration]
19 |
20 | public func register(configurations: [ModelConfiguration]) {
21 | lock.withLock {
22 | for c in configurations {
23 | registry[c.name] = c
24 | }
25 | }
26 | }
27 |
28 | /// Returns configuration from ``modelRegistry``.
29 | ///
30 | /// - Note: If the id doesn't exists in the configuration, this will return a new instance of it.
31 | /// If you want to check if the configuration in model registry, you should use ``contains(id:)``.
32 | public func configuration(id: String) -> ModelConfiguration {
33 | lock.withLock {
34 | if let c = registry[id] {
35 | return c
36 | } else {
37 | return ModelConfiguration(id: id)
38 | }
39 | }
40 | }
41 |
42 | /// Returns true if the registry contains a model with the id. Otherwise, false.
43 | public func contains(id: String) -> Bool {
44 | lock.withLock {
45 | registry[id] != nil
46 | }
47 | }
48 |
49 | public var models: some Collection & Sendable {
50 | lock.withLock {
51 | return registry.values
52 | }
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1
2 |
3 | orbs:
4 | apple: ml-explore/pr-approval@0.1.0
5 |
6 | parameters:
7 | nightly_build:
8 | type: boolean
9 | default: false
10 | weekly_build:
11 | type: boolean
12 | default: false
13 |
14 | jobs:
15 |
16 | mac_build_and_test:
17 | macos:
18 | xcode: 16.0.0
19 | resource_class: macos.m1.medium.gen1
20 | steps:
21 | - checkout
22 | - run: git submodule sync
23 | - run: git submodule update --init
24 | - run:
25 | name: Run style checks
26 | command: |
27 | pip install pre-commit
28 | brew install swift-format
29 | pre-commit run --all
30 | if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
31 | - run:
32 | name: Build Examples
33 | command: |
34 | xcodebuild -version
35 | xcrun --show-sdk-build-version
36 | swift --version
37 | find . -name Package.resolved -exec rm {} \;
38 | xcodebuild -scheme llm-tool
39 | xcodebuild -scheme image-tool
40 | xcodebuild -scheme mnist-tool
41 |
42 | workflows:
43 | build_and_test:
44 | when:
45 | and:
46 | - matches:
47 | pattern: "^(?!pull/)[-\\w]+$"
48 | value: << pipeline.git.branch >>
49 | - not: << pipeline.parameters.nightly_build >>
50 | - not: << pipeline.parameters.weekly_build >>
51 | jobs:
52 | - mac_build_and_test
53 |
54 | prb:
55 | when:
56 | matches:
57 | pattern: "^pull/\\d+(/head)?$"
58 | value: << pipeline.git.branch >>
59 | jobs:
60 | - hold:
61 | type: approval
62 | - apple/authenticate:
63 | context: pr-approval
64 | - mac_build_and_test:
65 | requires: [ hold ]
66 |
--------------------------------------------------------------------------------
/Libraries/StableDiffusion/README.md:
--------------------------------------------------------------------------------
1 | # Stable Diffusion
2 |
3 | Stable Diffusion in MLX. The implementation was ported from Hugging Face's
4 | [diffusers](https://huggingface.co/docs/diffusers/index) and
5 | [mlx-examples/stable_diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
6 | Model weights are downloaded directly from the Hugging Face hub. The implementation currently
7 | supports the following models:
8 |
9 | - [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo)
10 | - [stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1)
11 |
12 | ## Usage
13 |
14 | See [StableDiffusionExample](../../Applications/StableDiffusionExample) and
15 | [image-tool](../../Tools/image-tool) for examples of using this code.
16 |
17 | The basic sequence is:
18 |
19 | - download & load the model
20 | - generate latents
21 | - evaluate the latents one by one
22 | - decode the last latent generated
23 | - you have an image!
24 |
25 | ```swift
26 | let configuration = StableDiffusionConfiguration.presetSDXLTurbo
27 |
28 | let generator = try configuration.textToImageGenerator(
29 | configuration: model.loadConfiguration)
30 |
31 | generator.ensureLoaded()
32 |
33 | // generate the latents -- these are the iterations for generating
34 | // the output image. this is just generating the evaluation graph
35 | let parameters = generate.evaluateParameters(configuration: configuration)
36 | let latents = generator.generateLatents(parameters: parameters)
37 |
38 | // evaluate the latents (evalue the graph) and keep the last value generated
39 | var lastXt: MLXArray?
40 | for xt in latents {
41 | eval(xt)
42 | lastXt = xt
43 | }
44 |
45 | // decode the final latent into an image
46 | if let lastXt {
47 | var raster = decoder(lastXt[0])
48 | raster = (image * 255).asType(.uint8).squeezed()
49 | eval(raster)
50 |
51 | // turn it into a CGImage
52 | let image = Image(raster).asCGImage()
53 |
54 | // or write it out
55 | try Image(raster).save(url: url)
56 | }
57 | ```
58 |
--------------------------------------------------------------------------------
/Libraries/MLXLLM/SuScaledRotaryEmbedding.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import MLX
3 | import MLXFast
4 | import MLXNN
5 |
6 | public class SuScaledRotaryEmbedding: Module {
7 | let dimensions: Int
8 | let maxPositionEmbeddings: Int
9 | let originalMaxPositionEmbeddings: Int
10 | let scale: Float
11 | let _freqs: MLXArray
12 |
13 | public init(
14 | dimensions: Int,
15 | base: Float = 10000.0,
16 | maxPositionEmbeddings: Int = 131072,
17 | originalMaxPositionEmbeddings: Int = 4096,
18 | longFactor: [Float] = [1.0],
19 | // shortMScale: Float? = nil,
20 | longMScale: Float? = nil
21 | ) {
22 | precondition(dimensions % 2 == 0, "Dimensions must be even")
23 |
24 | self.dimensions = dimensions
25 | self.maxPositionEmbeddings = maxPositionEmbeddings
26 | self.originalMaxPositionEmbeddings = originalMaxPositionEmbeddings
27 |
28 | let exponent =
29 | MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / Float(dimensions)
30 | let freqs = MLX.pow(MLXArray(base), exponent)
31 | self._freqs = MLXArray(longFactor).asType(.float32) * freqs
32 |
33 | self.scale =
34 | longMScale
35 | ?? sqrt(
36 | 1 + log(Float(maxPositionEmbeddings) / Float(originalMaxPositionEmbeddings))
37 | / log(Float(originalMaxPositionEmbeddings))
38 | )
39 | }
40 |
41 | public func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray {
42 | // Apply scaling only to the dimensions that will be rotated
43 | var scaledX = x
44 | let sliceToScale = scaledX[.ellipsis, 0 ..< dimensions]
45 | scaledX[.ellipsis, 0 ..< dimensions] = scale * sliceToScale
46 |
47 | return MLXFast.RoPE(
48 | scaledX,
49 | dimensions: dimensions,
50 | traditional: false,
51 | base: nil,
52 | scale: 1.0,
53 | offset: offset,
54 | freqs: self._freqs
55 | )
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/Libraries/MLXLLM/Lora+Data.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 |
5 | enum LoRADataError: LocalizedError {
6 | case fileNotFound(URL, String)
7 |
8 | var errorDescription: String? {
9 | switch self {
10 | case .fileNotFound(let directory, let name):
11 | return String(
12 | localized: "Could not find data file '\(name)' in directory '\(directory.path())'.")
13 | }
14 | }
15 | }
16 |
17 | /// Load a LoRA data file.
18 | ///
19 | /// Given a directory and a base name, e.g. `train`, this will load a `.jsonl` or `.txt` file
20 | /// if possible.
21 | public func loadLoRAData(directory: URL, name: String) throws -> [String] {
22 | let extensions = ["jsonl", "txt"]
23 |
24 | for ext in extensions {
25 | let url = directory.appending(component: "\(name).\(ext)")
26 | if FileManager.default.fileExists(atPath: url.path()) {
27 | return try loadLoRAData(url: url)
28 | }
29 | }
30 |
31 | throw LoRADataError.fileNotFound(directory, name)
32 | }
33 |
34 | /// Load a .txt or .jsonl file and return the contents
35 | public func loadLoRAData(url: URL) throws -> [String] {
36 | switch url.pathExtension {
37 | case "jsonl":
38 | return try loadJSONL(url: url)
39 |
40 | case "txt":
41 | return try loadLines(url: url)
42 |
43 | default:
44 | fatalError("Unable to load data file, unknown type: \(url)")
45 |
46 | }
47 | }
48 |
49 | func loadJSONL(url: URL) throws -> [String] {
50 |
51 | struct Line: Codable {
52 | let text: String?
53 | }
54 |
55 | return try String(contentsOf: url)
56 | .components(separatedBy: .newlines)
57 | .filter {
58 | $0.first == "{"
59 | }
60 | .compactMap {
61 | try JSONDecoder().decode(Line.self, from: $0.data(using: .utf8)!).text
62 | }
63 | }
64 |
65 | func loadLines(url: URL) throws -> [String] {
66 | try String(contentsOf: url)
67 | .components(separatedBy: .newlines)
68 | .filter { !$0.isEmpty }
69 | }
70 |
--------------------------------------------------------------------------------
/Libraries/Embedders/Tokenizer.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import Hub
5 | import Tokenizers
6 |
7 | public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer
8 | {
9 | let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig(
10 | configuration: configuration, hub: hub)
11 |
12 | return try PreTrainedTokenizer(
13 | tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
14 | }
15 |
16 | func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async throws -> (
17 | Config, Config
18 | ) {
19 | // from AutoTokenizer.from() -- this lets us override parts of the configuration
20 | let config: LanguageModelConfigurationFromHub
21 |
22 | switch configuration.id {
23 | case .id(let id):
24 | do {
25 | // the load can fail (async when we try to use it)
26 | let loaded = LanguageModelConfigurationFromHub(
27 | modelName: configuration.tokenizerId ?? id, hubApi: hub)
28 | _ = try await loaded.tokenizerConfig
29 | config = loaded
30 | } catch {
31 | let nserror = error as NSError
32 | if nserror.domain == NSURLErrorDomain
33 | && nserror.code == NSURLErrorNotConnectedToInternet
34 | {
35 | // Internet connection appears to be offline -- fall back to loading from
36 | // the local directory
37 | config = LanguageModelConfigurationFromHub(
38 | modelFolder: configuration.modelDirectory(hub: hub), hubApi: hub)
39 | } else {
40 | throw error
41 | }
42 | }
43 | case .directory(let directory):
44 | config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub)
45 | }
46 |
47 | guard let tokenizerConfig = try await config.tokenizerConfig else {
48 | throw EmbedderError(message: "missing config")
49 | }
50 | let tokenizerData = try await config.tokenizerData
51 | return (tokenizerConfig, tokenizerData)
52 | }
53 |
--------------------------------------------------------------------------------
/Applications/LLMEval/README.md:
--------------------------------------------------------------------------------
1 | # LLMEval
2 |
3 | An example that:
4 |
5 | - downloads a huggingface model (phi-2) and tokenizer
6 | - evaluates a prompt
7 | - displays the output as it generates text
8 |
9 | You will need to set the Team on the LLMEval target in order to build and run on iOS.
10 |
11 | Some notes about the setup:
12 |
13 | - this downloads models from hugging face so LLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
14 | - LLM models are large so this uses the Increased Memory Limit entitlement on iOS to allow ... increased memory limits for devices that have more memory
15 | - `MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)` is used to limit the buffer cache size
16 | - The Phi2 4 bit model is small enough to run on some iPhone models
17 | - this can be changed by editing `let modelConfiguration = ModelConfiguration.phi4bit`
18 |
19 | ### Trying Different Models
20 |
21 | The example application uses Phi2 model by default, see [ContentView.swift](ContentView.swift#L58):
22 |
23 | ```
24 | /// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
25 | /// more devices
26 | let modelConfiguration = ModelConfiguration.phi4bit
27 | ```
28 |
29 | There are some pre-configured models in [MLXLLM/LLMModelFactory.swift](../../Libraries/MLXLLM/LLMModelFactory.swift#L78)
30 | and you can load any weights from Hugging Face where there
31 | is a model architecture defined and you have enough
32 | memory.
33 |
34 | ### Troubleshooting
35 |
36 | If the program crashes with a very deep stack trace you may need to build
37 | in Release configuration. This seems to depend on the size of the model.
38 |
39 | There are a couple options:
40 |
41 | - build Release
42 | - force the model evaluation to run on the main thread, e.g. using @MainActor
43 | - build `Cmlx` with optimizations by modifying `mlx/Package.swift` and adding `.unsafeOptions(["-O3"]),` around line 87
44 |
45 | See discussion here: https://github.com/ml-explore/mlx-swift-examples/issues/3
46 |
47 | ### Performance
48 |
49 | Different models have difference performance characteristics. For example Gemma 2B may outperform Phi-2 in terms of tokens / second.
50 |
51 | You may also find that running outside the debugger boosts performance. You can do this in Xcode by pressing cmd-opt-r and unchecking "Debug Executable".
52 |
--------------------------------------------------------------------------------
/Package.resolved:
--------------------------------------------------------------------------------
1 | {
2 | "pins" : [
3 | {
4 | "identity" : "gzipswift",
5 | "kind" : "remoteSourceControl",
6 | "location" : "https://github.com/1024jp/GzipSwift",
7 | "state" : {
8 | "revision" : "731037f6cc2be2ec01562f6597c1d0aa3fe6fd05",
9 | "version" : "6.0.1"
10 | }
11 | },
12 | {
13 | "identity" : "jinja",
14 | "kind" : "remoteSourceControl",
15 | "location" : "https://github.com/johnmai-dev/Jinja",
16 | "state" : {
17 | "revision" : "bbddb92fc51ae420b87300298370fd1dfc308f73",
18 | "version" : "1.1.1"
19 | }
20 | },
21 | {
22 | "identity" : "mlx-swift",
23 | "kind" : "remoteSourceControl",
24 | "location" : "https://github.com/ml-explore/mlx-swift",
25 | "state" : {
26 | "revision" : "70dbb62128a5a1471a5ab80363430adb33470cab",
27 | "version" : "0.21.2"
28 | }
29 | },
30 | {
31 | "identity" : "swift-argument-parser",
32 | "kind" : "remoteSourceControl",
33 | "location" : "https://github.com/apple/swift-argument-parser.git",
34 | "state" : {
35 | "revision" : "0fbc8848e389af3bb55c182bc19ca9d5dc2f255b",
36 | "version" : "1.4.0"
37 | }
38 | },
39 | {
40 | "identity" : "swift-async-algorithms",
41 | "kind" : "remoteSourceControl",
42 | "location" : "https://github.com/apple/swift-async-algorithms",
43 | "state" : {
44 | "revision" : "da4e36f86544cdf733a40d59b3a2267e3a7bbf36",
45 | "version" : "1.0.0"
46 | }
47 | },
48 | {
49 | "identity" : "swift-collections",
50 | "kind" : "remoteSourceControl",
51 | "location" : "https://github.com/apple/swift-collections.git",
52 | "state" : {
53 | "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7",
54 | "version" : "1.1.4"
55 | }
56 | },
57 | {
58 | "identity" : "swift-numerics",
59 | "kind" : "remoteSourceControl",
60 | "location" : "https://github.com/apple/swift-numerics",
61 | "state" : {
62 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b",
63 | "version" : "1.0.2"
64 | }
65 | },
66 | {
67 | "identity" : "swift-transformers",
68 | "kind" : "remoteSourceControl",
69 | "location" : "https://github.com/huggingface/swift-transformers",
70 | "state" : {
71 | "revision" : "be855fac725dbae27264e47a3eb535cc422a4ba8",
72 | "version" : "0.1.18"
73 | }
74 | }
75 | ],
76 | "version" : 2
77 | }
78 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Xcode
2 | #
3 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
4 |
5 | ## User settings
6 | xcuserdata/
7 |
8 | ## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9)
9 | *.xcscmblueprint
10 | *.xccheckout
11 |
12 | ## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4)
13 | build/
14 | DerivedData/
15 | *.moved-aside
16 | *.pbxuser
17 | !default.pbxuser
18 | *.mode1v3
19 | !default.mode1v3
20 | *.mode2v3
21 | !default.mode2v3
22 | *.perspectivev3
23 | !default.perspectivev3
24 |
25 | ## Obj-C/Swift specific
26 | *.hmap
27 |
28 | ## App packaging
29 | *.ipa
30 | *.dSYM.zip
31 | *.dSYM
32 |
33 | ## Playgrounds
34 | timeline.xctimeline
35 | playground.xcworkspace
36 |
37 | # Swift Package Manager
38 | #
39 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies.
40 | Packages/
41 | Package.pins
42 | Package.resolved
43 | # *.xcodeproj
44 | #
45 | # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata
46 | # hence it is not needed unless you have added a package configuration file to your project
47 | .swiftpm
48 |
49 | .build/
50 |
51 | # CocoaPods
52 | #
53 | # We recommend against adding the Pods directory to your .gitignore. However
54 | # you should judge for yourself, the pros and cons are mentioned at:
55 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control
56 | #
57 | # Pods/
58 | #
59 | # Add this line if you want to avoid checking in source code from the Xcode workspace
60 | # *.xcworkspace
61 |
62 | # Carthage
63 | #
64 | # Add this line if you want to avoid checking in source code from Carthage dependencies.
65 | # Carthage/Checkouts
66 |
67 | Carthage/Build/
68 |
69 | # Accio dependency management
70 | Dependencies/
71 | .accio/
72 |
73 | # fastlane
74 | #
75 | # It is recommended to not store the screenshots in the git repo.
76 | # Instead, use fastlane to re-generate the screenshots whenever they are needed.
77 | # For more information about the recommended setup visit:
78 | # https://docs.fastlane.tools/best-practices/source-control/#source-control
79 |
80 | fastlane/report.xml
81 | fastlane/Preview.html
82 | fastlane/screenshots/**/*.png
83 | fastlane/test_output
84 |
85 | # Code Injection
86 | #
87 | # After new code Injection tools there's a generated folder /iOSInjectionProject
88 | # https://github.com/johnno1962/injectionforxcode
89 |
90 | iOSInjectionProject/
91 |
92 | # OS
93 | .DS_Store
94 |
95 | .idea
--------------------------------------------------------------------------------
/Libraries/MLXMNIST/MNIST.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import MLX
5 | import MLXNN
6 |
7 | // based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
8 |
9 | public class LeNet: Module, UnaryLayer {
10 |
11 | @ModuleInfo var conv1: Conv2d
12 | @ModuleInfo var conv2: Conv2d
13 | @ModuleInfo var pool1: MaxPool2d
14 | @ModuleInfo var pool2: MaxPool2d
15 | @ModuleInfo var fc1: Linear
16 | @ModuleInfo var fc2: Linear
17 | @ModuleInfo var fc3: Linear
18 |
19 | override public init() {
20 | conv1 = Conv2d(inputChannels: 1, outputChannels: 6, kernelSize: 5, padding: 2)
21 | conv2 = Conv2d(inputChannels: 6, outputChannels: 16, kernelSize: 5, padding: 0)
22 | pool1 = MaxPool2d(kernelSize: 2, stride: 2)
23 | pool2 = MaxPool2d(kernelSize: 2, stride: 2)
24 | fc1 = Linear(16 * 5 * 5, 120)
25 | fc2 = Linear(120, 84)
26 | fc3 = Linear(84, 10)
27 | }
28 |
29 | public func callAsFunction(_ x: MLXArray) -> MLXArray {
30 | var x = x
31 | x = pool1(tanh(conv1(x)))
32 | x = pool2(tanh(conv2(x)))
33 | x = flattened(x, start: 1)
34 | x = tanh(fc1(x))
35 | x = tanh(fc2(x))
36 | x = fc3(x)
37 | return x
38 | }
39 | }
40 |
41 | public func loss(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
42 | crossEntropy(logits: model(x), targets: y, reduction: .mean)
43 | }
44 |
45 | public func eval(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
46 | mean(argMax(model(x), axis: 1) .== y)
47 | }
48 |
49 | private struct BatchSequence: Sequence, IteratorProtocol {
50 |
51 | let batchSize: Int
52 | let x: MLXArray
53 | let y: MLXArray
54 |
55 | let indexes: MLXArray
56 | var index = 0
57 |
58 | init(batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator)
59 | {
60 | self.batchSize = batchSize
61 | self.x = x
62 | self.y = y
63 | self.indexes = MLXArray(Array(0 ..< y.size).shuffled(using: &generator))
64 | }
65 |
66 | mutating func next() -> (MLXArray, MLXArray)? {
67 | guard index < y.size else { return nil }
68 |
69 | let range = index ..< Swift.min(index + batchSize, y.size)
70 | index += batchSize
71 | let ids = indexes[range]
72 | return (x[ids], y[ids])
73 | }
74 | }
75 |
76 | public func iterateBatches(
77 | batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator
78 | ) -> some Sequence<(MLXArray, MLXArray)> {
79 | BatchSequence(batchSize: batchSize, x: x, y: y, using: &generator)
80 | }
81 |
--------------------------------------------------------------------------------
/Libraries/MLXLLM/Documentation.docc/adding-model.md:
--------------------------------------------------------------------------------
1 | # Adding a Model
2 |
3 | If the model follows the typical LLM pattern you can add a new
4 | model in a few steps.
5 |
6 | - `config.json`, `tokenizer.json`, and `tokenizer_config.json`
7 | - `*.safetensors`
8 |
9 | You can follow the pattern of the models in the [Models](Models) directory
10 | and create a `.swift` file for your new model:
11 |
12 | ## Create a Configuration
13 |
14 | Create a configuration struct to match the `config.json` (any parameters needed).
15 |
16 | ```swift
17 | public struct YourModelConfiguration: Codable, Sendable {
18 | public let hiddenSize: Int
19 |
20 | // use this pattern for values that need defaults
21 | public let _layerNormEps: Float?
22 | public var layerNormEps: Float { _layerNormEps ?? 1e-6 }
23 |
24 | enum CodingKeys: String, CodingKey {
25 | case hiddenSize = "hidden_size"
26 | case _layerNormEps = "layer_norm_eps"
27 | }
28 | }
29 | ```
30 |
31 | ## Create the Model Class
32 |
33 | Create the model class. The top-level public class should have a
34 | structure something like this:
35 |
36 | ```swift
37 | public class YourModel: Module, LLMModel, KVCacheDimensionProvider, LoRAModel {
38 |
39 | public let kvHeads: [Int]
40 |
41 | @ModuleInfo var model: YourModelInner
42 |
43 | public func loraLinearLayers() -> LoRALinearLayers {
44 | // TODO: modify as needed
45 | model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
46 | }
47 |
48 | public init(_ args: YourModelConfiguration) {
49 | self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers)
50 | self.model = YourModelInner(args)
51 | }
52 |
53 | public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
54 | // TODO: modify as needed
55 | let out = model(inputs, cache: cache)
56 | return model.embedTokens.asLinear(out)
57 | }
58 | }
59 | ```
60 |
61 | ## Register the Model
62 |
63 | In [LLMModelFactory.swift](LLMModelFactory.swift) register the model type itself
64 | (this is independent of the model id):
65 |
66 | ```swift
67 | public class ModelTypeRegistry: @unchecked Sendable {
68 | ...
69 | private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [
70 | "yourModel": create(YourModelConfiguration.self, YourModel.init),
71 | ```
72 |
73 | Add a constant for the model in the `ModelRegistry` (not strictly required but useful
74 | for callers to refer to it in code):
75 |
76 | ```swift
77 | public class ModelRegistry: @unchecked Sendable {
78 | ...
79 | static public let yourModel_4bit = ModelConfiguration(
80 | id: "mlx-community/YourModel-4bit",
81 | defaultPrompt: "What is the gravity on Mars and the moon?"
82 | )
83 | ```
84 |
85 | and finally add it to the all list -- this will let users find the model
86 | configuration by id:
87 |
88 | ```swift
89 | private static func all() -> [ModelConfiguration] {
90 | [
91 | codeLlama13b4bit,
92 | ...
93 | yourModel_4bit,
94 | ```
95 |
--------------------------------------------------------------------------------
/Tools/image-tool/Arguments.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 | import MLX
6 |
7 | #if swift(>=5.10)
8 | /// Extension to allow URL command line arguments.
9 | extension URL: @retroactive ExpressibleByArgument {
10 | public init?(argument: String) {
11 | if argument.contains("://") {
12 | self.init(string: argument)
13 | } else {
14 | self.init(filePath: argument)
15 | }
16 | }
17 | }
18 | #else
19 | /// Extension to allow URL command line arguments.
20 | extension URL: ExpressibleByArgument {
21 | public init?(argument: String) {
22 | if argument.contains("://") {
23 | self.init(string: argument)
24 | } else {
25 | self.init(filePath: argument)
26 | }
27 | }
28 | }
29 | #endif
30 |
31 | /// Argument package for adjusting and reporting memory use.
32 | struct MemoryArguments: ParsableArguments, Sendable {
33 |
34 | @Flag(name: .long, help: "Show memory stats")
35 | var memoryStats = false
36 |
37 | @Option(name: .long, help: "Maximum cache size in M")
38 | var cacheSize = 1024
39 |
40 | @Option(name: .long, help: "Maximum memory size in M")
41 | var memorySize: Int?
42 |
43 | var startMemory: GPU.Snapshot?
44 |
45 | mutating func start(_ load: () async throws -> L) async throws -> L {
46 | GPU.set(cacheLimit: cacheSize * 1024 * 1024)
47 |
48 | if let memorySize {
49 | GPU.set(memoryLimit: memorySize * 1024 * 1024)
50 | }
51 |
52 | let result = try await load()
53 | startMemory = GPU.snapshot()
54 |
55 | return result
56 | }
57 |
58 | mutating func start() {
59 | GPU.set(cacheLimit: cacheSize * 1024 * 1024)
60 |
61 | if let memorySize {
62 | GPU.set(memoryLimit: memorySize * 1024 * 1024)
63 | }
64 |
65 | startMemory = GPU.snapshot()
66 | }
67 |
68 | func reportCurrent() {
69 | if memoryStats {
70 | let memory = GPU.snapshot()
71 | print(memory.description)
72 | }
73 | }
74 |
75 | func reportMemoryStatistics() {
76 | if memoryStats, let startMemory {
77 | let endMemory = GPU.snapshot()
78 |
79 | print("=======")
80 | print("Memory size: \(GPU.memoryLimit / 1024)K")
81 | print("Cache size: \(GPU.cacheLimit / 1024)K")
82 |
83 | print("")
84 | print("=======")
85 | print("Starting memory")
86 | print(startMemory.description)
87 |
88 | print("")
89 | print("=======")
90 | print("Ending memory")
91 | print(endMemory.description)
92 |
93 | print("")
94 | print("=======")
95 | print("Growth")
96 | print(startMemory.delta(endMemory).description)
97 |
98 | }
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/ModelConfiguration.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import Hub
5 |
6 | /// Configuration for a given model name with overrides for prompts and tokens.
7 | ///
8 | /// See e.g. `MLXLM.ModelRegistry` for an example of use.
9 | public struct ModelConfiguration: Sendable {
10 |
11 | public enum Identifier: Sendable {
12 | case id(String)
13 | case directory(URL)
14 | }
15 |
16 | public var id: Identifier
17 |
18 | public var name: String {
19 | switch id {
20 | case .id(let string):
21 | string
22 | case .directory(let url):
23 | url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent
24 | }
25 | }
26 |
27 | /// pull the tokenizer from an alternate id
28 | public let tokenizerId: String?
29 |
30 | /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
31 | public let overrideTokenizer: String?
32 |
33 | /// A reasonable default prompt for the model
34 | public var defaultPrompt: String
35 |
36 | /// Additional tokens to use for end of string
37 | public var extraEOSTokens: Set
38 |
39 | public init(
40 | id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
41 | defaultPrompt: String = "hello",
42 | extraEOSTokens: Set = [],
43 | preparePrompt: (@Sendable (String) -> String)? = nil
44 | ) {
45 | self.id = .id(id)
46 | self.tokenizerId = tokenizerId
47 | self.overrideTokenizer = overrideTokenizer
48 | self.defaultPrompt = defaultPrompt
49 | self.extraEOSTokens = extraEOSTokens
50 | }
51 |
52 | public init(
53 | directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
54 | defaultPrompt: String = "hello",
55 | extraEOSTokens: Set = []
56 | ) {
57 | self.id = .directory(directory)
58 | self.tokenizerId = tokenizerId
59 | self.overrideTokenizer = overrideTokenizer
60 | self.defaultPrompt = defaultPrompt
61 | self.extraEOSTokens = extraEOSTokens
62 | }
63 |
64 | public func modelDirectory(hub: HubApi = HubApi()) -> URL {
65 | switch id {
66 | case .id(let id):
67 | // download the model weights and config
68 | let repo = Hub.Repo(id: id)
69 | return hub.localRepoLocation(repo)
70 |
71 | case .directory(let directory):
72 | return directory
73 | }
74 | }
75 | }
76 |
77 | extension ModelConfiguration: Equatable {
78 |
79 | }
80 |
81 | extension ModelConfiguration.Identifier: Equatable {
82 |
83 | public static func == (lhs: ModelConfiguration.Identifier, rhs: ModelConfiguration.Identifier)
84 | -> Bool
85 | {
86 | switch (lhs, rhs) {
87 | case (.id(let lhsID), .id(let rhsID)):
88 | lhsID == rhsID
89 | case (.directory(let lhsURL), .directory(let rhsURL)):
90 | lhsURL == rhsURL
91 | default:
92 | false
93 | }
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/ModelContainer.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import Hub
5 | import MLX
6 | import MLXNN
7 | import Tokenizers
8 |
9 | /// Container for models that guarantees single threaded access.
10 | ///
11 | /// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
12 | /// the model and/or tokenizer (any values from the ``ModelContext``):
13 | ///
14 | /// ```swift
15 | /// let messages = [["role": "user", "content": prompt]]
16 | /// let promptTokens = try await modelContainer.perform { context in
17 | /// try context.tokenizer.applyChatTemplate(messages: messages)
18 | /// }
19 | /// ```
20 | ///
21 | /// or:
22 | ///
23 | /// ```swift
24 | /// let userInput: UserInput
25 | /// let result = await modelContainer.perform { context in
26 | /// let input = try await context.processor.prepare(input: userInput)
27 | /// return generate(
28 | /// input: input, parameters: generateParameters, context: context
29 | /// ) { tokens in
30 | /// ...
31 | /// }
32 | /// }
33 | /// ```
34 | public actor ModelContainer {
35 | var context: ModelContext
36 | public var configuration: ModelConfiguration { context.configuration }
37 |
38 | public init(context: ModelContext) {
39 | self.context = context
40 | }
41 |
42 | /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as
43 | /// `MLXArray` is not `Sendable`.
44 | @available(*, deprecated, message: "prefer perform(_:) that uses a ModelContext")
45 | public func perform(_ action: @Sendable (any LanguageModel, Tokenizer) throws -> R) rethrows
46 | -> R
47 | {
48 | try action(context.model, context.tokenizer)
49 | }
50 |
51 | /// Perform an action on the model and/or tokenizer with additional context values.
52 | /// Callers _must_ eval any `MLXArray` before returning as
53 | /// `MLXArray` is not `Sendable`.
54 | @available(*, deprecated, message: "prefer perform(values:_:) that uses a ModelContext")
55 | public func perform(
56 | values: V, _ action: @Sendable (any LanguageModel, Tokenizer, V) throws -> R
57 | ) rethrows -> R {
58 | try action(context.model, context.tokenizer, values)
59 | }
60 |
61 | /// Perform an action on the ``ModelContext``. Callers _must_ eval any `MLXArray` before returning as
62 | /// `MLXArray` is not `Sendable`.
63 | public func perform(_ action: @Sendable (ModelContext) async throws -> R) async rethrows -> R
64 | {
65 | try await action(context)
66 | }
67 |
68 | /// Perform an action on the ``ModelContext`` with additional context values.
69 | /// Callers _must_ eval any `MLXArray` before returning as
70 | /// `MLXArray` is not `Sendable`.
71 | public func perform(
72 | values: V, _ action: @Sendable (ModelContext, V) async throws -> R
73 | ) async rethrows -> R {
74 | try await action(context, values)
75 | }
76 |
77 | /// Update the owned `ModelContext`.
78 | /// - Parameter action: update action
79 | public func update(_ action: @Sendable (inout ModelContext) -> Void) {
80 | action(&context)
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/LLMEval.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
8 |
9 |
15 |
21 |
22 |
23 |
24 |
25 |
31 |
32 |
42 |
44 |
50 |
51 |
52 |
53 |
59 |
61 |
67 |
68 |
69 |
70 |
72 |
73 |
76 |
77 |
78 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/VLMEval.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
25 |
26 |
32 |
33 |
43 |
45 |
51 |
52 |
53 |
54 |
60 |
62 |
68 |
69 |
70 |
71 |
73 |
74 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/Libraries/MLXMNIST/Files.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import Gzip
5 | import MLX
6 |
7 | // based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py
8 |
9 | public enum Use: String, Hashable, Sendable {
10 | case test
11 | case training
12 | }
13 |
14 | public enum DataKind: String, Hashable, Sendable {
15 | case images
16 | case labels
17 | }
18 |
19 | public struct FileKind: Hashable, CustomStringConvertible, Sendable {
20 | let use: Use
21 | let data: DataKind
22 |
23 | public init(_ use: Use, _ data: DataKind) {
24 | self.use = use
25 | self.data = data
26 | }
27 |
28 | public var description: String {
29 | "\(use.rawValue)-\(data.rawValue)"
30 | }
31 | }
32 |
33 | struct LoadInfo: Sendable {
34 | let name: String
35 | let offset: Int
36 | let convert: @Sendable (MLXArray) -> MLXArray
37 | }
38 |
39 | let baseURL = URL(string: "https://raw.githubusercontent.com/fgnt/mnist/master/")!
40 |
41 | private let files = [
42 | FileKind(.training, .images): LoadInfo(
43 | name: "train-images-idx3-ubyte.gz",
44 | offset: 16,
45 | convert: {
46 | $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
47 | }),
48 | FileKind(.test, .images): LoadInfo(
49 | name: "t10k-images-idx3-ubyte.gz",
50 | offset: 16,
51 | convert: {
52 | $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
53 | }),
54 | FileKind(.training, .labels): LoadInfo(
55 | name: "train-labels-idx1-ubyte.gz",
56 | offset: 8,
57 | convert: {
58 | $0.asType(.uint32)
59 | }),
60 | FileKind(.test, .labels): LoadInfo(
61 | name: "t10k-labels-idx1-ubyte.gz",
62 | offset: 8,
63 | convert: {
64 | $0.asType(.uint32)
65 | }),
66 | ]
67 |
68 | public func download(into: URL) async throws {
69 | for (_, info) in files {
70 | let fileURL = into.appending(component: info.name)
71 | if !FileManager.default.fileExists(atPath: fileURL.path()) {
72 | print("Download: \(info.name)")
73 | let url = baseURL.appending(component: info.name)
74 | let (data, response) = try await URLSession.shared.data(from: url)
75 |
76 | guard let httpResponse = response as? HTTPURLResponse else {
77 | fatalError("Unable to download \(url), not an http response: \(response)")
78 | }
79 | guard httpResponse.statusCode == 200 else {
80 | fatalError("Unable to download \(url): \(httpResponse)")
81 | }
82 |
83 | try data.write(to: fileURL)
84 | }
85 | }
86 | }
87 |
88 | public func load(from: URL) throws -> [FileKind: MLXArray] {
89 | var result = [FileKind: MLXArray]()
90 |
91 | for (key, info) in files {
92 | let fileURL = from.appending(component: info.name)
93 | let data = try Data(contentsOf: fileURL).gunzipped()
94 |
95 | let array = MLXArray(
96 | data.dropFirst(info.offset), [data.count - info.offset], type: UInt8.self)
97 |
98 | result[key] = info.convert(array)
99 | }
100 |
101 | return result
102 | }
103 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/StableDiffusionExample.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
25 |
26 |
32 |
33 |
43 |
45 |
51 |
52 |
53 |
54 |
60 |
62 |
68 |
69 |
70 |
71 |
73 |
74 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/Tools/Tutorial/Tutorial.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import MLX
5 |
6 | /// mlx-swift tutorial based on:
7 | /// https://github.com/ml-explore/mlx/blob/main/examples/cpp/tutorial.cpp
8 | @main
9 | struct Tutorial {
10 |
11 | static func scalarBasics() {
12 | // create a scalar array
13 | let x = MLXArray(1.0)
14 |
15 | // the datatype is .float32
16 | let dtype = x.dtype
17 | assert(dtype == .float32)
18 |
19 | // get the value
20 | let s = x.item(Float.self)
21 | assert(s == 1.0)
22 |
23 | // reading the value with a different type is a fatal error
24 | // let i = x.item(Int.self)
25 |
26 | // scalars have a size of 1
27 | let size = x.size
28 | assert(size == 1)
29 |
30 | // scalars have 0 dimensions
31 | let ndim = x.ndim
32 | assert(ndim == 0)
33 |
34 | // scalar shapes are empty arrays
35 | let shape = x.shape
36 | assert(shape == [])
37 | }
38 |
39 | static func arrayBasics() {
40 | // make a multidimensional array.
41 | //
42 | // Note: the argument is a [Double] array literal, which is not
43 | // a supported type, but we can explicitly convert it to [Float]
44 | // when we create the MLXArray.
45 | let x = MLXArray(converting: [1.0, 2.0, 3.0, 4.0], [2, 2])
46 |
47 | // mlx is row-major by default so the first row of this array
48 | // is [1.0, 2.0] and the second row is [3.0, 4.0]
49 | print(x[0])
50 | print(x[1])
51 |
52 | // make an array of shape [2, 2] filled with ones
53 | let y = MLXArray.ones([2, 2])
54 |
55 | // pointwise add x and y
56 | let z = x + y
57 |
58 | // mlx is lazy by default. At this point `z` only
59 | // has a shape and a type but no actual data
60 | assert(z.dtype == .float32)
61 | assert(z.shape == [2, 2])
62 |
63 | // To actually run the computation you must evaluate `z`.
64 | // Under the hood, mlx records operations in a graph.
65 | // The variable `z` is a node in the graph which points to its operation
66 | // and inputs. When `eval` is called on an array (or arrays), the array and
67 | // all of its dependencies are recursively evaluated to produce the result.
68 | // Once an array is evaluated, it has data and is detached from its inputs.
69 |
70 | // Note: this is being called for demonstration purposes -- all reads
71 | // ensure the array is evaluated.
72 | z.eval()
73 |
74 | // this implicitly evaluates z before converting to a description
75 | print(z)
76 | }
77 |
78 | static func automaticDifferentiation() {
79 | func fn(_ x: MLXArray) -> MLXArray {
80 | x.square()
81 | }
82 |
83 | let gradFn = grad(fn)
84 |
85 | let x = MLXArray(1.5)
86 | let dfdx = gradFn(x)
87 | print(dfdx)
88 |
89 | assert(dfdx.item() == Float(2 * 1.5))
90 |
91 | let df2dx2 = grad(grad(fn))(x)
92 | print(df2dx2)
93 |
94 | assert(df2dx2.item() == Float(2))
95 | }
96 |
97 | static func main() {
98 | scalarBasics()
99 | arrayBasics()
100 | automaticDifferentiation()
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/StringOrNumber.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 |
5 | /// Representation of a heterogenous type in a JSON configuration file.
6 | ///
7 | /// This can be: a string, a numeric value or an array of numeric values.
8 | /// There are methods to do unwrapping, see e.g. ``asFloat()`` and
9 | /// ``asFloats()`` or callers can switch on the enum.
10 | public enum StringOrNumber: Codable, Equatable, Sendable {
11 | case string(String)
12 | case int(Int)
13 | case float(Float)
14 | case ints([Int])
15 | case floats([Float])
16 |
17 | public init(from decoder: Decoder) throws {
18 | let values = try decoder.singleValueContainer()
19 |
20 | if let v = try? values.decode(Int.self) {
21 | self = .int(v)
22 | } else if let v = try? values.decode(Float.self) {
23 | self = .float(v)
24 | } else if let v = try? values.decode([Int].self) {
25 | self = .ints(v)
26 | } else if let v = try? values.decode([Float].self) {
27 | self = .floats(v)
28 | } else {
29 | let v = try values.decode(String.self)
30 | self = .string(v)
31 | }
32 | }
33 |
34 | public func encode(to encoder: Encoder) throws {
35 | var container = encoder.singleValueContainer()
36 | switch self {
37 | case .string(let v): try container.encode(v)
38 | case .int(let v): try container.encode(v)
39 | case .float(let v): try container.encode(v)
40 | case .ints(let v): try container.encode(v)
41 | case .floats(let v): try container.encode(v)
42 | }
43 | }
44 |
45 | /// Return the value as an optional array of integers.
46 | ///
47 | /// This will not coerce `Float` or `String` to `Int`.
48 | public func asInts() -> [Int]? {
49 | switch self {
50 | case .string(let string): nil
51 | case .int(let v): [v]
52 | case .float(let float): nil
53 | case .ints(let array): array
54 | case .floats(let array): nil
55 | }
56 | }
57 |
58 | /// Return the value as an optional integer.
59 | ///
60 | /// This will not coerce `Float` or `String` to `Int`.
61 | public func asInt() -> Int? {
62 | switch self {
63 | case .string(let string): nil
64 | case .int(let v): v
65 | case .float(let float): nil
66 | case .ints(let array): array.count == 1 ? array[0] : nil
67 | case .floats(let array): nil
68 | }
69 | }
70 |
71 | /// Return the value as an optional array of floats.
72 | ///
73 | /// This will not coerce `Int` or `String` to `Float`.
74 | public func asFloats() -> [Float]? {
75 | switch self {
76 | case .string(let string): nil
77 | case .int(let v): [Float(v)]
78 | case .float(let float): [float]
79 | case .ints(let array): array.map { Float($0) }
80 | case .floats(let array): array
81 | }
82 | }
83 |
84 | /// Return the value as an optional float.
85 | ///
86 | /// This will not coerce `Int` or `String` to `Float`.
87 | public func asFloat() -> Float? {
88 | switch self {
89 | case .string(let string): nil
90 | case .int(let v): Float(v)
91 | case .float(let float): float
92 | case .ints(let array): array.count == 1 ? Float(array[0]) : nil
93 | case .floats(let array): array.count == 1 ? array[0] : nil
94 | }
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/Load.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import Hub
5 | import MLX
6 | import MLXNN
7 | import Tokenizers
8 |
9 | /// Download the model using the `HubApi`.
10 | ///
11 | /// This will download `*.safetensors` and `*.json` if the ``ModelConfiguration``
12 | /// represents a Hub id, e.g. `mlx-community/gemma-2-2b-it-4bit`.
13 | ///
14 | /// This is typically called via ``ModelFactory/load(hub:configuration:progressHandler:)``
15 | ///
16 | /// - Parameters:
17 | /// - hub: HubApi instance
18 | /// - configuration: the model identifier
19 | /// - progressHandler: callback for progress
20 | /// - Returns: URL for the directory containing downloaded files
21 | public func downloadModel(
22 | hub: HubApi, configuration: ModelConfiguration,
23 | progressHandler: @Sendable @escaping (Progress) -> Void
24 | ) async throws -> URL {
25 | do {
26 | switch configuration.id {
27 | case .id(let id):
28 | // download the model weights
29 | let repo = Hub.Repo(id: id)
30 | let modelFiles = ["*.safetensors", "*.json"]
31 | return try await hub.snapshot(
32 | from: repo, matching: modelFiles, progressHandler: progressHandler)
33 |
34 | case .directory(let directory):
35 | return directory
36 | }
37 |
38 | } catch Hub.HubClientError.authorizationRequired {
39 | // an authorizationRequired means (typically) that the named repo doesn't exist on
40 | // on the server so retry with local only configuration
41 | return configuration.modelDirectory(hub: hub)
42 |
43 | } catch {
44 | let nserror = error as NSError
45 | if nserror.domain == NSURLErrorDomain && nserror.code == NSURLErrorNotConnectedToInternet {
46 | // Error Domain=NSURLErrorDomain Code=-1009 "The Internet connection appears to be offline."
47 | // fall back to the local directory
48 | return configuration.modelDirectory(hub: hub)
49 | } else {
50 | throw error
51 | }
52 | }
53 | }
54 |
55 | /// Load model weights.
56 | ///
57 | /// This is typically called via ``ModelFactory/load(hub:configuration:progressHandler:)``.
58 | /// This function loads all `safetensor` files in the given `modelDirectory`,
59 | /// calls ``LanguageModel/sanitize(weights:)``, applies optional quantization, and
60 | /// updates the model with the weights.
61 | public func loadWeights(
62 | modelDirectory: URL, model: LanguageModel, quantization: BaseConfiguration.Quantization? = nil
63 | ) throws {
64 | // load the weights
65 | var weights = [String: MLXArray]()
66 | let enumerator = FileManager.default.enumerator(
67 | at: modelDirectory, includingPropertiesForKeys: nil)!
68 | for case let url as URL in enumerator {
69 | if url.pathExtension == "safetensors" {
70 | let w = try loadArrays(url: url)
71 | for (key, value) in w {
72 | weights[key] = value
73 | }
74 | }
75 | }
76 |
77 | // per-model cleanup
78 | weights = model.sanitize(weights: weights)
79 |
80 | // quantize if needed
81 | if let quantization {
82 | quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
83 | path, module in
84 | weights["\(path).scales"] != nil
85 | }
86 | }
87 |
88 | // apply the loaded weights
89 | let parameters = ModuleParameters.unflattened(weights)
90 | try model.update(parameters: parameters, verify: [.all])
91 |
92 | eval(model)
93 | }
94 |
--------------------------------------------------------------------------------
/Tools/mnist-tool/MNISTTool.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 | import MLX
6 | import MLXMNIST
7 | import MLXNN
8 | import MLXOptimizers
9 | import MLXRandom
10 |
11 | @main
12 | struct MNISTTool: AsyncParsableCommand {
13 | static let configuration = CommandConfiguration(
14 | abstract: "Command line tool for training mnist models",
15 | subcommands: [Train.self],
16 | defaultSubcommand: Train.self)
17 | }
18 |
19 | #if swift(>=5.10)
20 | extension MLX.DeviceType: @retroactive ExpressibleByArgument {
21 | public init?(argument: String) {
22 | self.init(rawValue: argument)
23 | }
24 | }
25 | #else
26 | extension MLX.DeviceType: ExpressibleByArgument {
27 | public init?(argument: String) {
28 | self.init(rawValue: argument)
29 | }
30 | }
31 | #endif
32 |
33 | struct Train: AsyncParsableCommand {
34 |
35 | @Option(name: .long, help: "Directory with the training data")
36 | var data: String
37 |
38 | @Option(name: .long, help: "The PRNG seed")
39 | var seed: UInt64 = 0
40 |
41 | @Option var batchSize = 256
42 | @Option var epochs = 20
43 | @Option var learningRate: Float = 1e-1
44 |
45 | @Option var device = DeviceType.gpu
46 |
47 | @Flag var compile = false
48 |
49 | func run() async throws {
50 | Device.setDefault(device: Device(device))
51 |
52 | MLXRandom.seed(seed)
53 | var generator: RandomNumberGenerator = SplitMix64(seed: seed)
54 |
55 | // load the data
56 | let url = URL(filePath: data)
57 |
58 | try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true)
59 | try await download(into: url)
60 |
61 | let data = try load(from: url)
62 |
63 | let trainImages = data[.init(.training, .images)]!
64 | let trainLabels = data[.init(.training, .labels)]!
65 | let testImages = data[.init(.test, .images)]!
66 | let testLabels = data[.init(.test, .labels)]!
67 |
68 | // create the model
69 | let model = LeNet()
70 | eval(model.parameters())
71 |
72 | let lg = valueAndGrad(model: model, loss)
73 | let optimizer = SGD(learningRate: learningRate)
74 |
75 | func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray {
76 | let (loss, grads) = lg(model, x, y)
77 | optimizer.update(model: model, gradients: grads)
78 | return loss
79 | }
80 |
81 | let resolvedStep =
82 | compile
83 | ? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step
84 |
85 | for e in 0 ..< epochs {
86 | let start = Date.timeIntervalSinceReferenceDate
87 |
88 | for (x, y) in iterateBatches(
89 | batchSize: batchSize, x: trainImages, y: trainLabels, using: &generator)
90 | {
91 | _ = resolvedStep(x, y)
92 |
93 | // eval the parameters so the next iteration is independent
94 | eval(model, optimizer)
95 | }
96 |
97 | let accuracy = eval(model: model, x: testImages, y: testLabels)
98 |
99 | let end = Date.timeIntervalSinceReferenceDate
100 |
101 | print(
102 | """
103 | Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted())
104 | Time: \((end - start).formatted())
105 |
106 | """
107 | )
108 | }
109 | }
110 | }
111 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved:
--------------------------------------------------------------------------------
1 | {
2 | "originHash" : "347ce608ed233db4ed416d22692a515e7f4fd2fd3eed7904f75bb8b35eb5366c",
3 | "pins" : [
4 | {
5 | "identity" : "gzipswift",
6 | "kind" : "remoteSourceControl",
7 | "location" : "https://github.com/1024jp/GzipSwift",
8 | "state" : {
9 | "revision" : "731037f6cc2be2ec01562f6597c1d0aa3fe6fd05",
10 | "version" : "6.0.1"
11 | }
12 | },
13 | {
14 | "identity" : "jinja",
15 | "kind" : "remoteSourceControl",
16 | "location" : "https://github.com/johnmai-dev/Jinja",
17 | "state" : {
18 | "revision" : "bbddb92fc51ae420b87300298370fd1dfc308f73",
19 | "version" : "1.1.1"
20 | }
21 | },
22 | {
23 | "identity" : "mlx-swift",
24 | "kind" : "remoteSourceControl",
25 | "location" : "https://github.com/ml-explore/mlx-swift",
26 | "state" : {
27 | "revision" : "70dbb62128a5a1471a5ab80363430adb33470cab",
28 | "version" : "0.21.2"
29 | }
30 | },
31 | {
32 | "identity" : "networkimage",
33 | "kind" : "remoteSourceControl",
34 | "location" : "https://github.com/gonzalezreal/NetworkImage",
35 | "state" : {
36 | "revision" : "2849f5323265386e200484b0d0f896e73c3411b9",
37 | "version" : "6.0.1"
38 | }
39 | },
40 | {
41 | "identity" : "progress.swift",
42 | "kind" : "remoteSourceControl",
43 | "location" : "https://github.com/jkandzi/Progress.swift",
44 | "state" : {
45 | "revision" : "fed6598735d7982058690acf8f52a0a5fdaeb3e0",
46 | "version" : "0.4.0"
47 | }
48 | },
49 | {
50 | "identity" : "swift-argument-parser",
51 | "kind" : "remoteSourceControl",
52 | "location" : "https://github.com/apple/swift-argument-parser.git",
53 | "state" : {
54 | "revision" : "0fbc8848e389af3bb55c182bc19ca9d5dc2f255b",
55 | "version" : "1.4.0"
56 | }
57 | },
58 | {
59 | "identity" : "swift-cmark",
60 | "kind" : "remoteSourceControl",
61 | "location" : "https://github.com/swiftlang/swift-cmark",
62 | "state" : {
63 | "revision" : "3ccff77b2dc5b96b77db3da0d68d28068593fa53",
64 | "version" : "0.5.0"
65 | }
66 | },
67 | {
68 | "identity" : "swift-collections",
69 | "kind" : "remoteSourceControl",
70 | "location" : "https://github.com/apple/swift-collections.git",
71 | "state" : {
72 | "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7",
73 | "version" : "1.1.4"
74 | }
75 | },
76 | {
77 | "identity" : "swift-markdown-ui",
78 | "kind" : "remoteSourceControl",
79 | "location" : "https://github.com/gonzalezreal/swift-markdown-ui",
80 | "state" : {
81 | "revision" : "5f613358148239d0292c0cef674a3c2314737f9e",
82 | "version" : "2.4.1"
83 | }
84 | },
85 | {
86 | "identity" : "swift-numerics",
87 | "kind" : "remoteSourceControl",
88 | "location" : "https://github.com/apple/swift-numerics",
89 | "state" : {
90 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b",
91 | "version" : "1.0.2"
92 | }
93 | },
94 | {
95 | "identity" : "swift-transformers",
96 | "kind" : "remoteSourceControl",
97 | "location" : "https://github.com/huggingface/swift-transformers",
98 | "state" : {
99 | "revision" : "55710ddfb1ae804b4b7ce973be75cf2e41272185",
100 | "version" : "0.1.17"
101 | }
102 | }
103 | ],
104 | "version" : 3
105 | }
106 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/KVCache.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import MLX
5 |
6 | /// Interface for Key/Value cache for LLMs.
7 | ///
8 | /// See ``LanguageModel/newCache(parameters:)``
9 | public protocol KVCache: Evaluatable {
10 |
11 | /// get the current offset
12 | var offset: Int { get }
13 |
14 | func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray)
15 | }
16 |
17 | func createAdditiveCausalMask(n: Int, offset: Int) -> MLXArray {
18 | let rinds = MLXArray(Int32(0) ..< Int32(offset + n))
19 | let linds = offset != 0 ? MLXArray(Int32(offset) ..< Int32(offset + n)) : rinds
20 | let mask = linds[0..., .newAxis] .< rinds[.newAxis]
21 | return mask * Float32(-1e9)
22 | }
23 |
24 | /// create an attention mask using the parameters from the KVCache.
25 | ///
26 | /// See also ``MultiHeadAttention/createAdditiveCausalMask(_:dtype:)`` -- same idea
27 | /// but doesn't honor the cache offset.
28 | public func createAttentionMask(h: MLXArray, cache: [KVCache]?) -> MLXArray? {
29 | let t = h.dim(1)
30 | if t > 1 {
31 | var offset = 0
32 | if let c = cache?.first {
33 | offset = c.offset
34 | }
35 | return createAdditiveCausalMask(n: t, offset: offset)
36 | .asType(h.dtype)
37 | }
38 | return nil
39 | }
40 |
41 | /// See https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/base.py#L11
42 | public class KVCacheSimple: KVCache, Evaluatable {
43 | var keys: MLXArray?
44 | var values: MLXArray?
45 |
46 | public var offset = 0
47 | var step = 256
48 |
49 | public init() {}
50 |
51 | public func innerState() -> [MLXArray] {
52 | [self.keys, self.values].compactMap { $0 }
53 | }
54 |
55 | public func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) {
56 | let previous = self.offset
57 |
58 | let reset =
59 | if let currentKeys = self.keys, (previous + keys.dim(2)) > currentKeys.dim(2) {
60 | true
61 | } else {
62 | self.keys == nil
63 | }
64 | if reset {
65 | let B = keys.dim(0)
66 | let kvHeads = keys.dim(1)
67 | let kHeadDim = keys.dim(3)
68 | let vHeadDim = values.dim(3)
69 |
70 | let nSteps = (step + keys.dim(2) - 1) / step
71 | let kShape = [B, kvHeads, nSteps * step, kHeadDim]
72 | let vShape = [B, kvHeads, nSteps * step, vHeadDim]
73 | let newK = MLXArray.zeros(kShape, dtype: keys.dtype)
74 | let newV = MLXArray.zeros(vShape, dtype: values.dtype)
75 |
76 | if var currentKeys = self.keys, var currentValues = self.values {
77 | if previous % step != 0 {
78 | currentKeys = currentKeys[.ellipsis, .. Pooling {
26 | let configurationURL = modelDirectory.appending(components: "1_Pooling", "config.json")
27 | guard
28 | let poolingConfig = try? JSONDecoder().decode(
29 | PoolingConfiguration.self, from: Data(contentsOf: configurationURL))
30 | else {
31 | return Pooling(strategy: .none)
32 | }
33 |
34 | return Pooling(config: poolingConfig)
35 | }
36 |
37 | public class Pooling: Module {
38 | public enum Strategy {
39 | case mean
40 | case cls
41 | case first
42 | case last
43 | case max
44 | case none
45 | }
46 | let strategy: Strategy
47 | let dimension: Int?
48 |
49 | public init(
50 | strategy: Strategy, dimension: Int? = nil
51 | ) {
52 | self.strategy = strategy
53 | self.dimension = dimension
54 | }
55 |
56 | public init(
57 | config: PoolingConfiguration
58 | ) {
59 | dimension = config.dimension
60 | if config.poolingModeClsToken {
61 | strategy = .cls
62 | } else if config.poolingModeMeanTokens {
63 | strategy = .mean
64 | } else if config.poolingModeMaxTokens {
65 | strategy = .max
66 | } else if config.poolingModeLastToken {
67 | strategy = .last
68 | } else {
69 | strategy = .first
70 | }
71 | }
72 |
73 | public func callAsFunction(
74 | _ inputs: EmbeddingModelOutput, mask: MLXArray? = nil, normalize: Bool = false,
75 | applyLayerNorm: Bool = false
76 | ) -> MLXArray {
77 | let _mask = mask ?? MLXArray.ones(Array(inputs.hiddenStates?.shape[0 ..< 2] ?? [0]))
78 |
79 | var pooled: MLXArray
80 | switch self.strategy {
81 | case .mean:
82 | pooled =
83 | sum(
84 | inputs.hiddenStates! * _mask.expandedDimensions(axes: [-1]),
85 | axis: 1)
86 | / sum(_mask, axis: -1, keepDims: true)
87 | case .max:
88 | pooled = MLX.max(
89 | inputs.hiddenStates! * _mask.expandedDimensions(axes: [-1]), axis: 1)
90 | case .first:
91 | pooled = inputs.hiddenStates![0..., 0, 0...]
92 | case .last:
93 | pooled = inputs.hiddenStates![0..., -1, 0...]
94 | case .cls:
95 | pooled =
96 | inputs.pooledOutput
97 | ?? inputs.hiddenStates![0..., 0, 0...]
98 | case .none:
99 | pooled = inputs.pooledOutput ?? inputs.hiddenStates!
100 | }
101 | if applyLayerNorm {
102 | pooled = MLXFast.layerNorm(pooled, eps: 1e-5)
103 | }
104 | if let dimension {
105 | pooled = pooled[0..., 0 ..< dimension]
106 | }
107 | if normalize {
108 | pooled = pooled / norm(pooled, axis: -1, keepDims: true)
109 | }
110 | return pooled
111 | }
112 | }
113 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Documentation
2 |
3 | Developers can use these examples in their own programs -- just import the swift package!
4 |
5 | - [MLXLLMCommon](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxlmcommon) -- common API for LLM and VLM
6 | - [MLXLLM](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxllm) -- large language model example implementations
7 | - [MLXVLM](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxvlm) -- visual language model example implementations
8 | - [MLXEmbedders](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxembedders) -- popular Encoders / Embedding models example implementations
9 | - [StableDiffusion](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/stablediffusion) -- SDXL Turbo and Stable Diffusion mdeol example implementations
10 | - [MLXMNIST](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxmnist) -- MNIST implementation for all your digit recognition needs
11 |
12 | # MLX Swift Examples
13 |
14 | Example [MLX Swift](https://github.com/ml-explore/mlx-swift) programs.
15 |
16 | - [MNISTTrainer](Applications/MNISTTrainer/README.md): An example that runs on
17 | both iOS and macOS that downloads MNIST training data and trains a
18 | [LeNet](https://en.wikipedia.org/wiki/LeNet).
19 |
20 | - [LLMEval](Applications/LLMEval/README.md): An example that runs on both iOS
21 | and macOS that downloads an LLM and tokenizer from Hugging Face and
22 | generates text from a given prompt.
23 |
24 | - [VLMEval](Applications/VLMEval/README.md): An example that runs on iOS, macOS and visionOS to download a VLM and tokenizer from Hugging Face and
25 | analyzes the given image and describe it in text.
26 |
27 | - [LinearModelTraining](Tools/LinearModelTraining/README.md): An example that
28 | trains a simple linear model.
29 |
30 | - [StableDiffusionExample](Applications/StableDiffusionExample/README.md): An
31 | example that runs on both iOS and macOS that downloads a stable diffusion model
32 | from Hugging Face and and generates an image from a given prompt.
33 |
34 | - [llm-tool](Tools/llm-tool/README.md): A command line tool for generating text
35 | using a variety of LLMs available on the Hugging Face hub.
36 |
37 | - [image-tool](Tools/image-tool/README.md): A command line tool for generating images
38 | using a stable diffusion model from Hugging Face.
39 |
40 | - [mnist-tool](Tools/mnist-tool/README.md): A command line tool for training a
41 | a LeNet on MNIST.
42 |
43 | ## Running
44 |
45 | The application and command line tool examples can be run from Xcode or from
46 | the command line:
47 |
48 | ```
49 | ./mlx-run llm-tool --prompt "swift programming language"
50 | ```
51 |
52 | Note: `mlx-run` is a shell script that uses `xcode` command line tools to
53 | locate the built binaries. It is equivalent to running from Xcode itself.
54 |
55 | See also:
56 |
57 | - [MLX troubleshooting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/troubleshooting)
58 |
59 | ## Installation of libraries
60 |
61 | The MLXLLM, MLXVLM, MLXLMCommon, MLXMNIST, MLXEmbedders, and StableDiffusion libraries in the example repo are available
62 | as Swift Packages.
63 |
64 |
65 | Add the following dependency to your Package.swift
66 |
67 | ```swift
68 | .package(url: "https://github.com/ml-explore/mlx-swift-examples/", branch: "main"),
69 | ```
70 |
71 | Then add one or more libraries to the target as a dependency:
72 |
73 | ```swift
74 | .target(
75 | name: "YourTargetName",
76 | dependencies: [
77 | .product(name: "MLXLLM", package: "mlx-swift-examples")
78 | ]),
79 | ```
80 |
81 | Alternatively, add `https://github.com/ml-explore/mlx-swift-examples/` to the `Project Dependencies` and set the `Dependency Rule` to `Branch` and `main` in Xcode.
82 |
--------------------------------------------------------------------------------
/Libraries/MLXLLM/Documentation.docc/using-model.md:
--------------------------------------------------------------------------------
1 | # Using a Model
2 |
3 | Using a model is easy: load the weights, tokenize and evaluate.
4 |
5 | ## Loading a Model
6 |
7 | A model is typically loaded by using a `ModelFactory` and a `ModelConfiguration`:
8 |
9 | ```swift
10 | // e.g. LLMModelFactory.shared
11 | let modelFactory: ModelFactory
12 |
13 | // e.g. MLXLLM.ModelRegistry.llama3_8B_4bit
14 | let modelConfiguration: ModelConfiguration
15 |
16 | let container = try await modelFactory.loadContainer(configuration: modelConfiguration)
17 | ```
18 |
19 | The `container` provides an isolation context (an `actor`) to run inference in the model.
20 |
21 | Predefined `ModelConfiguration` instances are provided as static variables
22 | on the `ModelRegistry` types or they can be created:
23 |
24 | ```swift
25 | let modelConfiguration = ModelConfiguration(id: "mlx-community/llama3_8B_4bit")
26 | ```
27 |
28 | The flow inside the `ModelFactory` goes like this:
29 |
30 | ```swift
31 | public class LLMModelFactory: ModelFactory {
32 |
33 | public func _load(
34 | hub: HubApi, configuration: ModelConfiguration,
35 | progressHandler: @Sendable @escaping (Progress) -> Void
36 | ) async throws -> ModelContext {
37 | // download the weight and config using HubApi
38 | // load the base configuration
39 | // using the typeRegistry create a model (random weights)
40 | // load the weights, apply quantization as needed, update the model
41 | // calls model.sanitize() for weight preparation
42 | // load the tokenizer
43 | // (vlm) load the processor configuration, create the processor
44 | }
45 | }
46 | ```
47 |
48 | Callers with specialized requirements can use these individual components to manually
49 | load models, if needed.
50 |
51 | ## Evaluation Flow
52 |
53 | - Load the Model
54 | - UserInput
55 | - LMInput
56 | - generate()
57 | - NaiveStreamingDetokenizer
58 | - TokenIterator
59 |
60 | ## Evaluating a Model
61 |
62 | Once a model is loaded you can evaluate a prompt or series of
63 | messages. Minimally you need to prepare the user input:
64 |
65 | ```swift
66 | let prompt = "Describe the image in English"
67 | var input = UserInput(prompt: prompt, images: image.map { .url($0) })
68 | input.processing.resize = .init(width: 256, height: 256)
69 | ```
70 |
71 | This example shows adding some images and processing instructions -- if
72 | model accepts text only then these parts can be omitted. The inference
73 | calls are the same.
74 |
75 | Assuming you are using a `ModelContainer` (an actor that holds
76 | a `ModelContext`, which is the bundled set of types that implement a
77 | model), the first step is to convert the `UserInput` into the
78 | `LMInput` (LanguageModel Input):
79 |
80 | ```swift
81 | let generateParameters: GenerateParameters
82 | let input: UserInput
83 |
84 | let result = try await modelContainer.perform { [input] context in
85 | let input = try context.processor.prepare(input: input)
86 |
87 | ```
88 |
89 | Given that `input` we can call `generate()` to produce a stream
90 | of tokens. In this example we use a `NaiveStreamingDetokenizer`
91 | to assist in converting a stream of tokens into text and print it.
92 | The stream is stopped after we hit a maximum number of tokens:
93 |
94 | ```
95 | var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer)
96 |
97 | return try MLXLMCommon.generate(
98 | input: input, parameters: generateParameters, context: context
99 | ) { tokens in
100 |
101 | if let last = tokens.last {
102 | detokenizer.append(token: last)
103 | }
104 |
105 | if let new = detokenizer.next() {
106 | print(new, terminator: "")
107 | fflush(stdout)
108 | }
109 |
110 | if tokens.count >= maxTokens {
111 | return .stop
112 | } else {
113 | return .more
114 | }
115 | }
116 | }
117 | ```
118 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/PredictionView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // PredictionView.swift
3 | // MNISTTrainer
4 | //
5 | // Created by Rounak Jain on 3/9/24.
6 | //
7 |
8 | import MLX
9 | import MLXMNIST
10 | import MLXNN
11 | import SwiftUI
12 |
13 | struct Canvas: View {
14 |
15 | @Binding var path: Path
16 | @State var lastPoint: CGPoint?
17 |
18 | var body: some View {
19 | path
20 | .stroke(.white, lineWidth: 10)
21 | .background(.black)
22 | .gesture(
23 | DragGesture(minimumDistance: 0.05)
24 | .onChanged { touch in
25 | add(point: touch.location)
26 | }
27 | .onEnded { touch in
28 | lastPoint = nil
29 | }
30 | )
31 | }
32 |
33 | func add(point: CGPoint) {
34 | var newPath = path
35 | if let lastPoint {
36 | newPath.move(to: lastPoint)
37 | newPath.addLine(to: point)
38 | } else {
39 | newPath.move(to: point)
40 | }
41 | self.path = newPath
42 | lastPoint = point
43 | }
44 | }
45 |
46 | extension Path {
47 | mutating func center(to newMidPoint: CGPoint) {
48 | let middleX = boundingRect.midX
49 | let middleY = boundingRect.midY
50 | self = offsetBy(dx: newMidPoint.x - middleX, dy: newMidPoint.y - middleY)
51 | }
52 | }
53 |
54 | struct PredictionView: View {
55 | @State var path: Path = Path()
56 | @State var prediction: Int?
57 | let model: LeNetContainer
58 | let canvasSize = 150.0
59 | let mnistImageSize: CGSize = CGSize(width: 28, height: 28)
60 |
61 | var body: some View {
62 | VStack {
63 | if let prediction {
64 | Text("You've drawn a \(prediction)")
65 | } else {
66 | Text("Draw a digit")
67 | }
68 | Canvas(path: $path)
69 | .frame(width: canvasSize, height: canvasSize)
70 | HStack {
71 | Button("Predict") {
72 | path.center(to: CGPoint(x: canvasSize / 2, y: canvasSize / 2))
73 | predict()
74 | }
75 | Button("Clear") {
76 | path = Path()
77 | prediction = nil
78 | }
79 | }
80 | }
81 | }
82 |
83 | @MainActor
84 | func predict() {
85 | let imageRenderer = ImageRenderer(
86 | content: Canvas(path: $path).frame(width: 150, height: 150))
87 |
88 | if let image = imageRenderer.cgImage {
89 | Task {
90 | self.prediction = await model.evaluate(image: image)
91 | }
92 | }
93 | }
94 | }
95 |
96 | extension CGImage {
97 | func grayscaleImage(with newSize: CGSize) -> CGImage? {
98 | let colorSpace = CGColorSpaceCreateDeviceGray()
99 | let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
100 |
101 | guard
102 | let context = CGContext(
103 | data: nil,
104 | width: Int(newSize.width),
105 | height: Int(newSize.height),
106 | bitsPerComponent: 8,
107 | bytesPerRow: Int(newSize.width),
108 | space: colorSpace,
109 | bitmapInfo: bitmapInfo.rawValue)
110 | else {
111 | return nil
112 | }
113 | context.draw(self, in: CGRect(x: 0, y: 0, width: newSize.width, height: newSize.width))
114 | return context.makeImage()
115 | }
116 |
117 | func pixelData() -> MLXArray {
118 | guard let data = self.dataProvider?.data else {
119 | return []
120 | }
121 | let bytePtr = CFDataGetBytePtr(data)
122 | let count = CFDataGetLength(data)
123 | return MLXArray(UnsafeBufferPointer(start: bytePtr, count: count))
124 | }
125 | }
126 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/README.md:
--------------------------------------------------------------------------------
1 | # MLXLMCommon
2 |
3 | MLXLMCommon contains types and code that is generic across many types
4 | of language models, from LLMs to VLMs:
5 |
6 | - Evaluation
7 | - KVCache
8 | - Loading
9 | - UserInput
10 |
11 | ## Loading a Model
12 |
13 | A model is typically loaded by using a `ModelFactory` and a `ModelConfiguration`:
14 |
15 | ```swift
16 | // e.g. VLMModelFactory.shared
17 | let modelFactory: ModelFactory
18 |
19 | // e.g. MLXVLM.ModelRegistry.paligemma3bMix4488bit
20 | let modelConfiguration: ModelConfiguration
21 |
22 | let container = try await modelFactory.loadContainer(configuration: modelConfiguration)
23 | ```
24 |
25 | The `container` provides an isolation context (an `actor`) to run inference in the model.
26 |
27 | Predefined `ModelConfiguration` instances are provided as static variables
28 | on the `ModelRegistry` types or they can be created:
29 |
30 | ```swift
31 | let modelConfiguration = ModelConfiguration(id: "mlx-community/paligemma-3b-mix-448-8bit")
32 | ```
33 |
34 | The flow inside the `ModelFactory` goes like this:
35 |
36 | ```swift
37 | public class VLMModelFactory: ModelFactory {
38 |
39 | public func _load(
40 | hub: HubApi, configuration: ModelConfiguration,
41 | progressHandler: @Sendable @escaping (Progress) -> Void
42 | ) async throws -> ModelContext {
43 | // download the weight and config using HubApi
44 | // load the base configuration
45 | // using the typeRegistry create a model (random weights)
46 | // load the weights, apply quantization as needed, update the model
47 | // calls model.sanitize() for weight preparation
48 | // load the tokenizer
49 | // (vlm) load the processor configuration, create the processor
50 | }
51 | }
52 | ```
53 |
54 | Callers with specialized requirements can use these individual components to manually
55 | load models, if needed.
56 |
57 | ## Evaluation Flow
58 |
59 | - Load the Model
60 | - UserInput
61 | - LMInput
62 | - generate()
63 | - NaiveStreamingDetokenizer
64 | - TokenIterator
65 |
66 | ## Using a Model
67 |
68 | Once a model is loaded you can evaluate a prompt or series of
69 | messages. Minimally you need to prepare the user input:
70 |
71 | ```swift
72 | let prompt = "Describe the image in English"
73 | var input = UserInput(prompt: prompt, images: image.map { .url($0) })
74 | input.processing.resize = .init(width: 256, height: 256)
75 | ```
76 |
77 | This example shows adding some images and processing instructions -- if
78 | model accepts text only then these parts can be omitted. The inference
79 | calls are the same.
80 |
81 | Assuming you are using a `ModelContainer` (an actor that holds
82 | a `ModelContext`, which is the bundled set of types that implement a
83 | model), the first step is to convert the `UserInput` into the
84 | `LMInput` (LanguageModel Input):
85 |
86 | ```swift
87 | let generateParameters: GenerateParameters
88 | let input: UserInput
89 |
90 | let result = try await modelContainer.perform { [input] context in
91 | let input = try context.processor.prepare(input: input)
92 |
93 | ```
94 |
95 | Given that `input` we can call `generate()` to produce a stream
96 | of tokens. In this example we use a `NaiveStreamingDetokenizer`
97 | to assist in converting a stream of tokens into text and print it.
98 | The stream is stopped after we hit a maximum number of tokens:
99 |
100 | ```
101 | var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer)
102 |
103 | return try MLXLMCommon.generate(
104 | input: input, parameters: generateParameters, context: context
105 | ) { tokens in
106 |
107 | if let last = tokens.last {
108 | detokenizer.append(token: last)
109 | }
110 |
111 | if let new = detokenizer.next() {
112 | print(new, terminator: "")
113 | fflush(stdout)
114 | }
115 |
116 | if tokens.count >= maxTokens {
117 | return .stop
118 | } else {
119 | return .more
120 | }
121 | }
122 | }
123 | ```
124 |
125 |
--------------------------------------------------------------------------------
/Libraries/Embedders/EmbeddingModel.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | @preconcurrency import Hub
5 | import MLX
6 | import MLXNN
7 | import Tokenizers
8 |
9 | /// Container for models that guarantees single threaded access.
10 | ///
11 | /// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
12 | /// the model and/or tokenizer:
13 | ///
14 | /// ```swift
15 | /// let promptTokens = await modelContainer.perform { _, tokenizer in
16 | /// tokenizer.encode(text: prompt)
17 | /// }
18 | /// ```
19 | ///
20 | /// or:
21 | ///
22 | /// ```swift
23 | /// let result = await modelContainer.perform { model, tokenizer in
24 | /// LLM.generate(
25 | /// promptTokens: promptTokens, parameters: generateParameters, model: model,
26 | /// tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
27 | /// ) { tokens in
28 | /// ...
29 | /// }
30 | /// }
31 | /// ```
32 | public actor ModelContainer {
33 | let model: EmbeddingModel
34 | let tokenizer: Tokenizer
35 | let pooler: Pooling
36 |
37 | public init(
38 | model: EmbeddingModel, tokenizer: Tokenizer, pooler: Pooling = Pooling(strategy: .none)
39 | ) {
40 | self.model = model
41 | self.tokenizer = tokenizer
42 | self.pooler = pooler
43 | }
44 |
45 | /// build the model and tokenizer without passing non-sendable data over isolation barriers
46 | public init(
47 | hub: HubApi, modelDirectory: URL, configuration: ModelConfiguration
48 | ) async throws {
49 | self.model = try loadSynchronous(modelDirectory: modelDirectory)
50 |
51 | let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig(
52 | configuration: configuration, hub: hub)
53 | self.tokenizer = try PreTrainedTokenizer(
54 | tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
55 | self.pooler = loadPooling(modelDirectory: modelDirectory) //?? Pooling(strategy: .none)
56 | }
57 |
58 | /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as
59 | /// `MLXArray` is not `Sendable`.
60 | public func perform(_ action: @Sendable (EmbeddingModel, Tokenizer, Pooling) throws -> R)
61 | rethrows
62 | -> R
63 | {
64 | try action(model, tokenizer, pooler)
65 | }
66 | }
67 |
68 | extension Module {
69 |
70 | /// Compute the number of parameters in a possibly quantized model
71 | public func numParameters() -> Int {
72 | return leafModules().flattenedValues().map {
73 | mod -> Int in
74 | if let qlin = mod as? QuantizedLinear {
75 | return qlin.scales.size * qlin.groupSize
76 | } else if let qemb = mod as? QuantizedEmbedding {
77 | return qemb.scales.size * qemb.groupSize
78 | } else {
79 | return mod.parameters().flattenedValues().reduce(
80 | 0,
81 | {
82 | $0 + $1.size
83 | })
84 | }
85 | }.reduce(0, +)
86 | }
87 | }
88 |
89 | public struct EmbeddingModelOutput {
90 | let hiddenStates: MLXArray?
91 | let pooledOutput: MLXArray?
92 | }
93 |
94 | public protocol EmbeddingModel: Module {
95 | var vocabularySize: Int { get }
96 | func callAsFunction(
97 | _ inputs: MLXArray, positionIds: MLXArray?, tokenTypeIds: MLXArray?,
98 | attentionMask: MLXArray?
99 | ) -> EmbeddingModelOutput
100 | /// Optionally preprocess the weights and modify / remove values as needed.
101 | func sanitize(weights: [String: MLXArray]) -> [String: MLXArray]
102 | }
103 |
104 | extension EmbeddingModel {
105 | func callAsFunction(
106 | _ inputs: MLXArray, positionIds: MLXArray? = nil, tokenTypeIds: MLXArray? = nil,
107 | attentionMask: MLXArray? = nil
108 | ) -> EmbeddingModelOutput {
109 | return callAsFunction(
110 | inputs, positionIds: positionIds, tokenTypeIds: tokenTypeIds,
111 | attentionMask: attentionMask)
112 | }
113 | }
114 |
--------------------------------------------------------------------------------
/Libraries/StableDiffusion/Sampler.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import MLX
5 | import MLXRandom
6 |
7 | // port of https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/sampler.py
8 |
9 | /// Interpolate the function defined by `(0 ..< y.count) y)` at positions `xNew`.
10 | func interpolate(y: MLXArray, xNew: MLXArray) -> MLXArray {
11 | let xLow = xNew.asType(.int32)
12 | let xHigh = minimum(xLow + 1, y.count - 1)
13 |
14 | let yLow = y[xLow]
15 | let yHigh = y[xHigh]
16 | let deltaX = xNew - xLow
17 | let yNew = yLow * (1 - deltaX) + deltaX * yHigh
18 |
19 | return yNew
20 | }
21 |
22 | /// A simple Euler integrator that can be used to sample from our diffusion models.
23 | ///
24 | /// The method ``step()`` performs one Euler step from `x_t` to `x_t_prev`.
25 | class SimpleEulerSampler {
26 |
27 | let sigmas: MLXArray
28 |
29 | public init(configuration: DiffusionConfiguration) {
30 | let betas: MLXArray
31 |
32 | // compute the noise schedule
33 | switch configuration.betaSchedule {
34 | case .linear:
35 | betas = MLXArray.linspace(
36 | configuration.betaStart, configuration.betaEnd, count: configuration.trainSteps)
37 | case .scaledLinear:
38 | betas = MLXArray.linspace(
39 | sqrt(configuration.betaStart), sqrt(configuration.betaEnd),
40 | count: configuration.trainSteps
41 | ).square()
42 | }
43 |
44 | let alphas = 1 - betas
45 | let alphasCumprod = cumprod(alphas)
46 |
47 | self.sigmas = concatenated([
48 | MLXArray.zeros([1]), ((1 - alphasCumprod) / alphasCumprod).sqrt(),
49 | ])
50 | }
51 |
52 | public var maxTime: Int {
53 | sigmas.count - 1
54 | }
55 |
56 | public func samplePrior(shape: [Int], dType: DType = .float32, key: MLXArray? = nil) -> MLXArray
57 | {
58 | let noise = MLXRandom.normal(shape, key: key)
59 | return (noise * sigmas[-1] * (sigmas[-1].square() + 1).rsqrt()).asType(dType)
60 | }
61 |
62 | public func addNoise(x: MLXArray, t: MLXArray, key: MLXArray? = nil) -> MLXArray {
63 | let noise = MLXRandom.normal(x.shape, key: key)
64 | let s = sigmas(t)
65 | return (x + noise * s) * (s.square() + 1).rsqrt()
66 | }
67 |
68 | public func sigmas(_ t: MLXArray) -> MLXArray {
69 | interpolate(y: sigmas, xNew: t)
70 | }
71 |
72 | public func timeSteps(steps: Int, start: Int? = nil, dType: DType = .float32) -> [(
73 | MLXArray, MLXArray
74 | )] {
75 | let start = start ?? (sigmas.count - 1)
76 | precondition(0 < start)
77 | precondition(start <= sigmas.count - 1)
78 | let steps = MLX.linspace(start, 0, count: steps + 1).asType(dType)
79 |
80 | return Array(zip(steps, steps[1...]))
81 | }
82 |
83 | open func step(epsPred: MLXArray, xt: MLXArray, t: MLXArray, tPrev: MLXArray) -> MLXArray {
84 | let dtype = epsPred.dtype
85 | let sigma = sigmas(t).asType(dtype)
86 | let sigmaPrev = sigmas(tPrev).asType(dtype)
87 |
88 | let dt = sigmaPrev - sigma
89 | var xtPrev = (sigma.square() + 1).sqrt() * xt + epsPred * dt
90 | xtPrev = xtPrev * (sigmaPrev.square() + 1).rsqrt()
91 |
92 | return xtPrev
93 | }
94 | }
95 |
96 | class SimpleEulerAncestralSampler: SimpleEulerSampler {
97 |
98 | open override func step(epsPred: MLXArray, xt: MLXArray, t: MLXArray, tPrev: MLXArray)
99 | -> MLXArray
100 | {
101 | let dtype = epsPred.dtype
102 | let sigma = sigmas(t).asType(dtype)
103 | let sigmaPrev = sigmas(tPrev).asType(dtype)
104 |
105 | let sigma2 = sigma.square()
106 | let sigmaPrev2 = sigmaPrev.square()
107 | let sigmaUp = (sigmaPrev2 * (sigma2 - sigmaPrev2) / sigma2).sqrt()
108 | let sigmaDown = (sigmaPrev2 - sigmaUp ** 2).sqrt()
109 |
110 | let dt = sigmaDown - sigma
111 | var xtPrev = (sigma2 + 1).sqrt() * xt + epsPred * dt
112 | let noise = MLXRandom.normal(xtPrev.shape).asType(xtPrev.dtype)
113 | xtPrev = xtPrev + noise * sigmaUp
114 | xtPrev = xtPrev * (sigmaPrev2 + 1).rsqrt()
115 |
116 | return xtPrev
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/Data/lora/wikisql.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | """
4 | Code to preprocess the WikiSQL dataset adapted from
5 | https://github.com/salesforce/WikiSQL and
6 | https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb .
7 | """
8 |
9 |
10 | import json
11 | import os
12 |
13 |
14 | def load():
15 | """
16 | Load all three splits of the WikiSQL dataset.
17 | """
18 | return (WikiSQL(dn) for dn in ["train", "dev", "test"])
19 |
20 |
21 | class WikiSQL:
22 | def __init__(self, dataset, save_dir="/tmp"):
23 | valid_sets = ("train", "dev", "test")
24 | if dataset not in valid_sets:
25 | raise ValueError(f"Dataset must be in {valid_sets}, got {dataset}")
26 | data_dir = os.path.join(save_dir, "wikisql")
27 | self._maybe_download(data_dir)
28 |
29 | self._parse_tables(os.path.join(data_dir, f"data/{dataset}.tables.jsonl"))
30 | self._parse_queries(os.path.join(data_dir, f"data/{dataset}.jsonl"))
31 |
32 | def _maybe_download(self, data_dir):
33 | if not os.path.exists(data_dir):
34 | import io
35 | import tarfile
36 | from urllib import request
37 |
38 | url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
39 | r = request.urlopen(url)
40 | with tarfile.open(fileobj=io.BytesIO(r.read())) as tf:
41 | tf.extractall(data_dir)
42 |
43 | def _parse_tables(self, tables):
44 | self._tables = {}
45 | with open(tables) as f:
46 | for line in f:
47 | table = json.loads(line)
48 | self._tables[table["id"]] = {
49 | "columns": table["header"],
50 | "types": table["types"],
51 | "desc": f"table: {table['id']}\ncolumns: {', '.join(table['header'])}",
52 | }
53 |
54 | def _parse_queries(self, queries):
55 | self._queries = []
56 | with open(queries) as f:
57 | for line in f:
58 | query = json.loads(line)
59 | table = self._tables[query["table_id"]]
60 | question = query["question"]
61 | answer = self.query_to_text(
62 | query["sql"], query["table_id"], table["columns"], table["types"]
63 | )
64 | self._queries.append(
65 | f"{table['desc']}\nQ: {question}\nA: {answer}"
66 | )
67 |
68 | def query_to_text(self, query, table, columns, types):
69 | aggregation_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
70 | condition_ops = ["=", ">", "<", "OP"]
71 | column = columns[query["sel"]]
72 | aggregation = (aggregation_ops[query["agg"]] + " ") if query["agg"] > 0 else ""
73 | sql = f"SELECT {aggregation}{column} FROM {table}"
74 |
75 | conditions = query["conds"]
76 | if conditions:
77 | cs = []
78 | for i, o, v in conditions:
79 | column = columns[i]
80 | op = condition_ops[o]
81 |
82 | if types[i] == "text":
83 | value = f"'{v}'"
84 | else:
85 | value = v
86 | cs.append(f"{column} {op} {value}")
87 |
88 | sql += " WHERE " + " AND ".join(cs)
89 |
90 | return sql
91 |
92 | def __getitem__(self, idx):
93 | return self._queries[idx]
94 |
95 | def __len__(self):
96 | return len(self._queries)
97 |
98 |
99 | if __name__ == "__main__":
100 | datanames = ["train", "dev", "test"]
101 | sizes = [56355, 8421, 15878]
102 | for dataname, size in zip(datanames, sizes):
103 | len(WikiSQL(dataname)) == size, f"Wrong {dataname} set size."
104 |
105 | # Write the sets to jsonl
106 | import json
107 |
108 | train, dev, test = load()
109 | datasets = [
110 | (train, "train", 1000),
111 | (dev, "valid", 100),
112 | (test, "test", 100),
113 | ]
114 | for dataset, name, size in datasets:
115 | with open(f"data/{name}.jsonl", "w") as fid:
116 | for e, t in zip(range(size), dataset):
117 | # Strip the , since the tokenizer adds them
118 | json.dump({"text": t[3:-4]}, fid)
119 | fid.write("\n")
120 |
--------------------------------------------------------------------------------
/Libraries/Embedders/Load.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | @preconcurrency import Hub
5 | import MLX
6 | import MLXNN
7 | import MLXRandom
8 | import Tokenizers
9 |
10 | struct EmbedderError: Error {
11 | let message: String
12 | }
13 |
14 | func prepareModelDirectory(
15 | hub: HubApi, configuration: ModelConfiguration,
16 | progressHandler: @Sendable @escaping (Progress) -> Void
17 | ) async throws -> URL {
18 | do {
19 | switch configuration.id {
20 | case .id(let id):
21 | // download the model weights
22 | let repo = Hub.Repo(id: id)
23 | let modelFiles = ["*.safetensors", "config.json"]
24 | return try await hub.snapshot(
25 | from: repo, matching: modelFiles, progressHandler: progressHandler)
26 |
27 | case .directory(let directory):
28 | return directory
29 | }
30 | } catch Hub.HubClientError.authorizationRequired {
31 | // an authorizationRequired means (typically) that the named repo doesn't exist on
32 | // on the server so retry with local only configuration
33 | return configuration.modelDirectory(hub: hub)
34 | } catch {
35 | let nserror = error as NSError
36 | if nserror.domain == NSURLErrorDomain && nserror.code == NSURLErrorNotConnectedToInternet {
37 | // Error Domain=NSURLErrorDomain Code=-1009 "The Internet connection appears to be offline."
38 | // fall back to the local directory
39 | return configuration.modelDirectory(hub: hub)
40 | } else {
41 | throw error
42 | }
43 | }
44 | }
45 |
46 | /// Load and return the model and tokenizer
47 | public func load(
48 | hub: HubApi = HubApi(), configuration: ModelConfiguration,
49 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
50 | ) async throws -> (EmbeddingModel, Tokenizer) {
51 | let modelDirectory = try await prepareModelDirectory(
52 | hub: hub, configuration: configuration, progressHandler: progressHandler)
53 | let model = try loadSynchronous(modelDirectory: modelDirectory)
54 | let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
55 |
56 | return (model, tokenizer)
57 | }
58 |
59 | func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
60 | // create the model (no weights loaded)
61 | let configurationURL = modelDirectory.appending(component: "config.json")
62 | let baseConfig = try JSONDecoder().decode(
63 | BaseConfiguration.self, from: Data(contentsOf: configurationURL))
64 |
65 | let model = try baseConfig.modelType.createModel(configuration: configurationURL)
66 |
67 | // load the weights
68 | var weights = [String: MLXArray]()
69 | let enumerator = FileManager.default.enumerator(
70 | at: modelDirectory, includingPropertiesForKeys: nil)!
71 | for case let url as URL in enumerator {
72 | if url.pathExtension == "safetensors" {
73 | let w = try loadArrays(url: url)
74 | for (key, value) in w {
75 | weights[key] = value
76 | }
77 | }
78 | }
79 |
80 | // per-model cleanup
81 | weights = model.sanitize(weights: weights)
82 |
83 | // quantize if needed
84 | if let quantization = baseConfig.quantization {
85 | quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
86 | path, module in
87 | weights["\(path).scales"] != nil
88 | }
89 | }
90 |
91 | // apply the loaded weights
92 | let parameters = ModuleParameters.unflattened(weights)
93 | try model.update(parameters: parameters, verify: [.all])
94 |
95 | eval(model)
96 |
97 | return model
98 | }
99 |
100 | /// Load and return the model and tokenizer wrapped in a ``ModelContainer`` (provides
101 | /// thread safe access).
102 | public func loadModelContainer(
103 | hub: HubApi = HubApi(), configuration: ModelConfiguration,
104 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
105 | ) async throws -> ModelContainer {
106 | let modelDirectory = try await prepareModelDirectory(
107 | hub: hub, configuration: configuration, progressHandler: progressHandler)
108 | return try await ModelContainer(
109 | hub: hub, modelDirectory: modelDirectory, configuration: configuration)
110 | }
111 |
--------------------------------------------------------------------------------
/Tools/LinearModelTraining/LinearModelTraining.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 | import MLX
6 | import MLXNN
7 | import MLXOptimizers
8 | import MLXRandom
9 |
10 | #if swift(>=5.10)
11 | extension MLX.DeviceType: @retroactive ExpressibleByArgument {
12 | public init?(argument: String) {
13 | self.init(rawValue: argument)
14 | }
15 | }
16 | #else
17 | extension MLX.DeviceType: ExpressibleByArgument {
18 | public init?(argument: String) {
19 | self.init(rawValue: argument)
20 | }
21 | }
22 | #endif
23 |
24 | @main
25 | struct Train: AsyncParsableCommand {
26 |
27 | @Option var epochs = 20
28 | @Option var batchSize = 8
29 |
30 | @Option var m: Float = 0.25
31 | @Option var b: Float = 7
32 |
33 | @Flag var compile = false
34 |
35 | @Option var device = DeviceType.cpu
36 |
37 | func run() async throws {
38 | Device.setDefault(device: Device(device))
39 |
40 | // A very simple model that implements the equation
41 | // for a linear function: y = mx + b. This can be trained
42 | // to match data -- in this case an unknown (to the model)
43 | // linear function.
44 | //
45 | // This is a nice example because most people know how
46 | // linear functions work and we can see how the slope
47 | // and intercept converge.
48 | class LinearFunctionModel: Module, UnaryLayer {
49 | let m = MLXRandom.uniform(low: -5.0, high: 5.0)
50 | let b = MLXRandom.uniform(low: -5.0, high: 5.0)
51 |
52 | func callAsFunction(_ x: MLXArray) -> MLXArray {
53 | m * x + b
54 | }
55 | }
56 |
57 | // measure the distance from the prediction (model(x)) and the
58 | // ground truth (y). this gives feedback on how close the
59 | // prediction is from matching the truth
60 | func loss(model: LinearFunctionModel, x: MLXArray, y: MLXArray) -> MLXArray {
61 | mseLoss(predictions: model(x), targets: y, reduction: .mean)
62 | }
63 |
64 | let model = LinearFunctionModel()
65 | eval(model.parameters())
66 |
67 | let lg = valueAndGrad(model: model, loss)
68 |
69 | // the optimizer will use the gradients update the model parameters
70 | let optimizer = SGD(learningRate: 1e-1)
71 |
72 | // the function to train our model against -- it doesn't have
73 | // to be linear, but matching what the model models is easy
74 | // to understand
75 | func f(_ x: MLXArray) -> MLXArray {
76 | // these are the target parameters
77 | let m = self.m
78 | let b = self.b
79 |
80 | // our actual function
81 | return m * x + b
82 | }
83 |
84 | func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray {
85 | let (loss, grads) = lg(model, x, y)
86 | optimizer.update(model: model, gradients: grads)
87 | return loss
88 | }
89 |
90 | let resolvedStep =
91 | self.compile
92 | ? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step
93 |
94 | for _ in 0 ..< epochs {
95 | // we expect that the parameters will approach the targets
96 | print("target: b = \(b), m = \(m)")
97 | print("parameters: \(model.parameters())")
98 |
99 | // generate random training data along with the ground truth.
100 | // notice that the shape is [B, 1] where B is the batch
101 | // dimension -- this allows us to train on several samples simultaneously
102 | //
103 | // note: a very large batch size will take longer to converge because
104 | // the gradient will be representing too many samples down into
105 | // a single float parameter.
106 | let x = MLXRandom.uniform(low: -5.0, high: 5.0, [batchSize, 1])
107 | let y = f(x)
108 | eval(x, y)
109 |
110 | // compute the loss and gradients. use the optimizer
111 | // to adjust the parameters closer to the target
112 | let loss = resolvedStep(x, y)
113 |
114 | eval(model, optimizer)
115 |
116 | // we should see this converge toward 0
117 | print("loss: \(loss)")
118 | }
119 |
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/Libraries/StableDiffusion/Tokenizer.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 |
5 | struct Bigram: Hashable {
6 | let a: String
7 | let b: String
8 |
9 | init(_ s: String) {
10 | let pieces = s.split(separator: " ")
11 | precondition(pieces.count == 2, "BPEPair expected two pieces for '\(s)'")
12 | self.a = String(pieces[0])
13 | self.b = String(pieces[1])
14 | }
15 |
16 | init(_ a: String, _ b: String) {
17 | self.a = a
18 | self.b = b
19 | }
20 |
21 | init(_ v: (String, String)) {
22 | self.a = v.0
23 | self.b = v.1
24 | }
25 | }
26 |
27 | /// A CLIP tokenizer.
28 | ///
29 | /// Ported from:
30 | ///
31 | /// - https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
32 | /// - https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
33 | ///
34 | /// Ideally this would be a tokenizer from `swift-transformers` but this is too special purpose to be representable in
35 | /// what exists there (at time of writing).
36 | class CLIPTokenizer {
37 |
38 | let pattern =
39 | #/<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+/#
40 | let bpeRanks: [Bigram: Int]
41 | let vocabulary: [String: Int]
42 |
43 | let bos = "<|startoftext|>"
44 | let eos = "<|endoftext|>"
45 |
46 | let bosToken: Int
47 | let eosToken: Int
48 |
49 | var cache = [String: [String]]()
50 |
51 | init(merges: [String], vocabulary: [String: Int]) {
52 | self.bpeRanks = Dictionary(
53 | uniqueKeysWithValues:
54 | merges
55 | .map { Bigram($0) }
56 | .enumerated()
57 | .map { ($0.element, $0.offset) })
58 |
59 | self.vocabulary = vocabulary
60 | self.cache[bos] = [bos]
61 | self.cache[eos] = [eos]
62 | self.bosToken = vocabulary[bos]!
63 | self.eosToken = vocabulary[eos]!
64 | }
65 |
66 | func bpe(text: String) -> [String] {
67 | if let result = cache[text] {
68 | return result
69 | }
70 |
71 | precondition(!text.isEmpty)
72 |
73 | var unigrams = text.dropLast().map { String($0) } + ["\(text.last!)"]
74 | var uniqueBigrams = Set(zip(unigrams, unigrams.dropFirst()).map { Bigram($0) })
75 |
76 | // In every iteration try to merge the two most likely bigrams. If none
77 | // was merged we are done
78 |
79 | while !uniqueBigrams.isEmpty {
80 | let (bigram, _) =
81 | uniqueBigrams
82 | .map { ($0, bpeRanks[$0] ?? Int.max) }
83 | .min { $0.1 < $1.1 }!
84 |
85 | if bpeRanks[bigram] == nil {
86 | break
87 | }
88 |
89 | var newUnigrams = [String]()
90 | var skip = false
91 |
92 | for (a, b) in zip(unigrams, unigrams.dropFirst()) {
93 | if skip {
94 | skip = false
95 | continue
96 | }
97 |
98 | if Bigram(a, b) == bigram {
99 | newUnigrams.append(a + b)
100 | skip = true
101 | } else {
102 | newUnigrams.append(a)
103 | }
104 | }
105 |
106 | if !skip, let last = unigrams.last {
107 | newUnigrams.append(last)
108 | }
109 |
110 | unigrams = newUnigrams
111 | uniqueBigrams = Set(zip(unigrams, unigrams.dropFirst()).map { Bigram($0) })
112 | }
113 |
114 | cache[text] = unigrams
115 |
116 | return unigrams
117 | }
118 |
119 | public func tokenize(text: String) -> [Int32] {
120 | // Lower case cleanup and split according to self.pat. Hugging Face does
121 | // a much more thorough job here but this should suffice for 95% of
122 | // cases.
123 |
124 | let clean = text.lowercased().replacing(#/\s+/#, with: " ")
125 | let tokens = clean.matches(of: pattern).map { $0.description }
126 |
127 | // Split the tokens according to the byte-pair merge file
128 | let bpeTokens = tokens.flatMap { bpe(text: String($0)) }
129 |
130 | // Map to token ids and return
131 | let result = [bosToken] + bpeTokens.compactMap { vocabulary[$0] } + [eosToken]
132 |
133 | return result.map { Int32($0) }
134 | }
135 | }
136 |
--------------------------------------------------------------------------------
/Libraries/Embedders/Configuration.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 |
5 | public enum StringOrNumber: Codable, Equatable, Sendable {
6 | case string(String)
7 | case float(Float)
8 |
9 | public init(from decoder: Decoder) throws {
10 | let values = try decoder.singleValueContainer()
11 |
12 | if let v = try? values.decode(Float.self) {
13 | self = .float(v)
14 | } else {
15 | let v = try values.decode(String.self)
16 | self = .string(v)
17 | }
18 | }
19 |
20 | public func encode(to encoder: Encoder) throws {
21 | var container = encoder.singleValueContainer()
22 | switch self {
23 | case .string(let v): try container.encode(v)
24 | case .float(let v): try container.encode(v)
25 | }
26 | }
27 | }
28 |
29 | private class ModelTypeRegistry: @unchecked Sendable {
30 |
31 | // Note: using NSLock as we have very small (just dictionary get/set)
32 | // critical sections and expect no contention. this allows the methods
33 | // to remain synchronous.
34 | private let lock = NSLock()
35 |
36 | private var creators: [String: @Sendable (URL) throws -> EmbeddingModel] = [
37 | "bert": {
38 | url in
39 | let configuration = try JSONDecoder().decode(
40 | BertConfiguration.self, from: Data(contentsOf: url))
41 | let model = BertModel(configuration)
42 | return model
43 | },
44 | "roberta": {
45 | url in
46 | let configuration = try JSONDecoder().decode(
47 | BertConfiguration.self, from: Data(contentsOf: url))
48 | let model = BertModel(configuration)
49 | return model
50 | },
51 | "xlm-roberta": {
52 | url in
53 | let configuration = try JSONDecoder().decode(
54 | BertConfiguration.self, from: Data(contentsOf: url))
55 | let model = BertModel(configuration)
56 | return model
57 | },
58 | "distilbert": {
59 | url in
60 | let configuration = try JSONDecoder().decode(
61 | BertConfiguration.self, from: Data(contentsOf: url))
62 | let model = BertModel(configuration)
63 | return model
64 | },
65 | "nomic_bert": {
66 | url in
67 | let configuration = try JSONDecoder().decode(
68 | NomicBertConfiguration.self, from: Data(contentsOf: url))
69 | let model = NomicBertModel(configuration)
70 | return model
71 | },
72 | ]
73 |
74 | public func registerModelType(
75 | _ type: String, creator: @Sendable @escaping (URL) throws -> EmbeddingModel
76 | ) {
77 | lock.withLock {
78 | creators[type] = creator
79 | }
80 | }
81 |
82 | public func createModel(configuration: URL, rawValue: String) throws -> EmbeddingModel {
83 | let creator = lock.withLock {
84 | creators[rawValue]
85 | }
86 | guard let creator else {
87 | throw EmbedderError(message: "Unsupported model type.")
88 | }
89 | return try creator(configuration)
90 | }
91 |
92 | }
93 |
94 | private let modelTypeRegistry = ModelTypeRegistry()
95 |
96 | public struct ModelType: RawRepresentable, Codable, Sendable {
97 | public let rawValue: String
98 |
99 | public init(rawValue: String) {
100 | self.rawValue = rawValue
101 | }
102 |
103 | public static func registerModelType(
104 | _ type: String, creator: @Sendable @escaping (URL) throws -> EmbeddingModel
105 | ) {
106 | modelTypeRegistry.registerModelType(type, creator: creator)
107 | }
108 |
109 | public func createModel(configuration: URL) throws -> EmbeddingModel {
110 | try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue)
111 | }
112 | }
113 |
114 | public struct BaseConfiguration: Codable, Sendable {
115 | public let modelType: ModelType
116 |
117 | public struct Quantization: Codable, Sendable {
118 | public init(groupSize: Int, bits: Int) {
119 | self.groupSize = groupSize
120 | self.bits = bits
121 | }
122 |
123 | let groupSize: Int
124 | let bits: Int
125 |
126 | enum CodingKeys: String, CodingKey {
127 | case groupSize = "group_size"
128 | case bits = "bits"
129 | }
130 | }
131 |
132 | public var quantization: Quantization?
133 |
134 | enum CodingKeys: String, CodingKey {
135 | case modelType = "model_type"
136 | case quantization
137 | }
138 | }
139 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/ModelFactory.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import Hub
5 | import Tokenizers
6 |
7 | public enum ModelFactoryError: LocalizedError {
8 | case unsupportedModelType(String)
9 | case unsupportedProcessorType(String)
10 |
11 | public var errorDescription: String? {
12 | switch self {
13 | case .unsupportedModelType(let type): "Unsupported model type: \(type)"
14 | case .unsupportedProcessorType(let type): "Unsupported processor type: \(type)"
15 | }
16 | }
17 | }
18 |
19 | /// Context of types that work together to provide a ``LanguageModel``.
20 | ///
21 | /// A ``ModelContext`` is created by ``ModelFactory/load(hub:configuration:progressHandler:)``.
22 | /// This contains the following:
23 | ///
24 | /// - ``ModelConfiguration`` -- identifier for the model
25 | /// - ``LanguageModel`` -- the model itself, see ``generate(input:parameters:context:didGenerate:)``
26 | /// - ``UserInputProcessor`` -- can convert ``UserInput`` into ``LMInput``
27 | /// - `Tokenizer` -- the tokenizer used by ``UserInputProcessor``
28 | ///
29 | /// See also ``ModelFactory/loadContainer(hub:configuration:progressHandler:)`` and
30 | /// ``ModelContainer``.
31 | public struct ModelContext {
32 | public var configuration: ModelConfiguration
33 | public var model: any LanguageModel
34 | public var processor: any UserInputProcessor
35 | public var tokenizer: Tokenizer
36 |
37 | public init(
38 | configuration: ModelConfiguration, model: any LanguageModel,
39 | processor: any UserInputProcessor, tokenizer: any Tokenizer
40 | ) {
41 | self.configuration = configuration
42 | self.model = model
43 | self.processor = processor
44 | self.tokenizer = tokenizer
45 | }
46 | }
47 |
48 | public protocol ModelFactory: Sendable {
49 |
50 | var modelRegistry: AbstractModelRegistry { get }
51 |
52 | func _load(
53 | hub: HubApi, configuration: ModelConfiguration,
54 | progressHandler: @Sendable @escaping (Progress) -> Void
55 | ) async throws -> ModelContext
56 |
57 | func _loadContainer(
58 | hub: HubApi, configuration: ModelConfiguration,
59 | progressHandler: @Sendable @escaping (Progress) -> Void
60 | ) async throws -> ModelContainer
61 |
62 | }
63 |
64 | extension ModelFactory {
65 |
66 | /// Resolve a model identifier, e.g. "mlx-community/Llama-3.2-3B-Instruct-4bit", into
67 | /// a ``ModelConfiguration``.
68 | ///
69 | /// This will either create a new (mostly unconfigured) ``ModelConfiguration`` or
70 | /// return a registered instance that matches the id.
71 | ///
72 | /// - Note: If the id doesn't exists in the configuration, this will return a new instance of it.
73 | /// If you want to check if the configuration in model registry, you should use ``contains(id:)``.
74 | public func configuration(id: String) -> ModelConfiguration {
75 | modelRegistry.configuration(id: id)
76 | }
77 |
78 | /// Returns true if ``modelRegistry`` contains a model with the id. Otherwise, false.
79 | public func contains(id: String) -> Bool {
80 | modelRegistry.contains(id: id)
81 | }
82 |
83 | }
84 |
85 | extension ModelFactory {
86 |
87 | /// Load a model identified by a ``ModelConfiguration`` and produce a ``ModelContext``.
88 | ///
89 | /// This method returns a ``ModelContext``. See also
90 | /// ``loadContainer(hub:configuration:progressHandler:)`` for a method that
91 | /// returns a ``ModelContainer``.
92 | public func load(
93 | hub: HubApi = HubApi(), configuration: ModelConfiguration,
94 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
95 | ) async throws -> ModelContext {
96 | try await _load(hub: hub, configuration: configuration, progressHandler: progressHandler)
97 | }
98 |
99 | /// Load a model identified by a ``ModelConfiguration`` and produce a ``ModelContainer``.
100 | public func loadContainer(
101 | hub: HubApi = HubApi(), configuration: ModelConfiguration,
102 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
103 | ) async throws -> ModelContainer {
104 | try await _loadContainer(
105 | hub: hub, configuration: configuration, progressHandler: progressHandler)
106 | }
107 |
108 | public func _loadContainer(
109 | hub: HubApi = HubApi(), configuration: ModelConfiguration,
110 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
111 | ) async throws -> ModelContainer {
112 | let context = try await _load(
113 | hub: hub, configuration: configuration, progressHandler: progressHandler)
114 | return ModelContainer(context: context)
115 | }
116 |
117 | }
118 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/ContentView.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import MLX
4 | import MLXMNIST
5 | import MLXNN
6 | import MLXOptimizers
7 | import MLXRandom
8 | import SwiftUI
9 |
10 | struct TrainingView: View {
11 |
12 | @Binding var trainer: ModelState
13 |
14 | var body: some View {
15 | VStack {
16 | Spacer()
17 |
18 | ScrollView(.vertical) {
19 | ForEach(trainer.messages, id: \.self) {
20 | Text($0)
21 | }
22 | }
23 |
24 | HStack {
25 | Spacer()
26 | switch trainer.state {
27 | case .untrained:
28 | Button("Train") {
29 | Task {
30 | try! await trainer.train()
31 | }
32 | }
33 | case .trained(let model), .predict(let model):
34 | Button("Draw a digit") {
35 | trainer.state = .predict(model)
36 | }
37 | }
38 |
39 | Spacer()
40 | }
41 | Spacer()
42 | }
43 | .padding()
44 | }
45 | }
46 |
47 | struct ContentView: View {
48 | // the training loop
49 | @State var trainer = ModelState()
50 |
51 | var body: some View {
52 | switch trainer.state {
53 | case .untrained, .trained:
54 | TrainingView(trainer: $trainer)
55 | case .predict(let model):
56 | PredictionView(model: model)
57 | }
58 | }
59 | }
60 |
61 | @MainActor
62 | @Observable
63 | class ModelState {
64 |
65 | enum State {
66 | case untrained
67 | case trained(LeNetContainer)
68 | case predict(LeNetContainer)
69 | }
70 |
71 | var state: State = .untrained
72 | var messages = [String]()
73 |
74 | func train() async throws {
75 | let model = LeNetContainer()
76 | try await model.train(output: self)
77 | self.state = .trained(model)
78 | }
79 | }
80 |
81 | actor LeNetContainer {
82 |
83 | private let model = LeNet()
84 |
85 | let mnistImageSize: CGSize = CGSize(width: 28, height: 28)
86 |
87 | func train(output: ModelState) async throws {
88 | // Note: this is pretty close to the code in `mnist-tool`, just
89 | // wrapped in an Observable to make it easy to display in SwiftUI
90 |
91 | // download & load the training data
92 | let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
93 | try await download(into: url)
94 | let data = try load(from: url)
95 |
96 | let trainImages = data[.init(.training, .images)]!
97 | let trainLabels = data[.init(.training, .labels)]!
98 | let testImages = data[.init(.test, .images)]!
99 | let testLabels = data[.init(.test, .labels)]!
100 |
101 | eval(model.parameters())
102 |
103 | // the training loop
104 | let lg = valueAndGrad(model: model, loss)
105 | let optimizer = SGD(learningRate: 0.1)
106 |
107 | // using a consistent random seed so it behaves the same way each time
108 | MLXRandom.seed(0)
109 | var generator: RandomNumberGenerator = SplitMix64(seed: 0)
110 |
111 | for e in 0 ..< 10 {
112 | let start = Date.timeIntervalSinceReferenceDate
113 |
114 | for (x, y) in iterateBatches(
115 | batchSize: 256, x: trainImages, y: trainLabels, using: &generator)
116 | {
117 | // loss and gradients
118 | let (_, grads) = lg(model, x, y)
119 |
120 | // use SGD to update the weights
121 | optimizer.update(model: model, gradients: grads)
122 |
123 | // eval the parameters so the next iteration is independent
124 | eval(model, optimizer)
125 | }
126 |
127 | let accuracy = eval(model: model, x: testImages, y: testLabels)
128 |
129 | let end = Date.timeIntervalSinceReferenceDate
130 |
131 | // add to messages -- triggers display
132 | let accuracyItem = accuracy.item(Float.self)
133 | await MainActor.run {
134 | output.messages.append(
135 | """
136 | Epoch \(e): test accuracy \(accuracyItem.formatted())
137 | Time: \((end - start).formatted())
138 |
139 | """
140 | )
141 | }
142 | }
143 | }
144 |
145 | func evaluate(image: CGImage) -> Int? {
146 | let pixelData = image.grayscaleImage(with: mnistImageSize)?.pixelData()
147 | if let pixelData {
148 | let x = pixelData.reshaped([1, 28, 28, 1]).asType(.float32) / 255.0
149 | return argMax(model(x)).item()
150 | } else {
151 | return nil
152 | }
153 | }
154 | }
155 |
--------------------------------------------------------------------------------
/Libraries/MLXLMCommon/Tokenizer.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import Hub
5 | import Tokenizers
6 |
7 | struct TokenizerError: Error {
8 | let message: String
9 | }
10 |
11 | public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer
12 | {
13 | let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig(
14 | configuration: configuration, hub: hub)
15 |
16 | return try PreTrainedTokenizer(
17 | tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
18 | }
19 |
20 | public func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async throws -> (
21 | Config, Config
22 | ) {
23 | // from AutoTokenizer.from() -- this lets us override parts of the configuration
24 | let config: LanguageModelConfigurationFromHub
25 |
26 | switch configuration.id {
27 | case .id(let id):
28 | do {
29 | // the load can fail (async when we try to use it)
30 | let loaded = LanguageModelConfigurationFromHub(
31 | modelName: configuration.tokenizerId ?? id, hubApi: hub)
32 | _ = try await loaded.tokenizerConfig
33 | config = loaded
34 | } catch {
35 | let nserror = error as NSError
36 | if nserror.domain == NSURLErrorDomain
37 | && nserror.code == NSURLErrorNotConnectedToInternet
38 | {
39 | // Internet connection appears to be offline -- fall back to loading from
40 | // the local directory
41 | config = LanguageModelConfigurationFromHub(
42 | modelFolder: configuration.modelDirectory(hub: hub), hubApi: hub)
43 | } else {
44 | throw error
45 | }
46 | }
47 | case .directory(let directory):
48 | config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub)
49 | }
50 |
51 | guard var tokenizerConfig = try await config.tokenizerConfig else {
52 | throw TokenizerError(message: "missing config")
53 | }
54 | let tokenizerData = try await config.tokenizerData
55 |
56 | tokenizerConfig = updateTokenizerConfig(tokenizerConfig)
57 |
58 | return (tokenizerConfig, tokenizerData)
59 | }
60 |
61 | private func updateTokenizerConfig(_ tokenizerConfig: Config) -> Config {
62 | // workaround: replacement tokenizers for unhandled values in swift-transform
63 | if let tokenizerClass = tokenizerConfig.tokenizerClass?.stringValue,
64 | let replacement = replacementTokenizers[tokenizerClass]
65 | {
66 | var dictionary = tokenizerConfig.dictionary
67 | dictionary["tokenizer_class"] = replacement
68 | return Config(dictionary)
69 | }
70 |
71 | return tokenizerConfig
72 | }
73 |
74 | public class TokenizerReplacementRegistry: @unchecked Sendable {
75 |
76 | // Note: using NSLock as we have very small (just dictionary get/set)
77 | // critical sections and expect no contention. this allows the methods
78 | // to remain synchronous.
79 | private let lock = NSLock()
80 |
81 | /// overrides for TokenizerModel/knownTokenizers
82 | private var replacementTokenizers = [
83 | "InternLM2Tokenizer": "PreTrainedTokenizer",
84 | "Qwen2Tokenizer": "PreTrainedTokenizer",
85 | "CohereTokenizer": "PreTrainedTokenizer",
86 | ]
87 |
88 | public subscript(key: String) -> String? {
89 | get {
90 | lock.withLock {
91 | replacementTokenizers[key]
92 | }
93 | }
94 | set {
95 | lock.withLock {
96 | replacementTokenizers[key] = newValue
97 | }
98 | }
99 | }
100 | }
101 |
102 | public let replacementTokenizers = TokenizerReplacementRegistry()
103 |
104 | public protocol StreamingDetokenizer: IteratorProtocol {
105 |
106 | mutating func append(token: Int)
107 |
108 | }
109 |
110 | public struct NaiveStreamingDetokenizer: StreamingDetokenizer {
111 | let tokenizer: Tokenizer
112 |
113 | var segmentTokens = [Int]()
114 | var segment = ""
115 |
116 | public init(tokenizer: Tokenizer) {
117 | self.tokenizer = tokenizer
118 | }
119 |
120 | mutating public func append(token: Int) {
121 | segmentTokens.append(token)
122 | }
123 |
124 | mutating func startNewSegment() {
125 | let lastToken = segmentTokens.last
126 | segmentTokens.removeAll()
127 | if let lastToken {
128 | segmentTokens.append(lastToken)
129 | segment = tokenizer.decode(tokens: segmentTokens)
130 | } else {
131 | segment = ""
132 | }
133 | }
134 |
135 | public mutating func next() -> String? {
136 | let newSegment = tokenizer.decode(tokens: segmentTokens)
137 | let new = newSegment.suffix(newSegment.count - segment.count)
138 |
139 | if new.hasSuffix("\n") {
140 | startNewSegment()
141 | } else {
142 | self.segment = newSegment
143 | }
144 |
145 | return String(new)
146 | }
147 |
148 | }
149 |
--------------------------------------------------------------------------------
/Libraries/StableDiffusion/Clip.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import MLX
5 | import MLXNN
6 |
7 | // port of https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
8 |
9 | struct CLIPOutput {
10 | /// The lastHiddenState indexed at the EOS token and possibly projected if
11 | /// the model has a projection layer
12 | public var pooledOutput: MLXArray
13 |
14 | /// The full sequence output of the transformer after the final layernorm
15 | public var lastHiddenState: MLXArray
16 |
17 | /// A list of hidden states corresponding to the outputs of the transformer layers
18 | public var hiddenStates: [MLXArray]
19 | }
20 |
21 | /// The transformer encoder layer from CLIP
22 | class CLIPEncoderLayer: Module {
23 |
24 | @ModuleInfo(key: "layer_norm1") var layerNorm1: LayerNorm
25 | @ModuleInfo(key: "layer_norm2") var layerNorm2: LayerNorm
26 |
27 | let attention: MultiHeadAttention
28 |
29 | @ModuleInfo var linear1: Linear
30 | @ModuleInfo var linear2: Linear
31 |
32 | let activation: (MLXArray) -> MLXArray
33 |
34 | init(modelDimensions: Int, numHeads: Int, activation: @escaping (MLXArray) -> MLXArray) {
35 | self._layerNorm1.wrappedValue = LayerNorm(dimensions: modelDimensions)
36 | self._layerNorm2.wrappedValue = LayerNorm(dimensions: modelDimensions)
37 |
38 | self.attention = MultiHeadAttention(
39 | dimensions: modelDimensions, numHeads: numHeads, bias: true)
40 |
41 | self.linear1 = Linear(modelDimensions, 4 * modelDimensions)
42 | self.linear2 = Linear(4 * modelDimensions, modelDimensions)
43 |
44 | self.activation = activation
45 | }
46 |
47 | func callAsFunction(_ x: MLXArray, attentionMask: MLXArray? = nil) -> MLXArray {
48 | var y = layerNorm1(x)
49 | y = attention(y, keys: y, values: y, mask: attentionMask)
50 | var x = y + x
51 |
52 | y = layerNorm2(x)
53 | y = linear1(y)
54 | y = activation(y)
55 | y = linear2(y)
56 | x = y + x
57 |
58 | return x
59 | }
60 | }
61 |
62 | /// Implements the text encoder transformer from CLIP
63 | class CLIPTextModel: Module {
64 |
65 | @ModuleInfo(key: "token_embedding") var tokenEmbedding: Embedding
66 | @ModuleInfo(key: "position_embedding") var positionEmbedding: Embedding
67 |
68 | let layers: [CLIPEncoderLayer]
69 |
70 | @ModuleInfo(key: "final_layer_norm") var finalLayerNorm: LayerNorm
71 |
72 | @ModuleInfo(key: "text_projection") var textProjection: Linear?
73 |
74 | init(configuration: CLIPTextModelConfiguration) {
75 | self._tokenEmbedding.wrappedValue = Embedding(
76 | embeddingCount: configuration.vocabularySize, dimensions: configuration.modelDimensions)
77 | self._positionEmbedding.wrappedValue = Embedding(
78 | embeddingCount: configuration.maxLength, dimensions: configuration.modelDimensions)
79 |
80 | self.layers = (0 ..< configuration.numLayers)
81 | .map { _ in
82 | CLIPEncoderLayer(
83 | modelDimensions: configuration.modelDimensions,
84 | numHeads: configuration.numHeads,
85 | activation: configuration.hiddenActivation.activation)
86 | }
87 |
88 | self._finalLayerNorm.wrappedValue = LayerNorm(dimensions: configuration.modelDimensions)
89 |
90 | if let projectionDimensions = configuration.projectionDimensions {
91 | self._textProjection.wrappedValue = Linear(
92 | configuration.modelDimensions, projectionDimensions, bias: false)
93 | } else {
94 | self._textProjection.wrappedValue = nil
95 | }
96 | }
97 |
98 | func mask(_ N: Int, _ dType: DType) -> MLXArray {
99 | let indices = MLXArray(0 ..< Int32(N))
100 | var mask = indices[0..., .newAxis] .< indices[.newAxis]
101 | mask = mask.asType(dType) * (dType == .float16 ? -6e4 : -1e9)
102 | return mask
103 | }
104 |
105 | func callAsFunction(_ x: MLXArray) -> CLIPOutput {
106 | var x = x
107 | let (_, N) = x.shape2
108 | let eosTokens = x.argMax(axis: -1)
109 |
110 | // compute the embeddings
111 | x = tokenEmbedding(x)
112 | x = x + positionEmbedding.weight[.. LoRALinearLayers {
94 | // TODO: modify as needed
95 | model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
96 | }
97 |
98 | public init(_ args: YourModelConfiguration) {
99 | self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers)
100 | self.model = YourModelInner(args)
101 | }
102 |
103 | public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
104 | // TODO: modify as needed
105 | let out = model(inputs, cache: cache)
106 | return model.embedTokens.asLinear(out)
107 | }
108 | }
109 | ```
110 |
111 | ## Register the Model
112 |
113 | In [LLMModelFactory.swift](LLMModelFactory.swift) register the model type itself
114 | (this is independent of the model id):
115 |
116 | ```swift
117 | public class ModelTypeRegistry: @unchecked Sendable {
118 | ...
119 | private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [
120 | "yourModel": create(YourModelConfiguration.self, YourModel.init),
121 | ```
122 |
123 | Add a constant for the model in the `ModelRegistry` (not strictly required but useful
124 | for callers to refer to it in code):
125 |
126 | ```swift
127 | public class ModelRegistry: @unchecked Sendable {
128 | ...
129 | static public let yourModel_4bit = ModelConfiguration(
130 | id: "mlx-community/YourModel-4bit",
131 | defaultPrompt: "What is the gravity on Mars and the moon?"
132 | )
133 | ```
134 |
135 | and finally add it to the all list -- this will let users find the model
136 | configuration by id:
137 |
138 | ```swift
139 | private static func all() -> [ModelConfiguration] {
140 | [
141 | codeLlama13b4bit,
142 | ...
143 | yourModel_4bit,
144 | ```
145 |
146 | # Using a Model
147 |
148 | See [MLXLMCommon/README.md](../MLXLMCommon/README.md#using-a-model).
149 |
150 | # LoRA
151 |
152 | [Lora.swift](Lora.swift) contains an implementation of LoRA based on this example:
153 |
154 | - https://github.com/ml-explore/mlx-examples/tree/main/lora
155 |
156 | See [llm-tool/LoraCommands.swift](../../Tools/llm-tool/LoraCommands.swift) for an example of a driver and
157 | [llm-tool](../../Tools/llm-tool) for examples of how to run it.
158 |
--------------------------------------------------------------------------------
/Libraries/MLXLLM/SwitchLayers.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import MLX
3 | import MLXFast
4 | import MLXNN
5 | import MLXRandom
6 |
7 | // Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/switch_layers.py
8 |
9 | // MARK: - SwitchGLU
10 |
11 | class SwitchGLU: Module {
12 | @ModuleInfo(key: "gate_proj") var gateProj: SwitchLinear
13 | @ModuleInfo(key: "up_proj") var upProj: SwitchLinear
14 | @ModuleInfo(key: "down_proj") var downProj: SwitchLinear
15 |
16 | let inputDims: Int
17 | let hiddenDims: Int
18 | let numExperts: Int
19 | let activation: (MLXArray) -> MLXArray
20 |
21 | init(
22 | inputDims: Int,
23 | hiddenDims: Int,
24 | numExperts: Int,
25 | activation: @escaping (MLXArray) -> MLXArray = MLXNN.silu,
26 | bias: Bool = false
27 | ) {
28 | self.inputDims = inputDims
29 | self.hiddenDims = hiddenDims
30 | self.numExperts = numExperts
31 | self.activation = activation
32 |
33 | self._gateProj.wrappedValue = SwitchLinear(
34 | inputDims: inputDims, outputDims: hiddenDims, numExperts: numExperts, bias: bias)
35 | self._upProj.wrappedValue = SwitchLinear(
36 | inputDims: inputDims, outputDims: hiddenDims, numExperts: numExperts, bias: bias)
37 | self._downProj.wrappedValue = SwitchLinear(
38 | inputDims: hiddenDims, outputDims: inputDims, numExperts: numExperts, bias: bias)
39 |
40 | super.init()
41 | }
42 |
43 | func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray {
44 | let x = MLX.expandedDimensions(x, axes: [-2, -3])
45 |
46 | let xUp = upProj(x, indices)
47 | let xGate = gateProj(x, indices)
48 | let xDown = downProj(activation(xGate) * xUp, indices)
49 |
50 | return MLX.squeezed(xDown, axis: -2)
51 | }
52 | }
53 |
54 | class SwitchLinear: Module, Quantizable {
55 | @ModuleInfo(key: "weight") var weight: MLXArray
56 | @ModuleInfo(key: "bias") var bias: MLXArray?
57 |
58 | let inputDims: Int
59 | let outputDims: Int
60 | let numExperts: Int
61 |
62 | init(inputDims: Int, outputDims: Int, numExperts: Int, bias: Bool = true) {
63 | self.inputDims = inputDims
64 | self.outputDims = outputDims
65 | self.numExperts = numExperts
66 |
67 | let scale = sqrt(1.0 / Float(inputDims))
68 | self._weight.wrappedValue = MLXRandom.uniform(
69 | low: -scale,
70 | high: scale,
71 | [numExperts, outputDims, inputDims]
72 | )
73 |
74 | if bias {
75 | self._bias.wrappedValue = MLXArray.zeros([numExperts, outputDims])
76 | }
77 |
78 | super.init()
79 | }
80 |
81 | /// Initializer meant for subclasses to provide weight and bias arrays directly.
82 | ///
83 | /// This is used e.g. by ``QuantizedSwitchLinear`` to provide quantized weights and biases
84 | /// rather than have ``SwitchLinear`` compute them.
85 | init(
86 | inputDims: Int, outputDims: Int, numExperts: Int,
87 | weight: MLXArray, bias: MLXArray? = nil
88 | ) {
89 | self.inputDims = inputDims
90 | self.outputDims = outputDims
91 | self.numExperts = numExperts
92 |
93 | self._weight.wrappedValue = weight
94 | self._bias.wrappedValue = bias
95 | }
96 |
97 | func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray {
98 | let weightT = self.weight.swappedAxes(-1, -2)
99 | var result = MLX.gatherMatmul(x, weightT, rhsIndices: indices)
100 |
101 | if let bias = self.bias {
102 | result = result + MLX.expandedDimensions(bias[indices], axis: -2)
103 | }
104 |
105 | return result
106 | }
107 |
108 | func toQuantized(groupSize: Int = 64, bits: Int = 4) -> Module {
109 | QuantizedSwitchLinear(self, groupSize: groupSize, bits: bits)
110 | }
111 | }
112 |
113 | class QuantizedSwitchLinear: SwitchLinear {
114 | @ModuleInfo(key: "scales") var scales: MLXArray
115 | @ModuleInfo(key: "biases") var biases: MLXArray
116 |
117 | let groupSize: Int
118 | let bits: Int
119 |
120 | init(_ other: SwitchLinear, groupSize: Int = 64, bits: Int = 4) {
121 | self.groupSize = groupSize
122 | self.bits = bits
123 |
124 | let (quantizedWeight, scales, biases) = MLX.quantized(
125 | other.weight, groupSize: groupSize, bits: bits)
126 |
127 | self._scales.wrappedValue = scales
128 | self._biases.wrappedValue = biases
129 |
130 | super.init(
131 | inputDims: other.inputDims, outputDims: other.outputDims, numExperts: other.numExperts,
132 | weight: quantizedWeight, bias: other.bias)
133 |
134 | self.freeze()
135 | }
136 |
137 | override func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray {
138 | var result = MLX.gatherQuantizedMatmul(
139 | x,
140 | self.weight,
141 | scales: self.scales,
142 | biases: self.biases,
143 | rhsIndices: indices,
144 | transpose: true,
145 | groupSize: self.groupSize,
146 | bits: self.bits
147 | )
148 |
149 | if let bias = self.bias {
150 | result = result + MLX.expandedDimensions(bias[indices], axis: -2)
151 | }
152 |
153 | return result
154 | }
155 | }
156 |
--------------------------------------------------------------------------------
/Libraries/Embedders/Models.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import Hub
5 |
6 | /// Registry of models and any overrides that go with them, e.g. prompt augmentation.
7 | /// If asked for an unknown configuration this will use the model/tokenizer as-is.
8 | ///
9 | /// The python tokenizers have a very rich set of implementations and configuration. The
10 | /// swift-tokenizers code handles a good chunk of that and this is a place to augment that
11 | /// implementation, if needed.
12 | public struct ModelConfiguration: Sendable {
13 |
14 | public enum Identifier: Sendable {
15 | case id(String)
16 | case directory(URL)
17 | }
18 |
19 | public var id: Identifier
20 |
21 | public var name: String {
22 | switch id {
23 | case .id(let string):
24 | string
25 | case .directory(let url):
26 | url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent
27 | }
28 | }
29 |
30 | /// pull the tokenizer from an alternate id
31 | public let tokenizerId: String?
32 |
33 | /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
34 | public let overrideTokenizer: String?
35 |
36 | public init(
37 | id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil
38 | ) {
39 | self.id = .id(id)
40 | self.tokenizerId = tokenizerId
41 | self.overrideTokenizer = overrideTokenizer
42 | }
43 |
44 | public init(
45 | directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil
46 | ) {
47 | self.id = .directory(directory)
48 | self.tokenizerId = tokenizerId
49 | self.overrideTokenizer = overrideTokenizer
50 | }
51 |
52 | public func modelDirectory(hub: HubApi = HubApi()) -> URL {
53 | switch id {
54 | case .id(let id):
55 | // download the model weights and config
56 | let repo = Hub.Repo(id: id)
57 | return hub.localRepoLocation(repo)
58 |
59 | case .directory(let directory):
60 | return directory
61 | }
62 | }
63 |
64 | @MainActor
65 | public static var registry = [String: ModelConfiguration]()
66 |
67 | @MainActor
68 | public static func register(configurations: [ModelConfiguration]) {
69 | bootstrap()
70 |
71 | for c in configurations {
72 | registry[c.name] = c
73 | }
74 | }
75 |
76 | @MainActor
77 | public static func configuration(id: String) -> ModelConfiguration {
78 | bootstrap()
79 |
80 | if let c = registry[id] {
81 | return c
82 | } else {
83 | return ModelConfiguration(id: id)
84 | }
85 | }
86 |
87 | @MainActor
88 | public static var models: some Collection & Sendable {
89 | bootstrap()
90 | return Self.registry.values
91 | }
92 | }
93 |
94 | extension ModelConfiguration {
95 | public static let bge_micro = ModelConfiguration(id: "TaylorAI/bge-micro-v2")
96 | public static let gte_tiny = ModelConfiguration(id: "TaylorAI/gte-tiny")
97 | public static let minilm_l6 = ModelConfiguration(id: "sentence-transformers/all-MiniLM-L6-v2")
98 | public static let snowflake_xs = ModelConfiguration(id: "Snowflake/snowflake-arctic-embed-xs")
99 | public static let minilm_l12 = ModelConfiguration(id: "sentence-transformers/all-MiniLM-L12-v2")
100 | public static let bge_small = ModelConfiguration(id: "BAAI/bge-small-en-v1.5")
101 | public static let multilingual_e5_small = ModelConfiguration(
102 | id: "intfloat/multilingual-e5-small")
103 | public static let bge_base = ModelConfiguration(id: "BAAI/bge-base-en-v1.5")
104 | public static let nomic_text_v1 = ModelConfiguration(id: "nomic-ai/nomic-embed-text-v1")
105 | public static let nomic_text_v1_5 = ModelConfiguration(id: "nomic-ai/nomic-embed-text-v1.5")
106 | public static let bge_large = ModelConfiguration(id: "BAAI/bge-large-en-v1.5")
107 | public static let snowflake_lg = ModelConfiguration(id: "Snowflake/snowflake-arctic-embed-l")
108 | public static let bge_m3 = ModelConfiguration(id: "BAAI/bge-m3")
109 | public static let mixedbread_large = ModelConfiguration(
110 | id: "mixedbread-ai/mxbai-embed-large-v1")
111 |
112 | private enum BootstrapState: Sendable {
113 | case idle
114 | case bootstrapping
115 | case bootstrapped
116 | }
117 |
118 | @MainActor
119 | static private var bootstrapState = BootstrapState.idle
120 |
121 | @MainActor
122 | static func bootstrap() {
123 | switch bootstrapState {
124 | case .idle:
125 | bootstrapState = .bootstrapping
126 | register(configurations: [
127 | bge_micro,
128 | gte_tiny,
129 | minilm_l6,
130 | snowflake_xs,
131 | minilm_l12,
132 | bge_small,
133 | multilingual_e5_small,
134 | bge_base,
135 | nomic_text_v1,
136 | nomic_text_v1_5,
137 | bge_large,
138 | snowflake_lg,
139 | bge_m3,
140 | mixedbread_large,
141 | ])
142 | bootstrapState = .bootstrapped
143 |
144 | case .bootstrapping:
145 | break
146 |
147 | case .bootstrapped:
148 | break
149 | }
150 | }
151 | }
152 |
--------------------------------------------------------------------------------
/Libraries/StableDiffusion/Image.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import CoreGraphics
4 | import CoreImage
5 | import Foundation
6 | import ImageIO
7 | import MLX
8 | import UniformTypeIdentifiers
9 |
10 | enum ImageError: LocalizedError {
11 | case failedToSave
12 | case unableToOpen
13 |
14 | var errorDescription: String? {
15 | switch self {
16 | case .failedToSave:
17 | return String(localized: "Failed to save the image to the specified location.")
18 | case .unableToOpen:
19 | return String(localized: "Unable to open the image file.")
20 | }
21 | }
22 | }
23 |
24 | /// Conversion utilities for moving between `MLXArray`, `CGImage` and files.
25 | public struct Image {
26 |
27 | public let data: MLXArray
28 |
29 | /// Create an Image from a MLXArray with ndim == 3
30 | public init(_ data: MLXArray) {
31 | precondition(data.ndim == 3)
32 | self.data = data
33 | }
34 |
35 | /// Create an Image by loading from a file
36 | public init(url: URL, maximumEdge: Int? = nil) throws {
37 | guard let source = CGImageSourceCreateWithURL(url as CFURL, nil),
38 | let image = CGImageSourceCreateImageAtIndex(source, 0, nil)
39 | else {
40 | throw ImageError.unableToOpen
41 | }
42 |
43 | self.init(image: image)
44 | }
45 |
46 | /// Create an image from a CGImage
47 | public init(image: CGImage, maximumEdge: Int? = nil) {
48 | // ensure the sizes ar multiples of 64 -- this doesn't worry about
49 | // the aspect ratio
50 |
51 | var width = image.width
52 | var height = image.height
53 |
54 | if let maximumEdge {
55 | func scale(_ edge: Int, _ maxEdge: Int) -> Int {
56 | Int(round(Float(maximumEdge) / Float(maxEdge) * Float(edge)))
57 | }
58 |
59 | // aspect fit inside the given maximum
60 | if width >= height {
61 | width = scale(width, image.width)
62 | height = scale(height, image.width)
63 | } else {
64 | width = scale(width, image.height)
65 | height = scale(height, image.height)
66 | }
67 | }
68 |
69 | // size must be multiples of 64 -- coerce without regard to aspect ratio
70 | width = width - width % 64
71 | height = height - height % 64
72 |
73 | var raster = Data(count: width * 4 * height)
74 | raster.withUnsafeMutableBytes { ptr in
75 | let cs = CGColorSpace(name: CGColorSpace.sRGB)!
76 | let context = CGContext(
77 | data: ptr.baseAddress, width: width, height: height, bitsPerComponent: 8,
78 | bytesPerRow: width * 4, space: cs,
79 | bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue
80 | | CGBitmapInfo.byteOrder32Big.rawValue)!
81 |
82 | context.draw(
83 | image, in: CGRect(origin: .zero, size: .init(width: width, height: height)))
84 | }
85 |
86 | self.data = MLXArray(raster, [height, width, 4], type: UInt8.self)[0..., 0..., ..<3]
87 | }
88 |
89 | /// Convert the image data to a CGImage
90 | public func asCGImage() -> CGImage {
91 | var raster = data
92 |
93 | // we need 4 bytes per pixel
94 | if data.dim(-1) == 3 {
95 | raster = padded(raster, widths: [0, 0, [0, 1]])
96 | }
97 |
98 | class DataHolder {
99 | var data: Data
100 | init(_ data: Data) {
101 | self.data = data
102 | }
103 | }
104 |
105 | let holder = DataHolder(raster.asData(access: .copy).data)
106 |
107 | let payload = Unmanaged.passRetained(holder).toOpaque()
108 | func release(payload: UnsafeMutableRawPointer?, data: UnsafeMutableRawPointer?) {
109 | Unmanaged.fromOpaque(payload!).release()
110 | }
111 |
112 | return holder.data.withUnsafeMutableBytes { ptr in
113 | let (H, W, C) = raster.shape3
114 | let cs = CGColorSpace(name: CGColorSpace.sRGB)!
115 |
116 | let context = CGContext(
117 | data: ptr.baseAddress, width: W, height: H, bitsPerComponent: 8, bytesPerRow: W * C,
118 | space: cs,
119 | bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue
120 | | CGBitmapInfo.byteOrder32Big.rawValue, releaseCallback: release,
121 | releaseInfo: payload)!
122 | return context.makeImage()!
123 | }
124 | }
125 |
126 | /// Convert the image data to a CIImage
127 | public func asCIImage() -> CIImage {
128 | // we need 4 bytes per pixel
129 | var raster = data
130 | if data.dim(-1) == 3 {
131 | raster = padded(raster, widths: [0, 0, [0, 1]], value: MLXArray(255))
132 | }
133 |
134 | let arrayData = raster.asData()
135 | let (H, W, C) = raster.shape3
136 | let cs = CGColorSpace(name: CGColorSpace.sRGB)!
137 |
138 | return CIImage(
139 | bitmapData: arrayData.data, bytesPerRow: W * 4, size: .init(width: W, height: H),
140 | format: .RGBA8, colorSpace: cs)
141 | }
142 |
143 | /// Save the image
144 | public func save(url: URL) throws {
145 | let uti = UTType(filenameExtension: url.pathExtension) ?? UTType.png
146 |
147 | let destination = CGImageDestinationCreateWithURL(
148 | url as CFURL, uti.identifier as CFString, 1, nil)!
149 | CGImageDestinationAddImage(destination, asCGImage(), nil)
150 | if !CGImageDestinationFinalize(destination) {
151 | throw ImageError.failedToSave
152 | }
153 | }
154 | }
155 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
8 |
9 |
15 |
21 |
22 |
23 |
24 |
25 |
31 |
32 |
43 |
45 |
51 |
52 |
53 |
54 |
57 |
58 |
61 |
62 |
65 |
66 |
69 |
70 |
73 |
74 |
77 |
78 |
81 |
82 |
85 |
86 |
89 |
90 |
93 |
94 |
97 |
98 |
101 |
102 |
105 |
106 |
107 |
108 |
114 |
116 |
122 |
123 |
124 |
125 |
127 |
128 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, caste, color, religion, or sexual
10 | identity and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the overall
26 | community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or advances of
31 | any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email address,
35 | without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com).
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series of
86 | actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or permanent
93 | ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within the
113 | community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.1, available at
119 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
120 |
121 | Community Impact Guidelines were inspired by
122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC].
123 |
124 | For answers to common questions about this code of conduct, see the FAQ at
125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
126 | [https://www.contributor-covenant.org/translations][translations].
127 |
128 | [homepage]: https://www.contributor-covenant.org
129 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
130 | [Mozilla CoC]: https://github.com/mozilla/diversity
131 | [FAQ]: https://www.contributor-covenant.org/faq
132 | [translations]: https://www.contributor-covenant.org/translations
133 |
--------------------------------------------------------------------------------
/Package.swift:
--------------------------------------------------------------------------------
1 | // swift-tools-version: 5.9
2 | // The swift-tools-version declares the minimum version of Swift required to build this package.
3 |
4 | import PackageDescription
5 |
6 | let package = Package(
7 | name: "mlx-libraries",
8 | platforms: [.macOS(.v14), .iOS(.v16)],
9 | products: [
10 | .library(
11 | name: "MLXLLM",
12 | targets: ["MLXLLM"]),
13 | .library(
14 | name: "MLXVLM",
15 | targets: ["MLXVLM"]),
16 | .library(
17 | name: "MLXLMCommon",
18 | targets: ["MLXLMCommon"]),
19 | .library(
20 | name: "MLXMNIST",
21 | targets: ["MLXMNIST"]),
22 | .library(
23 | name: "MLXEmbedders",
24 | targets: ["MLXEmbedders"]),
25 | .library(
26 | name: "StableDiffusion",
27 | targets: ["StableDiffusion"]),
28 | ],
29 | dependencies: [
30 | .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.21.2")),
31 | .package(
32 | url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.17")
33 | ),
34 | .package(
35 | url: "https://github.com/apple/swift-async-algorithms", .upToNextMinor(from: "1.0.0")),
36 | .package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"), // Only needed by MLXMNIST
37 | ],
38 | targets: [
39 | .target(
40 | name: "MLXLLM",
41 | dependencies: [
42 | "MLXLMCommon",
43 | .product(name: "MLX", package: "mlx-swift"),
44 | .product(name: "MLXFast", package: "mlx-swift"),
45 | .product(name: "MLXNN", package: "mlx-swift"),
46 | .product(name: "MLXOptimizers", package: "mlx-swift"),
47 | .product(name: "MLXRandom", package: "mlx-swift"),
48 | .product(name: "Transformers", package: "swift-transformers"),
49 | ],
50 | path: "Libraries/MLXLLM",
51 | exclude: [
52 | "README.md"
53 | ],
54 | swiftSettings: [
55 | .enableExperimentalFeature("StrictConcurrency")
56 | ]
57 | ),
58 | .target(
59 | name: "MLXVLM",
60 | dependencies: [
61 | "MLXLMCommon",
62 | .product(name: "MLX", package: "mlx-swift"),
63 | .product(name: "MLXFast", package: "mlx-swift"),
64 | .product(name: "MLXNN", package: "mlx-swift"),
65 | .product(name: "MLXOptimizers", package: "mlx-swift"),
66 | .product(name: "MLXRandom", package: "mlx-swift"),
67 | .product(name: "Transformers", package: "swift-transformers"),
68 | ],
69 | path: "Libraries/MLXVLM",
70 | exclude: [
71 | "README.md"
72 | ],
73 | swiftSettings: [
74 | .enableExperimentalFeature("StrictConcurrency")
75 | ]
76 | ),
77 | .target(
78 | name: "MLXLMCommon",
79 | dependencies: [
80 | .product(name: "MLX", package: "mlx-swift"),
81 | .product(name: "MLXNN", package: "mlx-swift"),
82 | .product(name: "MLXOptimizers", package: "mlx-swift"),
83 | .product(name: "MLXRandom", package: "mlx-swift"),
84 | .product(name: "Transformers", package: "swift-transformers"),
85 | ],
86 | path: "Libraries/MLXLMCommon",
87 | exclude: [
88 | "README.md"
89 | ],
90 | swiftSettings: [
91 | .enableExperimentalFeature("StrictConcurrency")
92 | ]
93 | ),
94 | .target(
95 | name: "MLXEmbedders",
96 | dependencies: [
97 | .product(name: "MLX", package: "mlx-swift"),
98 | .product(name: "MLXFast", package: "mlx-swift"),
99 | .product(name: "MLXNN", package: "mlx-swift"),
100 | .product(name: "Transformers", package: "swift-transformers"),
101 | .product(name: "MLXLinalg", package: "mlx-swift"),
102 | ],
103 | path: "Libraries/Embedders",
104 | exclude: [
105 | "README.md"
106 | ]
107 | ),
108 | .target(
109 | name: "MLXMNIST",
110 | dependencies: [
111 | .product(name: "MLX", package: "mlx-swift"),
112 | .product(name: "MLXFast", package: "mlx-swift"),
113 | .product(name: "MLXNN", package: "mlx-swift"),
114 | .product(name: "MLXOptimizers", package: "mlx-swift"),
115 | .product(name: "MLXRandom", package: "mlx-swift"),
116 | .product(name: "Transformers", package: "swift-transformers"),
117 | .product(name: "Gzip", package: "GzipSwift"),
118 | ],
119 | path: "Libraries/MLXMNIST",
120 | exclude: [
121 | "README.md"
122 | ],
123 | swiftSettings: [
124 | .enableExperimentalFeature("StrictConcurrency")
125 | ]
126 | ),
127 | .target(
128 | name: "StableDiffusion",
129 | dependencies: [
130 | .product(name: "MLX", package: "mlx-swift"),
131 | .product(name: "MLXNN", package: "mlx-swift"),
132 | .product(name: "MLXRandom", package: "mlx-swift"),
133 | .product(name: "Transformers", package: "swift-transformers"),
134 | ],
135 | path: "Libraries/StableDiffusion",
136 | exclude: [
137 | "README.md"
138 | ],
139 | swiftSettings: [
140 | .enableExperimentalFeature("StrictConcurrency")
141 | ]
142 | ),
143 | ]
144 | )
145 |
146 | if Context.environment["MLX_SWIFT_BUILD_DOC"] == "1"
147 | || Context.environment["SPI_GENERATE_DOCS"] == "1"
148 | {
149 | // docc builder
150 | package.dependencies.append(
151 | .package(url: "https://github.com/apple/swift-docc-plugin", from: "1.3.0")
152 | )
153 | }
154 |
--------------------------------------------------------------------------------