├── Tests ├── MLXLMTests │ ├── README.md │ ├── StreamlinedTests.swift │ ├── BaseConfigurationTests.swift │ ├── ToolTests.swift │ └── UserInputTests.swift └── mlx-libraries-Package.xctestplan ├── support ├── test.jpg └── generate-run-all-llms.sh ├── Applications ├── VLMEval │ ├── Assets.xcassets │ │ ├── Contents.json │ │ ├── AccentColor.colorset │ │ │ └── Contents.json │ │ └── AppIcon.appiconset │ │ │ └── Contents.json │ ├── Preview Content │ │ └── Preview Assets.xcassets │ │ │ └── Contents.json │ ├── VLMEvalApp.swift │ ├── VLMEval.entitlements │ └── README.md ├── LLMEval │ ├── AssetCatalog.xcassets │ │ ├── Contents.json │ │ └── AccentColor.colorset │ │ │ └── Contents.json │ ├── AppIcon.icon │ │ ├── Assets │ │ │ └── bubbles3.png │ │ └── icon.json │ ├── LLMEvalApp.swift │ ├── LLMEval.entitlements │ ├── Models │ │ ├── ToolDefinitions.swift │ │ └── PresetPrompts.swift │ ├── ViewModels │ │ └── DeviceStat.swift │ ├── Services │ │ ├── FormatUtilities.swift │ │ └── ToolExecutor.swift │ ├── Views │ │ ├── MetricCard.swift │ │ ├── OutputView.swift │ │ ├── LoadingOverlayView.swift │ │ ├── PromptInputView.swift │ │ ├── HeaderView.swift │ │ ├── MetricsView.swift │ │ ├── ContentView.swift │ │ └── PresetPromptsSheet.swift │ └── README.md ├── MNISTTrainer │ ├── Assets.xcassets │ │ ├── Contents.json │ │ ├── AccentColor.colorset │ │ │ └── Contents.json │ │ └── AppIcon.appiconset │ │ │ └── Contents.json │ ├── Preview Content │ │ └── Preview Assets.xcassets │ │ │ └── Contents.json │ ├── MNISTTrainer-Info.plist │ ├── MNISTTrainerApp.swift │ ├── MNISTTrainer.entitlements │ ├── README.md │ └── PredictionView.swift ├── LoRATrainingExample │ ├── Assets.xcassets │ │ ├── Contents.json │ │ ├── AccentColor.colorset │ │ │ └── Contents.json │ │ └── AppIcon.appiconset │ │ │ └── Contents.json │ ├── Preview Content │ │ └── Preview Assets.xcassets │ │ │ └── Contents.json │ ├── LoRATrainingExampleApp.swift │ ├── LoRATrainingExample.entitlements │ └── README.md ├── MLXChatExample │ ├── Support │ │ ├── Assets.xcassets │ │ │ ├── Contents.json │ │ │ ├── AccentColor.colorset │ │ │ │ └── Contents.json │ │ │ └── AppIcon.appiconset │ │ │ │ └── Contents.json │ │ ├── Preview Content │ │ │ └── Preview Assets.xcassets │ │ │ │ └── Contents.json │ │ ├── HubApi+default.swift │ │ └── SampleData.swift │ ├── MLXChatExampleApp.swift │ ├── Views │ │ ├── Toolbar │ │ │ ├── GenerationInfoView.swift │ │ │ ├── ErrorView.swift │ │ │ ├── DownloadProgressView.swift │ │ │ └── ChatToolbarView.swift │ │ ├── ConversationView.swift │ │ ├── PromptField.swift │ │ ├── MessageView.swift │ │ └── MediaPreviewView.swift │ ├── MLXChatExample.entitlements │ ├── Models │ │ ├── LMModel.swift │ │ └── Message.swift │ ├── ChatView.swift │ └── README.md └── StableDiffusionExample │ ├── Assets.xcassets │ ├── Contents.json │ ├── AccentColor.colorset │ │ └── Contents.json │ └── AppIcon.appiconset │ │ └── Contents.json │ ├── Preview Content │ └── Preview Assets.xcassets │ │ └── Contents.json │ ├── StableDiffusionExampleApp.swift │ ├── StableDiffusionExample.entitlements │ └── README.md ├── .spi.yml ├── .swift-format ├── mlx-swift-examples.xcodeproj ├── project.xcworkspace │ ├── contents.xcworkspacedata │ └── xcshareddata │ │ ├── IDEWorkspaceChecks.plist │ │ └── swiftpm │ │ └── Package.resolved └── xcshareddata │ └── xcschemes │ ├── LLMEval.xcscheme │ ├── VLMEval.xcscheme │ ├── ExampleLLM.xcscheme │ ├── StableDiffusionExample.xcscheme │ └── embedder-tool.xcscheme ├── Tools ├── embedder-tool │ ├── CommandError.swift │ ├── ArgumentSupport.swift │ ├── Diagnostics.swift │ ├── ListCommand.swift │ ├── CorpusArguments.swift │ ├── PoolingArguments.swift │ ├── EmbedderCommand.swift │ ├── MemoryArguments.swift │ ├── DemoCommand.swift │ ├── ModelArguments.swift │ ├── EmbedderRuntime+Embedding.swift │ ├── PoolingSupport.swift │ └── VectorOperations.swift ├── LinearModelTraining │ ├── README.md │ └── LinearModelTraining.swift ├── llm-tool │ ├── Arguments.swift │ └── ListCommands.swift ├── mnist-tool │ ├── README.md │ └── MNISTTool.swift ├── ExampleLLM │ ├── main.swift │ └── README.md ├── image-tool │ └── Arguments.swift └── Tutorial │ └── Tutorial.swift ├── .pre-commit-config.yaml ├── Configuration └── Build.xcconfig ├── Libraries ├── MLXMNIST │ ├── README.md │ ├── Random.swift │ ├── MNIST.swift │ └── Files.swift └── StableDiffusion │ ├── README.md │ ├── Sampler.swift │ └── Tokenizer.swift ├── .github ├── pull_request_template.md ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ └── pull_request.yml ├── ACKNOWLEDGMENTS.md ├── LICENSE ├── CONTRIBUTING.md ├── mlx-run ├── Package.resolved ├── .gitignore ├── Package.swift ├── Data └── lora │ └── wikisql.py └── README.md /Tests/MLXLMTests/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /support/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-swift-examples/HEAD/support/test.jpg -------------------------------------------------------------------------------- /Applications/VLMEval/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/LLMEval/AssetCatalog.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Support/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /.spi.yml: -------------------------------------------------------------------------------- 1 | version: 1 2 | builder: 3 | configs: 4 | - documentation_targets: [MLXLLM, MLXVLM, MLXLMCommon, MLXMNIST, MLXEmbedders, StableDiffusion] 5 | -------------------------------------------------------------------------------- /.swift-format: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "indentation": { 4 | "spaces": 4 5 | }, 6 | "spacesAroundRangeFormationOperators": true, 7 | } 8 | -------------------------------------------------------------------------------- /Applications/VLMEval/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/LLMEval/AppIcon.icon/Assets/bubbles3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-explore/mlx-swift-examples/HEAD/Applications/LLMEval/AppIcon.icon/Assets/bubbles3.png -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Support/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/MNISTTrainer-Info.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Applications/VLMEval/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Support/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/MNISTTrainerApp.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | @main 6 | struct MNISTTrainerApp: App { 7 | var body: some Scene { 8 | WindowGroup { 9 | ContentView() 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Tools/embedder-tool/CommandError.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | struct CommandError: LocalizedError { 4 | let message: String 5 | 6 | init(_ message: String) { 7 | self.message = message 8 | } 9 | 10 | var errorDescription: String? { message } 11 | } 12 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/LoRATrainingExampleApp.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | @main 6 | struct LoRATrainingExampleApp: App { 7 | var body: some Scene { 8 | WindowGroup { 9 | ContentView() 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | 3 | - repo: local 4 | hooks: 5 | - id: swift-format 6 | name: swift-format 7 | language: system 8 | entry: swift-format format --in-place --configuration .swift-format --recursive . 9 | require_serial: true 10 | types: [swift] 11 | -------------------------------------------------------------------------------- /Applications/LLMEval/LLMEvalApp.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | @main 6 | struct LLMEvalApp: App { 7 | var body: some Scene { 8 | WindowGroup { 9 | ContentView() 10 | .environment(DeviceStat()) 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/StableDiffusionExampleApp.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | @main 6 | struct StableDiffusionExampleApp: App { 7 | var body: some Scene { 8 | WindowGroup { 9 | ContentView() 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Applications/VLMEval/VLMEvalApp.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | @main 6 | struct VLMEvalApp: App { 7 | var body: some Scene { 8 | WindowGroup { 9 | ContentView() 10 | .environment(DeviceStat()) 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /Tools/LinearModelTraining/README.md: -------------------------------------------------------------------------------- 1 | # LinearModelTraining 2 | 3 | A command line tool that creates a Model that represents: 4 | 5 | f(x) = mx + b 6 | 7 | and trains it against an unknown linear function. Very 8 | simple but illustrates: 9 | 10 | - a very simple model with parameters 11 | - a loss function 12 | - the gradient 13 | - use of an optimizers 14 | - the training loop 15 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/MLXChatExampleApp.swift: -------------------------------------------------------------------------------- 1 | // 2 | // MLXChatExampleApp.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 20.04.2025. 6 | // 7 | 8 | import SwiftUI 9 | 10 | @main 11 | struct MLXChatExampleApp: App { 12 | var body: some Scene { 13 | WindowGroup { 14 | ChatView(viewModel: ChatViewModel(mlxService: MLXService())) 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/MNISTTrainer.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.security.app-sandbox 6 | 7 | com.apple.security.files.user-selected.read-only 8 | 9 | com.apple.security.network.client 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /Applications/LLMEval/AssetCatalog.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "color" : { 5 | "color-space" : "srgb", 6 | "components" : { 7 | "alpha" : "1.000", 8 | "blue" : "1.000", 9 | "green" : "0.689", 10 | "red" : "0.307" 11 | } 12 | }, 13 | "idiom" : "universal" 14 | } 15 | ], 16 | "info" : { 17 | "author" : "xcode", 18 | "version" : 1 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /Tools/embedder-tool/ArgumentSupport.swift: -------------------------------------------------------------------------------- 1 | import ArgumentParser 2 | import Foundation 3 | 4 | extension URL: @retroactive ExpressibleByArgument { 5 | public init?(argument: String) { 6 | let expanded = NSString(string: argument).expandingTildeInPath 7 | 8 | if argument.contains("://"), let remote = URL(string: argument), remote.scheme != nil { 9 | self = remote 10 | return 11 | } 12 | 13 | self.init(fileURLWithPath: expanded) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /Configuration/Build.xcconfig: -------------------------------------------------------------------------------- 1 | // The `DISAMBIGUATOR` configuration is to make it easier to build 2 | // and run a sample code project. Once you set your project's development team, 3 | // you'll have a unique bundle identifier. This is because the bundle identifier 4 | // is derived based on the 'DISAMBIGUATOR' value. Do not use this 5 | // approach in your own projects—it's only useful for example projects because 6 | // they are frequently downloaded and don't have a development team set. 7 | DISAMBIGUATOR=${DEVELOPMENT_TEAM} 8 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Views/Toolbar/GenerationInfoView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // GenerationInfoView.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 21.04.2025. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct GenerationInfoView: View { 11 | let tokensPerSecond: Double 12 | 13 | var body: some View { 14 | Text("\(tokensPerSecond, format: .number.precision(.fractionLength(2))) tokens/s") 15 | } 16 | } 17 | 18 | #Preview { 19 | GenerationInfoView(tokensPerSecond: 58.5834) 20 | } 21 | -------------------------------------------------------------------------------- /Applications/LLMEval/LLMEval.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.files.user-selected.read-only 10 | 11 | com.apple.security.network.client 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /Libraries/MLXMNIST/README.md: -------------------------------------------------------------------------------- 1 | # MNIST 2 | 3 | This is a port of the MNIST training code from the [Python MLX example](https://github.com/ml-explore/mlx-examples/blob/main/mnist). This example uses a [LeNet](https://en.wikipedia.org/wiki/LeNet) instead of an MLP. 4 | 5 | It provides code to: 6 | 7 | - Download the MNIST test/train data 8 | - Build the LeNet 9 | - Some functions to shuffle and batch the data 10 | 11 | See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there. 12 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/LoRATrainingExample.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.files.user-selected.read-only 10 | 11 | com.apple.security.network.client 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/StableDiffusionExample.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.files.user-selected.read-only 10 | 11 | com.apple.security.network.client 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /Tests/mlx-libraries-Package.xctestplan: -------------------------------------------------------------------------------- 1 | { 2 | "configurations" : [ 3 | { 4 | "id" : "CCE21325-5E68-419C-8EB1-B868EF7688F7", 5 | "name" : "Test Scheme Action", 6 | "options" : { 7 | 8 | } 9 | } 10 | ], 11 | "defaultOptions" : { 12 | 13 | }, 14 | "testTargets" : [ 15 | { 16 | "target" : { 17 | "containerPath" : "container:mlx-swift-examples.xcodeproj", 18 | "identifier" : "C3208E6D2DB19451006AE6CA", 19 | "name" : "MLXLMTests" 20 | } 21 | } 22 | ], 23 | "version" : 1 24 | } 25 | -------------------------------------------------------------------------------- /Applications/VLMEval/VLMEval.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.device.usb 10 | 11 | com.apple.security.files.user-selected.read-only 12 | 13 | com.apple.security.network.client 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /support/generate-run-all-llms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo "#!/bin/sh" 4 | echo "# NOTE: GENERATED BY generate-run-all-llms.sh -- DO NOT MODIFY BY HAND" 5 | 6 | ./mlx-run llm-tool list llms | \ 7 | awk '{printf "./mlx-run llm-tool eval --download ~/Downloads/huggingface --model %s\n", $0}' | \ 8 | awk '{printf "echo\necho ======\necho '\''%s'\''\n%s\n", $0, $0}' 9 | 10 | ./mlx-run llm-tool list vlms | \ 11 | awk '{printf "./mlx-run llm-tool eval --download ~/Downloads/huggingface --model %s --resize 512 --image support/test.jpg\n", $0}' | \ 12 | awk '{printf "echo\necho ======\necho '\''%s'\''\n%s\n", $0, $0}' 13 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/MLXChatExample.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.files.downloads.read-write 10 | 11 | com.apple.security.files.user-selected.read-only 12 | 13 | com.apple.security.network.client 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Proposed changes 2 | 3 | Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #. 4 | 5 | ## Checklist 6 | 7 | Put an `x` in the boxes that apply. 8 | 9 | - [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document 10 | - [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes 11 | - [ ] I have added tests that prove my fix is effective or that my feature works 12 | - [ ] I have updated the necessary documentation (if needed) 13 | -------------------------------------------------------------------------------- /Tools/embedder-tool/Diagnostics.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | enum DiagnosticKind { 4 | case info 5 | case warning 6 | case error 7 | 8 | fileprivate var prefix: String { 9 | switch self { 10 | case .info: 11 | return "" 12 | case .warning: 13 | return "warning: " 14 | case .error: 15 | return "error: " 16 | } 17 | } 18 | } 19 | 20 | func writeDiagnostic(_ message: String, kind: DiagnosticKind = .info) { 21 | guard let data = (kind.prefix + message + "\n").data(using: .utf8) else { 22 | return 23 | } 24 | FileHandle.standardError.write(data) 25 | } 26 | -------------------------------------------------------------------------------- /Applications/LLMEval/Models/ToolDefinitions.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import Foundation 4 | 5 | // MARK: - Weather Tool 6 | 7 | struct WeatherInput: Codable { 8 | let location: String 9 | let unit: String? 10 | } 11 | 12 | struct WeatherOutput: Codable { 13 | let temperature: Double 14 | let conditions: String 15 | } 16 | 17 | // MARK: - Add Tool 18 | 19 | struct AddInput: Codable { 20 | let first: Int 21 | let second: Int 22 | } 23 | 24 | struct AddOutput: Codable { 25 | let result: Int 26 | } 27 | 28 | // MARK: - Time Tool 29 | 30 | struct EmptyInput: Codable {} 31 | 32 | struct TimeOutput: Codable { 33 | let time: String 34 | } 35 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/README.md: -------------------------------------------------------------------------------- 1 | # MNISTTrainer 2 | 3 | This is an example of model training that works on both macOS and iOS. 4 | The example will download the MNIST training data, create a LeNet, and train 5 | it. It will show the epoch time and test accuracy as it trains. 6 | 7 | You will need to set the Team on the MNISTTrainer target in order to build and 8 | run on iOS. 9 | 10 | Some notes about the setup: 11 | 12 | - This will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox 13 | - The website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report about an issue you've encountered 4 | title: "[BUG] " 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | 15 | Include code snippet 16 | ```swift 17 | 18 | ``` 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Desktop (please complete the following information):** 24 | - OS Version: [e.g. MacOS 14.1.2] 25 | - Device: [e.g. iPhone 16, M2 MacBook Pro] 26 | - Version [e.g. 0.29.1] 27 | 28 | **Additional context** 29 | Add any other context about the problem here. 30 | -------------------------------------------------------------------------------- /ACKNOWLEDGMENTS.md: -------------------------------------------------------------------------------- 1 | # Individual Contributors 2 | 3 | > If you wish to be acknowledged for your contributions, please list your name 4 | > with a short description of your contribution(s) below. For example: 5 | > - Jane Smith: Added the `foo` and `bar` ops. 6 | 7 | MLX Swift was developed with contributions from the following individuals: 8 | 9 | - [John Mai](https://github.com/johnmai-dev): Added support for multiple models (Qwen2, Starcoder2, InternLM2, Qwen3, Qwen3 MoE, GLM-4, MiMo, BitNet, SmolLM3, LFM2, Baichuan-M1). 10 | 11 | 12 | 13 | 14 | 15 | 16 | SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Views/Toolbar/ErrorView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ErrorView.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 21.04.2025. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct ErrorView: View { 11 | let errorMessage: String 12 | 13 | @State private var isShowingError = false 14 | 15 | var body: some View { 16 | Button { 17 | isShowingError = true 18 | } label: { 19 | Image(systemName: "exclamationmark.triangle") 20 | .foregroundStyle(.red) 21 | } 22 | .popover(isPresented: $isShowingError, arrowEdge: .bottom) { 23 | Text(errorMessage) 24 | .padding() 25 | } 26 | } 27 | } 28 | 29 | #Preview { 30 | ErrorView(errorMessage: "Something went wrong!") 31 | } 32 | -------------------------------------------------------------------------------- /Applications/VLMEval/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "appearances" : [ 10 | { 11 | "appearance" : "luminosity", 12 | "value" : "dark" 13 | } 14 | ], 15 | "idiom" : "universal", 16 | "platform" : "ios", 17 | "size" : "1024x1024" 18 | }, 19 | { 20 | "appearances" : [ 21 | { 22 | "appearance" : "luminosity", 23 | "value" : "tinted" 24 | } 25 | ], 26 | "idiom" : "universal", 27 | "platform" : "ios", 28 | "size" : "1024x1024" 29 | } 30 | ], 31 | "info" : { 32 | "author" : "xcode", 33 | "version" : 1 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Support/HubApi+default.swift: -------------------------------------------------------------------------------- 1 | // 2 | // HubApi+default.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 21.04.2025. 6 | // 7 | 8 | import Foundation 9 | @preconcurrency import Hub 10 | 11 | /// Extension providing a default HubApi instance for downloading model files 12 | extension HubApi { 13 | /// Default HubApi instance configured to download models to the user's Downloads directory 14 | /// under a 'huggingface' subdirectory. 15 | #if os(macOS) 16 | static let `default` = HubApi( 17 | downloadBase: URL.downloadsDirectory.appending(path: "huggingface") 18 | ) 19 | #else 20 | static let `default` = HubApi( 21 | downloadBase: URL.cachesDirectory.appending(path: "huggingface") 22 | ) 23 | #endif 24 | } 25 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Support/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "appearances" : [ 10 | { 11 | "appearance" : "luminosity", 12 | "value" : "dark" 13 | } 14 | ], 15 | "idiom" : "universal", 16 | "platform" : "ios", 17 | "size" : "1024x1024" 18 | }, 19 | { 20 | "appearances" : [ 21 | { 22 | "appearance" : "luminosity", 23 | "value" : "tinted" 24 | } 25 | ], 26 | "idiom" : "universal", 27 | "platform" : "ios", 28 | "size" : "1024x1024" 29 | } 30 | ], 31 | "info" : { 32 | "author" : "xcode", 33 | "version" : 1 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /Tools/llm-tool/Arguments.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | 6 | /// Extension to allow URL command line arguments. 7 | #if swift(>=5.10) 8 | extension URL: @retroactive ExpressibleByArgument { 9 | public init?(argument: String) { 10 | if argument.contains("://") { 11 | self.init(string: argument) 12 | } else { 13 | self.init(filePath: argument) 14 | } 15 | } 16 | } 17 | #else 18 | extension URL: ExpressibleByArgument { 19 | public init?(argument: String) { 20 | if argument.contains("://") { 21 | self.init(string: argument) 22 | } else { 23 | self.init(filePath: argument) 24 | } 25 | } 26 | } 27 | #endif 28 | -------------------------------------------------------------------------------- /Applications/LLMEval/ViewModels/DeviceStat.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import Foundation 4 | import MLX 5 | 6 | @Observable 7 | final class DeviceStat: @unchecked Sendable { 8 | 9 | @MainActor 10 | var gpuUsage = GPU.snapshot() 11 | 12 | private let initialGPUSnapshot = GPU.snapshot() 13 | private var timer: Timer? 14 | 15 | init() { 16 | timer = Timer.scheduledTimer(withTimeInterval: 2.0, repeats: true) { [weak self] _ in 17 | self?.updateGPUUsages() 18 | } 19 | } 20 | 21 | deinit { 22 | timer?.invalidate() 23 | } 24 | 25 | private func updateGPUUsages() { 26 | let gpuSnapshotDelta = initialGPUSnapshot.delta(GPU.snapshot()) 27 | DispatchQueue.main.async { [weak self] in 28 | self?.gpuUsage = gpuSnapshotDelta 29 | } 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /Applications/LLMEval/Services/FormatUtilities.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import Foundation 4 | 5 | /// Utility functions for formatting values 6 | enum FormatUtilities { 7 | 8 | /// Formats a byte count into a human-readable string with appropriate units 9 | /// - Parameter bytes: The number of bytes to format 10 | /// - Returns: A formatted string (e.g., "2.5 GB", "128 MB", "512 KB") 11 | static func formatMemory(_ bytes: Int) -> String { 12 | let kb = Double(bytes) / 1024 13 | let mb = kb / 1024 14 | let gb = mb / 1024 15 | 16 | if gb >= 1 { 17 | return String(format: "%.2f GB", gb) 18 | } else if mb >= 1 { 19 | return String(format: "%.0f MB", mb) 20 | } else if kb >= 1 { 21 | return String(format: "%.0f KB", kb) 22 | } else { 23 | return "0 KB" 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /Applications/LLMEval/Views/MetricCard.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | struct MetricCard: View { 6 | let icon: String 7 | let title: String 8 | let value: String 9 | 10 | var body: some View { 11 | VStack(spacing: 8) { 12 | HStack(spacing: 4) { 13 | Image(systemName: icon) 14 | .font(.caption) 15 | .foregroundStyle(.secondary) 16 | Text(title) 17 | .font(.caption) 18 | .foregroundStyle(.secondary) 19 | } 20 | Text(value) 21 | .font(.title3) 22 | .fontWeight(.semibold) 23 | .monospacedDigit() 24 | } 25 | .frame(maxWidth: .infinity) 26 | .padding(.vertical, 12) 27 | .padding(.horizontal, 8) 28 | .background(Color.gray.opacity(0.1)) 29 | .cornerRadius(8) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/README.md: -------------------------------------------------------------------------------- 1 | # LoRATrainingExample 2 | 3 | Example application that: 4 | 5 | - downloads the `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model from huggingface 6 | - loads the train/valid/test data from `$SRCROOT/Data/lora` (this is copied into the build but you can imagine how it might be downloaded) 7 | - adds LoRA adapters and trains the model 8 | - let's you evaluate a prompt against the model 9 | 10 | This roughly equates to the command line example in [Tools/llm-tool](../../Tools/llm-tool) and 11 | you can read more about LoRA there. 12 | 13 | This evaluates the LoRA adapted model rather than a fused model. This doesn't persist 14 | the LoRA weights or the fused model -- it will retrain it each time the program is launched. 15 | 16 | ### Troubleshooting 17 | 18 | The `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model requires a little over 4G of 19 | memory to load an train -- this may require ~6G of physical RAM. 20 | 21 | 22 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Views/ConversationView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ConversationView.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 20.04.2025. 6 | // 7 | 8 | import SwiftUI 9 | 10 | /// Displays the chat conversation as a scrollable list of messages. 11 | struct ConversationView: View { 12 | /// Array of messages to display in the conversation 13 | let messages: [Message] 14 | 15 | var body: some View { 16 | ScrollView { 17 | LazyVStack(spacing: 12) { 18 | ForEach(messages) { message in 19 | MessageView(message) 20 | .padding(.horizontal, 12) 21 | } 22 | } 23 | } 24 | .padding(.vertical, 8) 25 | .defaultScrollAnchor(.bottom, for: .sizeChanges) 26 | } 27 | } 28 | 29 | #Preview { 30 | // Display sample conversation in preview 31 | ConversationView(messages: SampleData.conversation) 32 | } 33 | -------------------------------------------------------------------------------- /Applications/LLMEval/AppIcon.icon/icon.json: -------------------------------------------------------------------------------- 1 | { 2 | "fill" : { 3 | "automatic-gradient" : "srgb:0.66275,0.79216,0.87451,1.00000", 4 | "orientation" : { 5 | "start" : { 6 | "x" : 0.5, 7 | "y" : 0 8 | }, 9 | "stop" : { 10 | "x" : 0.5, 11 | "y" : 0.7 12 | } 13 | } 14 | }, 15 | "groups" : [ 16 | { 17 | "layers" : [ 18 | { 19 | "image-name" : "bubbles3.png", 20 | "name" : "bubbles3", 21 | "position" : { 22 | "scale" : 1.35, 23 | "translation-in-points" : [ 24 | 0, 25 | 0 26 | ] 27 | } 28 | } 29 | ], 30 | "shadow" : { 31 | "kind" : "neutral", 32 | "opacity" : 0.5 33 | }, 34 | "translucency" : { 35 | "enabled" : true, 36 | "value" : 0.5 37 | } 38 | } 39 | ], 40 | "supported-platforms" : { 41 | "circles" : [ 42 | "watchOS" 43 | ], 44 | "squares" : "shared" 45 | } 46 | } -------------------------------------------------------------------------------- /Tools/embedder-tool/ListCommand.swift: -------------------------------------------------------------------------------- 1 | import ArgumentParser 2 | import Foundation 3 | import MLXEmbedders 4 | 5 | struct ListCommand: AsyncParsableCommand { 6 | static let configuration = CommandConfiguration( 7 | commandName: "list", 8 | abstract: "List registered embedder model configurations" 9 | ) 10 | 11 | @Flag(name: .long, help: "Include models registered from local directories") 12 | var includeDirectories = false 13 | 14 | func run() async throws { 15 | let models = await MainActor.run { Array(ModelConfiguration.models) } 16 | .sorted { $0.name.localizedCaseInsensitiveCompare($1.name) == .orderedAscending } 17 | 18 | for configuration in models { 19 | switch configuration.id { 20 | case .id(let identifier): 21 | print(identifier) 22 | case .directory(let url): 23 | if includeDirectories { 24 | print(url.path) 25 | } 26 | } 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /Tools/mnist-tool/README.md: -------------------------------------------------------------------------------- 1 | # mnist-tool 2 | 3 | See the [MNIST README.md](../../Libraries/MNIST/README.md). 4 | 5 | ### Building 6 | 7 | `mnist-tool` has no dependencies outside of the package dependencies 8 | represented in xcode. 9 | 10 | When you run the tool it will download the test/train datasets and 11 | store them in a specified directory (see run arguments -- default is /tmp). 12 | 13 | Simply build the project in xcode. 14 | 15 | ### Running (Xcode) 16 | 17 | To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example: 18 | 19 | ``` 20 | --data /tmp 21 | ``` 22 | 23 | Then cmd-r to run. 24 | 25 | ### Running (CommandLine) 26 | 27 | Use the `mlx-run` script to run the command line tools: 28 | 29 | ``` 30 | ./mlx-run mnist-tool --data /tmp 31 | ``` 32 | 33 | By default this will find and run the tools built in _Release_ configuration. Specify `--debug` 34 | to find and run the tool built in _Debug_ configuration. 35 | 36 | See also: 37 | 38 | - [MLX troubleshooting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/troubleshooting) 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ml-explore 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MLX Swift Examples 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | 1. Fork and submit pull requests to the repo. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. Every PR should have passing tests (if any) and at least one review. 11 | 4. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`. 12 | If needed you may need to `brew install swift-format`. 13 | 14 | You can also run the formatters manually as follows: 15 | 16 | ``` 17 | swift-format format --in-place --recursive Libraries Tools Applications 18 | ``` 19 | 20 | or run `pre-commit run --all-files` to check all files in the repo. 21 | 22 | ## Issues 23 | 24 | We use GitHub issues to track public bugs. Please ensure your description is 25 | clear and has sufficient instructions to be able to reproduce the issue. 26 | 27 | ## License 28 | 29 | By contributing to MLX Swift Examples, you agree that your contributions will be licensed 30 | under the LICENSE file in the root directory of this source tree. 31 | -------------------------------------------------------------------------------- /Libraries/MLXMNIST/Random.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | 5 | // From https://github.com/apple/swift/blob/cb0fb1ea051631219c0b944b84c78571448d58c2/benchmark/utils/TestsUtils.swift#L254 6 | // 7 | // This is just a seedable RandomNumberGenerator for shuffle() 8 | 9 | // This is a fixed-increment version of Java 8's SplittableRandom generator. 10 | // It is a very fast generator passing BigCrush, with 64 bits of state. 11 | // See http://dx.doi.org/10.1145/2714064.2660195 and 12 | // http://docs.oracle.com/javase/8/docs/api/java/util/SplittableRandom.html 13 | // 14 | // Derived from public domain C implementation by Sebastiano Vigna 15 | // See http://xoshiro.di.unimi.it/splitmix64.c 16 | public struct SplitMix64: RandomNumberGenerator, Sendable { 17 | private var state: UInt64 18 | 19 | public init(seed: UInt64) { 20 | self.state = seed 21 | } 22 | 23 | public mutating func next() -> UInt64 { 24 | self.state &+= 0x9e37_79b9_7f4a_7c15 25 | var z: UInt64 = self.state 26 | z = (z ^ (z &>> 30)) &* 0xbf58_476d_1ce4_e5b9 27 | z = (z ^ (z &>> 27)) &* 0x94d0_49bb_1331_11eb 28 | return z ^ (z &>> 31) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /Tools/embedder-tool/CorpusArguments.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | 6 | struct CorpusArguments: ParsableArguments { 7 | 8 | @Option(name: .shortAndLong, help: "Directory containing documents to index.") 9 | var directory: URL = URL( 10 | fileURLWithPath: FileManager.default.currentDirectoryPath, isDirectory: true) 11 | 12 | @Option( 13 | name: [.customShort("e"), .long], parsing: .upToNextOption, 14 | help: "File extensions to include (without dots).") 15 | var extensions: [String] = ["txt", "md"] 16 | 17 | @Flag(name: .shortAndLong, help: "Recursively scan subdirectories.") 18 | var recursive = false 19 | 20 | @Option(name: .long, help: "Limit the number of documents to load.") 21 | var limit: Int? 22 | 23 | var normalizedExtensions: [String] { 24 | extensions.map { value in 25 | let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines) 26 | let noDot = trimmed.hasPrefix(".") ? String(trimmed.dropFirst()) : trimmed 27 | return noDot.lowercased() 28 | } 29 | } 30 | 31 | var directoryURL: URL { directory.standardizedFileURL } 32 | } 33 | -------------------------------------------------------------------------------- /Tools/ExampleLLM/main.swift: -------------------------------------------------------------------------------- 1 | import MLXLMCommon 2 | 3 | let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit") 4 | 5 | let prompt = "What are three things to see in Paris?" 6 | 7 | // MARK: - one-shot print out full response 8 | print( 9 | """ 10 | ================ 11 | \(prompt) 12 | 13 | """) 14 | print(try await ChatSession(model).respond(to: prompt)) 15 | 16 | // MARK: - one-shot streaming output 17 | print( 18 | """ 19 | ================ 20 | \(prompt) 21 | 22 | """) 23 | for try await item in ChatSession(model).streamResponse(to: prompt) { 24 | print(item, terminator: "") 25 | } 26 | print() 27 | 28 | // MARK: - conversation with follow-on questions 29 | let session = ChatSession(model) 30 | 31 | let questions = [ 32 | "What are two things to see in San Francisco?", 33 | "How about a great place to eat?", 34 | "What city are we talking about? I forgot!", 35 | ] 36 | 37 | for question in questions { 38 | print( 39 | """ 40 | ================ 41 | \(question) 42 | 43 | """) 44 | 45 | for try await item in session.streamResponse(to: question) { 46 | print(item, terminator: "") 47 | } 48 | print() 49 | } 50 | -------------------------------------------------------------------------------- /mlx-run: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Wrapper to help run command line tools -- this will find the build directory 4 | # and set the DYLD_FRAMEWORK_PATH so that command line tools that link frameworks 5 | # can be run. 6 | # 7 | # Example: 8 | # ./mlx-run --debug llm-tool --help 9 | 10 | if [ "$#" -lt 1 ]; then 11 | echo "usage: mlx-run [--debug/--release] arguments" 12 | exit 1 13 | fi 14 | 15 | CONFIGURATION=Release 16 | if [ "$1" == "--release" ]; then 17 | CONFIGURATION=Release 18 | shift 19 | fi 20 | if [ "$1" == "--debug" ]; then 21 | CONFIGURATION=Debug 22 | shift 23 | fi 24 | if [ "$1" == "--list" ]; then 25 | xcodebuild -list 26 | exit 0 27 | fi 28 | 29 | COMMAND="$1" 30 | shift 31 | 32 | BUILD_DIR=`xcodebuild -configuration $CONFIGURATION -showBuildSettings -scheme $COMMAND | grep 'BUILT_PRODUCTS_DIR = /' | sed -e 's/^[^=]*= //g'` 33 | 34 | if [ -d "$BUILD_DIR/$COMMAND.app" ]; then 35 | exec $BUILD_DIR/$COMMAND.app/Contents/MacOS/$COMMAND "$@" & 36 | fi 37 | 38 | if [ -f "$BUILD_DIR/$COMMAND" ]; then 39 | export DYLD_FRAMEWORK_PATH=$BUILD_DIR/PackageFrameworks:$BUILD_DIR 40 | exec "$BUILD_DIR/$COMMAND" "$@" 41 | else 42 | echo "$BUILD_DIR/$COMMAND does not exist -- check build configuration ($CONFIGURATION)" 43 | exit 1 44 | fi 45 | 46 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Views/Toolbar/DownloadProgressView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // DownloadProgressView.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 21.04.2025. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct DownloadProgressView: View { 11 | let progress: Progress 12 | 13 | @State private var isShowingDownload = false 14 | 15 | var body: some View { 16 | Button { 17 | isShowingDownload = true 18 | } label: { 19 | Image(systemName: "arrow.down.square") 20 | .foregroundStyle(.tint) 21 | } 22 | .popover(isPresented: $isShowingDownload, arrowEdge: .bottom) { 23 | VStack { 24 | ProgressView(value: progress.fractionCompleted) { 25 | HStack { 26 | Text(progress.localizedAdditionalDescription) 27 | .bold() 28 | Spacer() 29 | Text(progress.localizedDescription) 30 | } 31 | } 32 | 33 | Text("The model is downloading") 34 | .padding(.horizontal, 32) 35 | } 36 | .padding() 37 | } 38 | } 39 | } 40 | 41 | #Preview { 42 | DownloadProgressView(progress: Progress(totalUnitCount: 6)) 43 | } 44 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "idiom" : "mac", 10 | "scale" : "1x", 11 | "size" : "16x16" 12 | }, 13 | { 14 | "idiom" : "mac", 15 | "scale" : "2x", 16 | "size" : "16x16" 17 | }, 18 | { 19 | "idiom" : "mac", 20 | "scale" : "1x", 21 | "size" : "32x32" 22 | }, 23 | { 24 | "idiom" : "mac", 25 | "scale" : "2x", 26 | "size" : "32x32" 27 | }, 28 | { 29 | "idiom" : "mac", 30 | "scale" : "1x", 31 | "size" : "128x128" 32 | }, 33 | { 34 | "idiom" : "mac", 35 | "scale" : "2x", 36 | "size" : "128x128" 37 | }, 38 | { 39 | "idiom" : "mac", 40 | "scale" : "1x", 41 | "size" : "256x256" 42 | }, 43 | { 44 | "idiom" : "mac", 45 | "scale" : "2x", 46 | "size" : "256x256" 47 | }, 48 | { 49 | "idiom" : "mac", 50 | "scale" : "1x", 51 | "size" : "512x512" 52 | }, 53 | { 54 | "idiom" : "mac", 55 | "scale" : "2x", 56 | "size" : "512x512" 57 | } 58 | ], 59 | "info" : { 60 | "author" : "xcode", 61 | "version" : 1 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /Applications/LoRATrainingExample/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "idiom" : "mac", 10 | "scale" : "1x", 11 | "size" : "16x16" 12 | }, 13 | { 14 | "idiom" : "mac", 15 | "scale" : "2x", 16 | "size" : "16x16" 17 | }, 18 | { 19 | "idiom" : "mac", 20 | "scale" : "1x", 21 | "size" : "32x32" 22 | }, 23 | { 24 | "idiom" : "mac", 25 | "scale" : "2x", 26 | "size" : "32x32" 27 | }, 28 | { 29 | "idiom" : "mac", 30 | "scale" : "1x", 31 | "size" : "128x128" 32 | }, 33 | { 34 | "idiom" : "mac", 35 | "scale" : "2x", 36 | "size" : "128x128" 37 | }, 38 | { 39 | "idiom" : "mac", 40 | "scale" : "1x", 41 | "size" : "256x256" 42 | }, 43 | { 44 | "idiom" : "mac", 45 | "scale" : "2x", 46 | "size" : "256x256" 47 | }, 48 | { 49 | "idiom" : "mac", 50 | "scale" : "1x", 51 | "size" : "512x512" 52 | }, 53 | { 54 | "idiom" : "mac", 55 | "scale" : "2x", 56 | "size" : "512x512" 57 | } 58 | ], 59 | "info" : { 60 | "author" : "xcode", 61 | "version" : 1 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "idiom" : "mac", 10 | "scale" : "1x", 11 | "size" : "16x16" 12 | }, 13 | { 14 | "idiom" : "mac", 15 | "scale" : "2x", 16 | "size" : "16x16" 17 | }, 18 | { 19 | "idiom" : "mac", 20 | "scale" : "1x", 21 | "size" : "32x32" 22 | }, 23 | { 24 | "idiom" : "mac", 25 | "scale" : "2x", 26 | "size" : "32x32" 27 | }, 28 | { 29 | "idiom" : "mac", 30 | "scale" : "1x", 31 | "size" : "128x128" 32 | }, 33 | { 34 | "idiom" : "mac", 35 | "scale" : "2x", 36 | "size" : "128x128" 37 | }, 38 | { 39 | "idiom" : "mac", 40 | "scale" : "1x", 41 | "size" : "256x256" 42 | }, 43 | { 44 | "idiom" : "mac", 45 | "scale" : "2x", 46 | "size" : "256x256" 47 | }, 48 | { 49 | "idiom" : "mac", 50 | "scale" : "1x", 51 | "size" : "512x512" 52 | }, 53 | { 54 | "idiom" : "mac", 55 | "scale" : "2x", 56 | "size" : "512x512" 57 | } 58 | ], 59 | "info" : { 60 | "author" : "xcode", 61 | "version" : 1 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /Applications/StableDiffusionExample/README.md: -------------------------------------------------------------------------------- 1 | # StableDiffusionExample 2 | 3 | An example application that runs the StableDiffusion example code. 4 | 5 | See also [image-tool](../../Tools/image-tool) for a command line example. 6 | 7 | This example application accepts a prompt and used the StableDiffusion example 8 | library to render an image using: 9 | 10 | - [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo) 11 | 12 | Please refer to that model for license and other information. 13 | 14 | If you are interested in adjusting the generated images, look in 15 | [ContentView.swift](ContentView.swift) at this method: 16 | 17 | ```swift 18 | func generate(prompt: String, negativePrompt: String, showProgress: Bool) async 19 | ``` 20 | 21 | ### Troubleshooting 22 | 23 | Stable diffusion can run in less that 4G available memory (typically a 24 | device or computer with 6G of memory or more) in a constrained mode -- it will 25 | load and unload parts of the model as it runs and it can only perform one step 26 | of diffusion. This is configured automatically, see `modelFactory.conserveMemory` 27 | in [ContentView.swift](ContentView.swift). 28 | 29 | On a device or computer with more memory the model will be kept resident and 30 | images can be regenerated much more efficiently. 31 | 32 | If the program exits while generating the image it may have exceeded the available 33 | memory. 34 | -------------------------------------------------------------------------------- /Tools/ExampleLLM/README.md: -------------------------------------------------------------------------------- 1 | # ExampleLLM 2 | 3 | An example that uses the simplified APIs to load and evaluate an LLM in only a few lines of 4 | code: 5 | 6 | ```swift 7 | let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit") 8 | let session = ChatSession(model) 9 | print(try await session.respond(to: "What are two things to see in San Francisco?") 10 | print(try await session.respond(to: "How about a great place to eat?") 11 | ``` 12 | 13 | See various READMEs: 14 | 15 | - [MLXLMCommon](https://github.com/ml-explore/mlx-swift-lm/Libraries/MLXLMCommon/README.md) -- common LM code 16 | - [MLXLLM](https://github.com/ml-explore/mlx-swift-lm/Libraries/MLXLLM/README.md) -- large language models 17 | - [MLXVLM](https://github.com/ml-explore/mlx-swift-lm/Libraries/MLXVLM/README.md) -- vision language models 18 | 19 | ### Building 20 | 21 | Build the `ExampleLLM` scheme in Xcode. 22 | 23 | ### Running: Xcode 24 | 25 | Just press cmd-r to run! 26 | 27 | ### Running: Command Line 28 | 29 | Use the `mlx-run` script to run the command line tools: 30 | 31 | ``` 32 | ./mlx-run ExampleLLM 33 | ``` 34 | 35 | Note: `mlx-run` is a shell script that uses `xcode` command line tools to 36 | locate the built binaries. It is equivalent to running from Xcode itself. 37 | 38 | By default this will find and run the tools built in _Release_ configuration. Specify `--debug` 39 | to find and run the tool built in _Debug_ configuration. 40 | 41 | -------------------------------------------------------------------------------- /Tools/embedder-tool/PoolingArguments.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import ArgumentParser 4 | import MLXEmbedders 5 | 6 | struct PoolingArguments: ParsableArguments { 7 | 8 | @Option( 9 | name: .long, help: "Pooling strategy used to collapse token embeddings (default: mean).") 10 | var strategy: Pooling.Strategy? 11 | 12 | @Flag( 13 | name: .long, inversion: .prefixedNo, 14 | help: 15 | "Normalize pooled embeddings to unit length (default: true). Use --no-normalize to disable." 16 | ) 17 | var normalize = true 18 | 19 | @Flag(name: .long, help: "Apply layer normalization before pooling.") 20 | var layerNorm = false 21 | } 22 | 23 | extension PoolingArguments { 24 | var strategyOverride: Pooling.Strategy? { 25 | strategy ?? .mean 26 | } 27 | } 28 | 29 | extension Pooling.Strategy: @retroactive CaseIterable { 30 | public static var allCases: [Pooling.Strategy] { 31 | [.mean, .cls, .first, .last, .max, .none] 32 | } 33 | } 34 | 35 | extension Pooling.Strategy: @retroactive ExpressibleByArgument { 36 | public init?(argument: String) { 37 | switch argument.lowercased() { 38 | case "mean": self = .mean 39 | case "cls": self = .cls 40 | case "first": self = .first 41 | case "last": self = .last 42 | case "max": self = .max 43 | case "none": self = .none 44 | default: return nil 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /Tools/llm-tool/ListCommands.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | import MLXLLM 6 | import MLXVLM 7 | 8 | struct ListCommands: AsyncParsableCommand { 9 | 10 | static let configuration = CommandConfiguration( 11 | commandName: "list", 12 | abstract: "list registered model configurations", 13 | subcommands: [ 14 | ListLLMCommand.self, ListVLMCommand.self, 15 | ] 16 | ) 17 | } 18 | 19 | struct ListLLMCommand: AsyncParsableCommand { 20 | 21 | static let configuration = CommandConfiguration( 22 | commandName: "llms", 23 | abstract: "List registered LLM model configurations" 24 | ) 25 | 26 | func run() async throws { 27 | for configuration in LLMRegistry.shared.models { 28 | switch configuration.id { 29 | case .id(let id): print(id) 30 | case .directory: break 31 | } 32 | } 33 | } 34 | } 35 | 36 | struct ListVLMCommand: AsyncParsableCommand { 37 | 38 | static let configuration = CommandConfiguration( 39 | commandName: "vlms", 40 | abstract: "List registered VLM model configurations" 41 | ) 42 | 43 | func run() async throws { 44 | for configuration in VLMRegistry.shared.models { 45 | switch configuration.id { 46 | case .id(let id): print(id) 47 | case .directory: break 48 | } 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Views/Toolbar/ChatToolbarView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ChatToolbarView.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 21.04.2025. 6 | // 7 | 8 | import SwiftUI 9 | 10 | /// Toolbar view for the chat interface that displays error messages, download progress, 11 | /// generation statistics, and model selection controls. 12 | struct ChatToolbarView: View { 13 | /// View model containing the chat state and controls 14 | @Bindable var vm: ChatViewModel 15 | 16 | var body: some View { 17 | // Display error message if present 18 | if let errorMessage = vm.errorMessage { 19 | ErrorView(errorMessage: errorMessage) 20 | } 21 | 22 | // Show download progress for model loading 23 | if let progress = vm.modelDownloadProgress, !progress.isFinished { 24 | DownloadProgressView(progress: progress) 25 | } 26 | 27 | // Button to clear chat history, displays generation statistics 28 | Button { 29 | vm.clear([.chat, .meta]) 30 | } label: { 31 | GenerationInfoView( 32 | tokensPerSecond: vm.tokensPerSecond 33 | ) 34 | } 35 | 36 | // Model selection picker 37 | Picker("Model", selection: $vm.selectedModel) { 38 | ForEach(MLXService.availableModels) { model in 39 | Text(model.displayName) 40 | .tag(model) 41 | } 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Views/PromptField.swift: -------------------------------------------------------------------------------- 1 | // 2 | // PromptField.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 20.04.2025. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct PromptField: View { 11 | @Binding var prompt: String 12 | @State private var task: Task? 13 | 14 | let sendButtonAction: () async -> Void 15 | let mediaButtonAction: (() -> Void)? 16 | 17 | var body: some View { 18 | HStack { 19 | if let mediaButtonAction { 20 | Button(action: mediaButtonAction) { 21 | Image(systemName: "photo.badge.plus") 22 | } 23 | } 24 | 25 | TextField("Prompt", text: $prompt) 26 | .textFieldStyle(.roundedBorder) 27 | 28 | Button { 29 | if isRunning { 30 | task?.cancel() 31 | removeTask() 32 | } else { 33 | task = Task { 34 | await sendButtonAction() 35 | removeTask() 36 | } 37 | } 38 | } label: { 39 | Image(systemName: isRunning ? "stop.circle.fill" : "paperplane.fill") 40 | } 41 | .keyboardShortcut(isRunning ? .cancelAction : .defaultAction) 42 | } 43 | } 44 | 45 | private var isRunning: Bool { 46 | task != nil && !(task!.isCancelled) 47 | } 48 | 49 | private func removeTask() { 50 | task = nil 51 | } 52 | } 53 | 54 | #Preview { 55 | PromptField(prompt: .constant("")) { 56 | } mediaButtonAction: { 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Models/LMModel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // LMModel.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 21.04.2025. 6 | // 7 | 8 | import MLXLMCommon 9 | 10 | /// Represents a language model configuration with its associated properties and type. 11 | /// Can represent either a large language model (LLM) or a vision-language model (VLM). 12 | struct LMModel { 13 | /// Name of the model 14 | let name: String 15 | 16 | /// Configuration settings for model initialization 17 | let configuration: ModelConfiguration 18 | 19 | /// Type of the model (language or vision-language) 20 | let type: ModelType 21 | 22 | /// Defines the type of language model 23 | enum ModelType { 24 | /// Large language model (text-only) 25 | case llm 26 | /// Vision-language model (supports images and text) 27 | case vlm 28 | } 29 | } 30 | 31 | // MARK: - Helpers 32 | 33 | extension LMModel { 34 | /// Display name with additional "(Vision)" suffix for vision models 35 | var displayName: String { 36 | if isVisionModel { 37 | "\(name) (Vision)" 38 | } else { 39 | name 40 | } 41 | } 42 | 43 | /// Whether the model is a large language model 44 | var isLanguageModel: Bool { 45 | type == .llm 46 | } 47 | 48 | /// Whether the model is a vision-language model 49 | var isVisionModel: Bool { 50 | type == .vlm 51 | } 52 | } 53 | 54 | extension LMModel: Identifiable, Hashable { 55 | var id: String { 56 | name 57 | } 58 | 59 | func hash(into hasher: inout Hasher) { 60 | hasher.combine(name) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /Applications/VLMEval/README.md: -------------------------------------------------------------------------------- 1 | # VLMEval 2 | 3 | An example that: 4 | 5 | - downloads a vision language model (SmolVLM2) 6 | - processes an image or a video with a prompt 7 | 8 | You will need to set the Team on the VLMEval target in order to build and run on macOS. 9 | 10 | Some notes about the setup: 11 | 12 | - This downloads models from hugging face so VLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox 13 | - VLM models are large so this uses significant memory 14 | - The example can process image, video and provides detailed analysis 15 | 16 | ### Image Processing 17 | 18 | The example application uses SmolVLM2 model by default, see [ContentView.swift](ContentView.swift): 19 | 20 | ```swift 21 | self.modelContainer = try await VLMModelFactory.shared.loadContainer( 22 | configuration: VLMRegistry.smolvlm) 23 | ``` 24 | 25 | The application: 26 | 1. Downloads a sample image 27 | 2. Processes it through the vision language model 28 | 3. Describes the images based on the prompt, providing detailed analysis of the content, objects, colors, and composition. 29 | 30 | ### Troubleshooting 31 | 32 | If the program crashes with a very deep stack trace you may need to build 33 | in Release configuration. This seems to depend on the size of the model. 34 | 35 | There are a couple options: 36 | 37 | - Build Release 38 | - Force the model evaluation to run on the main thread, e.g. using @MainActor 39 | - Build `Cmlx` with optimizations by modifying `mlx/Package.swift` and adding `.unsafeOptions(["-O3"]),` 40 | 41 | ### Performance 42 | 43 | You may find that running outside the debugger boosts performance. You can do this in Xcode by pressing cmd-opt-r and unchecking "Debug Executable". 44 | -------------------------------------------------------------------------------- /Tools/embedder-tool/EmbedderCommand.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | 6 | /// A protocol for commands that need to load and use an embedder model. 7 | /// 8 | /// This protocol centralizes the logic for loading an `EmbedderRuntime` 9 | /// and managing memory statistics, reducing boilerplate in individual commands. 10 | protocol EmbedderCommand: AsyncParsableCommand { 11 | /// The model arguments, captured via `@OptionGroup`. 12 | var model: ModelArguments { get } 13 | 14 | /// The pooling arguments, captured via `@OptionGroup`. 15 | var pooling: PoolingArguments { get } 16 | 17 | /// The memory management arguments, captured via `@OptionGroup`. 18 | var memory: MemoryArguments { get set } 19 | 20 | /// The core logic of the command, which receives a fully initialized `EmbedderRuntime`. 21 | /// - Parameter runtime: The loaded and configured embedder runtime. 22 | mutating func run(runtime: EmbedderRuntime) async throws 23 | } 24 | 25 | extension EmbedderCommand { 26 | /// The main entry point for the command. 27 | /// 28 | /// This default implementation handles the loading of the embedder runtime 29 | /// and memory reporting, then calls the command's specific `run(runtime:)` method. 30 | mutating func run() async throws { 31 | var memory = self.memory 32 | let capturedModel = model 33 | let capturedPooling = pooling 34 | 35 | let runtime = try await memory.start { 36 | try await EmbedderTool.loadRuntime(model: capturedModel, pooling: capturedPooling) 37 | } 38 | 39 | defer { 40 | memory.reportMemoryStatistics() 41 | self.memory = memory 42 | } 43 | 44 | try await run(runtime: runtime) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /Applications/LLMEval/Views/OutputView.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import MarkdownUI 4 | import SwiftUI 5 | 6 | struct OutputView: View { 7 | let output: String 8 | let displayStyle: ContentView.DisplayStyle 9 | let wasTruncated: Bool 10 | 11 | var body: some View { 12 | ScrollView(.vertical) { 13 | ScrollViewReader { sp in 14 | VStack(alignment: .leading, spacing: 12) { 15 | Group { 16 | if displayStyle == .plain { 17 | Text(output) 18 | .textSelection(.enabled) 19 | } else { 20 | Markdown(output) 21 | .textSelection(.enabled) 22 | } 23 | } 24 | 25 | // Warning banner when output is truncated 26 | if wasTruncated && !output.isEmpty { 27 | HStack(spacing: 8) { 28 | Image(systemName: "exclamationmark.triangle.fill") 29 | .foregroundStyle(.orange) 30 | Text("Output truncated: Maximum token limit reached") 31 | .font(.caption) 32 | .foregroundStyle(.secondary) 33 | } 34 | .padding(8) 35 | .background(.orange.opacity(0.1), in: RoundedRectangle(cornerRadius: 6)) 36 | } 37 | } 38 | .onChange(of: output) { _, _ in 39 | sp.scrollTo("bottom") 40 | } 41 | 42 | Spacer() 43 | .frame(width: 1, height: 1) 44 | .id("bottom") 45 | } 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "pins" : [ 3 | { 4 | "identity" : "gzipswift", 5 | "kind" : "remoteSourceControl", 6 | "location" : "https://github.com/1024jp/GzipSwift", 7 | "state" : { 8 | "revision" : "731037f6cc2be2ec01562f6597c1d0aa3fe6fd05", 9 | "version" : "6.0.1" 10 | } 11 | }, 12 | { 13 | "identity" : "mlx-swift", 14 | "kind" : "remoteSourceControl", 15 | "location" : "https://github.com/ml-explore/mlx-swift", 16 | "state" : { 17 | "revision" : "072b684acaae80b6a463abab3a103732f33774bf", 18 | "version" : "0.29.1" 19 | } 20 | }, 21 | { 22 | "identity" : "swift-collections", 23 | "kind" : "remoteSourceControl", 24 | "location" : "https://github.com/apple/swift-collections.git", 25 | "state" : { 26 | "revision" : "c1805596154bb3a265fd91b8ac0c4433b4348fb0", 27 | "version" : "1.2.0" 28 | } 29 | }, 30 | { 31 | "identity" : "swift-jinja", 32 | "kind" : "remoteSourceControl", 33 | "location" : "https://github.com/huggingface/swift-jinja.git", 34 | "state" : { 35 | "revision" : "c1ef5963ba4a97a589b9c9583ff4ee3352a86d23", 36 | "version" : "2.1.0" 37 | } 38 | }, 39 | { 40 | "identity" : "swift-numerics", 41 | "kind" : "remoteSourceControl", 42 | "location" : "https://github.com/apple/swift-numerics", 43 | "state" : { 44 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", 45 | "version" : "1.0.2" 46 | } 47 | }, 48 | { 49 | "identity" : "swift-transformers", 50 | "kind" : "remoteSourceControl", 51 | "location" : "https://github.com/huggingface/swift-transformers", 52 | "state" : { 53 | "revision" : "94610577e4af9bbc267060af1e25e977604dd796", 54 | "version" : "1.1.1" 55 | } 56 | } 57 | ], 58 | "version" : 2 59 | } 60 | -------------------------------------------------------------------------------- /Libraries/StableDiffusion/README.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion 2 | 3 | Stable Diffusion in MLX. The implementation was ported from Hugging Face's 4 | [diffusers](https://huggingface.co/docs/diffusers/index) and 5 | [mlx-examples/stable_diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion). 6 | Model weights are downloaded directly from the Hugging Face hub. The implementation currently 7 | supports the following models: 8 | 9 | - [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo) 10 | - [stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) 11 | 12 | ## Usage 13 | 14 | See [StableDiffusionExample](../../Applications/StableDiffusionExample) and 15 | [image-tool](../../Tools/image-tool) for examples of using this code. 16 | 17 | The basic sequence is: 18 | 19 | - download & load the model 20 | - generate latents 21 | - evaluate the latents one by one 22 | - decode the last latent generated 23 | - you have an image! 24 | 25 | ```swift 26 | let configuration = StableDiffusionConfiguration.presetSDXLTurbo 27 | 28 | let generator = try configuration.textToImageGenerator( 29 | configuration: model.loadConfiguration) 30 | 31 | generator.ensureLoaded() 32 | 33 | // Generate the latents, which are the iterations for generating 34 | // the output image. This is just generating the evaluation graph 35 | let parameters = generate.evaluateParameters(configuration: configuration) 36 | let latents = generator.generateLatents(parameters: parameters) 37 | 38 | // evaluate the latents (evalue the graph) and keep the last value generated 39 | var lastXt: MLXArray? 40 | for xt in latents { 41 | eval(xt) 42 | lastXt = xt 43 | } 44 | 45 | // decode the final latent into an image 46 | if let lastXt { 47 | var raster = decoder(lastXt[0]) 48 | raster = (image * 255).asType(.uint8).squeezed() 49 | eval(raster) 50 | 51 | // turn it into a CGImage 52 | let image = Image(raster).asCGImage() 53 | 54 | // or write it out 55 | try Image(raster).save(url: url) 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /Applications/LLMEval/Views/LoadingOverlayView.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | struct LoadingOverlayView: View { 6 | let modelInfo: String 7 | let downloadProgress: Double? 8 | let progressDescription: String? 9 | 10 | init(modelInfo: String, downloadProgress: Double? = nil, progressDescription: String? = nil) { 11 | self.modelInfo = modelInfo 12 | self.downloadProgress = downloadProgress 13 | self.progressDescription = progressDescription 14 | } 15 | 16 | var body: some View { 17 | ZStack { 18 | Color.black.opacity(0.4) 19 | .ignoresSafeArea() 20 | 21 | VStack(spacing: 16) { 22 | if let progress = downloadProgress, progress < 1.0 { 23 | ProgressView(value: progress) 24 | .progressViewStyle(.linear) 25 | .frame(width: 200) 26 | } else { 27 | ProgressView() 28 | .scaleEffect(1.5) 29 | .progressViewStyle(.circular) 30 | } 31 | 32 | Text(modelInfo) 33 | .font(.headline) 34 | .foregroundStyle(.primary) 35 | 36 | if let description = progressDescription { 37 | Text(description) 38 | .font(.subheadline) 39 | .foregroundStyle(.secondary) 40 | .monospacedDigit() 41 | } 42 | 43 | Text( 44 | "Models are large and may take a couple of minutes to download on first use. They are cached locally for faster loading in the future." 45 | ) 46 | .font(.subheadline) 47 | .foregroundStyle(.secondary) 48 | .multilineTextAlignment(.center) 49 | .frame(maxWidth: 300) 50 | } 51 | .padding(32) 52 | .background(.regularMaterial, in: RoundedRectangle(cornerRadius: 16)) 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /Applications/LLMEval/Models/PresetPrompts.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import Foundation 4 | 5 | struct PresetPrompt: Identifiable { 6 | let id = UUID() 7 | let prompt: String 8 | let enableTools: Bool 9 | let enableThinking: Bool 10 | let isLongPrompt: Bool 11 | 12 | init( 13 | _ prompt: String, enableTools: Bool = false, enableThinking: Bool = false, 14 | isLongPrompt: Bool = false 15 | ) { 16 | self.prompt = prompt 17 | self.enableTools = enableTools 18 | self.enableThinking = enableThinking 19 | self.isLongPrompt = isLongPrompt 20 | } 21 | } 22 | 23 | struct PresetPrompts { 24 | // Helper to load prompts from markdown files 25 | private static func loadPrompt(named fileName: String) -> String { 26 | guard let url = Bundle.main.url(forResource: fileName, withExtension: "md"), 27 | let content = try? String(contentsOf: url, encoding: .utf8) 28 | else { 29 | return "Could not load \(fileName).md. Please ensure it is included in the app bundle." 30 | } 31 | return content 32 | } 33 | 34 | static let all: [PresetPrompt] = [ 35 | PresetPrompt("Why is the sky blue?"), 36 | PresetPrompt("What would a medieval knight's Yelp review of a dragon's lair look like?"), 37 | PresetPrompt("Explain why socks disappear in the dryer from the dryer's perspective."), 38 | 39 | PresetPrompt( 40 | "Write a breaking news report about cats discovering they can vote.", 41 | enableThinking: true), 42 | PresetPrompt( 43 | "Write a performance review for the person whose job is to make sure Mondays feel terrible.", 44 | enableThinking: true), 45 | 46 | PresetPrompt("What's the weather in Paris?", enableTools: true), 47 | PresetPrompt("What is the current time?", enableTools: true), 48 | 49 | PresetPrompt(loadPrompt(named: "LongPrompt"), enableThinking: true, isLongPrompt: true), 50 | PresetPrompt(loadPrompt(named: "CarKeysStory"), isLongPrompt: true), 51 | ] 52 | } 53 | -------------------------------------------------------------------------------- /Tools/embedder-tool/MemoryArguments.swift: -------------------------------------------------------------------------------- 1 | import ArgumentParser 2 | import MLX 3 | 4 | /// Argument package for adjusting and reporting GPU memory usage. 5 | struct MemoryArguments: ParsableArguments, Sendable { 6 | 7 | @Flag(name: .long, help: "Show GPU memory stats before exit.") 8 | var memoryStats = false 9 | 10 | @Option(name: .long, help: "Maximum GPU cache size in megabytes.") 11 | var cacheSize: Int? 12 | 13 | @Option(name: .long, help: "Maximum GPU memory size in megabytes.") 14 | var memorySize: Int? 15 | 16 | private(set) var startMemory: GPU.Snapshot? 17 | 18 | mutating func start(_ operation: @Sendable () async throws -> L) async throws -> L { 19 | applyLimits() 20 | let result = try await operation() 21 | startMemory = GPU.snapshot() 22 | return result 23 | } 24 | 25 | mutating func start() { 26 | applyLimits() 27 | startMemory = GPU.snapshot() 28 | } 29 | 30 | func reportCurrent() { 31 | guard memoryStats else { return } 32 | let memory = GPU.snapshot() 33 | print(memory.description) 34 | } 35 | 36 | func reportMemoryStatistics() { 37 | guard memoryStats, let startMemory else { return } 38 | 39 | let endMemory = GPU.snapshot() 40 | 41 | print("=======") 42 | print("GPU memory limit: \(GPU.memoryLimit / 1024)K") 43 | print("GPU cache limit: \(GPU.cacheLimit / 1024)K") 44 | print("") 45 | print("=======") 46 | print("Starting snapshot") 47 | print(startMemory.description) 48 | print("") 49 | print("=======") 50 | print("Ending snapshot") 51 | print(endMemory.description) 52 | print("") 53 | print("=======") 54 | print("Delta") 55 | print(startMemory.delta(endMemory).description) 56 | } 57 | 58 | private func applyLimits() { 59 | if let cacheSize { 60 | GPU.set(cacheLimit: cacheSize * 1024 * 1024) 61 | } 62 | 63 | if let memorySize { 64 | GPU.set(memoryLimit: memorySize * 1024 * 1024) 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /Applications/LLMEval/README.md: -------------------------------------------------------------------------------- 1 | # LLMEval 2 | 3 | An example that: 4 | 5 | - downloads a huggingface model and tokenizer 6 | - evaluates a prompt 7 | - displays the output as it generates text 8 | 9 | You will need to set the Team on the LLMEval target in order to build and run on iOS. 10 | 11 | Some notes about the setup: 12 | 13 | - this downloads models from hugging face so LLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox 14 | - LLM models are large so this uses the Increased Memory Limit entitlement on iOS to allow ... increased memory limits for devices that have more memory 15 | - `MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)` is used to limit the buffer cache size 16 | 17 | ### Trying Different Models 18 | 19 | The example app uses an 8 billion parameter quantized Qwen3 model by default, see [LLMEvaluator.swift](ViewModels/LLMEvaluator.swift#L52): 20 | 21 | ``` 22 | var modelConfiguration = LLMRegistry.qwen3_8b_4bit 23 | ``` 24 | 25 | There are some pre-configured models in [MLXLLM/LLMModelFactory.swift](../../Libraries/MLXLLM/LLMModelFactory.swift#L78) 26 | and you can load any weights from Hugging Face where there 27 | is a model architecture defined and you have enough 28 | memory. 29 | 30 | For example: 31 | ``` 32 | /// phi4bit is one of the smaller models so will fit on more devices 33 | var modelConfiguration = LLMRegistry.phi4bit 34 | ``` 35 | 36 | ### Troubleshooting 37 | 38 | If the program crashes with a very deep stack trace, you may need to build 39 | in Release configuration. This seems to depend on the size of the model. 40 | 41 | There are a couple options: 42 | 43 | - Build Release 44 | - Force the model evaluation to run on the main thread, e.g. using @MainActor 45 | - Build `Cmlx` with optimizations by modifying `mlx/Package.swift` and adding `.unsafeOptions(["-O3"]),` around line 87 46 | 47 | See discussion here: https://github.com/ml-explore/mlx-swift-examples/issues/3 48 | 49 | ### Performance 50 | 51 | Different models have difference performance characteristics. For example Gemma 2B may outperform Phi-2 in terms of tokens / second. 52 | 53 | You may also find that running outside the debugger boosts performance. You can do this in Xcode by pressing cmd-opt-r and unchecking "Debug Executable". 54 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Support/SampleData.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SampleData.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 20.04.2025. 6 | // 7 | 8 | @MainActor 9 | struct SampleData { 10 | static let conversation: [Message] = [ 11 | .system("You are a helpful assistant specializing in SwiftUI development."), 12 | .user("I need help building a weather app in SwiftUI. Where should I start?"), 13 | .assistant( 14 | "I'll help you create a weather app! Let's break it down into steps. First, we'll need to design the main view to display current weather conditions. Would you like to start with that?" 15 | ), 16 | .user("Yes, that sounds good. What components should I use for the main view?"), 17 | .assistant( 18 | "For the main weather view, I recommend using a VStack as the container. You can include:\n\n1. An Image view for the weather icon\n2. Text views for temperature and conditions\n3. HStack for additional metrics like humidity and wind speed" 19 | ), 20 | .user("How do I make the UI look modern and polished?"), 21 | .assistant( 22 | "To create a modern UI, try these techniques:\n\n- Use SF Symbols for weather icons\n- Add subtle gradients with .background()\n- Include padding and spacing for better layout\n- Implement dark mode support\n\nWould you like to see some example code?" 23 | ), 24 | .user("Yes, please show me an example for the main weather view."), 25 | .assistant( 26 | "Here's a basic example:\n\nVStack(spacing: 20) {\n Image(systemName: \"sun.max.fill\")\n .symbolRenderingMode(.multicolor)\n .font(.system(size: 64))\n \n Text(\"72°\")\n .font(.largeTitle)\n .bold()\n \n Text(\"Sunny\")\n .font(.title2)\n}\n.padding()\n.background(.ultraThinMaterial)\n.clipShape(RoundedRectangle(cornerRadius: 20))" 27 | ), 28 | .user("That looks great! How would I add animations to this?"), 29 | .assistant( 30 | "We can add smooth animations using SwiftUI's animation modifiers. For example:\n\n1. Use withAnimation for state changes\n2. Add .animation() modifier to views\n3. Implement transitions\n\nWould you like to see how to animate the weather changes?" 31 | ), 32 | ] 33 | } 34 | -------------------------------------------------------------------------------- /Tests/MLXLMTests/StreamlinedTests.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import Foundation 4 | import MLX 5 | import MLXLLM 6 | import MLXLMCommon 7 | import MLXNN 8 | import MLXOptimizers 9 | import Tokenizers 10 | import XCTest 11 | 12 | /// Tests for the streamlined API 13 | public class StreamlinedTests: XCTestCase { 14 | 15 | /// for tests we don't download a model but do execute one 16 | func model() -> LanguageModel { 17 | let config = LlamaConfiguration( 18 | hiddenSize: 64, hiddenLayers: 16, intermediateSize: 64, attentionHeads: 8, 19 | rmsNormEps: 0.00001, vocabularySize: 100, kvHeads: 8) 20 | let model = LlamaModel(config) 21 | quantize(model: model, groupSize: 64, bits: 4) 22 | return model 23 | } 24 | 25 | /// This is equivalent to: 26 | /// 27 | /// ```swift 28 | /// let model = LLMModelFactory.load("test") 29 | /// ``` 30 | func modelContainer() -> ModelContainer { 31 | let context = ModelContext( 32 | configuration: .init(id: "test", extraEOSTokens: ["EOS"]), model: model(), 33 | processor: TestUserInputProcessor(), tokenizer: TestTokenizer()) 34 | return ModelContainer(context: context) 35 | } 36 | 37 | func testOneShot() async throws { 38 | let model = modelContainer() 39 | let result = try await ChatSession(model).respond(to: "Tell me about things") 40 | print(result) 41 | } 42 | 43 | func testOneShotStream() async throws { 44 | let model = modelContainer() 45 | for try await token in ChatSession(model).streamResponse(to: "Tell me about things") { 46 | print(token, terminator: "") 47 | } 48 | } 49 | 50 | func testChat() async throws { 51 | let model = modelContainer() 52 | let session = ChatSession(model) 53 | 54 | print(try await session.respond(to: "what color is the sky?")) 55 | print(try await session.respond(to: "why is that?")) 56 | print(try await session.respond(to: "describe this image", image: .ciImage(CIImage.red))) 57 | } 58 | } 59 | 60 | private struct TestUserInputProcessor: UserInputProcessor { 61 | func prepare(input: UserInput) throws -> LMInput { 62 | LMInput(tokens: MLXRandom.randInt(0 ..< 1000, [100])) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/ChatView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ChatView.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 20.04.2025. 6 | // 7 | 8 | import AVFoundation 9 | import AVKit 10 | import SwiftUI 11 | 12 | /// Main chat interface view that manages the conversation UI and user interactions. 13 | /// Displays messages, handles media attachments, and provides input controls. 14 | struct ChatView: View { 15 | /// View model that manages the chat state and business logic 16 | @Bindable private var vm: ChatViewModel 17 | 18 | /// Initializes the chat view with a view model 19 | /// - Parameter viewModel: The view model to manage chat state 20 | init(viewModel: ChatViewModel) { 21 | self.vm = viewModel 22 | } 23 | 24 | var body: some View { 25 | NavigationStack { 26 | VStack(spacing: 0) { 27 | // Display conversation history 28 | ConversationView(messages: vm.messages) 29 | 30 | Divider() 31 | 32 | // Show media previews if attachments are present 33 | if !vm.mediaSelection.isEmpty { 34 | MediaPreviewsView(mediaSelection: vm.mediaSelection) 35 | } 36 | 37 | // Input field with send and media attachment buttons 38 | PromptField( 39 | prompt: $vm.prompt, 40 | sendButtonAction: vm.generate, 41 | // Only show media button for vision-capable models 42 | mediaButtonAction: vm.selectedModel.isVisionModel 43 | ? { 44 | vm.mediaSelection.isShowing = true 45 | } : nil 46 | ) 47 | .padding() 48 | } 49 | .navigationTitle("MLX Chat Example") 50 | .toolbar { 51 | ChatToolbarView(vm: vm) 52 | } 53 | // Handle media file selection 54 | .fileImporter( 55 | isPresented: $vm.mediaSelection.isShowing, 56 | allowedContentTypes: [.image, .movie], 57 | onCompletion: vm.addMedia 58 | ) 59 | } 60 | } 61 | } 62 | 63 | #Preview { 64 | ChatView(viewModel: ChatViewModel(mlxService: MLXService())) 65 | } 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Xcode 2 | # 3 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 4 | 5 | ## User settings 6 | xcuserdata/ 7 | 8 | ## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9) 9 | *.xcscmblueprint 10 | *.xccheckout 11 | 12 | ## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4) 13 | build/ 14 | DerivedData/ 15 | *.moved-aside 16 | *.pbxuser 17 | !default.pbxuser 18 | *.mode1v3 19 | !default.mode1v3 20 | *.mode2v3 21 | !default.mode2v3 22 | *.perspectivev3 23 | !default.perspectivev3 24 | 25 | ## Obj-C/Swift specific 26 | *.hmap 27 | 28 | ## App packaging 29 | *.ipa 30 | *.dSYM.zip 31 | *.dSYM 32 | 33 | ## Playgrounds 34 | timeline.xctimeline 35 | playground.xcworkspace 36 | 37 | # Swift Package Manager 38 | # 39 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. 40 | Packages/ 41 | Package.pins 42 | Package.resolved 43 | # *.xcodeproj 44 | # 45 | # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata 46 | # hence it is not needed unless you have added a package configuration file to your project 47 | .swiftpm 48 | 49 | .build/ 50 | 51 | # CocoaPods 52 | # 53 | # We recommend against adding the Pods directory to your .gitignore. However 54 | # you should judge for yourself, the pros and cons are mentioned at: 55 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control 56 | # 57 | # Pods/ 58 | # 59 | # Add this line if you want to avoid checking in source code from the Xcode workspace 60 | # *.xcworkspace 61 | 62 | # Carthage 63 | # 64 | # Add this line if you want to avoid checking in source code from Carthage dependencies. 65 | # Carthage/Checkouts 66 | 67 | Carthage/Build/ 68 | 69 | # Accio dependency management 70 | Dependencies/ 71 | .accio/ 72 | 73 | # fastlane 74 | # 75 | # It is recommended to not store the screenshots in the git repo. 76 | # Instead, use fastlane to re-generate the screenshots whenever they are needed. 77 | # For more information about the recommended setup visit: 78 | # https://docs.fastlane.tools/best-practices/source-control/#source-control 79 | 80 | fastlane/report.xml 81 | fastlane/Preview.html 82 | fastlane/screenshots/**/*.png 83 | fastlane/test_output 84 | 85 | # Code Injection 86 | # 87 | # After new code Injection tools there's a generated folder /iOSInjectionProject 88 | # https://github.com/johnno1962/injectionforxcode 89 | 90 | iOSInjectionProject/ 91 | 92 | # OS 93 | .DS_Store 94 | 95 | .idea 96 | .vscode 97 | 98 | -------------------------------------------------------------------------------- /Tests/MLXLMTests/BaseConfigurationTests.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import Foundation 4 | import MLXLMCommon 5 | import XCTest 6 | 7 | public class BaseConfigurationTests: XCTestCase { 8 | 9 | func testQuantization() throws { 10 | let json = 11 | """ 12 | { 13 | "model_type": "Test", 14 | "quantization": { 15 | "group_size": 128, 16 | "bits": 4 17 | } 18 | } 19 | """ 20 | 21 | let config = try JSONDecoder().decode( 22 | BaseConfiguration.self, from: json.data(using: .utf8)!) 23 | 24 | XCTAssertEqual(config.quantization, .init(groupSize: 128, bits: 4)) 25 | XCTAssertEqual( 26 | config.perLayerQuantization?.quantization(layer: "x"), .init(groupSize: 128, bits: 4)) 27 | } 28 | 29 | func testHeterogenousQuantization() throws { 30 | // from https://huggingface.co/mlx-community/Qwen3-1.7B-4bit-AWQ/blob/main/config.json#L20 31 | let json = 32 | """ 33 | { 34 | "model_type": "Test", 35 | "quantization": { 36 | "group_size": 64, 37 | "bits": 4, 38 | "model.embed_tokens": { 39 | "group_size": 32, 40 | "bits": 4 41 | }, 42 | "model.layers.0.self_attn.q_norm": false, 43 | "true_layer": true 44 | } 45 | } 46 | """ 47 | 48 | let config = try JSONDecoder().decode( 49 | BaseConfiguration.self, from: json.data(using: .utf8)!) 50 | 51 | XCTAssertEqual(config.quantization, .init(groupSize: 64, bits: 4)) 52 | 53 | // a random layer -- no specific configuration gets default 54 | XCTAssertEqual( 55 | config.perLayerQuantization?.quantization(layer: "x"), 56 | .init(groupSize: 64, bits: 4)) 57 | 58 | // layer with an override 59 | XCTAssertEqual( 60 | config.perLayerQuantization?.quantization(layer: "model.embed_tokens"), 61 | .init(groupSize: 32, bits: 4)) 62 | 63 | // layer with an override -- not quant 64 | XCTAssertNil( 65 | config.perLayerQuantization?.quantization(layer: "model.layers.0.self_attn.q_norm")) 66 | 67 | // layer with an override -- true, use the default 68 | XCTAssertEqual( 69 | config.perLayerQuantization?.quantization(layer: "true_layer"), 70 | .init(groupSize: 64, bits: 4)) 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /Libraries/MLXMNIST/MNIST.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import MLX 5 | import MLXNN 6 | 7 | // based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py 8 | 9 | public class LeNet: Module, UnaryLayer { 10 | 11 | @ModuleInfo var conv1: Conv2d 12 | @ModuleInfo var conv2: Conv2d 13 | @ModuleInfo var pool1: MaxPool2d 14 | @ModuleInfo var pool2: MaxPool2d 15 | @ModuleInfo var fc1: Linear 16 | @ModuleInfo var fc2: Linear 17 | @ModuleInfo var fc3: Linear 18 | 19 | override public init() { 20 | conv1 = Conv2d(inputChannels: 1, outputChannels: 6, kernelSize: 5, padding: 2) 21 | conv2 = Conv2d(inputChannels: 6, outputChannels: 16, kernelSize: 5, padding: 0) 22 | pool1 = MaxPool2d(kernelSize: 2, stride: 2) 23 | pool2 = MaxPool2d(kernelSize: 2, stride: 2) 24 | fc1 = Linear(16 * 5 * 5, 120) 25 | fc2 = Linear(120, 84) 26 | fc3 = Linear(84, 10) 27 | } 28 | 29 | public func callAsFunction(_ x: MLXArray) -> MLXArray { 30 | var x = x 31 | x = pool1(tanh(conv1(x))) 32 | x = pool2(tanh(conv2(x))) 33 | x = flattened(x, start: 1) 34 | x = tanh(fc1(x)) 35 | x = tanh(fc2(x)) 36 | x = fc3(x) 37 | return x 38 | } 39 | } 40 | 41 | public func loss(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray { 42 | crossEntropy(logits: model(x), targets: y, reduction: .mean) 43 | } 44 | 45 | public func eval(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray { 46 | mean(argMax(model(x), axis: 1) .== y) 47 | } 48 | 49 | private struct BatchSequence: Sequence, IteratorProtocol { 50 | 51 | let batchSize: Int 52 | let x: MLXArray 53 | let y: MLXArray 54 | 55 | let indexes: MLXArray 56 | var index = 0 57 | 58 | init(batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator) 59 | { 60 | self.batchSize = batchSize 61 | self.x = x 62 | self.y = y 63 | self.indexes = MLXArray(Array(0 ..< y.size).shuffled(using: &generator)) 64 | } 65 | 66 | mutating func next() -> (MLXArray, MLXArray)? { 67 | guard index < y.size else { return nil } 68 | 69 | let range = index ..< Swift.min(index + batchSize, y.size) 70 | index += batchSize 71 | let ids = indexes[range] 72 | return (x[ids], y[ids]) 73 | } 74 | } 75 | 76 | public func iterateBatches( 77 | batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator 78 | ) -> some Sequence<(MLXArray, MLXArray)> { 79 | BatchSequence(batchSize: batchSize, x: x, y: y, using: &generator) 80 | } 81 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 5.9 2 | // The swift-tools-version declares the minimum version of Swift required to build this package. 3 | 4 | import PackageDescription 5 | 6 | let package = Package( 7 | name: "mlx-libraries", 8 | platforms: [.macOS(.v14), .iOS(.v16)], 9 | products: [ 10 | .library( 11 | name: "MLXMNIST", 12 | targets: ["MLXMNIST"]), 13 | .library( 14 | name: "StableDiffusion", 15 | targets: ["StableDiffusion"]), 16 | ], 17 | dependencies: [ 18 | .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.29.1")), 19 | .package( 20 | url: "https://github.com/huggingface/swift-transformers", 21 | .upToNextMinor(from: "1.1.0") 22 | ), 23 | .package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"), // Only needed by MLXMNIST 24 | ], 25 | targets: [ 26 | .target( 27 | name: "MLXMNIST", 28 | dependencies: [ 29 | .product(name: "MLX", package: "mlx-swift"), 30 | .product(name: "MLXFast", package: "mlx-swift"), 31 | .product(name: "MLXNN", package: "mlx-swift"), 32 | .product(name: "MLXOptimizers", package: "mlx-swift"), 33 | .product(name: "MLXRandom", package: "mlx-swift"), 34 | .product(name: "Transformers", package: "swift-transformers"), 35 | .product(name: "Gzip", package: "GzipSwift"), 36 | ], 37 | path: "Libraries/MLXMNIST", 38 | exclude: [ 39 | "README.md" 40 | ], 41 | swiftSettings: [ 42 | .enableExperimentalFeature("StrictConcurrency") 43 | ] 44 | ), 45 | .target( 46 | name: "StableDiffusion", 47 | dependencies: [ 48 | .product(name: "MLX", package: "mlx-swift"), 49 | .product(name: "MLXNN", package: "mlx-swift"), 50 | .product(name: "MLXRandom", package: "mlx-swift"), 51 | .product(name: "Transformers", package: "swift-transformers"), 52 | ], 53 | path: "Libraries/StableDiffusion", 54 | exclude: [ 55 | "README.md" 56 | ], 57 | swiftSettings: [ 58 | .enableExperimentalFeature("StrictConcurrency") 59 | ] 60 | ), 61 | ] 62 | ) 63 | 64 | if Context.environment["MLX_SWIFT_BUILD_DOC"] == "1" 65 | || Context.environment["SPI_GENERATE_DOCS"] == "1" 66 | { 67 | // docc builder 68 | package.dependencies.append( 69 | .package(url: "https://github.com/apple/swift-docc-plugin", from: "1.3.0") 70 | ) 71 | } 72 | -------------------------------------------------------------------------------- /Applications/LLMEval/Views/PromptInputView.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | struct PromptInputView: View { 6 | @Bindable var llm: LLMEvaluator 7 | @Binding var isPromptExpanded: Bool 8 | @Binding var showingPresetPrompts: Bool 9 | 10 | let onGenerate: () -> Void 11 | let onCancel: () -> Void 12 | 13 | var body: some View { 14 | VStack(alignment: .leading, spacing: 8) { 15 | // Prompt header with expand/collapse chevron 16 | HStack { 17 | Text("Prompt") 18 | .font(.caption) 19 | .foregroundStyle(.secondary) 20 | 21 | Spacer() 22 | 23 | Button { 24 | withAnimation(.easeInOut(duration: 0.2)) { 25 | isPromptExpanded.toggle() 26 | } 27 | } label: { 28 | Image(systemName: isPromptExpanded ? "chevron.down" : "chevron.up") 29 | .font(.caption) 30 | .foregroundStyle(.secondary) 31 | } 32 | .buttonStyle(.plain) 33 | .help(isPromptExpanded ? "Collapse prompt area" : "Expand prompt area") 34 | } 35 | 36 | // Prompt text field with dynamic sizing 37 | TextField("Enter your prompt...", text: $llm.prompt, axis: .vertical) 38 | .textFieldStyle(.roundedBorder) 39 | .lineLimit(isPromptExpanded ? 15 ... 50 : 1 ... 3) 40 | .frame(height: isPromptExpanded ? 400 : nil) 41 | .onSubmit(onGenerate) 42 | .disabled(llm.running || llm.isLoading) 43 | 44 | // Action buttons 45 | HStack(spacing: 12) { 46 | Button { 47 | showingPresetPrompts = true 48 | } label: { 49 | Label("Example Prompts", systemImage: "list.bullet") 50 | } 51 | .disabled(llm.running || llm.isLoading) 52 | 53 | Spacer() 54 | 55 | Button { 56 | if llm.running { 57 | onCancel() 58 | } else { 59 | onGenerate() 60 | } 61 | } label: { 62 | Label( 63 | llm.running ? "Stop" : "Generate", 64 | systemImage: llm.running ? "stop.circle" : "play.fill" 65 | ) 66 | } 67 | .buttonStyle(.borderedProminent) 68 | .keyboardShortcut(.return, modifiers: .command) 69 | .disabled((llm.prompt.isEmpty && !llm.running) || llm.isLoading) 70 | } 71 | } 72 | .padding(.vertical, 8) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /Applications/LLMEval/Services/ToolExecutor.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import Foundation 4 | import MLXLMCommon 5 | 6 | public typealias ToolSpec = [String: Any] 7 | 8 | /// Manages tool definitions and execution for LLM function calling 9 | @MainActor 10 | class ToolExecutor { 11 | 12 | // MARK: - Tool Definitions 13 | 14 | let currentWeatherTool = Tool( 15 | name: "get_current_weather", 16 | description: "Get the current weather in a given location", 17 | parameters: [ 18 | .required( 19 | "location", type: .string, description: "The city and state, e.g. San Francisco, CA" 20 | ), 21 | .optional( 22 | "unit", 23 | type: .string, 24 | description: "The unit of temperature", 25 | extraProperties: [ 26 | "enum": ["celsius", "fahrenheit"], 27 | "default": "celsius", 28 | ] 29 | ), 30 | ] 31 | ) { input in 32 | let range = input.unit == "celsius" ? (min: -20.0, max: 40.0) : (min: 0, max: 100) 33 | let temperature = Double.random(in: range.min ... range.max) 34 | let conditions = ["Sunny", "Cloudy", "Rainy", "Snowy", "Windy", "Stormy"].randomElement()! 35 | return WeatherOutput(temperature: temperature, conditions: conditions) 36 | } 37 | 38 | let addTool = Tool( 39 | name: "add_two_numbers", 40 | description: "Add two numbers together", 41 | parameters: [ 42 | .required("first", type: .int, description: "The first number to add"), 43 | .required("second", type: .int, description: "The second number to add"), 44 | ] 45 | ) { input in 46 | AddOutput(result: input.first + input.second) 47 | } 48 | 49 | let timeTool = Tool( 50 | name: "get_time", 51 | description: "Get the current time", 52 | parameters: [] 53 | ) { _ in 54 | TimeOutput(time: Date.now.formatted()) 55 | } 56 | 57 | // MARK: - Tool Execution 58 | 59 | /// Returns all available tool schemas 60 | var allToolSchemas: [ToolSpec] { 61 | [currentWeatherTool.schema, addTool.schema, timeTool.schema] 62 | } 63 | 64 | /// Executes a tool call and returns the result 65 | func execute(_ toolCall: ToolCall) async throws -> String { 66 | switch toolCall.function.name { 67 | case currentWeatherTool.name: 68 | return try await toolCall.execute(with: currentWeatherTool).toolResult 69 | case addTool.name: 70 | return try await toolCall.execute(with: addTool).toolResult 71 | case timeTool.name: 72 | return try await toolCall.execute(with: timeTool).toolResult 73 | default: 74 | return "Unknown tool: \(toolCall.function.name)" 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Models/Message.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Message.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 20.04.2025. 6 | // 7 | 8 | import Foundation 9 | 10 | /// Represents a chat message in the conversation. 11 | /// Messages can contain text content and optional media attachments (images and videos). 12 | @Observable 13 | class Message: Identifiable { 14 | /// Unique identifier for the message 15 | let id: UUID 16 | 17 | /// The role of the message sender (user, assistant, or system) 18 | let role: Role 19 | 20 | /// The text content of the message 21 | var content: String 22 | 23 | /// Array of image URLs attached to the message 24 | var images: [URL] 25 | 26 | /// Array of video URLs attached to the message 27 | var videos: [URL] 28 | 29 | /// Timestamp when the message was created 30 | let timestamp: Date 31 | 32 | /// Creates a new message with the specified role, content, and optional media attachments 33 | /// - Parameters: 34 | /// - role: The role of the message sender 35 | /// - content: The text content of the message 36 | /// - images: Optional array of image URLs 37 | /// - videos: Optional array of video URLs 38 | init(role: Role, content: String, images: [URL] = [], videos: [URL] = []) { 39 | self.id = UUID() 40 | self.role = role 41 | self.content = content 42 | self.images = images 43 | self.videos = videos 44 | self.timestamp = .now 45 | } 46 | 47 | /// Defines the role of the message sender in the conversation 48 | enum Role { 49 | /// Message from the user 50 | case user 51 | /// Message from the AI assistant 52 | case assistant 53 | /// System message providing context or instructions 54 | case system 55 | } 56 | } 57 | 58 | /// Convenience methods for creating different types of messages 59 | extension Message { 60 | /// Creates a user message with optional media attachments 61 | /// - Parameters: 62 | /// - content: The text content of the message 63 | /// - images: Optional array of image URLs 64 | /// - videos: Optional array of video URLs 65 | /// - Returns: A new Message instance with user role 66 | static func user(_ content: String, images: [URL] = [], videos: [URL] = []) -> Message { 67 | Message(role: .user, content: content, images: images, videos: videos) 68 | } 69 | 70 | /// Creates an assistant message 71 | /// - Parameter content: The text content of the message 72 | /// - Returns: A new Message instance with assistant role 73 | static func assistant(_ content: String) -> Message { 74 | Message(role: .assistant, content: content) 75 | } 76 | 77 | /// Creates a system message 78 | /// - Parameter content: The text content of the message 79 | /// - Returns: A new Message instance with system role 80 | static func system(_ content: String) -> Message { 81 | Message(role: .system, content: content) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /Tools/image-tool/Arguments.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | import MLX 6 | 7 | #if swift(>=5.10) 8 | /// Extension to allow URL command line arguments. 9 | extension URL: @retroactive ExpressibleByArgument { 10 | public init?(argument: String) { 11 | if argument.contains("://") { 12 | self.init(string: argument) 13 | } else { 14 | self.init(filePath: argument) 15 | } 16 | } 17 | } 18 | #else 19 | /// Extension to allow URL command line arguments. 20 | extension URL: ExpressibleByArgument { 21 | public init?(argument: String) { 22 | if argument.contains("://") { 23 | self.init(string: argument) 24 | } else { 25 | self.init(filePath: argument) 26 | } 27 | } 28 | } 29 | #endif 30 | 31 | /// Argument package for adjusting and reporting memory use. 32 | struct MemoryArguments: ParsableArguments, Sendable { 33 | 34 | @Flag(name: .long, help: "Show memory stats") 35 | var memoryStats = false 36 | 37 | @Option(name: .long, help: "Maximum cache size in M") 38 | var cacheSize = 1024 39 | 40 | @Option(name: .long, help: "Maximum memory size in M") 41 | var memorySize: Int? 42 | 43 | var startMemory: GPU.Snapshot? 44 | 45 | mutating func start(_ load: () async throws -> L) async throws -> L { 46 | GPU.set(cacheLimit: cacheSize * 1024 * 1024) 47 | 48 | if let memorySize { 49 | GPU.set(memoryLimit: memorySize * 1024 * 1024) 50 | } 51 | 52 | let result = try await load() 53 | startMemory = GPU.snapshot() 54 | 55 | return result 56 | } 57 | 58 | mutating func start() { 59 | GPU.set(cacheLimit: cacheSize * 1024 * 1024) 60 | 61 | if let memorySize { 62 | GPU.set(memoryLimit: memorySize * 1024 * 1024) 63 | } 64 | 65 | startMemory = GPU.snapshot() 66 | } 67 | 68 | func reportCurrent() { 69 | if memoryStats { 70 | let memory = GPU.snapshot() 71 | print(memory.description) 72 | } 73 | } 74 | 75 | func reportMemoryStatistics() { 76 | if memoryStats, let startMemory { 77 | let endMemory = GPU.snapshot() 78 | 79 | print("=======") 80 | print("Memory size: \(GPU.memoryLimit / 1024)K") 81 | print("Cache size: \(GPU.cacheLimit / 1024)K") 82 | 83 | print("") 84 | print("=======") 85 | print("Starting memory") 86 | print(startMemory.description) 87 | 88 | print("") 89 | print("=======") 90 | print("Ending memory") 91 | print(endMemory.description) 92 | 93 | print("") 94 | print("=======") 95 | print("Growth") 96 | print(startMemory.delta(endMemory).description) 97 | 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /Tools/embedder-tool/DemoCommand.swift: -------------------------------------------------------------------------------- 1 | import ArgumentParser 2 | import Foundation 3 | 4 | struct DemoCommand: AsyncParsableCommand { 5 | static let configuration = CommandConfiguration( 6 | commandName: "demo", 7 | abstract: "Run a demo using sample repository documentation" 8 | ) 9 | 10 | @OptionGroup var memory: MemoryArguments 11 | 12 | @Flag(name: .long, help: "Keep the generated demo index file instead of removing it") 13 | var keepIndex = false 14 | 15 | @Argument(help: "Optional queries to run after indexing. Defaults to three sample queries.") 16 | var queries: [String] = [] 17 | 18 | mutating func run() async throws { 19 | var memory = self.memory 20 | memory.start() 21 | defer { 22 | memory.reportMemoryStatistics() 23 | self.memory = memory 24 | } 25 | 26 | print("Embedder Tool Demo") 27 | 28 | let indexURL = try makeTemporaryIndexURL() 29 | defer { 30 | if !keepIndex { 31 | do { 32 | try FileManager.default.removeItem(at: indexURL) 33 | } catch { 34 | if FileManager.default.fileExists(atPath: indexURL.path) { 35 | let message = 36 | "Failed to remove temporary index file at \(indexURL.path): \(error.localizedDescription). Please remove it manually." 37 | writeDiagnostic(message, kind: .warning) 38 | } 39 | } 40 | } 41 | } 42 | 43 | try await buildIndex(at: indexURL) 44 | let queriesToRun = queries.isEmpty ? defaultQueries : queries 45 | try await runSampleQueries(using: indexURL, queries: queriesToRun) 46 | } 47 | 48 | private func makeTemporaryIndexURL() throws -> URL { 49 | let directory = FileManager.default.temporaryDirectory 50 | return directory.appendingPathComponent("embedder-demo-\(UUID().uuidString).json") 51 | } 52 | 53 | private func buildIndex(at url: URL) async throws { 54 | var indexCommand = IndexCommand() 55 | indexCommand.corpus.directory = URL(fileURLWithPath: "Libraries") 56 | indexCommand.corpus.extensions = ["md"] 57 | indexCommand.corpus.recursive = true 58 | indexCommand.corpus.limit = 8 59 | indexCommand.output = url 60 | indexCommand.batchSize = 4 61 | indexCommand.pooling.normalize = true 62 | 63 | try await indexCommand.run() 64 | } 65 | 66 | private func runSampleQueries(using indexURL: URL, queries: [String]) async throws { 67 | for query in queries { 68 | var searchCommand = SearchCommand() 69 | searchCommand.index = indexURL 70 | searchCommand.query = query 71 | searchCommand.top = 2 72 | searchCommand.pooling.normalize = true 73 | 74 | try await searchCommand.run() 75 | } 76 | } 77 | 78 | private var defaultQueries: [String] { 79 | [ 80 | "How do I use embedding models?", 81 | "Training language models", 82 | "Vision language models", 83 | ] 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/LLMEval.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 8 | 9 | 15 | 21 | 22 | 23 | 24 | 25 | 31 | 32 | 42 | 44 | 50 | 51 | 52 | 53 | 59 | 61 | 67 | 68 | 69 | 70 | 72 | 73 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /Applications/LLMEval/Views/HeaderView.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | struct HeaderView: View { 6 | @Bindable var llm: LLMEvaluator 7 | @Binding var selectedDisplayStyle: ContentView.DisplayStyle 8 | 9 | var body: some View { 10 | VStack(alignment: .leading, spacing: 12) { 11 | // Model info with status 12 | HStack { 13 | VStack(alignment: .leading, spacing: 4) { 14 | Text("Model") 15 | .font(.caption) 16 | .foregroundStyle(.secondary) 17 | 18 | Text(llm.modelInfo) 19 | .font(.headline) 20 | .lineLimit(1) 21 | } 22 | 23 | Spacer() 24 | 25 | if llm.running { 26 | HStack(spacing: 8) { 27 | ProgressView() 28 | .controlSize(.small) 29 | Text("Generating...") 30 | .font(.subheadline) 31 | .foregroundStyle(.secondary) 32 | } 33 | } 34 | } 35 | 36 | // Controls row 37 | HStack(spacing: 16) { 38 | HStack(spacing: 24) { 39 | Toggle("Tools", isOn: $llm.includeWeatherTool) 40 | .toggleStyle(.switch) 41 | .fixedSize() 42 | .help("Enable function calling with weather, math, and time tools") 43 | 44 | Toggle("Thinking", isOn: $llm.enableThinking) 45 | .toggleStyle(.switch) 46 | .fixedSize() 47 | .help("Enable thinking mode (supported by Qwen3)") 48 | 49 | // Max tokens slider 50 | VStack(alignment: .leading, spacing: 4) { 51 | Text("Max Tokens: \(llm.maxTokens)") 52 | .font(.caption) 53 | .foregroundStyle(.secondary) 54 | 55 | Slider( 56 | value: Binding( 57 | get: { log2(Double(llm.maxTokens)) }, 58 | set: { llm.maxTokens = Int(pow(2, $0)) } 59 | ), 60 | in: 10 ... 15, // 2^10 (1024) to 2^15 (32768) 61 | step: 1 62 | ) 63 | .frame(width: 120) 64 | .help("Maximum number of tokens to generate (1024-32768)") 65 | } 66 | } 67 | 68 | Spacer() 69 | 70 | Picker("Display", selection: $selectedDisplayStyle) { 71 | ForEach(ContentView.DisplayStyle.allCases, id: \.self) { option in 72 | Text(option.rawValue.capitalized) 73 | .tag(option) 74 | } 75 | } 76 | .pickerStyle(.segmented) 77 | .labelsHidden() 78 | .frame(maxWidth: 180) 79 | } 80 | } 81 | .padding(.bottom, 12) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/VLMEval.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 25 | 26 | 32 | 33 | 43 | 45 | 51 | 52 | 53 | 54 | 60 | 62 | 68 | 69 | 70 | 71 | 73 | 74 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/ExampleLLM.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 25 | 26 | 32 | 33 | 44 | 46 | 52 | 53 | 54 | 55 | 61 | 63 | 69 | 70 | 71 | 72 | 74 | 75 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /Libraries/MLXMNIST/Files.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import Gzip 5 | import MLX 6 | 7 | // based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py 8 | 9 | public enum Use: String, Hashable, Sendable { 10 | case test 11 | case training 12 | } 13 | 14 | public enum DataKind: String, Hashable, Sendable { 15 | case images 16 | case labels 17 | } 18 | 19 | public struct FileKind: Hashable, CustomStringConvertible, Sendable { 20 | let use: Use 21 | let data: DataKind 22 | 23 | public init(_ use: Use, _ data: DataKind) { 24 | self.use = use 25 | self.data = data 26 | } 27 | 28 | public var description: String { 29 | "\(use.rawValue)-\(data.rawValue)" 30 | } 31 | } 32 | 33 | struct LoadInfo: Sendable { 34 | let name: String 35 | let offset: Int 36 | let convert: @Sendable (MLXArray) -> MLXArray 37 | } 38 | 39 | let baseURL = URL(string: "https://raw.githubusercontent.com/fgnt/mnist/master/")! 40 | 41 | private let files = [ 42 | FileKind(.training, .images): LoadInfo( 43 | name: "train-images-idx3-ubyte.gz", 44 | offset: 16, 45 | convert: { 46 | $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0 47 | }), 48 | FileKind(.test, .images): LoadInfo( 49 | name: "t10k-images-idx3-ubyte.gz", 50 | offset: 16, 51 | convert: { 52 | $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0 53 | }), 54 | FileKind(.training, .labels): LoadInfo( 55 | name: "train-labels-idx1-ubyte.gz", 56 | offset: 8, 57 | convert: { 58 | $0.asType(.uint32) 59 | }), 60 | FileKind(.test, .labels): LoadInfo( 61 | name: "t10k-labels-idx1-ubyte.gz", 62 | offset: 8, 63 | convert: { 64 | $0.asType(.uint32) 65 | }), 66 | ] 67 | 68 | public func download(into: URL) async throws { 69 | for (_, info) in files { 70 | let fileURL = into.appending(component: info.name) 71 | if !FileManager.default.fileExists(atPath: fileURL.path()) { 72 | print("Download: \(info.name)") 73 | let url = baseURL.appending(component: info.name) 74 | let (data, response) = try await URLSession.shared.data(from: url) 75 | 76 | guard let httpResponse = response as? HTTPURLResponse else { 77 | fatalError("Unable to download \(url), not an http response: \(response)") 78 | } 79 | guard httpResponse.statusCode == 200 else { 80 | fatalError("Unable to download \(url): \(httpResponse)") 81 | } 82 | 83 | try data.write(to: fileURL) 84 | } 85 | } 86 | } 87 | 88 | public func load(from: URL) throws -> [FileKind: MLXArray] { 89 | var result = [FileKind: MLXArray]() 90 | 91 | for (key, info) in files { 92 | let fileURL = from.appending(component: info.name) 93 | let data = try Data(contentsOf: fileURL).gunzipped() 94 | 95 | let array = MLXArray( 96 | data.dropFirst(info.offset), [data.count - info.offset], type: UInt8.self) 97 | 98 | result[key] = info.convert(array) 99 | } 100 | 101 | return result 102 | } 103 | -------------------------------------------------------------------------------- /Tools/embedder-tool/ModelArguments.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | import Hub 6 | import MLXEmbedders 7 | 8 | struct ModelArguments: ParsableArguments { 9 | 10 | @Option( 11 | name: .long, 12 | help: "Name of the embedder model configuration or absolute path to a local directory.") 13 | var model: String? 14 | 15 | @Option(name: .long, help: "Directory used for downloading model assets from the Hub.") 16 | var download: URL? 17 | 18 | @MainActor 19 | func configuration(default defaultConfiguration: ModelConfiguration) -> ModelConfiguration { 20 | guard let model else { 21 | return defaultConfiguration 22 | } 23 | 24 | if let localConfiguration = resolveLocalModelPath(model) { 25 | return localConfiguration 26 | } 27 | 28 | return ModelConfiguration.configuration(id: model) 29 | } 30 | 31 | var downloadURL: URL? { 32 | download?.standardizedFileURL 33 | } 34 | } 35 | 36 | struct LoadedEmbedderModel { 37 | let configuration: ModelConfiguration 38 | let container: ModelContainer 39 | } 40 | 41 | extension ModelArguments { 42 | 43 | func load(default defaultConfiguration: ModelConfiguration) async throws -> LoadedEmbedderModel 44 | { 45 | let configuration = await configuration(default: defaultConfiguration) 46 | let hub = makeHub() 47 | 48 | print("Loading model \(configuration.name)...") 49 | 50 | let container = try await MLXEmbedders.loadModelContainer( 51 | hub: hub, 52 | configuration: configuration, 53 | progressHandler: { progress in 54 | let percentage = Int(progress.fractionCompleted * 100) 55 | let previousPercentage = Int((progress.fractionCompleted - 0.01) * 100) 56 | 57 | if percentage % 10 == 0 && percentage != previousPercentage { 58 | print("Downloading model: \(percentage)%") 59 | } 60 | } 61 | ) 62 | 63 | return LoadedEmbedderModel(configuration: configuration, container: container) 64 | } 65 | 66 | private func makeHub() -> HubApi { 67 | if let downloadURL { 68 | return HubApi(downloadBase: downloadURL) 69 | } 70 | 71 | return HubApi() 72 | } 73 | } 74 | 75 | extension ModelArguments { 76 | private func resolveLocalModelPath(_ value: String) -> ModelConfiguration? { 77 | let expanded = NSString(string: value).expandingTildeInPath 78 | let candidate = URL(fileURLWithPath: expanded, isDirectory: true) 79 | var isDirectory: ObjCBool = false 80 | if FileManager.default.fileExists(atPath: candidate.path, isDirectory: &isDirectory), 81 | isDirectory.boolValue 82 | { 83 | return ModelConfiguration(directory: candidate.standardizedFileURL) 84 | } 85 | 86 | if let url = URL(string: value), url.isFileURL { 87 | var isDir: ObjCBool = false 88 | if FileManager.default.fileExists(atPath: url.path, isDirectory: &isDir), 89 | isDir.boolValue 90 | { 91 | return ModelConfiguration(directory: url.standardizedFileURL) 92 | } 93 | } 94 | 95 | return nil 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/StableDiffusionExample.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 25 | 26 | 32 | 33 | 43 | 45 | 51 | 52 | 53 | 54 | 60 | 62 | 68 | 69 | 70 | 71 | 73 | 74 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /Tools/embedder-tool/EmbedderRuntime+Embedding.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLX 3 | import MLXEmbedders 4 | import Tokenizers 5 | 6 | public struct RuntimeEmbeddingResult { 7 | public let embeddings: [(index: Int, vector: [Float])] 8 | public let skippedIndices: [Int] 9 | public let fallbackDescription: String? 10 | 11 | public init( 12 | embeddings: [(index: Int, vector: [Float])], 13 | skippedIndices: [Int], 14 | fallbackDescription: String? 15 | ) { 16 | self.embeddings = embeddings 17 | self.skippedIndices = skippedIndices 18 | self.fallbackDescription = fallbackDescription 19 | } 20 | } 21 | 22 | extension EmbedderRuntime { 23 | func embed(texts: [String]) async throws -> RuntimeEmbeddingResult { 24 | guard !texts.isEmpty else { 25 | return RuntimeEmbeddingResult( 26 | embeddings: [], skippedIndices: [], fallbackDescription: nil) 27 | } 28 | 29 | return try await container.perform { model, tokenizer, pooler in 30 | var skippedIndices: [Int] = [] 31 | 32 | let encoded = texts.enumerated().compactMap { index, text -> (Int, [Int])? in 33 | let tokens = tokenizer.encode(text: text, addSpecialTokens: true) 34 | guard !tokens.isEmpty else { 35 | skippedIndices.append(index) 36 | return nil 37 | } 38 | return (index, tokens) 39 | } 40 | 41 | guard !encoded.isEmpty else { 42 | return RuntimeEmbeddingResult( 43 | embeddings: [], 44 | skippedIndices: skippedIndices, 45 | fallbackDescription: nil 46 | ) 47 | } 48 | 49 | guard let padToken = tokenizer.eosTokenId else { 50 | throw CommandError("Could not determine a padding token from the tokenizer.") 51 | } 52 | let maxLength = encoded.map { $0.1.count }.max() ?? 0 53 | 54 | let padded = stacked( 55 | encoded.map { _, tokens in 56 | MLXArray(tokens + Array(repeating: padToken, count: maxLength - tokens.count)) 57 | }) 58 | let mask = (padded .!= padToken) 59 | let tokenTypes = MLXArray.zeros(like: padded) 60 | 61 | let outputs = model( 62 | padded, 63 | positionIds: nil, 64 | tokenTypeIds: tokenTypes, 65 | attentionMask: mask 66 | ) 67 | 68 | let poolingModule = resolvedPooler(for: pooler) 69 | let pooled = poolingModule( 70 | outputs, 71 | mask: mask, 72 | normalize: self.normalize, 73 | applyLayerNorm: self.applyLayerNorm 74 | ) 75 | pooled.eval() 76 | 77 | let extraction = try extractVectors(from: pooled, expectedCount: encoded.count) 78 | 79 | let embeddings = zip(encoded.map { $0.0 }, extraction.vectors).map { index, vector in 80 | (index: index, vector: vector) 81 | } 82 | 83 | return RuntimeEmbeddingResult( 84 | embeddings: embeddings, 85 | skippedIndices: skippedIndices, 86 | fallbackDescription: extraction.fallbackDescription 87 | ) 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /Tools/Tutorial/Tutorial.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import MLX 5 | 6 | /// mlx-swift tutorial based on: 7 | /// https://github.com/ml-explore/mlx/blob/main/examples/cpp/tutorial.cpp 8 | @main 9 | struct Tutorial { 10 | 11 | static func scalarBasics() { 12 | // create a scalar array 13 | let x = MLXArray(1.0) 14 | 15 | // the datatype is .float32 16 | let dtype = x.dtype 17 | assert(dtype == .float32) 18 | 19 | // get the value 20 | let s = x.item(Float.self) 21 | assert(s == 1.0) 22 | 23 | // reading the value with a different type is a fatal error 24 | // let i = x.item(Int.self) 25 | 26 | // scalars have a size of 1 27 | let size = x.size 28 | assert(size == 1) 29 | 30 | // scalars have 0 dimensions 31 | let ndim = x.ndim 32 | assert(ndim == 0) 33 | 34 | // scalar shapes are empty arrays 35 | let shape = x.shape 36 | assert(shape == []) 37 | } 38 | 39 | static func arrayBasics() { 40 | // make a multidimensional array. 41 | // 42 | // Note: the argument is a [Double] array literal, which is not 43 | // a supported type, but we can explicitly convert it to [Float] 44 | // when we create the MLXArray. 45 | let x = MLXArray(converting: [1.0, 2.0, 3.0, 4.0], [2, 2]) 46 | 47 | // mlx is row-major by default so the first row of this array 48 | // is [1.0, 2.0] and the second row is [3.0, 4.0] 49 | print(x[0]) 50 | print(x[1]) 51 | 52 | // make an array of shape [2, 2] filled with ones 53 | let y = MLXArray.ones([2, 2]) 54 | 55 | // pointwise add x and y 56 | let z = x + y 57 | 58 | // mlx is lazy by default. At this point `z` only 59 | // has a shape and a type but no actual data 60 | assert(z.dtype == .float32) 61 | assert(z.shape == [2, 2]) 62 | 63 | // To actually run the computation you must evaluate `z`. 64 | // Under the hood, mlx records operations in a graph. 65 | // The variable `z` is a node in the graph which points to its operation 66 | // and inputs. When `eval` is called on an array (or arrays), the array and 67 | // all of its dependencies are recursively evaluated to produce the result. 68 | // Once an array is evaluated, it has data and is detached from its inputs. 69 | 70 | // Note: this is being called for demonstration purposes -- all reads 71 | // ensure the array is evaluated. 72 | z.eval() 73 | 74 | // this implicitly evaluates z before converting to a description 75 | print(z) 76 | } 77 | 78 | static func automaticDifferentiation() { 79 | func fn(_ x: MLXArray) -> MLXArray { 80 | x.square() 81 | } 82 | 83 | let gradFn = grad(fn) 84 | 85 | let x = MLXArray(1.5) 86 | let dfdx = gradFn(x) 87 | print(dfdx) 88 | 89 | assert(dfdx.item() == Float(2 * 1.5)) 90 | 91 | let df2dx2 = grad(grad(fn))(x) 92 | print(df2dx2) 93 | 94 | assert(df2dx2.item() == Float(2)) 95 | } 96 | 97 | static func main() { 98 | scalarBasics() 99 | arrayBasics() 100 | automaticDifferentiation() 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /Tests/MLXLMTests/ToolTests.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLXLMCommon 3 | import Testing 4 | 5 | struct ToolTests { 6 | @Test("Test Weather Tool Schema Generation") 7 | func testWeatherToolSchemaGeneration() throws { 8 | struct WeatherInput: Codable { 9 | let location: String 10 | let unit: String? 11 | } 12 | 13 | struct WeatherOutput: Codable { 14 | let temperature: Double 15 | let conditions: String 16 | } 17 | 18 | let tool = Tool( 19 | name: "get_current_weather", 20 | description: "Get the current weather in a given location", 21 | parameters: [ 22 | .required( 23 | "location", type: .string, description: "The city, e.g. Istanbul" 24 | ), 25 | .optional( 26 | "unit", 27 | type: .string, 28 | description: "The unit of temperature", 29 | extraProperties: [ 30 | "enum": ["celsius", "fahrenheit"] 31 | ] 32 | ), 33 | ] 34 | ) { input in 35 | WeatherOutput(temperature: 14.0, conditions: "Sunny") 36 | } 37 | 38 | let actual = tool.schema as NSDictionary 39 | 40 | let expected: NSDictionary = [ 41 | "type": "function", 42 | "function": [ 43 | "name": "get_current_weather", 44 | "description": "Get the current weather in a given location", 45 | "parameters": [ 46 | "type": "object", 47 | "properties": [ 48 | "location": [ 49 | "type": "string", 50 | "description": "The city, e.g. Istanbul", 51 | ], 52 | "unit": [ 53 | "type": "string", 54 | "description": "The unit of temperature", 55 | "enum": ["celsius", "fahrenheit"], 56 | ], 57 | ], 58 | "required": ["location"], 59 | ], 60 | ], 61 | ] 62 | 63 | #expect(actual == expected) 64 | } 65 | 66 | @Test("Test Tool Call Detection in Generated Text") 67 | func testToolCallDetection() throws { 68 | let processor = ToolCallProcessor() 69 | let chunks: [String] = [ 70 | "", "{", "\"", "name", "\"", ":", " ", "\"", "get", "_", "current", 71 | "_", "weather", "\"", ",", " ", "\"", "arguments", "\"", ":", " ", "{", "\"", 72 | "location", "\"", ":", " ", "\"", "San", " Francisco", "\"", ",", " ", "\"", "unit", 73 | "\"", ":", " ", "\"", "celsius", "\"", "}", "}", "", 74 | ] 75 | 76 | for chunk in chunks { 77 | let result = processor.processChunk(chunk) 78 | #expect(result == nil) 79 | } 80 | 81 | #expect(processor.toolCalls.count == 1) 82 | let toolCall = try #require(processor.toolCalls.first) 83 | 84 | #expect(toolCall.function.name == "get_current_weather") 85 | #expect(toolCall.function.arguments["location"] == .string("San Francisco")) 86 | #expect(toolCall.function.arguments["unit"] == .string("celsius")) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Views/MessageView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // MessageView.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 20.04.2025. 6 | // 7 | 8 | import AVKit 9 | import SwiftUI 10 | 11 | /// A view that displays a single message in the chat interface. 12 | /// Supports different message roles (user, assistant, system) and media attachments. 13 | struct MessageView: View { 14 | /// The message to be displayed 15 | let message: Message 16 | 17 | /// Creates a message view 18 | /// - Parameter message: The message model to display 19 | init(_ message: Message) { 20 | self.message = message 21 | } 22 | 23 | var body: some View { 24 | switch message.role { 25 | case .user: 26 | // User messages are right-aligned with blue background 27 | HStack { 28 | Spacer() 29 | VStack(alignment: .trailing, spacing: 8) { 30 | // Display first image if present 31 | if let firstImage = message.images.first { 32 | AsyncImage(url: firstImage) { image in 33 | image 34 | .resizable() 35 | .aspectRatio(contentMode: .fill) 36 | } placeholder: { 37 | ProgressView() 38 | } 39 | .frame(maxWidth: 250, maxHeight: 200) 40 | .clipShape(.rect(cornerRadius: 12)) 41 | } 42 | 43 | // Display first video if present 44 | if let firstVideo = message.videos.first { 45 | VideoPlayer(player: AVPlayer(url: firstVideo)) 46 | .frame(width: 250, height: 340) 47 | .clipShape(.rect(cornerRadius: 12)) 48 | } 49 | 50 | // Message content with tinted background. 51 | // LocalizedStringKey used to trigger default handling of markdown content. 52 | Text(LocalizedStringKey(message.content)) 53 | .padding(.vertical, 8) 54 | .padding(.horizontal, 12) 55 | .background(.tint, in: .rect(cornerRadius: 16)) 56 | .textSelection(.enabled) 57 | } 58 | } 59 | 60 | case .assistant: 61 | // Assistant messages are left-aligned without background 62 | // LocalizedStringKey used to trigger default handling of markdown content. 63 | HStack { 64 | Text(LocalizedStringKey(message.content)) 65 | .textSelection(.enabled) 66 | 67 | Spacer() 68 | } 69 | 70 | case .system: 71 | // System messages are centered with computer icon 72 | Label(message.content, systemImage: "desktopcomputer") 73 | .font(.headline) 74 | .foregroundColor(.secondary) 75 | .frame(maxWidth: .infinity, alignment: .center) 76 | } 77 | } 78 | } 79 | 80 | #Preview { 81 | VStack(spacing: 20) { 82 | MessageView(.system("You are a helpful assistant.")) 83 | 84 | MessageView( 85 | .user( 86 | "Here's a photo", 87 | images: [URL(string: "https://picsum.photos/200")!] 88 | ) 89 | ) 90 | 91 | MessageView(.assistant("I see your photo!")) 92 | } 93 | .padding() 94 | } 95 | -------------------------------------------------------------------------------- /Tools/mnist-tool/MNISTTool.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | import MLX 6 | import MLXMNIST 7 | import MLXNN 8 | import MLXOptimizers 9 | 10 | @main 11 | struct MNISTTool: AsyncParsableCommand { 12 | static let configuration = CommandConfiguration( 13 | abstract: "Command line tool for training mnist models", 14 | subcommands: [Train.self], 15 | defaultSubcommand: Train.self) 16 | } 17 | 18 | #if swift(>=5.10) 19 | extension MLX.DeviceType: @retroactive ExpressibleByArgument { 20 | public init?(argument: String) { 21 | self.init(rawValue: argument) 22 | } 23 | } 24 | #else 25 | extension MLX.DeviceType: ExpressibleByArgument { 26 | public init?(argument: String) { 27 | self.init(rawValue: argument) 28 | } 29 | } 30 | #endif 31 | 32 | struct Train: AsyncParsableCommand { 33 | 34 | @Option(name: .long, help: "Directory with the training data") 35 | var data: String 36 | 37 | @Option(name: .long, help: "The PRNG seed") 38 | var seed: UInt64 = 0 39 | 40 | @Option var batchSize = 256 41 | @Option var epochs = 20 42 | @Option var learningRate: Float = 1e-1 43 | 44 | @Option var device = DeviceType.gpu 45 | 46 | @Flag var compile = false 47 | 48 | func run() async throws { 49 | Device.setDefault(device: Device(device)) 50 | 51 | MLXRandom.seed(seed) 52 | var generator: RandomNumberGenerator = SplitMix64(seed: seed) 53 | 54 | // load the data 55 | let url = URL(filePath: data) 56 | 57 | try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true) 58 | try await download(into: url) 59 | 60 | let data = try load(from: url) 61 | 62 | let trainImages = data[.init(.training, .images)]! 63 | let trainLabels = data[.init(.training, .labels)]! 64 | let testImages = data[.init(.test, .images)]! 65 | let testLabels = data[.init(.test, .labels)]! 66 | 67 | // create the model 68 | let model = LeNet() 69 | eval(model.parameters()) 70 | 71 | let lg = valueAndGrad(model: model, loss) 72 | let optimizer = SGD(learningRate: learningRate) 73 | 74 | func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray { 75 | let (loss, grads) = lg(model, x, y) 76 | optimizer.update(model: model, gradients: grads) 77 | return loss 78 | } 79 | 80 | let resolvedStep = 81 | compile 82 | ? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step 83 | 84 | for e in 0 ..< epochs { 85 | let start = Date.timeIntervalSinceReferenceDate 86 | 87 | for (x, y) in iterateBatches( 88 | batchSize: batchSize, x: trainImages, y: trainLabels, using: &generator) 89 | { 90 | _ = resolvedStep(x, y) 91 | 92 | // eval the parameters so the next iteration is independent 93 | eval(model, optimizer) 94 | } 95 | 96 | let accuracy = eval(model: model, x: testImages, y: testLabels) 97 | 98 | let end = Date.timeIntervalSinceReferenceDate 99 | 100 | print( 101 | """ 102 | Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted()) 103 | Time: \((end - start).formatted()) 104 | 105 | """ 106 | ) 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /Tools/embedder-tool/PoolingSupport.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLX 3 | import MLXEmbedders 4 | 5 | enum PoolingError: LocalizedError { 6 | case unsupportedShape([Int]) 7 | case vectorCountMismatch(expected: Int, received: Int) 8 | 9 | var errorDescription: String? { 10 | switch self { 11 | case .unsupportedShape(let shape): 12 | return "Pooling produced unsupported shape: \(shape)" 13 | case .vectorCountMismatch(let expected, let received): 14 | return "Pooling produced \(received) vectors but expected \(expected)" 15 | } 16 | } 17 | } 18 | 19 | struct PoolingExtraction { 20 | let vectors: [[Float]] 21 | let fallbackDescription: String? 22 | } 23 | 24 | extension EmbedderRuntime { 25 | func resolvedPooler(for pooler: Pooling) -> Pooling { 26 | guard let override = strategyOverride else { 27 | return pooler 28 | } 29 | 30 | if pooler.strategy == override { 31 | return pooler 32 | } 33 | 34 | if let dimension = pooler.dimension { 35 | return Pooling(strategy: override, dimension: dimension) 36 | } else { 37 | return Pooling(strategy: override) 38 | } 39 | } 40 | 41 | func extractVectors(from array: MLXArray, expectedCount: Int) throws -> PoolingExtraction { 42 | let shape = array.shape 43 | 44 | switch shape.count { 45 | case 2: 46 | let vectors = array.map { $0.asArray(Float.self) } 47 | guard vectors.count == expectedCount else { 48 | throw PoolingError.vectorCountMismatch( 49 | expected: expectedCount, received: vectors.count) 50 | } 51 | return PoolingExtraction(vectors: vectors, fallbackDescription: nil) 52 | 53 | case 3: 54 | let reduced = mean(array, axis: 1) 55 | reduced.eval() 56 | let vectors = reduced.map { $0.asArray(Float.self) } 57 | guard vectors.count == expectedCount else { 58 | throw PoolingError.vectorCountMismatch( 59 | expected: expectedCount, received: vectors.count) 60 | } 61 | 62 | let effectiveStrategy = strategyOverride ?? baseStrategy 63 | let description: String 64 | if effectiveStrategy == .none { 65 | description = 66 | "Pooling strategy 'none' returned sequence embeddings; falling back to mean over tokens." 67 | } else { 68 | description = 69 | "Pooling returned sequence embeddings; falling back to mean over tokens." 70 | } 71 | return PoolingExtraction(vectors: vectors, fallbackDescription: description) 72 | 73 | default: 74 | throw PoolingError.unsupportedShape(shape) 75 | } 76 | } 77 | } 78 | 79 | extension Pooling.Strategy { 80 | var cliDescription: String { 81 | switch self { 82 | case .mean: return "mean" 83 | case .cls: return "cls" 84 | case .first: return "first" 85 | case .last: return "last" 86 | case .max: return "max" 87 | case .none: return "none" 88 | } 89 | } 90 | } 91 | 92 | extension EmbedderRuntime { 93 | var poolingDescription: String { 94 | if let override = strategyOverride { 95 | return "override (\(override.cliDescription))" 96 | } else { 97 | return "model default (\(baseStrategy.cliDescription))" 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /Applications/LLMEval/Views/MetricsView.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import MLX 4 | import SwiftUI 5 | 6 | struct MetricsView: View { 7 | let tokensPerSecond: Double 8 | let timeToFirstToken: Double 9 | let promptLength: Int 10 | let totalTokens: Int 11 | let totalTime: Double 12 | let memoryUsed: Int 13 | let cacheMemory: Int 14 | let peakMemory: Int 15 | 16 | @State private var showMemoryDetails = false 17 | 18 | var body: some View { 19 | VStack(spacing: 12) { 20 | // Top row 21 | HStack(spacing: 12) { 22 | MetricCard( 23 | icon: "speedometer", 24 | title: "Tokens/sec", 25 | value: String(format: "%.1f", tokensPerSecond) 26 | ) 27 | MetricCard( 28 | icon: "timer", 29 | title: "Time to First Token", 30 | value: String(format: "%.0fms", timeToFirstToken) 31 | ) 32 | MetricCard( 33 | icon: "text.alignleft", 34 | title: "Prompt Length", 35 | value: "\(promptLength)" 36 | ) 37 | } 38 | 39 | // Bottom row 40 | HStack(spacing: 12) { 41 | MetricCard( 42 | icon: "number", 43 | title: "Total Tokens", 44 | value: "\(totalTokens)" 45 | ) 46 | MetricCard( 47 | icon: "hourglass", 48 | title: "Total Time", 49 | value: String(format: "%.1fs", totalTime) 50 | ) 51 | ZStack(alignment: .topTrailing) { 52 | MetricCard( 53 | icon: "memorychip", 54 | title: "Memory", 55 | value: FormatUtilities.formatMemory(memoryUsed) 56 | ) 57 | Button(action: { 58 | #if os(iOS) 59 | showMemoryDetails = true 60 | #endif 61 | }) { 62 | Image(systemName: "info.circle.fill") 63 | .font(.caption) 64 | .foregroundStyle(.secondary) 65 | .frame(width: 44, height: 44) 66 | .contentShape(Rectangle()) 67 | } 68 | .buttonStyle(.plain) 69 | .help( 70 | """ 71 | Active Memory: \(FormatUtilities.formatMemory(memoryUsed))/\(FormatUtilities.formatMemory(GPU.memoryLimit)) 72 | Cache Memory: \(FormatUtilities.formatMemory(cacheMemory))/\(FormatUtilities.formatMemory(GPU.cacheLimit)) 73 | Peak Memory: \(FormatUtilities.formatMemory(peakMemory)) 74 | """ 75 | ) 76 | } 77 | } 78 | } 79 | .padding(.top, 8) 80 | .alert("Memory Details", isPresented: $showMemoryDetails) { 81 | Button("OK", role: .cancel) {} 82 | } message: { 83 | Text( 84 | """ 85 | Active Memory: \(FormatUtilities.formatMemory(memoryUsed))/\(FormatUtilities.formatMemory(GPU.memoryLimit)) 86 | Cache Memory: \(FormatUtilities.formatMemory(cacheMemory))/\(FormatUtilities.formatMemory(GPU.cacheLimit)) 87 | Peak Memory: \(FormatUtilities.formatMemory(peakMemory)) 88 | """) 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /Tools/embedder-tool/VectorOperations.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import Accelerate 4 | import Foundation 5 | 6 | enum VectorOperations { 7 | static func cosineSimilarity(_ lhs: [Float], _ rhs: [Float]) -> Float { 8 | guard lhs.count == rhs.count else { return 0 } 9 | guard !lhs.isEmpty else { return 0 } 10 | 11 | var dot: Float = 0 12 | var lhsNormSquared: Float = 0 13 | var rhsNormSquared: Float = 0 14 | 15 | vDSP_dotpr(lhs, 1, rhs, 1, &dot, vDSP_Length(lhs.count)) 16 | vDSP_svesq(lhs, 1, &lhsNormSquared, vDSP_Length(lhs.count)) 17 | vDSP_svesq(rhs, 1, &rhsNormSquared, vDSP_Length(rhs.count)) 18 | 19 | let denominator = sqrt(lhsNormSquared * rhsNormSquared) 20 | guard denominator > 1e-9 else { return 0 } 21 | return dot / denominator 22 | } 23 | 24 | static func l2Norm(_ vector: [Float]) -> Float { 25 | guard !vector.isEmpty else { return 0 } 26 | var sumSquares: Float = 0 27 | vDSP_svesq(vector, 1, &sumSquares, vDSP_Length(vector.count)) 28 | return sqrt(sumSquares) 29 | } 30 | 31 | static func normalize(_ vector: [Float]) -> [Float] { 32 | guard !vector.isEmpty else { return [] } 33 | 34 | let sanitized = sanitize(vector) 35 | 36 | var sumSquares: Float = 0 37 | vDSP_svesq(sanitized, 1, &sumSquares, vDSP_Length(sanitized.count)) 38 | 39 | guard sumSquares.isFinite else { return [] } 40 | guard sumSquares > 1e-9 else { return sanitized } 41 | 42 | var divisor = sqrt(sumSquares) 43 | var normalized = [Float](repeating: 0, count: sanitized.count) 44 | vDSP_vsdiv(sanitized, 1, &divisor, &normalized, 1, vDSP_Length(sanitized.count)) 45 | 46 | return normalized 47 | } 48 | 49 | static func sanitize(_ vector: [Float]) -> [Float] { 50 | guard !vector.isEmpty else { return [] } 51 | 52 | var sanitized = vector 53 | for index in sanitized.indices where !sanitized[index].isFinite { 54 | sanitized[index] = 0 55 | } 56 | return sanitized 57 | } 58 | 59 | static func dotProduct(_ lhs: [Float], _ rhs: [Float]) -> Float { 60 | guard lhs.count == rhs.count else { return 0 } 61 | guard !lhs.isEmpty else { return 0 } 62 | 63 | var result: Float = 0 64 | vDSP_dotpr(lhs, 1, rhs, 1, &result, vDSP_Length(lhs.count)) 65 | return result 66 | } 67 | 68 | static func normalizeBatch(_ vectors: [[Float]]) -> [[Float]] { 69 | vectors.map { normalize($0) } 70 | } 71 | 72 | static func batchCosineSimilarity(query: [Float], documents: [[Float]]) -> [Float] { 73 | documents.map { cosineSimilarity(query, $0) } 74 | } 75 | 76 | static func batchDotProduct(query: [Float], documents: [[Float]]) -> [Float] { 77 | documents.map { dotProduct(query, $0) } 78 | } 79 | } 80 | 81 | extension VectorOperations { 82 | static func hasNonFiniteValues(_ vector: [Float]) -> Bool { 83 | vector.contains(where: { !$0.isFinite }) 84 | } 85 | 86 | static func statistics(_ vector: [Float]) -> (mean: Float, min: Float, max: Float, norm: Float) 87 | { 88 | guard !vector.isEmpty else { return (0, 0, 0, 0) } 89 | 90 | var mean: Float = 0 91 | var min: Float = 0 92 | var max: Float = 0 93 | 94 | vDSP_meanv(vector, 1, &mean, vDSP_Length(vector.count)) 95 | vDSP_minv(vector, 1, &min, vDSP_Length(vector.count)) 96 | vDSP_maxv(vector, 1, &max, vDSP_Length(vector.count)) 97 | 98 | return (mean, min, max, l2Norm(vector)) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /.github/workflows/pull_request.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: pull_request 4 | 5 | permissions: 6 | contents: read 7 | 8 | jobs: 9 | lint: 10 | if: github.repository == 'ml-explore/mlx-swift-examples' 11 | runs-on: ubuntu-22.04 12 | container: 13 | image: swift:6.2-rhel-ubi9 14 | steps: 15 | - uses: actions/checkout@v6 16 | with: 17 | submodules: recursive 18 | 19 | - name: Setup uv 20 | uses: astral-sh/setup-uv@v6 21 | with: 22 | activate-environment: true 23 | 24 | - name: Setup pre-commit 25 | shell: sh 26 | run: | 27 | uv pip install pre-commit 28 | 29 | - name: Get swift-format tag 30 | id: swift-format 31 | shell: sh 32 | run: | 33 | cd /tmp 34 | LATEST_TAG=$(curl -s https://api.github.com/repos/swiftlang/swift-format/releases/latest | \ 35 | grep '"tag_name":' | \ 36 | sed -E 's/.*"([^"]+)".*/\1/') 37 | echo "swift-format $LATEST_TAG" 38 | echo "SWIFT_FORMAT_VERSION=$LATEST_TAG" >> $GITHUB_OUTPUT 39 | 40 | - name: Cache swift-format build 41 | uses: actions/cache@v4 42 | id: cache-swift-format 43 | with: 44 | path: /tmp/swift-format/.build 45 | key: ${{ runner.os }}-swift-format-build-${{ steps.swift-format.outputs.SWIFT_FORMAT_VERSION }} 46 | 47 | - name: Build swift-format 48 | if: steps.cache-swift-format.outputs.cache-hit != 'true' 49 | shell: sh 50 | run: | 51 | cd /tmp 52 | git clone --branch ${{ steps.swift-format.outputs.SWIFT_FORMAT_VERSION }} --depth 1 https://github.com/swiftlang/swift-format.git 53 | cd swift-format 54 | swift build -c release 55 | 56 | - name: Link swift-format to /usr/local/bin 57 | shell: sh 58 | run: | 59 | cd /tmp/swift-format 60 | ln -s "$(swift build --show-bin-path -c release)/swift-format" /usr/local/bin/swift-format 61 | 62 | - name: Configure safe directory for git 63 | shell: sh 64 | run: | 65 | git config --global --add safe.directory "$GITHUB_WORKSPACE" 66 | 67 | - name: Run style checks 68 | shell: sh 69 | run: | 70 | pre-commit run --all || (echo "Style checks failed, please install pre-commit and run pre-commit run --all and push the change"; echo ""; git --no-pager diff; exit 1) 71 | 72 | 73 | mac_build_and_test: 74 | needs: lint 75 | if: github.repository == 'ml-explore/mlx-swift-examples' 76 | runs-on: [self-hosted, macos] 77 | steps: 78 | - uses: actions/checkout@v6 79 | with: 80 | submodules: recursive 81 | 82 | - name: Verify MetalToolchain installed 83 | shell: bash 84 | run: xcodebuild -showComponent MetalToolchain 85 | 86 | - name: Build Package (Xcode, macOS) 87 | shell: sh 88 | run: | 89 | xcodebuild -version 90 | xcrun --show-sdk-build-version 91 | swift --version 92 | rm -rf ~/Library/Developer/Xcode/DerivedData/* 93 | xcodebuild build-for-testing -scheme mlx-libraries-Package -destination 'platform=macOS' 94 | 95 | - name: Build tools (Xcode, macOS) 96 | shell: sh 97 | run: | 98 | xcodebuild -version 99 | xcrun --show-sdk-build-version 100 | swift --version 101 | find . -name Package.resolved -exec rm {} \; 102 | xcodebuild -scheme llm-tool 103 | xcodebuild -scheme image-tool 104 | xcodebuild -scheme mnist-tool 105 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/Views/MediaPreviewView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // MediaPreviewView.swift 3 | // MLXChatExample 4 | // 5 | // Created by İbrahim Çetin on 21.04.2025. 6 | // 7 | 8 | import AVFoundation 9 | import AVKit 10 | import SwiftUI 11 | 12 | /// A view that displays a horizontal scrollable list of media previews (images and videos). 13 | struct MediaPreviewsView: View { 14 | /// The media selection containing arrays of image and video URLs 15 | let mediaSelection: MediaSelection 16 | 17 | var body: some View { 18 | ScrollView(.horizontal) { 19 | HStack(spacing: 8) { 20 | // Display image previews 21 | ForEach(mediaSelection.images, id: \.self) { imageURL in 22 | MediaPreviewView( 23 | mediaURL: imageURL, 24 | type: .image, 25 | onRemove: { 26 | mediaSelection.images.removeAll(where: { $0 == imageURL }) 27 | } 28 | ) 29 | } 30 | 31 | // Display video previews 32 | ForEach(mediaSelection.videos, id: \.self) { videoURL in 33 | MediaPreviewView( 34 | mediaURL: videoURL, 35 | type: .video, 36 | onRemove: { 37 | mediaSelection.videos.removeAll(where: { $0 == videoURL }) 38 | } 39 | ) 40 | } 41 | } 42 | .padding(.horizontal) 43 | } 44 | .padding(.top) 45 | } 46 | } 47 | 48 | /// A view that displays a single media item (image or video) with a remove button. 49 | struct MediaPreviewView: View { 50 | /// URL of the media file to display 51 | let mediaURL: URL 52 | /// Type of media (image or video) 53 | let type: MediaPreviewType 54 | /// Callback to handle removal of the media item 55 | let onRemove: () -> Void 56 | 57 | var body: some View { 58 | ZStack(alignment: .topTrailing) { 59 | switch type { 60 | case .image: 61 | // Display image with loading placeholder 62 | AsyncImage(url: mediaURL) { image in 63 | image 64 | .resizable() 65 | .scaledToFit() 66 | .frame(height: 100) 67 | .clipShape(RoundedRectangle(cornerRadius: 8)) 68 | } placeholder: { 69 | ProgressView() 70 | .frame(width: 150, height: 100) 71 | } 72 | case .video: 73 | // Display video player 74 | VideoPlayer(player: AVPlayer(url: mediaURL)) 75 | .frame(width: 150, height: 100) 76 | .clipShape(RoundedRectangle(cornerRadius: 8)) 77 | } 78 | 79 | RemoveButton(action: onRemove) 80 | } 81 | } 82 | } 83 | 84 | /// A button for removing media items from the preview. 85 | struct RemoveButton: View { 86 | /// Action to perform when the remove button is tapped 87 | let action: () -> Void 88 | 89 | var body: some View { 90 | Button(action: action) { 91 | Image(systemName: "xmark.circle.fill") 92 | .foregroundStyle(.secondary) 93 | .imageScale(.large) 94 | } 95 | .buttonStyle(.plain) 96 | .padding(4) 97 | } 98 | } 99 | 100 | extension MediaPreviewView { 101 | /// Defines the type of media that can be displayed in the preview 102 | enum MediaPreviewType { 103 | /// An image file 104 | case image 105 | /// A video file 106 | case video 107 | } 108 | } 109 | 110 | #Preview("Remove Button") { 111 | RemoveButton {} 112 | } 113 | -------------------------------------------------------------------------------- /Applications/MNISTTrainer/PredictionView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // PredictionView.swift 3 | // MNISTTrainer 4 | // 5 | // Created by Rounak Jain on 3/9/24. 6 | // 7 | 8 | import MLX 9 | import MLXMNIST 10 | import MLXNN 11 | import SwiftUI 12 | 13 | struct Canvas: View { 14 | 15 | @Binding var path: Path 16 | @State var lastPoint: CGPoint? 17 | 18 | var body: some View { 19 | path 20 | .stroke(.white, lineWidth: 10) 21 | .background(.black) 22 | .gesture( 23 | DragGesture(minimumDistance: 0.05) 24 | .onChanged { touch in 25 | add(point: touch.location) 26 | } 27 | .onEnded { touch in 28 | lastPoint = nil 29 | } 30 | ) 31 | } 32 | 33 | func add(point: CGPoint) { 34 | var newPath = path 35 | if let lastPoint { 36 | newPath.move(to: lastPoint) 37 | newPath.addLine(to: point) 38 | } else { 39 | newPath.move(to: point) 40 | } 41 | self.path = newPath 42 | lastPoint = point 43 | } 44 | } 45 | 46 | extension Path { 47 | mutating func center(to newMidPoint: CGPoint) { 48 | let middleX = boundingRect.midX 49 | let middleY = boundingRect.midY 50 | self = offsetBy(dx: newMidPoint.x - middleX, dy: newMidPoint.y - middleY) 51 | } 52 | } 53 | 54 | struct PredictionView: View { 55 | @State var path: Path = Path() 56 | @State var prediction: Int? 57 | let model: LeNetContainer 58 | let canvasSize = 150.0 59 | let mnistImageSize: CGSize = CGSize(width: 28, height: 28) 60 | 61 | var body: some View { 62 | VStack { 63 | if let prediction { 64 | Text("You've drawn a \(prediction)") 65 | } else { 66 | Text("Draw a digit") 67 | } 68 | Canvas(path: $path) 69 | .frame(width: canvasSize, height: canvasSize) 70 | HStack { 71 | Button("Predict") { 72 | path.center(to: CGPoint(x: canvasSize / 2, y: canvasSize / 2)) 73 | predict() 74 | } 75 | Button("Clear") { 76 | path = Path() 77 | prediction = nil 78 | } 79 | } 80 | } 81 | } 82 | 83 | @MainActor 84 | func predict() { 85 | let imageRenderer = ImageRenderer( 86 | content: Canvas(path: $path).frame(width: 150, height: 150)) 87 | 88 | if let image = imageRenderer.cgImage { 89 | Task { 90 | self.prediction = await model.evaluate(image: image) 91 | } 92 | } 93 | } 94 | } 95 | 96 | extension CGImage { 97 | func grayscaleImage(with newSize: CGSize) -> CGImage? { 98 | let colorSpace = CGColorSpaceCreateDeviceGray() 99 | let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue) 100 | 101 | guard 102 | let context = CGContext( 103 | data: nil, 104 | width: Int(newSize.width), 105 | height: Int(newSize.height), 106 | bitsPerComponent: 8, 107 | bytesPerRow: Int(newSize.width), 108 | space: colorSpace, 109 | bitmapInfo: bitmapInfo.rawValue) 110 | else { 111 | return nil 112 | } 113 | context.draw(self, in: CGRect(x: 0, y: 0, width: newSize.width, height: newSize.width)) 114 | return context.makeImage() 115 | } 116 | 117 | func pixelData() -> MLXArray { 118 | guard let data = self.dataProvider?.data else { 119 | return [] 120 | } 121 | let bytePtr = CFDataGetBytePtr(data) 122 | let count = CFDataGetLength(data) 123 | return MLXArray(UnsafeBufferPointer(start: bytePtr, count: count)) 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /Applications/LLMEval/Views/ContentView.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import AsyncAlgorithms 4 | import MLX 5 | import MLXLLM 6 | import MLXLMCommon 7 | import Metal 8 | import SwiftUI 9 | import Tokenizers 10 | 11 | struct ContentView: View { 12 | @Environment(DeviceStat.self) private var deviceStat 13 | 14 | @State var llm = LLMEvaluator() 15 | 16 | enum DisplayStyle: String, CaseIterable, Identifiable { 17 | case plain, markdown 18 | var id: Self { self } 19 | } 20 | 21 | @State private var selectedDisplayStyle = DisplayStyle.markdown 22 | @State private var showingPresetPrompts = false 23 | @State private var isPromptExpanded = false 24 | 25 | var body: some View { 26 | VStack(alignment: .leading, spacing: 0) { 27 | // Header Section 28 | HeaderView( 29 | llm: llm, 30 | selectedDisplayStyle: $selectedDisplayStyle 31 | ) 32 | 33 | Divider() 34 | .padding(.bottom, 12) 35 | 36 | // Output display 37 | OutputView( 38 | output: llm.output, 39 | displayStyle: selectedDisplayStyle, 40 | wasTruncated: llm.wasTruncated 41 | ) 42 | 43 | // Prompt input section 44 | PromptInputView( 45 | llm: llm, 46 | isPromptExpanded: $isPromptExpanded, 47 | showingPresetPrompts: $showingPresetPrompts, 48 | onGenerate: generate, 49 | onCancel: cancel 50 | ) 51 | 52 | // Performance Metrics Panel 53 | MetricsView( 54 | tokensPerSecond: llm.tokensPerSecond, 55 | timeToFirstToken: llm.timeToFirstToken, 56 | promptLength: llm.promptLength, 57 | totalTokens: llm.totalTokens, 58 | totalTime: llm.totalTime, 59 | memoryUsed: deviceStat.gpuUsage.activeMemory, 60 | cacheMemory: deviceStat.gpuUsage.cacheMemory, 61 | peakMemory: deviceStat.gpuUsage.peakMemory 62 | ) 63 | } 64 | #if os(visionOS) 65 | .padding(40) 66 | #else 67 | .padding() 68 | #endif 69 | .toolbar { 70 | ToolbarItem(placement: .primaryAction) { 71 | Button { 72 | Task { 73 | copyToClipboard(llm.output) 74 | } 75 | } label: { 76 | Label("Copy Output", systemImage: "doc.on.doc.fill") 77 | } 78 | .disabled(llm.output == "") 79 | .labelStyle(.titleAndIcon) 80 | } 81 | 82 | } 83 | .task { 84 | do { 85 | // pre-load the weights on launch to speed up the first generation 86 | _ = try await llm.load() 87 | } catch { 88 | llm.output = "Failed: \(error)" 89 | } 90 | } 91 | .sheet(isPresented: $showingPresetPrompts) { 92 | PresetPromptsSheet(isPresented: $showingPresetPrompts) { preset in 93 | llm.prompt = preset.prompt 94 | llm.includeWeatherTool = preset.enableTools 95 | llm.enableThinking = preset.enableThinking 96 | } 97 | } 98 | .overlay { 99 | if llm.isLoading { 100 | LoadingOverlayView( 101 | modelInfo: llm.modelInfo, 102 | downloadProgress: llm.downloadProgress, 103 | progressDescription: llm.totalSize 104 | ) 105 | } 106 | } 107 | } 108 | 109 | private func generate() { 110 | llm.generate() 111 | } 112 | 113 | private func cancel() { 114 | llm.cancelGeneration() 115 | } 116 | 117 | private func copyToClipboard(_ string: String) { 118 | #if os(macOS) 119 | NSPasteboard.general.clearContents() 120 | NSPasteboard.general.setString(string, forType: .string) 121 | #else 122 | UIPasteboard.general.string = string 123 | #endif 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /Applications/LLMEval/Views/PresetPromptsSheet.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2025 Apple Inc. 2 | 3 | import SwiftUI 4 | 5 | struct PresetPromptsSheet: View { 6 | @Binding var isPresented: Bool 7 | let onSelect: (PresetPrompt) -> Void 8 | 9 | var body: some View { 10 | NavigationStack { 11 | List { 12 | ForEach(PresetPrompts.all) { preset in 13 | Button { 14 | onSelect(preset) 15 | isPresented = false 16 | } label: { 17 | HStack(alignment: .center, spacing: 12) { 18 | // Show the actual prompt (first 2 lines) 19 | // Clean up whitespace for better preview 20 | let cleanedPrompt = preset.prompt 21 | .trimmingCharacters(in: .whitespacesAndNewlines) 22 | .replacingOccurrences( 23 | of: #"\n\n+"#, with: " ", options: .regularExpression 24 | ) 25 | .replacingOccurrences( 26 | of: #"\s+"#, with: " ", options: .regularExpression) 27 | 28 | Text(cleanedPrompt) 29 | .multilineTextAlignment(.leading) 30 | .lineLimit(2) 31 | .font(.body) 32 | .foregroundStyle(.primary) 33 | .frame(maxWidth: .infinity, alignment: .leading) 34 | 35 | // Show indicators if present 36 | if preset.enableThinking || preset.enableTools || preset.isLongPrompt { 37 | HStack(spacing: 6) { 38 | if preset.enableThinking { 39 | BadgeView(icon: "brain", text: "Thinking", color: .purple) 40 | } 41 | if preset.enableTools { 42 | BadgeView(icon: "hammer.fill", text: "Tools", color: .blue) 43 | } 44 | if preset.isLongPrompt { 45 | BadgeView( 46 | icon: "doc.text.fill", text: "Long", color: .orange) 47 | } 48 | } 49 | } 50 | } 51 | .padding(.vertical, 8) 52 | #if os(macOS) 53 | .frame(maxWidth: .infinity, alignment: .leading) 54 | .contentShape(Rectangle()) 55 | #endif 56 | } 57 | .buttonStyle(.plain) 58 | #if os(macOS) 59 | .listRowInsets(EdgeInsets(top: 8, leading: 12, bottom: 8, trailing: 12)) 60 | #endif 61 | } 62 | } 63 | #if os(macOS) 64 | .listStyle(.inset) 65 | #endif 66 | .navigationTitle("Example Prompts") 67 | #if !os(macOS) 68 | .navigationBarTitleDisplayMode(.inline) 69 | #endif 70 | .toolbar { 71 | ToolbarItem(placement: .cancellationAction) { 72 | Button("Close") { 73 | isPresented = false 74 | } 75 | } 76 | } 77 | } 78 | #if os(macOS) 79 | .frame(minWidth: 600, minHeight: 500) 80 | #endif 81 | } 82 | } 83 | 84 | // Badge component 85 | private struct BadgeView: View { 86 | let icon: String 87 | let text: String 88 | let color: Color 89 | 90 | var body: some View { 91 | Label(text, systemImage: icon) 92 | .font(.caption2) 93 | .fontWeight(.medium) 94 | .foregroundStyle(.white) 95 | .padding(.horizontal, 8) 96 | .padding(.vertical, 4) 97 | .background(color.gradient, in: Capsule()) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "originHash" : "0dd305f09ba278569849d2b91cb2ef7a95956604f60ef1c3013f0e7b4b61d1a7", 3 | "pins" : [ 4 | { 5 | "identity" : "gzipswift", 6 | "kind" : "remoteSourceControl", 7 | "location" : "https://github.com/1024jp/GzipSwift", 8 | "state" : { 9 | "revision" : "731037f6cc2be2ec01562f6597c1d0aa3fe6fd05", 10 | "version" : "6.0.1" 11 | } 12 | }, 13 | { 14 | "identity" : "mlx-swift", 15 | "kind" : "remoteSourceControl", 16 | "location" : "https://github.com/ml-explore/mlx-swift", 17 | "state" : { 18 | "revision" : "072b684acaae80b6a463abab3a103732f33774bf", 19 | "version" : "0.29.1" 20 | } 21 | }, 22 | { 23 | "identity" : "mlx-swift-lm", 24 | "kind" : "remoteSourceControl", 25 | "location" : "https://github.com/ml-explore/mlx-swift-lm", 26 | "state" : { 27 | "revision" : "01852971866b17889f7aff27500cc62d9148b0df", 28 | "version" : "2.29.2" 29 | } 30 | }, 31 | { 32 | "identity" : "networkimage", 33 | "kind" : "remoteSourceControl", 34 | "location" : "https://github.com/gonzalezreal/NetworkImage", 35 | "state" : { 36 | "revision" : "2849f5323265386e200484b0d0f896e73c3411b9", 37 | "version" : "6.0.1" 38 | } 39 | }, 40 | { 41 | "identity" : "progress.swift", 42 | "kind" : "remoteSourceControl", 43 | "location" : "https://github.com/jkandzi/Progress.swift", 44 | "state" : { 45 | "revision" : "fed6598735d7982058690acf8f52a0a5fdaeb3e0", 46 | "version" : "0.4.0" 47 | } 48 | }, 49 | { 50 | "identity" : "swift-argument-parser", 51 | "kind" : "remoteSourceControl", 52 | "location" : "https://github.com/apple/swift-argument-parser.git", 53 | "state" : { 54 | "revision" : "cdd0ef3755280949551dc26dee5de9ddeda89f54", 55 | "version" : "1.6.2" 56 | } 57 | }, 58 | { 59 | "identity" : "swift-async-algorithms", 60 | "kind" : "remoteSourceControl", 61 | "location" : "https://github.com/apple/swift-async-algorithms.git", 62 | "state" : { 63 | "revision" : "042e1c4d9d19748c9c228f8d4ebc97bb1e339b0b", 64 | "version" : "1.0.4" 65 | } 66 | }, 67 | { 68 | "identity" : "swift-cmark", 69 | "kind" : "remoteSourceControl", 70 | "location" : "https://github.com/swiftlang/swift-cmark", 71 | "state" : { 72 | "revision" : "b97d09472e847a416629f026eceae0e2afcfad65", 73 | "version" : "0.7.0" 74 | } 75 | }, 76 | { 77 | "identity" : "swift-collections", 78 | "kind" : "remoteSourceControl", 79 | "location" : "https://github.com/apple/swift-collections.git", 80 | "state" : { 81 | "revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e", 82 | "version" : "1.3.0" 83 | } 84 | }, 85 | { 86 | "identity" : "swift-jinja", 87 | "kind" : "remoteSourceControl", 88 | "location" : "https://github.com/huggingface/swift-jinja.git", 89 | "state" : { 90 | "revision" : "c1ef5963ba4a97a589b9c9583ff4ee3352a86d23", 91 | "version" : "2.1.0" 92 | } 93 | }, 94 | { 95 | "identity" : "swift-markdown-ui", 96 | "kind" : "remoteSourceControl", 97 | "location" : "https://github.com/gonzalezreal/swift-markdown-ui", 98 | "state" : { 99 | "revision" : "5f613358148239d0292c0cef674a3c2314737f9e", 100 | "version" : "2.4.1" 101 | } 102 | }, 103 | { 104 | "identity" : "swift-numerics", 105 | "kind" : "remoteSourceControl", 106 | "location" : "https://github.com/apple/swift-numerics", 107 | "state" : { 108 | "revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2", 109 | "version" : "1.1.1" 110 | } 111 | }, 112 | { 113 | "identity" : "swift-transformers", 114 | "kind" : "remoteSourceControl", 115 | "location" : "https://github.com/huggingface/swift-transformers", 116 | "state" : { 117 | "revision" : "94610577e4af9bbc267060af1e25e977604dd796", 118 | "version" : "1.1.1" 119 | } 120 | } 121 | ], 122 | "version" : 3 123 | } 124 | -------------------------------------------------------------------------------- /Libraries/StableDiffusion/Sampler.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | import MLX 5 | 6 | // port of https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/sampler.py 7 | 8 | /// Interpolate the function defined by `(0 ..< y.count) y)` at positions `xNew`. 9 | func interpolate(y: MLXArray, xNew: MLXArray) -> MLXArray { 10 | let xLow = xNew.asType(.int32) 11 | let xHigh = minimum(xLow + 1, y.count - 1) 12 | 13 | let yLow = y[xLow] 14 | let yHigh = y[xHigh] 15 | let deltaX = xNew - xLow 16 | let yNew = yLow * (1 - deltaX) + deltaX * yHigh 17 | 18 | return yNew 19 | } 20 | 21 | /// A simple Euler integrator that can be used to sample from our diffusion models. 22 | /// 23 | /// The method ``step()`` performs one Euler step from `x_t` to `x_t_prev`. 24 | class SimpleEulerSampler { 25 | 26 | let sigmas: MLXArray 27 | 28 | public init(configuration: DiffusionConfiguration) { 29 | let betas: MLXArray 30 | 31 | // compute the noise schedule 32 | switch configuration.betaSchedule { 33 | case .linear: 34 | betas = MLXArray.linspace( 35 | configuration.betaStart, configuration.betaEnd, count: configuration.trainSteps) 36 | case .scaledLinear: 37 | betas = MLXArray.linspace( 38 | sqrt(configuration.betaStart), sqrt(configuration.betaEnd), 39 | count: configuration.trainSteps 40 | ).square() 41 | } 42 | 43 | let alphas = 1 - betas 44 | let alphasCumprod = cumprod(alphas) 45 | 46 | self.sigmas = concatenated([ 47 | MLXArray.zeros([1]), ((1 - alphasCumprod) / alphasCumprod).sqrt(), 48 | ]) 49 | } 50 | 51 | public var maxTime: Int { 52 | sigmas.count - 1 53 | } 54 | 55 | public func samplePrior(shape: [Int], dType: DType = .float32, key: MLXArray? = nil) -> MLXArray 56 | { 57 | let noise = MLXRandom.normal(shape, key: key) 58 | return (noise * sigmas[-1] * (sigmas[-1].square() + 1).rsqrt()).asType(dType) 59 | } 60 | 61 | public func addNoise(x: MLXArray, t: MLXArray, key: MLXArray? = nil) -> MLXArray { 62 | let noise = MLXRandom.normal(x.shape, key: key) 63 | let s = sigmas(t) 64 | return (x + noise * s) * (s.square() + 1).rsqrt() 65 | } 66 | 67 | public func sigmas(_ t: MLXArray) -> MLXArray { 68 | interpolate(y: sigmas, xNew: t) 69 | } 70 | 71 | public func timeSteps(steps: Int, start: Int? = nil, dType: DType = .float32) -> [( 72 | MLXArray, MLXArray 73 | )] { 74 | let start = start ?? (sigmas.count - 1) 75 | precondition(0 < start) 76 | precondition(start <= sigmas.count - 1) 77 | let steps = MLX.linspace(start, 0, count: steps + 1).asType(dType) 78 | 79 | return Array(zip(steps, steps[1...])) 80 | } 81 | 82 | open func step(epsPred: MLXArray, xt: MLXArray, t: MLXArray, tPrev: MLXArray) -> MLXArray { 83 | let dtype = epsPred.dtype 84 | let sigma = sigmas(t).asType(dtype) 85 | let sigmaPrev = sigmas(tPrev).asType(dtype) 86 | 87 | let dt = sigmaPrev - sigma 88 | var xtPrev = (sigma.square() + 1).sqrt() * xt + epsPred * dt 89 | xtPrev = xtPrev * (sigmaPrev.square() + 1).rsqrt() 90 | 91 | return xtPrev 92 | } 93 | } 94 | 95 | class SimpleEulerAncestralSampler: SimpleEulerSampler { 96 | 97 | open override func step(epsPred: MLXArray, xt: MLXArray, t: MLXArray, tPrev: MLXArray) 98 | -> MLXArray 99 | { 100 | let dtype = epsPred.dtype 101 | let sigma = sigmas(t).asType(dtype) 102 | let sigmaPrev = sigmas(tPrev).asType(dtype) 103 | 104 | let sigma2 = sigma.square() 105 | let sigmaPrev2 = sigmaPrev.square() 106 | let sigmaUp = (sigmaPrev2 * (sigma2 - sigmaPrev2) / sigma2).sqrt() 107 | let sigmaDown = (sigmaPrev2 - sigmaUp ** 2).sqrt() 108 | 109 | let dt = sigmaDown - sigma 110 | var xtPrev = (sigma2 + 1).sqrt() * xt + epsPred * dt 111 | let noise = MLXRandom.normal(xtPrev.shape).asType(xtPrev.dtype) 112 | xtPrev = xtPrev + noise * sigmaUp 113 | xtPrev = xtPrev * (sigmaPrev2 + 1).rsqrt() 114 | 115 | return xtPrev 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/embedder-tool.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 25 | 26 | 32 | 33 | 45 | 47 | 53 | 54 | 55 | 56 | 59 | 60 | 63 | 64 | 67 | 68 | 71 | 72 | 75 | 76 | 79 | 80 | 81 | 82 | 88 | 90 | 96 | 97 | 98 | 99 | 101 | 102 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /Data/lora/wikisql.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Apple Inc. 2 | 3 | """ 4 | Code to preprocess the WikiSQL dataset adapted from 5 | https://github.com/salesforce/WikiSQL and 6 | https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb . 7 | """ 8 | 9 | 10 | import json 11 | import os 12 | 13 | 14 | def load(): 15 | """ 16 | Load all three splits of the WikiSQL dataset. 17 | """ 18 | return (WikiSQL(dn) for dn in ["train", "dev", "test"]) 19 | 20 | 21 | class WikiSQL: 22 | def __init__(self, dataset, save_dir="/tmp"): 23 | valid_sets = ("train", "dev", "test") 24 | if dataset not in valid_sets: 25 | raise ValueError(f"Dataset must be in {valid_sets}, got {dataset}") 26 | data_dir = os.path.join(save_dir, "wikisql") 27 | self._maybe_download(data_dir) 28 | 29 | self._parse_tables(os.path.join(data_dir, f"data/{dataset}.tables.jsonl")) 30 | self._parse_queries(os.path.join(data_dir, f"data/{dataset}.jsonl")) 31 | 32 | def _maybe_download(self, data_dir): 33 | if not os.path.exists(data_dir): 34 | import io 35 | import tarfile 36 | from urllib import request 37 | 38 | url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2" 39 | r = request.urlopen(url) 40 | with tarfile.open(fileobj=io.BytesIO(r.read())) as tf: 41 | tf.extractall(data_dir) 42 | 43 | def _parse_tables(self, tables): 44 | self._tables = {} 45 | with open(tables) as f: 46 | for line in f: 47 | table = json.loads(line) 48 | self._tables[table["id"]] = { 49 | "columns": table["header"], 50 | "types": table["types"], 51 | "desc": f"table: {table['id']}\ncolumns: {', '.join(table['header'])}", 52 | } 53 | 54 | def _parse_queries(self, queries): 55 | self._queries = [] 56 | with open(queries) as f: 57 | for line in f: 58 | query = json.loads(line) 59 | table = self._tables[query["table_id"]] 60 | question = query["question"] 61 | answer = self.query_to_text( 62 | query["sql"], query["table_id"], table["columns"], table["types"] 63 | ) 64 | self._queries.append( 65 | f"{table['desc']}\nQ: {question}\nA: {answer}" 66 | ) 67 | 68 | def query_to_text(self, query, table, columns, types): 69 | aggregation_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"] 70 | condition_ops = ["=", ">", "<", "OP"] 71 | column = columns[query["sel"]] 72 | aggregation = (aggregation_ops[query["agg"]] + " ") if query["agg"] > 0 else "" 73 | sql = f"SELECT {aggregation}{column} FROM {table}" 74 | 75 | conditions = query["conds"] 76 | if conditions: 77 | cs = [] 78 | for i, o, v in conditions: 79 | column = columns[i] 80 | op = condition_ops[o] 81 | 82 | if types[i] == "text": 83 | value = f"'{v}'" 84 | else: 85 | value = v 86 | cs.append(f"{column} {op} {value}") 87 | 88 | sql += " WHERE " + " AND ".join(cs) 89 | 90 | return sql 91 | 92 | def __getitem__(self, idx): 93 | return self._queries[idx] 94 | 95 | def __len__(self): 96 | return len(self._queries) 97 | 98 | 99 | if __name__ == "__main__": 100 | datanames = ["train", "dev", "test"] 101 | sizes = [56355, 8421, 15878] 102 | for dataname, size in zip(datanames, sizes): 103 | len(WikiSQL(dataname)) == size, f"Wrong {dataname} set size." 104 | 105 | # Write the sets to jsonl 106 | import json 107 | 108 | train, dev, test = load() 109 | datasets = [ 110 | (train, "train", 1000), 111 | (dev, "valid", 100), 112 | (test, "test", 100), 113 | ] 114 | for dataset, name, size in datasets: 115 | with open(f"data/{name}.jsonl", "w") as fid: 116 | for e, t in zip(range(size), dataset): 117 | # Strip the , since the tokenizer adds them 118 | json.dump({"text": t[3:-4]}, fid) 119 | fid.write("\n") 120 | -------------------------------------------------------------------------------- /Tools/LinearModelTraining/LinearModelTraining.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import ArgumentParser 4 | import Foundation 5 | import MLX 6 | import MLXNN 7 | import MLXOptimizers 8 | 9 | #if swift(>=5.10) 10 | extension MLX.DeviceType: @retroactive ExpressibleByArgument { 11 | public init?(argument: String) { 12 | self.init(rawValue: argument) 13 | } 14 | } 15 | #else 16 | extension MLX.DeviceType: ExpressibleByArgument { 17 | public init?(argument: String) { 18 | self.init(rawValue: argument) 19 | } 20 | } 21 | #endif 22 | 23 | @main 24 | struct Train: AsyncParsableCommand { 25 | 26 | @Option var epochs = 20 27 | @Option var batchSize = 8 28 | 29 | @Option var m: Float = 0.25 30 | @Option var b: Float = 7 31 | 32 | @Flag var compile = false 33 | 34 | @Option var device = DeviceType.cpu 35 | 36 | func run() async throws { 37 | Device.setDefault(device: Device(device)) 38 | 39 | // A very simple model that implements the equation 40 | // for a linear function: y = mx + b. This can be trained 41 | // to match data – in this case, an unknown (to the model) 42 | // linear function. 43 | // 44 | // This is a nice example because most people know how 45 | // linear functions work and we can see how the slope 46 | // and intercept converge. 47 | class LinearFunctionModel: Module, UnaryLayer { 48 | let m = MLXRandom.uniform(low: -5.0, high: 5.0) 49 | let b = MLXRandom.uniform(low: -5.0, high: 5.0) 50 | 51 | func callAsFunction(_ x: MLXArray) -> MLXArray { 52 | m * x + b 53 | } 54 | } 55 | 56 | // Measure the distance from the prediction (model(x)) and the 57 | // ground truth (y). This gives feedback on how close the 58 | // prediction is from matching the truth. 59 | func loss(model: LinearFunctionModel, x: MLXArray, y: MLXArray) -> MLXArray { 60 | mseLoss(predictions: model(x), targets: y, reduction: .mean) 61 | } 62 | 63 | let model = LinearFunctionModel() 64 | eval(model.parameters()) 65 | 66 | let lg = valueAndGrad(model: model, loss) 67 | 68 | // The optimizer will use the gradients update the model parameters 69 | let optimizer = SGD(learningRate: 1e-1) 70 | 71 | // The function to train our model against. It doesn't have 72 | // to be linear, but matching what the model models is easy 73 | // to understand. 74 | func f(_ x: MLXArray) -> MLXArray { 75 | // These are the target parameters 76 | let m = self.m 77 | let b = self.b 78 | 79 | // Our actual function 80 | return m * x + b 81 | } 82 | 83 | func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray { 84 | let (loss, grads) = lg(model, x, y) 85 | optimizer.update(model: model, gradients: grads) 86 | return loss 87 | } 88 | 89 | let resolvedStep = 90 | self.compile 91 | ? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step 92 | 93 | for _ in 0 ..< epochs { 94 | // We expect that the parameters will approach the targets 95 | print("target: b = \(b), m = \(m)") 96 | print("parameters: \(model.parameters())") 97 | 98 | // Generate random training data along with the ground truth. 99 | // Notice that the shape is [B, 1] where B is the batch 100 | // dimension. This allows us to train on several samples simultaneously. 101 | // 102 | // Note: A very large batch size will take longer to converge because 103 | // the gradient will be representing too many samples down into 104 | // a single float parameter. 105 | let x = MLXRandom.uniform(low: -5.0, high: 5.0, [batchSize, 1]) 106 | let y = f(x) 107 | eval(x, y) 108 | 109 | // Compute the loss and gradients. Use the optimizer 110 | // to adjust the parameters closer to the target. 111 | let loss = resolvedStep(x, y) 112 | 113 | eval(model, optimizer) 114 | 115 | // We should see this converge toward 0 116 | print("loss: \(loss)") 117 | } 118 | 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /Libraries/StableDiffusion/Tokenizer.swift: -------------------------------------------------------------------------------- 1 | // Copyright © 2024 Apple Inc. 2 | 3 | import Foundation 4 | 5 | struct Bigram: Hashable { 6 | let a: String 7 | let b: String 8 | 9 | init(_ s: String) { 10 | let pieces = s.split(separator: " ") 11 | precondition(pieces.count == 2, "BPEPair expected two pieces for '\(s)'") 12 | self.a = String(pieces[0]) 13 | self.b = String(pieces[1]) 14 | } 15 | 16 | init(_ a: String, _ b: String) { 17 | self.a = a 18 | self.b = b 19 | } 20 | 21 | init(_ v: (String, String)) { 22 | self.a = v.0 23 | self.b = v.1 24 | } 25 | } 26 | 27 | /// A CLIP tokenizer. 28 | /// 29 | /// Ported from: 30 | /// 31 | /// - https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py 32 | /// - https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py 33 | /// 34 | /// Ideally this would be a tokenizer from `swift-transformers` but this is too special purpose to be representable in 35 | /// what exists there (at time of writing). 36 | class CLIPTokenizer { 37 | 38 | let pattern = 39 | #/<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+/# 40 | let bpeRanks: [Bigram: Int] 41 | let vocabulary: [String: Int] 42 | 43 | let bos = "<|startoftext|>" 44 | let eos = "<|endoftext|>" 45 | 46 | let bosToken: Int 47 | let eosToken: Int 48 | 49 | var cache = [String: [String]]() 50 | 51 | init(merges: [String], vocabulary: [String: Int]) { 52 | self.bpeRanks = Dictionary( 53 | uniqueKeysWithValues: 54 | merges 55 | .map { Bigram($0) } 56 | .enumerated() 57 | .map { ($0.element, $0.offset) }) 58 | 59 | self.vocabulary = vocabulary 60 | self.cache[bos] = [bos] 61 | self.cache[eos] = [eos] 62 | self.bosToken = vocabulary[bos]! 63 | self.eosToken = vocabulary[eos]! 64 | } 65 | 66 | func bpe(text: String) -> [String] { 67 | if let result = cache[text] { 68 | return result 69 | } 70 | 71 | precondition(!text.isEmpty) 72 | 73 | var unigrams = text.dropLast().map { String($0) } + ["\(text.last!)"] 74 | var uniqueBigrams = Set(zip(unigrams, unigrams.dropFirst()).map { Bigram($0) }) 75 | 76 | // In every iteration try to merge the two most likely bigrams. If none 77 | // was merged we are done 78 | 79 | while !uniqueBigrams.isEmpty { 80 | let (bigram, _) = 81 | uniqueBigrams 82 | .map { ($0, bpeRanks[$0] ?? Int.max) } 83 | .min { $0.1 < $1.1 }! 84 | 85 | if bpeRanks[bigram] == nil { 86 | break 87 | } 88 | 89 | var newUnigrams = [String]() 90 | var skip = false 91 | 92 | for (a, b) in zip(unigrams, unigrams.dropFirst()) { 93 | if skip { 94 | skip = false 95 | continue 96 | } 97 | 98 | if Bigram(a, b) == bigram { 99 | newUnigrams.append(a + b) 100 | skip = true 101 | } else { 102 | newUnigrams.append(a) 103 | } 104 | } 105 | 106 | if !skip, let last = unigrams.last { 107 | newUnigrams.append(last) 108 | } 109 | 110 | unigrams = newUnigrams 111 | uniqueBigrams = Set(zip(unigrams, unigrams.dropFirst()).map { Bigram($0) }) 112 | } 113 | 114 | cache[text] = unigrams 115 | 116 | return unigrams 117 | } 118 | 119 | public func tokenize(text: String) -> [Int32] { 120 | // Lower case cleanup and split according to self.pat. Hugging Face does 121 | // a much more thorough job here but this should suffice for 95% of 122 | // cases. 123 | 124 | let clean = text.lowercased().replacing(#/\s+/#, with: " ") 125 | let tokens = clean.matches(of: pattern).map { $0.description } 126 | 127 | // Split the tokens according to the byte-pair merge file 128 | let bpeTokens = tokens.flatMap { bpe(text: String($0)) } 129 | 130 | // Map to token ids and return 131 | let result = [bosToken] + bpeTokens.compactMap { vocabulary[$0] } + [eosToken] 132 | 133 | return result.map { Int32($0) } 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /Applications/MLXChatExample/README.md: -------------------------------------------------------------------------------- 1 | # MLX Chat Example 2 | 3 | A lightweight chat application demonstrating MLX integration for iOS and macOS. Built with SwiftUI, this example project shows how to implement both Large Language Models (LLMs) and Vision Language Models (VLMs) using MLX. 4 | 5 | MLX Chat Example Screenshot 6 | 7 | ## Features 8 | 9 | - 🤖 LLM and VLM support with real-time text generation 10 | - 📱 Cross-platform (iOS and macOS) 11 | - 🖼️ Image and video input for vision models 12 | - 💾 Efficient model caching and memory management 13 | - ⚡️ Async/await based generation with cancellation support 14 | - 🎨 Modern SwiftUI interface 15 | - 📝 Comprehensive documentation and comments 16 | 17 | ## Requirements 18 | 19 | - iOS 17.0+ / macOS 14.0+ 20 | - Xcode 15.0+ 21 | - Swift 5.9+ 22 | 23 | ## Dependencies 24 | 25 | - [MLX](https://github.com/ml-explore/mlx-swift): Core machine learning operations 26 | - [MLXLMCommon](https://github.com/ml-explore/mlx-swift-lm/tree/main/Libraries/MLXLMCommon): Common language model utilities 27 | - [MLXLLM](https://github.com/ml-explore/mlx-swift-lm/tree/main/Libraries/MLXLLM): Large language model support 28 | - [MLXVLM](https://github.com/ml-explore/mlx-swift-lm/tree/main/Libraries/MLXVLM): Vision-language model support 29 | 30 | ## Project Structure 31 | 32 | ``` 33 | MLXChatExample/ 34 | ├── Views/ # SwiftUI views 35 | ├── Models/ # Data models 36 | ├── ViewModels/ # Business logic 37 | ├── Services/ # MLX integration 38 | └── Support/ # Utilities 39 | ``` 40 | 41 | ## Technical Overview 42 | 43 | The project follows MVVM architecture with clear separation between UI and business logic. The core functionality is split into two main components: 44 | 45 | ### MLXService 46 | 47 | Core service handling all model operations: 48 | - Model loading and caching with memory management 49 | - Async text generation with streaming support 50 | - GPU memory optimization 51 | - Model state management 52 | - Handles both LLM and VLM model types 53 | 54 | ### ChatViewModel 55 | 56 | Business logic coordinator: 57 | - Manages chat state and message history 58 | - Handles generation lifecycle and cancellation 59 | - Coordinates media attachments for vision models 60 | - Performance metrics and error handling 61 | - Provides clean interface between UI and ML service 62 | 63 | ### Architecture Highlights 64 | 65 | - Complete separation of UI and business logic 66 | - SwiftUI views with async/await integration 67 | - Modular design for easy extension 68 | 69 | ### Documentation 70 | 71 | The codebase is thoroughly documented with: 72 | - Detailed class and method documentation 73 | - Clear inline comments explaining complex logic 74 | - DocC documentation format 75 | 76 | ### Markdown Support 77 | 78 | This sample app renders markdown content using SwiftUI's native `Text` view by passing the content as a `LocalizedStringKey`: 79 | 80 | ```swift 81 | Text(LocalizedStringKey(message.content)) 82 | ``` 83 | 84 | #### Limitations and Alternatives 85 | 86 | The default SwiftUI markdown rendering only supports standard markdown syntax. It does not support advanced features like tables and task lists that are available in GitHub Flavored Markdown (GFM). 87 | 88 | For more comprehensive markdown support: 89 | 90 | - **GitHub Flavored Markdown**: Consider using the [swift-markdown-ui](https://github.com/gonzalezreal/swift-markdown-ui) library. However, be aware that this library currently has an [unresolved issue with text selection](https://github.com/gonzalezreal/swift-markdown-ui/issues/264), which is why it wasn't used in this example. 91 | 92 | - **Enhanced Text Selection**: If you're satisfied with standard markdown but want better text selection capabilities on iOS (instead of only being able to select and copy entire content block), consider combining: 93 | - [SelectableText](https://github.com/kevinhermawan/SelectableText) for improved selection functionality 94 | - [MarkdownToAttributedString](https://github.com/madebywindmill/MarkdownToAttributedString) for markdown formatting 95 | 96 | > More discussion on this can be found on [issue #297](https://github.com/ml-explore/mlx-swift-examples/issues/297) 97 | 98 | ## Getting Started 99 | 100 | 1. Clone the repository 101 | 2. Install dependencies 102 | 3. Open in Xcode 103 | 4. Build and run 104 | 105 | ## Contributing 106 | 107 | This is an example project demonstrating MLX capabilities. Feel free to use it as a reference for your own projects. 108 | 109 | ## Acknowledgments 110 | 111 | - MLX team for the core framework 112 | -------------------------------------------------------------------------------- /Tests/MLXLMTests/UserInputTests.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLX 3 | import MLXLMCommon 4 | import MLXVLM 5 | import XCTest 6 | 7 | func assertEqual( 8 | _ v1: Any, _ v2: Any, path: [String] = [], file: StaticString = #filePath, line: UInt = #line 9 | ) { 10 | switch (v1, v2) { 11 | case let (v1, v2) as (String, String): 12 | XCTAssertEqual(v1, v2, file: file, line: line) 13 | 14 | case let (v1, v2) as ([Any], [Any]): 15 | XCTAssertEqual( 16 | v1.count, v2.count, "Arrays not equal size at \(path)", file: file, line: line) 17 | 18 | for (index, (v1v, v2v)) in zip(v1, v2).enumerated() { 19 | assertEqual(v1v, v2v, path: path + [index.description], file: file, line: line) 20 | } 21 | 22 | case let (v1, v2) as ([String: Any], [String: Any]): 23 | XCTAssertEqual( 24 | v1.keys.sorted(), v2.keys.sorted(), 25 | "\(String(describing: v1.keys.sorted())) and \(String(describing: v2.keys.sorted())) not equal at \(path)", 26 | file: file, line: line) 27 | 28 | for (k, v1v) in v1 { 29 | if let v2v = v2[k] { 30 | assertEqual(v1v, v2v, path: path + [k], file: file, line: line) 31 | } else { 32 | XCTFail("Missing value for \(k) at \(path)", file: file, line: line) 33 | } 34 | } 35 | default: 36 | XCTFail( 37 | "Unable to compare \(String(describing: v1)) and \(String(describing: v2)) at \(path)", 38 | file: file, line: line) 39 | } 40 | } 41 | 42 | public class UserInputTests: XCTestCase { 43 | 44 | public func testStandardConversion() { 45 | let chat: [Chat.Message] = [ 46 | .system("You are a useful agent."), 47 | .user("Tell me a story."), 48 | ] 49 | 50 | let messages = DefaultMessageGenerator().generate(messages: chat) 51 | 52 | let expected = [ 53 | [ 54 | "role": "system", 55 | "content": "You are a useful agent.", 56 | ], 57 | [ 58 | "role": "user", 59 | "content": "Tell me a story.", 60 | ], 61 | ] 62 | 63 | XCTAssertEqual(expected, messages as? [[String: String]]) 64 | } 65 | 66 | public func testQwen2ConversionText() { 67 | let chat: [Chat.Message] = [ 68 | .system("You are a useful agent."), 69 | .user("Tell me a story."), 70 | ] 71 | 72 | let messages = Qwen2VLMessageGenerator().generate(messages: chat) 73 | 74 | let expected = [ 75 | [ 76 | "role": "system", 77 | "content": [ 78 | [ 79 | "type": "text", 80 | "text": "You are a useful agent.", 81 | ] 82 | ], 83 | ], 84 | [ 85 | "role": "user", 86 | "content": [ 87 | [ 88 | "type": "text", 89 | "text": "Tell me a story.", 90 | ] 91 | ], 92 | ], 93 | ] 94 | 95 | assertEqual(expected, messages) 96 | } 97 | 98 | public func testQwen2ConversionImage() { 99 | let chat: [Chat.Message] = [ 100 | .system("You are a useful agent."), 101 | .user( 102 | "What is this?", 103 | images: [ 104 | .url( 105 | URL( 106 | string: "https://opensource.apple.com/images/projects/mlx.f5c59d8b.png")! 107 | ) 108 | ]), 109 | ] 110 | 111 | let messages = Qwen2VLMessageGenerator().generate(messages: chat) 112 | 113 | let expected = [ 114 | [ 115 | "role": "system", 116 | "content": [ 117 | [ 118 | "type": "text", 119 | "text": "You are a useful agent.", 120 | ] 121 | ], 122 | ], 123 | [ 124 | "role": "user", 125 | "content": [ 126 | [ 127 | "type": "text", 128 | "text": "What is this?", 129 | ], 130 | [ 131 | "type": "image" 132 | ], 133 | ], 134 | ], 135 | ] 136 | 137 | assertEqual(expected, messages) 138 | 139 | let userInput = UserInput(chat: chat) 140 | XCTAssertEqual(userInput.images.count, 1) 141 | } 142 | 143 | } 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLX Swift Examples 2 | 3 | Example [MLX Swift](https://github.com/ml-explore/mlx-swift) programs. The language model 4 | examples use models implemented in [MLX Swift LM](https://github.com/ml-explore/mlx-swift-lm). 5 | 6 | - [MNISTTrainer](Applications/MNISTTrainer/README.md): An example that runs on 7 | both iOS and macOS that downloads MNIST training data and trains a 8 | [LeNet](https://en.wikipedia.org/wiki/LeNet). 9 | 10 | - [LLMEval](Applications/LLMEval/README.md): An example that runs on both iOS 11 | and macOS that downloads an LLM and tokenizer from Hugging Face and 12 | generates text from a given prompt. 13 | 14 | - [VLMEval](Applications/VLMEval/README.md): An example that runs on iOS, macOS and visionOS to download a VLM and tokenizer from Hugging Face and 15 | analyzes the given image and describe it in text. 16 | 17 | - [MLXChatExample](Applications/MLXChatExample/README.md): An example chat app that runs on both iOS and macOS that supports LLMs and VLMs. 18 | 19 | - [LoRATrainingExample](Applications/LoRATrainingExample/README.md): An example that runs on macOS that downloads an LLM and fine-tunes it using LoRA (Low-Rank Adaptation) with training data. 20 | 21 | - [LinearModelTraining](Tools/LinearModelTraining/README.md): An example that 22 | trains a simple linear model. 23 | 24 | - [StableDiffusionExample](Applications/StableDiffusionExample/README.md): An 25 | example that runs on both iOS and macOS that downloads a stable diffusion model 26 | from Hugging Face and and generates an image from a given prompt. 27 | 28 | - [llm-tool](Tools/llm-tool/README.md): A command line tool for generating text 29 | using a variety of LLMs available on the Hugging Face hub. 30 | 31 | - [ExampleLLM](Tools/ExampleLLM/README.md): A command line tool using the simplified API to interact with LLMs. 32 | 33 | - [image-tool](Tools/image-tool/README.md): A command line tool for generating images 34 | using a stable diffusion model from Hugging Face. 35 | 36 | - [mnist-tool](Tools/mnist-tool/README.md): A command line tool for training a 37 | a LeNet on MNIST. 38 | 39 | > [!IMPORTANT] 40 | > `MLXLMCommon`, `MLXLLM`, `MLXVLM` and `MLXEmbedders` have moved to a new repository 41 | > containing _only_ reusable libraries: [mlx-swift-lm](https://github.com/ml-explore/mlx-swift-lm). 42 | 43 | Previous URLs and tags will continue to work, but going forward all updates to these 44 | libraries will be done in the other repository. Previous tags _are_ supported in 45 | the new repository. 46 | 47 | > [!TIP] 48 | > Contributors that wish to edit both `mlx-swift-examples` and `mlx-swift-lm` can 49 | > use [this technique in Xcode](https://developer.apple.com/documentation/xcode/editing-a-package-dependency-as-a-local-package). 50 | 51 | 52 | # Reusable Libraries 53 | 54 | LLM and VLM implementations are available in [MLX Swift LM](https://github.com/ml-explore/mlx-swift-lm): 55 | 56 | - [MLXLLMCommon](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxlmcommon) -- common API for LLM and VLM 57 | - [MLXLLM](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxllm) -- large language model example implementations 58 | - [MLXVLM](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxvlm) -- vision language model example implementations 59 | - [MLXEmbedders](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxembedders) -- popular Encoders / Embedding models example implementations 60 | 61 | MLX Swift Examples also contains a few reusable libraries that can be imported with this code in your `Package.swift` or by referencing the URL in Xcode: 62 | 63 | ```swift 64 | .package(url: "https://github.com/ml-explore/mlx-swift-examples/", branch: "main"), 65 | ``` 66 | 67 | Then add one or more libraries to the target as a dependency: 68 | 69 | ```swift 70 | .target( 71 | name: "YourTargetName", 72 | dependencies: [ 73 | .product(name: "StableDiffusion", package: "mlx-libraries") 74 | ]), 75 | ``` 76 | 77 | - [StableDiffusion](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/stablediffusion) -- SDXL Turbo and Stable Diffusion model example implementations 78 | - [MLXMNIST](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxmnist) -- MNIST implementation for all your digit recognition needs 79 | 80 | ## Running 81 | 82 | The application and command line tool examples can be run from Xcode or from 83 | the command line: 84 | 85 | ``` 86 | ./mlx-run llm-tool --prompt "swift programming language" 87 | ``` 88 | 89 | Note: `mlx-run` is a shell script that uses `xcode` command line tools to 90 | locate the built binaries. It is equivalent to running from Xcode itself. 91 | 92 | See also: 93 | 94 | - [MLX troubleshooting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/troubleshooting) 95 | --------------------------------------------------------------------------------