├── Applications ├── LLMEval │ ├── Assets.xcassets │ │ ├── Contents.json │ │ ├── AccentColor.colorset │ │ │ └── Contents.json │ │ └── AppIcon.appiconset │ │ │ └── Contents.json │ ├── Preview Content │ │ └── Preview Assets.xcassets │ │ │ └── Contents.json │ ├── LLMEvalApp.swift │ ├── LLMEval.entitlements │ ├── ViewModels │ │ └── DeviceStat.swift │ └── README.md ├── VLMEval │ ├── Assets.xcassets │ │ ├── Contents.json │ │ ├── AccentColor.colorset │ │ │ └── Contents.json │ │ └── AppIcon.appiconset │ │ │ └── Contents.json │ ├── Preview Content │ │ └── Preview Assets.xcassets │ │ │ └── Contents.json │ ├── VLMEvalApp.swift │ ├── VLMEval.entitlements │ └── README.md ├── MNISTTrainer │ ├── Assets.xcassets │ │ ├── Contents.json │ │ ├── AccentColor.colorset │ │ │ └── Contents.json │ │ └── AppIcon.appiconset │ │ │ └── Contents.json │ ├── Preview Content │ │ └── Preview Assets.xcassets │ │ │ └── Contents.json │ ├── MNISTTrainer-Info.plist │ ├── MNISTTrainerApp.swift │ ├── MNISTTrainer.entitlements │ ├── README.md │ ├── PredictionView.swift │ └── ContentView.swift ├── LoRATrainingExample │ ├── Assets.xcassets │ │ ├── Contents.json │ │ ├── AccentColor.colorset │ │ │ └── Contents.json │ │ └── AppIcon.appiconset │ │ │ └── Contents.json │ ├── Preview Content │ │ └── Preview Assets.xcassets │ │ │ └── Contents.json │ ├── LoRATrainingExampleApp.swift │ ├── LoRATrainingExample.entitlements │ └── README.md └── StableDiffusionExample │ ├── Assets.xcassets │ ├── Contents.json │ ├── AccentColor.colorset │ │ └── Contents.json │ └── AppIcon.appiconset │ │ └── Contents.json │ ├── Preview Content │ └── Preview Assets.xcassets │ │ └── Contents.json │ ├── StableDiffusionExampleApp.swift │ ├── StableDiffusionExample.entitlements │ └── README.md ├── .spi.yml ├── .swift-format ├── Libraries ├── MLXVLM │ └── VLMModel.swift ├── MLXLMCommon │ ├── Documentation.docc │ │ └── Documentation.md │ ├── Module+Extensions.swift │ ├── Registries │ │ ├── ModelTypeRegistry.swift │ │ ├── ProcessorTypeRegistry.swift │ │ └── AbstractModelRegistry.swift │ ├── ModelConfiguration.swift │ ├── ModelContainer.swift │ ├── StringOrNumber.swift │ ├── Load.swift │ ├── KVCache.swift │ ├── README.md │ ├── ModelFactory.swift │ └── Tokenizer.swift ├── MLXMNIST │ ├── README.md │ ├── Random.swift │ ├── MNIST.swift │ └── Files.swift ├── MLXLLM │ ├── Documentation.docc │ │ ├── Documentation.md │ │ ├── adding-model.md │ │ └── using-model.md │ ├── LLMModel.swift │ ├── SuScaledRotaryEmbedding.swift │ ├── Lora+Data.swift │ ├── README.md │ └── SwitchLayers.swift ├── Embedders │ ├── README.md │ ├── Tokenizer.swift │ ├── Pooling.swift │ ├── EmbeddingModel.swift │ ├── Load.swift │ ├── Configuration.swift │ └── Models.swift └── StableDiffusion │ ├── README.md │ ├── Sampler.swift │ ├── Tokenizer.swift │ ├── Clip.swift │ └── Image.swift ├── mlx-swift-examples.xcodeproj ├── project.xcworkspace │ ├── contents.xcworkspacedata │ └── xcshareddata │ │ ├── IDEWorkspaceChecks.plist │ │ └── swiftpm │ │ └── Package.resolved └── xcshareddata │ └── xcschemes │ ├── LLMEval.xcscheme │ ├── VLMEval.xcscheme │ ├── StableDiffusionExample.xcscheme │ └── llm-tool.xcscheme ├── .pre-commit-config.yaml ├── Tools ├── LinearModelTraining │ ├── README.md │ └── LinearModelTraining.swift ├── llm-tool │ └── Arguments.swift ├── mnist-tool │ ├── README.md │ └── MNISTTool.swift ├── image-tool │ └── Arguments.swift └── Tutorial │ └── Tutorial.swift ├── Configuration └── Build.xcconfig ├── ACKNOWLEDGMENTS.md ├── LICENSE ├── CONTRIBUTING.md ├── mlx-run ├── .circleci └── config.yml ├── Package.resolved ├── .gitignore ├── README.md ├── Data └── lora │ └── wikisql.py ├── CODE_OF_CONDUCT.md └── Package.swift /Applications/LLMEval/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/VLMEval/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /.spi.yml: -------------------------------------------------------------------------------- 1 | version: 1 2 | builder: 3 | configs: 4 | - documentation_targets: [MLXLLM, MLXVLM, MLXLMCommon, MLXMNIST, MLXEmbedders, StableDiffusion] 5 | -------------------------------------------------------------------------------- /.swift-format: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "indentation": { 4 | "spaces": 4 5 | }, 6 | "spacesAroundRangeFormationOperators": true, 7 | } 8 | -------------------------------------------------------------------------------- /Applications/LLMEval/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/VLMEval/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Libraries/MLXVLM/VLMModel.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import MLX 4 | import MLXLMCommon 5 | 6 | public protocol VLMModel: LanguageModel, LoRAModel { 7 | } 8 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/slessans/pre-commit-swift-format 3 | rev: "fd627de92bdf84a75c924ed95691336d14e94cf1" 4 | hooks: 5 | - id: swift-format 6 | args: ["--configuration", ".swift-format"] 7 | -------------------------------------------------------------------------------- /Applications/LLMEval/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/MNISTTrainer-Info.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Applications/VLMEval/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/MNISTTrainerApp.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | @main 6 | struct MNISTTrainerApp: App { 7 | var body: some Scene { 8 | WindowGroup { 9 | ContentView() 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/LoRATrainingExampleApp.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | @main 6 | struct LoRATrainingExampleApp: App { 7 | var body: some Scene { 8 | WindowGroup { 9 | ContentView() 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Applications/LLMEval/LLMEvalApp.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | @main 6 | struct LLMEvalApp: App { 7 | var body: some Scene { 8 | WindowGroup { 9 | ContentView() 10 | .environment(DeviceStat()) 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/StableDiffusionExampleApp.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | @main 6 | struct StableDiffusionExampleApp: App { 7 | var body: some Scene { 8 | WindowGroup { 9 | ContentView() 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Applications/VLMEval/VLMEvalApp.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | @main 6 | struct VLMEvalApp: App { 7 | var body: some Scene { 8 | WindowGroup { 9 | ContentView() 10 | .environment(DeviceStat()) 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/Documentation.docc/Documentation.md: -------------------------------------------------------------------------------- 1 | # ``MLXLMCommon`` 2 | 3 | Common language model code. 4 | 5 | ## Other MLX Libraries Packages 6 | 7 | - [MLXEmbedders](MLXEmbedders) 8 | - [MLXLLM](MLXLLM) 9 | - [MLXLMCommon](MLXLMCommon) 10 | - [MLXMNIST](MLXMNIST) 11 | - [MLXVLM](MLXVLM) 12 | - [StableDiffusion](StableDiffusion) 13 | 14 | -------------------------------------------------------------------------------- /Tools/LinearModelTraining/README.md: -------------------------------------------------------------------------------- 1 | # LinearModelTraining 2 | 3 | A command line tool that creates a Model that represents: 4 | 5 | f(x) = mx + b 6 | 7 | and trains it against an unknown linear function. Very 8 | simple but illustrates: 9 | 10 | - a very simple model with parameters 11 | - a loss function 12 | - the gradient 13 | - use of an optimizers 14 | - the training loop 15 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/MNISTTrainer.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.security.app-sandbox 6 | 7 | com.apple.security.files.user-selected.read-only 8 | 9 | com.apple.security.network.client 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /Configuration/Build.xcconfig: -------------------------------------------------------------------------------- 1 | // The `DISAMBIGUATOR` configuration is to make it easier to build 2 | // and run a sample code project. Once you set your project's development team, 3 | // you'll have a unique bundle identifier. This is because the bundle identifier 4 | // is derived based on the 'DISAMBIGUATOR' value. Do not use this 5 | // approach in your own projects—it's only useful for example projects because 6 | // they are frequently downloaded and don't have a development team set. 7 | DISAMBIGUATOR=${DEVELOPMENT_TEAM} 8 | -------------------------------------------------------------------------------- /Libraries/MLXMNIST/README.md: -------------------------------------------------------------------------------- 1 | # MNIST 2 | 3 | This is a port of the MNIST training code from the [Python MLX example](https://github.com/ml-explore/mlx-examples/blob/main/mnist). This example uses a [LeNet](https://en.wikipedia.org/wiki/LeNet) instead of an MLP. 4 | 5 | It provides code to: 6 | 7 | - Download the MNIST test/train data 8 | - Build the LeNet 9 | - Some functions to shuffle and batch the data 10 | 11 | See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there. 12 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/LoRATrainingExample.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.files.user-selected.read-only 10 | 11 | com.apple.security.network.client 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/StableDiffusionExample.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.files.user-selected.read-only 10 | 11 | com.apple.security.network.client 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /Applications/LLMEval/LLMEval.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.device.usb 10 | 11 | com.apple.security.files.user-selected.read-only 12 | 13 | com.apple.security.network.client 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /Applications/VLMEval/VLMEval.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.device.usb 10 | 11 | com.apple.security.files.user-selected.read-only 12 | 13 | com.apple.security.network.client 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /ACKNOWLEDGMENTS.md: -------------------------------------------------------------------------------- 1 | # Individual Contributors 2 | 3 | If you wish to be acknowledged for your contributions, please list your name 4 | with a short description of your contribution(s) below. For example: 5 | 6 | - Jane Smith: Added the `foo` and `bar` ops. 7 | 8 | MLX Swift was developed with contributions from the following individuals: 9 | 10 | 11 | 12 | 13 | 14 | 15 | SOFTWARE. 16 | 17 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/README.md: -------------------------------------------------------------------------------- 1 | # MNISTTrainer 2 | 3 | This is an example of model training that works on both macOS and iOS. 4 | The example will download the MNIST training data, create a LeNet, and train 5 | it. It will show the epoch time and test accuracy as it trains. 6 | 7 | You will need to set the Team on the MNISTTrainer target in order to build and 8 | run on iOS. 9 | 10 | Some notes about the setup: 11 | 12 | - This will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox 13 | - The website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist 14 | -------------------------------------------------------------------------------- /Libraries/MLXLLM/Documentation.docc/Documentation.md: -------------------------------------------------------------------------------- 1 | # ``MLXLLM`` 2 | 3 | Example implementations of various Large Language Models (LLMs). 4 | 5 | ## Other MLX Libraries Packages 6 | 7 | - [MLXEmbedders](MLXEmbedders) 8 | - [MLXLLM](MLXLLM) 9 | - [MLXLMCommon](MLXLMCommon) 10 | - [MLXMNIST](MLXMNIST) 11 | - [MLXVLM](MLXVLM) 12 | - [StableDiffusion](StableDiffusion) 13 | 14 | ## Topics 15 | 16 | - 17 | - 18 | 19 | ### Models 20 | 21 | - ``CohereModel`` 22 | - ``GemmaModel`` 23 | - ``Gemma2Model`` 24 | - ``InternLM2Model`` 25 | - ``LlamaModel`` 26 | - ``OpenELMModel`` 27 | - ``PhiModel`` 28 | - ``Phi3Model`` 29 | - ``PhiMoEModel`` 30 | - ``Qwen2Model`` 31 | - ``Starcoder2Model`` 32 | 33 | 34 | -------------------------------------------------------------------------------- /Applications/LLMEval/ViewModels/DeviceStat.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLX 3 | 4 | @Observable 5 | final class DeviceStat: @unchecked Sendable { 6 | 7 | @MainActor 8 | var gpuUsage = GPU.snapshot() 9 | 10 | private let initialGPUSnapshot = GPU.snapshot() 11 | private var timer: Timer? 12 | 13 | init() { 14 | timer = Timer.scheduledTimer(withTimeInterval: 2.0, repeats: true) { [weak self] _ in 15 | self?.updateGPUUsages() 16 | } 17 | } 18 | 19 | deinit { 20 | timer?.invalidate() 21 | } 22 | 23 | private func updateGPUUsages() { 24 | let gpuSnapshotDelta = initialGPUSnapshot.delta(GPU.snapshot()) 25 | DispatchQueue.main.async { [weak self] in 26 | self?.gpuUsage = gpuSnapshotDelta 27 | } 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /Applications/VLMEval/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "appearances" : [ 10 | { 11 | "appearance" : "luminosity", 12 | "value" : "dark" 13 | } 14 | ], 15 | "idiom" : "universal", 16 | "platform" : "ios", 17 | "size" : "1024x1024" 18 | }, 19 | { 20 | "appearances" : [ 21 | { 22 | "appearance" : "luminosity", 23 | "value" : "tinted" 24 | } 25 | ], 26 | "idiom" : "universal", 27 | "platform" : "ios", 28 | "size" : "1024x1024" 29 | } 30 | ], 31 | "info" : { 32 | "author" : "xcode", 33 | "version" : 1 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /Tools/llm-tool/Arguments.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | 6 | /// Extension to allow URL command line arguments. 7 | #if swift(>=5.10) 8 | extension URL: @retroactive ExpressibleByArgument { 9 | public init?(argument: String) { 10 | if argument.contains("://") { 11 | self.init(string: argument) 12 | } else { 13 | self.init(filePath: argument) 14 | } 15 | } 16 | } 17 | #else 18 | extension URL: ExpressibleByArgument { 19 | public init?(argument: String) { 20 | if argument.contains("://") { 21 | self.init(string: argument) 22 | } else { 23 | self.init(filePath: argument) 24 | } 25 | } 26 | } 27 | #endif 28 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/Module+Extensions.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import MLXNN 4 | 5 | extension Module { 6 | 7 | /// Compute the number of parameters in a possibly quantized model 8 | public func numParameters() -> Int { 9 | return leafModules().flattenedValues().map { 10 | mod -> Int in 11 | if let qlin = mod as? QuantizedLinear { 12 | return qlin.scales.size * qlin.groupSize 13 | } else if let qemb = mod as? QuantizedEmbedding { 14 | return qemb.scales.size * qemb.groupSize 15 | } else { 16 | return mod.parameters().flattenedValues().reduce( 17 | 0, 18 | { 19 | $0 + $1.size 20 | }) 21 | } 22 | }.reduce(0, +) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/README.md: -------------------------------------------------------------------------------- 1 | # LoRATrainingExample 2 | 3 | Example application that: 4 | 5 | - downloads the `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model from huggingface 6 | - loads the train/valid/test data from `$SRCROOT/Data/lora` (this is copied into the build but you can imagine how it might be downloaded) 7 | - adds LoRA adapters and trains the model 8 | - let's you evaluate a prompt against the model 9 | 10 | This roughly equates to the command line example in [Tools/llm-tool](../../Tools/llm-tool) and 11 | you can read more about LoRA there. 12 | 13 | This evaluates the LoRA adapted model rather than a fused model. This doesn't persist 14 | the LoRA weights or the fused model -- it will retrain it each time the program is launched. 15 | 16 | ### Troubleshooting 17 | 18 | The `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model requires a little over 4G of 19 | memory to load an train -- this may require ~6G of physical RAM. 20 | 21 | 22 | -------------------------------------------------------------------------------- /Tools/mnist-tool/README.md: -------------------------------------------------------------------------------- 1 | # mnist-tool 2 | 3 | See the [MNIST README.md](../../Libraries/MNIST/README.md). 4 | 5 | ### Building 6 | 7 | `mnist-tool` has no dependencies outside of the package dependencies 8 | represented in xcode. 9 | 10 | When you run the tool it will download the test/train datasets and 11 | store them in a specified directory (see run arguments -- default is /tmp). 12 | 13 | Simply build the project in xcode. 14 | 15 | ### Running (Xcode) 16 | 17 | To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example: 18 | 19 | ``` 20 | --data /tmp 21 | ``` 22 | 23 | Then cmd-r to run. 24 | 25 | ### Running (CommandLine) 26 | 27 | Use the `mlx-run` script to run the command line tools: 28 | 29 | ``` 30 | ./mlx-run mnist-tool --data /tmp 31 | ``` 32 | 33 | By default this will find and run the tools built in _Release_ configuration. Specify `--debug` 34 | to find and run the tool built in _Debug_ configuration. 35 | 36 | See also: 37 | 38 | - [MLX troubleshooting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/troubleshooting) 39 | -------------------------------------------------------------------------------- /Libraries/MLXLLM/LLMModel.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import MLX 4 | import MLXLMCommon 5 | 6 | /// Marker protocol for LLMModels 7 | public protocol LLMModel: LanguageModel, LoRAModel { 8 | } 9 | 10 | extension LLMModel { 11 | 12 | /// Default prepare step for ``LLMModel``. 13 | /// 14 | /// This will evaluate the prompt in chunks until there is a small amount of 15 | /// tokens left to feed into the `TokenIterator`. 16 | public func prepare(_ input: LMInput, cache: [KVCache], windowSize: Int?) throws 17 | -> PrepareResult 18 | { 19 | let prefillStepSize = windowSize ?? 512 20 | var y = input.text 21 | var state: LMOutput.State? = nil 22 | 23 | // prepare the prompt in chunks if larger than the prefill size 24 | while y.tokens.size > prefillStepSize { 25 | let input = y[.newAxis, .. UInt64 { 24 | self.state &+= 0x9e37_79b9_7f4a_7c15 25 | var z: UInt64 = self.state 26 | z = (z ^ (z &>> 30)) &* 0xbf58_476d_1ce4_e5b9 27 | z = (z ^ (z &>> 27)) &* 0x94d0_49bb_1331_11eb 28 | return z ^ (z &>> 31) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /mlx-run: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Wrapper to help run command line tools -- this will find the build directory 4 | # and set the DYLD_FRAMEWORK_PATH so that command line tools that link frameworks 5 | # can be run. 6 | # 7 | # Example: 8 | # ./mlx-run --debug llm-tool --help 9 | 10 | if [ "$#" -lt 1 ]; then 11 | echo "usage: mlx-run [--debug/--release] arguments" 12 | exit 1 13 | fi 14 | 15 | CONFIGURATION=Release 16 | if [ "$1" == "--release" ]; then 17 | CONFIGURATION=Release 18 | shift 19 | fi 20 | if [ "$1" == "--debug" ]; then 21 | CONFIGURATION=Debug 22 | shift 23 | fi 24 | if [ "$1" == "--list" ]; then 25 | xcodebuild -list 26 | exit 0 27 | fi 28 | 29 | COMMAND="$1" 30 | shift 31 | 32 | BUILD_DIR=`xcodebuild -configuration $CONFIGURATION -showBuildSettings -scheme $COMMAND | grep 'BUILT_PRODUCTS_DIR = /' | sed -e 's/^[^=]*= //g'` 33 | 34 | if [ -d "$BUILD_DIR/$COMMAND.app" ]; then 35 | exec $BUILD_DIR/$COMMAND.app/Contents/MacOS/$COMMAND "$@" & 36 | fi 37 | 38 | if [ -f "$BUILD_DIR/$COMMAND" ]; then 39 | export DYLD_FRAMEWORK_PATH=$BUILD_DIR/PackageFrameworks:$BUILD_DIR 40 | exec "$BUILD_DIR/$COMMAND" "$@" 41 | else 42 | echo "$BUILD_DIR/$COMMAND does not exist -- check build configuration ($CONFIGURATION)" 43 | exit 1 44 | fi 45 | 46 | -------------------------------------------------------------------------------- /Applications/LLMEval/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "idiom" : "mac", 10 | "scale" : "1x", 11 | "size" : "16x16" 12 | }, 13 | { 14 | "idiom" : "mac", 15 | "scale" : "2x", 16 | "size" : "16x16" 17 | }, 18 | { 19 | "idiom" : "mac", 20 | "scale" : "1x", 21 | "size" : "32x32" 22 | }, 23 | { 24 | "idiom" : "mac", 25 | "scale" : "2x", 26 | "size" : "32x32" 27 | }, 28 | { 29 | "idiom" : "mac", 30 | "scale" : "1x", 31 | "size" : "128x128" 32 | }, 33 | { 34 | "idiom" : "mac", 35 | "scale" : "2x", 36 | "size" : "128x128" 37 | }, 38 | { 39 | "idiom" : "mac", 40 | "scale" : "1x", 41 | "size" : "256x256" 42 | }, 43 | { 44 | "idiom" : "mac", 45 | "scale" : "2x", 46 | "size" : "256x256" 47 | }, 48 | { 49 | "idiom" : "mac", 50 | "scale" : "1x", 51 | "size" : "512x512" 52 | }, 53 | { 54 | "idiom" : "mac", 55 | "scale" : "2x", 56 | "size" : "512x512" 57 | } 58 | ], 59 | "info" : { 60 | "author" : "xcode", 61 | "version" : 1 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "idiom" : "mac", 10 | "scale" : "1x", 11 | "size" : "16x16" 12 | }, 13 | { 14 | "idiom" : "mac", 15 | "scale" : "2x", 16 | "size" : "16x16" 17 | }, 18 | { 19 | "idiom" : "mac", 20 | "scale" : "1x", 21 | "size" : "32x32" 22 | }, 23 | { 24 | "idiom" : "mac", 25 | "scale" : "2x", 26 | "size" : "32x32" 27 | }, 28 | { 29 | "idiom" : "mac", 30 | "scale" : "1x", 31 | "size" : "128x128" 32 | }, 33 | { 34 | "idiom" : "mac", 35 | "scale" : "2x", 36 | "size" : "128x128" 37 | }, 38 | { 39 | "idiom" : "mac", 40 | "scale" : "1x", 41 | "size" : "256x256" 42 | }, 43 | { 44 | "idiom" : "mac", 45 | "scale" : "2x", 46 | "size" : "256x256" 47 | }, 48 | { 49 | "idiom" : "mac", 50 | "scale" : "1x", 51 | "size" : "512x512" 52 | }, 53 | { 54 | "idiom" : "mac", 55 | "scale" : "2x", 56 | "size" : "512x512" 57 | } 58 | ], 59 | "info" : { 60 | "author" : "xcode", 61 | "version" : 1 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "idiom" : "mac", 10 | "scale" : "1x", 11 | "size" : "16x16" 12 | }, 13 | { 14 | "idiom" : "mac", 15 | "scale" : "2x", 16 | "size" : "16x16" 17 | }, 18 | { 19 | "idiom" : "mac", 20 | "scale" : "1x", 21 | "size" : "32x32" 22 | }, 23 | { 24 | "idiom" : "mac", 25 | "scale" : "2x", 26 | "size" : "32x32" 27 | }, 28 | { 29 | "idiom" : "mac", 30 | "scale" : "1x", 31 | "size" : "128x128" 32 | }, 33 | { 34 | "idiom" : "mac", 35 | "scale" : "2x", 36 | "size" : "128x128" 37 | }, 38 | { 39 | "idiom" : "mac", 40 | "scale" : "1x", 41 | "size" : "256x256" 42 | }, 43 | { 44 | "idiom" : "mac", 45 | "scale" : "2x", 46 | "size" : "256x256" 47 | }, 48 | { 49 | "idiom" : "mac", 50 | "scale" : "1x", 51 | "size" : "512x512" 52 | }, 53 | { 54 | "idiom" : "mac", 55 | "scale" : "2x", 56 | "size" : "512x512" 57 | } 58 | ], 59 | "info" : { 60 | "author" : "xcode", 61 | "version" : 1 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "idiom" : "mac", 10 | "scale" : "1x", 11 | "size" : "16x16" 12 | }, 13 | { 14 | "idiom" : "mac", 15 | "scale" : "2x", 16 | "size" : "16x16" 17 | }, 18 | { 19 | "idiom" : "mac", 20 | "scale" : "1x", 21 | "size" : "32x32" 22 | }, 23 | { 24 | "idiom" : "mac", 25 | "scale" : "2x", 26 | "size" : "32x32" 27 | }, 28 | { 29 | "idiom" : "mac", 30 | "scale" : "1x", 31 | "size" : "128x128" 32 | }, 33 | { 34 | "idiom" : "mac", 35 | "scale" : "2x", 36 | "size" : "128x128" 37 | }, 38 | { 39 | "idiom" : "mac", 40 | "scale" : "1x", 41 | "size" : "256x256" 42 | }, 43 | { 44 | "idiom" : "mac", 45 | "scale" : "2x", 46 | "size" : "256x256" 47 | }, 48 | { 49 | "idiom" : "mac", 50 | "scale" : "1x", 51 | "size" : "512x512" 52 | }, 53 | { 54 | "idiom" : "mac", 55 | "scale" : "2x", 56 | "size" : "512x512" 57 | } 58 | ], 59 | "info" : { 60 | "author" : "xcode", 61 | "version" : 1 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/README.md: -------------------------------------------------------------------------------- 1 | # StableDiffusionExample 2 | 3 | An example application that runs the StableDiffusion example code. 4 | 5 | See also [image-tool](../../Tools/image-tool) for a command line example. 6 | 7 | This example application accepts a prompt and used the StableDiffusion example 8 | library to render an image using: 9 | 10 | - [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo) 11 | 12 | Please refer to that model for license and other information. 13 | 14 | If you are interested in adjusting the generated images, look in 15 | [ContentView.swift](ContentView.swift) at this method: 16 | 17 | ```swift 18 | func generate(prompt: String, negativePrompt: String, showProgress: Bool) async 19 | ``` 20 | 21 | ### Troubleshooting 22 | 23 | Stable diffusion can run in less that 4G available memory (typically a 24 | device or computer with 6G of memory or more) in a constrained mode -- it will 25 | load and unload parts of the model as it runs and it can only perform one step 26 | of diffusion. This is configured automatically, see `modelFactory.conserveMemory` 27 | in [ContentView.swift](ContentView.swift). 28 | 29 | On a device or computer with more memory the model will be kept resident and 30 | images can be regenerated much more efficiently. 31 | 32 | If the program exits while generating the image it may have exceeded the available 33 | memory. 34 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/Registries/ModelTypeRegistry.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | 5 | open class ModelTypeRegistry: @unchecked Sendable { 6 | 7 | /// Creates an empty registry. 8 | public init() { 9 | self.creators = [:] 10 | } 11 | 12 | /// Creates a registry with given creators. 13 | public init(creators: [String: @Sendable (URL) throws -> any LanguageModel]) { 14 | self.creators = creators 15 | } 16 | 17 | // Note: using NSLock as we have very small (just dictionary get/set) 18 | // critical sections and expect no contention. this allows the methods 19 | // to remain synchronous. 20 | private let lock = NSLock() 21 | private var creators: [String: @Sendable (URL) throws -> any LanguageModel] 22 | 23 | /// Add a new model to the type registry. 24 | public func registerModelType( 25 | _ type: String, creator: @Sendable @escaping (URL) throws -> any LanguageModel 26 | ) { 27 | lock.withLock { 28 | creators[type] = creator 29 | } 30 | } 31 | 32 | /// Given a `modelType` and configuration file instantiate a new `LanguageModel`. 33 | public func createModel(configuration: URL, modelType: String) throws -> LanguageModel { 34 | let creator = lock.withLock { 35 | creators[modelType] 36 | } 37 | guard let creator else { 38 | throw ModelFactoryError.unsupportedModelType(modelType) 39 | } 40 | return try creator(configuration) 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /Libraries/Embedders/README.md: -------------------------------------------------------------------------------- 1 | # MLXEmbedders 2 | 3 | This directory contains ports of popular Encoders / Embedding Models. 4 | 5 | ## Usage Example 6 | 7 | ```swift 8 | 9 | let modelContainer = try await MLXEmbedders.loadModelContainer( 10 | configuration: ModelConfiguration.nomic_text_v1_5) 11 | let result = await modelContainer.perform { 12 | (model: EmbeddingModel, tokenizer, pooling) -> [[Float]] in 13 | let inputs = [ 14 | "search_query: Animals in Tropical Climates.", 15 | "search_document: Elephants", 16 | "search_document: Horses", 17 | "search_document: Polar Bears", 18 | ].map { 19 | tokenizer.encode(text: $0, addSpecialTokens: true) 20 | } 21 | // Pad to longest 22 | let maxLength = inputs.reduce(into: 16) { acc, elem in 23 | acc = max(acc, elem.count) 24 | } 25 | 26 | let padded = stacked( 27 | inputs.map { elem in 28 | MLXArray( 29 | elem 30 | + Array( 31 | repeating: tokenizer.eosTokenId ?? 0, 32 | count: maxLength - elem.count)) 33 | }) 34 | let mask = (padded .!= tokenizer.eosTokenId ?? 0) 35 | let tokenTypes = MLXArray.zeros(like: padded) 36 | let result = pooling( 37 | model(padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask), 38 | normalize: true, applyLayerNorm: true 39 | ).eval() 40 | return result.map { $0.asArray(Float.self) } 41 | } 42 | ``` 43 | 44 | 45 | Ported to swift from [taylorai/mlx_embedding_models](https://github.com/taylorai/mlx_embedding_models/tree/main) 46 | -------------------------------------------------------------------------------- /Applications/VLMEval/README.md: -------------------------------------------------------------------------------- 1 | # VLMEval 2 | 3 | An example that: 4 | 5 | - downloads a vision language model (Qwen-VL-2B) 6 | - processes an image with a prompt 7 | 8 | You will need to set the Team on the VLMEval target in order to build and run on macOS. 9 | 10 | Some notes about the setup: 11 | 12 | - This downloads models from hugging face so VLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox 13 | - VLM models are large so this uses significant memory 14 | - The example processes images and provides detailed analysis 15 | 16 | ### Image Processing 17 | 18 | The example application uses Qwen-VL-2B model by default, see [ContentView.swift](ContentView.swift): 19 | 20 | ```swift 21 | self.modelContainer = try await VLMModelFactory.shared.loadContainer( 22 | configuration: ModelRegistry.qwen2VL2BInstruct4Bit) 23 | ``` 24 | 25 | The application: 26 | 1. Downloads a sample image 27 | 2. Processes it through the vision language model 28 | 3. Describes the images based on the prompt, providing detailed analysis of the content, objects, colors, and composition. 29 | 30 | ### Troubleshooting 31 | 32 | If the program crashes with a very deep stack trace you may need to build 33 | in Release configuration. This seems to depend on the size of the model. 34 | 35 | There are a couple options: 36 | 37 | - Build Release 38 | - Force the model evaluation to run on the main thread, e.g. using @MainActor 39 | - Build `Cmlx` with optimizations by modifying `mlx/Package.swift` and adding `.unsafeOptions(["-O3"]),` 40 | 41 | ### Performance 42 | 43 | You may find that running outside the debugger boosts performance. You can do this in Xcode by pressing cmd-opt-r and unchecking "Debug Executable". 44 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/Registries/ProcessorTypeRegistry.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import Tokenizers 5 | 6 | open class ProcessorTypeRegistry: @unchecked Sendable { 7 | 8 | /// Creates an empty registry. 9 | public init() { 10 | self.creators = [:] 11 | } 12 | 13 | /// Creates a registry with given creators. 14 | public init(creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor]) 15 | { 16 | self.creators = creators 17 | } 18 | 19 | // Note: using NSLock as we have very small (just dictionary get/set) 20 | // critical sections and expect no contention. this allows the methods 21 | // to remain synchronous. 22 | private let lock = NSLock() 23 | 24 | private var creators: [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor] 25 | 26 | /// Add a new model to the type registry. 27 | public func registerProcessorType( 28 | _ type: String, 29 | creator: @Sendable @escaping ( 30 | URL, 31 | any Tokenizer 32 | ) throws -> any UserInputProcessor 33 | ) { 34 | lock.withLock { 35 | creators[type] = creator 36 | } 37 | } 38 | 39 | /// Given a `processorType` and configuration file instantiate a new `UserInputProcessor`. 40 | public func createModel(configuration: URL, processorType: String, tokenizer: any Tokenizer) 41 | throws -> any UserInputProcessor 42 | { 43 | let creator = lock.withLock { 44 | creators[processorType] 45 | } 46 | guard let creator else { 47 | throw ModelFactoryError.unsupportedProcessorType(processorType) 48 | } 49 | return try creator(configuration, tokenizer) 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/Registries/AbstractModelRegistry.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | 5 | open class AbstractModelRegistry: @unchecked Sendable { 6 | 7 | /// Creates an empty registry. 8 | public init() { 9 | self.registry = Dictionary() 10 | } 11 | 12 | /// Creates a new registry with from given model configurations. 13 | public init(modelConfigurations: [ModelConfiguration]) { 14 | self.registry = Dictionary(uniqueKeysWithValues: modelConfigurations.map { ($0.name, $0) }) 15 | } 16 | 17 | private let lock = NSLock() 18 | private var registry: [String: ModelConfiguration] 19 | 20 | public func register(configurations: [ModelConfiguration]) { 21 | lock.withLock { 22 | for c in configurations { 23 | registry[c.name] = c 24 | } 25 | } 26 | } 27 | 28 | /// Returns configuration from ``modelRegistry``. 29 | /// 30 | /// - Note: If the id doesn't exists in the configuration, this will return a new instance of it. 31 | /// If you want to check if the configuration in model registry, you should use ``contains(id:)``. 32 | public func configuration(id: String) -> ModelConfiguration { 33 | lock.withLock { 34 | if let c = registry[id] { 35 | return c 36 | } else { 37 | return ModelConfiguration(id: id) 38 | } 39 | } 40 | } 41 | 42 | /// Returns true if the registry contains a model with the id. Otherwise, false. 43 | public func contains(id: String) -> Bool { 44 | lock.withLock { 45 | registry[id] != nil 46 | } 47 | } 48 | 49 | public var models: some Collection & Sendable { 50 | lock.withLock { 51 | return registry.values 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | apple: ml-explore/pr-approval@0.1.0 5 | 6 | parameters: 7 | nightly_build: 8 | type: boolean 9 | default: false 10 | weekly_build: 11 | type: boolean 12 | default: false 13 | 14 | jobs: 15 | 16 | mac_build_and_test: 17 | macos: 18 | xcode: 16.0.0 19 | resource_class: macos.m1.medium.gen1 20 | steps: 21 | - checkout 22 | - run: git submodule sync 23 | - run: git submodule update --init 24 | - run: 25 | name: Run style checks 26 | command: | 27 | pip install pre-commit 28 | brew install swift-format 29 | pre-commit run --all 30 | if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi 31 | - run: 32 | name: Build Examples 33 | command: | 34 | xcodebuild -version 35 | xcrun --show-sdk-build-version 36 | swift --version 37 | find . -name Package.resolved -exec rm {} \; 38 | xcodebuild -scheme llm-tool 39 | xcodebuild -scheme image-tool 40 | xcodebuild -scheme mnist-tool 41 | 42 | workflows: 43 | build_and_test: 44 | when: 45 | and: 46 | - matches: 47 | pattern: "^(?!pull/)[-\\w]+$" 48 | value: << pipeline.git.branch >> 49 | - not: << pipeline.parameters.nightly_build >> 50 | - not: << pipeline.parameters.weekly_build >> 51 | jobs: 52 | - mac_build_and_test 53 | 54 | prb: 55 | when: 56 | matches: 57 | pattern: "^pull/\\d+(/head)?$" 58 | value: << pipeline.git.branch >> 59 | jobs: 60 | - hold: 61 | type: approval 62 | - apple/authenticate: 63 | context: pr-approval 64 | - mac_build_and_test: 65 | requires: [ hold ] 66 | -------------------------------------------------------------------------------- /Libraries/StableDiffusion/README.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion 2 | 3 | Stable Diffusion in MLX. The implementation was ported from Hugging Face's 4 | [diffusers](https://huggingface.co/docs/diffusers/index) and 5 | [mlx-examples/stable_diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion). 6 | Model weights are downloaded directly from the Hugging Face hub. The implementation currently 7 | supports the following models: 8 | 9 | - [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo) 10 | - [stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) 11 | 12 | ## Usage 13 | 14 | See [StableDiffusionExample](../../Applications/StableDiffusionExample) and 15 | [image-tool](../../Tools/image-tool) for examples of using this code. 16 | 17 | The basic sequence is: 18 | 19 | - download & load the model 20 | - generate latents 21 | - evaluate the latents one by one 22 | - decode the last latent generated 23 | - you have an image! 24 | 25 | ```swift 26 | let configuration = StableDiffusionConfiguration.presetSDXLTurbo 27 | 28 | let generator = try configuration.textToImageGenerator( 29 | configuration: model.loadConfiguration) 30 | 31 | generator.ensureLoaded() 32 | 33 | // generate the latents -- these are the iterations for generating 34 | // the output image. this is just generating the evaluation graph 35 | let parameters = generate.evaluateParameters(configuration: configuration) 36 | let latents = generator.generateLatents(parameters: parameters) 37 | 38 | // evaluate the latents (evalue the graph) and keep the last value generated 39 | var lastXt: MLXArray? 40 | for xt in latents { 41 | eval(xt) 42 | lastXt = xt 43 | } 44 | 45 | // decode the final latent into an image 46 | if let lastXt { 47 | var raster = decoder(lastXt[0]) 48 | raster = (image * 255).asType(.uint8).squeezed() 49 | eval(raster) 50 | 51 | // turn it into a CGImage 52 | let image = Image(raster).asCGImage() 53 | 54 | // or write it out 55 | try Image(raster).save(url: url) 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /Libraries/MLXLLM/SuScaledRotaryEmbedding.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLX 3 | import MLXFast 4 | import MLXNN 5 | 6 | public class SuScaledRotaryEmbedding: Module { 7 | let dimensions: Int 8 | let maxPositionEmbeddings: Int 9 | let originalMaxPositionEmbeddings: Int 10 | let scale: Float 11 | let _freqs: MLXArray 12 | 13 | public init( 14 | dimensions: Int, 15 | base: Float = 10000.0, 16 | maxPositionEmbeddings: Int = 131072, 17 | originalMaxPositionEmbeddings: Int = 4096, 18 | longFactor: [Float] = [1.0], 19 | // shortMScale: Float? = nil, 20 | longMScale: Float? = nil 21 | ) { 22 | precondition(dimensions % 2 == 0, "Dimensions must be even") 23 | 24 | self.dimensions = dimensions 25 | self.maxPositionEmbeddings = maxPositionEmbeddings 26 | self.originalMaxPositionEmbeddings = originalMaxPositionEmbeddings 27 | 28 | let exponent = 29 | MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / Float(dimensions) 30 | let freqs = MLX.pow(MLXArray(base), exponent) 31 | self._freqs = MLXArray(longFactor).asType(.float32) * freqs 32 | 33 | self.scale = 34 | longMScale 35 | ?? sqrt( 36 | 1 + log(Float(maxPositionEmbeddings) / Float(originalMaxPositionEmbeddings)) 37 | / log(Float(originalMaxPositionEmbeddings)) 38 | ) 39 | } 40 | 41 | public func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray { 42 | // Apply scaling only to the dimensions that will be rotated 43 | var scaledX = x 44 | let sliceToScale = scaledX[.ellipsis, 0 ..< dimensions] 45 | scaledX[.ellipsis, 0 ..< dimensions] = scale * sliceToScale 46 | 47 | return MLXFast.RoPE( 48 | scaledX, 49 | dimensions: dimensions, 50 | traditional: false, 51 | base: nil, 52 | scale: 1.0, 53 | offset: offset, 54 | freqs: self._freqs 55 | ) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /Libraries/MLXLLM/Lora+Data.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | 5 | enum LoRADataError: LocalizedError { 6 | case fileNotFound(URL, String) 7 | 8 | var errorDescription: String? { 9 | switch self { 10 | case .fileNotFound(let directory, let name): 11 | return String( 12 | localized: "Could not find data file '\(name)' in directory '\(directory.path())'.") 13 | } 14 | } 15 | } 16 | 17 | /// Load a LoRA data file. 18 | /// 19 | /// Given a directory and a base name, e.g. `train`, this will load a `.jsonl` or `.txt` file 20 | /// if possible. 21 | public func loadLoRAData(directory: URL, name: String) throws -> [String] { 22 | let extensions = ["jsonl", "txt"] 23 | 24 | for ext in extensions { 25 | let url = directory.appending(component: "\(name).\(ext)") 26 | if FileManager.default.fileExists(atPath: url.path()) { 27 | return try loadLoRAData(url: url) 28 | } 29 | } 30 | 31 | throw LoRADataError.fileNotFound(directory, name) 32 | } 33 | 34 | /// Load a .txt or .jsonl file and return the contents 35 | public func loadLoRAData(url: URL) throws -> [String] { 36 | switch url.pathExtension { 37 | case "jsonl": 38 | return try loadJSONL(url: url) 39 | 40 | case "txt": 41 | return try loadLines(url: url) 42 | 43 | default: 44 | fatalError("Unable to load data file, unknown type: \(url)") 45 | 46 | } 47 | } 48 | 49 | func loadJSONL(url: URL) throws -> [String] { 50 | 51 | struct Line: Codable { 52 | let text: String? 53 | } 54 | 55 | return try String(contentsOf: url) 56 | .components(separatedBy: .newlines) 57 | .filter { 58 | $0.first == "{" 59 | } 60 | .compactMap { 61 | try JSONDecoder().decode(Line.self, from: $0.data(using: .utf8)!).text 62 | } 63 | } 64 | 65 | func loadLines(url: URL) throws -> [String] { 66 | try String(contentsOf: url) 67 | .components(separatedBy: .newlines) 68 | .filter { !$0.isEmpty } 69 | } 70 | -------------------------------------------------------------------------------- /Libraries/Embedders/Tokenizer.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import Hub 5 | import Tokenizers 6 | 7 | public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer 8 | { 9 | let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig( 10 | configuration: configuration, hub: hub) 11 | 12 | return try PreTrainedTokenizer( 13 | tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) 14 | } 15 | 16 | func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async throws -> ( 17 | Config, Config 18 | ) { 19 | // from AutoTokenizer.from() -- this lets us override parts of the configuration 20 | let config: LanguageModelConfigurationFromHub 21 | 22 | switch configuration.id { 23 | case .id(let id): 24 | do { 25 | // the load can fail (async when we try to use it) 26 | let loaded = LanguageModelConfigurationFromHub( 27 | modelName: configuration.tokenizerId ?? id, hubApi: hub) 28 | _ = try await loaded.tokenizerConfig 29 | config = loaded 30 | } catch { 31 | let nserror = error as NSError 32 | if nserror.domain == NSURLErrorDomain 33 | && nserror.code == NSURLErrorNotConnectedToInternet 34 | { 35 | // Internet connection appears to be offline -- fall back to loading from 36 | // the local directory 37 | config = LanguageModelConfigurationFromHub( 38 | modelFolder: configuration.modelDirectory(hub: hub), hubApi: hub) 39 | } else { 40 | throw error 41 | } 42 | } 43 | case .directory(let directory): 44 | config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub) 45 | } 46 | 47 | guard let tokenizerConfig = try await config.tokenizerConfig else { 48 | throw EmbedderError(message: "missing config") 49 | } 50 | let tokenizerData = try await config.tokenizerData 51 | return (tokenizerConfig, tokenizerData) 52 | } 53 | -------------------------------------------------------------------------------- /Applications/LLMEval/README.md: -------------------------------------------------------------------------------- 1 | # LLMEval 2 | 3 | An example that: 4 | 5 | - downloads a huggingface model (phi-2) and tokenizer 6 | - evaluates a prompt 7 | - displays the output as it generates text 8 | 9 | You will need to set the Team on the LLMEval target in order to build and run on iOS. 10 | 11 | Some notes about the setup: 12 | 13 | - this downloads models from hugging face so LLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox 14 | - LLM models are large so this uses the Increased Memory Limit entitlement on iOS to allow ... increased memory limits for devices that have more memory 15 | - `MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)` is used to limit the buffer cache size 16 | - The Phi2 4 bit model is small enough to run on some iPhone models 17 | - this can be changed by editing `let modelConfiguration = ModelConfiguration.phi4bit` 18 | 19 | ### Trying Different Models 20 | 21 | The example application uses Phi2 model by default, see [ContentView.swift](ContentView.swift#L58): 22 | 23 | ``` 24 | /// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on 25 | /// more devices 26 | let modelConfiguration = ModelConfiguration.phi4bit 27 | ``` 28 | 29 | There are some pre-configured models in [MLXLLM/LLMModelFactory.swift](../../Libraries/MLXLLM/LLMModelFactory.swift#L78) 30 | and you can load any weights from Hugging Face where there 31 | is a model architecture defined and you have enough 32 | memory. 33 | 34 | ### Troubleshooting 35 | 36 | If the program crashes with a very deep stack trace you may need to build 37 | in Release configuration. This seems to depend on the size of the model. 38 | 39 | There are a couple options: 40 | 41 | - build Release 42 | - force the model evaluation to run on the main thread, e.g. using @MainActor 43 | - build `Cmlx` with optimizations by modifying `mlx/Package.swift` and adding `.unsafeOptions(["-O3"]),` around line 87 44 | 45 | See discussion here: https://github.com/ml-explore/mlx-swift-examples/issues/3 46 | 47 | ### Performance 48 | 49 | Different models have difference performance characteristics. For example Gemma 2B may outperform Phi-2 in terms of tokens / second. 50 | 51 | You may also find that running outside the debugger boosts performance. You can do this in Xcode by pressing cmd-opt-r and unchecking "Debug Executable". 52 | -------------------------------------------------------------------------------- /Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "pins" : [ 3 | { 4 | "identity" : "gzipswift", 5 | "kind" : "remoteSourceControl", 6 | "location" : "https://github.com/1024jp/GzipSwift", 7 | "state" : { 8 | "revision" : "731037f6cc2be2ec01562f6597c1d0aa3fe6fd05", 9 | "version" : "6.0.1" 10 | } 11 | }, 12 | { 13 | "identity" : "jinja", 14 | "kind" : "remoteSourceControl", 15 | "location" : "https://github.com/johnmai-dev/Jinja", 16 | "state" : { 17 | "revision" : "bbddb92fc51ae420b87300298370fd1dfc308f73", 18 | "version" : "1.1.1" 19 | } 20 | }, 21 | { 22 | "identity" : "mlx-swift", 23 | "kind" : "remoteSourceControl", 24 | "location" : "https://github.com/ml-explore/mlx-swift", 25 | "state" : { 26 | "revision" : "70dbb62128a5a1471a5ab80363430adb33470cab", 27 | "version" : "0.21.2" 28 | } 29 | }, 30 | { 31 | "identity" : "swift-argument-parser", 32 | "kind" : "remoteSourceControl", 33 | "location" : "https://github.com/apple/swift-argument-parser.git", 34 | "state" : { 35 | "revision" : "0fbc8848e389af3bb55c182bc19ca9d5dc2f255b", 36 | "version" : "1.4.0" 37 | } 38 | }, 39 | { 40 | "identity" : "swift-async-algorithms", 41 | "kind" : "remoteSourceControl", 42 | "location" : "https://github.com/apple/swift-async-algorithms", 43 | "state" : { 44 | "revision" : "da4e36f86544cdf733a40d59b3a2267e3a7bbf36", 45 | "version" : "1.0.0" 46 | } 47 | }, 48 | { 49 | "identity" : "swift-collections", 50 | "kind" : "remoteSourceControl", 51 | "location" : "https://github.com/apple/swift-collections.git", 52 | "state" : { 53 | "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7", 54 | "version" : "1.1.4" 55 | } 56 | }, 57 | { 58 | "identity" : "swift-numerics", 59 | "kind" : "remoteSourceControl", 60 | "location" : "https://github.com/apple/swift-numerics", 61 | "state" : { 62 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", 63 | "version" : "1.0.2" 64 | } 65 | }, 66 | { 67 | "identity" : "swift-transformers", 68 | "kind" : "remoteSourceControl", 69 | "location" : "https://github.com/huggingface/swift-transformers", 70 | "state" : { 71 | "revision" : "be855fac725dbae27264e47a3eb535cc422a4ba8", 72 | "version" : "0.1.18" 73 | } 74 | } 75 | ], 76 | "version" : 2 77 | } 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Xcode 2 | # 3 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 4 | 5 | ## User settings 6 | xcuserdata/ 7 | 8 | ## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9) 9 | *.xcscmblueprint 10 | *.xccheckout 11 | 12 | ## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4) 13 | build/ 14 | DerivedData/ 15 | *.moved-aside 16 | *.pbxuser 17 | !default.pbxuser 18 | *.mode1v3 19 | !default.mode1v3 20 | *.mode2v3 21 | !default.mode2v3 22 | *.perspectivev3 23 | !default.perspectivev3 24 | 25 | ## Obj-C/Swift specific 26 | *.hmap 27 | 28 | ## App packaging 29 | *.ipa 30 | *.dSYM.zip 31 | *.dSYM 32 | 33 | ## Playgrounds 34 | timeline.xctimeline 35 | playground.xcworkspace 36 | 37 | # Swift Package Manager 38 | # 39 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. 40 | Packages/ 41 | Package.pins 42 | Package.resolved 43 | # *.xcodeproj 44 | # 45 | # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata 46 | # hence it is not needed unless you have added a package configuration file to your project 47 | .swiftpm 48 | 49 | .build/ 50 | 51 | # CocoaPods 52 | # 53 | # We recommend against adding the Pods directory to your .gitignore. However 54 | # you should judge for yourself, the pros and cons are mentioned at: 55 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control 56 | # 57 | # Pods/ 58 | # 59 | # Add this line if you want to avoid checking in source code from the Xcode workspace 60 | # *.xcworkspace 61 | 62 | # Carthage 63 | # 64 | # Add this line if you want to avoid checking in source code from Carthage dependencies. 65 | # Carthage/Checkouts 66 | 67 | Carthage/Build/ 68 | 69 | # Accio dependency management 70 | Dependencies/ 71 | .accio/ 72 | 73 | # fastlane 74 | # 75 | # It is recommended to not store the screenshots in the git repo. 76 | # Instead, use fastlane to re-generate the screenshots whenever they are needed. 77 | # For more information about the recommended setup visit: 78 | # https://docs.fastlane.tools/best-practices/source-control/#source-control 79 | 80 | fastlane/report.xml 81 | fastlane/Preview.html 82 | fastlane/screenshots/**/*.png 83 | fastlane/test_output 84 | 85 | # Code Injection 86 | # 87 | # After new code Injection tools there's a generated folder /iOSInjectionProject 88 | # https://github.com/johnno1962/injectionforxcode 89 | 90 | iOSInjectionProject/ 91 | 92 | # OS 93 | .DS_Store 94 | 95 | .idea -------------------------------------------------------------------------------- /Libraries/MLXMNIST/MNIST.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import MLX 5 | import MLXNN 6 | 7 | // based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py 8 | 9 | public class LeNet: Module, UnaryLayer { 10 | 11 | @ModuleInfo var conv1: Conv2d 12 | @ModuleInfo var conv2: Conv2d 13 | @ModuleInfo var pool1: MaxPool2d 14 | @ModuleInfo var pool2: MaxPool2d 15 | @ModuleInfo var fc1: Linear 16 | @ModuleInfo var fc2: Linear 17 | @ModuleInfo var fc3: Linear 18 | 19 | override public init() { 20 | conv1 = Conv2d(inputChannels: 1, outputChannels: 6, kernelSize: 5, padding: 2) 21 | conv2 = Conv2d(inputChannels: 6, outputChannels: 16, kernelSize: 5, padding: 0) 22 | pool1 = MaxPool2d(kernelSize: 2, stride: 2) 23 | pool2 = MaxPool2d(kernelSize: 2, stride: 2) 24 | fc1 = Linear(16 * 5 * 5, 120) 25 | fc2 = Linear(120, 84) 26 | fc3 = Linear(84, 10) 27 | } 28 | 29 | public func callAsFunction(_ x: MLXArray) -> MLXArray { 30 | var x = x 31 | x = pool1(tanh(conv1(x))) 32 | x = pool2(tanh(conv2(x))) 33 | x = flattened(x, start: 1) 34 | x = tanh(fc1(x)) 35 | x = tanh(fc2(x)) 36 | x = fc3(x) 37 | return x 38 | } 39 | } 40 | 41 | public func loss(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray { 42 | crossEntropy(logits: model(x), targets: y, reduction: .mean) 43 | } 44 | 45 | public func eval(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray { 46 | mean(argMax(model(x), axis: 1) .== y) 47 | } 48 | 49 | private struct BatchSequence: Sequence, IteratorProtocol { 50 | 51 | let batchSize: Int 52 | let x: MLXArray 53 | let y: MLXArray 54 | 55 | let indexes: MLXArray 56 | var index = 0 57 | 58 | init(batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator) 59 | { 60 | self.batchSize = batchSize 61 | self.x = x 62 | self.y = y 63 | self.indexes = MLXArray(Array(0 ..< y.size).shuffled(using: &generator)) 64 | } 65 | 66 | mutating func next() -> (MLXArray, MLXArray)? { 67 | guard index < y.size else { return nil } 68 | 69 | let range = index ..< Swift.min(index + batchSize, y.size) 70 | index += batchSize 71 | let ids = indexes[range] 72 | return (x[ids], y[ids]) 73 | } 74 | } 75 | 76 | public func iterateBatches( 77 | batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator 78 | ) -> some Sequence<(MLXArray, MLXArray)> { 79 | BatchSequence(batchSize: batchSize, x: x, y: y, using: &generator) 80 | } 81 | -------------------------------------------------------------------------------- /Libraries/MLXLLM/Documentation.docc/adding-model.md: -------------------------------------------------------------------------------- 1 | # Adding a Model 2 | 3 | If the model follows the typical LLM pattern you can add a new 4 | model in a few steps. 5 | 6 | - `config.json`, `tokenizer.json`, and `tokenizer_config.json` 7 | - `*.safetensors` 8 | 9 | You can follow the pattern of the models in the [Models](Models) directory 10 | and create a `.swift` file for your new model: 11 | 12 | ## Create a Configuration 13 | 14 | Create a configuration struct to match the `config.json` (any parameters needed). 15 | 16 | ```swift 17 | public struct YourModelConfiguration: Codable, Sendable { 18 | public let hiddenSize: Int 19 | 20 | // use this pattern for values that need defaults 21 | public let _layerNormEps: Float? 22 | public var layerNormEps: Float { _layerNormEps ?? 1e-6 } 23 | 24 | enum CodingKeys: String, CodingKey { 25 | case hiddenSize = "hidden_size" 26 | case _layerNormEps = "layer_norm_eps" 27 | } 28 | } 29 | ``` 30 | 31 | ## Create the Model Class 32 | 33 | Create the model class. The top-level public class should have a 34 | structure something like this: 35 | 36 | ```swift 37 | public class YourModel: Module, LLMModel, KVCacheDimensionProvider, LoRAModel { 38 | 39 | public let kvHeads: [Int] 40 | 41 | @ModuleInfo var model: YourModelInner 42 | 43 | public func loraLinearLayers() -> LoRALinearLayers { 44 | // TODO: modify as needed 45 | model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } 46 | } 47 | 48 | public init(_ args: YourModelConfiguration) { 49 | self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers) 50 | self.model = YourModelInner(args) 51 | } 52 | 53 | public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { 54 | // TODO: modify as needed 55 | let out = model(inputs, cache: cache) 56 | return model.embedTokens.asLinear(out) 57 | } 58 | } 59 | ``` 60 | 61 | ## Register the Model 62 | 63 | In [LLMModelFactory.swift](LLMModelFactory.swift) register the model type itself 64 | (this is independent of the model id): 65 | 66 | ```swift 67 | public class ModelTypeRegistry: @unchecked Sendable { 68 | ... 69 | private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [ 70 | "yourModel": create(YourModelConfiguration.self, YourModel.init), 71 | ``` 72 | 73 | Add a constant for the model in the `ModelRegistry` (not strictly required but useful 74 | for callers to refer to it in code): 75 | 76 | ```swift 77 | public class ModelRegistry: @unchecked Sendable { 78 | ... 79 | static public let yourModel_4bit = ModelConfiguration( 80 | id: "mlx-community/YourModel-4bit", 81 | defaultPrompt: "What is the gravity on Mars and the moon?" 82 | ) 83 | ``` 84 | 85 | and finally add it to the all list -- this will let users find the model 86 | configuration by id: 87 | 88 | ```swift 89 | private static func all() -> [ModelConfiguration] { 90 | [ 91 | codeLlama13b4bit, 92 | ... 93 | yourModel_4bit, 94 | ``` 95 | -------------------------------------------------------------------------------- /Tools/image-tool/Arguments.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | import MLX 6 | 7 | #if swift(>=5.10) 8 | /// Extension to allow URL command line arguments. 9 | extension URL: @retroactive ExpressibleByArgument { 10 | public init?(argument: String) { 11 | if argument.contains("://") { 12 | self.init(string: argument) 13 | } else { 14 | self.init(filePath: argument) 15 | } 16 | } 17 | } 18 | #else 19 | /// Extension to allow URL command line arguments. 20 | extension URL: ExpressibleByArgument { 21 | public init?(argument: String) { 22 | if argument.contains("://") { 23 | self.init(string: argument) 24 | } else { 25 | self.init(filePath: argument) 26 | } 27 | } 28 | } 29 | #endif 30 | 31 | /// Argument package for adjusting and reporting memory use. 32 | struct MemoryArguments: ParsableArguments, Sendable { 33 | 34 | @Flag(name: .long, help: "Show memory stats") 35 | var memoryStats = false 36 | 37 | @Option(name: .long, help: "Maximum cache size in M") 38 | var cacheSize = 1024 39 | 40 | @Option(name: .long, help: "Maximum memory size in M") 41 | var memorySize: Int? 42 | 43 | var startMemory: GPU.Snapshot? 44 | 45 | mutating func start(_ load: () async throws -> L) async throws -> L { 46 | GPU.set(cacheLimit: cacheSize * 1024 * 1024) 47 | 48 | if let memorySize { 49 | GPU.set(memoryLimit: memorySize * 1024 * 1024) 50 | } 51 | 52 | let result = try await load() 53 | startMemory = GPU.snapshot() 54 | 55 | return result 56 | } 57 | 58 | mutating func start() { 59 | GPU.set(cacheLimit: cacheSize * 1024 * 1024) 60 | 61 | if let memorySize { 62 | GPU.set(memoryLimit: memorySize * 1024 * 1024) 63 | } 64 | 65 | startMemory = GPU.snapshot() 66 | } 67 | 68 | func reportCurrent() { 69 | if memoryStats { 70 | let memory = GPU.snapshot() 71 | print(memory.description) 72 | } 73 | } 74 | 75 | func reportMemoryStatistics() { 76 | if memoryStats, let startMemory { 77 | let endMemory = GPU.snapshot() 78 | 79 | print("=======") 80 | print("Memory size: \(GPU.memoryLimit / 1024)K") 81 | print("Cache size: \(GPU.cacheLimit / 1024)K") 82 | 83 | print("") 84 | print("=======") 85 | print("Starting memory") 86 | print(startMemory.description) 87 | 88 | print("") 89 | print("=======") 90 | print("Ending memory") 91 | print(endMemory.description) 92 | 93 | print("") 94 | print("=======") 95 | print("Growth") 96 | print(startMemory.delta(endMemory).description) 97 | 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/ModelConfiguration.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import Hub 5 | 6 | /// Configuration for a given model name with overrides for prompts and tokens. 7 | /// 8 | /// See e.g. `MLXLM.ModelRegistry` for an example of use. 9 | public struct ModelConfiguration: Sendable { 10 | 11 | public enum Identifier: Sendable { 12 | case id(String) 13 | case directory(URL) 14 | } 15 | 16 | public var id: Identifier 17 | 18 | public var name: String { 19 | switch id { 20 | case .id(let string): 21 | string 22 | case .directory(let url): 23 | url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent 24 | } 25 | } 26 | 27 | /// pull the tokenizer from an alternate id 28 | public let tokenizerId: String? 29 | 30 | /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated 31 | public let overrideTokenizer: String? 32 | 33 | /// A reasonable default prompt for the model 34 | public var defaultPrompt: String 35 | 36 | /// Additional tokens to use for end of string 37 | public var extraEOSTokens: Set 38 | 39 | public init( 40 | id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil, 41 | defaultPrompt: String = "hello", 42 | extraEOSTokens: Set = [], 43 | preparePrompt: (@Sendable (String) -> String)? = nil 44 | ) { 45 | self.id = .id(id) 46 | self.tokenizerId = tokenizerId 47 | self.overrideTokenizer = overrideTokenizer 48 | self.defaultPrompt = defaultPrompt 49 | self.extraEOSTokens = extraEOSTokens 50 | } 51 | 52 | public init( 53 | directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil, 54 | defaultPrompt: String = "hello", 55 | extraEOSTokens: Set = [] 56 | ) { 57 | self.id = .directory(directory) 58 | self.tokenizerId = tokenizerId 59 | self.overrideTokenizer = overrideTokenizer 60 | self.defaultPrompt = defaultPrompt 61 | self.extraEOSTokens = extraEOSTokens 62 | } 63 | 64 | public func modelDirectory(hub: HubApi = HubApi()) -> URL { 65 | switch id { 66 | case .id(let id): 67 | // download the model weights and config 68 | let repo = Hub.Repo(id: id) 69 | return hub.localRepoLocation(repo) 70 | 71 | case .directory(let directory): 72 | return directory 73 | } 74 | } 75 | } 76 | 77 | extension ModelConfiguration: Equatable { 78 | 79 | } 80 | 81 | extension ModelConfiguration.Identifier: Equatable { 82 | 83 | public static func == (lhs: ModelConfiguration.Identifier, rhs: ModelConfiguration.Identifier) 84 | -> Bool 85 | { 86 | switch (lhs, rhs) { 87 | case (.id(let lhsID), .id(let rhsID)): 88 | lhsID == rhsID 89 | case (.directory(let lhsURL), .directory(let rhsURL)): 90 | lhsURL == rhsURL 91 | default: 92 | false 93 | } 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/ModelContainer.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import Hub 5 | import MLX 6 | import MLXNN 7 | import Tokenizers 8 | 9 | /// Container for models that guarantees single threaded access. 10 | /// 11 | /// Wrap models used by e.g. the UI in a ModelContainer. Callers can access 12 | /// the model and/or tokenizer (any values from the ``ModelContext``): 13 | /// 14 | /// ```swift 15 | /// let messages = [["role": "user", "content": prompt]] 16 | /// let promptTokens = try await modelContainer.perform { context in 17 | /// try context.tokenizer.applyChatTemplate(messages: messages) 18 | /// } 19 | /// ``` 20 | /// 21 | /// or: 22 | /// 23 | /// ```swift 24 | /// let userInput: UserInput 25 | /// let result = await modelContainer.perform { context in 26 | /// let input = try await context.processor.prepare(input: userInput) 27 | /// return generate( 28 | /// input: input, parameters: generateParameters, context: context 29 | /// ) { tokens in 30 | /// ... 31 | /// } 32 | /// } 33 | /// ``` 34 | public actor ModelContainer { 35 | var context: ModelContext 36 | public var configuration: ModelConfiguration { context.configuration } 37 | 38 | public init(context: ModelContext) { 39 | self.context = context 40 | } 41 | 42 | /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as 43 | /// `MLXArray` is not `Sendable`. 44 | @available(*, deprecated, message: "prefer perform(_:) that uses a ModelContext") 45 | public func perform(_ action: @Sendable (any LanguageModel, Tokenizer) throws -> R) rethrows 46 | -> R 47 | { 48 | try action(context.model, context.tokenizer) 49 | } 50 | 51 | /// Perform an action on the model and/or tokenizer with additional context values. 52 | /// Callers _must_ eval any `MLXArray` before returning as 53 | /// `MLXArray` is not `Sendable`. 54 | @available(*, deprecated, message: "prefer perform(values:_:) that uses a ModelContext") 55 | public func perform( 56 | values: V, _ action: @Sendable (any LanguageModel, Tokenizer, V) throws -> R 57 | ) rethrows -> R { 58 | try action(context.model, context.tokenizer, values) 59 | } 60 | 61 | /// Perform an action on the ``ModelContext``. Callers _must_ eval any `MLXArray` before returning as 62 | /// `MLXArray` is not `Sendable`. 63 | public func perform(_ action: @Sendable (ModelContext) async throws -> R) async rethrows -> R 64 | { 65 | try await action(context) 66 | } 67 | 68 | /// Perform an action on the ``ModelContext`` with additional context values. 69 | /// Callers _must_ eval any `MLXArray` before returning as 70 | /// `MLXArray` is not `Sendable`. 71 | public func perform( 72 | values: V, _ action: @Sendable (ModelContext, V) async throws -> R 73 | ) async rethrows -> R { 74 | try await action(context, values) 75 | } 76 | 77 | /// Update the owned `ModelContext`. 78 | /// - Parameter action: update action 79 | public func update(_ action: @Sendable (inout ModelContext) -> Void) { 80 | action(&context) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/LLMEval.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 8 | 9 | 15 | 21 | 22 | 23 | 24 | 25 | 31 | 32 | 42 | 44 | 50 | 51 | 52 | 53 | 59 | 61 | 67 | 68 | 69 | 70 | 72 | 73 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/VLMEval.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 25 | 26 | 32 | 33 | 43 | 45 | 51 | 52 | 53 | 54 | 60 | 62 | 68 | 69 | 70 | 71 | 73 | 74 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /Libraries/MLXMNIST/Files.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import Gzip 5 | import MLX 6 | 7 | // based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py 8 | 9 | public enum Use: String, Hashable, Sendable { 10 | case test 11 | case training 12 | } 13 | 14 | public enum DataKind: String, Hashable, Sendable { 15 | case images 16 | case labels 17 | } 18 | 19 | public struct FileKind: Hashable, CustomStringConvertible, Sendable { 20 | let use: Use 21 | let data: DataKind 22 | 23 | public init(_ use: Use, _ data: DataKind) { 24 | self.use = use 25 | self.data = data 26 | } 27 | 28 | public var description: String { 29 | "\(use.rawValue)-\(data.rawValue)" 30 | } 31 | } 32 | 33 | struct LoadInfo: Sendable { 34 | let name: String 35 | let offset: Int 36 | let convert: @Sendable (MLXArray) -> MLXArray 37 | } 38 | 39 | let baseURL = URL(string: "https://raw.githubusercontent.com/fgnt/mnist/master/")! 40 | 41 | private let files = [ 42 | FileKind(.training, .images): LoadInfo( 43 | name: "train-images-idx3-ubyte.gz", 44 | offset: 16, 45 | convert: { 46 | $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0 47 | }), 48 | FileKind(.test, .images): LoadInfo( 49 | name: "t10k-images-idx3-ubyte.gz", 50 | offset: 16, 51 | convert: { 52 | $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0 53 | }), 54 | FileKind(.training, .labels): LoadInfo( 55 | name: "train-labels-idx1-ubyte.gz", 56 | offset: 8, 57 | convert: { 58 | $0.asType(.uint32) 59 | }), 60 | FileKind(.test, .labels): LoadInfo( 61 | name: "t10k-labels-idx1-ubyte.gz", 62 | offset: 8, 63 | convert: { 64 | $0.asType(.uint32) 65 | }), 66 | ] 67 | 68 | public func download(into: URL) async throws { 69 | for (_, info) in files { 70 | let fileURL = into.appending(component: info.name) 71 | if !FileManager.default.fileExists(atPath: fileURL.path()) { 72 | print("Download: \(info.name)") 73 | let url = baseURL.appending(component: info.name) 74 | let (data, response) = try await URLSession.shared.data(from: url) 75 | 76 | guard let httpResponse = response as? HTTPURLResponse else { 77 | fatalError("Unable to download \(url), not an http response: \(response)") 78 | } 79 | guard httpResponse.statusCode == 200 else { 80 | fatalError("Unable to download \(url): \(httpResponse)") 81 | } 82 | 83 | try data.write(to: fileURL) 84 | } 85 | } 86 | } 87 | 88 | public func load(from: URL) throws -> [FileKind: MLXArray] { 89 | var result = [FileKind: MLXArray]() 90 | 91 | for (key, info) in files { 92 | let fileURL = from.appending(component: info.name) 93 | let data = try Data(contentsOf: fileURL).gunzipped() 94 | 95 | let array = MLXArray( 96 | data.dropFirst(info.offset), [data.count - info.offset], type: UInt8.self) 97 | 98 | result[key] = info.convert(array) 99 | } 100 | 101 | return result 102 | } 103 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/StableDiffusionExample.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 25 | 26 | 32 | 33 | 43 | 45 | 51 | 52 | 53 | 54 | 60 | 62 | 68 | 69 | 70 | 71 | 73 | 74 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /Tools/Tutorial/Tutorial.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import MLX 5 | 6 | /// mlx-swift tutorial based on: 7 | /// https://github.com/ml-explore/mlx/blob/main/examples/cpp/tutorial.cpp 8 | @main 9 | struct Tutorial { 10 | 11 | static func scalarBasics() { 12 | // create a scalar array 13 | let x = MLXArray(1.0) 14 | 15 | // the datatype is .float32 16 | let dtype = x.dtype 17 | assert(dtype == .float32) 18 | 19 | // get the value 20 | let s = x.item(Float.self) 21 | assert(s == 1.0) 22 | 23 | // reading the value with a different type is a fatal error 24 | // let i = x.item(Int.self) 25 | 26 | // scalars have a size of 1 27 | let size = x.size 28 | assert(size == 1) 29 | 30 | // scalars have 0 dimensions 31 | let ndim = x.ndim 32 | assert(ndim == 0) 33 | 34 | // scalar shapes are empty arrays 35 | let shape = x.shape 36 | assert(shape == []) 37 | } 38 | 39 | static func arrayBasics() { 40 | // make a multidimensional array. 41 | // 42 | // Note: the argument is a [Double] array literal, which is not 43 | // a supported type, but we can explicitly convert it to [Float] 44 | // when we create the MLXArray. 45 | let x = MLXArray(converting: [1.0, 2.0, 3.0, 4.0], [2, 2]) 46 | 47 | // mlx is row-major by default so the first row of this array 48 | // is [1.0, 2.0] and the second row is [3.0, 4.0] 49 | print(x[0]) 50 | print(x[1]) 51 | 52 | // make an array of shape [2, 2] filled with ones 53 | let y = MLXArray.ones([2, 2]) 54 | 55 | // pointwise add x and y 56 | let z = x + y 57 | 58 | // mlx is lazy by default. At this point `z` only 59 | // has a shape and a type but no actual data 60 | assert(z.dtype == .float32) 61 | assert(z.shape == [2, 2]) 62 | 63 | // To actually run the computation you must evaluate `z`. 64 | // Under the hood, mlx records operations in a graph. 65 | // The variable `z` is a node in the graph which points to its operation 66 | // and inputs. When `eval` is called on an array (or arrays), the array and 67 | // all of its dependencies are recursively evaluated to produce the result. 68 | // Once an array is evaluated, it has data and is detached from its inputs. 69 | 70 | // Note: this is being called for demonstration purposes -- all reads 71 | // ensure the array is evaluated. 72 | z.eval() 73 | 74 | // this implicitly evaluates z before converting to a description 75 | print(z) 76 | } 77 | 78 | static func automaticDifferentiation() { 79 | func fn(_ x: MLXArray) -> MLXArray { 80 | x.square() 81 | } 82 | 83 | let gradFn = grad(fn) 84 | 85 | let x = MLXArray(1.5) 86 | let dfdx = gradFn(x) 87 | print(dfdx) 88 | 89 | assert(dfdx.item() == Float(2 * 1.5)) 90 | 91 | let df2dx2 = grad(grad(fn))(x) 92 | print(df2dx2) 93 | 94 | assert(df2dx2.item() == Float(2)) 95 | } 96 | 97 | static func main() { 98 | scalarBasics() 99 | arrayBasics() 100 | automaticDifferentiation() 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/StringOrNumber.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | 5 | /// Representation of a heterogenous type in a JSON configuration file. 6 | /// 7 | /// This can be: a string, a numeric value or an array of numeric values. 8 | /// There are methods to do unwrapping, see e.g. ``asFloat()`` and 9 | /// ``asFloats()`` or callers can switch on the enum. 10 | public enum StringOrNumber: Codable, Equatable, Sendable { 11 | case string(String) 12 | case int(Int) 13 | case float(Float) 14 | case ints([Int]) 15 | case floats([Float]) 16 | 17 | public init(from decoder: Decoder) throws { 18 | let values = try decoder.singleValueContainer() 19 | 20 | if let v = try? values.decode(Int.self) { 21 | self = .int(v) 22 | } else if let v = try? values.decode(Float.self) { 23 | self = .float(v) 24 | } else if let v = try? values.decode([Int].self) { 25 | self = .ints(v) 26 | } else if let v = try? values.decode([Float].self) { 27 | self = .floats(v) 28 | } else { 29 | let v = try values.decode(String.self) 30 | self = .string(v) 31 | } 32 | } 33 | 34 | public func encode(to encoder: Encoder) throws { 35 | var container = encoder.singleValueContainer() 36 | switch self { 37 | case .string(let v): try container.encode(v) 38 | case .int(let v): try container.encode(v) 39 | case .float(let v): try container.encode(v) 40 | case .ints(let v): try container.encode(v) 41 | case .floats(let v): try container.encode(v) 42 | } 43 | } 44 | 45 | /// Return the value as an optional array of integers. 46 | /// 47 | /// This will not coerce `Float` or `String` to `Int`. 48 | public func asInts() -> [Int]? { 49 | switch self { 50 | case .string(let string): nil 51 | case .int(let v): [v] 52 | case .float(let float): nil 53 | case .ints(let array): array 54 | case .floats(let array): nil 55 | } 56 | } 57 | 58 | /// Return the value as an optional integer. 59 | /// 60 | /// This will not coerce `Float` or `String` to `Int`. 61 | public func asInt() -> Int? { 62 | switch self { 63 | case .string(let string): nil 64 | case .int(let v): v 65 | case .float(let float): nil 66 | case .ints(let array): array.count == 1 ? array[0] : nil 67 | case .floats(let array): nil 68 | } 69 | } 70 | 71 | /// Return the value as an optional array of floats. 72 | /// 73 | /// This will not coerce `Int` or `String` to `Float`. 74 | public func asFloats() -> [Float]? { 75 | switch self { 76 | case .string(let string): nil 77 | case .int(let v): [Float(v)] 78 | case .float(let float): [float] 79 | case .ints(let array): array.map { Float($0) } 80 | case .floats(let array): array 81 | } 82 | } 83 | 84 | /// Return the value as an optional float. 85 | /// 86 | /// This will not coerce `Int` or `String` to `Float`. 87 | public func asFloat() -> Float? { 88 | switch self { 89 | case .string(let string): nil 90 | case .int(let v): Float(v) 91 | case .float(let float): float 92 | case .ints(let array): array.count == 1 ? Float(array[0]) : nil 93 | case .floats(let array): array.count == 1 ? array[0] : nil 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/Load.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import Hub 5 | import MLX 6 | import MLXNN 7 | import Tokenizers 8 | 9 | /// Download the model using the `HubApi`. 10 | /// 11 | /// This will download `*.safetensors` and `*.json` if the ``ModelConfiguration`` 12 | /// represents a Hub id, e.g. `mlx-community/gemma-2-2b-it-4bit`. 13 | /// 14 | /// This is typically called via ``ModelFactory/load(hub:configuration:progressHandler:)`` 15 | /// 16 | /// - Parameters: 17 | /// - hub: HubApi instance 18 | /// - configuration: the model identifier 19 | /// - progressHandler: callback for progress 20 | /// - Returns: URL for the directory containing downloaded files 21 | public func downloadModel( 22 | hub: HubApi, configuration: ModelConfiguration, 23 | progressHandler: @Sendable @escaping (Progress) -> Void 24 | ) async throws -> URL { 25 | do { 26 | switch configuration.id { 27 | case .id(let id): 28 | // download the model weights 29 | let repo = Hub.Repo(id: id) 30 | let modelFiles = ["*.safetensors", "*.json"] 31 | return try await hub.snapshot( 32 | from: repo, matching: modelFiles, progressHandler: progressHandler) 33 | 34 | case .directory(let directory): 35 | return directory 36 | } 37 | 38 | } catch Hub.HubClientError.authorizationRequired { 39 | // an authorizationRequired means (typically) that the named repo doesn't exist on 40 | // on the server so retry with local only configuration 41 | return configuration.modelDirectory(hub: hub) 42 | 43 | } catch { 44 | let nserror = error as NSError 45 | if nserror.domain == NSURLErrorDomain && nserror.code == NSURLErrorNotConnectedToInternet { 46 | // Error Domain=NSURLErrorDomain Code=-1009 "The Internet connection appears to be offline." 47 | // fall back to the local directory 48 | return configuration.modelDirectory(hub: hub) 49 | } else { 50 | throw error 51 | } 52 | } 53 | } 54 | 55 | /// Load model weights. 56 | /// 57 | /// This is typically called via ``ModelFactory/load(hub:configuration:progressHandler:)``. 58 | /// This function loads all `safetensor` files in the given `modelDirectory`, 59 | /// calls ``LanguageModel/sanitize(weights:)``, applies optional quantization, and 60 | /// updates the model with the weights. 61 | public func loadWeights( 62 | modelDirectory: URL, model: LanguageModel, quantization: BaseConfiguration.Quantization? = nil 63 | ) throws { 64 | // load the weights 65 | var weights = [String: MLXArray]() 66 | let enumerator = FileManager.default.enumerator( 67 | at: modelDirectory, includingPropertiesForKeys: nil)! 68 | for case let url as URL in enumerator { 69 | if url.pathExtension == "safetensors" { 70 | let w = try loadArrays(url: url) 71 | for (key, value) in w { 72 | weights[key] = value 73 | } 74 | } 75 | } 76 | 77 | // per-model cleanup 78 | weights = model.sanitize(weights: weights) 79 | 80 | // quantize if needed 81 | if let quantization { 82 | quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) { 83 | path, module in 84 | weights["\(path).scales"] != nil 85 | } 86 | } 87 | 88 | // apply the loaded weights 89 | let parameters = ModuleParameters.unflattened(weights) 90 | try model.update(parameters: parameters, verify: [.all]) 91 | 92 | eval(model) 93 | } 94 | -------------------------------------------------------------------------------- /Tools/mnist-tool/MNISTTool.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | import MLX 6 | import MLXMNIST 7 | import MLXNN 8 | import MLXOptimizers 9 | import MLXRandom 10 | 11 | @main 12 | struct MNISTTool: AsyncParsableCommand { 13 | static let configuration = CommandConfiguration( 14 | abstract: "Command line tool for training mnist models", 15 | subcommands: [Train.self], 16 | defaultSubcommand: Train.self) 17 | } 18 | 19 | #if swift(>=5.10) 20 | extension MLX.DeviceType: @retroactive ExpressibleByArgument { 21 | public init?(argument: String) { 22 | self.init(rawValue: argument) 23 | } 24 | } 25 | #else 26 | extension MLX.DeviceType: ExpressibleByArgument { 27 | public init?(argument: String) { 28 | self.init(rawValue: argument) 29 | } 30 | } 31 | #endif 32 | 33 | struct Train: AsyncParsableCommand { 34 | 35 | @Option(name: .long, help: "Directory with the training data") 36 | var data: String 37 | 38 | @Option(name: .long, help: "The PRNG seed") 39 | var seed: UInt64 = 0 40 | 41 | @Option var batchSize = 256 42 | @Option var epochs = 20 43 | @Option var learningRate: Float = 1e-1 44 | 45 | @Option var device = DeviceType.gpu 46 | 47 | @Flag var compile = false 48 | 49 | func run() async throws { 50 | Device.setDefault(device: Device(device)) 51 | 52 | MLXRandom.seed(seed) 53 | var generator: RandomNumberGenerator = SplitMix64(seed: seed) 54 | 55 | // load the data 56 | let url = URL(filePath: data) 57 | 58 | try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true) 59 | try await download(into: url) 60 | 61 | let data = try load(from: url) 62 | 63 | let trainImages = data[.init(.training, .images)]! 64 | let trainLabels = data[.init(.training, .labels)]! 65 | let testImages = data[.init(.test, .images)]! 66 | let testLabels = data[.init(.test, .labels)]! 67 | 68 | // create the model 69 | let model = LeNet() 70 | eval(model.parameters()) 71 | 72 | let lg = valueAndGrad(model: model, loss) 73 | let optimizer = SGD(learningRate: learningRate) 74 | 75 | func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray { 76 | let (loss, grads) = lg(model, x, y) 77 | optimizer.update(model: model, gradients: grads) 78 | return loss 79 | } 80 | 81 | let resolvedStep = 82 | compile 83 | ? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step 84 | 85 | for e in 0 ..< epochs { 86 | let start = Date.timeIntervalSinceReferenceDate 87 | 88 | for (x, y) in iterateBatches( 89 | batchSize: batchSize, x: trainImages, y: trainLabels, using: &generator) 90 | { 91 | _ = resolvedStep(x, y) 92 | 93 | // eval the parameters so the next iteration is independent 94 | eval(model, optimizer) 95 | } 96 | 97 | let accuracy = eval(model: model, x: testImages, y: testLabels) 98 | 99 | let end = Date.timeIntervalSinceReferenceDate 100 | 101 | print( 102 | """ 103 | Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted()) 104 | Time: \((end - start).formatted()) 105 | 106 | """ 107 | ) 108 | } 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "originHash" : "347ce608ed233db4ed416d22692a515e7f4fd2fd3eed7904f75bb8b35eb5366c", 3 | "pins" : [ 4 | { 5 | "identity" : "gzipswift", 6 | "kind" : "remoteSourceControl", 7 | "location" : "https://github.com/1024jp/GzipSwift", 8 | "state" : { 9 | "revision" : "731037f6cc2be2ec01562f6597c1d0aa3fe6fd05", 10 | "version" : "6.0.1" 11 | } 12 | }, 13 | { 14 | "identity" : "jinja", 15 | "kind" : "remoteSourceControl", 16 | "location" : "https://github.com/johnmai-dev/Jinja", 17 | "state" : { 18 | "revision" : "bbddb92fc51ae420b87300298370fd1dfc308f73", 19 | "version" : "1.1.1" 20 | } 21 | }, 22 | { 23 | "identity" : "mlx-swift", 24 | "kind" : "remoteSourceControl", 25 | "location" : "https://github.com/ml-explore/mlx-swift", 26 | "state" : { 27 | "revision" : "70dbb62128a5a1471a5ab80363430adb33470cab", 28 | "version" : "0.21.2" 29 | } 30 | }, 31 | { 32 | "identity" : "networkimage", 33 | "kind" : "remoteSourceControl", 34 | "location" : "https://github.com/gonzalezreal/NetworkImage", 35 | "state" : { 36 | "revision" : "2849f5323265386e200484b0d0f896e73c3411b9", 37 | "version" : "6.0.1" 38 | } 39 | }, 40 | { 41 | "identity" : "progress.swift", 42 | "kind" : "remoteSourceControl", 43 | "location" : "https://github.com/jkandzi/Progress.swift", 44 | "state" : { 45 | "revision" : "fed6598735d7982058690acf8f52a0a5fdaeb3e0", 46 | "version" : "0.4.0" 47 | } 48 | }, 49 | { 50 | "identity" : "swift-argument-parser", 51 | "kind" : "remoteSourceControl", 52 | "location" : "https://github.com/apple/swift-argument-parser.git", 53 | "state" : { 54 | "revision" : "0fbc8848e389af3bb55c182bc19ca9d5dc2f255b", 55 | "version" : "1.4.0" 56 | } 57 | }, 58 | { 59 | "identity" : "swift-cmark", 60 | "kind" : "remoteSourceControl", 61 | "location" : "https://github.com/swiftlang/swift-cmark", 62 | "state" : { 63 | "revision" : "3ccff77b2dc5b96b77db3da0d68d28068593fa53", 64 | "version" : "0.5.0" 65 | } 66 | }, 67 | { 68 | "identity" : "swift-collections", 69 | "kind" : "remoteSourceControl", 70 | "location" : "https://github.com/apple/swift-collections.git", 71 | "state" : { 72 | "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7", 73 | "version" : "1.1.4" 74 | } 75 | }, 76 | { 77 | "identity" : "swift-markdown-ui", 78 | "kind" : "remoteSourceControl", 79 | "location" : "https://github.com/gonzalezreal/swift-markdown-ui", 80 | "state" : { 81 | "revision" : "5f613358148239d0292c0cef674a3c2314737f9e", 82 | "version" : "2.4.1" 83 | } 84 | }, 85 | { 86 | "identity" : "swift-numerics", 87 | "kind" : "remoteSourceControl", 88 | "location" : "https://github.com/apple/swift-numerics", 89 | "state" : { 90 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", 91 | "version" : "1.0.2" 92 | } 93 | }, 94 | { 95 | "identity" : "swift-transformers", 96 | "kind" : "remoteSourceControl", 97 | "location" : "https://github.com/huggingface/swift-transformers", 98 | "state" : { 99 | "revision" : "55710ddfb1ae804b4b7ce973be75cf2e41272185", 100 | "version" : "0.1.17" 101 | } 102 | } 103 | ], 104 | "version" : 3 105 | } 106 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/KVCache.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import MLX 5 | 6 | /// Interface for Key/Value cache for LLMs. 7 | /// 8 | /// See ``LanguageModel/newCache(parameters:)`` 9 | public protocol KVCache: Evaluatable { 10 | 11 | /// get the current offset 12 | var offset: Int { get } 13 | 14 | func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) 15 | } 16 | 17 | func createAdditiveCausalMask(n: Int, offset: Int) -> MLXArray { 18 | let rinds = MLXArray(Int32(0) ..< Int32(offset + n)) 19 | let linds = offset != 0 ? MLXArray(Int32(offset) ..< Int32(offset + n)) : rinds 20 | let mask = linds[0..., .newAxis] .< rinds[.newAxis] 21 | return mask * Float32(-1e9) 22 | } 23 | 24 | /// create an attention mask using the parameters from the KVCache. 25 | /// 26 | /// See also ``MultiHeadAttention/createAdditiveCausalMask(_:dtype:)`` -- same idea 27 | /// but doesn't honor the cache offset. 28 | public func createAttentionMask(h: MLXArray, cache: [KVCache]?) -> MLXArray? { 29 | let t = h.dim(1) 30 | if t > 1 { 31 | var offset = 0 32 | if let c = cache?.first { 33 | offset = c.offset 34 | } 35 | return createAdditiveCausalMask(n: t, offset: offset) 36 | .asType(h.dtype) 37 | } 38 | return nil 39 | } 40 | 41 | /// See https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/base.py#L11 42 | public class KVCacheSimple: KVCache, Evaluatable { 43 | var keys: MLXArray? 44 | var values: MLXArray? 45 | 46 | public var offset = 0 47 | var step = 256 48 | 49 | public init() {} 50 | 51 | public func innerState() -> [MLXArray] { 52 | [self.keys, self.values].compactMap { $0 } 53 | } 54 | 55 | public func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { 56 | let previous = self.offset 57 | 58 | let reset = 59 | if let currentKeys = self.keys, (previous + keys.dim(2)) > currentKeys.dim(2) { 60 | true 61 | } else { 62 | self.keys == nil 63 | } 64 | if reset { 65 | let B = keys.dim(0) 66 | let kvHeads = keys.dim(1) 67 | let kHeadDim = keys.dim(3) 68 | let vHeadDim = values.dim(3) 69 | 70 | let nSteps = (step + keys.dim(2) - 1) / step 71 | let kShape = [B, kvHeads, nSteps * step, kHeadDim] 72 | let vShape = [B, kvHeads, nSteps * step, vHeadDim] 73 | let newK = MLXArray.zeros(kShape, dtype: keys.dtype) 74 | let newV = MLXArray.zeros(vShape, dtype: values.dtype) 75 | 76 | if var currentKeys = self.keys, var currentValues = self.values { 77 | if previous % step != 0 { 78 | currentKeys = currentKeys[.ellipsis, .. Pooling { 26 | let configurationURL = modelDirectory.appending(components: "1_Pooling", "config.json") 27 | guard 28 | let poolingConfig = try? JSONDecoder().decode( 29 | PoolingConfiguration.self, from: Data(contentsOf: configurationURL)) 30 | else { 31 | return Pooling(strategy: .none) 32 | } 33 | 34 | return Pooling(config: poolingConfig) 35 | } 36 | 37 | public class Pooling: Module { 38 | public enum Strategy { 39 | case mean 40 | case cls 41 | case first 42 | case last 43 | case max 44 | case none 45 | } 46 | let strategy: Strategy 47 | let dimension: Int? 48 | 49 | public init( 50 | strategy: Strategy, dimension: Int? = nil 51 | ) { 52 | self.strategy = strategy 53 | self.dimension = dimension 54 | } 55 | 56 | public init( 57 | config: PoolingConfiguration 58 | ) { 59 | dimension = config.dimension 60 | if config.poolingModeClsToken { 61 | strategy = .cls 62 | } else if config.poolingModeMeanTokens { 63 | strategy = .mean 64 | } else if config.poolingModeMaxTokens { 65 | strategy = .max 66 | } else if config.poolingModeLastToken { 67 | strategy = .last 68 | } else { 69 | strategy = .first 70 | } 71 | } 72 | 73 | public func callAsFunction( 74 | _ inputs: EmbeddingModelOutput, mask: MLXArray? = nil, normalize: Bool = false, 75 | applyLayerNorm: Bool = false 76 | ) -> MLXArray { 77 | let _mask = mask ?? MLXArray.ones(Array(inputs.hiddenStates?.shape[0 ..< 2] ?? [0])) 78 | 79 | var pooled: MLXArray 80 | switch self.strategy { 81 | case .mean: 82 | pooled = 83 | sum( 84 | inputs.hiddenStates! * _mask.expandedDimensions(axes: [-1]), 85 | axis: 1) 86 | / sum(_mask, axis: -1, keepDims: true) 87 | case .max: 88 | pooled = MLX.max( 89 | inputs.hiddenStates! * _mask.expandedDimensions(axes: [-1]), axis: 1) 90 | case .first: 91 | pooled = inputs.hiddenStates![0..., 0, 0...] 92 | case .last: 93 | pooled = inputs.hiddenStates![0..., -1, 0...] 94 | case .cls: 95 | pooled = 96 | inputs.pooledOutput 97 | ?? inputs.hiddenStates![0..., 0, 0...] 98 | case .none: 99 | pooled = inputs.pooledOutput ?? inputs.hiddenStates! 100 | } 101 | if applyLayerNorm { 102 | pooled = MLXFast.layerNorm(pooled, eps: 1e-5) 103 | } 104 | if let dimension { 105 | pooled = pooled[0..., 0 ..< dimension] 106 | } 107 | if normalize { 108 | pooled = pooled / norm(pooled, axis: -1, keepDims: true) 109 | } 110 | return pooled 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | Developers can use these examples in their own programs -- just import the swift package! 4 | 5 | - [MLXLLMCommon](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxlmcommon) -- common API for LLM and VLM 6 | - [MLXLLM](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxllm) -- large language model example implementations 7 | - [MLXVLM](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxvlm) -- visual language model example implementations 8 | - [MLXEmbedders](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxembedders) -- popular Encoders / Embedding models example implementations 9 | - [StableDiffusion](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/stablediffusion) -- SDXL Turbo and Stable Diffusion mdeol example implementations 10 | - [MLXMNIST](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxmnist) -- MNIST implementation for all your digit recognition needs 11 | 12 | # MLX Swift Examples 13 | 14 | Example [MLX Swift](https://github.com/ml-explore/mlx-swift) programs. 15 | 16 | - [MNISTTrainer](Applications/MNISTTrainer/README.md): An example that runs on 17 | both iOS and macOS that downloads MNIST training data and trains a 18 | [LeNet](https://en.wikipedia.org/wiki/LeNet). 19 | 20 | - [LLMEval](Applications/LLMEval/README.md): An example that runs on both iOS 21 | and macOS that downloads an LLM and tokenizer from Hugging Face and 22 | generates text from a given prompt. 23 | 24 | - [VLMEval](Applications/VLMEval/README.md): An example that runs on iOS, macOS and visionOS to download a VLM and tokenizer from Hugging Face and 25 | analyzes the given image and describe it in text. 26 | 27 | - [LinearModelTraining](Tools/LinearModelTraining/README.md): An example that 28 | trains a simple linear model. 29 | 30 | - [StableDiffusionExample](Applications/StableDiffusionExample/README.md): An 31 | example that runs on both iOS and macOS that downloads a stable diffusion model 32 | from Hugging Face and and generates an image from a given prompt. 33 | 34 | - [llm-tool](Tools/llm-tool/README.md): A command line tool for generating text 35 | using a variety of LLMs available on the Hugging Face hub. 36 | 37 | - [image-tool](Tools/image-tool/README.md): A command line tool for generating images 38 | using a stable diffusion model from Hugging Face. 39 | 40 | - [mnist-tool](Tools/mnist-tool/README.md): A command line tool for training a 41 | a LeNet on MNIST. 42 | 43 | ## Running 44 | 45 | The application and command line tool examples can be run from Xcode or from 46 | the command line: 47 | 48 | ``` 49 | ./mlx-run llm-tool --prompt "swift programming language" 50 | ``` 51 | 52 | Note: `mlx-run` is a shell script that uses `xcode` command line tools to 53 | locate the built binaries. It is equivalent to running from Xcode itself. 54 | 55 | See also: 56 | 57 | - [MLX troubleshooting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/troubleshooting) 58 | 59 | ## Installation of libraries 60 | 61 | The MLXLLM, MLXVLM, MLXLMCommon, MLXMNIST, MLXEmbedders, and StableDiffusion libraries in the example repo are available 62 | as Swift Packages. 63 | 64 | 65 | Add the following dependency to your Package.swift 66 | 67 | ```swift 68 | .package(url: "https://github.com/ml-explore/mlx-swift-examples/", branch: "main"), 69 | ``` 70 | 71 | Then add one or more libraries to the target as a dependency: 72 | 73 | ```swift 74 | .target( 75 | name: "YourTargetName", 76 | dependencies: [ 77 | .product(name: "MLXLLM", package: "mlx-swift-examples") 78 | ]), 79 | ``` 80 | 81 | Alternatively, add `https://github.com/ml-explore/mlx-swift-examples/` to the `Project Dependencies` and set the `Dependency Rule` to `Branch` and `main` in Xcode. 82 | -------------------------------------------------------------------------------- /Libraries/MLXLLM/Documentation.docc/using-model.md: -------------------------------------------------------------------------------- 1 | # Using a Model 2 | 3 | Using a model is easy: load the weights, tokenize and evaluate. 4 | 5 | ## Loading a Model 6 | 7 | A model is typically loaded by using a `ModelFactory` and a `ModelConfiguration`: 8 | 9 | ```swift 10 | // e.g. LLMModelFactory.shared 11 | let modelFactory: ModelFactory 12 | 13 | // e.g. MLXLLM.ModelRegistry.llama3_8B_4bit 14 | let modelConfiguration: ModelConfiguration 15 | 16 | let container = try await modelFactory.loadContainer(configuration: modelConfiguration) 17 | ``` 18 | 19 | The `container` provides an isolation context (an `actor`) to run inference in the model. 20 | 21 | Predefined `ModelConfiguration` instances are provided as static variables 22 | on the `ModelRegistry` types or they can be created: 23 | 24 | ```swift 25 | let modelConfiguration = ModelConfiguration(id: "mlx-community/llama3_8B_4bit") 26 | ``` 27 | 28 | The flow inside the `ModelFactory` goes like this: 29 | 30 | ```swift 31 | public class LLMModelFactory: ModelFactory { 32 | 33 | public func _load( 34 | hub: HubApi, configuration: ModelConfiguration, 35 | progressHandler: @Sendable @escaping (Progress) -> Void 36 | ) async throws -> ModelContext { 37 | // download the weight and config using HubApi 38 | // load the base configuration 39 | // using the typeRegistry create a model (random weights) 40 | // load the weights, apply quantization as needed, update the model 41 | // calls model.sanitize() for weight preparation 42 | // load the tokenizer 43 | // (vlm) load the processor configuration, create the processor 44 | } 45 | } 46 | ``` 47 | 48 | Callers with specialized requirements can use these individual components to manually 49 | load models, if needed. 50 | 51 | ## Evaluation Flow 52 | 53 | - Load the Model 54 | - UserInput 55 | - LMInput 56 | - generate() 57 | - NaiveStreamingDetokenizer 58 | - TokenIterator 59 | 60 | ## Evaluating a Model 61 | 62 | Once a model is loaded you can evaluate a prompt or series of 63 | messages. Minimally you need to prepare the user input: 64 | 65 | ```swift 66 | let prompt = "Describe the image in English" 67 | var input = UserInput(prompt: prompt, images: image.map { .url($0) }) 68 | input.processing.resize = .init(width: 256, height: 256) 69 | ``` 70 | 71 | This example shows adding some images and processing instructions -- if 72 | model accepts text only then these parts can be omitted. The inference 73 | calls are the same. 74 | 75 | Assuming you are using a `ModelContainer` (an actor that holds 76 | a `ModelContext`, which is the bundled set of types that implement a 77 | model), the first step is to convert the `UserInput` into the 78 | `LMInput` (LanguageModel Input): 79 | 80 | ```swift 81 | let generateParameters: GenerateParameters 82 | let input: UserInput 83 | 84 | let result = try await modelContainer.perform { [input] context in 85 | let input = try context.processor.prepare(input: input) 86 | 87 | ``` 88 | 89 | Given that `input` we can call `generate()` to produce a stream 90 | of tokens. In this example we use a `NaiveStreamingDetokenizer` 91 | to assist in converting a stream of tokens into text and print it. 92 | The stream is stopped after we hit a maximum number of tokens: 93 | 94 | ``` 95 | var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer) 96 | 97 | return try MLXLMCommon.generate( 98 | input: input, parameters: generateParameters, context: context 99 | ) { tokens in 100 | 101 | if let last = tokens.last { 102 | detokenizer.append(token: last) 103 | } 104 | 105 | if let new = detokenizer.next() { 106 | print(new, terminator: "") 107 | fflush(stdout) 108 | } 109 | 110 | if tokens.count >= maxTokens { 111 | return .stop 112 | } else { 113 | return .more 114 | } 115 | } 116 | } 117 | ``` 118 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/PredictionView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // PredictionView.swift 3 | // MNISTTrainer 4 | // 5 | // Created by Rounak Jain on 3/9/24. 6 | // 7 | 8 | import MLX 9 | import MLXMNIST 10 | import MLXNN 11 | import SwiftUI 12 | 13 | struct Canvas: View { 14 | 15 | @Binding var path: Path 16 | @State var lastPoint: CGPoint? 17 | 18 | var body: some View { 19 | path 20 | .stroke(.white, lineWidth: 10) 21 | .background(.black) 22 | .gesture( 23 | DragGesture(minimumDistance: 0.05) 24 | .onChanged { touch in 25 | add(point: touch.location) 26 | } 27 | .onEnded { touch in 28 | lastPoint = nil 29 | } 30 | ) 31 | } 32 | 33 | func add(point: CGPoint) { 34 | var newPath = path 35 | if let lastPoint { 36 | newPath.move(to: lastPoint) 37 | newPath.addLine(to: point) 38 | } else { 39 | newPath.move(to: point) 40 | } 41 | self.path = newPath 42 | lastPoint = point 43 | } 44 | } 45 | 46 | extension Path { 47 | mutating func center(to newMidPoint: CGPoint) { 48 | let middleX = boundingRect.midX 49 | let middleY = boundingRect.midY 50 | self = offsetBy(dx: newMidPoint.x - middleX, dy: newMidPoint.y - middleY) 51 | } 52 | } 53 | 54 | struct PredictionView: View { 55 | @State var path: Path = Path() 56 | @State var prediction: Int? 57 | let model: LeNetContainer 58 | let canvasSize = 150.0 59 | let mnistImageSize: CGSize = CGSize(width: 28, height: 28) 60 | 61 | var body: some View { 62 | VStack { 63 | if let prediction { 64 | Text("You've drawn a \(prediction)") 65 | } else { 66 | Text("Draw a digit") 67 | } 68 | Canvas(path: $path) 69 | .frame(width: canvasSize, height: canvasSize) 70 | HStack { 71 | Button("Predict") { 72 | path.center(to: CGPoint(x: canvasSize / 2, y: canvasSize / 2)) 73 | predict() 74 | } 75 | Button("Clear") { 76 | path = Path() 77 | prediction = nil 78 | } 79 | } 80 | } 81 | } 82 | 83 | @MainActor 84 | func predict() { 85 | let imageRenderer = ImageRenderer( 86 | content: Canvas(path: $path).frame(width: 150, height: 150)) 87 | 88 | if let image = imageRenderer.cgImage { 89 | Task { 90 | self.prediction = await model.evaluate(image: image) 91 | } 92 | } 93 | } 94 | } 95 | 96 | extension CGImage { 97 | func grayscaleImage(with newSize: CGSize) -> CGImage? { 98 | let colorSpace = CGColorSpaceCreateDeviceGray() 99 | let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue) 100 | 101 | guard 102 | let context = CGContext( 103 | data: nil, 104 | width: Int(newSize.width), 105 | height: Int(newSize.height), 106 | bitsPerComponent: 8, 107 | bytesPerRow: Int(newSize.width), 108 | space: colorSpace, 109 | bitmapInfo: bitmapInfo.rawValue) 110 | else { 111 | return nil 112 | } 113 | context.draw(self, in: CGRect(x: 0, y: 0, width: newSize.width, height: newSize.width)) 114 | return context.makeImage() 115 | } 116 | 117 | func pixelData() -> MLXArray { 118 | guard let data = self.dataProvider?.data else { 119 | return [] 120 | } 121 | let bytePtr = CFDataGetBytePtr(data) 122 | let count = CFDataGetLength(data) 123 | return MLXArray(UnsafeBufferPointer(start: bytePtr, count: count)) 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/README.md: -------------------------------------------------------------------------------- 1 | # MLXLMCommon 2 | 3 | MLXLMCommon contains types and code that is generic across many types 4 | of language models, from LLMs to VLMs: 5 | 6 | - Evaluation 7 | - KVCache 8 | - Loading 9 | - UserInput 10 | 11 | ## Loading a Model 12 | 13 | A model is typically loaded by using a `ModelFactory` and a `ModelConfiguration`: 14 | 15 | ```swift 16 | // e.g. VLMModelFactory.shared 17 | let modelFactory: ModelFactory 18 | 19 | // e.g. MLXVLM.ModelRegistry.paligemma3bMix4488bit 20 | let modelConfiguration: ModelConfiguration 21 | 22 | let container = try await modelFactory.loadContainer(configuration: modelConfiguration) 23 | ``` 24 | 25 | The `container` provides an isolation context (an `actor`) to run inference in the model. 26 | 27 | Predefined `ModelConfiguration` instances are provided as static variables 28 | on the `ModelRegistry` types or they can be created: 29 | 30 | ```swift 31 | let modelConfiguration = ModelConfiguration(id: "mlx-community/paligemma-3b-mix-448-8bit") 32 | ``` 33 | 34 | The flow inside the `ModelFactory` goes like this: 35 | 36 | ```swift 37 | public class VLMModelFactory: ModelFactory { 38 | 39 | public func _load( 40 | hub: HubApi, configuration: ModelConfiguration, 41 | progressHandler: @Sendable @escaping (Progress) -> Void 42 | ) async throws -> ModelContext { 43 | // download the weight and config using HubApi 44 | // load the base configuration 45 | // using the typeRegistry create a model (random weights) 46 | // load the weights, apply quantization as needed, update the model 47 | // calls model.sanitize() for weight preparation 48 | // load the tokenizer 49 | // (vlm) load the processor configuration, create the processor 50 | } 51 | } 52 | ``` 53 | 54 | Callers with specialized requirements can use these individual components to manually 55 | load models, if needed. 56 | 57 | ## Evaluation Flow 58 | 59 | - Load the Model 60 | - UserInput 61 | - LMInput 62 | - generate() 63 | - NaiveStreamingDetokenizer 64 | - TokenIterator 65 | 66 | ## Using a Model 67 | 68 | Once a model is loaded you can evaluate a prompt or series of 69 | messages. Minimally you need to prepare the user input: 70 | 71 | ```swift 72 | let prompt = "Describe the image in English" 73 | var input = UserInput(prompt: prompt, images: image.map { .url($0) }) 74 | input.processing.resize = .init(width: 256, height: 256) 75 | ``` 76 | 77 | This example shows adding some images and processing instructions -- if 78 | model accepts text only then these parts can be omitted. The inference 79 | calls are the same. 80 | 81 | Assuming you are using a `ModelContainer` (an actor that holds 82 | a `ModelContext`, which is the bundled set of types that implement a 83 | model), the first step is to convert the `UserInput` into the 84 | `LMInput` (LanguageModel Input): 85 | 86 | ```swift 87 | let generateParameters: GenerateParameters 88 | let input: UserInput 89 | 90 | let result = try await modelContainer.perform { [input] context in 91 | let input = try context.processor.prepare(input: input) 92 | 93 | ``` 94 | 95 | Given that `input` we can call `generate()` to produce a stream 96 | of tokens. In this example we use a `NaiveStreamingDetokenizer` 97 | to assist in converting a stream of tokens into text and print it. 98 | The stream is stopped after we hit a maximum number of tokens: 99 | 100 | ``` 101 | var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer) 102 | 103 | return try MLXLMCommon.generate( 104 | input: input, parameters: generateParameters, context: context 105 | ) { tokens in 106 | 107 | if let last = tokens.last { 108 | detokenizer.append(token: last) 109 | } 110 | 111 | if let new = detokenizer.next() { 112 | print(new, terminator: "") 113 | fflush(stdout) 114 | } 115 | 116 | if tokens.count >= maxTokens { 117 | return .stop 118 | } else { 119 | return .more 120 | } 121 | } 122 | } 123 | ``` 124 | 125 | -------------------------------------------------------------------------------- /Libraries/Embedders/EmbeddingModel.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | @preconcurrency import Hub 5 | import MLX 6 | import MLXNN 7 | import Tokenizers 8 | 9 | /// Container for models that guarantees single threaded access. 10 | /// 11 | /// Wrap models used by e.g. the UI in a ModelContainer. Callers can access 12 | /// the model and/or tokenizer: 13 | /// 14 | /// ```swift 15 | /// let promptTokens = await modelContainer.perform { _, tokenizer in 16 | /// tokenizer.encode(text: prompt) 17 | /// } 18 | /// ``` 19 | /// 20 | /// or: 21 | /// 22 | /// ```swift 23 | /// let result = await modelContainer.perform { model, tokenizer in 24 | /// LLM.generate( 25 | /// promptTokens: promptTokens, parameters: generateParameters, model: model, 26 | /// tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens 27 | /// ) { tokens in 28 | /// ... 29 | /// } 30 | /// } 31 | /// ``` 32 | public actor ModelContainer { 33 | let model: EmbeddingModel 34 | let tokenizer: Tokenizer 35 | let pooler: Pooling 36 | 37 | public init( 38 | model: EmbeddingModel, tokenizer: Tokenizer, pooler: Pooling = Pooling(strategy: .none) 39 | ) { 40 | self.model = model 41 | self.tokenizer = tokenizer 42 | self.pooler = pooler 43 | } 44 | 45 | /// build the model and tokenizer without passing non-sendable data over isolation barriers 46 | public init( 47 | hub: HubApi, modelDirectory: URL, configuration: ModelConfiguration 48 | ) async throws { 49 | self.model = try loadSynchronous(modelDirectory: modelDirectory) 50 | 51 | let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig( 52 | configuration: configuration, hub: hub) 53 | self.tokenizer = try PreTrainedTokenizer( 54 | tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) 55 | self.pooler = loadPooling(modelDirectory: modelDirectory) //?? Pooling(strategy: .none) 56 | } 57 | 58 | /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as 59 | /// `MLXArray` is not `Sendable`. 60 | public func perform(_ action: @Sendable (EmbeddingModel, Tokenizer, Pooling) throws -> R) 61 | rethrows 62 | -> R 63 | { 64 | try action(model, tokenizer, pooler) 65 | } 66 | } 67 | 68 | extension Module { 69 | 70 | /// Compute the number of parameters in a possibly quantized model 71 | public func numParameters() -> Int { 72 | return leafModules().flattenedValues().map { 73 | mod -> Int in 74 | if let qlin = mod as? QuantizedLinear { 75 | return qlin.scales.size * qlin.groupSize 76 | } else if let qemb = mod as? QuantizedEmbedding { 77 | return qemb.scales.size * qemb.groupSize 78 | } else { 79 | return mod.parameters().flattenedValues().reduce( 80 | 0, 81 | { 82 | $0 + $1.size 83 | }) 84 | } 85 | }.reduce(0, +) 86 | } 87 | } 88 | 89 | public struct EmbeddingModelOutput { 90 | let hiddenStates: MLXArray? 91 | let pooledOutput: MLXArray? 92 | } 93 | 94 | public protocol EmbeddingModel: Module { 95 | var vocabularySize: Int { get } 96 | func callAsFunction( 97 | _ inputs: MLXArray, positionIds: MLXArray?, tokenTypeIds: MLXArray?, 98 | attentionMask: MLXArray? 99 | ) -> EmbeddingModelOutput 100 | /// Optionally preprocess the weights and modify / remove values as needed. 101 | func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] 102 | } 103 | 104 | extension EmbeddingModel { 105 | func callAsFunction( 106 | _ inputs: MLXArray, positionIds: MLXArray? = nil, tokenTypeIds: MLXArray? = nil, 107 | attentionMask: MLXArray? = nil 108 | ) -> EmbeddingModelOutput { 109 | return callAsFunction( 110 | inputs, positionIds: positionIds, tokenTypeIds: tokenTypeIds, 111 | attentionMask: attentionMask) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /Libraries/StableDiffusion/Sampler.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import MLX 5 | import MLXRandom 6 | 7 | // port of https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/sampler.py 8 | 9 | /// Interpolate the function defined by `(0 ..< y.count) y)` at positions `xNew`. 10 | func interpolate(y: MLXArray, xNew: MLXArray) -> MLXArray { 11 | let xLow = xNew.asType(.int32) 12 | let xHigh = minimum(xLow + 1, y.count - 1) 13 | 14 | let yLow = y[xLow] 15 | let yHigh = y[xHigh] 16 | let deltaX = xNew - xLow 17 | let yNew = yLow * (1 - deltaX) + deltaX * yHigh 18 | 19 | return yNew 20 | } 21 | 22 | /// A simple Euler integrator that can be used to sample from our diffusion models. 23 | /// 24 | /// The method ``step()`` performs one Euler step from `x_t` to `x_t_prev`. 25 | class SimpleEulerSampler { 26 | 27 | let sigmas: MLXArray 28 | 29 | public init(configuration: DiffusionConfiguration) { 30 | let betas: MLXArray 31 | 32 | // compute the noise schedule 33 | switch configuration.betaSchedule { 34 | case .linear: 35 | betas = MLXArray.linspace( 36 | configuration.betaStart, configuration.betaEnd, count: configuration.trainSteps) 37 | case .scaledLinear: 38 | betas = MLXArray.linspace( 39 | sqrt(configuration.betaStart), sqrt(configuration.betaEnd), 40 | count: configuration.trainSteps 41 | ).square() 42 | } 43 | 44 | let alphas = 1 - betas 45 | let alphasCumprod = cumprod(alphas) 46 | 47 | self.sigmas = concatenated([ 48 | MLXArray.zeros([1]), ((1 - alphasCumprod) / alphasCumprod).sqrt(), 49 | ]) 50 | } 51 | 52 | public var maxTime: Int { 53 | sigmas.count - 1 54 | } 55 | 56 | public func samplePrior(shape: [Int], dType: DType = .float32, key: MLXArray? = nil) -> MLXArray 57 | { 58 | let noise = MLXRandom.normal(shape, key: key) 59 | return (noise * sigmas[-1] * (sigmas[-1].square() + 1).rsqrt()).asType(dType) 60 | } 61 | 62 | public func addNoise(x: MLXArray, t: MLXArray, key: MLXArray? = nil) -> MLXArray { 63 | let noise = MLXRandom.normal(x.shape, key: key) 64 | let s = sigmas(t) 65 | return (x + noise * s) * (s.square() + 1).rsqrt() 66 | } 67 | 68 | public func sigmas(_ t: MLXArray) -> MLXArray { 69 | interpolate(y: sigmas, xNew: t) 70 | } 71 | 72 | public func timeSteps(steps: Int, start: Int? = nil, dType: DType = .float32) -> [( 73 | MLXArray, MLXArray 74 | )] { 75 | let start = start ?? (sigmas.count - 1) 76 | precondition(0 < start) 77 | precondition(start <= sigmas.count - 1) 78 | let steps = MLX.linspace(start, 0, count: steps + 1).asType(dType) 79 | 80 | return Array(zip(steps, steps[1...])) 81 | } 82 | 83 | open func step(epsPred: MLXArray, xt: MLXArray, t: MLXArray, tPrev: MLXArray) -> MLXArray { 84 | let dtype = epsPred.dtype 85 | let sigma = sigmas(t).asType(dtype) 86 | let sigmaPrev = sigmas(tPrev).asType(dtype) 87 | 88 | let dt = sigmaPrev - sigma 89 | var xtPrev = (sigma.square() + 1).sqrt() * xt + epsPred * dt 90 | xtPrev = xtPrev * (sigmaPrev.square() + 1).rsqrt() 91 | 92 | return xtPrev 93 | } 94 | } 95 | 96 | class SimpleEulerAncestralSampler: SimpleEulerSampler { 97 | 98 | open override func step(epsPred: MLXArray, xt: MLXArray, t: MLXArray, tPrev: MLXArray) 99 | -> MLXArray 100 | { 101 | let dtype = epsPred.dtype 102 | let sigma = sigmas(t).asType(dtype) 103 | let sigmaPrev = sigmas(tPrev).asType(dtype) 104 | 105 | let sigma2 = sigma.square() 106 | let sigmaPrev2 = sigmaPrev.square() 107 | let sigmaUp = (sigmaPrev2 * (sigma2 - sigmaPrev2) / sigma2).sqrt() 108 | let sigmaDown = (sigmaPrev2 - sigmaUp ** 2).sqrt() 109 | 110 | let dt = sigmaDown - sigma 111 | var xtPrev = (sigma2 + 1).sqrt() * xt + epsPred * dt 112 | let noise = MLXRandom.normal(xtPrev.shape).asType(xtPrev.dtype) 113 | xtPrev = xtPrev + noise * sigmaUp 114 | xtPrev = xtPrev * (sigmaPrev2 + 1).rsqrt() 115 | 116 | return xtPrev 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /Data/lora/wikisql.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """ 4 | Code to preprocess the WikiSQL dataset adapted from 5 | https://github.com/salesforce/WikiSQL and 6 | https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb . 7 | """ 8 | 9 | 10 | import json 11 | import os 12 | 13 | 14 | def load(): 15 | """ 16 | Load all three splits of the WikiSQL dataset. 17 | """ 18 | return (WikiSQL(dn) for dn in ["train", "dev", "test"]) 19 | 20 | 21 | class WikiSQL: 22 | def __init__(self, dataset, save_dir="/tmp"): 23 | valid_sets = ("train", "dev", "test") 24 | if dataset not in valid_sets: 25 | raise ValueError(f"Dataset must be in {valid_sets}, got {dataset}") 26 | data_dir = os.path.join(save_dir, "wikisql") 27 | self._maybe_download(data_dir) 28 | 29 | self._parse_tables(os.path.join(data_dir, f"data/{dataset}.tables.jsonl")) 30 | self._parse_queries(os.path.join(data_dir, f"data/{dataset}.jsonl")) 31 | 32 | def _maybe_download(self, data_dir): 33 | if not os.path.exists(data_dir): 34 | import io 35 | import tarfile 36 | from urllib import request 37 | 38 | url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2" 39 | r = request.urlopen(url) 40 | with tarfile.open(fileobj=io.BytesIO(r.read())) as tf: 41 | tf.extractall(data_dir) 42 | 43 | def _parse_tables(self, tables): 44 | self._tables = {} 45 | with open(tables) as f: 46 | for line in f: 47 | table = json.loads(line) 48 | self._tables[table["id"]] = { 49 | "columns": table["header"], 50 | "types": table["types"], 51 | "desc": f"table: {table['id']}\ncolumns: {', '.join(table['header'])}", 52 | } 53 | 54 | def _parse_queries(self, queries): 55 | self._queries = [] 56 | with open(queries) as f: 57 | for line in f: 58 | query = json.loads(line) 59 | table = self._tables[query["table_id"]] 60 | question = query["question"] 61 | answer = self.query_to_text( 62 | query["sql"], query["table_id"], table["columns"], table["types"] 63 | ) 64 | self._queries.append( 65 | f"{table['desc']}\nQ: {question}\nA: {answer}" 66 | ) 67 | 68 | def query_to_text(self, query, table, columns, types): 69 | aggregation_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"] 70 | condition_ops = ["=", ">", "<", "OP"] 71 | column = columns[query["sel"]] 72 | aggregation = (aggregation_ops[query["agg"]] + " ") if query["agg"] > 0 else "" 73 | sql = f"SELECT {aggregation}{column} FROM {table}" 74 | 75 | conditions = query["conds"] 76 | if conditions: 77 | cs = [] 78 | for i, o, v in conditions: 79 | column = columns[i] 80 | op = condition_ops[o] 81 | 82 | if types[i] == "text": 83 | value = f"'{v}'" 84 | else: 85 | value = v 86 | cs.append(f"{column} {op} {value}") 87 | 88 | sql += " WHERE " + " AND ".join(cs) 89 | 90 | return sql 91 | 92 | def __getitem__(self, idx): 93 | return self._queries[idx] 94 | 95 | def __len__(self): 96 | return len(self._queries) 97 | 98 | 99 | if __name__ == "__main__": 100 | datanames = ["train", "dev", "test"] 101 | sizes = [56355, 8421, 15878] 102 | for dataname, size in zip(datanames, sizes): 103 | len(WikiSQL(dataname)) == size, f"Wrong {dataname} set size." 104 | 105 | # Write the sets to jsonl 106 | import json 107 | 108 | train, dev, test = load() 109 | datasets = [ 110 | (train, "train", 1000), 111 | (dev, "valid", 100), 112 | (test, "test", 100), 113 | ] 114 | for dataset, name, size in datasets: 115 | with open(f"data/{name}.jsonl", "w") as fid: 116 | for e, t in zip(range(size), dataset): 117 | # Strip the , since the tokenizer adds them 118 | json.dump({"text": t[3:-4]}, fid) 119 | fid.write("\n") 120 | -------------------------------------------------------------------------------- /Libraries/Embedders/Load.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | @preconcurrency import Hub 5 | import MLX 6 | import MLXNN 7 | import MLXRandom 8 | import Tokenizers 9 | 10 | struct EmbedderError: Error { 11 | let message: String 12 | } 13 | 14 | func prepareModelDirectory( 15 | hub: HubApi, configuration: ModelConfiguration, 16 | progressHandler: @Sendable @escaping (Progress) -> Void 17 | ) async throws -> URL { 18 | do { 19 | switch configuration.id { 20 | case .id(let id): 21 | // download the model weights 22 | let repo = Hub.Repo(id: id) 23 | let modelFiles = ["*.safetensors", "config.json"] 24 | return try await hub.snapshot( 25 | from: repo, matching: modelFiles, progressHandler: progressHandler) 26 | 27 | case .directory(let directory): 28 | return directory 29 | } 30 | } catch Hub.HubClientError.authorizationRequired { 31 | // an authorizationRequired means (typically) that the named repo doesn't exist on 32 | // on the server so retry with local only configuration 33 | return configuration.modelDirectory(hub: hub) 34 | } catch { 35 | let nserror = error as NSError 36 | if nserror.domain == NSURLErrorDomain && nserror.code == NSURLErrorNotConnectedToInternet { 37 | // Error Domain=NSURLErrorDomain Code=-1009 "The Internet connection appears to be offline." 38 | // fall back to the local directory 39 | return configuration.modelDirectory(hub: hub) 40 | } else { 41 | throw error 42 | } 43 | } 44 | } 45 | 46 | /// Load and return the model and tokenizer 47 | public func load( 48 | hub: HubApi = HubApi(), configuration: ModelConfiguration, 49 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } 50 | ) async throws -> (EmbeddingModel, Tokenizer) { 51 | let modelDirectory = try await prepareModelDirectory( 52 | hub: hub, configuration: configuration, progressHandler: progressHandler) 53 | let model = try loadSynchronous(modelDirectory: modelDirectory) 54 | let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub) 55 | 56 | return (model, tokenizer) 57 | } 58 | 59 | func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel { 60 | // create the model (no weights loaded) 61 | let configurationURL = modelDirectory.appending(component: "config.json") 62 | let baseConfig = try JSONDecoder().decode( 63 | BaseConfiguration.self, from: Data(contentsOf: configurationURL)) 64 | 65 | let model = try baseConfig.modelType.createModel(configuration: configurationURL) 66 | 67 | // load the weights 68 | var weights = [String: MLXArray]() 69 | let enumerator = FileManager.default.enumerator( 70 | at: modelDirectory, includingPropertiesForKeys: nil)! 71 | for case let url as URL in enumerator { 72 | if url.pathExtension == "safetensors" { 73 | let w = try loadArrays(url: url) 74 | for (key, value) in w { 75 | weights[key] = value 76 | } 77 | } 78 | } 79 | 80 | // per-model cleanup 81 | weights = model.sanitize(weights: weights) 82 | 83 | // quantize if needed 84 | if let quantization = baseConfig.quantization { 85 | quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) { 86 | path, module in 87 | weights["\(path).scales"] != nil 88 | } 89 | } 90 | 91 | // apply the loaded weights 92 | let parameters = ModuleParameters.unflattened(weights) 93 | try model.update(parameters: parameters, verify: [.all]) 94 | 95 | eval(model) 96 | 97 | return model 98 | } 99 | 100 | /// Load and return the model and tokenizer wrapped in a ``ModelContainer`` (provides 101 | /// thread safe access). 102 | public func loadModelContainer( 103 | hub: HubApi = HubApi(), configuration: ModelConfiguration, 104 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } 105 | ) async throws -> ModelContainer { 106 | let modelDirectory = try await prepareModelDirectory( 107 | hub: hub, configuration: configuration, progressHandler: progressHandler) 108 | return try await ModelContainer( 109 | hub: hub, modelDirectory: modelDirectory, configuration: configuration) 110 | } 111 | -------------------------------------------------------------------------------- /Tools/LinearModelTraining/LinearModelTraining.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | import MLX 6 | import MLXNN 7 | import MLXOptimizers 8 | import MLXRandom 9 | 10 | #if swift(>=5.10) 11 | extension MLX.DeviceType: @retroactive ExpressibleByArgument { 12 | public init?(argument: String) { 13 | self.init(rawValue: argument) 14 | } 15 | } 16 | #else 17 | extension MLX.DeviceType: ExpressibleByArgument { 18 | public init?(argument: String) { 19 | self.init(rawValue: argument) 20 | } 21 | } 22 | #endif 23 | 24 | @main 25 | struct Train: AsyncParsableCommand { 26 | 27 | @Option var epochs = 20 28 | @Option var batchSize = 8 29 | 30 | @Option var m: Float = 0.25 31 | @Option var b: Float = 7 32 | 33 | @Flag var compile = false 34 | 35 | @Option var device = DeviceType.cpu 36 | 37 | func run() async throws { 38 | Device.setDefault(device: Device(device)) 39 | 40 | // A very simple model that implements the equation 41 | // for a linear function: y = mx + b. This can be trained 42 | // to match data -- in this case an unknown (to the model) 43 | // linear function. 44 | // 45 | // This is a nice example because most people know how 46 | // linear functions work and we can see how the slope 47 | // and intercept converge. 48 | class LinearFunctionModel: Module, UnaryLayer { 49 | let m = MLXRandom.uniform(low: -5.0, high: 5.0) 50 | let b = MLXRandom.uniform(low: -5.0, high: 5.0) 51 | 52 | func callAsFunction(_ x: MLXArray) -> MLXArray { 53 | m * x + b 54 | } 55 | } 56 | 57 | // measure the distance from the prediction (model(x)) and the 58 | // ground truth (y). this gives feedback on how close the 59 | // prediction is from matching the truth 60 | func loss(model: LinearFunctionModel, x: MLXArray, y: MLXArray) -> MLXArray { 61 | mseLoss(predictions: model(x), targets: y, reduction: .mean) 62 | } 63 | 64 | let model = LinearFunctionModel() 65 | eval(model.parameters()) 66 | 67 | let lg = valueAndGrad(model: model, loss) 68 | 69 | // the optimizer will use the gradients update the model parameters 70 | let optimizer = SGD(learningRate: 1e-1) 71 | 72 | // the function to train our model against -- it doesn't have 73 | // to be linear, but matching what the model models is easy 74 | // to understand 75 | func f(_ x: MLXArray) -> MLXArray { 76 | // these are the target parameters 77 | let m = self.m 78 | let b = self.b 79 | 80 | // our actual function 81 | return m * x + b 82 | } 83 | 84 | func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray { 85 | let (loss, grads) = lg(model, x, y) 86 | optimizer.update(model: model, gradients: grads) 87 | return loss 88 | } 89 | 90 | let resolvedStep = 91 | self.compile 92 | ? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step 93 | 94 | for _ in 0 ..< epochs { 95 | // we expect that the parameters will approach the targets 96 | print("target: b = \(b), m = \(m)") 97 | print("parameters: \(model.parameters())") 98 | 99 | // generate random training data along with the ground truth. 100 | // notice that the shape is [B, 1] where B is the batch 101 | // dimension -- this allows us to train on several samples simultaneously 102 | // 103 | // note: a very large batch size will take longer to converge because 104 | // the gradient will be representing too many samples down into 105 | // a single float parameter. 106 | let x = MLXRandom.uniform(low: -5.0, high: 5.0, [batchSize, 1]) 107 | let y = f(x) 108 | eval(x, y) 109 | 110 | // compute the loss and gradients. use the optimizer 111 | // to adjust the parameters closer to the target 112 | let loss = resolvedStep(x, y) 113 | 114 | eval(model, optimizer) 115 | 116 | // we should see this converge toward 0 117 | print("loss: \(loss)") 118 | } 119 | 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /Libraries/StableDiffusion/Tokenizer.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | 5 | struct Bigram: Hashable { 6 | let a: String 7 | let b: String 8 | 9 | init(_ s: String) { 10 | let pieces = s.split(separator: " ") 11 | precondition(pieces.count == 2, "BPEPair expected two pieces for '\(s)'") 12 | self.a = String(pieces[0]) 13 | self.b = String(pieces[1]) 14 | } 15 | 16 | init(_ a: String, _ b: String) { 17 | self.a = a 18 | self.b = b 19 | } 20 | 21 | init(_ v: (String, String)) { 22 | self.a = v.0 23 | self.b = v.1 24 | } 25 | } 26 | 27 | /// A CLIP tokenizer. 28 | /// 29 | /// Ported from: 30 | /// 31 | /// - https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py 32 | /// - https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py 33 | /// 34 | /// Ideally this would be a tokenizer from `swift-transformers` but this is too special purpose to be representable in 35 | /// what exists there (at time of writing). 36 | class CLIPTokenizer { 37 | 38 | let pattern = 39 | #/<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+/# 40 | let bpeRanks: [Bigram: Int] 41 | let vocabulary: [String: Int] 42 | 43 | let bos = "<|startoftext|>" 44 | let eos = "<|endoftext|>" 45 | 46 | let bosToken: Int 47 | let eosToken: Int 48 | 49 | var cache = [String: [String]]() 50 | 51 | init(merges: [String], vocabulary: [String: Int]) { 52 | self.bpeRanks = Dictionary( 53 | uniqueKeysWithValues: 54 | merges 55 | .map { Bigram($0) } 56 | .enumerated() 57 | .map { ($0.element, $0.offset) }) 58 | 59 | self.vocabulary = vocabulary 60 | self.cache[bos] = [bos] 61 | self.cache[eos] = [eos] 62 | self.bosToken = vocabulary[bos]! 63 | self.eosToken = vocabulary[eos]! 64 | } 65 | 66 | func bpe(text: String) -> [String] { 67 | if let result = cache[text] { 68 | return result 69 | } 70 | 71 | precondition(!text.isEmpty) 72 | 73 | var unigrams = text.dropLast().map { String($0) } + ["\(text.last!)"] 74 | var uniqueBigrams = Set(zip(unigrams, unigrams.dropFirst()).map { Bigram($0) }) 75 | 76 | // In every iteration try to merge the two most likely bigrams. If none 77 | // was merged we are done 78 | 79 | while !uniqueBigrams.isEmpty { 80 | let (bigram, _) = 81 | uniqueBigrams 82 | .map { ($0, bpeRanks[$0] ?? Int.max) } 83 | .min { $0.1 < $1.1 }! 84 | 85 | if bpeRanks[bigram] == nil { 86 | break 87 | } 88 | 89 | var newUnigrams = [String]() 90 | var skip = false 91 | 92 | for (a, b) in zip(unigrams, unigrams.dropFirst()) { 93 | if skip { 94 | skip = false 95 | continue 96 | } 97 | 98 | if Bigram(a, b) == bigram { 99 | newUnigrams.append(a + b) 100 | skip = true 101 | } else { 102 | newUnigrams.append(a) 103 | } 104 | } 105 | 106 | if !skip, let last = unigrams.last { 107 | newUnigrams.append(last) 108 | } 109 | 110 | unigrams = newUnigrams 111 | uniqueBigrams = Set(zip(unigrams, unigrams.dropFirst()).map { Bigram($0) }) 112 | } 113 | 114 | cache[text] = unigrams 115 | 116 | return unigrams 117 | } 118 | 119 | public func tokenize(text: String) -> [Int32] { 120 | // Lower case cleanup and split according to self.pat. Hugging Face does 121 | // a much more thorough job here but this should suffice for 95% of 122 | // cases. 123 | 124 | let clean = text.lowercased().replacing(#/\s+/#, with: " ") 125 | let tokens = clean.matches(of: pattern).map { $0.description } 126 | 127 | // Split the tokens according to the byte-pair merge file 128 | let bpeTokens = tokens.flatMap { bpe(text: String($0)) } 129 | 130 | // Map to token ids and return 131 | let result = [bosToken] + bpeTokens.compactMap { vocabulary[$0] } + [eosToken] 132 | 133 | return result.map { Int32($0) } 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /Libraries/Embedders/Configuration.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | 5 | public enum StringOrNumber: Codable, Equatable, Sendable { 6 | case string(String) 7 | case float(Float) 8 | 9 | public init(from decoder: Decoder) throws { 10 | let values = try decoder.singleValueContainer() 11 | 12 | if let v = try? values.decode(Float.self) { 13 | self = .float(v) 14 | } else { 15 | let v = try values.decode(String.self) 16 | self = .string(v) 17 | } 18 | } 19 | 20 | public func encode(to encoder: Encoder) throws { 21 | var container = encoder.singleValueContainer() 22 | switch self { 23 | case .string(let v): try container.encode(v) 24 | case .float(let v): try container.encode(v) 25 | } 26 | } 27 | } 28 | 29 | private class ModelTypeRegistry: @unchecked Sendable { 30 | 31 | // Note: using NSLock as we have very small (just dictionary get/set) 32 | // critical sections and expect no contention. this allows the methods 33 | // to remain synchronous. 34 | private let lock = NSLock() 35 | 36 | private var creators: [String: @Sendable (URL) throws -> EmbeddingModel] = [ 37 | "bert": { 38 | url in 39 | let configuration = try JSONDecoder().decode( 40 | BertConfiguration.self, from: Data(contentsOf: url)) 41 | let model = BertModel(configuration) 42 | return model 43 | }, 44 | "roberta": { 45 | url in 46 | let configuration = try JSONDecoder().decode( 47 | BertConfiguration.self, from: Data(contentsOf: url)) 48 | let model = BertModel(configuration) 49 | return model 50 | }, 51 | "xlm-roberta": { 52 | url in 53 | let configuration = try JSONDecoder().decode( 54 | BertConfiguration.self, from: Data(contentsOf: url)) 55 | let model = BertModel(configuration) 56 | return model 57 | }, 58 | "distilbert": { 59 | url in 60 | let configuration = try JSONDecoder().decode( 61 | BertConfiguration.self, from: Data(contentsOf: url)) 62 | let model = BertModel(configuration) 63 | return model 64 | }, 65 | "nomic_bert": { 66 | url in 67 | let configuration = try JSONDecoder().decode( 68 | NomicBertConfiguration.self, from: Data(contentsOf: url)) 69 | let model = NomicBertModel(configuration) 70 | return model 71 | }, 72 | ] 73 | 74 | public func registerModelType( 75 | _ type: String, creator: @Sendable @escaping (URL) throws -> EmbeddingModel 76 | ) { 77 | lock.withLock { 78 | creators[type] = creator 79 | } 80 | } 81 | 82 | public func createModel(configuration: URL, rawValue: String) throws -> EmbeddingModel { 83 | let creator = lock.withLock { 84 | creators[rawValue] 85 | } 86 | guard let creator else { 87 | throw EmbedderError(message: "Unsupported model type.") 88 | } 89 | return try creator(configuration) 90 | } 91 | 92 | } 93 | 94 | private let modelTypeRegistry = ModelTypeRegistry() 95 | 96 | public struct ModelType: RawRepresentable, Codable, Sendable { 97 | public let rawValue: String 98 | 99 | public init(rawValue: String) { 100 | self.rawValue = rawValue 101 | } 102 | 103 | public static func registerModelType( 104 | _ type: String, creator: @Sendable @escaping (URL) throws -> EmbeddingModel 105 | ) { 106 | modelTypeRegistry.registerModelType(type, creator: creator) 107 | } 108 | 109 | public func createModel(configuration: URL) throws -> EmbeddingModel { 110 | try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue) 111 | } 112 | } 113 | 114 | public struct BaseConfiguration: Codable, Sendable { 115 | public let modelType: ModelType 116 | 117 | public struct Quantization: Codable, Sendable { 118 | public init(groupSize: Int, bits: Int) { 119 | self.groupSize = groupSize 120 | self.bits = bits 121 | } 122 | 123 | let groupSize: Int 124 | let bits: Int 125 | 126 | enum CodingKeys: String, CodingKey { 127 | case groupSize = "group_size" 128 | case bits = "bits" 129 | } 130 | } 131 | 132 | public var quantization: Quantization? 133 | 134 | enum CodingKeys: String, CodingKey { 135 | case modelType = "model_type" 136 | case quantization 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/ModelFactory.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import Hub 5 | import Tokenizers 6 | 7 | public enum ModelFactoryError: LocalizedError { 8 | case unsupportedModelType(String) 9 | case unsupportedProcessorType(String) 10 | 11 | public var errorDescription: String? { 12 | switch self { 13 | case .unsupportedModelType(let type): "Unsupported model type: \(type)" 14 | case .unsupportedProcessorType(let type): "Unsupported processor type: \(type)" 15 | } 16 | } 17 | } 18 | 19 | /// Context of types that work together to provide a ``LanguageModel``. 20 | /// 21 | /// A ``ModelContext`` is created by ``ModelFactory/load(hub:configuration:progressHandler:)``. 22 | /// This contains the following: 23 | /// 24 | /// - ``ModelConfiguration`` -- identifier for the model 25 | /// - ``LanguageModel`` -- the model itself, see ``generate(input:parameters:context:didGenerate:)`` 26 | /// - ``UserInputProcessor`` -- can convert ``UserInput`` into ``LMInput`` 27 | /// - `Tokenizer` -- the tokenizer used by ``UserInputProcessor`` 28 | /// 29 | /// See also ``ModelFactory/loadContainer(hub:configuration:progressHandler:)`` and 30 | /// ``ModelContainer``. 31 | public struct ModelContext { 32 | public var configuration: ModelConfiguration 33 | public var model: any LanguageModel 34 | public var processor: any UserInputProcessor 35 | public var tokenizer: Tokenizer 36 | 37 | public init( 38 | configuration: ModelConfiguration, model: any LanguageModel, 39 | processor: any UserInputProcessor, tokenizer: any Tokenizer 40 | ) { 41 | self.configuration = configuration 42 | self.model = model 43 | self.processor = processor 44 | self.tokenizer = tokenizer 45 | } 46 | } 47 | 48 | public protocol ModelFactory: Sendable { 49 | 50 | var modelRegistry: AbstractModelRegistry { get } 51 | 52 | func _load( 53 | hub: HubApi, configuration: ModelConfiguration, 54 | progressHandler: @Sendable @escaping (Progress) -> Void 55 | ) async throws -> ModelContext 56 | 57 | func _loadContainer( 58 | hub: HubApi, configuration: ModelConfiguration, 59 | progressHandler: @Sendable @escaping (Progress) -> Void 60 | ) async throws -> ModelContainer 61 | 62 | } 63 | 64 | extension ModelFactory { 65 | 66 | /// Resolve a model identifier, e.g. "mlx-community/Llama-3.2-3B-Instruct-4bit", into 67 | /// a ``ModelConfiguration``. 68 | /// 69 | /// This will either create a new (mostly unconfigured) ``ModelConfiguration`` or 70 | /// return a registered instance that matches the id. 71 | /// 72 | /// - Note: If the id doesn't exists in the configuration, this will return a new instance of it. 73 | /// If you want to check if the configuration in model registry, you should use ``contains(id:)``. 74 | public func configuration(id: String) -> ModelConfiguration { 75 | modelRegistry.configuration(id: id) 76 | } 77 | 78 | /// Returns true if ``modelRegistry`` contains a model with the id. Otherwise, false. 79 | public func contains(id: String) -> Bool { 80 | modelRegistry.contains(id: id) 81 | } 82 | 83 | } 84 | 85 | extension ModelFactory { 86 | 87 | /// Load a model identified by a ``ModelConfiguration`` and produce a ``ModelContext``. 88 | /// 89 | /// This method returns a ``ModelContext``. See also 90 | /// ``loadContainer(hub:configuration:progressHandler:)`` for a method that 91 | /// returns a ``ModelContainer``. 92 | public func load( 93 | hub: HubApi = HubApi(), configuration: ModelConfiguration, 94 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } 95 | ) async throws -> ModelContext { 96 | try await _load(hub: hub, configuration: configuration, progressHandler: progressHandler) 97 | } 98 | 99 | /// Load a model identified by a ``ModelConfiguration`` and produce a ``ModelContainer``. 100 | public func loadContainer( 101 | hub: HubApi = HubApi(), configuration: ModelConfiguration, 102 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } 103 | ) async throws -> ModelContainer { 104 | try await _loadContainer( 105 | hub: hub, configuration: configuration, progressHandler: progressHandler) 106 | } 107 | 108 | public func _loadContainer( 109 | hub: HubApi = HubApi(), configuration: ModelConfiguration, 110 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } 111 | ) async throws -> ModelContainer { 112 | let context = try await _load( 113 | hub: hub, configuration: configuration, progressHandler: progressHandler) 114 | return ModelContainer(context: context) 115 | } 116 | 117 | } 118 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/ContentView.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import MLX 4 | import MLXMNIST 5 | import MLXNN 6 | import MLXOptimizers 7 | import MLXRandom 8 | import SwiftUI 9 | 10 | struct TrainingView: View { 11 | 12 | @Binding var trainer: ModelState 13 | 14 | var body: some View { 15 | VStack { 16 | Spacer() 17 | 18 | ScrollView(.vertical) { 19 | ForEach(trainer.messages, id: \.self) { 20 | Text($0) 21 | } 22 | } 23 | 24 | HStack { 25 | Spacer() 26 | switch trainer.state { 27 | case .untrained: 28 | Button("Train") { 29 | Task { 30 | try! await trainer.train() 31 | } 32 | } 33 | case .trained(let model), .predict(let model): 34 | Button("Draw a digit") { 35 | trainer.state = .predict(model) 36 | } 37 | } 38 | 39 | Spacer() 40 | } 41 | Spacer() 42 | } 43 | .padding() 44 | } 45 | } 46 | 47 | struct ContentView: View { 48 | // the training loop 49 | @State var trainer = ModelState() 50 | 51 | var body: some View { 52 | switch trainer.state { 53 | case .untrained, .trained: 54 | TrainingView(trainer: $trainer) 55 | case .predict(let model): 56 | PredictionView(model: model) 57 | } 58 | } 59 | } 60 | 61 | @MainActor 62 | @Observable 63 | class ModelState { 64 | 65 | enum State { 66 | case untrained 67 | case trained(LeNetContainer) 68 | case predict(LeNetContainer) 69 | } 70 | 71 | var state: State = .untrained 72 | var messages = [String]() 73 | 74 | func train() async throws { 75 | let model = LeNetContainer() 76 | try await model.train(output: self) 77 | self.state = .trained(model) 78 | } 79 | } 80 | 81 | actor LeNetContainer { 82 | 83 | private let model = LeNet() 84 | 85 | let mnistImageSize: CGSize = CGSize(width: 28, height: 28) 86 | 87 | func train(output: ModelState) async throws { 88 | // Note: this is pretty close to the code in `mnist-tool`, just 89 | // wrapped in an Observable to make it easy to display in SwiftUI 90 | 91 | // download & load the training data 92 | let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true) 93 | try await download(into: url) 94 | let data = try load(from: url) 95 | 96 | let trainImages = data[.init(.training, .images)]! 97 | let trainLabels = data[.init(.training, .labels)]! 98 | let testImages = data[.init(.test, .images)]! 99 | let testLabels = data[.init(.test, .labels)]! 100 | 101 | eval(model.parameters()) 102 | 103 | // the training loop 104 | let lg = valueAndGrad(model: model, loss) 105 | let optimizer = SGD(learningRate: 0.1) 106 | 107 | // using a consistent random seed so it behaves the same way each time 108 | MLXRandom.seed(0) 109 | var generator: RandomNumberGenerator = SplitMix64(seed: 0) 110 | 111 | for e in 0 ..< 10 { 112 | let start = Date.timeIntervalSinceReferenceDate 113 | 114 | for (x, y) in iterateBatches( 115 | batchSize: 256, x: trainImages, y: trainLabels, using: &generator) 116 | { 117 | // loss and gradients 118 | let (_, grads) = lg(model, x, y) 119 | 120 | // use SGD to update the weights 121 | optimizer.update(model: model, gradients: grads) 122 | 123 | // eval the parameters so the next iteration is independent 124 | eval(model, optimizer) 125 | } 126 | 127 | let accuracy = eval(model: model, x: testImages, y: testLabels) 128 | 129 | let end = Date.timeIntervalSinceReferenceDate 130 | 131 | // add to messages -- triggers display 132 | let accuracyItem = accuracy.item(Float.self) 133 | await MainActor.run { 134 | output.messages.append( 135 | """ 136 | Epoch \(e): test accuracy \(accuracyItem.formatted()) 137 | Time: \((end - start).formatted()) 138 | 139 | """ 140 | ) 141 | } 142 | } 143 | } 144 | 145 | func evaluate(image: CGImage) -> Int? { 146 | let pixelData = image.grayscaleImage(with: mnistImageSize)?.pixelData() 147 | if let pixelData { 148 | let x = pixelData.reshaped([1, 28, 28, 1]).asType(.float32) / 255.0 149 | return argMax(model(x)).item() 150 | } else { 151 | return nil 152 | } 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /Libraries/MLXLMCommon/Tokenizer.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import Hub 5 | import Tokenizers 6 | 7 | struct TokenizerError: Error { 8 | let message: String 9 | } 10 | 11 | public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer 12 | { 13 | let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig( 14 | configuration: configuration, hub: hub) 15 | 16 | return try PreTrainedTokenizer( 17 | tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) 18 | } 19 | 20 | public func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async throws -> ( 21 | Config, Config 22 | ) { 23 | // from AutoTokenizer.from() -- this lets us override parts of the configuration 24 | let config: LanguageModelConfigurationFromHub 25 | 26 | switch configuration.id { 27 | case .id(let id): 28 | do { 29 | // the load can fail (async when we try to use it) 30 | let loaded = LanguageModelConfigurationFromHub( 31 | modelName: configuration.tokenizerId ?? id, hubApi: hub) 32 | _ = try await loaded.tokenizerConfig 33 | config = loaded 34 | } catch { 35 | let nserror = error as NSError 36 | if nserror.domain == NSURLErrorDomain 37 | && nserror.code == NSURLErrorNotConnectedToInternet 38 | { 39 | // Internet connection appears to be offline -- fall back to loading from 40 | // the local directory 41 | config = LanguageModelConfigurationFromHub( 42 | modelFolder: configuration.modelDirectory(hub: hub), hubApi: hub) 43 | } else { 44 | throw error 45 | } 46 | } 47 | case .directory(let directory): 48 | config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub) 49 | } 50 | 51 | guard var tokenizerConfig = try await config.tokenizerConfig else { 52 | throw TokenizerError(message: "missing config") 53 | } 54 | let tokenizerData = try await config.tokenizerData 55 | 56 | tokenizerConfig = updateTokenizerConfig(tokenizerConfig) 57 | 58 | return (tokenizerConfig, tokenizerData) 59 | } 60 | 61 | private func updateTokenizerConfig(_ tokenizerConfig: Config) -> Config { 62 | // workaround: replacement tokenizers for unhandled values in swift-transform 63 | if let tokenizerClass = tokenizerConfig.tokenizerClass?.stringValue, 64 | let replacement = replacementTokenizers[tokenizerClass] 65 | { 66 | var dictionary = tokenizerConfig.dictionary 67 | dictionary["tokenizer_class"] = replacement 68 | return Config(dictionary) 69 | } 70 | 71 | return tokenizerConfig 72 | } 73 | 74 | public class TokenizerReplacementRegistry: @unchecked Sendable { 75 | 76 | // Note: using NSLock as we have very small (just dictionary get/set) 77 | // critical sections and expect no contention. this allows the methods 78 | // to remain synchronous. 79 | private let lock = NSLock() 80 | 81 | /// overrides for TokenizerModel/knownTokenizers 82 | private var replacementTokenizers = [ 83 | "InternLM2Tokenizer": "PreTrainedTokenizer", 84 | "Qwen2Tokenizer": "PreTrainedTokenizer", 85 | "CohereTokenizer": "PreTrainedTokenizer", 86 | ] 87 | 88 | public subscript(key: String) -> String? { 89 | get { 90 | lock.withLock { 91 | replacementTokenizers[key] 92 | } 93 | } 94 | set { 95 | lock.withLock { 96 | replacementTokenizers[key] = newValue 97 | } 98 | } 99 | } 100 | } 101 | 102 | public let replacementTokenizers = TokenizerReplacementRegistry() 103 | 104 | public protocol StreamingDetokenizer: IteratorProtocol { 105 | 106 | mutating func append(token: Int) 107 | 108 | } 109 | 110 | public struct NaiveStreamingDetokenizer: StreamingDetokenizer { 111 | let tokenizer: Tokenizer 112 | 113 | var segmentTokens = [Int]() 114 | var segment = "" 115 | 116 | public init(tokenizer: Tokenizer) { 117 | self.tokenizer = tokenizer 118 | } 119 | 120 | mutating public func append(token: Int) { 121 | segmentTokens.append(token) 122 | } 123 | 124 | mutating func startNewSegment() { 125 | let lastToken = segmentTokens.last 126 | segmentTokens.removeAll() 127 | if let lastToken { 128 | segmentTokens.append(lastToken) 129 | segment = tokenizer.decode(tokens: segmentTokens) 130 | } else { 131 | segment = "" 132 | } 133 | } 134 | 135 | public mutating func next() -> String? { 136 | let newSegment = tokenizer.decode(tokens: segmentTokens) 137 | let new = newSegment.suffix(newSegment.count - segment.count) 138 | 139 | if new.hasSuffix("\n") { 140 | startNewSegment() 141 | } else { 142 | self.segment = newSegment 143 | } 144 | 145 | return String(new) 146 | } 147 | 148 | } 149 | -------------------------------------------------------------------------------- /Libraries/StableDiffusion/Clip.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import MLX 5 | import MLXNN 6 | 7 | // port of https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py 8 | 9 | struct CLIPOutput { 10 | /// The lastHiddenState indexed at the EOS token and possibly projected if 11 | /// the model has a projection layer 12 | public var pooledOutput: MLXArray 13 | 14 | /// The full sequence output of the transformer after the final layernorm 15 | public var lastHiddenState: MLXArray 16 | 17 | /// A list of hidden states corresponding to the outputs of the transformer layers 18 | public var hiddenStates: [MLXArray] 19 | } 20 | 21 | /// The transformer encoder layer from CLIP 22 | class CLIPEncoderLayer: Module { 23 | 24 | @ModuleInfo(key: "layer_norm1") var layerNorm1: LayerNorm 25 | @ModuleInfo(key: "layer_norm2") var layerNorm2: LayerNorm 26 | 27 | let attention: MultiHeadAttention 28 | 29 | @ModuleInfo var linear1: Linear 30 | @ModuleInfo var linear2: Linear 31 | 32 | let activation: (MLXArray) -> MLXArray 33 | 34 | init(modelDimensions: Int, numHeads: Int, activation: @escaping (MLXArray) -> MLXArray) { 35 | self._layerNorm1.wrappedValue = LayerNorm(dimensions: modelDimensions) 36 | self._layerNorm2.wrappedValue = LayerNorm(dimensions: modelDimensions) 37 | 38 | self.attention = MultiHeadAttention( 39 | dimensions: modelDimensions, numHeads: numHeads, bias: true) 40 | 41 | self.linear1 = Linear(modelDimensions, 4 * modelDimensions) 42 | self.linear2 = Linear(4 * modelDimensions, modelDimensions) 43 | 44 | self.activation = activation 45 | } 46 | 47 | func callAsFunction(_ x: MLXArray, attentionMask: MLXArray? = nil) -> MLXArray { 48 | var y = layerNorm1(x) 49 | y = attention(y, keys: y, values: y, mask: attentionMask) 50 | var x = y + x 51 | 52 | y = layerNorm2(x) 53 | y = linear1(y) 54 | y = activation(y) 55 | y = linear2(y) 56 | x = y + x 57 | 58 | return x 59 | } 60 | } 61 | 62 | /// Implements the text encoder transformer from CLIP 63 | class CLIPTextModel: Module { 64 | 65 | @ModuleInfo(key: "token_embedding") var tokenEmbedding: Embedding 66 | @ModuleInfo(key: "position_embedding") var positionEmbedding: Embedding 67 | 68 | let layers: [CLIPEncoderLayer] 69 | 70 | @ModuleInfo(key: "final_layer_norm") var finalLayerNorm: LayerNorm 71 | 72 | @ModuleInfo(key: "text_projection") var textProjection: Linear? 73 | 74 | init(configuration: CLIPTextModelConfiguration) { 75 | self._tokenEmbedding.wrappedValue = Embedding( 76 | embeddingCount: configuration.vocabularySize, dimensions: configuration.modelDimensions) 77 | self._positionEmbedding.wrappedValue = Embedding( 78 | embeddingCount: configuration.maxLength, dimensions: configuration.modelDimensions) 79 | 80 | self.layers = (0 ..< configuration.numLayers) 81 | .map { _ in 82 | CLIPEncoderLayer( 83 | modelDimensions: configuration.modelDimensions, 84 | numHeads: configuration.numHeads, 85 | activation: configuration.hiddenActivation.activation) 86 | } 87 | 88 | self._finalLayerNorm.wrappedValue = LayerNorm(dimensions: configuration.modelDimensions) 89 | 90 | if let projectionDimensions = configuration.projectionDimensions { 91 | self._textProjection.wrappedValue = Linear( 92 | configuration.modelDimensions, projectionDimensions, bias: false) 93 | } else { 94 | self._textProjection.wrappedValue = nil 95 | } 96 | } 97 | 98 | func mask(_ N: Int, _ dType: DType) -> MLXArray { 99 | let indices = MLXArray(0 ..< Int32(N)) 100 | var mask = indices[0..., .newAxis] .< indices[.newAxis] 101 | mask = mask.asType(dType) * (dType == .float16 ? -6e4 : -1e9) 102 | return mask 103 | } 104 | 105 | func callAsFunction(_ x: MLXArray) -> CLIPOutput { 106 | var x = x 107 | let (_, N) = x.shape2 108 | let eosTokens = x.argMax(axis: -1) 109 | 110 | // compute the embeddings 111 | x = tokenEmbedding(x) 112 | x = x + positionEmbedding.weight[.. LoRALinearLayers { 94 | // TODO: modify as needed 95 | model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } 96 | } 97 | 98 | public init(_ args: YourModelConfiguration) { 99 | self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers) 100 | self.model = YourModelInner(args) 101 | } 102 | 103 | public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { 104 | // TODO: modify as needed 105 | let out = model(inputs, cache: cache) 106 | return model.embedTokens.asLinear(out) 107 | } 108 | } 109 | ``` 110 | 111 | ## Register the Model 112 | 113 | In [LLMModelFactory.swift](LLMModelFactory.swift) register the model type itself 114 | (this is independent of the model id): 115 | 116 | ```swift 117 | public class ModelTypeRegistry: @unchecked Sendable { 118 | ... 119 | private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [ 120 | "yourModel": create(YourModelConfiguration.self, YourModel.init), 121 | ``` 122 | 123 | Add a constant for the model in the `ModelRegistry` (not strictly required but useful 124 | for callers to refer to it in code): 125 | 126 | ```swift 127 | public class ModelRegistry: @unchecked Sendable { 128 | ... 129 | static public let yourModel_4bit = ModelConfiguration( 130 | id: "mlx-community/YourModel-4bit", 131 | defaultPrompt: "What is the gravity on Mars and the moon?" 132 | ) 133 | ``` 134 | 135 | and finally add it to the all list -- this will let users find the model 136 | configuration by id: 137 | 138 | ```swift 139 | private static func all() -> [ModelConfiguration] { 140 | [ 141 | codeLlama13b4bit, 142 | ... 143 | yourModel_4bit, 144 | ``` 145 | 146 | # Using a Model 147 | 148 | See [MLXLMCommon/README.md](../MLXLMCommon/README.md#using-a-model). 149 | 150 | # LoRA 151 | 152 | [Lora.swift](Lora.swift) contains an implementation of LoRA based on this example: 153 | 154 | - https://github.com/ml-explore/mlx-examples/tree/main/lora 155 | 156 | See [llm-tool/LoraCommands.swift](../../Tools/llm-tool/LoraCommands.swift) for an example of a driver and 157 | [llm-tool](../../Tools/llm-tool) for examples of how to run it. 158 | -------------------------------------------------------------------------------- /Libraries/MLXLLM/SwitchLayers.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLX 3 | import MLXFast 4 | import MLXNN 5 | import MLXRandom 6 | 7 | // Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/switch_layers.py 8 | 9 | // MARK: - SwitchGLU 10 | 11 | class SwitchGLU: Module { 12 | @ModuleInfo(key: "gate_proj") var gateProj: SwitchLinear 13 | @ModuleInfo(key: "up_proj") var upProj: SwitchLinear 14 | @ModuleInfo(key: "down_proj") var downProj: SwitchLinear 15 | 16 | let inputDims: Int 17 | let hiddenDims: Int 18 | let numExperts: Int 19 | let activation: (MLXArray) -> MLXArray 20 | 21 | init( 22 | inputDims: Int, 23 | hiddenDims: Int, 24 | numExperts: Int, 25 | activation: @escaping (MLXArray) -> MLXArray = MLXNN.silu, 26 | bias: Bool = false 27 | ) { 28 | self.inputDims = inputDims 29 | self.hiddenDims = hiddenDims 30 | self.numExperts = numExperts 31 | self.activation = activation 32 | 33 | self._gateProj.wrappedValue = SwitchLinear( 34 | inputDims: inputDims, outputDims: hiddenDims, numExperts: numExperts, bias: bias) 35 | self._upProj.wrappedValue = SwitchLinear( 36 | inputDims: inputDims, outputDims: hiddenDims, numExperts: numExperts, bias: bias) 37 | self._downProj.wrappedValue = SwitchLinear( 38 | inputDims: hiddenDims, outputDims: inputDims, numExperts: numExperts, bias: bias) 39 | 40 | super.init() 41 | } 42 | 43 | func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { 44 | let x = MLX.expandedDimensions(x, axes: [-2, -3]) 45 | 46 | let xUp = upProj(x, indices) 47 | let xGate = gateProj(x, indices) 48 | let xDown = downProj(activation(xGate) * xUp, indices) 49 | 50 | return MLX.squeezed(xDown, axis: -2) 51 | } 52 | } 53 | 54 | class SwitchLinear: Module, Quantizable { 55 | @ModuleInfo(key: "weight") var weight: MLXArray 56 | @ModuleInfo(key: "bias") var bias: MLXArray? 57 | 58 | let inputDims: Int 59 | let outputDims: Int 60 | let numExperts: Int 61 | 62 | init(inputDims: Int, outputDims: Int, numExperts: Int, bias: Bool = true) { 63 | self.inputDims = inputDims 64 | self.outputDims = outputDims 65 | self.numExperts = numExperts 66 | 67 | let scale = sqrt(1.0 / Float(inputDims)) 68 | self._weight.wrappedValue = MLXRandom.uniform( 69 | low: -scale, 70 | high: scale, 71 | [numExperts, outputDims, inputDims] 72 | ) 73 | 74 | if bias { 75 | self._bias.wrappedValue = MLXArray.zeros([numExperts, outputDims]) 76 | } 77 | 78 | super.init() 79 | } 80 | 81 | /// Initializer meant for subclasses to provide weight and bias arrays directly. 82 | /// 83 | /// This is used e.g. by ``QuantizedSwitchLinear`` to provide quantized weights and biases 84 | /// rather than have ``SwitchLinear`` compute them. 85 | init( 86 | inputDims: Int, outputDims: Int, numExperts: Int, 87 | weight: MLXArray, bias: MLXArray? = nil 88 | ) { 89 | self.inputDims = inputDims 90 | self.outputDims = outputDims 91 | self.numExperts = numExperts 92 | 93 | self._weight.wrappedValue = weight 94 | self._bias.wrappedValue = bias 95 | } 96 | 97 | func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { 98 | let weightT = self.weight.swappedAxes(-1, -2) 99 | var result = MLX.gatherMatmul(x, weightT, rhsIndices: indices) 100 | 101 | if let bias = self.bias { 102 | result = result + MLX.expandedDimensions(bias[indices], axis: -2) 103 | } 104 | 105 | return result 106 | } 107 | 108 | func toQuantized(groupSize: Int = 64, bits: Int = 4) -> Module { 109 | QuantizedSwitchLinear(self, groupSize: groupSize, bits: bits) 110 | } 111 | } 112 | 113 | class QuantizedSwitchLinear: SwitchLinear { 114 | @ModuleInfo(key: "scales") var scales: MLXArray 115 | @ModuleInfo(key: "biases") var biases: MLXArray 116 | 117 | let groupSize: Int 118 | let bits: Int 119 | 120 | init(_ other: SwitchLinear, groupSize: Int = 64, bits: Int = 4) { 121 | self.groupSize = groupSize 122 | self.bits = bits 123 | 124 | let (quantizedWeight, scales, biases) = MLX.quantized( 125 | other.weight, groupSize: groupSize, bits: bits) 126 | 127 | self._scales.wrappedValue = scales 128 | self._biases.wrappedValue = biases 129 | 130 | super.init( 131 | inputDims: other.inputDims, outputDims: other.outputDims, numExperts: other.numExperts, 132 | weight: quantizedWeight, bias: other.bias) 133 | 134 | self.freeze() 135 | } 136 | 137 | override func callAsFunction(_ x: MLXArray, _ indices: MLXArray) -> MLXArray { 138 | var result = MLX.gatherQuantizedMatmul( 139 | x, 140 | self.weight, 141 | scales: self.scales, 142 | biases: self.biases, 143 | rhsIndices: indices, 144 | transpose: true, 145 | groupSize: self.groupSize, 146 | bits: self.bits 147 | ) 148 | 149 | if let bias = self.bias { 150 | result = result + MLX.expandedDimensions(bias[indices], axis: -2) 151 | } 152 | 153 | return result 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /Libraries/Embedders/Models.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import Hub 5 | 6 | /// Registry of models and any overrides that go with them, e.g. prompt augmentation. 7 | /// If asked for an unknown configuration this will use the model/tokenizer as-is. 8 | /// 9 | /// The python tokenizers have a very rich set of implementations and configuration. The 10 | /// swift-tokenizers code handles a good chunk of that and this is a place to augment that 11 | /// implementation, if needed. 12 | public struct ModelConfiguration: Sendable { 13 | 14 | public enum Identifier: Sendable { 15 | case id(String) 16 | case directory(URL) 17 | } 18 | 19 | public var id: Identifier 20 | 21 | public var name: String { 22 | switch id { 23 | case .id(let string): 24 | string 25 | case .directory(let url): 26 | url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent 27 | } 28 | } 29 | 30 | /// pull the tokenizer from an alternate id 31 | public let tokenizerId: String? 32 | 33 | /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated 34 | public let overrideTokenizer: String? 35 | 36 | public init( 37 | id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil 38 | ) { 39 | self.id = .id(id) 40 | self.tokenizerId = tokenizerId 41 | self.overrideTokenizer = overrideTokenizer 42 | } 43 | 44 | public init( 45 | directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil 46 | ) { 47 | self.id = .directory(directory) 48 | self.tokenizerId = tokenizerId 49 | self.overrideTokenizer = overrideTokenizer 50 | } 51 | 52 | public func modelDirectory(hub: HubApi = HubApi()) -> URL { 53 | switch id { 54 | case .id(let id): 55 | // download the model weights and config 56 | let repo = Hub.Repo(id: id) 57 | return hub.localRepoLocation(repo) 58 | 59 | case .directory(let directory): 60 | return directory 61 | } 62 | } 63 | 64 | @MainActor 65 | public static var registry = [String: ModelConfiguration]() 66 | 67 | @MainActor 68 | public static func register(configurations: [ModelConfiguration]) { 69 | bootstrap() 70 | 71 | for c in configurations { 72 | registry[c.name] = c 73 | } 74 | } 75 | 76 | @MainActor 77 | public static func configuration(id: String) -> ModelConfiguration { 78 | bootstrap() 79 | 80 | if let c = registry[id] { 81 | return c 82 | } else { 83 | return ModelConfiguration(id: id) 84 | } 85 | } 86 | 87 | @MainActor 88 | public static var models: some Collection & Sendable { 89 | bootstrap() 90 | return Self.registry.values 91 | } 92 | } 93 | 94 | extension ModelConfiguration { 95 | public static let bge_micro = ModelConfiguration(id: "TaylorAI/bge-micro-v2") 96 | public static let gte_tiny = ModelConfiguration(id: "TaylorAI/gte-tiny") 97 | public static let minilm_l6 = ModelConfiguration(id: "sentence-transformers/all-MiniLM-L6-v2") 98 | public static let snowflake_xs = ModelConfiguration(id: "Snowflake/snowflake-arctic-embed-xs") 99 | public static let minilm_l12 = ModelConfiguration(id: "sentence-transformers/all-MiniLM-L12-v2") 100 | public static let bge_small = ModelConfiguration(id: "BAAI/bge-small-en-v1.5") 101 | public static let multilingual_e5_small = ModelConfiguration( 102 | id: "intfloat/multilingual-e5-small") 103 | public static let bge_base = ModelConfiguration(id: "BAAI/bge-base-en-v1.5") 104 | public static let nomic_text_v1 = ModelConfiguration(id: "nomic-ai/nomic-embed-text-v1") 105 | public static let nomic_text_v1_5 = ModelConfiguration(id: "nomic-ai/nomic-embed-text-v1.5") 106 | public static let bge_large = ModelConfiguration(id: "BAAI/bge-large-en-v1.5") 107 | public static let snowflake_lg = ModelConfiguration(id: "Snowflake/snowflake-arctic-embed-l") 108 | public static let bge_m3 = ModelConfiguration(id: "BAAI/bge-m3") 109 | public static let mixedbread_large = ModelConfiguration( 110 | id: "mixedbread-ai/mxbai-embed-large-v1") 111 | 112 | private enum BootstrapState: Sendable { 113 | case idle 114 | case bootstrapping 115 | case bootstrapped 116 | } 117 | 118 | @MainActor 119 | static private var bootstrapState = BootstrapState.idle 120 | 121 | @MainActor 122 | static func bootstrap() { 123 | switch bootstrapState { 124 | case .idle: 125 | bootstrapState = .bootstrapping 126 | register(configurations: [ 127 | bge_micro, 128 | gte_tiny, 129 | minilm_l6, 130 | snowflake_xs, 131 | minilm_l12, 132 | bge_small, 133 | multilingual_e5_small, 134 | bge_base, 135 | nomic_text_v1, 136 | nomic_text_v1_5, 137 | bge_large, 138 | snowflake_lg, 139 | bge_m3, 140 | mixedbread_large, 141 | ]) 142 | bootstrapState = .bootstrapped 143 | 144 | case .bootstrapping: 145 | break 146 | 147 | case .bootstrapped: 148 | break 149 | } 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /Libraries/StableDiffusion/Image.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import CoreGraphics 4 | import CoreImage 5 | import Foundation 6 | import ImageIO 7 | import MLX 8 | import UniformTypeIdentifiers 9 | 10 | enum ImageError: LocalizedError { 11 | case failedToSave 12 | case unableToOpen 13 | 14 | var errorDescription: String? { 15 | switch self { 16 | case .failedToSave: 17 | return String(localized: "Failed to save the image to the specified location.") 18 | case .unableToOpen: 19 | return String(localized: "Unable to open the image file.") 20 | } 21 | } 22 | } 23 | 24 | /// Conversion utilities for moving between `MLXArray`, `CGImage` and files. 25 | public struct Image { 26 | 27 | public let data: MLXArray 28 | 29 | /// Create an Image from a MLXArray with ndim == 3 30 | public init(_ data: MLXArray) { 31 | precondition(data.ndim == 3) 32 | self.data = data 33 | } 34 | 35 | /// Create an Image by loading from a file 36 | public init(url: URL, maximumEdge: Int? = nil) throws { 37 | guard let source = CGImageSourceCreateWithURL(url as CFURL, nil), 38 | let image = CGImageSourceCreateImageAtIndex(source, 0, nil) 39 | else { 40 | throw ImageError.unableToOpen 41 | } 42 | 43 | self.init(image: image) 44 | } 45 | 46 | /// Create an image from a CGImage 47 | public init(image: CGImage, maximumEdge: Int? = nil) { 48 | // ensure the sizes ar multiples of 64 -- this doesn't worry about 49 | // the aspect ratio 50 | 51 | var width = image.width 52 | var height = image.height 53 | 54 | if let maximumEdge { 55 | func scale(_ edge: Int, _ maxEdge: Int) -> Int { 56 | Int(round(Float(maximumEdge) / Float(maxEdge) * Float(edge))) 57 | } 58 | 59 | // aspect fit inside the given maximum 60 | if width >= height { 61 | width = scale(width, image.width) 62 | height = scale(height, image.width) 63 | } else { 64 | width = scale(width, image.height) 65 | height = scale(height, image.height) 66 | } 67 | } 68 | 69 | // size must be multiples of 64 -- coerce without regard to aspect ratio 70 | width = width - width % 64 71 | height = height - height % 64 72 | 73 | var raster = Data(count: width * 4 * height) 74 | raster.withUnsafeMutableBytes { ptr in 75 | let cs = CGColorSpace(name: CGColorSpace.sRGB)! 76 | let context = CGContext( 77 | data: ptr.baseAddress, width: width, height: height, bitsPerComponent: 8, 78 | bytesPerRow: width * 4, space: cs, 79 | bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue 80 | | CGBitmapInfo.byteOrder32Big.rawValue)! 81 | 82 | context.draw( 83 | image, in: CGRect(origin: .zero, size: .init(width: width, height: height))) 84 | } 85 | 86 | self.data = MLXArray(raster, [height, width, 4], type: UInt8.self)[0..., 0..., ..<3] 87 | } 88 | 89 | /// Convert the image data to a CGImage 90 | public func asCGImage() -> CGImage { 91 | var raster = data 92 | 93 | // we need 4 bytes per pixel 94 | if data.dim(-1) == 3 { 95 | raster = padded(raster, widths: [0, 0, [0, 1]]) 96 | } 97 | 98 | class DataHolder { 99 | var data: Data 100 | init(_ data: Data) { 101 | self.data = data 102 | } 103 | } 104 | 105 | let holder = DataHolder(raster.asData(access: .copy).data) 106 | 107 | let payload = Unmanaged.passRetained(holder).toOpaque() 108 | func release(payload: UnsafeMutableRawPointer?, data: UnsafeMutableRawPointer?) { 109 | Unmanaged.fromOpaque(payload!).release() 110 | } 111 | 112 | return holder.data.withUnsafeMutableBytes { ptr in 113 | let (H, W, C) = raster.shape3 114 | let cs = CGColorSpace(name: CGColorSpace.sRGB)! 115 | 116 | let context = CGContext( 117 | data: ptr.baseAddress, width: W, height: H, bitsPerComponent: 8, bytesPerRow: W * C, 118 | space: cs, 119 | bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue 120 | | CGBitmapInfo.byteOrder32Big.rawValue, releaseCallback: release, 121 | releaseInfo: payload)! 122 | return context.makeImage()! 123 | } 124 | } 125 | 126 | /// Convert the image data to a CIImage 127 | public func asCIImage() -> CIImage { 128 | // we need 4 bytes per pixel 129 | var raster = data 130 | if data.dim(-1) == 3 { 131 | raster = padded(raster, widths: [0, 0, [0, 1]], value: MLXArray(255)) 132 | } 133 | 134 | let arrayData = raster.asData() 135 | let (H, W, C) = raster.shape3 136 | let cs = CGColorSpace(name: CGColorSpace.sRGB)! 137 | 138 | return CIImage( 139 | bitmapData: arrayData.data, bytesPerRow: W * 4, size: .init(width: W, height: H), 140 | format: .RGBA8, colorSpace: cs) 141 | } 142 | 143 | /// Save the image 144 | public func save(url: URL) throws { 145 | let uti = UTType(filenameExtension: url.pathExtension) ?? UTType.png 146 | 147 | let destination = CGImageDestinationCreateWithURL( 148 | url as CFURL, uti.identifier as CFString, 1, nil)! 149 | CGImageDestinationAddImage(destination, asCGImage(), nil) 150 | if !CGImageDestinationFinalize(destination) { 151 | throw ImageError.failedToSave 152 | } 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 8 | 9 | 15 | 21 | 22 | 23 | 24 | 25 | 31 | 32 | 43 | 45 | 51 | 52 | 53 | 54 | 57 | 58 | 61 | 62 | 65 | 66 | 69 | 70 | 73 | 74 | 77 | 78 | 81 | 82 | 85 | 86 | 89 | 90 | 93 | 94 | 97 | 98 | 101 | 102 | 105 | 106 | 107 | 108 | 114 | 116 | 122 | 123 | 124 | 125 | 127 | 128 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series of 86 | actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or permanent 93 | ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within the 113 | community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.1, available at 119 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 126 | [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 130 | [Mozilla CoC]: https://github.com/mozilla/diversity 131 | [FAQ]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 5.9 2 | // The swift-tools-version declares the minimum version of Swift required to build this package. 3 | 4 | import PackageDescription 5 | 6 | let package = Package( 7 | name: "mlx-libraries", 8 | platforms: [.macOS(.v14), .iOS(.v16)], 9 | products: [ 10 | .library( 11 | name: "MLXLLM", 12 | targets: ["MLXLLM"]), 13 | .library( 14 | name: "MLXVLM", 15 | targets: ["MLXVLM"]), 16 | .library( 17 | name: "MLXLMCommon", 18 | targets: ["MLXLMCommon"]), 19 | .library( 20 | name: "MLXMNIST", 21 | targets: ["MLXMNIST"]), 22 | .library( 23 | name: "MLXEmbedders", 24 | targets: ["MLXEmbedders"]), 25 | .library( 26 | name: "StableDiffusion", 27 | targets: ["StableDiffusion"]), 28 | ], 29 | dependencies: [ 30 | .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.21.2")), 31 | .package( 32 | url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.17") 33 | ), 34 | .package( 35 | url: "https://github.com/apple/swift-async-algorithms", .upToNextMinor(from: "1.0.0")), 36 | .package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"), // Only needed by MLXMNIST 37 | ], 38 | targets: [ 39 | .target( 40 | name: "MLXLLM", 41 | dependencies: [ 42 | "MLXLMCommon", 43 | .product(name: "MLX", package: "mlx-swift"), 44 | .product(name: "MLXFast", package: "mlx-swift"), 45 | .product(name: "MLXNN", package: "mlx-swift"), 46 | .product(name: "MLXOptimizers", package: "mlx-swift"), 47 | .product(name: "MLXRandom", package: "mlx-swift"), 48 | .product(name: "Transformers", package: "swift-transformers"), 49 | ], 50 | path: "Libraries/MLXLLM", 51 | exclude: [ 52 | "README.md" 53 | ], 54 | swiftSettings: [ 55 | .enableExperimentalFeature("StrictConcurrency") 56 | ] 57 | ), 58 | .target( 59 | name: "MLXVLM", 60 | dependencies: [ 61 | "MLXLMCommon", 62 | .product(name: "MLX", package: "mlx-swift"), 63 | .product(name: "MLXFast", package: "mlx-swift"), 64 | .product(name: "MLXNN", package: "mlx-swift"), 65 | .product(name: "MLXOptimizers", package: "mlx-swift"), 66 | .product(name: "MLXRandom", package: "mlx-swift"), 67 | .product(name: "Transformers", package: "swift-transformers"), 68 | ], 69 | path: "Libraries/MLXVLM", 70 | exclude: [ 71 | "README.md" 72 | ], 73 | swiftSettings: [ 74 | .enableExperimentalFeature("StrictConcurrency") 75 | ] 76 | ), 77 | .target( 78 | name: "MLXLMCommon", 79 | dependencies: [ 80 | .product(name: "MLX", package: "mlx-swift"), 81 | .product(name: "MLXNN", package: "mlx-swift"), 82 | .product(name: "MLXOptimizers", package: "mlx-swift"), 83 | .product(name: "MLXRandom", package: "mlx-swift"), 84 | .product(name: "Transformers", package: "swift-transformers"), 85 | ], 86 | path: "Libraries/MLXLMCommon", 87 | exclude: [ 88 | "README.md" 89 | ], 90 | swiftSettings: [ 91 | .enableExperimentalFeature("StrictConcurrency") 92 | ] 93 | ), 94 | .target( 95 | name: "MLXEmbedders", 96 | dependencies: [ 97 | .product(name: "MLX", package: "mlx-swift"), 98 | .product(name: "MLXFast", package: "mlx-swift"), 99 | .product(name: "MLXNN", package: "mlx-swift"), 100 | .product(name: "Transformers", package: "swift-transformers"), 101 | .product(name: "MLXLinalg", package: "mlx-swift"), 102 | ], 103 | path: "Libraries/Embedders", 104 | exclude: [ 105 | "README.md" 106 | ] 107 | ), 108 | .target( 109 | name: "MLXMNIST", 110 | dependencies: [ 111 | .product(name: "MLX", package: "mlx-swift"), 112 | .product(name: "MLXFast", package: "mlx-swift"), 113 | .product(name: "MLXNN", package: "mlx-swift"), 114 | .product(name: "MLXOptimizers", package: "mlx-swift"), 115 | .product(name: "MLXRandom", package: "mlx-swift"), 116 | .product(name: "Transformers", package: "swift-transformers"), 117 | .product(name: "Gzip", package: "GzipSwift"), 118 | ], 119 | path: "Libraries/MLXMNIST", 120 | exclude: [ 121 | "README.md" 122 | ], 123 | swiftSettings: [ 124 | .enableExperimentalFeature("StrictConcurrency") 125 | ] 126 | ), 127 | .target( 128 | name: "StableDiffusion", 129 | dependencies: [ 130 | .product(name: "MLX", package: "mlx-swift"), 131 | .product(name: "MLXNN", package: "mlx-swift"), 132 | .product(name: "MLXRandom", package: "mlx-swift"), 133 | .product(name: "Transformers", package: "swift-transformers"), 134 | ], 135 | path: "Libraries/StableDiffusion", 136 | exclude: [ 137 | "README.md" 138 | ], 139 | swiftSettings: [ 140 | .enableExperimentalFeature("StrictConcurrency") 141 | ] 142 | ), 143 | ] 144 | ) 145 | 146 | if Context.environment["MLX_SWIFT_BUILD_DOC"] == "1" 147 | || Context.environment["SPI_GENERATE_DOCS"] == "1" 148 | { 149 | // docc builder 150 | package.dependencies.append( 151 | .package(url: "https://github.com/apple/swift-docc-plugin", from: "1.3.0") 152 | ) 153 | } 154 | --------------------------------------------------------------------------------