├── Tests
├── MLXLMTests
│ ├── README.md
│ ├── StreamlinedTests.swift
│ ├── BaseConfigurationTests.swift
│ ├── ToolTests.swift
│ └── UserInputTests.swift
└── mlx-libraries-Package.xctestplan
├── support
├── test.jpg
└── generate-run-all-llms.sh
├── Applications
├── VLMEval
│ ├── Assets.xcassets
│ │ ├── Contents.json
│ │ ├── AccentColor.colorset
│ │ │ └── Contents.json
│ │ └── AppIcon.appiconset
│ │ │ └── Contents.json
│ ├── Preview Content
│ │ └── Preview Assets.xcassets
│ │ │ └── Contents.json
│ ├── VLMEvalApp.swift
│ ├── VLMEval.entitlements
│ └── README.md
├── LLMEval
│ ├── AssetCatalog.xcassets
│ │ ├── Contents.json
│ │ └── AccentColor.colorset
│ │ │ └── Contents.json
│ ├── AppIcon.icon
│ │ ├── Assets
│ │ │ └── bubbles3.png
│ │ └── icon.json
│ ├── LLMEvalApp.swift
│ ├── LLMEval.entitlements
│ ├── Models
│ │ ├── ToolDefinitions.swift
│ │ └── PresetPrompts.swift
│ ├── ViewModels
│ │ └── DeviceStat.swift
│ ├── Services
│ │ ├── FormatUtilities.swift
│ │ └── ToolExecutor.swift
│ ├── Views
│ │ ├── MetricCard.swift
│ │ ├── OutputView.swift
│ │ ├── LoadingOverlayView.swift
│ │ ├── PromptInputView.swift
│ │ ├── HeaderView.swift
│ │ ├── MetricsView.swift
│ │ ├── ContentView.swift
│ │ └── PresetPromptsSheet.swift
│ └── README.md
├── MNISTTrainer
│ ├── Assets.xcassets
│ │ ├── Contents.json
│ │ ├── AccentColor.colorset
│ │ │ └── Contents.json
│ │ └── AppIcon.appiconset
│ │ │ └── Contents.json
│ ├── Preview Content
│ │ └── Preview Assets.xcassets
│ │ │ └── Contents.json
│ ├── MNISTTrainer-Info.plist
│ ├── MNISTTrainerApp.swift
│ ├── MNISTTrainer.entitlements
│ ├── README.md
│ └── PredictionView.swift
├── LoRATrainingExample
│ ├── Assets.xcassets
│ │ ├── Contents.json
│ │ ├── AccentColor.colorset
│ │ │ └── Contents.json
│ │ └── AppIcon.appiconset
│ │ │ └── Contents.json
│ ├── Preview Content
│ │ └── Preview Assets.xcassets
│ │ │ └── Contents.json
│ ├── LoRATrainingExampleApp.swift
│ ├── LoRATrainingExample.entitlements
│ └── README.md
├── MLXChatExample
│ ├── Support
│ │ ├── Assets.xcassets
│ │ │ ├── Contents.json
│ │ │ ├── AccentColor.colorset
│ │ │ │ └── Contents.json
│ │ │ └── AppIcon.appiconset
│ │ │ │ └── Contents.json
│ │ ├── Preview Content
│ │ │ └── Preview Assets.xcassets
│ │ │ │ └── Contents.json
│ │ ├── HubApi+default.swift
│ │ └── SampleData.swift
│ ├── MLXChatExampleApp.swift
│ ├── Views
│ │ ├── Toolbar
│ │ │ ├── GenerationInfoView.swift
│ │ │ ├── ErrorView.swift
│ │ │ ├── DownloadProgressView.swift
│ │ │ └── ChatToolbarView.swift
│ │ ├── ConversationView.swift
│ │ ├── PromptField.swift
│ │ ├── MessageView.swift
│ │ └── MediaPreviewView.swift
│ ├── MLXChatExample.entitlements
│ ├── Models
│ │ ├── LMModel.swift
│ │ └── Message.swift
│ ├── ChatView.swift
│ └── README.md
└── StableDiffusionExample
│ ├── Assets.xcassets
│ ├── Contents.json
│ ├── AccentColor.colorset
│ │ └── Contents.json
│ └── AppIcon.appiconset
│ │ └── Contents.json
│ ├── Preview Content
│ └── Preview Assets.xcassets
│ │ └── Contents.json
│ ├── StableDiffusionExampleApp.swift
│ ├── StableDiffusionExample.entitlements
│ └── README.md
├── .spi.yml
├── .swift-format
├── mlx-swift-examples.xcodeproj
├── project.xcworkspace
│ ├── contents.xcworkspacedata
│ └── xcshareddata
│ │ ├── IDEWorkspaceChecks.plist
│ │ └── swiftpm
│ │ └── Package.resolved
└── xcshareddata
│ └── xcschemes
│ ├── LLMEval.xcscheme
│ ├── VLMEval.xcscheme
│ ├── ExampleLLM.xcscheme
│ ├── StableDiffusionExample.xcscheme
│ └── embedder-tool.xcscheme
├── Tools
├── embedder-tool
│ ├── CommandError.swift
│ ├── ArgumentSupport.swift
│ ├── Diagnostics.swift
│ ├── ListCommand.swift
│ ├── CorpusArguments.swift
│ ├── PoolingArguments.swift
│ ├── EmbedderCommand.swift
│ ├── MemoryArguments.swift
│ ├── DemoCommand.swift
│ ├── ModelArguments.swift
│ ├── EmbedderRuntime+Embedding.swift
│ ├── PoolingSupport.swift
│ └── VectorOperations.swift
├── LinearModelTraining
│ ├── README.md
│ └── LinearModelTraining.swift
├── llm-tool
│ ├── Arguments.swift
│ └── ListCommands.swift
├── mnist-tool
│ ├── README.md
│ └── MNISTTool.swift
├── ExampleLLM
│ ├── main.swift
│ └── README.md
├── image-tool
│ └── Arguments.swift
└── Tutorial
│ └── Tutorial.swift
├── .pre-commit-config.yaml
├── Configuration
└── Build.xcconfig
├── Libraries
├── MLXMNIST
│ ├── README.md
│ ├── Random.swift
│ ├── MNIST.swift
│ └── Files.swift
└── StableDiffusion
│ ├── README.md
│ ├── Sampler.swift
│ └── Tokenizer.swift
├── .github
├── pull_request_template.md
├── ISSUE_TEMPLATE
│ └── bug_report.md
└── workflows
│ └── pull_request.yml
├── ACKNOWLEDGMENTS.md
├── LICENSE
├── CONTRIBUTING.md
├── mlx-run
├── Package.resolved
├── .gitignore
├── Package.swift
├── Data
└── lora
│ └── wikisql.py
└── README.md
/Tests/MLXLMTests/README.md:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/support/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-explore/mlx-swift-examples/HEAD/support/test.jpg
--------------------------------------------------------------------------------
/Applications/VLMEval/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/LLMEval/AssetCatalog.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Support/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/.spi.yml:
--------------------------------------------------------------------------------
1 | version: 1
2 | builder:
3 | configs:
4 | - documentation_targets: [MLXLLM, MLXVLM, MLXLMCommon, MLXMNIST, MLXEmbedders, StableDiffusion]
5 |
--------------------------------------------------------------------------------
/.swift-format:
--------------------------------------------------------------------------------
1 | {
2 | "version": 1,
3 | "indentation": {
4 | "spaces": 4
5 | },
6 | "spacesAroundRangeFormationOperators": true,
7 | }
8 |
--------------------------------------------------------------------------------
/Applications/VLMEval/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/LLMEval/AppIcon.icon/Assets/bubbles3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-explore/mlx-swift-examples/HEAD/Applications/LLMEval/AppIcon.icon/Assets/bubbles3.png
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Support/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/project.xcworkspace/contents.xcworkspacedata:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/MNISTTrainer-Info.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Applications/VLMEval/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Support/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/MNISTTrainerApp.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | @main
6 | struct MNISTTrainerApp: App {
7 | var body: some Scene {
8 | WindowGroup {
9 | ContentView()
10 | }
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/CommandError.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | struct CommandError: LocalizedError {
4 | let message: String
5 |
6 | init(_ message: String) {
7 | self.message = message
8 | }
9 |
10 | var errorDescription: String? { message }
11 | }
12 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/LoRATrainingExampleApp.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | @main
6 | struct LoRATrainingExampleApp: App {
7 | var body: some Scene {
8 | WindowGroup {
9 | ContentView()
10 | }
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 |
3 | - repo: local
4 | hooks:
5 | - id: swift-format
6 | name: swift-format
7 | language: system
8 | entry: swift-format format --in-place --configuration .swift-format --recursive .
9 | require_serial: true
10 | types: [swift]
11 |
--------------------------------------------------------------------------------
/Applications/LLMEval/LLMEvalApp.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | @main
6 | struct LLMEvalApp: App {
7 | var body: some Scene {
8 | WindowGroup {
9 | ContentView()
10 | .environment(DeviceStat())
11 | }
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/StableDiffusionExampleApp.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | @main
6 | struct StableDiffusionExampleApp: App {
7 | var body: some Scene {
8 | WindowGroup {
9 | ContentView()
10 | }
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/Applications/VLMEval/VLMEvalApp.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | @main
6 | struct VLMEvalApp: App {
7 | var body: some Scene {
8 | WindowGroup {
9 | ContentView()
10 | .environment(DeviceStat())
11 | }
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | IDEDidComputeMac32BitWarning
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/Tools/LinearModelTraining/README.md:
--------------------------------------------------------------------------------
1 | # LinearModelTraining
2 |
3 | A command line tool that creates a Model that represents:
4 |
5 | f(x) = mx + b
6 |
7 | and trains it against an unknown linear function. Very
8 | simple but illustrates:
9 |
10 | - a very simple model with parameters
11 | - a loss function
12 | - the gradient
13 | - use of an optimizers
14 | - the training loop
15 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/MLXChatExampleApp.swift:
--------------------------------------------------------------------------------
1 | //
2 | // MLXChatExampleApp.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 20.04.2025.
6 | //
7 |
8 | import SwiftUI
9 |
10 | @main
11 | struct MLXChatExampleApp: App {
12 | var body: some Scene {
13 | WindowGroup {
14 | ChatView(viewModel: ChatViewModel(mlxService: MLXService()))
15 | }
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/MNISTTrainer.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.security.app-sandbox
6 |
7 | com.apple.security.files.user-selected.read-only
8 |
9 | com.apple.security.network.client
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/Applications/LLMEval/AssetCatalog.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "color" : {
5 | "color-space" : "srgb",
6 | "components" : {
7 | "alpha" : "1.000",
8 | "blue" : "1.000",
9 | "green" : "0.689",
10 | "red" : "0.307"
11 | }
12 | },
13 | "idiom" : "universal"
14 | }
15 | ],
16 | "info" : {
17 | "author" : "xcode",
18 | "version" : 1
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/ArgumentSupport.swift:
--------------------------------------------------------------------------------
1 | import ArgumentParser
2 | import Foundation
3 |
4 | extension URL: @retroactive ExpressibleByArgument {
5 | public init?(argument: String) {
6 | let expanded = NSString(string: argument).expandingTildeInPath
7 |
8 | if argument.contains("://"), let remote = URL(string: argument), remote.scheme != nil {
9 | self = remote
10 | return
11 | }
12 |
13 | self.init(fileURLWithPath: expanded)
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/Configuration/Build.xcconfig:
--------------------------------------------------------------------------------
1 | // The `DISAMBIGUATOR` configuration is to make it easier to build
2 | // and run a sample code project. Once you set your project's development team,
3 | // you'll have a unique bundle identifier. This is because the bundle identifier
4 | // is derived based on the 'DISAMBIGUATOR' value. Do not use this
5 | // approach in your own projects—it's only useful for example projects because
6 | // they are frequently downloaded and don't have a development team set.
7 | DISAMBIGUATOR=${DEVELOPMENT_TEAM}
8 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Views/Toolbar/GenerationInfoView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // GenerationInfoView.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 21.04.2025.
6 | //
7 |
8 | import SwiftUI
9 |
10 | struct GenerationInfoView: View {
11 | let tokensPerSecond: Double
12 |
13 | var body: some View {
14 | Text("\(tokensPerSecond, format: .number.precision(.fractionLength(2))) tokens/s")
15 | }
16 | }
17 |
18 | #Preview {
19 | GenerationInfoView(tokensPerSecond: 58.5834)
20 | }
21 |
--------------------------------------------------------------------------------
/Applications/LLMEval/LLMEval.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.developer.kernel.increased-memory-limit
6 |
7 | com.apple.security.app-sandbox
8 |
9 | com.apple.security.files.user-selected.read-only
10 |
11 | com.apple.security.network.client
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/Libraries/MLXMNIST/README.md:
--------------------------------------------------------------------------------
1 | # MNIST
2 |
3 | This is a port of the MNIST training code from the [Python MLX example](https://github.com/ml-explore/mlx-examples/blob/main/mnist). This example uses a [LeNet](https://en.wikipedia.org/wiki/LeNet) instead of an MLP.
4 |
5 | It provides code to:
6 |
7 | - Download the MNIST test/train data
8 | - Build the LeNet
9 | - Some functions to shuffle and batch the data
10 |
11 | See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.
12 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/LoRATrainingExample.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.developer.kernel.increased-memory-limit
6 |
7 | com.apple.security.app-sandbox
8 |
9 | com.apple.security.files.user-selected.read-only
10 |
11 | com.apple.security.network.client
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/StableDiffusionExample.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.developer.kernel.increased-memory-limit
6 |
7 | com.apple.security.app-sandbox
8 |
9 | com.apple.security.files.user-selected.read-only
10 |
11 | com.apple.security.network.client
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/Tests/mlx-libraries-Package.xctestplan:
--------------------------------------------------------------------------------
1 | {
2 | "configurations" : [
3 | {
4 | "id" : "CCE21325-5E68-419C-8EB1-B868EF7688F7",
5 | "name" : "Test Scheme Action",
6 | "options" : {
7 |
8 | }
9 | }
10 | ],
11 | "defaultOptions" : {
12 |
13 | },
14 | "testTargets" : [
15 | {
16 | "target" : {
17 | "containerPath" : "container:mlx-swift-examples.xcodeproj",
18 | "identifier" : "C3208E6D2DB19451006AE6CA",
19 | "name" : "MLXLMTests"
20 | }
21 | }
22 | ],
23 | "version" : 1
24 | }
25 |
--------------------------------------------------------------------------------
/Applications/VLMEval/VLMEval.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.developer.kernel.increased-memory-limit
6 |
7 | com.apple.security.app-sandbox
8 |
9 | com.apple.security.device.usb
10 |
11 | com.apple.security.files.user-selected.read-only
12 |
13 | com.apple.security.network.client
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/support/generate-run-all-llms.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | echo "#!/bin/sh"
4 | echo "# NOTE: GENERATED BY generate-run-all-llms.sh -- DO NOT MODIFY BY HAND"
5 |
6 | ./mlx-run llm-tool list llms | \
7 | awk '{printf "./mlx-run llm-tool eval --download ~/Downloads/huggingface --model %s\n", $0}' | \
8 | awk '{printf "echo\necho ======\necho '\''%s'\''\n%s\n", $0, $0}'
9 |
10 | ./mlx-run llm-tool list vlms | \
11 | awk '{printf "./mlx-run llm-tool eval --download ~/Downloads/huggingface --model %s --resize 512 --image support/test.jpg\n", $0}' | \
12 | awk '{printf "echo\necho ======\necho '\''%s'\''\n%s\n", $0, $0}'
13 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/MLXChatExample.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.developer.kernel.increased-memory-limit
6 |
7 | com.apple.security.app-sandbox
8 |
9 | com.apple.security.files.downloads.read-write
10 |
11 | com.apple.security.files.user-selected.read-only
12 |
13 | com.apple.security.network.client
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ## Proposed changes
2 |
3 | Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.
4 |
5 | ## Checklist
6 |
7 | Put an `x` in the boxes that apply.
8 |
9 | - [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document
10 | - [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes
11 | - [ ] I have added tests that prove my fix is effective or that my feature works
12 | - [ ] I have updated the necessary documentation (if needed)
13 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/Diagnostics.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | enum DiagnosticKind {
4 | case info
5 | case warning
6 | case error
7 |
8 | fileprivate var prefix: String {
9 | switch self {
10 | case .info:
11 | return ""
12 | case .warning:
13 | return "warning: "
14 | case .error:
15 | return "error: "
16 | }
17 | }
18 | }
19 |
20 | func writeDiagnostic(_ message: String, kind: DiagnosticKind = .info) {
21 | guard let data = (kind.prefix + message + "\n").data(using: .utf8) else {
22 | return
23 | }
24 | FileHandle.standardError.write(data)
25 | }
26 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Models/ToolDefinitions.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import Foundation
4 |
5 | // MARK: - Weather Tool
6 |
7 | struct WeatherInput: Codable {
8 | let location: String
9 | let unit: String?
10 | }
11 |
12 | struct WeatherOutput: Codable {
13 | let temperature: Double
14 | let conditions: String
15 | }
16 |
17 | // MARK: - Add Tool
18 |
19 | struct AddInput: Codable {
20 | let first: Int
21 | let second: Int
22 | }
23 |
24 | struct AddOutput: Codable {
25 | let result: Int
26 | }
27 |
28 | // MARK: - Time Tool
29 |
30 | struct EmptyInput: Codable {}
31 |
32 | struct TimeOutput: Codable {
33 | let time: String
34 | }
35 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/README.md:
--------------------------------------------------------------------------------
1 | # MNISTTrainer
2 |
3 | This is an example of model training that works on both macOS and iOS.
4 | The example will download the MNIST training data, create a LeNet, and train
5 | it. It will show the epoch time and test accuracy as it trains.
6 |
7 | You will need to set the Team on the MNISTTrainer target in order to build and
8 | run on iOS.
9 |
10 | Some notes about the setup:
11 |
12 | - This will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
13 | - The website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist
14 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report about an issue you've encountered
4 | title: "[BUG] "
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 |
15 | Include code snippet
16 | ```swift
17 |
18 | ```
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Desktop (please complete the following information):**
24 | - OS Version: [e.g. MacOS 14.1.2]
25 | - Device: [e.g. iPhone 16, M2 MacBook Pro]
26 | - Version [e.g. 0.29.1]
27 |
28 | **Additional context**
29 | Add any other context about the problem here.
30 |
--------------------------------------------------------------------------------
/ACKNOWLEDGMENTS.md:
--------------------------------------------------------------------------------
1 | # Individual Contributors
2 |
3 | > If you wish to be acknowledged for your contributions, please list your name
4 | > with a short description of your contribution(s) below. For example:
5 | > - Jane Smith: Added the `foo` and `bar` ops.
6 |
7 | MLX Swift was developed with contributions from the following individuals:
8 |
9 | - [John Mai](https://github.com/johnmai-dev): Added support for multiple models (Qwen2, Starcoder2, InternLM2, Qwen3, Qwen3 MoE, GLM-4, MiMo, BitNet, SmolLM3, LFM2, Baichuan-M1).
10 |
11 |
12 |
13 |
14 |
15 |
16 | SOFTWARE.
17 |
18 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Views/Toolbar/ErrorView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // ErrorView.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 21.04.2025.
6 | //
7 |
8 | import SwiftUI
9 |
10 | struct ErrorView: View {
11 | let errorMessage: String
12 |
13 | @State private var isShowingError = false
14 |
15 | var body: some View {
16 | Button {
17 | isShowingError = true
18 | } label: {
19 | Image(systemName: "exclamationmark.triangle")
20 | .foregroundStyle(.red)
21 | }
22 | .popover(isPresented: $isShowingError, arrowEdge: .bottom) {
23 | Text(errorMessage)
24 | .padding()
25 | }
26 | }
27 | }
28 |
29 | #Preview {
30 | ErrorView(errorMessage: "Something went wrong!")
31 | }
32 |
--------------------------------------------------------------------------------
/Applications/VLMEval/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | },
8 | {
9 | "appearances" : [
10 | {
11 | "appearance" : "luminosity",
12 | "value" : "dark"
13 | }
14 | ],
15 | "idiom" : "universal",
16 | "platform" : "ios",
17 | "size" : "1024x1024"
18 | },
19 | {
20 | "appearances" : [
21 | {
22 | "appearance" : "luminosity",
23 | "value" : "tinted"
24 | }
25 | ],
26 | "idiom" : "universal",
27 | "platform" : "ios",
28 | "size" : "1024x1024"
29 | }
30 | ],
31 | "info" : {
32 | "author" : "xcode",
33 | "version" : 1
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Support/HubApi+default.swift:
--------------------------------------------------------------------------------
1 | //
2 | // HubApi+default.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 21.04.2025.
6 | //
7 |
8 | import Foundation
9 | @preconcurrency import Hub
10 |
11 | /// Extension providing a default HubApi instance for downloading model files
12 | extension HubApi {
13 | /// Default HubApi instance configured to download models to the user's Downloads directory
14 | /// under a 'huggingface' subdirectory.
15 | #if os(macOS)
16 | static let `default` = HubApi(
17 | downloadBase: URL.downloadsDirectory.appending(path: "huggingface")
18 | )
19 | #else
20 | static let `default` = HubApi(
21 | downloadBase: URL.cachesDirectory.appending(path: "huggingface")
22 | )
23 | #endif
24 | }
25 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Support/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | },
8 | {
9 | "appearances" : [
10 | {
11 | "appearance" : "luminosity",
12 | "value" : "dark"
13 | }
14 | ],
15 | "idiom" : "universal",
16 | "platform" : "ios",
17 | "size" : "1024x1024"
18 | },
19 | {
20 | "appearances" : [
21 | {
22 | "appearance" : "luminosity",
23 | "value" : "tinted"
24 | }
25 | ],
26 | "idiom" : "universal",
27 | "platform" : "ios",
28 | "size" : "1024x1024"
29 | }
30 | ],
31 | "info" : {
32 | "author" : "xcode",
33 | "version" : 1
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/Tools/llm-tool/Arguments.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 |
6 | /// Extension to allow URL command line arguments.
7 | #if swift(>=5.10)
8 | extension URL: @retroactive ExpressibleByArgument {
9 | public init?(argument: String) {
10 | if argument.contains("://") {
11 | self.init(string: argument)
12 | } else {
13 | self.init(filePath: argument)
14 | }
15 | }
16 | }
17 | #else
18 | extension URL: ExpressibleByArgument {
19 | public init?(argument: String) {
20 | if argument.contains("://") {
21 | self.init(string: argument)
22 | } else {
23 | self.init(filePath: argument)
24 | }
25 | }
26 | }
27 | #endif
28 |
--------------------------------------------------------------------------------
/Applications/LLMEval/ViewModels/DeviceStat.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import Foundation
4 | import MLX
5 |
6 | @Observable
7 | final class DeviceStat: @unchecked Sendable {
8 |
9 | @MainActor
10 | var gpuUsage = GPU.snapshot()
11 |
12 | private let initialGPUSnapshot = GPU.snapshot()
13 | private var timer: Timer?
14 |
15 | init() {
16 | timer = Timer.scheduledTimer(withTimeInterval: 2.0, repeats: true) { [weak self] _ in
17 | self?.updateGPUUsages()
18 | }
19 | }
20 |
21 | deinit {
22 | timer?.invalidate()
23 | }
24 |
25 | private func updateGPUUsages() {
26 | let gpuSnapshotDelta = initialGPUSnapshot.delta(GPU.snapshot())
27 | DispatchQueue.main.async { [weak self] in
28 | self?.gpuUsage = gpuSnapshotDelta
29 | }
30 | }
31 |
32 | }
33 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Services/FormatUtilities.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import Foundation
4 |
5 | /// Utility functions for formatting values
6 | enum FormatUtilities {
7 |
8 | /// Formats a byte count into a human-readable string with appropriate units
9 | /// - Parameter bytes: The number of bytes to format
10 | /// - Returns: A formatted string (e.g., "2.5 GB", "128 MB", "512 KB")
11 | static func formatMemory(_ bytes: Int) -> String {
12 | let kb = Double(bytes) / 1024
13 | let mb = kb / 1024
14 | let gb = mb / 1024
15 |
16 | if gb >= 1 {
17 | return String(format: "%.2f GB", gb)
18 | } else if mb >= 1 {
19 | return String(format: "%.0f MB", mb)
20 | } else if kb >= 1 {
21 | return String(format: "%.0f KB", kb)
22 | } else {
23 | return "0 KB"
24 | }
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Views/MetricCard.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | struct MetricCard: View {
6 | let icon: String
7 | let title: String
8 | let value: String
9 |
10 | var body: some View {
11 | VStack(spacing: 8) {
12 | HStack(spacing: 4) {
13 | Image(systemName: icon)
14 | .font(.caption)
15 | .foregroundStyle(.secondary)
16 | Text(title)
17 | .font(.caption)
18 | .foregroundStyle(.secondary)
19 | }
20 | Text(value)
21 | .font(.title3)
22 | .fontWeight(.semibold)
23 | .monospacedDigit()
24 | }
25 | .frame(maxWidth: .infinity)
26 | .padding(.vertical, 12)
27 | .padding(.horizontal, 8)
28 | .background(Color.gray.opacity(0.1))
29 | .cornerRadius(8)
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/README.md:
--------------------------------------------------------------------------------
1 | # LoRATrainingExample
2 |
3 | Example application that:
4 |
5 | - downloads the `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model from huggingface
6 | - loads the train/valid/test data from `$SRCROOT/Data/lora` (this is copied into the build but you can imagine how it might be downloaded)
7 | - adds LoRA adapters and trains the model
8 | - let's you evaluate a prompt against the model
9 |
10 | This roughly equates to the command line example in [Tools/llm-tool](../../Tools/llm-tool) and
11 | you can read more about LoRA there.
12 |
13 | This evaluates the LoRA adapted model rather than a fused model. This doesn't persist
14 | the LoRA weights or the fused model -- it will retrain it each time the program is launched.
15 |
16 | ### Troubleshooting
17 |
18 | The `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model requires a little over 4G of
19 | memory to load an train -- this may require ~6G of physical RAM.
20 |
21 |
22 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Views/ConversationView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // ConversationView.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 20.04.2025.
6 | //
7 |
8 | import SwiftUI
9 |
10 | /// Displays the chat conversation as a scrollable list of messages.
11 | struct ConversationView: View {
12 | /// Array of messages to display in the conversation
13 | let messages: [Message]
14 |
15 | var body: some View {
16 | ScrollView {
17 | LazyVStack(spacing: 12) {
18 | ForEach(messages) { message in
19 | MessageView(message)
20 | .padding(.horizontal, 12)
21 | }
22 | }
23 | }
24 | .padding(.vertical, 8)
25 | .defaultScrollAnchor(.bottom, for: .sizeChanges)
26 | }
27 | }
28 |
29 | #Preview {
30 | // Display sample conversation in preview
31 | ConversationView(messages: SampleData.conversation)
32 | }
33 |
--------------------------------------------------------------------------------
/Applications/LLMEval/AppIcon.icon/icon.json:
--------------------------------------------------------------------------------
1 | {
2 | "fill" : {
3 | "automatic-gradient" : "srgb:0.66275,0.79216,0.87451,1.00000",
4 | "orientation" : {
5 | "start" : {
6 | "x" : 0.5,
7 | "y" : 0
8 | },
9 | "stop" : {
10 | "x" : 0.5,
11 | "y" : 0.7
12 | }
13 | }
14 | },
15 | "groups" : [
16 | {
17 | "layers" : [
18 | {
19 | "image-name" : "bubbles3.png",
20 | "name" : "bubbles3",
21 | "position" : {
22 | "scale" : 1.35,
23 | "translation-in-points" : [
24 | 0,
25 | 0
26 | ]
27 | }
28 | }
29 | ],
30 | "shadow" : {
31 | "kind" : "neutral",
32 | "opacity" : 0.5
33 | },
34 | "translucency" : {
35 | "enabled" : true,
36 | "value" : 0.5
37 | }
38 | }
39 | ],
40 | "supported-platforms" : {
41 | "circles" : [
42 | "watchOS"
43 | ],
44 | "squares" : "shared"
45 | }
46 | }
--------------------------------------------------------------------------------
/Tools/embedder-tool/ListCommand.swift:
--------------------------------------------------------------------------------
1 | import ArgumentParser
2 | import Foundation
3 | import MLXEmbedders
4 |
5 | struct ListCommand: AsyncParsableCommand {
6 | static let configuration = CommandConfiguration(
7 | commandName: "list",
8 | abstract: "List registered embedder model configurations"
9 | )
10 |
11 | @Flag(name: .long, help: "Include models registered from local directories")
12 | var includeDirectories = false
13 |
14 | func run() async throws {
15 | let models = await MainActor.run { Array(ModelConfiguration.models) }
16 | .sorted { $0.name.localizedCaseInsensitiveCompare($1.name) == .orderedAscending }
17 |
18 | for configuration in models {
19 | switch configuration.id {
20 | case .id(let identifier):
21 | print(identifier)
22 | case .directory(let url):
23 | if includeDirectories {
24 | print(url.path)
25 | }
26 | }
27 | }
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/Tools/mnist-tool/README.md:
--------------------------------------------------------------------------------
1 | # mnist-tool
2 |
3 | See the [MNIST README.md](../../Libraries/MNIST/README.md).
4 |
5 | ### Building
6 |
7 | `mnist-tool` has no dependencies outside of the package dependencies
8 | represented in xcode.
9 |
10 | When you run the tool it will download the test/train datasets and
11 | store them in a specified directory (see run arguments -- default is /tmp).
12 |
13 | Simply build the project in xcode.
14 |
15 | ### Running (Xcode)
16 |
17 | To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
18 |
19 | ```
20 | --data /tmp
21 | ```
22 |
23 | Then cmd-r to run.
24 |
25 | ### Running (CommandLine)
26 |
27 | Use the `mlx-run` script to run the command line tools:
28 |
29 | ```
30 | ./mlx-run mnist-tool --data /tmp
31 | ```
32 |
33 | By default this will find and run the tools built in _Release_ configuration. Specify `--debug`
34 | to find and run the tool built in _Debug_ configuration.
35 |
36 | See also:
37 |
38 | - [MLX troubleshooting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/troubleshooting)
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 ml-explore
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to MLX Swift Examples
2 |
3 | We want to make contributing to this project as easy and transparent as
4 | possible.
5 |
6 | ## Pull Requests
7 |
8 | 1. Fork and submit pull requests to the repo.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. Every PR should have passing tests (if any) and at least one review.
11 | 4. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
12 | If needed you may need to `brew install swift-format`.
13 |
14 | You can also run the formatters manually as follows:
15 |
16 | ```
17 | swift-format format --in-place --recursive Libraries Tools Applications
18 | ```
19 |
20 | or run `pre-commit run --all-files` to check all files in the repo.
21 |
22 | ## Issues
23 |
24 | We use GitHub issues to track public bugs. Please ensure your description is
25 | clear and has sufficient instructions to be able to reproduce the issue.
26 |
27 | ## License
28 |
29 | By contributing to MLX Swift Examples, you agree that your contributions will be licensed
30 | under the LICENSE file in the root directory of this source tree.
31 |
--------------------------------------------------------------------------------
/Libraries/MLXMNIST/Random.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 |
5 | // From https://github.com/apple/swift/blob/cb0fb1ea051631219c0b944b84c78571448d58c2/benchmark/utils/TestsUtils.swift#L254
6 | //
7 | // This is just a seedable RandomNumberGenerator for shuffle()
8 |
9 | // This is a fixed-increment version of Java 8's SplittableRandom generator.
10 | // It is a very fast generator passing BigCrush, with 64 bits of state.
11 | // See http://dx.doi.org/10.1145/2714064.2660195 and
12 | // http://docs.oracle.com/javase/8/docs/api/java/util/SplittableRandom.html
13 | //
14 | // Derived from public domain C implementation by Sebastiano Vigna
15 | // See http://xoshiro.di.unimi.it/splitmix64.c
16 | public struct SplitMix64: RandomNumberGenerator, Sendable {
17 | private var state: UInt64
18 |
19 | public init(seed: UInt64) {
20 | self.state = seed
21 | }
22 |
23 | public mutating func next() -> UInt64 {
24 | self.state &+= 0x9e37_79b9_7f4a_7c15
25 | var z: UInt64 = self.state
26 | z = (z ^ (z &>> 30)) &* 0xbf58_476d_1ce4_e5b9
27 | z = (z ^ (z &>> 27)) &* 0x94d0_49bb_1331_11eb
28 | return z ^ (z &>> 31)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/CorpusArguments.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 |
6 | struct CorpusArguments: ParsableArguments {
7 |
8 | @Option(name: .shortAndLong, help: "Directory containing documents to index.")
9 | var directory: URL = URL(
10 | fileURLWithPath: FileManager.default.currentDirectoryPath, isDirectory: true)
11 |
12 | @Option(
13 | name: [.customShort("e"), .long], parsing: .upToNextOption,
14 | help: "File extensions to include (without dots).")
15 | var extensions: [String] = ["txt", "md"]
16 |
17 | @Flag(name: .shortAndLong, help: "Recursively scan subdirectories.")
18 | var recursive = false
19 |
20 | @Option(name: .long, help: "Limit the number of documents to load.")
21 | var limit: Int?
22 |
23 | var normalizedExtensions: [String] {
24 | extensions.map { value in
25 | let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines)
26 | let noDot = trimmed.hasPrefix(".") ? String(trimmed.dropFirst()) : trimmed
27 | return noDot.lowercased()
28 | }
29 | }
30 |
31 | var directoryURL: URL { directory.standardizedFileURL }
32 | }
33 |
--------------------------------------------------------------------------------
/Tools/ExampleLLM/main.swift:
--------------------------------------------------------------------------------
1 | import MLXLMCommon
2 |
3 | let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit")
4 |
5 | let prompt = "What are three things to see in Paris?"
6 |
7 | // MARK: - one-shot print out full response
8 | print(
9 | """
10 | ================
11 | \(prompt)
12 |
13 | """)
14 | print(try await ChatSession(model).respond(to: prompt))
15 |
16 | // MARK: - one-shot streaming output
17 | print(
18 | """
19 | ================
20 | \(prompt)
21 |
22 | """)
23 | for try await item in ChatSession(model).streamResponse(to: prompt) {
24 | print(item, terminator: "")
25 | }
26 | print()
27 |
28 | // MARK: - conversation with follow-on questions
29 | let session = ChatSession(model)
30 |
31 | let questions = [
32 | "What are two things to see in San Francisco?",
33 | "How about a great place to eat?",
34 | "What city are we talking about? I forgot!",
35 | ]
36 |
37 | for question in questions {
38 | print(
39 | """
40 | ================
41 | \(question)
42 |
43 | """)
44 |
45 | for try await item in session.streamResponse(to: question) {
46 | print(item, terminator: "")
47 | }
48 | print()
49 | }
50 |
--------------------------------------------------------------------------------
/mlx-run:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Wrapper to help run command line tools -- this will find the build directory
4 | # and set the DYLD_FRAMEWORK_PATH so that command line tools that link frameworks
5 | # can be run.
6 | #
7 | # Example:
8 | # ./mlx-run --debug llm-tool --help
9 |
10 | if [ "$#" -lt 1 ]; then
11 | echo "usage: mlx-run [--debug/--release] arguments"
12 | exit 1
13 | fi
14 |
15 | CONFIGURATION=Release
16 | if [ "$1" == "--release" ]; then
17 | CONFIGURATION=Release
18 | shift
19 | fi
20 | if [ "$1" == "--debug" ]; then
21 | CONFIGURATION=Debug
22 | shift
23 | fi
24 | if [ "$1" == "--list" ]; then
25 | xcodebuild -list
26 | exit 0
27 | fi
28 |
29 | COMMAND="$1"
30 | shift
31 |
32 | BUILD_DIR=`xcodebuild -configuration $CONFIGURATION -showBuildSettings -scheme $COMMAND | grep 'BUILT_PRODUCTS_DIR = /' | sed -e 's/^[^=]*= //g'`
33 |
34 | if [ -d "$BUILD_DIR/$COMMAND.app" ]; then
35 | exec $BUILD_DIR/$COMMAND.app/Contents/MacOS/$COMMAND "$@" &
36 | fi
37 |
38 | if [ -f "$BUILD_DIR/$COMMAND" ]; then
39 | export DYLD_FRAMEWORK_PATH=$BUILD_DIR/PackageFrameworks:$BUILD_DIR
40 | exec "$BUILD_DIR/$COMMAND" "$@"
41 | else
42 | echo "$BUILD_DIR/$COMMAND does not exist -- check build configuration ($CONFIGURATION)"
43 | exit 1
44 | fi
45 |
46 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Views/Toolbar/DownloadProgressView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // DownloadProgressView.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 21.04.2025.
6 | //
7 |
8 | import SwiftUI
9 |
10 | struct DownloadProgressView: View {
11 | let progress: Progress
12 |
13 | @State private var isShowingDownload = false
14 |
15 | var body: some View {
16 | Button {
17 | isShowingDownload = true
18 | } label: {
19 | Image(systemName: "arrow.down.square")
20 | .foregroundStyle(.tint)
21 | }
22 | .popover(isPresented: $isShowingDownload, arrowEdge: .bottom) {
23 | VStack {
24 | ProgressView(value: progress.fractionCompleted) {
25 | HStack {
26 | Text(progress.localizedAdditionalDescription)
27 | .bold()
28 | Spacer()
29 | Text(progress.localizedDescription)
30 | }
31 | }
32 |
33 | Text("The model is downloading")
34 | .padding(.horizontal, 32)
35 | }
36 | .padding()
37 | }
38 | }
39 | }
40 |
41 | #Preview {
42 | DownloadProgressView(progress: Progress(totalUnitCount: 6))
43 | }
44 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | },
8 | {
9 | "idiom" : "mac",
10 | "scale" : "1x",
11 | "size" : "16x16"
12 | },
13 | {
14 | "idiom" : "mac",
15 | "scale" : "2x",
16 | "size" : "16x16"
17 | },
18 | {
19 | "idiom" : "mac",
20 | "scale" : "1x",
21 | "size" : "32x32"
22 | },
23 | {
24 | "idiom" : "mac",
25 | "scale" : "2x",
26 | "size" : "32x32"
27 | },
28 | {
29 | "idiom" : "mac",
30 | "scale" : "1x",
31 | "size" : "128x128"
32 | },
33 | {
34 | "idiom" : "mac",
35 | "scale" : "2x",
36 | "size" : "128x128"
37 | },
38 | {
39 | "idiom" : "mac",
40 | "scale" : "1x",
41 | "size" : "256x256"
42 | },
43 | {
44 | "idiom" : "mac",
45 | "scale" : "2x",
46 | "size" : "256x256"
47 | },
48 | {
49 | "idiom" : "mac",
50 | "scale" : "1x",
51 | "size" : "512x512"
52 | },
53 | {
54 | "idiom" : "mac",
55 | "scale" : "2x",
56 | "size" : "512x512"
57 | }
58 | ],
59 | "info" : {
60 | "author" : "xcode",
61 | "version" : 1
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/Applications/LoRATrainingExample/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | },
8 | {
9 | "idiom" : "mac",
10 | "scale" : "1x",
11 | "size" : "16x16"
12 | },
13 | {
14 | "idiom" : "mac",
15 | "scale" : "2x",
16 | "size" : "16x16"
17 | },
18 | {
19 | "idiom" : "mac",
20 | "scale" : "1x",
21 | "size" : "32x32"
22 | },
23 | {
24 | "idiom" : "mac",
25 | "scale" : "2x",
26 | "size" : "32x32"
27 | },
28 | {
29 | "idiom" : "mac",
30 | "scale" : "1x",
31 | "size" : "128x128"
32 | },
33 | {
34 | "idiom" : "mac",
35 | "scale" : "2x",
36 | "size" : "128x128"
37 | },
38 | {
39 | "idiom" : "mac",
40 | "scale" : "1x",
41 | "size" : "256x256"
42 | },
43 | {
44 | "idiom" : "mac",
45 | "scale" : "2x",
46 | "size" : "256x256"
47 | },
48 | {
49 | "idiom" : "mac",
50 | "scale" : "1x",
51 | "size" : "512x512"
52 | },
53 | {
54 | "idiom" : "mac",
55 | "scale" : "2x",
56 | "size" : "512x512"
57 | }
58 | ],
59 | "info" : {
60 | "author" : "xcode",
61 | "version" : 1
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | },
8 | {
9 | "idiom" : "mac",
10 | "scale" : "1x",
11 | "size" : "16x16"
12 | },
13 | {
14 | "idiom" : "mac",
15 | "scale" : "2x",
16 | "size" : "16x16"
17 | },
18 | {
19 | "idiom" : "mac",
20 | "scale" : "1x",
21 | "size" : "32x32"
22 | },
23 | {
24 | "idiom" : "mac",
25 | "scale" : "2x",
26 | "size" : "32x32"
27 | },
28 | {
29 | "idiom" : "mac",
30 | "scale" : "1x",
31 | "size" : "128x128"
32 | },
33 | {
34 | "idiom" : "mac",
35 | "scale" : "2x",
36 | "size" : "128x128"
37 | },
38 | {
39 | "idiom" : "mac",
40 | "scale" : "1x",
41 | "size" : "256x256"
42 | },
43 | {
44 | "idiom" : "mac",
45 | "scale" : "2x",
46 | "size" : "256x256"
47 | },
48 | {
49 | "idiom" : "mac",
50 | "scale" : "1x",
51 | "size" : "512x512"
52 | },
53 | {
54 | "idiom" : "mac",
55 | "scale" : "2x",
56 | "size" : "512x512"
57 | }
58 | ],
59 | "info" : {
60 | "author" : "xcode",
61 | "version" : 1
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/Applications/StableDiffusionExample/README.md:
--------------------------------------------------------------------------------
1 | # StableDiffusionExample
2 |
3 | An example application that runs the StableDiffusion example code.
4 |
5 | See also [image-tool](../../Tools/image-tool) for a command line example.
6 |
7 | This example application accepts a prompt and used the StableDiffusion example
8 | library to render an image using:
9 |
10 | - [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo)
11 |
12 | Please refer to that model for license and other information.
13 |
14 | If you are interested in adjusting the generated images, look in
15 | [ContentView.swift](ContentView.swift) at this method:
16 |
17 | ```swift
18 | func generate(prompt: String, negativePrompt: String, showProgress: Bool) async
19 | ```
20 |
21 | ### Troubleshooting
22 |
23 | Stable diffusion can run in less that 4G available memory (typically a
24 | device or computer with 6G of memory or more) in a constrained mode -- it will
25 | load and unload parts of the model as it runs and it can only perform one step
26 | of diffusion. This is configured automatically, see `modelFactory.conserveMemory`
27 | in [ContentView.swift](ContentView.swift).
28 |
29 | On a device or computer with more memory the model will be kept resident and
30 | images can be regenerated much more efficiently.
31 |
32 | If the program exits while generating the image it may have exceeded the available
33 | memory.
34 |
--------------------------------------------------------------------------------
/Tools/ExampleLLM/README.md:
--------------------------------------------------------------------------------
1 | # ExampleLLM
2 |
3 | An example that uses the simplified APIs to load and evaluate an LLM in only a few lines of
4 | code:
5 |
6 | ```swift
7 | let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit")
8 | let session = ChatSession(model)
9 | print(try await session.respond(to: "What are two things to see in San Francisco?")
10 | print(try await session.respond(to: "How about a great place to eat?")
11 | ```
12 |
13 | See various READMEs:
14 |
15 | - [MLXLMCommon](https://github.com/ml-explore/mlx-swift-lm/Libraries/MLXLMCommon/README.md) -- common LM code
16 | - [MLXLLM](https://github.com/ml-explore/mlx-swift-lm/Libraries/MLXLLM/README.md) -- large language models
17 | - [MLXVLM](https://github.com/ml-explore/mlx-swift-lm/Libraries/MLXVLM/README.md) -- vision language models
18 |
19 | ### Building
20 |
21 | Build the `ExampleLLM` scheme in Xcode.
22 |
23 | ### Running: Xcode
24 |
25 | Just press cmd-r to run!
26 |
27 | ### Running: Command Line
28 |
29 | Use the `mlx-run` script to run the command line tools:
30 |
31 | ```
32 | ./mlx-run ExampleLLM
33 | ```
34 |
35 | Note: `mlx-run` is a shell script that uses `xcode` command line tools to
36 | locate the built binaries. It is equivalent to running from Xcode itself.
37 |
38 | By default this will find and run the tools built in _Release_ configuration. Specify `--debug`
39 | to find and run the tool built in _Debug_ configuration.
40 |
41 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/PoolingArguments.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import ArgumentParser
4 | import MLXEmbedders
5 |
6 | struct PoolingArguments: ParsableArguments {
7 |
8 | @Option(
9 | name: .long, help: "Pooling strategy used to collapse token embeddings (default: mean).")
10 | var strategy: Pooling.Strategy?
11 |
12 | @Flag(
13 | name: .long, inversion: .prefixedNo,
14 | help:
15 | "Normalize pooled embeddings to unit length (default: true). Use --no-normalize to disable."
16 | )
17 | var normalize = true
18 |
19 | @Flag(name: .long, help: "Apply layer normalization before pooling.")
20 | var layerNorm = false
21 | }
22 |
23 | extension PoolingArguments {
24 | var strategyOverride: Pooling.Strategy? {
25 | strategy ?? .mean
26 | }
27 | }
28 |
29 | extension Pooling.Strategy: @retroactive CaseIterable {
30 | public static var allCases: [Pooling.Strategy] {
31 | [.mean, .cls, .first, .last, .max, .none]
32 | }
33 | }
34 |
35 | extension Pooling.Strategy: @retroactive ExpressibleByArgument {
36 | public init?(argument: String) {
37 | switch argument.lowercased() {
38 | case "mean": self = .mean
39 | case "cls": self = .cls
40 | case "first": self = .first
41 | case "last": self = .last
42 | case "max": self = .max
43 | case "none": self = .none
44 | default: return nil
45 | }
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/Tools/llm-tool/ListCommands.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 | import MLXLLM
6 | import MLXVLM
7 |
8 | struct ListCommands: AsyncParsableCommand {
9 |
10 | static let configuration = CommandConfiguration(
11 | commandName: "list",
12 | abstract: "list registered model configurations",
13 | subcommands: [
14 | ListLLMCommand.self, ListVLMCommand.self,
15 | ]
16 | )
17 | }
18 |
19 | struct ListLLMCommand: AsyncParsableCommand {
20 |
21 | static let configuration = CommandConfiguration(
22 | commandName: "llms",
23 | abstract: "List registered LLM model configurations"
24 | )
25 |
26 | func run() async throws {
27 | for configuration in LLMRegistry.shared.models {
28 | switch configuration.id {
29 | case .id(let id): print(id)
30 | case .directory: break
31 | }
32 | }
33 | }
34 | }
35 |
36 | struct ListVLMCommand: AsyncParsableCommand {
37 |
38 | static let configuration = CommandConfiguration(
39 | commandName: "vlms",
40 | abstract: "List registered VLM model configurations"
41 | )
42 |
43 | func run() async throws {
44 | for configuration in VLMRegistry.shared.models {
45 | switch configuration.id {
46 | case .id(let id): print(id)
47 | case .directory: break
48 | }
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Views/Toolbar/ChatToolbarView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // ChatToolbarView.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 21.04.2025.
6 | //
7 |
8 | import SwiftUI
9 |
10 | /// Toolbar view for the chat interface that displays error messages, download progress,
11 | /// generation statistics, and model selection controls.
12 | struct ChatToolbarView: View {
13 | /// View model containing the chat state and controls
14 | @Bindable var vm: ChatViewModel
15 |
16 | var body: some View {
17 | // Display error message if present
18 | if let errorMessage = vm.errorMessage {
19 | ErrorView(errorMessage: errorMessage)
20 | }
21 |
22 | // Show download progress for model loading
23 | if let progress = vm.modelDownloadProgress, !progress.isFinished {
24 | DownloadProgressView(progress: progress)
25 | }
26 |
27 | // Button to clear chat history, displays generation statistics
28 | Button {
29 | vm.clear([.chat, .meta])
30 | } label: {
31 | GenerationInfoView(
32 | tokensPerSecond: vm.tokensPerSecond
33 | )
34 | }
35 |
36 | // Model selection picker
37 | Picker("Model", selection: $vm.selectedModel) {
38 | ForEach(MLXService.availableModels) { model in
39 | Text(model.displayName)
40 | .tag(model)
41 | }
42 | }
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Views/PromptField.swift:
--------------------------------------------------------------------------------
1 | //
2 | // PromptField.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 20.04.2025.
6 | //
7 |
8 | import SwiftUI
9 |
10 | struct PromptField: View {
11 | @Binding var prompt: String
12 | @State private var task: Task?
13 |
14 | let sendButtonAction: () async -> Void
15 | let mediaButtonAction: (() -> Void)?
16 |
17 | var body: some View {
18 | HStack {
19 | if let mediaButtonAction {
20 | Button(action: mediaButtonAction) {
21 | Image(systemName: "photo.badge.plus")
22 | }
23 | }
24 |
25 | TextField("Prompt", text: $prompt)
26 | .textFieldStyle(.roundedBorder)
27 |
28 | Button {
29 | if isRunning {
30 | task?.cancel()
31 | removeTask()
32 | } else {
33 | task = Task {
34 | await sendButtonAction()
35 | removeTask()
36 | }
37 | }
38 | } label: {
39 | Image(systemName: isRunning ? "stop.circle.fill" : "paperplane.fill")
40 | }
41 | .keyboardShortcut(isRunning ? .cancelAction : .defaultAction)
42 | }
43 | }
44 |
45 | private var isRunning: Bool {
46 | task != nil && !(task!.isCancelled)
47 | }
48 |
49 | private func removeTask() {
50 | task = nil
51 | }
52 | }
53 |
54 | #Preview {
55 | PromptField(prompt: .constant("")) {
56 | } mediaButtonAction: {
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Models/LMModel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // LMModel.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 21.04.2025.
6 | //
7 |
8 | import MLXLMCommon
9 |
10 | /// Represents a language model configuration with its associated properties and type.
11 | /// Can represent either a large language model (LLM) or a vision-language model (VLM).
12 | struct LMModel {
13 | /// Name of the model
14 | let name: String
15 |
16 | /// Configuration settings for model initialization
17 | let configuration: ModelConfiguration
18 |
19 | /// Type of the model (language or vision-language)
20 | let type: ModelType
21 |
22 | /// Defines the type of language model
23 | enum ModelType {
24 | /// Large language model (text-only)
25 | case llm
26 | /// Vision-language model (supports images and text)
27 | case vlm
28 | }
29 | }
30 |
31 | // MARK: - Helpers
32 |
33 | extension LMModel {
34 | /// Display name with additional "(Vision)" suffix for vision models
35 | var displayName: String {
36 | if isVisionModel {
37 | "\(name) (Vision)"
38 | } else {
39 | name
40 | }
41 | }
42 |
43 | /// Whether the model is a large language model
44 | var isLanguageModel: Bool {
45 | type == .llm
46 | }
47 |
48 | /// Whether the model is a vision-language model
49 | var isVisionModel: Bool {
50 | type == .vlm
51 | }
52 | }
53 |
54 | extension LMModel: Identifiable, Hashable {
55 | var id: String {
56 | name
57 | }
58 |
59 | func hash(into hasher: inout Hasher) {
60 | hasher.combine(name)
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/Applications/VLMEval/README.md:
--------------------------------------------------------------------------------
1 | # VLMEval
2 |
3 | An example that:
4 |
5 | - downloads a vision language model (SmolVLM2)
6 | - processes an image or a video with a prompt
7 |
8 | You will need to set the Team on the VLMEval target in order to build and run on macOS.
9 |
10 | Some notes about the setup:
11 |
12 | - This downloads models from hugging face so VLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
13 | - VLM models are large so this uses significant memory
14 | - The example can process image, video and provides detailed analysis
15 |
16 | ### Image Processing
17 |
18 | The example application uses SmolVLM2 model by default, see [ContentView.swift](ContentView.swift):
19 |
20 | ```swift
21 | self.modelContainer = try await VLMModelFactory.shared.loadContainer(
22 | configuration: VLMRegistry.smolvlm)
23 | ```
24 |
25 | The application:
26 | 1. Downloads a sample image
27 | 2. Processes it through the vision language model
28 | 3. Describes the images based on the prompt, providing detailed analysis of the content, objects, colors, and composition.
29 |
30 | ### Troubleshooting
31 |
32 | If the program crashes with a very deep stack trace you may need to build
33 | in Release configuration. This seems to depend on the size of the model.
34 |
35 | There are a couple options:
36 |
37 | - Build Release
38 | - Force the model evaluation to run on the main thread, e.g. using @MainActor
39 | - Build `Cmlx` with optimizations by modifying `mlx/Package.swift` and adding `.unsafeOptions(["-O3"]),`
40 |
41 | ### Performance
42 |
43 | You may find that running outside the debugger boosts performance. You can do this in Xcode by pressing cmd-opt-r and unchecking "Debug Executable".
44 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/EmbedderCommand.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 |
6 | /// A protocol for commands that need to load and use an embedder model.
7 | ///
8 | /// This protocol centralizes the logic for loading an `EmbedderRuntime`
9 | /// and managing memory statistics, reducing boilerplate in individual commands.
10 | protocol EmbedderCommand: AsyncParsableCommand {
11 | /// The model arguments, captured via `@OptionGroup`.
12 | var model: ModelArguments { get }
13 |
14 | /// The pooling arguments, captured via `@OptionGroup`.
15 | var pooling: PoolingArguments { get }
16 |
17 | /// The memory management arguments, captured via `@OptionGroup`.
18 | var memory: MemoryArguments { get set }
19 |
20 | /// The core logic of the command, which receives a fully initialized `EmbedderRuntime`.
21 | /// - Parameter runtime: The loaded and configured embedder runtime.
22 | mutating func run(runtime: EmbedderRuntime) async throws
23 | }
24 |
25 | extension EmbedderCommand {
26 | /// The main entry point for the command.
27 | ///
28 | /// This default implementation handles the loading of the embedder runtime
29 | /// and memory reporting, then calls the command's specific `run(runtime:)` method.
30 | mutating func run() async throws {
31 | var memory = self.memory
32 | let capturedModel = model
33 | let capturedPooling = pooling
34 |
35 | let runtime = try await memory.start {
36 | try await EmbedderTool.loadRuntime(model: capturedModel, pooling: capturedPooling)
37 | }
38 |
39 | defer {
40 | memory.reportMemoryStatistics()
41 | self.memory = memory
42 | }
43 |
44 | try await run(runtime: runtime)
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Views/OutputView.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import MarkdownUI
4 | import SwiftUI
5 |
6 | struct OutputView: View {
7 | let output: String
8 | let displayStyle: ContentView.DisplayStyle
9 | let wasTruncated: Bool
10 |
11 | var body: some View {
12 | ScrollView(.vertical) {
13 | ScrollViewReader { sp in
14 | VStack(alignment: .leading, spacing: 12) {
15 | Group {
16 | if displayStyle == .plain {
17 | Text(output)
18 | .textSelection(.enabled)
19 | } else {
20 | Markdown(output)
21 | .textSelection(.enabled)
22 | }
23 | }
24 |
25 | // Warning banner when output is truncated
26 | if wasTruncated && !output.isEmpty {
27 | HStack(spacing: 8) {
28 | Image(systemName: "exclamationmark.triangle.fill")
29 | .foregroundStyle(.orange)
30 | Text("Output truncated: Maximum token limit reached")
31 | .font(.caption)
32 | .foregroundStyle(.secondary)
33 | }
34 | .padding(8)
35 | .background(.orange.opacity(0.1), in: RoundedRectangle(cornerRadius: 6))
36 | }
37 | }
38 | .onChange(of: output) { _, _ in
39 | sp.scrollTo("bottom")
40 | }
41 |
42 | Spacer()
43 | .frame(width: 1, height: 1)
44 | .id("bottom")
45 | }
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/Package.resolved:
--------------------------------------------------------------------------------
1 | {
2 | "pins" : [
3 | {
4 | "identity" : "gzipswift",
5 | "kind" : "remoteSourceControl",
6 | "location" : "https://github.com/1024jp/GzipSwift",
7 | "state" : {
8 | "revision" : "731037f6cc2be2ec01562f6597c1d0aa3fe6fd05",
9 | "version" : "6.0.1"
10 | }
11 | },
12 | {
13 | "identity" : "mlx-swift",
14 | "kind" : "remoteSourceControl",
15 | "location" : "https://github.com/ml-explore/mlx-swift",
16 | "state" : {
17 | "revision" : "072b684acaae80b6a463abab3a103732f33774bf",
18 | "version" : "0.29.1"
19 | }
20 | },
21 | {
22 | "identity" : "swift-collections",
23 | "kind" : "remoteSourceControl",
24 | "location" : "https://github.com/apple/swift-collections.git",
25 | "state" : {
26 | "revision" : "c1805596154bb3a265fd91b8ac0c4433b4348fb0",
27 | "version" : "1.2.0"
28 | }
29 | },
30 | {
31 | "identity" : "swift-jinja",
32 | "kind" : "remoteSourceControl",
33 | "location" : "https://github.com/huggingface/swift-jinja.git",
34 | "state" : {
35 | "revision" : "c1ef5963ba4a97a589b9c9583ff4ee3352a86d23",
36 | "version" : "2.1.0"
37 | }
38 | },
39 | {
40 | "identity" : "swift-numerics",
41 | "kind" : "remoteSourceControl",
42 | "location" : "https://github.com/apple/swift-numerics",
43 | "state" : {
44 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b",
45 | "version" : "1.0.2"
46 | }
47 | },
48 | {
49 | "identity" : "swift-transformers",
50 | "kind" : "remoteSourceControl",
51 | "location" : "https://github.com/huggingface/swift-transformers",
52 | "state" : {
53 | "revision" : "94610577e4af9bbc267060af1e25e977604dd796",
54 | "version" : "1.1.1"
55 | }
56 | }
57 | ],
58 | "version" : 2
59 | }
60 |
--------------------------------------------------------------------------------
/Libraries/StableDiffusion/README.md:
--------------------------------------------------------------------------------
1 | # Stable Diffusion
2 |
3 | Stable Diffusion in MLX. The implementation was ported from Hugging Face's
4 | [diffusers](https://huggingface.co/docs/diffusers/index) and
5 | [mlx-examples/stable_diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
6 | Model weights are downloaded directly from the Hugging Face hub. The implementation currently
7 | supports the following models:
8 |
9 | - [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo)
10 | - [stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1)
11 |
12 | ## Usage
13 |
14 | See [StableDiffusionExample](../../Applications/StableDiffusionExample) and
15 | [image-tool](../../Tools/image-tool) for examples of using this code.
16 |
17 | The basic sequence is:
18 |
19 | - download & load the model
20 | - generate latents
21 | - evaluate the latents one by one
22 | - decode the last latent generated
23 | - you have an image!
24 |
25 | ```swift
26 | let configuration = StableDiffusionConfiguration.presetSDXLTurbo
27 |
28 | let generator = try configuration.textToImageGenerator(
29 | configuration: model.loadConfiguration)
30 |
31 | generator.ensureLoaded()
32 |
33 | // Generate the latents, which are the iterations for generating
34 | // the output image. This is just generating the evaluation graph
35 | let parameters = generate.evaluateParameters(configuration: configuration)
36 | let latents = generator.generateLatents(parameters: parameters)
37 |
38 | // evaluate the latents (evalue the graph) and keep the last value generated
39 | var lastXt: MLXArray?
40 | for xt in latents {
41 | eval(xt)
42 | lastXt = xt
43 | }
44 |
45 | // decode the final latent into an image
46 | if let lastXt {
47 | var raster = decoder(lastXt[0])
48 | raster = (image * 255).asType(.uint8).squeezed()
49 | eval(raster)
50 |
51 | // turn it into a CGImage
52 | let image = Image(raster).asCGImage()
53 |
54 | // or write it out
55 | try Image(raster).save(url: url)
56 | }
57 | ```
58 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Views/LoadingOverlayView.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | struct LoadingOverlayView: View {
6 | let modelInfo: String
7 | let downloadProgress: Double?
8 | let progressDescription: String?
9 |
10 | init(modelInfo: String, downloadProgress: Double? = nil, progressDescription: String? = nil) {
11 | self.modelInfo = modelInfo
12 | self.downloadProgress = downloadProgress
13 | self.progressDescription = progressDescription
14 | }
15 |
16 | var body: some View {
17 | ZStack {
18 | Color.black.opacity(0.4)
19 | .ignoresSafeArea()
20 |
21 | VStack(spacing: 16) {
22 | if let progress = downloadProgress, progress < 1.0 {
23 | ProgressView(value: progress)
24 | .progressViewStyle(.linear)
25 | .frame(width: 200)
26 | } else {
27 | ProgressView()
28 | .scaleEffect(1.5)
29 | .progressViewStyle(.circular)
30 | }
31 |
32 | Text(modelInfo)
33 | .font(.headline)
34 | .foregroundStyle(.primary)
35 |
36 | if let description = progressDescription {
37 | Text(description)
38 | .font(.subheadline)
39 | .foregroundStyle(.secondary)
40 | .monospacedDigit()
41 | }
42 |
43 | Text(
44 | "Models are large and may take a couple of minutes to download on first use. They are cached locally for faster loading in the future."
45 | )
46 | .font(.subheadline)
47 | .foregroundStyle(.secondary)
48 | .multilineTextAlignment(.center)
49 | .frame(maxWidth: 300)
50 | }
51 | .padding(32)
52 | .background(.regularMaterial, in: RoundedRectangle(cornerRadius: 16))
53 | }
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Models/PresetPrompts.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import Foundation
4 |
5 | struct PresetPrompt: Identifiable {
6 | let id = UUID()
7 | let prompt: String
8 | let enableTools: Bool
9 | let enableThinking: Bool
10 | let isLongPrompt: Bool
11 |
12 | init(
13 | _ prompt: String, enableTools: Bool = false, enableThinking: Bool = false,
14 | isLongPrompt: Bool = false
15 | ) {
16 | self.prompt = prompt
17 | self.enableTools = enableTools
18 | self.enableThinking = enableThinking
19 | self.isLongPrompt = isLongPrompt
20 | }
21 | }
22 |
23 | struct PresetPrompts {
24 | // Helper to load prompts from markdown files
25 | private static func loadPrompt(named fileName: String) -> String {
26 | guard let url = Bundle.main.url(forResource: fileName, withExtension: "md"),
27 | let content = try? String(contentsOf: url, encoding: .utf8)
28 | else {
29 | return "Could not load \(fileName).md. Please ensure it is included in the app bundle."
30 | }
31 | return content
32 | }
33 |
34 | static let all: [PresetPrompt] = [
35 | PresetPrompt("Why is the sky blue?"),
36 | PresetPrompt("What would a medieval knight's Yelp review of a dragon's lair look like?"),
37 | PresetPrompt("Explain why socks disappear in the dryer from the dryer's perspective."),
38 |
39 | PresetPrompt(
40 | "Write a breaking news report about cats discovering they can vote.",
41 | enableThinking: true),
42 | PresetPrompt(
43 | "Write a performance review for the person whose job is to make sure Mondays feel terrible.",
44 | enableThinking: true),
45 |
46 | PresetPrompt("What's the weather in Paris?", enableTools: true),
47 | PresetPrompt("What is the current time?", enableTools: true),
48 |
49 | PresetPrompt(loadPrompt(named: "LongPrompt"), enableThinking: true, isLongPrompt: true),
50 | PresetPrompt(loadPrompt(named: "CarKeysStory"), isLongPrompt: true),
51 | ]
52 | }
53 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/MemoryArguments.swift:
--------------------------------------------------------------------------------
1 | import ArgumentParser
2 | import MLX
3 |
4 | /// Argument package for adjusting and reporting GPU memory usage.
5 | struct MemoryArguments: ParsableArguments, Sendable {
6 |
7 | @Flag(name: .long, help: "Show GPU memory stats before exit.")
8 | var memoryStats = false
9 |
10 | @Option(name: .long, help: "Maximum GPU cache size in megabytes.")
11 | var cacheSize: Int?
12 |
13 | @Option(name: .long, help: "Maximum GPU memory size in megabytes.")
14 | var memorySize: Int?
15 |
16 | private(set) var startMemory: GPU.Snapshot?
17 |
18 | mutating func start(_ operation: @Sendable () async throws -> L) async throws -> L {
19 | applyLimits()
20 | let result = try await operation()
21 | startMemory = GPU.snapshot()
22 | return result
23 | }
24 |
25 | mutating func start() {
26 | applyLimits()
27 | startMemory = GPU.snapshot()
28 | }
29 |
30 | func reportCurrent() {
31 | guard memoryStats else { return }
32 | let memory = GPU.snapshot()
33 | print(memory.description)
34 | }
35 |
36 | func reportMemoryStatistics() {
37 | guard memoryStats, let startMemory else { return }
38 |
39 | let endMemory = GPU.snapshot()
40 |
41 | print("=======")
42 | print("GPU memory limit: \(GPU.memoryLimit / 1024)K")
43 | print("GPU cache limit: \(GPU.cacheLimit / 1024)K")
44 | print("")
45 | print("=======")
46 | print("Starting snapshot")
47 | print(startMemory.description)
48 | print("")
49 | print("=======")
50 | print("Ending snapshot")
51 | print(endMemory.description)
52 | print("")
53 | print("=======")
54 | print("Delta")
55 | print(startMemory.delta(endMemory).description)
56 | }
57 |
58 | private func applyLimits() {
59 | if let cacheSize {
60 | GPU.set(cacheLimit: cacheSize * 1024 * 1024)
61 | }
62 |
63 | if let memorySize {
64 | GPU.set(memoryLimit: memorySize * 1024 * 1024)
65 | }
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/Applications/LLMEval/README.md:
--------------------------------------------------------------------------------
1 | # LLMEval
2 |
3 | An example that:
4 |
5 | - downloads a huggingface model and tokenizer
6 | - evaluates a prompt
7 | - displays the output as it generates text
8 |
9 | You will need to set the Team on the LLMEval target in order to build and run on iOS.
10 |
11 | Some notes about the setup:
12 |
13 | - this downloads models from hugging face so LLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
14 | - LLM models are large so this uses the Increased Memory Limit entitlement on iOS to allow ... increased memory limits for devices that have more memory
15 | - `MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)` is used to limit the buffer cache size
16 |
17 | ### Trying Different Models
18 |
19 | The example app uses an 8 billion parameter quantized Qwen3 model by default, see [LLMEvaluator.swift](ViewModels/LLMEvaluator.swift#L52):
20 |
21 | ```
22 | var modelConfiguration = LLMRegistry.qwen3_8b_4bit
23 | ```
24 |
25 | There are some pre-configured models in [MLXLLM/LLMModelFactory.swift](../../Libraries/MLXLLM/LLMModelFactory.swift#L78)
26 | and you can load any weights from Hugging Face where there
27 | is a model architecture defined and you have enough
28 | memory.
29 |
30 | For example:
31 | ```
32 | /// phi4bit is one of the smaller models so will fit on more devices
33 | var modelConfiguration = LLMRegistry.phi4bit
34 | ```
35 |
36 | ### Troubleshooting
37 |
38 | If the program crashes with a very deep stack trace, you may need to build
39 | in Release configuration. This seems to depend on the size of the model.
40 |
41 | There are a couple options:
42 |
43 | - Build Release
44 | - Force the model evaluation to run on the main thread, e.g. using @MainActor
45 | - Build `Cmlx` with optimizations by modifying `mlx/Package.swift` and adding `.unsafeOptions(["-O3"]),` around line 87
46 |
47 | See discussion here: https://github.com/ml-explore/mlx-swift-examples/issues/3
48 |
49 | ### Performance
50 |
51 | Different models have difference performance characteristics. For example Gemma 2B may outperform Phi-2 in terms of tokens / second.
52 |
53 | You may also find that running outside the debugger boosts performance. You can do this in Xcode by pressing cmd-opt-r and unchecking "Debug Executable".
54 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Support/SampleData.swift:
--------------------------------------------------------------------------------
1 | //
2 | // SampleData.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 20.04.2025.
6 | //
7 |
8 | @MainActor
9 | struct SampleData {
10 | static let conversation: [Message] = [
11 | .system("You are a helpful assistant specializing in SwiftUI development."),
12 | .user("I need help building a weather app in SwiftUI. Where should I start?"),
13 | .assistant(
14 | "I'll help you create a weather app! Let's break it down into steps. First, we'll need to design the main view to display current weather conditions. Would you like to start with that?"
15 | ),
16 | .user("Yes, that sounds good. What components should I use for the main view?"),
17 | .assistant(
18 | "For the main weather view, I recommend using a VStack as the container. You can include:\n\n1. An Image view for the weather icon\n2. Text views for temperature and conditions\n3. HStack for additional metrics like humidity and wind speed"
19 | ),
20 | .user("How do I make the UI look modern and polished?"),
21 | .assistant(
22 | "To create a modern UI, try these techniques:\n\n- Use SF Symbols for weather icons\n- Add subtle gradients with .background()\n- Include padding and spacing for better layout\n- Implement dark mode support\n\nWould you like to see some example code?"
23 | ),
24 | .user("Yes, please show me an example for the main weather view."),
25 | .assistant(
26 | "Here's a basic example:\n\nVStack(spacing: 20) {\n Image(systemName: \"sun.max.fill\")\n .symbolRenderingMode(.multicolor)\n .font(.system(size: 64))\n \n Text(\"72°\")\n .font(.largeTitle)\n .bold()\n \n Text(\"Sunny\")\n .font(.title2)\n}\n.padding()\n.background(.ultraThinMaterial)\n.clipShape(RoundedRectangle(cornerRadius: 20))"
27 | ),
28 | .user("That looks great! How would I add animations to this?"),
29 | .assistant(
30 | "We can add smooth animations using SwiftUI's animation modifiers. For example:\n\n1. Use withAnimation for state changes\n2. Add .animation() modifier to views\n3. Implement transitions\n\nWould you like to see how to animate the weather changes?"
31 | ),
32 | ]
33 | }
34 |
--------------------------------------------------------------------------------
/Tests/MLXLMTests/StreamlinedTests.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import Foundation
4 | import MLX
5 | import MLXLLM
6 | import MLXLMCommon
7 | import MLXNN
8 | import MLXOptimizers
9 | import Tokenizers
10 | import XCTest
11 |
12 | /// Tests for the streamlined API
13 | public class StreamlinedTests: XCTestCase {
14 |
15 | /// for tests we don't download a model but do execute one
16 | func model() -> LanguageModel {
17 | let config = LlamaConfiguration(
18 | hiddenSize: 64, hiddenLayers: 16, intermediateSize: 64, attentionHeads: 8,
19 | rmsNormEps: 0.00001, vocabularySize: 100, kvHeads: 8)
20 | let model = LlamaModel(config)
21 | quantize(model: model, groupSize: 64, bits: 4)
22 | return model
23 | }
24 |
25 | /// This is equivalent to:
26 | ///
27 | /// ```swift
28 | /// let model = LLMModelFactory.load("test")
29 | /// ```
30 | func modelContainer() -> ModelContainer {
31 | let context = ModelContext(
32 | configuration: .init(id: "test", extraEOSTokens: ["EOS"]), model: model(),
33 | processor: TestUserInputProcessor(), tokenizer: TestTokenizer())
34 | return ModelContainer(context: context)
35 | }
36 |
37 | func testOneShot() async throws {
38 | let model = modelContainer()
39 | let result = try await ChatSession(model).respond(to: "Tell me about things")
40 | print(result)
41 | }
42 |
43 | func testOneShotStream() async throws {
44 | let model = modelContainer()
45 | for try await token in ChatSession(model).streamResponse(to: "Tell me about things") {
46 | print(token, terminator: "")
47 | }
48 | }
49 |
50 | func testChat() async throws {
51 | let model = modelContainer()
52 | let session = ChatSession(model)
53 |
54 | print(try await session.respond(to: "what color is the sky?"))
55 | print(try await session.respond(to: "why is that?"))
56 | print(try await session.respond(to: "describe this image", image: .ciImage(CIImage.red)))
57 | }
58 | }
59 |
60 | private struct TestUserInputProcessor: UserInputProcessor {
61 | func prepare(input: UserInput) throws -> LMInput {
62 | LMInput(tokens: MLXRandom.randInt(0 ..< 1000, [100]))
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/ChatView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // ChatView.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 20.04.2025.
6 | //
7 |
8 | import AVFoundation
9 | import AVKit
10 | import SwiftUI
11 |
12 | /// Main chat interface view that manages the conversation UI and user interactions.
13 | /// Displays messages, handles media attachments, and provides input controls.
14 | struct ChatView: View {
15 | /// View model that manages the chat state and business logic
16 | @Bindable private var vm: ChatViewModel
17 |
18 | /// Initializes the chat view with a view model
19 | /// - Parameter viewModel: The view model to manage chat state
20 | init(viewModel: ChatViewModel) {
21 | self.vm = viewModel
22 | }
23 |
24 | var body: some View {
25 | NavigationStack {
26 | VStack(spacing: 0) {
27 | // Display conversation history
28 | ConversationView(messages: vm.messages)
29 |
30 | Divider()
31 |
32 | // Show media previews if attachments are present
33 | if !vm.mediaSelection.isEmpty {
34 | MediaPreviewsView(mediaSelection: vm.mediaSelection)
35 | }
36 |
37 | // Input field with send and media attachment buttons
38 | PromptField(
39 | prompt: $vm.prompt,
40 | sendButtonAction: vm.generate,
41 | // Only show media button for vision-capable models
42 | mediaButtonAction: vm.selectedModel.isVisionModel
43 | ? {
44 | vm.mediaSelection.isShowing = true
45 | } : nil
46 | )
47 | .padding()
48 | }
49 | .navigationTitle("MLX Chat Example")
50 | .toolbar {
51 | ChatToolbarView(vm: vm)
52 | }
53 | // Handle media file selection
54 | .fileImporter(
55 | isPresented: $vm.mediaSelection.isShowing,
56 | allowedContentTypes: [.image, .movie],
57 | onCompletion: vm.addMedia
58 | )
59 | }
60 | }
61 | }
62 |
63 | #Preview {
64 | ChatView(viewModel: ChatViewModel(mlxService: MLXService()))
65 | }
66 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Xcode
2 | #
3 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
4 |
5 | ## User settings
6 | xcuserdata/
7 |
8 | ## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9)
9 | *.xcscmblueprint
10 | *.xccheckout
11 |
12 | ## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4)
13 | build/
14 | DerivedData/
15 | *.moved-aside
16 | *.pbxuser
17 | !default.pbxuser
18 | *.mode1v3
19 | !default.mode1v3
20 | *.mode2v3
21 | !default.mode2v3
22 | *.perspectivev3
23 | !default.perspectivev3
24 |
25 | ## Obj-C/Swift specific
26 | *.hmap
27 |
28 | ## App packaging
29 | *.ipa
30 | *.dSYM.zip
31 | *.dSYM
32 |
33 | ## Playgrounds
34 | timeline.xctimeline
35 | playground.xcworkspace
36 |
37 | # Swift Package Manager
38 | #
39 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies.
40 | Packages/
41 | Package.pins
42 | Package.resolved
43 | # *.xcodeproj
44 | #
45 | # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata
46 | # hence it is not needed unless you have added a package configuration file to your project
47 | .swiftpm
48 |
49 | .build/
50 |
51 | # CocoaPods
52 | #
53 | # We recommend against adding the Pods directory to your .gitignore. However
54 | # you should judge for yourself, the pros and cons are mentioned at:
55 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control
56 | #
57 | # Pods/
58 | #
59 | # Add this line if you want to avoid checking in source code from the Xcode workspace
60 | # *.xcworkspace
61 |
62 | # Carthage
63 | #
64 | # Add this line if you want to avoid checking in source code from Carthage dependencies.
65 | # Carthage/Checkouts
66 |
67 | Carthage/Build/
68 |
69 | # Accio dependency management
70 | Dependencies/
71 | .accio/
72 |
73 | # fastlane
74 | #
75 | # It is recommended to not store the screenshots in the git repo.
76 | # Instead, use fastlane to re-generate the screenshots whenever they are needed.
77 | # For more information about the recommended setup visit:
78 | # https://docs.fastlane.tools/best-practices/source-control/#source-control
79 |
80 | fastlane/report.xml
81 | fastlane/Preview.html
82 | fastlane/screenshots/**/*.png
83 | fastlane/test_output
84 |
85 | # Code Injection
86 | #
87 | # After new code Injection tools there's a generated folder /iOSInjectionProject
88 | # https://github.com/johnno1962/injectionforxcode
89 |
90 | iOSInjectionProject/
91 |
92 | # OS
93 | .DS_Store
94 |
95 | .idea
96 | .vscode
97 |
98 |
--------------------------------------------------------------------------------
/Tests/MLXLMTests/BaseConfigurationTests.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import Foundation
4 | import MLXLMCommon
5 | import XCTest
6 |
7 | public class BaseConfigurationTests: XCTestCase {
8 |
9 | func testQuantization() throws {
10 | let json =
11 | """
12 | {
13 | "model_type": "Test",
14 | "quantization": {
15 | "group_size": 128,
16 | "bits": 4
17 | }
18 | }
19 | """
20 |
21 | let config = try JSONDecoder().decode(
22 | BaseConfiguration.self, from: json.data(using: .utf8)!)
23 |
24 | XCTAssertEqual(config.quantization, .init(groupSize: 128, bits: 4))
25 | XCTAssertEqual(
26 | config.perLayerQuantization?.quantization(layer: "x"), .init(groupSize: 128, bits: 4))
27 | }
28 |
29 | func testHeterogenousQuantization() throws {
30 | // from https://huggingface.co/mlx-community/Qwen3-1.7B-4bit-AWQ/blob/main/config.json#L20
31 | let json =
32 | """
33 | {
34 | "model_type": "Test",
35 | "quantization": {
36 | "group_size": 64,
37 | "bits": 4,
38 | "model.embed_tokens": {
39 | "group_size": 32,
40 | "bits": 4
41 | },
42 | "model.layers.0.self_attn.q_norm": false,
43 | "true_layer": true
44 | }
45 | }
46 | """
47 |
48 | let config = try JSONDecoder().decode(
49 | BaseConfiguration.self, from: json.data(using: .utf8)!)
50 |
51 | XCTAssertEqual(config.quantization, .init(groupSize: 64, bits: 4))
52 |
53 | // a random layer -- no specific configuration gets default
54 | XCTAssertEqual(
55 | config.perLayerQuantization?.quantization(layer: "x"),
56 | .init(groupSize: 64, bits: 4))
57 |
58 | // layer with an override
59 | XCTAssertEqual(
60 | config.perLayerQuantization?.quantization(layer: "model.embed_tokens"),
61 | .init(groupSize: 32, bits: 4))
62 |
63 | // layer with an override -- not quant
64 | XCTAssertNil(
65 | config.perLayerQuantization?.quantization(layer: "model.layers.0.self_attn.q_norm"))
66 |
67 | // layer with an override -- true, use the default
68 | XCTAssertEqual(
69 | config.perLayerQuantization?.quantization(layer: "true_layer"),
70 | .init(groupSize: 64, bits: 4))
71 | }
72 |
73 | }
74 |
--------------------------------------------------------------------------------
/Libraries/MLXMNIST/MNIST.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import MLX
5 | import MLXNN
6 |
7 | // based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
8 |
9 | public class LeNet: Module, UnaryLayer {
10 |
11 | @ModuleInfo var conv1: Conv2d
12 | @ModuleInfo var conv2: Conv2d
13 | @ModuleInfo var pool1: MaxPool2d
14 | @ModuleInfo var pool2: MaxPool2d
15 | @ModuleInfo var fc1: Linear
16 | @ModuleInfo var fc2: Linear
17 | @ModuleInfo var fc3: Linear
18 |
19 | override public init() {
20 | conv1 = Conv2d(inputChannels: 1, outputChannels: 6, kernelSize: 5, padding: 2)
21 | conv2 = Conv2d(inputChannels: 6, outputChannels: 16, kernelSize: 5, padding: 0)
22 | pool1 = MaxPool2d(kernelSize: 2, stride: 2)
23 | pool2 = MaxPool2d(kernelSize: 2, stride: 2)
24 | fc1 = Linear(16 * 5 * 5, 120)
25 | fc2 = Linear(120, 84)
26 | fc3 = Linear(84, 10)
27 | }
28 |
29 | public func callAsFunction(_ x: MLXArray) -> MLXArray {
30 | var x = x
31 | x = pool1(tanh(conv1(x)))
32 | x = pool2(tanh(conv2(x)))
33 | x = flattened(x, start: 1)
34 | x = tanh(fc1(x))
35 | x = tanh(fc2(x))
36 | x = fc3(x)
37 | return x
38 | }
39 | }
40 |
41 | public func loss(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
42 | crossEntropy(logits: model(x), targets: y, reduction: .mean)
43 | }
44 |
45 | public func eval(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
46 | mean(argMax(model(x), axis: 1) .== y)
47 | }
48 |
49 | private struct BatchSequence: Sequence, IteratorProtocol {
50 |
51 | let batchSize: Int
52 | let x: MLXArray
53 | let y: MLXArray
54 |
55 | let indexes: MLXArray
56 | var index = 0
57 |
58 | init(batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator)
59 | {
60 | self.batchSize = batchSize
61 | self.x = x
62 | self.y = y
63 | self.indexes = MLXArray(Array(0 ..< y.size).shuffled(using: &generator))
64 | }
65 |
66 | mutating func next() -> (MLXArray, MLXArray)? {
67 | guard index < y.size else { return nil }
68 |
69 | let range = index ..< Swift.min(index + batchSize, y.size)
70 | index += batchSize
71 | let ids = indexes[range]
72 | return (x[ids], y[ids])
73 | }
74 | }
75 |
76 | public func iterateBatches(
77 | batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator
78 | ) -> some Sequence<(MLXArray, MLXArray)> {
79 | BatchSequence(batchSize: batchSize, x: x, y: y, using: &generator)
80 | }
81 |
--------------------------------------------------------------------------------
/Package.swift:
--------------------------------------------------------------------------------
1 | // swift-tools-version: 5.9
2 | // The swift-tools-version declares the minimum version of Swift required to build this package.
3 |
4 | import PackageDescription
5 |
6 | let package = Package(
7 | name: "mlx-libraries",
8 | platforms: [.macOS(.v14), .iOS(.v16)],
9 | products: [
10 | .library(
11 | name: "MLXMNIST",
12 | targets: ["MLXMNIST"]),
13 | .library(
14 | name: "StableDiffusion",
15 | targets: ["StableDiffusion"]),
16 | ],
17 | dependencies: [
18 | .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.29.1")),
19 | .package(
20 | url: "https://github.com/huggingface/swift-transformers",
21 | .upToNextMinor(from: "1.1.0")
22 | ),
23 | .package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"), // Only needed by MLXMNIST
24 | ],
25 | targets: [
26 | .target(
27 | name: "MLXMNIST",
28 | dependencies: [
29 | .product(name: "MLX", package: "mlx-swift"),
30 | .product(name: "MLXFast", package: "mlx-swift"),
31 | .product(name: "MLXNN", package: "mlx-swift"),
32 | .product(name: "MLXOptimizers", package: "mlx-swift"),
33 | .product(name: "MLXRandom", package: "mlx-swift"),
34 | .product(name: "Transformers", package: "swift-transformers"),
35 | .product(name: "Gzip", package: "GzipSwift"),
36 | ],
37 | path: "Libraries/MLXMNIST",
38 | exclude: [
39 | "README.md"
40 | ],
41 | swiftSettings: [
42 | .enableExperimentalFeature("StrictConcurrency")
43 | ]
44 | ),
45 | .target(
46 | name: "StableDiffusion",
47 | dependencies: [
48 | .product(name: "MLX", package: "mlx-swift"),
49 | .product(name: "MLXNN", package: "mlx-swift"),
50 | .product(name: "MLXRandom", package: "mlx-swift"),
51 | .product(name: "Transformers", package: "swift-transformers"),
52 | ],
53 | path: "Libraries/StableDiffusion",
54 | exclude: [
55 | "README.md"
56 | ],
57 | swiftSettings: [
58 | .enableExperimentalFeature("StrictConcurrency")
59 | ]
60 | ),
61 | ]
62 | )
63 |
64 | if Context.environment["MLX_SWIFT_BUILD_DOC"] == "1"
65 | || Context.environment["SPI_GENERATE_DOCS"] == "1"
66 | {
67 | // docc builder
68 | package.dependencies.append(
69 | .package(url: "https://github.com/apple/swift-docc-plugin", from: "1.3.0")
70 | )
71 | }
72 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Views/PromptInputView.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | struct PromptInputView: View {
6 | @Bindable var llm: LLMEvaluator
7 | @Binding var isPromptExpanded: Bool
8 | @Binding var showingPresetPrompts: Bool
9 |
10 | let onGenerate: () -> Void
11 | let onCancel: () -> Void
12 |
13 | var body: some View {
14 | VStack(alignment: .leading, spacing: 8) {
15 | // Prompt header with expand/collapse chevron
16 | HStack {
17 | Text("Prompt")
18 | .font(.caption)
19 | .foregroundStyle(.secondary)
20 |
21 | Spacer()
22 |
23 | Button {
24 | withAnimation(.easeInOut(duration: 0.2)) {
25 | isPromptExpanded.toggle()
26 | }
27 | } label: {
28 | Image(systemName: isPromptExpanded ? "chevron.down" : "chevron.up")
29 | .font(.caption)
30 | .foregroundStyle(.secondary)
31 | }
32 | .buttonStyle(.plain)
33 | .help(isPromptExpanded ? "Collapse prompt area" : "Expand prompt area")
34 | }
35 |
36 | // Prompt text field with dynamic sizing
37 | TextField("Enter your prompt...", text: $llm.prompt, axis: .vertical)
38 | .textFieldStyle(.roundedBorder)
39 | .lineLimit(isPromptExpanded ? 15 ... 50 : 1 ... 3)
40 | .frame(height: isPromptExpanded ? 400 : nil)
41 | .onSubmit(onGenerate)
42 | .disabled(llm.running || llm.isLoading)
43 |
44 | // Action buttons
45 | HStack(spacing: 12) {
46 | Button {
47 | showingPresetPrompts = true
48 | } label: {
49 | Label("Example Prompts", systemImage: "list.bullet")
50 | }
51 | .disabled(llm.running || llm.isLoading)
52 |
53 | Spacer()
54 |
55 | Button {
56 | if llm.running {
57 | onCancel()
58 | } else {
59 | onGenerate()
60 | }
61 | } label: {
62 | Label(
63 | llm.running ? "Stop" : "Generate",
64 | systemImage: llm.running ? "stop.circle" : "play.fill"
65 | )
66 | }
67 | .buttonStyle(.borderedProminent)
68 | .keyboardShortcut(.return, modifiers: .command)
69 | .disabled((llm.prompt.isEmpty && !llm.running) || llm.isLoading)
70 | }
71 | }
72 | .padding(.vertical, 8)
73 | }
74 | }
75 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Services/ToolExecutor.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import Foundation
4 | import MLXLMCommon
5 |
6 | public typealias ToolSpec = [String: Any]
7 |
8 | /// Manages tool definitions and execution for LLM function calling
9 | @MainActor
10 | class ToolExecutor {
11 |
12 | // MARK: - Tool Definitions
13 |
14 | let currentWeatherTool = Tool(
15 | name: "get_current_weather",
16 | description: "Get the current weather in a given location",
17 | parameters: [
18 | .required(
19 | "location", type: .string, description: "The city and state, e.g. San Francisco, CA"
20 | ),
21 | .optional(
22 | "unit",
23 | type: .string,
24 | description: "The unit of temperature",
25 | extraProperties: [
26 | "enum": ["celsius", "fahrenheit"],
27 | "default": "celsius",
28 | ]
29 | ),
30 | ]
31 | ) { input in
32 | let range = input.unit == "celsius" ? (min: -20.0, max: 40.0) : (min: 0, max: 100)
33 | let temperature = Double.random(in: range.min ... range.max)
34 | let conditions = ["Sunny", "Cloudy", "Rainy", "Snowy", "Windy", "Stormy"].randomElement()!
35 | return WeatherOutput(temperature: temperature, conditions: conditions)
36 | }
37 |
38 | let addTool = Tool(
39 | name: "add_two_numbers",
40 | description: "Add two numbers together",
41 | parameters: [
42 | .required("first", type: .int, description: "The first number to add"),
43 | .required("second", type: .int, description: "The second number to add"),
44 | ]
45 | ) { input in
46 | AddOutput(result: input.first + input.second)
47 | }
48 |
49 | let timeTool = Tool(
50 | name: "get_time",
51 | description: "Get the current time",
52 | parameters: []
53 | ) { _ in
54 | TimeOutput(time: Date.now.formatted())
55 | }
56 |
57 | // MARK: - Tool Execution
58 |
59 | /// Returns all available tool schemas
60 | var allToolSchemas: [ToolSpec] {
61 | [currentWeatherTool.schema, addTool.schema, timeTool.schema]
62 | }
63 |
64 | /// Executes a tool call and returns the result
65 | func execute(_ toolCall: ToolCall) async throws -> String {
66 | switch toolCall.function.name {
67 | case currentWeatherTool.name:
68 | return try await toolCall.execute(with: currentWeatherTool).toolResult
69 | case addTool.name:
70 | return try await toolCall.execute(with: addTool).toolResult
71 | case timeTool.name:
72 | return try await toolCall.execute(with: timeTool).toolResult
73 | default:
74 | return "Unknown tool: \(toolCall.function.name)"
75 | }
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Models/Message.swift:
--------------------------------------------------------------------------------
1 | //
2 | // Message.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 20.04.2025.
6 | //
7 |
8 | import Foundation
9 |
10 | /// Represents a chat message in the conversation.
11 | /// Messages can contain text content and optional media attachments (images and videos).
12 | @Observable
13 | class Message: Identifiable {
14 | /// Unique identifier for the message
15 | let id: UUID
16 |
17 | /// The role of the message sender (user, assistant, or system)
18 | let role: Role
19 |
20 | /// The text content of the message
21 | var content: String
22 |
23 | /// Array of image URLs attached to the message
24 | var images: [URL]
25 |
26 | /// Array of video URLs attached to the message
27 | var videos: [URL]
28 |
29 | /// Timestamp when the message was created
30 | let timestamp: Date
31 |
32 | /// Creates a new message with the specified role, content, and optional media attachments
33 | /// - Parameters:
34 | /// - role: The role of the message sender
35 | /// - content: The text content of the message
36 | /// - images: Optional array of image URLs
37 | /// - videos: Optional array of video URLs
38 | init(role: Role, content: String, images: [URL] = [], videos: [URL] = []) {
39 | self.id = UUID()
40 | self.role = role
41 | self.content = content
42 | self.images = images
43 | self.videos = videos
44 | self.timestamp = .now
45 | }
46 |
47 | /// Defines the role of the message sender in the conversation
48 | enum Role {
49 | /// Message from the user
50 | case user
51 | /// Message from the AI assistant
52 | case assistant
53 | /// System message providing context or instructions
54 | case system
55 | }
56 | }
57 |
58 | /// Convenience methods for creating different types of messages
59 | extension Message {
60 | /// Creates a user message with optional media attachments
61 | /// - Parameters:
62 | /// - content: The text content of the message
63 | /// - images: Optional array of image URLs
64 | /// - videos: Optional array of video URLs
65 | /// - Returns: A new Message instance with user role
66 | static func user(_ content: String, images: [URL] = [], videos: [URL] = []) -> Message {
67 | Message(role: .user, content: content, images: images, videos: videos)
68 | }
69 |
70 | /// Creates an assistant message
71 | /// - Parameter content: The text content of the message
72 | /// - Returns: A new Message instance with assistant role
73 | static func assistant(_ content: String) -> Message {
74 | Message(role: .assistant, content: content)
75 | }
76 |
77 | /// Creates a system message
78 | /// - Parameter content: The text content of the message
79 | /// - Returns: A new Message instance with system role
80 | static func system(_ content: String) -> Message {
81 | Message(role: .system, content: content)
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/Tools/image-tool/Arguments.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 | import MLX
6 |
7 | #if swift(>=5.10)
8 | /// Extension to allow URL command line arguments.
9 | extension URL: @retroactive ExpressibleByArgument {
10 | public init?(argument: String) {
11 | if argument.contains("://") {
12 | self.init(string: argument)
13 | } else {
14 | self.init(filePath: argument)
15 | }
16 | }
17 | }
18 | #else
19 | /// Extension to allow URL command line arguments.
20 | extension URL: ExpressibleByArgument {
21 | public init?(argument: String) {
22 | if argument.contains("://") {
23 | self.init(string: argument)
24 | } else {
25 | self.init(filePath: argument)
26 | }
27 | }
28 | }
29 | #endif
30 |
31 | /// Argument package for adjusting and reporting memory use.
32 | struct MemoryArguments: ParsableArguments, Sendable {
33 |
34 | @Flag(name: .long, help: "Show memory stats")
35 | var memoryStats = false
36 |
37 | @Option(name: .long, help: "Maximum cache size in M")
38 | var cacheSize = 1024
39 |
40 | @Option(name: .long, help: "Maximum memory size in M")
41 | var memorySize: Int?
42 |
43 | var startMemory: GPU.Snapshot?
44 |
45 | mutating func start(_ load: () async throws -> L) async throws -> L {
46 | GPU.set(cacheLimit: cacheSize * 1024 * 1024)
47 |
48 | if let memorySize {
49 | GPU.set(memoryLimit: memorySize * 1024 * 1024)
50 | }
51 |
52 | let result = try await load()
53 | startMemory = GPU.snapshot()
54 |
55 | return result
56 | }
57 |
58 | mutating func start() {
59 | GPU.set(cacheLimit: cacheSize * 1024 * 1024)
60 |
61 | if let memorySize {
62 | GPU.set(memoryLimit: memorySize * 1024 * 1024)
63 | }
64 |
65 | startMemory = GPU.snapshot()
66 | }
67 |
68 | func reportCurrent() {
69 | if memoryStats {
70 | let memory = GPU.snapshot()
71 | print(memory.description)
72 | }
73 | }
74 |
75 | func reportMemoryStatistics() {
76 | if memoryStats, let startMemory {
77 | let endMemory = GPU.snapshot()
78 |
79 | print("=======")
80 | print("Memory size: \(GPU.memoryLimit / 1024)K")
81 | print("Cache size: \(GPU.cacheLimit / 1024)K")
82 |
83 | print("")
84 | print("=======")
85 | print("Starting memory")
86 | print(startMemory.description)
87 |
88 | print("")
89 | print("=======")
90 | print("Ending memory")
91 | print(endMemory.description)
92 |
93 | print("")
94 | print("=======")
95 | print("Growth")
96 | print(startMemory.delta(endMemory).description)
97 |
98 | }
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/DemoCommand.swift:
--------------------------------------------------------------------------------
1 | import ArgumentParser
2 | import Foundation
3 |
4 | struct DemoCommand: AsyncParsableCommand {
5 | static let configuration = CommandConfiguration(
6 | commandName: "demo",
7 | abstract: "Run a demo using sample repository documentation"
8 | )
9 |
10 | @OptionGroup var memory: MemoryArguments
11 |
12 | @Flag(name: .long, help: "Keep the generated demo index file instead of removing it")
13 | var keepIndex = false
14 |
15 | @Argument(help: "Optional queries to run after indexing. Defaults to three sample queries.")
16 | var queries: [String] = []
17 |
18 | mutating func run() async throws {
19 | var memory = self.memory
20 | memory.start()
21 | defer {
22 | memory.reportMemoryStatistics()
23 | self.memory = memory
24 | }
25 |
26 | print("Embedder Tool Demo")
27 |
28 | let indexURL = try makeTemporaryIndexURL()
29 | defer {
30 | if !keepIndex {
31 | do {
32 | try FileManager.default.removeItem(at: indexURL)
33 | } catch {
34 | if FileManager.default.fileExists(atPath: indexURL.path) {
35 | let message =
36 | "Failed to remove temporary index file at \(indexURL.path): \(error.localizedDescription). Please remove it manually."
37 | writeDiagnostic(message, kind: .warning)
38 | }
39 | }
40 | }
41 | }
42 |
43 | try await buildIndex(at: indexURL)
44 | let queriesToRun = queries.isEmpty ? defaultQueries : queries
45 | try await runSampleQueries(using: indexURL, queries: queriesToRun)
46 | }
47 |
48 | private func makeTemporaryIndexURL() throws -> URL {
49 | let directory = FileManager.default.temporaryDirectory
50 | return directory.appendingPathComponent("embedder-demo-\(UUID().uuidString).json")
51 | }
52 |
53 | private func buildIndex(at url: URL) async throws {
54 | var indexCommand = IndexCommand()
55 | indexCommand.corpus.directory = URL(fileURLWithPath: "Libraries")
56 | indexCommand.corpus.extensions = ["md"]
57 | indexCommand.corpus.recursive = true
58 | indexCommand.corpus.limit = 8
59 | indexCommand.output = url
60 | indexCommand.batchSize = 4
61 | indexCommand.pooling.normalize = true
62 |
63 | try await indexCommand.run()
64 | }
65 |
66 | private func runSampleQueries(using indexURL: URL, queries: [String]) async throws {
67 | for query in queries {
68 | var searchCommand = SearchCommand()
69 | searchCommand.index = indexURL
70 | searchCommand.query = query
71 | searchCommand.top = 2
72 | searchCommand.pooling.normalize = true
73 |
74 | try await searchCommand.run()
75 | }
76 | }
77 |
78 | private var defaultQueries: [String] {
79 | [
80 | "How do I use embedding models?",
81 | "Training language models",
82 | "Vision language models",
83 | ]
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/LLMEval.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
8 |
9 |
15 |
21 |
22 |
23 |
24 |
25 |
31 |
32 |
42 |
44 |
50 |
51 |
52 |
53 |
59 |
61 |
67 |
68 |
69 |
70 |
72 |
73 |
76 |
77 |
78 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Views/HeaderView.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | struct HeaderView: View {
6 | @Bindable var llm: LLMEvaluator
7 | @Binding var selectedDisplayStyle: ContentView.DisplayStyle
8 |
9 | var body: some View {
10 | VStack(alignment: .leading, spacing: 12) {
11 | // Model info with status
12 | HStack {
13 | VStack(alignment: .leading, spacing: 4) {
14 | Text("Model")
15 | .font(.caption)
16 | .foregroundStyle(.secondary)
17 |
18 | Text(llm.modelInfo)
19 | .font(.headline)
20 | .lineLimit(1)
21 | }
22 |
23 | Spacer()
24 |
25 | if llm.running {
26 | HStack(spacing: 8) {
27 | ProgressView()
28 | .controlSize(.small)
29 | Text("Generating...")
30 | .font(.subheadline)
31 | .foregroundStyle(.secondary)
32 | }
33 | }
34 | }
35 |
36 | // Controls row
37 | HStack(spacing: 16) {
38 | HStack(spacing: 24) {
39 | Toggle("Tools", isOn: $llm.includeWeatherTool)
40 | .toggleStyle(.switch)
41 | .fixedSize()
42 | .help("Enable function calling with weather, math, and time tools")
43 |
44 | Toggle("Thinking", isOn: $llm.enableThinking)
45 | .toggleStyle(.switch)
46 | .fixedSize()
47 | .help("Enable thinking mode (supported by Qwen3)")
48 |
49 | // Max tokens slider
50 | VStack(alignment: .leading, spacing: 4) {
51 | Text("Max Tokens: \(llm.maxTokens)")
52 | .font(.caption)
53 | .foregroundStyle(.secondary)
54 |
55 | Slider(
56 | value: Binding(
57 | get: { log2(Double(llm.maxTokens)) },
58 | set: { llm.maxTokens = Int(pow(2, $0)) }
59 | ),
60 | in: 10 ... 15, // 2^10 (1024) to 2^15 (32768)
61 | step: 1
62 | )
63 | .frame(width: 120)
64 | .help("Maximum number of tokens to generate (1024-32768)")
65 | }
66 | }
67 |
68 | Spacer()
69 |
70 | Picker("Display", selection: $selectedDisplayStyle) {
71 | ForEach(ContentView.DisplayStyle.allCases, id: \.self) { option in
72 | Text(option.rawValue.capitalized)
73 | .tag(option)
74 | }
75 | }
76 | .pickerStyle(.segmented)
77 | .labelsHidden()
78 | .frame(maxWidth: 180)
79 | }
80 | }
81 | .padding(.bottom, 12)
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/VLMEval.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
25 |
26 |
32 |
33 |
43 |
45 |
51 |
52 |
53 |
54 |
60 |
62 |
68 |
69 |
70 |
71 |
73 |
74 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/ExampleLLM.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
25 |
26 |
32 |
33 |
44 |
46 |
52 |
53 |
54 |
55 |
61 |
63 |
69 |
70 |
71 |
72 |
74 |
75 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/Libraries/MLXMNIST/Files.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import Gzip
5 | import MLX
6 |
7 | // based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py
8 |
9 | public enum Use: String, Hashable, Sendable {
10 | case test
11 | case training
12 | }
13 |
14 | public enum DataKind: String, Hashable, Sendable {
15 | case images
16 | case labels
17 | }
18 |
19 | public struct FileKind: Hashable, CustomStringConvertible, Sendable {
20 | let use: Use
21 | let data: DataKind
22 |
23 | public init(_ use: Use, _ data: DataKind) {
24 | self.use = use
25 | self.data = data
26 | }
27 |
28 | public var description: String {
29 | "\(use.rawValue)-\(data.rawValue)"
30 | }
31 | }
32 |
33 | struct LoadInfo: Sendable {
34 | let name: String
35 | let offset: Int
36 | let convert: @Sendable (MLXArray) -> MLXArray
37 | }
38 |
39 | let baseURL = URL(string: "https://raw.githubusercontent.com/fgnt/mnist/master/")!
40 |
41 | private let files = [
42 | FileKind(.training, .images): LoadInfo(
43 | name: "train-images-idx3-ubyte.gz",
44 | offset: 16,
45 | convert: {
46 | $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
47 | }),
48 | FileKind(.test, .images): LoadInfo(
49 | name: "t10k-images-idx3-ubyte.gz",
50 | offset: 16,
51 | convert: {
52 | $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
53 | }),
54 | FileKind(.training, .labels): LoadInfo(
55 | name: "train-labels-idx1-ubyte.gz",
56 | offset: 8,
57 | convert: {
58 | $0.asType(.uint32)
59 | }),
60 | FileKind(.test, .labels): LoadInfo(
61 | name: "t10k-labels-idx1-ubyte.gz",
62 | offset: 8,
63 | convert: {
64 | $0.asType(.uint32)
65 | }),
66 | ]
67 |
68 | public func download(into: URL) async throws {
69 | for (_, info) in files {
70 | let fileURL = into.appending(component: info.name)
71 | if !FileManager.default.fileExists(atPath: fileURL.path()) {
72 | print("Download: \(info.name)")
73 | let url = baseURL.appending(component: info.name)
74 | let (data, response) = try await URLSession.shared.data(from: url)
75 |
76 | guard let httpResponse = response as? HTTPURLResponse else {
77 | fatalError("Unable to download \(url), not an http response: \(response)")
78 | }
79 | guard httpResponse.statusCode == 200 else {
80 | fatalError("Unable to download \(url): \(httpResponse)")
81 | }
82 |
83 | try data.write(to: fileURL)
84 | }
85 | }
86 | }
87 |
88 | public func load(from: URL) throws -> [FileKind: MLXArray] {
89 | var result = [FileKind: MLXArray]()
90 |
91 | for (key, info) in files {
92 | let fileURL = from.appending(component: info.name)
93 | let data = try Data(contentsOf: fileURL).gunzipped()
94 |
95 | let array = MLXArray(
96 | data.dropFirst(info.offset), [data.count - info.offset], type: UInt8.self)
97 |
98 | result[key] = info.convert(array)
99 | }
100 |
101 | return result
102 | }
103 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/ModelArguments.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 | import Hub
6 | import MLXEmbedders
7 |
8 | struct ModelArguments: ParsableArguments {
9 |
10 | @Option(
11 | name: .long,
12 | help: "Name of the embedder model configuration or absolute path to a local directory.")
13 | var model: String?
14 |
15 | @Option(name: .long, help: "Directory used for downloading model assets from the Hub.")
16 | var download: URL?
17 |
18 | @MainActor
19 | func configuration(default defaultConfiguration: ModelConfiguration) -> ModelConfiguration {
20 | guard let model else {
21 | return defaultConfiguration
22 | }
23 |
24 | if let localConfiguration = resolveLocalModelPath(model) {
25 | return localConfiguration
26 | }
27 |
28 | return ModelConfiguration.configuration(id: model)
29 | }
30 |
31 | var downloadURL: URL? {
32 | download?.standardizedFileURL
33 | }
34 | }
35 |
36 | struct LoadedEmbedderModel {
37 | let configuration: ModelConfiguration
38 | let container: ModelContainer
39 | }
40 |
41 | extension ModelArguments {
42 |
43 | func load(default defaultConfiguration: ModelConfiguration) async throws -> LoadedEmbedderModel
44 | {
45 | let configuration = await configuration(default: defaultConfiguration)
46 | let hub = makeHub()
47 |
48 | print("Loading model \(configuration.name)...")
49 |
50 | let container = try await MLXEmbedders.loadModelContainer(
51 | hub: hub,
52 | configuration: configuration,
53 | progressHandler: { progress in
54 | let percentage = Int(progress.fractionCompleted * 100)
55 | let previousPercentage = Int((progress.fractionCompleted - 0.01) * 100)
56 |
57 | if percentage % 10 == 0 && percentage != previousPercentage {
58 | print("Downloading model: \(percentage)%")
59 | }
60 | }
61 | )
62 |
63 | return LoadedEmbedderModel(configuration: configuration, container: container)
64 | }
65 |
66 | private func makeHub() -> HubApi {
67 | if let downloadURL {
68 | return HubApi(downloadBase: downloadURL)
69 | }
70 |
71 | return HubApi()
72 | }
73 | }
74 |
75 | extension ModelArguments {
76 | private func resolveLocalModelPath(_ value: String) -> ModelConfiguration? {
77 | let expanded = NSString(string: value).expandingTildeInPath
78 | let candidate = URL(fileURLWithPath: expanded, isDirectory: true)
79 | var isDirectory: ObjCBool = false
80 | if FileManager.default.fileExists(atPath: candidate.path, isDirectory: &isDirectory),
81 | isDirectory.boolValue
82 | {
83 | return ModelConfiguration(directory: candidate.standardizedFileURL)
84 | }
85 |
86 | if let url = URL(string: value), url.isFileURL {
87 | var isDir: ObjCBool = false
88 | if FileManager.default.fileExists(atPath: url.path, isDirectory: &isDir),
89 | isDir.boolValue
90 | {
91 | return ModelConfiguration(directory: url.standardizedFileURL)
92 | }
93 | }
94 |
95 | return nil
96 | }
97 | }
98 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/StableDiffusionExample.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
25 |
26 |
32 |
33 |
43 |
45 |
51 |
52 |
53 |
54 |
60 |
62 |
68 |
69 |
70 |
71 |
73 |
74 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/EmbedderRuntime+Embedding.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import MLX
3 | import MLXEmbedders
4 | import Tokenizers
5 |
6 | public struct RuntimeEmbeddingResult {
7 | public let embeddings: [(index: Int, vector: [Float])]
8 | public let skippedIndices: [Int]
9 | public let fallbackDescription: String?
10 |
11 | public init(
12 | embeddings: [(index: Int, vector: [Float])],
13 | skippedIndices: [Int],
14 | fallbackDescription: String?
15 | ) {
16 | self.embeddings = embeddings
17 | self.skippedIndices = skippedIndices
18 | self.fallbackDescription = fallbackDescription
19 | }
20 | }
21 |
22 | extension EmbedderRuntime {
23 | func embed(texts: [String]) async throws -> RuntimeEmbeddingResult {
24 | guard !texts.isEmpty else {
25 | return RuntimeEmbeddingResult(
26 | embeddings: [], skippedIndices: [], fallbackDescription: nil)
27 | }
28 |
29 | return try await container.perform { model, tokenizer, pooler in
30 | var skippedIndices: [Int] = []
31 |
32 | let encoded = texts.enumerated().compactMap { index, text -> (Int, [Int])? in
33 | let tokens = tokenizer.encode(text: text, addSpecialTokens: true)
34 | guard !tokens.isEmpty else {
35 | skippedIndices.append(index)
36 | return nil
37 | }
38 | return (index, tokens)
39 | }
40 |
41 | guard !encoded.isEmpty else {
42 | return RuntimeEmbeddingResult(
43 | embeddings: [],
44 | skippedIndices: skippedIndices,
45 | fallbackDescription: nil
46 | )
47 | }
48 |
49 | guard let padToken = tokenizer.eosTokenId else {
50 | throw CommandError("Could not determine a padding token from the tokenizer.")
51 | }
52 | let maxLength = encoded.map { $0.1.count }.max() ?? 0
53 |
54 | let padded = stacked(
55 | encoded.map { _, tokens in
56 | MLXArray(tokens + Array(repeating: padToken, count: maxLength - tokens.count))
57 | })
58 | let mask = (padded .!= padToken)
59 | let tokenTypes = MLXArray.zeros(like: padded)
60 |
61 | let outputs = model(
62 | padded,
63 | positionIds: nil,
64 | tokenTypeIds: tokenTypes,
65 | attentionMask: mask
66 | )
67 |
68 | let poolingModule = resolvedPooler(for: pooler)
69 | let pooled = poolingModule(
70 | outputs,
71 | mask: mask,
72 | normalize: self.normalize,
73 | applyLayerNorm: self.applyLayerNorm
74 | )
75 | pooled.eval()
76 |
77 | let extraction = try extractVectors(from: pooled, expectedCount: encoded.count)
78 |
79 | let embeddings = zip(encoded.map { $0.0 }, extraction.vectors).map { index, vector in
80 | (index: index, vector: vector)
81 | }
82 |
83 | return RuntimeEmbeddingResult(
84 | embeddings: embeddings,
85 | skippedIndices: skippedIndices,
86 | fallbackDescription: extraction.fallbackDescription
87 | )
88 | }
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/Tools/Tutorial/Tutorial.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import MLX
5 |
6 | /// mlx-swift tutorial based on:
7 | /// https://github.com/ml-explore/mlx/blob/main/examples/cpp/tutorial.cpp
8 | @main
9 | struct Tutorial {
10 |
11 | static func scalarBasics() {
12 | // create a scalar array
13 | let x = MLXArray(1.0)
14 |
15 | // the datatype is .float32
16 | let dtype = x.dtype
17 | assert(dtype == .float32)
18 |
19 | // get the value
20 | let s = x.item(Float.self)
21 | assert(s == 1.0)
22 |
23 | // reading the value with a different type is a fatal error
24 | // let i = x.item(Int.self)
25 |
26 | // scalars have a size of 1
27 | let size = x.size
28 | assert(size == 1)
29 |
30 | // scalars have 0 dimensions
31 | let ndim = x.ndim
32 | assert(ndim == 0)
33 |
34 | // scalar shapes are empty arrays
35 | let shape = x.shape
36 | assert(shape == [])
37 | }
38 |
39 | static func arrayBasics() {
40 | // make a multidimensional array.
41 | //
42 | // Note: the argument is a [Double] array literal, which is not
43 | // a supported type, but we can explicitly convert it to [Float]
44 | // when we create the MLXArray.
45 | let x = MLXArray(converting: [1.0, 2.0, 3.0, 4.0], [2, 2])
46 |
47 | // mlx is row-major by default so the first row of this array
48 | // is [1.0, 2.0] and the second row is [3.0, 4.0]
49 | print(x[0])
50 | print(x[1])
51 |
52 | // make an array of shape [2, 2] filled with ones
53 | let y = MLXArray.ones([2, 2])
54 |
55 | // pointwise add x and y
56 | let z = x + y
57 |
58 | // mlx is lazy by default. At this point `z` only
59 | // has a shape and a type but no actual data
60 | assert(z.dtype == .float32)
61 | assert(z.shape == [2, 2])
62 |
63 | // To actually run the computation you must evaluate `z`.
64 | // Under the hood, mlx records operations in a graph.
65 | // The variable `z` is a node in the graph which points to its operation
66 | // and inputs. When `eval` is called on an array (or arrays), the array and
67 | // all of its dependencies are recursively evaluated to produce the result.
68 | // Once an array is evaluated, it has data and is detached from its inputs.
69 |
70 | // Note: this is being called for demonstration purposes -- all reads
71 | // ensure the array is evaluated.
72 | z.eval()
73 |
74 | // this implicitly evaluates z before converting to a description
75 | print(z)
76 | }
77 |
78 | static func automaticDifferentiation() {
79 | func fn(_ x: MLXArray) -> MLXArray {
80 | x.square()
81 | }
82 |
83 | let gradFn = grad(fn)
84 |
85 | let x = MLXArray(1.5)
86 | let dfdx = gradFn(x)
87 | print(dfdx)
88 |
89 | assert(dfdx.item() == Float(2 * 1.5))
90 |
91 | let df2dx2 = grad(grad(fn))(x)
92 | print(df2dx2)
93 |
94 | assert(df2dx2.item() == Float(2))
95 | }
96 |
97 | static func main() {
98 | scalarBasics()
99 | arrayBasics()
100 | automaticDifferentiation()
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/Tests/MLXLMTests/ToolTests.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import MLXLMCommon
3 | import Testing
4 |
5 | struct ToolTests {
6 | @Test("Test Weather Tool Schema Generation")
7 | func testWeatherToolSchemaGeneration() throws {
8 | struct WeatherInput: Codable {
9 | let location: String
10 | let unit: String?
11 | }
12 |
13 | struct WeatherOutput: Codable {
14 | let temperature: Double
15 | let conditions: String
16 | }
17 |
18 | let tool = Tool(
19 | name: "get_current_weather",
20 | description: "Get the current weather in a given location",
21 | parameters: [
22 | .required(
23 | "location", type: .string, description: "The city, e.g. Istanbul"
24 | ),
25 | .optional(
26 | "unit",
27 | type: .string,
28 | description: "The unit of temperature",
29 | extraProperties: [
30 | "enum": ["celsius", "fahrenheit"]
31 | ]
32 | ),
33 | ]
34 | ) { input in
35 | WeatherOutput(temperature: 14.0, conditions: "Sunny")
36 | }
37 |
38 | let actual = tool.schema as NSDictionary
39 |
40 | let expected: NSDictionary = [
41 | "type": "function",
42 | "function": [
43 | "name": "get_current_weather",
44 | "description": "Get the current weather in a given location",
45 | "parameters": [
46 | "type": "object",
47 | "properties": [
48 | "location": [
49 | "type": "string",
50 | "description": "The city, e.g. Istanbul",
51 | ],
52 | "unit": [
53 | "type": "string",
54 | "description": "The unit of temperature",
55 | "enum": ["celsius", "fahrenheit"],
56 | ],
57 | ],
58 | "required": ["location"],
59 | ],
60 | ],
61 | ]
62 |
63 | #expect(actual == expected)
64 | }
65 |
66 | @Test("Test Tool Call Detection in Generated Text")
67 | func testToolCallDetection() throws {
68 | let processor = ToolCallProcessor()
69 | let chunks: [String] = [
70 | "", "{", "\"", "name", "\"", ":", " ", "\"", "get", "_", "current",
71 | "_", "weather", "\"", ",", " ", "\"", "arguments", "\"", ":", " ", "{", "\"",
72 | "location", "\"", ":", " ", "\"", "San", " Francisco", "\"", ",", " ", "\"", "unit",
73 | "\"", ":", " ", "\"", "celsius", "\"", "}", "}", "",
74 | ]
75 |
76 | for chunk in chunks {
77 | let result = processor.processChunk(chunk)
78 | #expect(result == nil)
79 | }
80 |
81 | #expect(processor.toolCalls.count == 1)
82 | let toolCall = try #require(processor.toolCalls.first)
83 |
84 | #expect(toolCall.function.name == "get_current_weather")
85 | #expect(toolCall.function.arguments["location"] == .string("San Francisco"))
86 | #expect(toolCall.function.arguments["unit"] == .string("celsius"))
87 | }
88 | }
89 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Views/MessageView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // MessageView.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 20.04.2025.
6 | //
7 |
8 | import AVKit
9 | import SwiftUI
10 |
11 | /// A view that displays a single message in the chat interface.
12 | /// Supports different message roles (user, assistant, system) and media attachments.
13 | struct MessageView: View {
14 | /// The message to be displayed
15 | let message: Message
16 |
17 | /// Creates a message view
18 | /// - Parameter message: The message model to display
19 | init(_ message: Message) {
20 | self.message = message
21 | }
22 |
23 | var body: some View {
24 | switch message.role {
25 | case .user:
26 | // User messages are right-aligned with blue background
27 | HStack {
28 | Spacer()
29 | VStack(alignment: .trailing, spacing: 8) {
30 | // Display first image if present
31 | if let firstImage = message.images.first {
32 | AsyncImage(url: firstImage) { image in
33 | image
34 | .resizable()
35 | .aspectRatio(contentMode: .fill)
36 | } placeholder: {
37 | ProgressView()
38 | }
39 | .frame(maxWidth: 250, maxHeight: 200)
40 | .clipShape(.rect(cornerRadius: 12))
41 | }
42 |
43 | // Display first video if present
44 | if let firstVideo = message.videos.first {
45 | VideoPlayer(player: AVPlayer(url: firstVideo))
46 | .frame(width: 250, height: 340)
47 | .clipShape(.rect(cornerRadius: 12))
48 | }
49 |
50 | // Message content with tinted background.
51 | // LocalizedStringKey used to trigger default handling of markdown content.
52 | Text(LocalizedStringKey(message.content))
53 | .padding(.vertical, 8)
54 | .padding(.horizontal, 12)
55 | .background(.tint, in: .rect(cornerRadius: 16))
56 | .textSelection(.enabled)
57 | }
58 | }
59 |
60 | case .assistant:
61 | // Assistant messages are left-aligned without background
62 | // LocalizedStringKey used to trigger default handling of markdown content.
63 | HStack {
64 | Text(LocalizedStringKey(message.content))
65 | .textSelection(.enabled)
66 |
67 | Spacer()
68 | }
69 |
70 | case .system:
71 | // System messages are centered with computer icon
72 | Label(message.content, systemImage: "desktopcomputer")
73 | .font(.headline)
74 | .foregroundColor(.secondary)
75 | .frame(maxWidth: .infinity, alignment: .center)
76 | }
77 | }
78 | }
79 |
80 | #Preview {
81 | VStack(spacing: 20) {
82 | MessageView(.system("You are a helpful assistant."))
83 |
84 | MessageView(
85 | .user(
86 | "Here's a photo",
87 | images: [URL(string: "https://picsum.photos/200")!]
88 | )
89 | )
90 |
91 | MessageView(.assistant("I see your photo!"))
92 | }
93 | .padding()
94 | }
95 |
--------------------------------------------------------------------------------
/Tools/mnist-tool/MNISTTool.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 | import MLX
6 | import MLXMNIST
7 | import MLXNN
8 | import MLXOptimizers
9 |
10 | @main
11 | struct MNISTTool: AsyncParsableCommand {
12 | static let configuration = CommandConfiguration(
13 | abstract: "Command line tool for training mnist models",
14 | subcommands: [Train.self],
15 | defaultSubcommand: Train.self)
16 | }
17 |
18 | #if swift(>=5.10)
19 | extension MLX.DeviceType: @retroactive ExpressibleByArgument {
20 | public init?(argument: String) {
21 | self.init(rawValue: argument)
22 | }
23 | }
24 | #else
25 | extension MLX.DeviceType: ExpressibleByArgument {
26 | public init?(argument: String) {
27 | self.init(rawValue: argument)
28 | }
29 | }
30 | #endif
31 |
32 | struct Train: AsyncParsableCommand {
33 |
34 | @Option(name: .long, help: "Directory with the training data")
35 | var data: String
36 |
37 | @Option(name: .long, help: "The PRNG seed")
38 | var seed: UInt64 = 0
39 |
40 | @Option var batchSize = 256
41 | @Option var epochs = 20
42 | @Option var learningRate: Float = 1e-1
43 |
44 | @Option var device = DeviceType.gpu
45 |
46 | @Flag var compile = false
47 |
48 | func run() async throws {
49 | Device.setDefault(device: Device(device))
50 |
51 | MLXRandom.seed(seed)
52 | var generator: RandomNumberGenerator = SplitMix64(seed: seed)
53 |
54 | // load the data
55 | let url = URL(filePath: data)
56 |
57 | try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true)
58 | try await download(into: url)
59 |
60 | let data = try load(from: url)
61 |
62 | let trainImages = data[.init(.training, .images)]!
63 | let trainLabels = data[.init(.training, .labels)]!
64 | let testImages = data[.init(.test, .images)]!
65 | let testLabels = data[.init(.test, .labels)]!
66 |
67 | // create the model
68 | let model = LeNet()
69 | eval(model.parameters())
70 |
71 | let lg = valueAndGrad(model: model, loss)
72 | let optimizer = SGD(learningRate: learningRate)
73 |
74 | func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray {
75 | let (loss, grads) = lg(model, x, y)
76 | optimizer.update(model: model, gradients: grads)
77 | return loss
78 | }
79 |
80 | let resolvedStep =
81 | compile
82 | ? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step
83 |
84 | for e in 0 ..< epochs {
85 | let start = Date.timeIntervalSinceReferenceDate
86 |
87 | for (x, y) in iterateBatches(
88 | batchSize: batchSize, x: trainImages, y: trainLabels, using: &generator)
89 | {
90 | _ = resolvedStep(x, y)
91 |
92 | // eval the parameters so the next iteration is independent
93 | eval(model, optimizer)
94 | }
95 |
96 | let accuracy = eval(model: model, x: testImages, y: testLabels)
97 |
98 | let end = Date.timeIntervalSinceReferenceDate
99 |
100 | print(
101 | """
102 | Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted())
103 | Time: \((end - start).formatted())
104 |
105 | """
106 | )
107 | }
108 | }
109 | }
110 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/PoolingSupport.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import MLX
3 | import MLXEmbedders
4 |
5 | enum PoolingError: LocalizedError {
6 | case unsupportedShape([Int])
7 | case vectorCountMismatch(expected: Int, received: Int)
8 |
9 | var errorDescription: String? {
10 | switch self {
11 | case .unsupportedShape(let shape):
12 | return "Pooling produced unsupported shape: \(shape)"
13 | case .vectorCountMismatch(let expected, let received):
14 | return "Pooling produced \(received) vectors but expected \(expected)"
15 | }
16 | }
17 | }
18 |
19 | struct PoolingExtraction {
20 | let vectors: [[Float]]
21 | let fallbackDescription: String?
22 | }
23 |
24 | extension EmbedderRuntime {
25 | func resolvedPooler(for pooler: Pooling) -> Pooling {
26 | guard let override = strategyOverride else {
27 | return pooler
28 | }
29 |
30 | if pooler.strategy == override {
31 | return pooler
32 | }
33 |
34 | if let dimension = pooler.dimension {
35 | return Pooling(strategy: override, dimension: dimension)
36 | } else {
37 | return Pooling(strategy: override)
38 | }
39 | }
40 |
41 | func extractVectors(from array: MLXArray, expectedCount: Int) throws -> PoolingExtraction {
42 | let shape = array.shape
43 |
44 | switch shape.count {
45 | case 2:
46 | let vectors = array.map { $0.asArray(Float.self) }
47 | guard vectors.count == expectedCount else {
48 | throw PoolingError.vectorCountMismatch(
49 | expected: expectedCount, received: vectors.count)
50 | }
51 | return PoolingExtraction(vectors: vectors, fallbackDescription: nil)
52 |
53 | case 3:
54 | let reduced = mean(array, axis: 1)
55 | reduced.eval()
56 | let vectors = reduced.map { $0.asArray(Float.self) }
57 | guard vectors.count == expectedCount else {
58 | throw PoolingError.vectorCountMismatch(
59 | expected: expectedCount, received: vectors.count)
60 | }
61 |
62 | let effectiveStrategy = strategyOverride ?? baseStrategy
63 | let description: String
64 | if effectiveStrategy == .none {
65 | description =
66 | "Pooling strategy 'none' returned sequence embeddings; falling back to mean over tokens."
67 | } else {
68 | description =
69 | "Pooling returned sequence embeddings; falling back to mean over tokens."
70 | }
71 | return PoolingExtraction(vectors: vectors, fallbackDescription: description)
72 |
73 | default:
74 | throw PoolingError.unsupportedShape(shape)
75 | }
76 | }
77 | }
78 |
79 | extension Pooling.Strategy {
80 | var cliDescription: String {
81 | switch self {
82 | case .mean: return "mean"
83 | case .cls: return "cls"
84 | case .first: return "first"
85 | case .last: return "last"
86 | case .max: return "max"
87 | case .none: return "none"
88 | }
89 | }
90 | }
91 |
92 | extension EmbedderRuntime {
93 | var poolingDescription: String {
94 | if let override = strategyOverride {
95 | return "override (\(override.cliDescription))"
96 | } else {
97 | return "model default (\(baseStrategy.cliDescription))"
98 | }
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Views/MetricsView.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import MLX
4 | import SwiftUI
5 |
6 | struct MetricsView: View {
7 | let tokensPerSecond: Double
8 | let timeToFirstToken: Double
9 | let promptLength: Int
10 | let totalTokens: Int
11 | let totalTime: Double
12 | let memoryUsed: Int
13 | let cacheMemory: Int
14 | let peakMemory: Int
15 |
16 | @State private var showMemoryDetails = false
17 |
18 | var body: some View {
19 | VStack(spacing: 12) {
20 | // Top row
21 | HStack(spacing: 12) {
22 | MetricCard(
23 | icon: "speedometer",
24 | title: "Tokens/sec",
25 | value: String(format: "%.1f", tokensPerSecond)
26 | )
27 | MetricCard(
28 | icon: "timer",
29 | title: "Time to First Token",
30 | value: String(format: "%.0fms", timeToFirstToken)
31 | )
32 | MetricCard(
33 | icon: "text.alignleft",
34 | title: "Prompt Length",
35 | value: "\(promptLength)"
36 | )
37 | }
38 |
39 | // Bottom row
40 | HStack(spacing: 12) {
41 | MetricCard(
42 | icon: "number",
43 | title: "Total Tokens",
44 | value: "\(totalTokens)"
45 | )
46 | MetricCard(
47 | icon: "hourglass",
48 | title: "Total Time",
49 | value: String(format: "%.1fs", totalTime)
50 | )
51 | ZStack(alignment: .topTrailing) {
52 | MetricCard(
53 | icon: "memorychip",
54 | title: "Memory",
55 | value: FormatUtilities.formatMemory(memoryUsed)
56 | )
57 | Button(action: {
58 | #if os(iOS)
59 | showMemoryDetails = true
60 | #endif
61 | }) {
62 | Image(systemName: "info.circle.fill")
63 | .font(.caption)
64 | .foregroundStyle(.secondary)
65 | .frame(width: 44, height: 44)
66 | .contentShape(Rectangle())
67 | }
68 | .buttonStyle(.plain)
69 | .help(
70 | """
71 | Active Memory: \(FormatUtilities.formatMemory(memoryUsed))/\(FormatUtilities.formatMemory(GPU.memoryLimit))
72 | Cache Memory: \(FormatUtilities.formatMemory(cacheMemory))/\(FormatUtilities.formatMemory(GPU.cacheLimit))
73 | Peak Memory: \(FormatUtilities.formatMemory(peakMemory))
74 | """
75 | )
76 | }
77 | }
78 | }
79 | .padding(.top, 8)
80 | .alert("Memory Details", isPresented: $showMemoryDetails) {
81 | Button("OK", role: .cancel) {}
82 | } message: {
83 | Text(
84 | """
85 | Active Memory: \(FormatUtilities.formatMemory(memoryUsed))/\(FormatUtilities.formatMemory(GPU.memoryLimit))
86 | Cache Memory: \(FormatUtilities.formatMemory(cacheMemory))/\(FormatUtilities.formatMemory(GPU.cacheLimit))
87 | Peak Memory: \(FormatUtilities.formatMemory(peakMemory))
88 | """)
89 | }
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/Tools/embedder-tool/VectorOperations.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import Accelerate
4 | import Foundation
5 |
6 | enum VectorOperations {
7 | static func cosineSimilarity(_ lhs: [Float], _ rhs: [Float]) -> Float {
8 | guard lhs.count == rhs.count else { return 0 }
9 | guard !lhs.isEmpty else { return 0 }
10 |
11 | var dot: Float = 0
12 | var lhsNormSquared: Float = 0
13 | var rhsNormSquared: Float = 0
14 |
15 | vDSP_dotpr(lhs, 1, rhs, 1, &dot, vDSP_Length(lhs.count))
16 | vDSP_svesq(lhs, 1, &lhsNormSquared, vDSP_Length(lhs.count))
17 | vDSP_svesq(rhs, 1, &rhsNormSquared, vDSP_Length(rhs.count))
18 |
19 | let denominator = sqrt(lhsNormSquared * rhsNormSquared)
20 | guard denominator > 1e-9 else { return 0 }
21 | return dot / denominator
22 | }
23 |
24 | static func l2Norm(_ vector: [Float]) -> Float {
25 | guard !vector.isEmpty else { return 0 }
26 | var sumSquares: Float = 0
27 | vDSP_svesq(vector, 1, &sumSquares, vDSP_Length(vector.count))
28 | return sqrt(sumSquares)
29 | }
30 |
31 | static func normalize(_ vector: [Float]) -> [Float] {
32 | guard !vector.isEmpty else { return [] }
33 |
34 | let sanitized = sanitize(vector)
35 |
36 | var sumSquares: Float = 0
37 | vDSP_svesq(sanitized, 1, &sumSquares, vDSP_Length(sanitized.count))
38 |
39 | guard sumSquares.isFinite else { return [] }
40 | guard sumSquares > 1e-9 else { return sanitized }
41 |
42 | var divisor = sqrt(sumSquares)
43 | var normalized = [Float](repeating: 0, count: sanitized.count)
44 | vDSP_vsdiv(sanitized, 1, &divisor, &normalized, 1, vDSP_Length(sanitized.count))
45 |
46 | return normalized
47 | }
48 |
49 | static func sanitize(_ vector: [Float]) -> [Float] {
50 | guard !vector.isEmpty else { return [] }
51 |
52 | var sanitized = vector
53 | for index in sanitized.indices where !sanitized[index].isFinite {
54 | sanitized[index] = 0
55 | }
56 | return sanitized
57 | }
58 |
59 | static func dotProduct(_ lhs: [Float], _ rhs: [Float]) -> Float {
60 | guard lhs.count == rhs.count else { return 0 }
61 | guard !lhs.isEmpty else { return 0 }
62 |
63 | var result: Float = 0
64 | vDSP_dotpr(lhs, 1, rhs, 1, &result, vDSP_Length(lhs.count))
65 | return result
66 | }
67 |
68 | static func normalizeBatch(_ vectors: [[Float]]) -> [[Float]] {
69 | vectors.map { normalize($0) }
70 | }
71 |
72 | static func batchCosineSimilarity(query: [Float], documents: [[Float]]) -> [Float] {
73 | documents.map { cosineSimilarity(query, $0) }
74 | }
75 |
76 | static func batchDotProduct(query: [Float], documents: [[Float]]) -> [Float] {
77 | documents.map { dotProduct(query, $0) }
78 | }
79 | }
80 |
81 | extension VectorOperations {
82 | static func hasNonFiniteValues(_ vector: [Float]) -> Bool {
83 | vector.contains(where: { !$0.isFinite })
84 | }
85 |
86 | static func statistics(_ vector: [Float]) -> (mean: Float, min: Float, max: Float, norm: Float)
87 | {
88 | guard !vector.isEmpty else { return (0, 0, 0, 0) }
89 |
90 | var mean: Float = 0
91 | var min: Float = 0
92 | var max: Float = 0
93 |
94 | vDSP_meanv(vector, 1, &mean, vDSP_Length(vector.count))
95 | vDSP_minv(vector, 1, &min, vDSP_Length(vector.count))
96 | vDSP_maxv(vector, 1, &max, vDSP_Length(vector.count))
97 |
98 | return (mean, min, max, l2Norm(vector))
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/.github/workflows/pull_request.yml:
--------------------------------------------------------------------------------
1 | name: Build and Test
2 |
3 | on: pull_request
4 |
5 | permissions:
6 | contents: read
7 |
8 | jobs:
9 | lint:
10 | if: github.repository == 'ml-explore/mlx-swift-examples'
11 | runs-on: ubuntu-22.04
12 | container:
13 | image: swift:6.2-rhel-ubi9
14 | steps:
15 | - uses: actions/checkout@v6
16 | with:
17 | submodules: recursive
18 |
19 | - name: Setup uv
20 | uses: astral-sh/setup-uv@v6
21 | with:
22 | activate-environment: true
23 |
24 | - name: Setup pre-commit
25 | shell: sh
26 | run: |
27 | uv pip install pre-commit
28 |
29 | - name: Get swift-format tag
30 | id: swift-format
31 | shell: sh
32 | run: |
33 | cd /tmp
34 | LATEST_TAG=$(curl -s https://api.github.com/repos/swiftlang/swift-format/releases/latest | \
35 | grep '"tag_name":' | \
36 | sed -E 's/.*"([^"]+)".*/\1/')
37 | echo "swift-format $LATEST_TAG"
38 | echo "SWIFT_FORMAT_VERSION=$LATEST_TAG" >> $GITHUB_OUTPUT
39 |
40 | - name: Cache swift-format build
41 | uses: actions/cache@v4
42 | id: cache-swift-format
43 | with:
44 | path: /tmp/swift-format/.build
45 | key: ${{ runner.os }}-swift-format-build-${{ steps.swift-format.outputs.SWIFT_FORMAT_VERSION }}
46 |
47 | - name: Build swift-format
48 | if: steps.cache-swift-format.outputs.cache-hit != 'true'
49 | shell: sh
50 | run: |
51 | cd /tmp
52 | git clone --branch ${{ steps.swift-format.outputs.SWIFT_FORMAT_VERSION }} --depth 1 https://github.com/swiftlang/swift-format.git
53 | cd swift-format
54 | swift build -c release
55 |
56 | - name: Link swift-format to /usr/local/bin
57 | shell: sh
58 | run: |
59 | cd /tmp/swift-format
60 | ln -s "$(swift build --show-bin-path -c release)/swift-format" /usr/local/bin/swift-format
61 |
62 | - name: Configure safe directory for git
63 | shell: sh
64 | run: |
65 | git config --global --add safe.directory "$GITHUB_WORKSPACE"
66 |
67 | - name: Run style checks
68 | shell: sh
69 | run: |
70 | pre-commit run --all || (echo "Style checks failed, please install pre-commit and run pre-commit run --all and push the change"; echo ""; git --no-pager diff; exit 1)
71 |
72 |
73 | mac_build_and_test:
74 | needs: lint
75 | if: github.repository == 'ml-explore/mlx-swift-examples'
76 | runs-on: [self-hosted, macos]
77 | steps:
78 | - uses: actions/checkout@v6
79 | with:
80 | submodules: recursive
81 |
82 | - name: Verify MetalToolchain installed
83 | shell: bash
84 | run: xcodebuild -showComponent MetalToolchain
85 |
86 | - name: Build Package (Xcode, macOS)
87 | shell: sh
88 | run: |
89 | xcodebuild -version
90 | xcrun --show-sdk-build-version
91 | swift --version
92 | rm -rf ~/Library/Developer/Xcode/DerivedData/*
93 | xcodebuild build-for-testing -scheme mlx-libraries-Package -destination 'platform=macOS'
94 |
95 | - name: Build tools (Xcode, macOS)
96 | shell: sh
97 | run: |
98 | xcodebuild -version
99 | xcrun --show-sdk-build-version
100 | swift --version
101 | find . -name Package.resolved -exec rm {} \;
102 | xcodebuild -scheme llm-tool
103 | xcodebuild -scheme image-tool
104 | xcodebuild -scheme mnist-tool
105 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/Views/MediaPreviewView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // MediaPreviewView.swift
3 | // MLXChatExample
4 | //
5 | // Created by İbrahim Çetin on 21.04.2025.
6 | //
7 |
8 | import AVFoundation
9 | import AVKit
10 | import SwiftUI
11 |
12 | /// A view that displays a horizontal scrollable list of media previews (images and videos).
13 | struct MediaPreviewsView: View {
14 | /// The media selection containing arrays of image and video URLs
15 | let mediaSelection: MediaSelection
16 |
17 | var body: some View {
18 | ScrollView(.horizontal) {
19 | HStack(spacing: 8) {
20 | // Display image previews
21 | ForEach(mediaSelection.images, id: \.self) { imageURL in
22 | MediaPreviewView(
23 | mediaURL: imageURL,
24 | type: .image,
25 | onRemove: {
26 | mediaSelection.images.removeAll(where: { $0 == imageURL })
27 | }
28 | )
29 | }
30 |
31 | // Display video previews
32 | ForEach(mediaSelection.videos, id: \.self) { videoURL in
33 | MediaPreviewView(
34 | mediaURL: videoURL,
35 | type: .video,
36 | onRemove: {
37 | mediaSelection.videos.removeAll(where: { $0 == videoURL })
38 | }
39 | )
40 | }
41 | }
42 | .padding(.horizontal)
43 | }
44 | .padding(.top)
45 | }
46 | }
47 |
48 | /// A view that displays a single media item (image or video) with a remove button.
49 | struct MediaPreviewView: View {
50 | /// URL of the media file to display
51 | let mediaURL: URL
52 | /// Type of media (image or video)
53 | let type: MediaPreviewType
54 | /// Callback to handle removal of the media item
55 | let onRemove: () -> Void
56 |
57 | var body: some View {
58 | ZStack(alignment: .topTrailing) {
59 | switch type {
60 | case .image:
61 | // Display image with loading placeholder
62 | AsyncImage(url: mediaURL) { image in
63 | image
64 | .resizable()
65 | .scaledToFit()
66 | .frame(height: 100)
67 | .clipShape(RoundedRectangle(cornerRadius: 8))
68 | } placeholder: {
69 | ProgressView()
70 | .frame(width: 150, height: 100)
71 | }
72 | case .video:
73 | // Display video player
74 | VideoPlayer(player: AVPlayer(url: mediaURL))
75 | .frame(width: 150, height: 100)
76 | .clipShape(RoundedRectangle(cornerRadius: 8))
77 | }
78 |
79 | RemoveButton(action: onRemove)
80 | }
81 | }
82 | }
83 |
84 | /// A button for removing media items from the preview.
85 | struct RemoveButton: View {
86 | /// Action to perform when the remove button is tapped
87 | let action: () -> Void
88 |
89 | var body: some View {
90 | Button(action: action) {
91 | Image(systemName: "xmark.circle.fill")
92 | .foregroundStyle(.secondary)
93 | .imageScale(.large)
94 | }
95 | .buttonStyle(.plain)
96 | .padding(4)
97 | }
98 | }
99 |
100 | extension MediaPreviewView {
101 | /// Defines the type of media that can be displayed in the preview
102 | enum MediaPreviewType {
103 | /// An image file
104 | case image
105 | /// A video file
106 | case video
107 | }
108 | }
109 |
110 | #Preview("Remove Button") {
111 | RemoveButton {}
112 | }
113 |
--------------------------------------------------------------------------------
/Applications/MNISTTrainer/PredictionView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // PredictionView.swift
3 | // MNISTTrainer
4 | //
5 | // Created by Rounak Jain on 3/9/24.
6 | //
7 |
8 | import MLX
9 | import MLXMNIST
10 | import MLXNN
11 | import SwiftUI
12 |
13 | struct Canvas: View {
14 |
15 | @Binding var path: Path
16 | @State var lastPoint: CGPoint?
17 |
18 | var body: some View {
19 | path
20 | .stroke(.white, lineWidth: 10)
21 | .background(.black)
22 | .gesture(
23 | DragGesture(minimumDistance: 0.05)
24 | .onChanged { touch in
25 | add(point: touch.location)
26 | }
27 | .onEnded { touch in
28 | lastPoint = nil
29 | }
30 | )
31 | }
32 |
33 | func add(point: CGPoint) {
34 | var newPath = path
35 | if let lastPoint {
36 | newPath.move(to: lastPoint)
37 | newPath.addLine(to: point)
38 | } else {
39 | newPath.move(to: point)
40 | }
41 | self.path = newPath
42 | lastPoint = point
43 | }
44 | }
45 |
46 | extension Path {
47 | mutating func center(to newMidPoint: CGPoint) {
48 | let middleX = boundingRect.midX
49 | let middleY = boundingRect.midY
50 | self = offsetBy(dx: newMidPoint.x - middleX, dy: newMidPoint.y - middleY)
51 | }
52 | }
53 |
54 | struct PredictionView: View {
55 | @State var path: Path = Path()
56 | @State var prediction: Int?
57 | let model: LeNetContainer
58 | let canvasSize = 150.0
59 | let mnistImageSize: CGSize = CGSize(width: 28, height: 28)
60 |
61 | var body: some View {
62 | VStack {
63 | if let prediction {
64 | Text("You've drawn a \(prediction)")
65 | } else {
66 | Text("Draw a digit")
67 | }
68 | Canvas(path: $path)
69 | .frame(width: canvasSize, height: canvasSize)
70 | HStack {
71 | Button("Predict") {
72 | path.center(to: CGPoint(x: canvasSize / 2, y: canvasSize / 2))
73 | predict()
74 | }
75 | Button("Clear") {
76 | path = Path()
77 | prediction = nil
78 | }
79 | }
80 | }
81 | }
82 |
83 | @MainActor
84 | func predict() {
85 | let imageRenderer = ImageRenderer(
86 | content: Canvas(path: $path).frame(width: 150, height: 150))
87 |
88 | if let image = imageRenderer.cgImage {
89 | Task {
90 | self.prediction = await model.evaluate(image: image)
91 | }
92 | }
93 | }
94 | }
95 |
96 | extension CGImage {
97 | func grayscaleImage(with newSize: CGSize) -> CGImage? {
98 | let colorSpace = CGColorSpaceCreateDeviceGray()
99 | let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
100 |
101 | guard
102 | let context = CGContext(
103 | data: nil,
104 | width: Int(newSize.width),
105 | height: Int(newSize.height),
106 | bitsPerComponent: 8,
107 | bytesPerRow: Int(newSize.width),
108 | space: colorSpace,
109 | bitmapInfo: bitmapInfo.rawValue)
110 | else {
111 | return nil
112 | }
113 | context.draw(self, in: CGRect(x: 0, y: 0, width: newSize.width, height: newSize.width))
114 | return context.makeImage()
115 | }
116 |
117 | func pixelData() -> MLXArray {
118 | guard let data = self.dataProvider?.data else {
119 | return []
120 | }
121 | let bytePtr = CFDataGetBytePtr(data)
122 | let count = CFDataGetLength(data)
123 | return MLXArray(UnsafeBufferPointer(start: bytePtr, count: count))
124 | }
125 | }
126 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Views/ContentView.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import AsyncAlgorithms
4 | import MLX
5 | import MLXLLM
6 | import MLXLMCommon
7 | import Metal
8 | import SwiftUI
9 | import Tokenizers
10 |
11 | struct ContentView: View {
12 | @Environment(DeviceStat.self) private var deviceStat
13 |
14 | @State var llm = LLMEvaluator()
15 |
16 | enum DisplayStyle: String, CaseIterable, Identifiable {
17 | case plain, markdown
18 | var id: Self { self }
19 | }
20 |
21 | @State private var selectedDisplayStyle = DisplayStyle.markdown
22 | @State private var showingPresetPrompts = false
23 | @State private var isPromptExpanded = false
24 |
25 | var body: some View {
26 | VStack(alignment: .leading, spacing: 0) {
27 | // Header Section
28 | HeaderView(
29 | llm: llm,
30 | selectedDisplayStyle: $selectedDisplayStyle
31 | )
32 |
33 | Divider()
34 | .padding(.bottom, 12)
35 |
36 | // Output display
37 | OutputView(
38 | output: llm.output,
39 | displayStyle: selectedDisplayStyle,
40 | wasTruncated: llm.wasTruncated
41 | )
42 |
43 | // Prompt input section
44 | PromptInputView(
45 | llm: llm,
46 | isPromptExpanded: $isPromptExpanded,
47 | showingPresetPrompts: $showingPresetPrompts,
48 | onGenerate: generate,
49 | onCancel: cancel
50 | )
51 |
52 | // Performance Metrics Panel
53 | MetricsView(
54 | tokensPerSecond: llm.tokensPerSecond,
55 | timeToFirstToken: llm.timeToFirstToken,
56 | promptLength: llm.promptLength,
57 | totalTokens: llm.totalTokens,
58 | totalTime: llm.totalTime,
59 | memoryUsed: deviceStat.gpuUsage.activeMemory,
60 | cacheMemory: deviceStat.gpuUsage.cacheMemory,
61 | peakMemory: deviceStat.gpuUsage.peakMemory
62 | )
63 | }
64 | #if os(visionOS)
65 | .padding(40)
66 | #else
67 | .padding()
68 | #endif
69 | .toolbar {
70 | ToolbarItem(placement: .primaryAction) {
71 | Button {
72 | Task {
73 | copyToClipboard(llm.output)
74 | }
75 | } label: {
76 | Label("Copy Output", systemImage: "doc.on.doc.fill")
77 | }
78 | .disabled(llm.output == "")
79 | .labelStyle(.titleAndIcon)
80 | }
81 |
82 | }
83 | .task {
84 | do {
85 | // pre-load the weights on launch to speed up the first generation
86 | _ = try await llm.load()
87 | } catch {
88 | llm.output = "Failed: \(error)"
89 | }
90 | }
91 | .sheet(isPresented: $showingPresetPrompts) {
92 | PresetPromptsSheet(isPresented: $showingPresetPrompts) { preset in
93 | llm.prompt = preset.prompt
94 | llm.includeWeatherTool = preset.enableTools
95 | llm.enableThinking = preset.enableThinking
96 | }
97 | }
98 | .overlay {
99 | if llm.isLoading {
100 | LoadingOverlayView(
101 | modelInfo: llm.modelInfo,
102 | downloadProgress: llm.downloadProgress,
103 | progressDescription: llm.totalSize
104 | )
105 | }
106 | }
107 | }
108 |
109 | private func generate() {
110 | llm.generate()
111 | }
112 |
113 | private func cancel() {
114 | llm.cancelGeneration()
115 | }
116 |
117 | private func copyToClipboard(_ string: String) {
118 | #if os(macOS)
119 | NSPasteboard.general.clearContents()
120 | NSPasteboard.general.setString(string, forType: .string)
121 | #else
122 | UIPasteboard.general.string = string
123 | #endif
124 | }
125 | }
126 |
--------------------------------------------------------------------------------
/Applications/LLMEval/Views/PresetPromptsSheet.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2025 Apple Inc.
2 |
3 | import SwiftUI
4 |
5 | struct PresetPromptsSheet: View {
6 | @Binding var isPresented: Bool
7 | let onSelect: (PresetPrompt) -> Void
8 |
9 | var body: some View {
10 | NavigationStack {
11 | List {
12 | ForEach(PresetPrompts.all) { preset in
13 | Button {
14 | onSelect(preset)
15 | isPresented = false
16 | } label: {
17 | HStack(alignment: .center, spacing: 12) {
18 | // Show the actual prompt (first 2 lines)
19 | // Clean up whitespace for better preview
20 | let cleanedPrompt = preset.prompt
21 | .trimmingCharacters(in: .whitespacesAndNewlines)
22 | .replacingOccurrences(
23 | of: #"\n\n+"#, with: " ", options: .regularExpression
24 | )
25 | .replacingOccurrences(
26 | of: #"\s+"#, with: " ", options: .regularExpression)
27 |
28 | Text(cleanedPrompt)
29 | .multilineTextAlignment(.leading)
30 | .lineLimit(2)
31 | .font(.body)
32 | .foregroundStyle(.primary)
33 | .frame(maxWidth: .infinity, alignment: .leading)
34 |
35 | // Show indicators if present
36 | if preset.enableThinking || preset.enableTools || preset.isLongPrompt {
37 | HStack(spacing: 6) {
38 | if preset.enableThinking {
39 | BadgeView(icon: "brain", text: "Thinking", color: .purple)
40 | }
41 | if preset.enableTools {
42 | BadgeView(icon: "hammer.fill", text: "Tools", color: .blue)
43 | }
44 | if preset.isLongPrompt {
45 | BadgeView(
46 | icon: "doc.text.fill", text: "Long", color: .orange)
47 | }
48 | }
49 | }
50 | }
51 | .padding(.vertical, 8)
52 | #if os(macOS)
53 | .frame(maxWidth: .infinity, alignment: .leading)
54 | .contentShape(Rectangle())
55 | #endif
56 | }
57 | .buttonStyle(.plain)
58 | #if os(macOS)
59 | .listRowInsets(EdgeInsets(top: 8, leading: 12, bottom: 8, trailing: 12))
60 | #endif
61 | }
62 | }
63 | #if os(macOS)
64 | .listStyle(.inset)
65 | #endif
66 | .navigationTitle("Example Prompts")
67 | #if !os(macOS)
68 | .navigationBarTitleDisplayMode(.inline)
69 | #endif
70 | .toolbar {
71 | ToolbarItem(placement: .cancellationAction) {
72 | Button("Close") {
73 | isPresented = false
74 | }
75 | }
76 | }
77 | }
78 | #if os(macOS)
79 | .frame(minWidth: 600, minHeight: 500)
80 | #endif
81 | }
82 | }
83 |
84 | // Badge component
85 | private struct BadgeView: View {
86 | let icon: String
87 | let text: String
88 | let color: Color
89 |
90 | var body: some View {
91 | Label(text, systemImage: icon)
92 | .font(.caption2)
93 | .fontWeight(.medium)
94 | .foregroundStyle(.white)
95 | .padding(.horizontal, 8)
96 | .padding(.vertical, 4)
97 | .background(color.gradient, in: Capsule())
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved:
--------------------------------------------------------------------------------
1 | {
2 | "originHash" : "0dd305f09ba278569849d2b91cb2ef7a95956604f60ef1c3013f0e7b4b61d1a7",
3 | "pins" : [
4 | {
5 | "identity" : "gzipswift",
6 | "kind" : "remoteSourceControl",
7 | "location" : "https://github.com/1024jp/GzipSwift",
8 | "state" : {
9 | "revision" : "731037f6cc2be2ec01562f6597c1d0aa3fe6fd05",
10 | "version" : "6.0.1"
11 | }
12 | },
13 | {
14 | "identity" : "mlx-swift",
15 | "kind" : "remoteSourceControl",
16 | "location" : "https://github.com/ml-explore/mlx-swift",
17 | "state" : {
18 | "revision" : "072b684acaae80b6a463abab3a103732f33774bf",
19 | "version" : "0.29.1"
20 | }
21 | },
22 | {
23 | "identity" : "mlx-swift-lm",
24 | "kind" : "remoteSourceControl",
25 | "location" : "https://github.com/ml-explore/mlx-swift-lm",
26 | "state" : {
27 | "revision" : "01852971866b17889f7aff27500cc62d9148b0df",
28 | "version" : "2.29.2"
29 | }
30 | },
31 | {
32 | "identity" : "networkimage",
33 | "kind" : "remoteSourceControl",
34 | "location" : "https://github.com/gonzalezreal/NetworkImage",
35 | "state" : {
36 | "revision" : "2849f5323265386e200484b0d0f896e73c3411b9",
37 | "version" : "6.0.1"
38 | }
39 | },
40 | {
41 | "identity" : "progress.swift",
42 | "kind" : "remoteSourceControl",
43 | "location" : "https://github.com/jkandzi/Progress.swift",
44 | "state" : {
45 | "revision" : "fed6598735d7982058690acf8f52a0a5fdaeb3e0",
46 | "version" : "0.4.0"
47 | }
48 | },
49 | {
50 | "identity" : "swift-argument-parser",
51 | "kind" : "remoteSourceControl",
52 | "location" : "https://github.com/apple/swift-argument-parser.git",
53 | "state" : {
54 | "revision" : "cdd0ef3755280949551dc26dee5de9ddeda89f54",
55 | "version" : "1.6.2"
56 | }
57 | },
58 | {
59 | "identity" : "swift-async-algorithms",
60 | "kind" : "remoteSourceControl",
61 | "location" : "https://github.com/apple/swift-async-algorithms.git",
62 | "state" : {
63 | "revision" : "042e1c4d9d19748c9c228f8d4ebc97bb1e339b0b",
64 | "version" : "1.0.4"
65 | }
66 | },
67 | {
68 | "identity" : "swift-cmark",
69 | "kind" : "remoteSourceControl",
70 | "location" : "https://github.com/swiftlang/swift-cmark",
71 | "state" : {
72 | "revision" : "b97d09472e847a416629f026eceae0e2afcfad65",
73 | "version" : "0.7.0"
74 | }
75 | },
76 | {
77 | "identity" : "swift-collections",
78 | "kind" : "remoteSourceControl",
79 | "location" : "https://github.com/apple/swift-collections.git",
80 | "state" : {
81 | "revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e",
82 | "version" : "1.3.0"
83 | }
84 | },
85 | {
86 | "identity" : "swift-jinja",
87 | "kind" : "remoteSourceControl",
88 | "location" : "https://github.com/huggingface/swift-jinja.git",
89 | "state" : {
90 | "revision" : "c1ef5963ba4a97a589b9c9583ff4ee3352a86d23",
91 | "version" : "2.1.0"
92 | }
93 | },
94 | {
95 | "identity" : "swift-markdown-ui",
96 | "kind" : "remoteSourceControl",
97 | "location" : "https://github.com/gonzalezreal/swift-markdown-ui",
98 | "state" : {
99 | "revision" : "5f613358148239d0292c0cef674a3c2314737f9e",
100 | "version" : "2.4.1"
101 | }
102 | },
103 | {
104 | "identity" : "swift-numerics",
105 | "kind" : "remoteSourceControl",
106 | "location" : "https://github.com/apple/swift-numerics",
107 | "state" : {
108 | "revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2",
109 | "version" : "1.1.1"
110 | }
111 | },
112 | {
113 | "identity" : "swift-transformers",
114 | "kind" : "remoteSourceControl",
115 | "location" : "https://github.com/huggingface/swift-transformers",
116 | "state" : {
117 | "revision" : "94610577e4af9bbc267060af1e25e977604dd796",
118 | "version" : "1.1.1"
119 | }
120 | }
121 | ],
122 | "version" : 3
123 | }
124 |
--------------------------------------------------------------------------------
/Libraries/StableDiffusion/Sampler.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 | import MLX
5 |
6 | // port of https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/sampler.py
7 |
8 | /// Interpolate the function defined by `(0 ..< y.count) y)` at positions `xNew`.
9 | func interpolate(y: MLXArray, xNew: MLXArray) -> MLXArray {
10 | let xLow = xNew.asType(.int32)
11 | let xHigh = minimum(xLow + 1, y.count - 1)
12 |
13 | let yLow = y[xLow]
14 | let yHigh = y[xHigh]
15 | let deltaX = xNew - xLow
16 | let yNew = yLow * (1 - deltaX) + deltaX * yHigh
17 |
18 | return yNew
19 | }
20 |
21 | /// A simple Euler integrator that can be used to sample from our diffusion models.
22 | ///
23 | /// The method ``step()`` performs one Euler step from `x_t` to `x_t_prev`.
24 | class SimpleEulerSampler {
25 |
26 | let sigmas: MLXArray
27 |
28 | public init(configuration: DiffusionConfiguration) {
29 | let betas: MLXArray
30 |
31 | // compute the noise schedule
32 | switch configuration.betaSchedule {
33 | case .linear:
34 | betas = MLXArray.linspace(
35 | configuration.betaStart, configuration.betaEnd, count: configuration.trainSteps)
36 | case .scaledLinear:
37 | betas = MLXArray.linspace(
38 | sqrt(configuration.betaStart), sqrt(configuration.betaEnd),
39 | count: configuration.trainSteps
40 | ).square()
41 | }
42 |
43 | let alphas = 1 - betas
44 | let alphasCumprod = cumprod(alphas)
45 |
46 | self.sigmas = concatenated([
47 | MLXArray.zeros([1]), ((1 - alphasCumprod) / alphasCumprod).sqrt(),
48 | ])
49 | }
50 |
51 | public var maxTime: Int {
52 | sigmas.count - 1
53 | }
54 |
55 | public func samplePrior(shape: [Int], dType: DType = .float32, key: MLXArray? = nil) -> MLXArray
56 | {
57 | let noise = MLXRandom.normal(shape, key: key)
58 | return (noise * sigmas[-1] * (sigmas[-1].square() + 1).rsqrt()).asType(dType)
59 | }
60 |
61 | public func addNoise(x: MLXArray, t: MLXArray, key: MLXArray? = nil) -> MLXArray {
62 | let noise = MLXRandom.normal(x.shape, key: key)
63 | let s = sigmas(t)
64 | return (x + noise * s) * (s.square() + 1).rsqrt()
65 | }
66 |
67 | public func sigmas(_ t: MLXArray) -> MLXArray {
68 | interpolate(y: sigmas, xNew: t)
69 | }
70 |
71 | public func timeSteps(steps: Int, start: Int? = nil, dType: DType = .float32) -> [(
72 | MLXArray, MLXArray
73 | )] {
74 | let start = start ?? (sigmas.count - 1)
75 | precondition(0 < start)
76 | precondition(start <= sigmas.count - 1)
77 | let steps = MLX.linspace(start, 0, count: steps + 1).asType(dType)
78 |
79 | return Array(zip(steps, steps[1...]))
80 | }
81 |
82 | open func step(epsPred: MLXArray, xt: MLXArray, t: MLXArray, tPrev: MLXArray) -> MLXArray {
83 | let dtype = epsPred.dtype
84 | let sigma = sigmas(t).asType(dtype)
85 | let sigmaPrev = sigmas(tPrev).asType(dtype)
86 |
87 | let dt = sigmaPrev - sigma
88 | var xtPrev = (sigma.square() + 1).sqrt() * xt + epsPred * dt
89 | xtPrev = xtPrev * (sigmaPrev.square() + 1).rsqrt()
90 |
91 | return xtPrev
92 | }
93 | }
94 |
95 | class SimpleEulerAncestralSampler: SimpleEulerSampler {
96 |
97 | open override func step(epsPred: MLXArray, xt: MLXArray, t: MLXArray, tPrev: MLXArray)
98 | -> MLXArray
99 | {
100 | let dtype = epsPred.dtype
101 | let sigma = sigmas(t).asType(dtype)
102 | let sigmaPrev = sigmas(tPrev).asType(dtype)
103 |
104 | let sigma2 = sigma.square()
105 | let sigmaPrev2 = sigmaPrev.square()
106 | let sigmaUp = (sigmaPrev2 * (sigma2 - sigmaPrev2) / sigma2).sqrt()
107 | let sigmaDown = (sigmaPrev2 - sigmaUp ** 2).sqrt()
108 |
109 | let dt = sigmaDown - sigma
110 | var xtPrev = (sigma2 + 1).sqrt() * xt + epsPred * dt
111 | let noise = MLXRandom.normal(xtPrev.shape).asType(xtPrev.dtype)
112 | xtPrev = xtPrev + noise * sigmaUp
113 | xtPrev = xtPrev * (sigmaPrev2 + 1).rsqrt()
114 |
115 | return xtPrev
116 | }
117 | }
118 |
--------------------------------------------------------------------------------
/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/embedder-tool.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
25 |
26 |
32 |
33 |
45 |
47 |
53 |
54 |
55 |
56 |
59 |
60 |
63 |
64 |
67 |
68 |
71 |
72 |
75 |
76 |
79 |
80 |
81 |
82 |
88 |
90 |
96 |
97 |
98 |
99 |
101 |
102 |
105 |
106 |
107 |
--------------------------------------------------------------------------------
/Data/lora/wikisql.py:
--------------------------------------------------------------------------------
1 | # Copyright © 2023 Apple Inc.
2 |
3 | """
4 | Code to preprocess the WikiSQL dataset adapted from
5 | https://github.com/salesforce/WikiSQL and
6 | https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb .
7 | """
8 |
9 |
10 | import json
11 | import os
12 |
13 |
14 | def load():
15 | """
16 | Load all three splits of the WikiSQL dataset.
17 | """
18 | return (WikiSQL(dn) for dn in ["train", "dev", "test"])
19 |
20 |
21 | class WikiSQL:
22 | def __init__(self, dataset, save_dir="/tmp"):
23 | valid_sets = ("train", "dev", "test")
24 | if dataset not in valid_sets:
25 | raise ValueError(f"Dataset must be in {valid_sets}, got {dataset}")
26 | data_dir = os.path.join(save_dir, "wikisql")
27 | self._maybe_download(data_dir)
28 |
29 | self._parse_tables(os.path.join(data_dir, f"data/{dataset}.tables.jsonl"))
30 | self._parse_queries(os.path.join(data_dir, f"data/{dataset}.jsonl"))
31 |
32 | def _maybe_download(self, data_dir):
33 | if not os.path.exists(data_dir):
34 | import io
35 | import tarfile
36 | from urllib import request
37 |
38 | url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
39 | r = request.urlopen(url)
40 | with tarfile.open(fileobj=io.BytesIO(r.read())) as tf:
41 | tf.extractall(data_dir)
42 |
43 | def _parse_tables(self, tables):
44 | self._tables = {}
45 | with open(tables) as f:
46 | for line in f:
47 | table = json.loads(line)
48 | self._tables[table["id"]] = {
49 | "columns": table["header"],
50 | "types": table["types"],
51 | "desc": f"table: {table['id']}\ncolumns: {', '.join(table['header'])}",
52 | }
53 |
54 | def _parse_queries(self, queries):
55 | self._queries = []
56 | with open(queries) as f:
57 | for line in f:
58 | query = json.loads(line)
59 | table = self._tables[query["table_id"]]
60 | question = query["question"]
61 | answer = self.query_to_text(
62 | query["sql"], query["table_id"], table["columns"], table["types"]
63 | )
64 | self._queries.append(
65 | f"{table['desc']}\nQ: {question}\nA: {answer}"
66 | )
67 |
68 | def query_to_text(self, query, table, columns, types):
69 | aggregation_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
70 | condition_ops = ["=", ">", "<", "OP"]
71 | column = columns[query["sel"]]
72 | aggregation = (aggregation_ops[query["agg"]] + " ") if query["agg"] > 0 else ""
73 | sql = f"SELECT {aggregation}{column} FROM {table}"
74 |
75 | conditions = query["conds"]
76 | if conditions:
77 | cs = []
78 | for i, o, v in conditions:
79 | column = columns[i]
80 | op = condition_ops[o]
81 |
82 | if types[i] == "text":
83 | value = f"'{v}'"
84 | else:
85 | value = v
86 | cs.append(f"{column} {op} {value}")
87 |
88 | sql += " WHERE " + " AND ".join(cs)
89 |
90 | return sql
91 |
92 | def __getitem__(self, idx):
93 | return self._queries[idx]
94 |
95 | def __len__(self):
96 | return len(self._queries)
97 |
98 |
99 | if __name__ == "__main__":
100 | datanames = ["train", "dev", "test"]
101 | sizes = [56355, 8421, 15878]
102 | for dataname, size in zip(datanames, sizes):
103 | len(WikiSQL(dataname)) == size, f"Wrong {dataname} set size."
104 |
105 | # Write the sets to jsonl
106 | import json
107 |
108 | train, dev, test = load()
109 | datasets = [
110 | (train, "train", 1000),
111 | (dev, "valid", 100),
112 | (test, "test", 100),
113 | ]
114 | for dataset, name, size in datasets:
115 | with open(f"data/{name}.jsonl", "w") as fid:
116 | for e, t in zip(range(size), dataset):
117 | # Strip the , since the tokenizer adds them
118 | json.dump({"text": t[3:-4]}, fid)
119 | fid.write("\n")
120 |
--------------------------------------------------------------------------------
/Tools/LinearModelTraining/LinearModelTraining.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import ArgumentParser
4 | import Foundation
5 | import MLX
6 | import MLXNN
7 | import MLXOptimizers
8 |
9 | #if swift(>=5.10)
10 | extension MLX.DeviceType: @retroactive ExpressibleByArgument {
11 | public init?(argument: String) {
12 | self.init(rawValue: argument)
13 | }
14 | }
15 | #else
16 | extension MLX.DeviceType: ExpressibleByArgument {
17 | public init?(argument: String) {
18 | self.init(rawValue: argument)
19 | }
20 | }
21 | #endif
22 |
23 | @main
24 | struct Train: AsyncParsableCommand {
25 |
26 | @Option var epochs = 20
27 | @Option var batchSize = 8
28 |
29 | @Option var m: Float = 0.25
30 | @Option var b: Float = 7
31 |
32 | @Flag var compile = false
33 |
34 | @Option var device = DeviceType.cpu
35 |
36 | func run() async throws {
37 | Device.setDefault(device: Device(device))
38 |
39 | // A very simple model that implements the equation
40 | // for a linear function: y = mx + b. This can be trained
41 | // to match data – in this case, an unknown (to the model)
42 | // linear function.
43 | //
44 | // This is a nice example because most people know how
45 | // linear functions work and we can see how the slope
46 | // and intercept converge.
47 | class LinearFunctionModel: Module, UnaryLayer {
48 | let m = MLXRandom.uniform(low: -5.0, high: 5.0)
49 | let b = MLXRandom.uniform(low: -5.0, high: 5.0)
50 |
51 | func callAsFunction(_ x: MLXArray) -> MLXArray {
52 | m * x + b
53 | }
54 | }
55 |
56 | // Measure the distance from the prediction (model(x)) and the
57 | // ground truth (y). This gives feedback on how close the
58 | // prediction is from matching the truth.
59 | func loss(model: LinearFunctionModel, x: MLXArray, y: MLXArray) -> MLXArray {
60 | mseLoss(predictions: model(x), targets: y, reduction: .mean)
61 | }
62 |
63 | let model = LinearFunctionModel()
64 | eval(model.parameters())
65 |
66 | let lg = valueAndGrad(model: model, loss)
67 |
68 | // The optimizer will use the gradients update the model parameters
69 | let optimizer = SGD(learningRate: 1e-1)
70 |
71 | // The function to train our model against. It doesn't have
72 | // to be linear, but matching what the model models is easy
73 | // to understand.
74 | func f(_ x: MLXArray) -> MLXArray {
75 | // These are the target parameters
76 | let m = self.m
77 | let b = self.b
78 |
79 | // Our actual function
80 | return m * x + b
81 | }
82 |
83 | func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray {
84 | let (loss, grads) = lg(model, x, y)
85 | optimizer.update(model: model, gradients: grads)
86 | return loss
87 | }
88 |
89 | let resolvedStep =
90 | self.compile
91 | ? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step
92 |
93 | for _ in 0 ..< epochs {
94 | // We expect that the parameters will approach the targets
95 | print("target: b = \(b), m = \(m)")
96 | print("parameters: \(model.parameters())")
97 |
98 | // Generate random training data along with the ground truth.
99 | // Notice that the shape is [B, 1] where B is the batch
100 | // dimension. This allows us to train on several samples simultaneously.
101 | //
102 | // Note: A very large batch size will take longer to converge because
103 | // the gradient will be representing too many samples down into
104 | // a single float parameter.
105 | let x = MLXRandom.uniform(low: -5.0, high: 5.0, [batchSize, 1])
106 | let y = f(x)
107 | eval(x, y)
108 |
109 | // Compute the loss and gradients. Use the optimizer
110 | // to adjust the parameters closer to the target.
111 | let loss = resolvedStep(x, y)
112 |
113 | eval(model, optimizer)
114 |
115 | // We should see this converge toward 0
116 | print("loss: \(loss)")
117 | }
118 |
119 | }
120 | }
121 |
--------------------------------------------------------------------------------
/Libraries/StableDiffusion/Tokenizer.swift:
--------------------------------------------------------------------------------
1 | // Copyright © 2024 Apple Inc.
2 |
3 | import Foundation
4 |
5 | struct Bigram: Hashable {
6 | let a: String
7 | let b: String
8 |
9 | init(_ s: String) {
10 | let pieces = s.split(separator: " ")
11 | precondition(pieces.count == 2, "BPEPair expected two pieces for '\(s)'")
12 | self.a = String(pieces[0])
13 | self.b = String(pieces[1])
14 | }
15 |
16 | init(_ a: String, _ b: String) {
17 | self.a = a
18 | self.b = b
19 | }
20 |
21 | init(_ v: (String, String)) {
22 | self.a = v.0
23 | self.b = v.1
24 | }
25 | }
26 |
27 | /// A CLIP tokenizer.
28 | ///
29 | /// Ported from:
30 | ///
31 | /// - https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
32 | /// - https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
33 | ///
34 | /// Ideally this would be a tokenizer from `swift-transformers` but this is too special purpose to be representable in
35 | /// what exists there (at time of writing).
36 | class CLIPTokenizer {
37 |
38 | let pattern =
39 | #/<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+/#
40 | let bpeRanks: [Bigram: Int]
41 | let vocabulary: [String: Int]
42 |
43 | let bos = "<|startoftext|>"
44 | let eos = "<|endoftext|>"
45 |
46 | let bosToken: Int
47 | let eosToken: Int
48 |
49 | var cache = [String: [String]]()
50 |
51 | init(merges: [String], vocabulary: [String: Int]) {
52 | self.bpeRanks = Dictionary(
53 | uniqueKeysWithValues:
54 | merges
55 | .map { Bigram($0) }
56 | .enumerated()
57 | .map { ($0.element, $0.offset) })
58 |
59 | self.vocabulary = vocabulary
60 | self.cache[bos] = [bos]
61 | self.cache[eos] = [eos]
62 | self.bosToken = vocabulary[bos]!
63 | self.eosToken = vocabulary[eos]!
64 | }
65 |
66 | func bpe(text: String) -> [String] {
67 | if let result = cache[text] {
68 | return result
69 | }
70 |
71 | precondition(!text.isEmpty)
72 |
73 | var unigrams = text.dropLast().map { String($0) } + ["\(text.last!)"]
74 | var uniqueBigrams = Set(zip(unigrams, unigrams.dropFirst()).map { Bigram($0) })
75 |
76 | // In every iteration try to merge the two most likely bigrams. If none
77 | // was merged we are done
78 |
79 | while !uniqueBigrams.isEmpty {
80 | let (bigram, _) =
81 | uniqueBigrams
82 | .map { ($0, bpeRanks[$0] ?? Int.max) }
83 | .min { $0.1 < $1.1 }!
84 |
85 | if bpeRanks[bigram] == nil {
86 | break
87 | }
88 |
89 | var newUnigrams = [String]()
90 | var skip = false
91 |
92 | for (a, b) in zip(unigrams, unigrams.dropFirst()) {
93 | if skip {
94 | skip = false
95 | continue
96 | }
97 |
98 | if Bigram(a, b) == bigram {
99 | newUnigrams.append(a + b)
100 | skip = true
101 | } else {
102 | newUnigrams.append(a)
103 | }
104 | }
105 |
106 | if !skip, let last = unigrams.last {
107 | newUnigrams.append(last)
108 | }
109 |
110 | unigrams = newUnigrams
111 | uniqueBigrams = Set(zip(unigrams, unigrams.dropFirst()).map { Bigram($0) })
112 | }
113 |
114 | cache[text] = unigrams
115 |
116 | return unigrams
117 | }
118 |
119 | public func tokenize(text: String) -> [Int32] {
120 | // Lower case cleanup and split according to self.pat. Hugging Face does
121 | // a much more thorough job here but this should suffice for 95% of
122 | // cases.
123 |
124 | let clean = text.lowercased().replacing(#/\s+/#, with: " ")
125 | let tokens = clean.matches(of: pattern).map { $0.description }
126 |
127 | // Split the tokens according to the byte-pair merge file
128 | let bpeTokens = tokens.flatMap { bpe(text: String($0)) }
129 |
130 | // Map to token ids and return
131 | let result = [bosToken] + bpeTokens.compactMap { vocabulary[$0] } + [eosToken]
132 |
133 | return result.map { Int32($0) }
134 | }
135 | }
136 |
--------------------------------------------------------------------------------
/Applications/MLXChatExample/README.md:
--------------------------------------------------------------------------------
1 | # MLX Chat Example
2 |
3 | A lightweight chat application demonstrating MLX integration for iOS and macOS. Built with SwiftUI, this example project shows how to implement both Large Language Models (LLMs) and Vision Language Models (VLMs) using MLX.
4 |
5 |
6 |
7 | ## Features
8 |
9 | - 🤖 LLM and VLM support with real-time text generation
10 | - 📱 Cross-platform (iOS and macOS)
11 | - 🖼️ Image and video input for vision models
12 | - 💾 Efficient model caching and memory management
13 | - ⚡️ Async/await based generation with cancellation support
14 | - 🎨 Modern SwiftUI interface
15 | - 📝 Comprehensive documentation and comments
16 |
17 | ## Requirements
18 |
19 | - iOS 17.0+ / macOS 14.0+
20 | - Xcode 15.0+
21 | - Swift 5.9+
22 |
23 | ## Dependencies
24 |
25 | - [MLX](https://github.com/ml-explore/mlx-swift): Core machine learning operations
26 | - [MLXLMCommon](https://github.com/ml-explore/mlx-swift-lm/tree/main/Libraries/MLXLMCommon): Common language model utilities
27 | - [MLXLLM](https://github.com/ml-explore/mlx-swift-lm/tree/main/Libraries/MLXLLM): Large language model support
28 | - [MLXVLM](https://github.com/ml-explore/mlx-swift-lm/tree/main/Libraries/MLXVLM): Vision-language model support
29 |
30 | ## Project Structure
31 |
32 | ```
33 | MLXChatExample/
34 | ├── Views/ # SwiftUI views
35 | ├── Models/ # Data models
36 | ├── ViewModels/ # Business logic
37 | ├── Services/ # MLX integration
38 | └── Support/ # Utilities
39 | ```
40 |
41 | ## Technical Overview
42 |
43 | The project follows MVVM architecture with clear separation between UI and business logic. The core functionality is split into two main components:
44 |
45 | ### MLXService
46 |
47 | Core service handling all model operations:
48 | - Model loading and caching with memory management
49 | - Async text generation with streaming support
50 | - GPU memory optimization
51 | - Model state management
52 | - Handles both LLM and VLM model types
53 |
54 | ### ChatViewModel
55 |
56 | Business logic coordinator:
57 | - Manages chat state and message history
58 | - Handles generation lifecycle and cancellation
59 | - Coordinates media attachments for vision models
60 | - Performance metrics and error handling
61 | - Provides clean interface between UI and ML service
62 |
63 | ### Architecture Highlights
64 |
65 | - Complete separation of UI and business logic
66 | - SwiftUI views with async/await integration
67 | - Modular design for easy extension
68 |
69 | ### Documentation
70 |
71 | The codebase is thoroughly documented with:
72 | - Detailed class and method documentation
73 | - Clear inline comments explaining complex logic
74 | - DocC documentation format
75 |
76 | ### Markdown Support
77 |
78 | This sample app renders markdown content using SwiftUI's native `Text` view by passing the content as a `LocalizedStringKey`:
79 |
80 | ```swift
81 | Text(LocalizedStringKey(message.content))
82 | ```
83 |
84 | #### Limitations and Alternatives
85 |
86 | The default SwiftUI markdown rendering only supports standard markdown syntax. It does not support advanced features like tables and task lists that are available in GitHub Flavored Markdown (GFM).
87 |
88 | For more comprehensive markdown support:
89 |
90 | - **GitHub Flavored Markdown**: Consider using the [swift-markdown-ui](https://github.com/gonzalezreal/swift-markdown-ui) library. However, be aware that this library currently has an [unresolved issue with text selection](https://github.com/gonzalezreal/swift-markdown-ui/issues/264), which is why it wasn't used in this example.
91 |
92 | - **Enhanced Text Selection**: If you're satisfied with standard markdown but want better text selection capabilities on iOS (instead of only being able to select and copy entire content block), consider combining:
93 | - [SelectableText](https://github.com/kevinhermawan/SelectableText) for improved selection functionality
94 | - [MarkdownToAttributedString](https://github.com/madebywindmill/MarkdownToAttributedString) for markdown formatting
95 |
96 | > More discussion on this can be found on [issue #297](https://github.com/ml-explore/mlx-swift-examples/issues/297)
97 |
98 | ## Getting Started
99 |
100 | 1. Clone the repository
101 | 2. Install dependencies
102 | 3. Open in Xcode
103 | 4. Build and run
104 |
105 | ## Contributing
106 |
107 | This is an example project demonstrating MLX capabilities. Feel free to use it as a reference for your own projects.
108 |
109 | ## Acknowledgments
110 |
111 | - MLX team for the core framework
112 |
--------------------------------------------------------------------------------
/Tests/MLXLMTests/UserInputTests.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import MLX
3 | import MLXLMCommon
4 | import MLXVLM
5 | import XCTest
6 |
7 | func assertEqual(
8 | _ v1: Any, _ v2: Any, path: [String] = [], file: StaticString = #filePath, line: UInt = #line
9 | ) {
10 | switch (v1, v2) {
11 | case let (v1, v2) as (String, String):
12 | XCTAssertEqual(v1, v2, file: file, line: line)
13 |
14 | case let (v1, v2) as ([Any], [Any]):
15 | XCTAssertEqual(
16 | v1.count, v2.count, "Arrays not equal size at \(path)", file: file, line: line)
17 |
18 | for (index, (v1v, v2v)) in zip(v1, v2).enumerated() {
19 | assertEqual(v1v, v2v, path: path + [index.description], file: file, line: line)
20 | }
21 |
22 | case let (v1, v2) as ([String: Any], [String: Any]):
23 | XCTAssertEqual(
24 | v1.keys.sorted(), v2.keys.sorted(),
25 | "\(String(describing: v1.keys.sorted())) and \(String(describing: v2.keys.sorted())) not equal at \(path)",
26 | file: file, line: line)
27 |
28 | for (k, v1v) in v1 {
29 | if let v2v = v2[k] {
30 | assertEqual(v1v, v2v, path: path + [k], file: file, line: line)
31 | } else {
32 | XCTFail("Missing value for \(k) at \(path)", file: file, line: line)
33 | }
34 | }
35 | default:
36 | XCTFail(
37 | "Unable to compare \(String(describing: v1)) and \(String(describing: v2)) at \(path)",
38 | file: file, line: line)
39 | }
40 | }
41 |
42 | public class UserInputTests: XCTestCase {
43 |
44 | public func testStandardConversion() {
45 | let chat: [Chat.Message] = [
46 | .system("You are a useful agent."),
47 | .user("Tell me a story."),
48 | ]
49 |
50 | let messages = DefaultMessageGenerator().generate(messages: chat)
51 |
52 | let expected = [
53 | [
54 | "role": "system",
55 | "content": "You are a useful agent.",
56 | ],
57 | [
58 | "role": "user",
59 | "content": "Tell me a story.",
60 | ],
61 | ]
62 |
63 | XCTAssertEqual(expected, messages as? [[String: String]])
64 | }
65 |
66 | public func testQwen2ConversionText() {
67 | let chat: [Chat.Message] = [
68 | .system("You are a useful agent."),
69 | .user("Tell me a story."),
70 | ]
71 |
72 | let messages = Qwen2VLMessageGenerator().generate(messages: chat)
73 |
74 | let expected = [
75 | [
76 | "role": "system",
77 | "content": [
78 | [
79 | "type": "text",
80 | "text": "You are a useful agent.",
81 | ]
82 | ],
83 | ],
84 | [
85 | "role": "user",
86 | "content": [
87 | [
88 | "type": "text",
89 | "text": "Tell me a story.",
90 | ]
91 | ],
92 | ],
93 | ]
94 |
95 | assertEqual(expected, messages)
96 | }
97 |
98 | public func testQwen2ConversionImage() {
99 | let chat: [Chat.Message] = [
100 | .system("You are a useful agent."),
101 | .user(
102 | "What is this?",
103 | images: [
104 | .url(
105 | URL(
106 | string: "https://opensource.apple.com/images/projects/mlx.f5c59d8b.png")!
107 | )
108 | ]),
109 | ]
110 |
111 | let messages = Qwen2VLMessageGenerator().generate(messages: chat)
112 |
113 | let expected = [
114 | [
115 | "role": "system",
116 | "content": [
117 | [
118 | "type": "text",
119 | "text": "You are a useful agent.",
120 | ]
121 | ],
122 | ],
123 | [
124 | "role": "user",
125 | "content": [
126 | [
127 | "type": "text",
128 | "text": "What is this?",
129 | ],
130 | [
131 | "type": "image"
132 | ],
133 | ],
134 | ],
135 | ]
136 |
137 | assertEqual(expected, messages)
138 |
139 | let userInput = UserInput(chat: chat)
140 | XCTAssertEqual(userInput.images.count, 1)
141 | }
142 |
143 | }
144 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MLX Swift Examples
2 |
3 | Example [MLX Swift](https://github.com/ml-explore/mlx-swift) programs. The language model
4 | examples use models implemented in [MLX Swift LM](https://github.com/ml-explore/mlx-swift-lm).
5 |
6 | - [MNISTTrainer](Applications/MNISTTrainer/README.md): An example that runs on
7 | both iOS and macOS that downloads MNIST training data and trains a
8 | [LeNet](https://en.wikipedia.org/wiki/LeNet).
9 |
10 | - [LLMEval](Applications/LLMEval/README.md): An example that runs on both iOS
11 | and macOS that downloads an LLM and tokenizer from Hugging Face and
12 | generates text from a given prompt.
13 |
14 | - [VLMEval](Applications/VLMEval/README.md): An example that runs on iOS, macOS and visionOS to download a VLM and tokenizer from Hugging Face and
15 | analyzes the given image and describe it in text.
16 |
17 | - [MLXChatExample](Applications/MLXChatExample/README.md): An example chat app that runs on both iOS and macOS that supports LLMs and VLMs.
18 |
19 | - [LoRATrainingExample](Applications/LoRATrainingExample/README.md): An example that runs on macOS that downloads an LLM and fine-tunes it using LoRA (Low-Rank Adaptation) with training data.
20 |
21 | - [LinearModelTraining](Tools/LinearModelTraining/README.md): An example that
22 | trains a simple linear model.
23 |
24 | - [StableDiffusionExample](Applications/StableDiffusionExample/README.md): An
25 | example that runs on both iOS and macOS that downloads a stable diffusion model
26 | from Hugging Face and and generates an image from a given prompt.
27 |
28 | - [llm-tool](Tools/llm-tool/README.md): A command line tool for generating text
29 | using a variety of LLMs available on the Hugging Face hub.
30 |
31 | - [ExampleLLM](Tools/ExampleLLM/README.md): A command line tool using the simplified API to interact with LLMs.
32 |
33 | - [image-tool](Tools/image-tool/README.md): A command line tool for generating images
34 | using a stable diffusion model from Hugging Face.
35 |
36 | - [mnist-tool](Tools/mnist-tool/README.md): A command line tool for training a
37 | a LeNet on MNIST.
38 |
39 | > [!IMPORTANT]
40 | > `MLXLMCommon`, `MLXLLM`, `MLXVLM` and `MLXEmbedders` have moved to a new repository
41 | > containing _only_ reusable libraries: [mlx-swift-lm](https://github.com/ml-explore/mlx-swift-lm).
42 |
43 | Previous URLs and tags will continue to work, but going forward all updates to these
44 | libraries will be done in the other repository. Previous tags _are_ supported in
45 | the new repository.
46 |
47 | > [!TIP]
48 | > Contributors that wish to edit both `mlx-swift-examples` and `mlx-swift-lm` can
49 | > use [this technique in Xcode](https://developer.apple.com/documentation/xcode/editing-a-package-dependency-as-a-local-package).
50 |
51 |
52 | # Reusable Libraries
53 |
54 | LLM and VLM implementations are available in [MLX Swift LM](https://github.com/ml-explore/mlx-swift-lm):
55 |
56 | - [MLXLLMCommon](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxlmcommon) -- common API for LLM and VLM
57 | - [MLXLLM](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxllm) -- large language model example implementations
58 | - [MLXVLM](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxvlm) -- vision language model example implementations
59 | - [MLXEmbedders](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxembedders) -- popular Encoders / Embedding models example implementations
60 |
61 | MLX Swift Examples also contains a few reusable libraries that can be imported with this code in your `Package.swift` or by referencing the URL in Xcode:
62 |
63 | ```swift
64 | .package(url: "https://github.com/ml-explore/mlx-swift-examples/", branch: "main"),
65 | ```
66 |
67 | Then add one or more libraries to the target as a dependency:
68 |
69 | ```swift
70 | .target(
71 | name: "YourTargetName",
72 | dependencies: [
73 | .product(name: "StableDiffusion", package: "mlx-libraries")
74 | ]),
75 | ```
76 |
77 | - [StableDiffusion](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/stablediffusion) -- SDXL Turbo and Stable Diffusion model example implementations
78 | - [MLXMNIST](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxmnist) -- MNIST implementation for all your digit recognition needs
79 |
80 | ## Running
81 |
82 | The application and command line tool examples can be run from Xcode or from
83 | the command line:
84 |
85 | ```
86 | ./mlx-run llm-tool --prompt "swift programming language"
87 | ```
88 |
89 | Note: `mlx-run` is a shell script that uses `xcode` command line tools to
90 | locate the built binaries. It is equivalent to running from Xcode itself.
91 |
92 | See also:
93 |
94 | - [MLX troubleshooting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/troubleshooting)
95 |
--------------------------------------------------------------------------------