├── .github ├── FUNDING.yml ├── copilot-instructions.md └── workflows │ ├── docc.yml │ ├── test.yml │ └── update-dependencies.yml ├── .gitignore ├── .gitmodules ├── .swiftpm └── xcode │ └── package.xcworkspace │ └── contents.xcworkspacedata ├── Example ├── LocalLLMClientExample.xcodeproj │ ├── project.pbxproj │ └── project.xcworkspace │ │ └── contents.xcworkspacedata ├── LocalLLMClientExample │ ├── AI.swift │ ├── App.swift │ ├── Assets.xcassets │ │ ├── AccentColor.colorset │ │ │ └── Contents.json │ │ ├── AppIcon.appiconset │ │ │ └── Contents.json │ │ └── Contents.json │ ├── BottomBar.swift │ ├── ChatView.swift │ ├── ChatViewModel.swift │ ├── Configuration │ │ └── Build.xcconfig │ ├── Downloader.swift │ ├── Image+.swift │ └── LocalLLMClientExample.entitlements ├── Package.swift └── README.md ├── LICENSE ├── Package.swift ├── README.md ├── Sources ├── LocalLLMCLI │ └── command.swift ├── LocalLLMClient │ ├── Async+.swift │ ├── Docs.docc │ │ └── index.md │ ├── LLMClient.swift │ ├── LLMError.swift │ └── LLMInput.swift ├── LocalLLMClientLlama │ ├── Batch.swift │ ├── Context.swift │ ├── Decoder.swift │ ├── Generator.swift │ ├── LlamaAutoMessageDecoder.swift │ ├── LlamaChatMessageDecoder.swift │ ├── LlamaClient.swift │ ├── Logger.swift │ ├── Model.swift │ ├── Multimodal.swift │ ├── Parameter.swift │ ├── Resources │ │ └── Grammars │ │ │ └── json.gbnf │ ├── Sampler.swift │ ├── Token.swift │ ├── Utility.swift │ └── stb_image.swift ├── LocalLLMClientLlamaC │ ├── clip-impl.h │ ├── clip.cpp │ ├── ggml-cpp.h │ ├── include │ │ ├── clip.h │ │ ├── ggml-alloc.h │ │ ├── ggml-backend.h │ │ ├── ggml-cpu.h │ │ ├── ggml-opt.h │ │ ├── ggml.h │ │ ├── gguf.h │ │ ├── llama.h │ │ ├── mtmd-helper.h │ │ └── mtmd.h │ ├── miniaudio │ │ └── miniaudio.h │ ├── mtmd-audio.cpp │ ├── mtmd-audio.h │ ├── mtmd-helper.cpp │ ├── mtmd.cpp │ └── stb │ │ └── stb_image.h ├── LocalLLMClientMLX │ ├── Context.swift │ ├── MLXClient.swift │ ├── Parameter.swift │ └── Utility.swift └── LocalLLMClientUtility │ ├── Downloader.swift │ ├── FileDownloader.swift │ ├── Globs.swift │ ├── HuggingFaceAPI.swift │ ├── Lock.swift │ └── URL+.swift ├── Tests ├── LocalLLMClientLlamaTests │ ├── ContextTests.swift │ ├── LlamaChatMessageDecoderTests.swift │ ├── LocalLLMClientTests.swift │ └── ModelTests.swift ├── LocalLLMClientMLXTests │ ├── LocalLLMClientTests.swift │ └── ModelTests.swift └── LocalLLMClientUtilityTests │ ├── DownloaderTests.swift │ ├── FileDownloaderTests.swift │ ├── FilesMetadataTests.swift │ ├── HuggingFaceAPITests.swift │ ├── MockURLProtocol.swift │ └── URLExtensionTests.swift └── scripts ├── get_llama_version.sh ├── run_mlx.sh └── update_dependencies.sh /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: tattn 4 | patreon: tattn 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 16 | -------------------------------------------------------------------------------- /.github/copilot-instructions.md: -------------------------------------------------------------------------------- 1 | We follow Apple's Swift API Design Guidelines for naming conventions and code structure. 2 | We use Swift Testing (import Testing) to write tests. Don't use XCTest. Swift Testing is already contained in the current Xcode version. 3 | We follow SOLID principles to ensure our code is modular and maintainable. 4 | For running tests on Apple platforms like macOS and iOS, we use xcodebuild like `xcodebuild test -scheme LocalLLMClient-Package -destination 'platform=macOS' -only-testing:/`. Don't use swift test like `swift test` for Apple platforms, as it does not support the same features and capabilities as xcodebuild. 5 | For running tests on other platforms, we use swift test like `swift test`. Don't use xcodebuild for non-Apple platforms, as it is not compatible with them. 6 | -------------------------------------------------------------------------------- /.github/workflows/docc.yml: -------------------------------------------------------------------------------- 1 | name: Generate and Deploy DocC 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: read 10 | pages: write 11 | id-token: write 12 | 13 | concurrency: 14 | group: "pages" 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | generate-docc: 19 | runs-on: macos-15 20 | env: 21 | DEVELOPER_DIR: "/Applications/Xcode_16.4.app/Contents/Developer" 22 | steps: 23 | - uses: actions/checkout@v4 24 | with: 25 | submodules: recursive 26 | 27 | - name: Setup Pages 28 | uses: actions/configure-pages@v4 29 | 30 | - name: Build DocC 31 | run: | 32 | BUILD_DOCC=1 swift package --allow-writing-to-directory \ 33 | ./docs generate-documentation --output-path ./docs \ 34 | --enable-experimental-combined-documentation \ 35 | --enable-experimental-external-link-support \ 36 | --hosting-base-path "LocalLLMClient" \ 37 | --transform-for-static-hosting \ 38 | --target LocalLLMClient --target LocalLLMClientMLX \ 39 | --target LocalLLMClientUtility --target LocalLLMClientLlama 40 | 41 | - name: Upload artifact 42 | uses: actions/upload-pages-artifact@v3 43 | with: 44 | path: "./docs" 45 | 46 | deploy: 47 | environment: 48 | name: github-pages 49 | url: ${{ steps.deployment.outputs.page_url }} 50 | runs-on: ubuntu-latest 51 | needs: generate-docc 52 | steps: 53 | - name: Deploy to GitHub Pages 54 | id: deployment 55 | uses: actions/deploy-pages@v4 56 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: ["main"] 7 | paths: 8 | - "Sources/**" 9 | - "Tests/**" 10 | - "Package.swift" 11 | - "Example/**" 12 | pull_request: 13 | types: [opened, reopened, synchronize, ready_for_review] 14 | branches: ["main"] 15 | paths: 16 | - "Sources/**" 17 | - "Tests/**" 18 | - "Package.swift" 19 | - "Example/**" 20 | 21 | jobs: 22 | test: 23 | name: ${{ matrix.test-type }} Tests 24 | runs-on: macos-15 25 | strategy: 26 | matrix: 27 | test-type: [Llama, MLX] 28 | env: 29 | DEVELOPER_DIR: "/Applications/Xcode_16.4.app/Contents/Developer" 30 | TEST_RUNNER_GITHUB_MODEL_CACHE: "${{ github.workspace }}/model_cache" 31 | steps: 32 | - uses: actions/checkout@v4 33 | with: 34 | submodules: recursive 35 | 36 | - name: Cache model file 37 | id: cache-model 38 | uses: actions/cache@v4 39 | with: 40 | path: model_cache 41 | key: model_cache 42 | 43 | - name: Run ${{ matrix.test-type }} tests 44 | run: TEST_RUNNER_GITHUB_ACTIONS_TEST="${{ matrix.test-type }}" xcodebuild test -scheme LocalLLMClient-Package -destination 'platform=macOS' 45 | 46 | build-example: 47 | runs-on: macos-15 48 | needs: test 49 | env: 50 | DEVELOPER_DIR: "/Applications/Xcode_16.4.app/Contents/Developer" 51 | steps: 52 | - uses: actions/checkout@v4 53 | with: 54 | submodules: recursive 55 | 56 | - name: Build Example app for macOS 57 | run: | 58 | cd Example 59 | xcodebuild build -project LocalLLMClientExample.xcodeproj -scheme LocalLLMClientExample -destination 'platform=macOS' CODE_SIGN_IDENTITY="-" 60 | 61 | build-ubuntu-x86_64: 62 | runs-on: ubuntu-latest 63 | needs: test 64 | env: 65 | GITHUB_MODEL_CACHE: "${{ github.workspace }}/model_cache" 66 | steps: 67 | - uses: actions/checkout@v4 68 | with: 69 | submodules: recursive 70 | 71 | - name: Cache model file 72 | id: cache-model 73 | uses: actions/cache@v4 74 | with: 75 | path: model_cache 76 | key: model_cache 77 | 78 | - name: Install dependencies 79 | run: | 80 | sudo apt-get update 81 | sudo apt-get install -y libcurl4-openssl-dev 82 | 83 | - name: Install Swift 84 | shell: bash 85 | run: | 86 | curl -O "https://download.swift.org/swiftly/linux/swiftly-$(uname -m).tar.gz" && \ 87 | tar zxf "swiftly-$(uname -m).tar.gz" && \ 88 | ./swiftly init --assume-yes --no-modify-profile --skip-install --quiet-shell-followup && \ 89 | . ${SWIFTLY_HOME_DIR:-~/.local/share/swiftly}/env.sh && \ 90 | hash -r 91 | swiftly install 6.1 92 | echo "SWIFTLY_HOME_DIR=${SWIFTLY_HOME_DIR}" >>"${GITHUB_ENV}" 93 | echo "SWIFTLY_BIN_DIR=${SWIFTLY_BIN_DIR}" >>"${GITHUB_ENV}" 94 | echo "${SWIFTLY_BIN_DIR}" >>"${GITHUB_PATH}" 95 | 96 | - name: Download and extract llama.cpp binaries 97 | run: | 98 | LLAMA_VERSION=$(./scripts/get_llama_version.sh) 99 | mkdir -p ${{ github.workspace }}/lib 100 | 101 | # Download and extract llama.cpp binaries 102 | LLAMA_URL="https://github.com/ggml-org/llama.cpp/releases/download/${LLAMA_VERSION}/llama-${LLAMA_VERSION}-bin-ubuntu-x64.zip" 103 | echo "Downloading llama.cpp binaries from: $LLAMA_URL" 104 | curl -L $LLAMA_URL -o llama-bin.zip 105 | unzip -j llama-bin.zip "*.so" -d "${{ github.workspace }}/lib" 106 | ls -la ${{ github.workspace }}/lib 107 | 108 | - name: Build package 109 | run: LDFLAGS="-L${{ github.workspace }}/lib" swift build 110 | # LD_LIBRARY_PATH="$(pwd)/lib" ./.build/debug/localllm -m "${{ github.workspace }}/model_cache/huggingface/models/ggml-org/SmolVLM-256M-Instruct-GGUF/SmolVLM-256M-Instruct-Q8_0.gguf" "Hello" 111 | 112 | - name: Run tests 113 | run: LDFLAGS="-L${{ github.workspace }}/lib" LD_LIBRARY_PATH="${{ github.workspace }}/lib" swift test 114 | if: false # TODO: 115 | -------------------------------------------------------------------------------- /.github/workflows/update-dependencies.yml: -------------------------------------------------------------------------------- 1 | name: Update Dependencies 2 | 3 | on: 4 | # schedule: 5 | # - cron: '0 11 * * *' 6 | workflow_dispatch: 7 | 8 | jobs: 9 | update-dependencies: 10 | runs-on: macos-15 11 | env: 12 | DEVELOPER_DIR: "/Applications/Xcode_16.4.app/Contents/Developer" 13 | permissions: 14 | contents: write 15 | pull-requests: write 16 | steps: 17 | - name: Checkout repository 18 | uses: actions/checkout@v4 19 | with: 20 | submodules: recursive 21 | 22 | - name: Get current datetime 23 | run: echo "DATE=$(date '+%Y%m%d%H%M%S')" >> $GITHUB_ENV 24 | 25 | - name: Run update_dependencies 26 | run: ./scripts/update_dependencies.sh 27 | env: 28 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 29 | 30 | - name: Check for changes 31 | id: git-check 32 | run: | 33 | git diff --exit-code || echo "changes=true" >> $GITHUB_OUTPUT 34 | 35 | - name: Create Pull Request 36 | if: steps.git-check.outputs.changes == 'true' 37 | uses: peter-evans/create-pull-request@v7 38 | with: 39 | token: ${{ secrets.GITHUB_TOKEN }} 40 | commit-message: 'auto: update llama.cpp dependency' 41 | title: '[auto] Update llama.cpp to latest version' 42 | body: | 43 | This PR updates the llama.cpp dependency to the latest version. 44 | 45 | - Automatically generated by the Update Dependencies workflow 46 | - Date: ${{ env.DATE }} 47 | branch: update-llama-dependency-${{ env.DATE }} 48 | draft: always-true 49 | base: main 50 | labels: | 51 | bot -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### https://raw.github.com/github/gitignore/cdd9e946da421758c6f42c427c7bc65c8326155d/Global/macOS.gitignore 2 | 3 | # General 4 | .DS_Store 5 | .AppleDouble 6 | .LSOverride 7 | 8 | # Icon must end with two \r 9 | Icon 10 | 11 | # Thumbnails 12 | ._* 13 | 14 | # Files that might appear in the root of a volume 15 | .DocumentRevisions-V100 16 | .fseventsd 17 | .Spotlight-V100 18 | .TemporaryItems 19 | .Trashes 20 | .VolumeIcon.icns 21 | .com.apple.timemachine.donotpresent 22 | 23 | # Directories potentially created on remote AFP share 24 | .AppleDB 25 | .AppleDesktop 26 | Network Trash Folder 27 | Temporary Items 28 | .apdisk 29 | 30 | 31 | ### https://raw.github.com/github/gitignore/cdd9e946da421758c6f42c427c7bc65c8326155d/Swift.gitignore 32 | 33 | # Xcode 34 | # 35 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 36 | 37 | ## User settings 38 | xcuserdata/ 39 | 40 | ## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9) 41 | *.xcscmblueprint 42 | *.xccheckout 43 | 44 | ## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4) 45 | build/ 46 | DerivedData/ 47 | *.moved-aside 48 | *.pbxuser 49 | !default.pbxuser 50 | *.mode1v3 51 | !default.mode1v3 52 | *.mode2v3 53 | !default.mode2v3 54 | *.perspectivev3 55 | !default.perspectivev3 56 | 57 | ## Obj-C/Swift specific 58 | *.hmap 59 | 60 | ## App packaging 61 | *.ipa 62 | *.dSYM.zip 63 | *.dSYM 64 | 65 | ## Playgrounds 66 | timeline.xctimeline 67 | playground.xcworkspace 68 | 69 | # Swift Package Manager 70 | # 71 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. 72 | # Packages/ 73 | # Package.pins 74 | # Package.resolved 75 | # *.xcodeproj 76 | # 77 | # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata 78 | # hence it is not needed unless you have added a package configuration file to your project 79 | # .swiftpm 80 | 81 | .build/ 82 | 83 | # CocoaPods 84 | # 85 | # We recommend against adding the Pods directory to your .gitignore. However 86 | # you should judge for yourself, the pros and cons are mentioned at: 87 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control 88 | # 89 | # Pods/ 90 | # 91 | # Add this line if you want to avoid checking in source code from the Xcode workspace 92 | # *.xcworkspace 93 | 94 | # Carthage 95 | # 96 | # Add this line if you want to avoid checking in source code from Carthage dependencies. 97 | # Carthage/Checkouts 98 | 99 | Carthage/Build/ 100 | 101 | # Accio dependency management 102 | Dependencies/ 103 | .accio/ 104 | 105 | # fastlane 106 | # 107 | # It is recommended to not store the screenshots in the git repo. 108 | # Instead, use fastlane to re-generate the screenshots whenever they are needed. 109 | # For more information about the recommended setup visit: 110 | # https://docs.fastlane.tools/best-practices/source-control/#source-control 111 | 112 | fastlane/report.xml 113 | fastlane/Preview.html 114 | fastlane/screenshots/**/*.png 115 | fastlane/test_output 116 | 117 | # Code Injection 118 | # 119 | # After new code Injection tools there's a generated folder /iOSInjectionProject 120 | # https://github.com/johnno1962/injectionforxcode 121 | 122 | iOSInjectionProject/ 123 | 124 | dev 125 | Tests/LocalLLMClientLlamaTests/dev 126 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "Sources/LlamaClientExperimentalC/exclude/llama.cpp"] 2 | path = Sources/LocalLLMClientLlamaC/exclude/llama.cpp 3 | url = https://github.com/ggml-org/llama.cpp.git 4 | -------------------------------------------------------------------------------- /.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/AI.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import LocalLLMClient 3 | import LocalLLMClientMLX 4 | import LocalLLMClientLlama 5 | #if canImport(UIKit) 6 | import UIKit 7 | #endif 8 | 9 | enum LLMModel: Sendable, CaseIterable, Identifiable { 10 | case qwen3 11 | case qwen3_4b 12 | case qwen2_5VL_3b 13 | case gemma3 14 | case gemma3_4b 15 | case mobileVLM_3b 16 | 17 | var name: String { 18 | switch self { 19 | case .qwen3: "MLX / Qwen3 1.7B" 20 | case .qwen3_4b: "MLX / Qwen3 4B" 21 | case .qwen2_5VL_3b: "MLX / Qwen2.5VL 3B" 22 | case .gemma3: "llama.cpp / Gemma3 1B" 23 | case .gemma3_4b: "llama.cpp / Gemma3 4B" 24 | case .mobileVLM_3b: "llama.cpp / MobileVLM 3B" 25 | } 26 | } 27 | 28 | var id: String { 29 | switch self { 30 | case .qwen3: "mlx-community/Qwen3-1.7B-4bit" 31 | case .qwen3_4b: "mlx-community/Qwen3-4B-4bit" 32 | case .qwen2_5VL_3b: "mlx-community/Qwen2.5-VL-3B-Instruct-abliterated-4bit" 33 | case .gemma3: "lmstudio-community/gemma-3-1B-it-qat-GGUF" 34 | case .gemma3_4b: "lmstudio-community/gemma-3-4B-it-qat-GGUF" 35 | case .mobileVLM_3b: "Blombert/MobileVLM-3B-GGUF" 36 | } 37 | } 38 | 39 | var filename: String? { 40 | switch self { 41 | case .qwen3, .qwen3_4b, .qwen2_5VL_3b: nil 42 | case .gemma3: "gemma-3-1B-it-QAT-Q4_0.gguf" 43 | case .gemma3_4b: "gemma-3-4B-it-QAT-Q4_0.gguf" 44 | case .mobileVLM_3b: "ggml-MobileVLM-3B-q5_k_s.gguf" 45 | } 46 | } 47 | 48 | var clipFilename: String? { 49 | switch self { 50 | case .qwen3, .qwen3_4b, .qwen2_5VL_3b, .gemma3: nil 51 | #if os(macOS) 52 | case .gemma3_4b: "mmproj-model-f16.gguf" 53 | #elseif os(iOS) 54 | case .gemma3_4b: nil 55 | #endif 56 | case .mobileVLM_3b: "mmproj-model-f16.gguf" 57 | } 58 | } 59 | 60 | var isMLX: Bool { 61 | switch self { 62 | case .qwen3, .qwen3_4b, .qwen2_5VL_3b: true 63 | case .gemma3, .gemma3_4b, .mobileVLM_3b: false 64 | } 65 | } 66 | 67 | var supportsVision: Bool { 68 | switch self { 69 | case .qwen3, .qwen3_4b, .gemma3: false 70 | #if os(macOS) 71 | case .gemma3_4b: true 72 | #elseif os(iOS) 73 | case .gemma3_4b: false 74 | #endif 75 | case .qwen2_5VL_3b, .mobileVLM_3b: true 76 | } 77 | } 78 | } 79 | 80 | @Observable @MainActor 81 | final class AI { 82 | var model = LLMModel.qwen3 83 | private(set) var isLoading = false 84 | private(set) var downloadProgress: Double = 0 85 | 86 | private var client: AnyLLMClient? 87 | 88 | func loadLLM() async { 89 | isLoading = true 90 | defer { isLoading = false } 91 | 92 | // Release memory first if a previous model was loaded 93 | client = nil 94 | 95 | do { 96 | let downloader = Downloader(model: model) 97 | if downloader.isDownloaded { 98 | downloadProgress = 1 99 | } else { 100 | downloadProgress = 0 101 | try await downloader.download { @MainActor [weak self] progress in 102 | self?.downloadProgress = progress 103 | } 104 | } 105 | 106 | #if os(iOS) 107 | while downloadProgress < 1 || UIApplication.shared.applicationState != .active { 108 | try await Task.sleep(for: .seconds(2)) 109 | } 110 | #endif 111 | 112 | if model.isMLX { 113 | client = try await AnyLLMClient(LocalLLMClient.mlx(url: downloader.url)) 114 | } else { 115 | client = try await AnyLLMClient(LocalLLMClient.llama(url: downloader.url, mmprojURL: downloader.clipURL, verbose: true)) 116 | } 117 | } catch { 118 | print("Failed to load LLM: \(error)") 119 | } 120 | } 121 | 122 | func ask(_ messages: [LLMInput.Message]) async throws -> AsyncThrowingStream { 123 | guard let client else { 124 | throw LLMError.failedToLoad(reason: "LLM not loaded") 125 | } 126 | return try await client.textStream(from: .chat(messages)) 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/App.swift: -------------------------------------------------------------------------------- 1 | import SwiftUI 2 | 3 | @main 4 | struct ExampleApp: App { 5 | var body: some Scene { 6 | WindowGroup { 7 | RootView() 8 | .environment(AI()) 9 | } 10 | } 11 | } 12 | 13 | struct RootView: View { 14 | @Environment(AI.self) private var ai 15 | 16 | var body: some View { 17 | NavigationStack { 18 | ChatView() 19 | } 20 | .disabled(ai.isLoading) 21 | .overlay { 22 | if ai.isLoading { 23 | ZStack { 24 | Color.black.opacity(0.5) 25 | .ignoresSafeArea() 26 | 27 | Group { 28 | if ai.downloadProgress < 1 { 29 | ProgressView("Downloading LLM...", value: ai.downloadProgress) 30 | } else { 31 | ProgressView("Loading LLM...") 32 | } 33 | } 34 | .padding() 35 | .background(.regularMaterial, in: RoundedRectangle(cornerRadius: 12)) 36 | .padding() 37 | } 38 | } 39 | } 40 | #if !targetEnvironment(simulator) 41 | .onChange(of: ai.model, initial: true) { _, _ in 42 | Task { 43 | await ai.loadLLM() 44 | } 45 | } 46 | #endif 47 | } 48 | } 49 | 50 | #Preview { 51 | RootView() 52 | .environment(AI()) 53 | } 54 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "appearances" : [ 10 | { 11 | "appearance" : "luminosity", 12 | "value" : "dark" 13 | } 14 | ], 15 | "idiom" : "universal", 16 | "platform" : "ios", 17 | "size" : "1024x1024" 18 | }, 19 | { 20 | "appearances" : [ 21 | { 22 | "appearance" : "luminosity", 23 | "value" : "tinted" 24 | } 25 | ], 26 | "idiom" : "universal", 27 | "platform" : "ios", 28 | "size" : "1024x1024" 29 | }, 30 | { 31 | "idiom" : "mac", 32 | "scale" : "1x", 33 | "size" : "16x16" 34 | }, 35 | { 36 | "idiom" : "mac", 37 | "scale" : "2x", 38 | "size" : "16x16" 39 | }, 40 | { 41 | "idiom" : "mac", 42 | "scale" : "1x", 43 | "size" : "32x32" 44 | }, 45 | { 46 | "idiom" : "mac", 47 | "scale" : "2x", 48 | "size" : "32x32" 49 | }, 50 | { 51 | "idiom" : "mac", 52 | "scale" : "1x", 53 | "size" : "128x128" 54 | }, 55 | { 56 | "idiom" : "mac", 57 | "scale" : "2x", 58 | "size" : "128x128" 59 | }, 60 | { 61 | "idiom" : "mac", 62 | "scale" : "1x", 63 | "size" : "256x256" 64 | }, 65 | { 66 | "idiom" : "mac", 67 | "scale" : "2x", 68 | "size" : "256x256" 69 | }, 70 | { 71 | "idiom" : "mac", 72 | "scale" : "1x", 73 | "size" : "512x512" 74 | }, 75 | { 76 | "idiom" : "mac", 77 | "scale" : "2x", 78 | "size" : "512x512" 79 | } 80 | ], 81 | "info" : { 82 | "author" : "xcode", 83 | "version" : 1 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/BottomBar.swift: -------------------------------------------------------------------------------- 1 | import SwiftUI 2 | import PhotosUI 3 | import LocalLLMClient 4 | 5 | struct BottomBar: View { 6 | @Binding var text: String 7 | @Binding var images: [ChatMessage.Image] 8 | let isGenerating: Bool 9 | let onSubmit: (String) -> Void 10 | let onCancel: () -> Void 11 | 12 | @State private var pickedItem: PhotosPickerItem? 13 | @Environment(AI.self) private var ai 14 | 15 | var body: some View { 16 | HStack { 17 | modelMenu 18 | imagePicker 19 | .disabled(!ai.model.supportsVision) 20 | 21 | TextField("Hello", text: $text) 22 | .textFieldStyle(.roundedBorder) 23 | .submitLabel(.send) 24 | .onSubmit{ 25 | onSubmit(text) 26 | } 27 | .disabled(isGenerating) 28 | 29 | if isGenerating { 30 | Button { 31 | onCancel() 32 | } label: { 33 | Image(systemName: "xmark") 34 | .foregroundStyle(.red) 35 | } 36 | } else if !text.isEmpty { 37 | Button { 38 | onSubmit(text) 39 | } label: { 40 | Image(systemName: "arrow.up") 41 | .foregroundStyle(text.isEmpty ? .gray : .accentColor) 42 | } 43 | .buttonBorderShape(.circle) 44 | .keyboardShortcut(.defaultAction) 45 | } 46 | } 47 | .safeAreaInset(edge: .top) { 48 | if !images.isEmpty { 49 | ScrollView(.horizontal) { 50 | HStack { 51 | ForEach(images) { image in 52 | Image(llm: image.value) 53 | .resizable() 54 | .aspectRatio(1, contentMode: .fill) 55 | .cornerRadius(8) 56 | .contextMenu { 57 | Button { 58 | images.removeAll { $0.id == image.id } 59 | } label: { 60 | Text("Remove") 61 | } 62 | } 63 | } 64 | } 65 | .frame(height: 60) 66 | } 67 | } 68 | } 69 | .animation(.default, value: text.isEmpty) 70 | .animation(.default, value: images.count) 71 | } 72 | 73 | @ViewBuilder 74 | private var modelMenu: some View { 75 | #if os(macOS) 76 | @Bindable var ai = ai 77 | Picker(selection: $ai.model) { 78 | ForEach(LLMModel.allCases) { model in 79 | Group { 80 | if model.supportsVision { 81 | Text("\(model.name) [VLM]") 82 | } else { 83 | Text(model.name) 84 | } 85 | } 86 | .tag(model) 87 | } 88 | } label: { 89 | Image(systemName: "brain.head.profile") 90 | } 91 | .pickerStyle(.menu) 92 | .labelsHidden() 93 | .fixedSize(horizontal: true, vertical: false) 94 | #elseif os(iOS) 95 | Menu { 96 | ForEach(LLMModel.allCases) { model in 97 | Button { 98 | ai.model = model 99 | } label: { 100 | if model.supportsVision { 101 | Text("\(model.name) [VLM]") 102 | } else { 103 | Text(model.name) 104 | } 105 | if ai.model == model { 106 | Image(systemName: "checkmark") 107 | } 108 | } 109 | } 110 | } label: { 111 | Image(systemName: "brain.head.profile") 112 | } 113 | .menuStyle(.button) 114 | #endif 115 | } 116 | 117 | @ViewBuilder 118 | private var imagePicker: some View { 119 | PhotosPicker( 120 | selection: $pickedItem, 121 | matching: .images, 122 | preferredItemEncoding: .compatible 123 | ) { 124 | Image(systemName: "photo") 125 | } 126 | .onChange(of: pickedItem) { _, item in 127 | guard let item else { return } 128 | pickedItem = nil 129 | Task { 130 | let data = try await item.loadTransferable(type: Data.self) 131 | guard let data, let image = LLMInputImage(data: data) else { return } 132 | images.append(.init(value: image)) 133 | } 134 | } 135 | } 136 | } 137 | 138 | #Preview(traits: .sizeThatFitsLayout) { 139 | @Previewable @State var text = "" 140 | @Previewable @State var images: [ChatMessage.Image] = [ 141 | .preview, .preview2 142 | ] 143 | 144 | BottomBar(text: $text, images: $images, isGenerating: false, onSubmit: { _ in }, onCancel: {}) 145 | .environment(AI()) 146 | } 147 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/ChatView.swift: -------------------------------------------------------------------------------- 1 | import SwiftUI 2 | import LocalLLMClient 3 | import LocalLLMClientMLX 4 | 5 | struct ChatView: View { 6 | @State var viewModel = ChatViewModel() 7 | @State private var position = ScrollPosition(idType: ChatMessage.ID.self) 8 | 9 | @Environment(AI.self) private var ai 10 | 11 | var body: some View { 12 | VStack { 13 | messageList 14 | 15 | BottomBar( 16 | text: $viewModel.inputText, 17 | images: $viewModel.inputImages, 18 | isGenerating: viewModel.isGenerating 19 | ) { _ in 20 | viewModel.sendMessage(to: ai) 21 | } onCancel: { 22 | viewModel.cancelGeneration() 23 | } 24 | .padding([.horizontal, .bottom]) 25 | } 26 | .navigationTitle("Chat") 27 | .toolbar { 28 | ToolbarItem { 29 | Menu { 30 | Button("Clear Chat") { 31 | viewModel.clearMessages() 32 | } 33 | } label: { 34 | Image(systemName: "ellipsis.circle") 35 | } 36 | } 37 | } 38 | .onChange(of: ai.model) { _, _ in 39 | viewModel.clearMessages() 40 | } 41 | } 42 | 43 | @ViewBuilder 44 | private var messageList: some View { 45 | ScrollView { 46 | LazyVStack(spacing: 12) { 47 | ForEach(viewModel.messages) { message in 48 | ChatBubbleView(message: message) 49 | .id(message.id) 50 | } 51 | } 52 | .scrollTargetLayout() 53 | .padding(.horizontal) 54 | } 55 | .onChange(of: viewModel.messages) { _, _ in 56 | withAnimation { 57 | position.scrollTo(edge: .bottom) 58 | } 59 | } 60 | .scrollPosition($position) 61 | } 62 | } 63 | 64 | struct ChatBubbleView: View { 65 | let message: ChatMessage 66 | 67 | var body: some View { 68 | let isUser = message.role == .user 69 | 70 | VStack(alignment: isUser ? .trailing : .leading) { 71 | LazyVGrid(columns: [.init(.adaptive(minimum: 100))], alignment: .leading) { 72 | ForEach(message.images) { image in 73 | Image(llm: image.value) 74 | .resizable() 75 | .scaledToFit() 76 | .cornerRadius(16) 77 | } 78 | .scaleEffect(x: isUser ? -1 : 1) 79 | } 80 | .scaleEffect(x: isUser ? -1 : 1) 81 | 82 | Text(message.text) 83 | .padding(12) 84 | .background(isUser ? Color.accentColor : .gray.opacity(0.2)) 85 | .foregroundColor(isUser ? .white : .primary) 86 | .cornerRadius(16) 87 | } 88 | .padding(isUser ? .leading : .trailing, 50) 89 | .frame(maxWidth: .infinity, alignment: isUser ? .trailing : .leading) 90 | } 91 | } 92 | 93 | #Preview("Text") { 94 | NavigationStack { 95 | ChatView(viewModel: .init(messages: [ 96 | .init(role: .user, text: "Hello"), 97 | .init(role: .assistant, text: "Hi! How can I help you?"), 98 | .init(role: .user, text: "Hello", images: [.preview, .preview2]), 99 | ])) 100 | } 101 | .environment(AI()) 102 | } 103 | 104 | extension ChatMessage.Image { 105 | static let preview = try! Self.init(value: LLMInputImage(data: .init(contentsOf: URL(string: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.jpeg")!))!) 106 | static let preview2 = try! Self.init(value: LLMInputImage(data: .init(contentsOf: URL(string: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")!))!) 107 | } 108 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/ChatViewModel.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import LocalLLMClient 3 | 4 | struct ChatMessage: Identifiable, Equatable, Sendable { 5 | var id = UUID() 6 | var role: Role 7 | var text: String 8 | var images: [Image] = [] 9 | 10 | enum Role: Sendable { 11 | case system 12 | case user 13 | case assistant 14 | } 15 | 16 | struct Image: Identifiable, Equatable, @unchecked Sendable { 17 | var id = UUID() 18 | var value: LLMInputImage 19 | } 20 | } 21 | 22 | @Observable @MainActor 23 | final class ChatViewModel { 24 | var inputText = "" 25 | var inputImages: [ChatMessage.Image] = [] 26 | private(set) var messages: [ChatMessage] = [] 27 | private var generateTask: Task? 28 | 29 | init(messages: [ChatMessage] = []) { 30 | self.messages = messages 31 | } 32 | 33 | var isGenerating: Bool { 34 | generateTask != nil 35 | } 36 | 37 | func sendMessage(to ai: AI) { 38 | guard !inputText.isEmpty, !isGenerating else { return } 39 | 40 | messages.append(ChatMessage(role: .user, text: inputText, images: inputImages)) 41 | let newMessages = messages 42 | messages.append(ChatMessage(role: .assistant, text: "")) 43 | 44 | let currentInput = (inputText, inputImages) 45 | inputText = "" 46 | inputImages = [] 47 | 48 | generateTask = Task { 49 | do { 50 | var response = "" 51 | for try await token in try await ai.ask(newMessages.llmMessages()) { 52 | response += token 53 | messages[messages.count - 1].text = response 54 | } 55 | } catch { 56 | messages[messages.count - 1].text = "Error: \(error.localizedDescription)" 57 | (inputText, inputImages) = currentInput 58 | } 59 | 60 | generateTask = nil 61 | } 62 | } 63 | 64 | func cancelGeneration() { 65 | generateTask?.cancel() 66 | generateTask = nil 67 | } 68 | 69 | func clearMessages() { 70 | messages.removeAll() 71 | } 72 | } 73 | 74 | extension [ChatMessage] { 75 | func llmMessages() -> [LLMInput.Message] { 76 | map { message in 77 | var role: LLMInput.Message.Role { 78 | switch message.role { 79 | case .user: return .user 80 | case .assistant: return .assistant 81 | case .system: return .system 82 | } 83 | } 84 | let attachments: [LLMAttachment] = message.images.map { image in 85 | .image(image.value) 86 | } 87 | return LLMInput.Message(role: role, content: message.text, attachments: attachments) 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/Configuration/Build.xcconfig: -------------------------------------------------------------------------------- 1 | DISAMBIGUATOR=${DEVELOPMENT_TEAM} 2 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/Downloader.swift: -------------------------------------------------------------------------------- 1 | import SwiftUI 2 | import LocalLLMClientUtility 3 | 4 | struct Downloader: Sendable { 5 | init(model: LLMModel) { 6 | self.model = model 7 | let globs: Globs = switch model { 8 | case .qwen3, .qwen3_4b, .qwen2_5VL_3b: .mlx 9 | case .gemma3, .gemma3_4b, .mobileVLM_3b: .init( 10 | (model.filename.map { [$0] } ?? []) + (model.clipFilename.map { [$0] } ?? []) 11 | )} 12 | #if os(macOS) 13 | downloader = FileDownloader(source: .huggingFace(id: model.id, globs: globs)) 14 | #elseif os(iOS) 15 | downloader = FileDownloader( 16 | source: .huggingFace(id: model.id, globs: globs), 17 | configuration: .background(withIdentifier: "localllmclient.downloader.\(model.id)") 18 | ) 19 | #endif 20 | // try? downloader.removeMetadata() // use it if you update the models 21 | } 22 | 23 | private let model: LLMModel 24 | private let downloader: FileDownloader 25 | 26 | var url: URL { 27 | downloader.destination.appending(component: model.filename ?? "") 28 | } 29 | 30 | var clipURL: URL? { 31 | model.clipFilename.map { downloader.destination.appending(component: $0) } 32 | } 33 | 34 | var isDownloaded: Bool { 35 | downloader.isDownloaded 36 | } 37 | 38 | func download(progressHandler: @escaping @Sendable (Double) async -> Void) async throws { 39 | try await downloader.download { progress in 40 | await progressHandler(progress) 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/Image+.swift: -------------------------------------------------------------------------------- 1 | import SwiftUI 2 | import LocalLLMClient 3 | 4 | extension Image { 5 | init(llm image: LLMInputImage) { 6 | #if os(macOS) 7 | self.init(nsImage: image) 8 | #elseif os(iOS) 9 | self.init(uiImage: image) 10 | #endif 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Example/LocalLLMClientExample/LocalLLMClientExample.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.security.app-sandbox 6 | 7 | com.apple.security.files.user-selected.read-only 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /Example/Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 6.1 2 | 3 | import PackageDescription 4 | 5 | let package = Package( 6 | name: "example", 7 | products: [], 8 | targets: [] 9 | ) 10 | 11 | -------------------------------------------------------------------------------- /Example/README.md: -------------------------------------------------------------------------------- 1 | # Example App 2 | 3 | This example demonstrates how to use the LocalLLMClient to integrate on-device LLMs into an iOS / macOS app. 4 | 5 | 6 | 7 | 8 | 9 | 10 |
example on iOSexample on macOS
11 | 12 | ## Requirements 13 | 14 | - iOS 18.0+ / macOS 15.0+ 15 | - Xcode 16.3+ 16 | - [Recommended]: M1 Mac or newer, or recent iPhone Pro models 17 | 18 | ## Usage 19 | 20 | To run the example app: 21 | 22 | 1. Clone the repository: 23 | ```bash 24 | git clone --recursive https://github.com/tattn/LocalLLMClient 25 | ``` 26 | If you already cloned the repository without `--recursive`, run: 27 | ```bash 28 | git submodule update --init --recursive 29 | ``` 30 | 2. Open `LocalLLMClientExample.xcodeproj` in Xcode 31 | 3. Build and run the app on your device, not a simulator 32 | 33 | *Note: The app requires a physical device* 34 | 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Tatsuya Tanaka 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 6.0 2 | 3 | import PackageDescription 4 | 5 | let llamaVersion = "b5631" 6 | 7 | // MARK: - Package Dependencies 8 | 9 | var packageDependencies: [Package.Dependency] = [ 10 | .package(url: "https://github.com/apple/swift-argument-parser.git", .upToNextMinor(from: "1.4.0")), 11 | .package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.1.0")), 12 | ] 13 | 14 | #if os(iOS) || os(macOS) 15 | packageDependencies.append(contentsOf: [ 16 | .package(url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.21")), 17 | .package(url: "https://github.com/ml-explore/mlx-swift-examples", branch: "main"), 18 | .package(url: "https://github.com/apple/swift-docc-plugin", from: "1.4.0") 19 | ]) 20 | #endif 21 | 22 | // MARK: - Package Products 23 | 24 | var packageProducts: [Product] = [ 25 | .library(name: "LocalLLMClient", targets: ["LocalLLMClient"]) 26 | ] 27 | 28 | #if os(iOS) || os(macOS) 29 | packageProducts.append(contentsOf: [ 30 | .executable(name: "localllm", targets: ["LocalLLMCLI"]), 31 | .library(name: "LocalLLMClientLlama", targets: ["LocalLLMClientLlama"]), 32 | .library(name: "LocalLLMClientMLX", targets: ["LocalLLMClientMLX"]), 33 | .library(name: "LocalLLMClientUtility", targets: ["LocalLLMClientUtility"]) 34 | ]) 35 | #elseif os(Linux) 36 | packageProducts.append(contentsOf: [ 37 | .executable(name: "localllm", targets: ["LocalLLMCLI"]), 38 | .library(name: "LocalLLMClientLlama", targets: ["LocalLLMClientLlama"]), 39 | ]) 40 | #endif 41 | 42 | // MARK: - Package Targets 43 | 44 | var packageTargets: [Target] = [ 45 | .target(name: "LocalLLMClient") 46 | ] 47 | 48 | #if os(iOS) || os(macOS) 49 | packageTargets.append(contentsOf: [ 50 | .executableTarget( 51 | name: "LocalLLMCLI", 52 | dependencies: [ 53 | "LocalLLMClientLlama", 54 | "LocalLLMClientMLX", 55 | "LocalLLMClientUtility", 56 | .product(name: "ArgumentParser", package: "swift-argument-parser"), 57 | ], 58 | linkerSettings: [ 59 | .unsafeFlags(["-rpath", "@executable_path"]) 60 | ] 61 | ), 62 | 63 | .target( 64 | name: "LocalLLMClientLlama", 65 | dependencies: [ 66 | "LocalLLMClient", 67 | "LocalLLMClientLlamaC", 68 | .product(name: "Jinja", package: "Jinja") 69 | ], 70 | resources: [.process("Resources")], 71 | swiftSettings: Context.environment["BUILD_DOCC"] == nil ? [] : [ 72 | .define("BUILD_DOCC") 73 | ] 74 | ), 75 | .testTarget( 76 | name: "LocalLLMClientLlamaTests", 77 | dependencies: ["LocalLLMClientLlama", "LocalLLMClientUtility"] 78 | ), 79 | 80 | .target( 81 | name: "LocalLLMClientMLX", 82 | dependencies: [ 83 | "LocalLLMClient", 84 | .product(name: "MLXLLM", package: "mlx-swift-examples"), 85 | .product(name: "MLXVLM", package: "mlx-swift-examples"), 86 | ], 87 | ), 88 | .testTarget( 89 | name: "LocalLLMClientMLXTests", 90 | dependencies: ["LocalLLMClientMLX", "LocalLLMClientUtility"] 91 | ), 92 | 93 | .binaryTarget( 94 | name: "LocalLLMClientLlamaFramework", 95 | url: 96 | "https://github.com/ggml-org/llama.cpp/releases/download/\(llamaVersion)/llama-\(llamaVersion)-xcframework.zip", 97 | checksum: "ba16de7a2d90050db99cbf521a2ab391c2218404caa5c983670bab86c707dd9f" 98 | ), 99 | .target( 100 | name: "LocalLLMClientLlamaC", 101 | dependencies: ["LocalLLMClientLlamaFramework"], 102 | exclude: ["exclude"], 103 | cSettings: [ 104 | .unsafeFlags(["-w"]) 105 | ], 106 | ), 107 | 108 | .target( 109 | name: "LocalLLMClientUtility", 110 | dependencies: [ 111 | .product(name: "MLXLMCommon", package: "mlx-swift-examples"), 112 | ], 113 | ), 114 | .testTarget( 115 | name: "LocalLLMClientUtilityTests", 116 | dependencies: ["LocalLLMClientUtility"] 117 | ) 118 | ]) 119 | #elseif os(Linux) 120 | packageTargets.append(contentsOf: [ 121 | .executableTarget( 122 | name: "LocalLLMCLI", 123 | dependencies: [ 124 | "LocalLLMClientLlama", 125 | "LocalLLMClientUtility", 126 | .product(name: "ArgumentParser", package: "swift-argument-parser"), 127 | ], 128 | linkerSettings: [ 129 | .unsafeFlags([ 130 | Context.environment["LDFLAGS", default: ""], 131 | ]) 132 | ] 133 | ), 134 | 135 | .target( 136 | name: "LocalLLMClientLlama", 137 | dependencies: [ 138 | "LocalLLMClient", 139 | "LocalLLMClientLlamaC", 140 | .product(name: "Jinja", package: "Jinja") 141 | ], 142 | resources: [.process("Resources")] 143 | ), 144 | .testTarget( 145 | name: "LocalLLMClientLlamaTests", 146 | dependencies: ["LocalLLMClientLlama"], 147 | linkerSettings: [ 148 | .unsafeFlags([ 149 | Context.environment["LDFLAGS", default: ""], 150 | ]) 151 | ] 152 | ), 153 | 154 | .target( 155 | name: "LocalLLMClientLlamaC", 156 | exclude: ["exclude"], 157 | cSettings: [ 158 | .unsafeFlags(["-w"]) 159 | ], 160 | linkerSettings: [ 161 | .unsafeFlags([ 162 | "-lggml-base", "-lggml", "-lllama", "-lmtmd" 163 | ]) 164 | ] 165 | ), 166 | 167 | .target( 168 | name: "LocalLLMClientUtility" 169 | ), 170 | .testTarget( 171 | name: "LocalLLMClientUtilityTests", 172 | dependencies: ["LocalLLMClientUtility"] 173 | ) 174 | ]) 175 | #endif 176 | 177 | // MARK: - Package Definition 178 | 179 | let package = Package( 180 | name: "LocalLLMClient", 181 | platforms: [.iOS(.v16), .macOS(.v14)], 182 | products: packageProducts, 183 | dependencies: packageDependencies, 184 | targets: packageTargets, 185 | cxxLanguageStandard: .cxx20 186 | ) 187 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LocalLLMClient 2 | 3 | [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) 4 | [![CI](https://github.com/tattn/LocalLLMClient/actions/workflows/test.yml/badge.svg)](https://github.com/tattn/LocalLLMClient/actions/workflows/test.yml) 5 | [![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Ftattn%2FLocalLLMClient%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/tattn/LocalLLMClient) 6 | [![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Ftattn%2FLocalLLMClient%2Fbadge%3Ftype%3Dplatforms)](https://swiftpackageindex.com/tattn/LocalLLMClient) 7 | 8 | 9 | A Swift package to interact with local Large Language Models (LLMs) on Apple platforms. 10 | 11 | 12 | 13 | 14 | 15 | 16 |
example on iOSexample on macOS
17 | 18 |
19 | Demo / Multimodal 20 | 21 | | MobileVLM-3B (llama.cpp) | Qwen2.5 VL 3B (MLX) | 22 | |:-:|:-:| 23 | |
28 | 29 | [Example app](https://github.com/tattn/LocalLLMClient/tree/main/Example) 30 | 31 | > [!IMPORTANT] 32 | > This project is still experimental. The API is subject to change. 33 | 34 | ## Features 35 | 36 | - Support for GGUF / MLX models 37 | - Support for iOS and macOS 38 | - Streaming API 39 | - Multimodal (experimental) 40 | 41 | ## Installation 42 | 43 | Add the following dependency to your `Package.swift` file: 44 | 45 | ```swift 46 | dependencies: [ 47 | .package(url: "https://github.com/tattn/LocalLLMClient.git", branch: "main") 48 | ] 49 | ``` 50 | 51 | ## Usage 52 | 53 | The API documentation is available [here](https://tattn.github.io/LocalLLMClient/documentation/). 54 | 55 | ### Basic Usage 56 | 57 |
58 | Using with llama.cpp 59 | 60 | ```swift 61 | import LocalLLMClient 62 | import LocalLLMClientLlama 63 | import LocalLLMClientUtility 64 | 65 | // Download model from Hugging Face (Gemma 3) 66 | let ggufName = "gemma-3-4B-it-QAT-Q4_0.gguf" 67 | let downloader = FileDownloader(source: .huggingFace( 68 | id: "lmstudio-community/gemma-3-4B-it-qat-GGUF", 69 | globs: [ggufName] 70 | )) 71 | 72 | try await downloader.download { print("Progress: \($0)") } 73 | 74 | // Initialize a client with the downloaded model 75 | let modelURL = downloader.destination.appending(component: ggufName) 76 | let client = try await LocalLLMClient.llama(url: modelURL, parameter: .init( 77 | context: 4096, // Context size 78 | temperature: 0.7, // Randomness (0.0〜1.0) 79 | topK: 40, // Top-K sampling 80 | topP: 0.9, // Top-P (nucleus) sampling 81 | options: .init(responseFormat: .json) // Response format 82 | )) 83 | 84 | let prompt = """ 85 | Create the beginning of a synopsis for an epic story with a cat as the main character. 86 | Format it in JSON, as shown below. 87 | { 88 | "title": "", 89 | "content": "<content>", 90 | } 91 | """ 92 | 93 | // Generate text 94 | let input = LLMInput.chat([ 95 | .system("You are a helpful assistant."), 96 | .user(prompt) 97 | ]) 98 | 99 | for try await text in try await client.textStream(from: input) { 100 | print(text, terminator: "") 101 | } 102 | ``` 103 | </details> 104 | 105 | <details> 106 | <summary>Using with Apple MLX</summary> 107 | 108 | ```swift 109 | import LocalLLMClient 110 | import LocalLLMClientMLX 111 | import LocalLLMClientUtility 112 | 113 | // Download model from Hugging Face 114 | let downloader = FileDownloader( 115 | source: .huggingFace(id: "mlx-community/Qwen3-1.7B-4bit", globs: .mlx) 116 | ) 117 | try await downloader.download { print("Progress: \($0)") } 118 | 119 | // Initialize a client with the downloaded model 120 | let client = try await LocalLLMClient.mlx(url: downloader.destination, parameter: .init( 121 | temperature: 0.7, // Randomness (0.0 to 1.0) 122 | topP: 0.9 // Top-P (nucleus) sampling 123 | )) 124 | 125 | // Generate text 126 | let input = LLMInput.chat([ 127 | .system("You are a helpful assistant."), 128 | .user("Tell me a story about a cat.") 129 | ]) 130 | 131 | for try await text in try await client.textStream(from: input) { 132 | print(text, terminator: "") 133 | } 134 | ``` 135 | </details> 136 | 137 | ### Multimodal for Image 138 | 139 | LocalLLMClient supports multimodal models like LLaVA for processing images along with text prompts. 140 | 141 | <details open> 142 | <summary>Using with llama.cpp</summary> 143 | 144 | ```swift 145 | import LocalLLMClient 146 | import LocalLLMClientLlama 147 | import LocalLLMClientUtility 148 | 149 | // Download model from Hugging Face (Gemma 3) 150 | let model = "gemma-3-4b-it-Q8_0.gguf" 151 | let mmproj = "mmproj-model-f16.gguf" 152 | 153 | let downloader = FileDownloader( 154 | source: .huggingFace(id: "ggml-org/gemma-3-4b-it-GGUF", globs: [model, mmproj]), 155 | ) 156 | try await downloader.download { print("Download: \($0)") } 157 | 158 | // Initialize a client with the downloaded model 159 | let client = try await LocalLLMClient.llama( 160 | url: downloader.destination.appending(component: model), 161 | mmprojURL: downloader.destination.appending(component: mmproj) 162 | ) 163 | 164 | let input = LLMInput.chat([ 165 | .user("What's in this image?", attachments: [.image(.init(resource: .yourImage))]), 166 | ]) 167 | 168 | // Generate text without streaming 169 | print(try await client.generateText(from: input)) 170 | ``` 171 | </details> 172 | 173 | <details> 174 | <summary>Using with Apple MLX</summary> 175 | 176 | ```swift 177 | import LocalLLMClient 178 | import LocalLLMClientMLX 179 | import LocalLLMClientUtility 180 | 181 | // Download model from Hugging Face (Qwen2.5 VL) 182 | let downloader = FileDownloader(source: .huggingFace( 183 | id: "mlx-community/Qwen2.5-VL-3B-Instruct-abliterated-4bit", 184 | globs: .mlx 185 | )) 186 | try await downloader.download { print("Progress: \($0)") } 187 | 188 | let client = try await LocalLLMClient.mlx(url: downloader.destination) 189 | 190 | let input = LLMInput.chat([ 191 | .user("What's in this image?", attachments: [.image(.init(resource: .yourImage))]), 192 | ]) 193 | 194 | // Generate text without streaming 195 | print(try await client.generateText(from: input)) 196 | ``` 197 | </details> 198 | 199 | ### Utility 200 | 201 | - `FileDownloader`: A utility to download models with progress tracking. 202 | 203 | ### CLI tool 204 | 205 | You can use LocalLLMClient directly from the terminal using the command line tool: 206 | 207 | ```bash 208 | # Run using llama.cpp 209 | swift run localllm --model /path/to/your/model.gguf "Your prompt here" 210 | 211 | # Run using MLX 212 | ./scripts/run_mlx.sh --model https://huggingface.co/mlx-community/Qwen3-1.7B-4bit "Your prompt here" 213 | ``` 214 | 215 | ## Tested models 216 | 217 | - LLaMA 3 218 | - Gemma 3 / 2 219 | - Qwen 3 / 2 220 | - Phi 4 221 | 222 | 223 | > [Models compatible with llama.cpp backend](https://github.com/ggml-org/llama.cpp?tab=readme-ov-file#text-only) 224 | > [Models compatible with MLX backend](https://github.com/ml-explore/mlx-swift-examples/blob/main/Libraries/MLXLLM/Documentation.docc/Documentation.md) 225 | 226 | *If you have a model that works, please open an issue or PR to add it to the list.* 227 | 228 | ## Requirements 229 | 230 | - iOS 16.0+ / macOS 14.0+ 231 | - Xcode 16.0+ 232 | 233 | ## Acknowledgements 234 | 235 | This package uses [llama.cpp](https://github.com/ggml-org/llama.cpp) and [Apple's MLX](https://opensource.apple.com/projects/mlx/) for model inference. 236 | 237 | --- 238 | 239 | [Support this project :heart:](https://github.com/sponsors/tattn) 240 | -------------------------------------------------------------------------------- /Sources/LocalLLMCLI/command.swift: -------------------------------------------------------------------------------- 1 | #if os(Linux) 2 | // Workaround: https://github.com/swiftlang/swift/issues/77866 3 | @preconcurrency import var Glibc.stdout 4 | #endif 5 | import ArgumentParser 6 | import Foundation 7 | #if canImport(FoundationNetworking) 8 | import FoundationNetworking 9 | #endif 10 | import LocalLLMClient 11 | import LocalLLMClientLlama 12 | #if canImport(LocalLLMClientMLX) 13 | import LocalLLMClientMLX 14 | #endif 15 | #if canImport(LocalLLMClientUtility) 16 | import LocalLLMClientUtility 17 | #endif 18 | 19 | @main 20 | struct LocalLLMCommand: AsyncParsableCommand { 21 | nonisolated(unsafe) static var configuration = CommandConfiguration( 22 | commandName: "localllm", 23 | abstract: "A command line tool for interacting with local LLMs", 24 | discussion: """ 25 | Run LLM models directly from your command line. 26 | """ 27 | ) 28 | 29 | @Option(name: [.short, .long], help: "Path to the model file") 30 | var model: String 31 | 32 | @Option(name: [.short, .long], help: "Backend to use: \(Backend.allCases.map(\.rawValue).joined(separator: ", "))") 33 | var backend: String = Backend.llama.rawValue 34 | 35 | @Option(name: [.short, .long], help: "Temperature for sampling") 36 | var temperature: Float = 0.8 37 | 38 | @Option(name: [.customShort("p"), .long], help: "Top-p for sampling") 39 | var topP: Float = 0.9 40 | 41 | @Option(name: [.customShort("k"), .long], help: "Top-k for sampling") 42 | var topK: Int = 40 43 | 44 | @Option(name: [.long], help: "Path to the mmproj") 45 | var mmproj: String? 46 | @Option(name: [.customLong("image")], help: "Path to the image file") 47 | var imageURL: String? 48 | 49 | @Flag(name: [.customShort("v"), .long], help: "Show verbose output") 50 | var verbose: Bool = false 51 | 52 | @Argument(help: "The prompt to send to the model") 53 | var prompt: String 54 | 55 | enum Backend: String, CaseIterable { 56 | case llama 57 | case mlx 58 | } 59 | 60 | func run() async throws { 61 | let backend = Backend(rawValue: backend) ?? .llama 62 | log("Loading model from: \(model) with backend: \(backend.rawValue)") 63 | 64 | let modelURL = try await getModel(for: model, backend: backend) 65 | 66 | // Initialize client 67 | let client: any LLMClient 68 | switch backend { 69 | case .llama: 70 | client = try await LocalLLMClient.llama( 71 | url: modelURL, 72 | mmprojURL: mmproj.asyncMap { try await getModel(for: $0, backend: backend) }, 73 | parameter: .init( 74 | temperature: temperature, 75 | topK: topK, 76 | topP: topP, 77 | ), 78 | verbose: verbose 79 | ) 80 | case .mlx: 81 | #if canImport(LocalLLMClientMLX) 82 | client = try await LocalLLMClient.mlx( 83 | url: modelURL, 84 | parameter: .init( 85 | temperature: temperature, 86 | topP: topP, 87 | ) 88 | ) 89 | #else 90 | throw LocalLLMCommandError.invalidModel("MLX backend is not supported on this platform.") 91 | #endif 92 | } 93 | 94 | var attachments: [LLMAttachment] = [] 95 | if let imageURL { 96 | attachments.append(.image(LLMInputImage(data: try Data(contentsOf: URL(filePath: imageURL)))!)) 97 | } 98 | 99 | log("Generating response for prompt: \"\(prompt)\"") 100 | log("---") 101 | 102 | let input = LLMInput( 103 | .chat([.user(prompt, attachments: attachments)]), 104 | ) 105 | 106 | // Generate response 107 | for try await token in try await client.textStream(from: input) { 108 | print(token, terminator: "") 109 | fflush(stdout) 110 | } 111 | 112 | log("\n---") 113 | log("Generation complete.") 114 | } 115 | 116 | private func getModel(for model: String, backend: Backend) async throws -> URL { 117 | return if model.hasPrefix("/") { 118 | URL(filePath: model) 119 | } else if model.hasPrefix("https://"), let url = URL(string: model) { 120 | try await downloadModel(from: url, backend: backend) 121 | } else { 122 | throw LocalLLMCommandError.invalidModel(model) 123 | } 124 | } 125 | 126 | private func downloadModel(from url: URL, backend: Backend) async throws -> URL { 127 | #if canImport(LocalLLMClientUtility) 128 | log("Downloading model from Hugging Face: \(model)") 129 | 130 | let globs: Globs = switch backend { 131 | case .llama: .init(["*\(url.lastPathComponent)"]) 132 | case .mlx: .mlx 133 | } 134 | 135 | let downloader = FileDownloader(source: .huggingFace( 136 | id: url.pathComponents[1...2].joined(separator: "/"), 137 | globs: globs 138 | )) 139 | try await downloader.download { progress in 140 | log("Downloading model: \(progress)") 141 | } 142 | return switch backend { 143 | case .llama: downloader.destination.appendingPathComponent(url.lastPathComponent) 144 | case .mlx: downloader.destination 145 | } 146 | #else 147 | throw LocalLLMCommandError.invalidModel("Downloading models is not supported on this platform.") 148 | #endif 149 | } 150 | 151 | private func log(_ message: String) { 152 | if verbose { 153 | print(message) 154 | } 155 | } 156 | } 157 | 158 | enum LocalLLMCommandError: Error { 159 | case invalidModel(String) 160 | } 161 | -------------------------------------------------------------------------------- /Sources/LocalLLMClient/Async+.swift: -------------------------------------------------------------------------------- 1 | package extension Array { 2 | /// Maps each element of the array asynchronously using the provided transform function. 3 | /// 4 | /// This method preserves the order of elements in the resulting array, applying the 5 | /// transformation function to each element in sequence. 6 | /// 7 | /// - Parameter transform: A function that transforms an element of the array asynchronously. 8 | /// - Returns: An array containing the transformed elements. 9 | /// - Throws: Rethrows any errors thrown by the transform function. 10 | func asyncMap<T>(_ transform: (Element) async throws -> T) async rethrows -> [T] { 11 | var results: [T] = [] 12 | results.reserveCapacity(count) 13 | for element in self { 14 | results.append(try await transform(element)) 15 | } 16 | return results 17 | } 18 | } 19 | 20 | package extension Optional { 21 | /// Maps the wrapped value asynchronously using the provided transform function if it exists. 22 | /// 23 | /// - Parameter transform: A function that transforms the wrapped value asynchronously. 24 | /// - Returns: The transformed value if the original optional contains a value, otherwise nil. 25 | /// - Throws: Rethrows any errors thrown by the transform function. 26 | func asyncMap<T>(_ transform: (Wrapped) async throws -> T) async rethrows -> T? { 27 | if let value = self { 28 | return try await transform(value) 29 | } else { 30 | return nil 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /Sources/LocalLLMClient/Docs.docc/index.md: -------------------------------------------------------------------------------- 1 | # ``LocalLLMClient`` 2 | 3 | A Swift package to interact with local Large Language Models (LLMs) on Apple platforms. 4 | 5 | ## Example 6 | 7 | ```swift 8 | import LocalLLMClient 9 | import LocalLLMClientMLX 10 | import LocalLLMClientUtility 11 | 12 | // Download model from Hugging Face 13 | let downloader = FileDownloader( 14 | source: .huggingFace(id: "mlx-community/Qwen3-1.7B-4bit", globs: .mlx) 15 | ) 16 | let modelURL = try await downloader.download { print("Progress: \($0)") } 17 | 18 | // Initialize a client with the downloaded model 19 | let client = try await LocalLLMClient.mlx(url: modelURL) 20 | 21 | // Generate text 22 | let prompt = "Tell me a story about a cat" 23 | let text = try await client.generateText(from: prompt) 24 | print(text) 25 | ``` 26 | 27 | ```swift 28 | // Streaming text 29 | for try await text in try await client.textStream(from: prompt) { 30 | print(text, terminator: "") 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /Sources/LocalLLMClient/LLMClient.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// A protocol representing a client for LLM 4 | public protocol LLMClient: Sendable { 5 | associatedtype TextGenerator: AsyncSequence & Sendable where Self.TextGenerator.Element == String 6 | 7 | /// Processes the provided input and returns the complete generated text 8 | /// - Parameter input: The input to process 9 | /// - Returns: The complete generated text 10 | /// - Throws: An error if text generation fails 11 | func generateText(from input: LLMInput) async throws -> String 12 | 13 | /// Processes the provided input and returns a stream of text tokens asynchronously 14 | /// - Parameter input: The input to process 15 | /// - Returns: An asynchronous sequence that emits text tokens 16 | /// - Throws: An error if text generation fails 17 | func textStream(from input: LLMInput) async throws -> TextGenerator 18 | } 19 | 20 | public extension LLMClient { 21 | /// Processes the provided input and returns the complete generated text 22 | /// - Parameter input: The input to process 23 | /// - Returns: The complete generated text 24 | /// - Throws: An error if text generation fails 25 | func generateText(from input: LLMInput) async throws -> String { 26 | var finalResult = "" 27 | for try await token in try await textStream(from: input) as TextGenerator { 28 | finalResult += token 29 | } 30 | return finalResult 31 | } 32 | 33 | /// Convenience method to generate text from a plain string input 34 | /// - Parameter input: The plain text input string 35 | /// - Returns: The complete generated text as a String 36 | /// - Throws: An error if text generation fails 37 | func generateText(from input: String) async throws -> String { 38 | try await generateText(from: .init(.plain(input))) 39 | } 40 | 41 | /// Convenience method to stream text from a plain string input 42 | /// - Parameter input: The plain text input string 43 | /// - Returns: An asynchronous sequence that emits text tokens 44 | /// - Throws: An error if text generation fails 45 | func textStream(from input: String) async throws -> TextGenerator { 46 | try await textStream(from: .init(.plain(input))) 47 | } 48 | } 49 | 50 | /// Namespace for local LLM client implementations 51 | public enum LocalLLMClient {} 52 | 53 | /// A type-erased wrapper around any LLMClient implementation 54 | public struct AnyLLMClient: LLMClient { 55 | /// The underlying LLM client 56 | public let client: any LLMClient 57 | 58 | /// Creates a new type-erased wrapper around an LLMClient 59 | /// - Parameter client: The LLM client to wrap 60 | public init(_ client: any LLMClient) { 61 | self.client = client 62 | } 63 | 64 | /// Processes the provided input and returns a stream of text tokens asynchronously 65 | /// - Parameter input: The input to process 66 | /// - Returns: An AsyncThrowingStream that emits text tokens 67 | /// - Throws: An error if text generation fails 68 | public func textStream(from input: LLMInput) async throws -> AsyncThrowingStream<String, Swift.Error> { 69 | AsyncThrowingStream { continuation in 70 | let task = Task { 71 | do { 72 | for try await text in try await client.textStream(from: input) { 73 | continuation.yield(text as! String) 74 | } 75 | continuation.finish() 76 | } catch { 77 | continuation.finish(throwing: error) 78 | } 79 | } 80 | continuation.onTermination = { _ in 81 | task.cancel() 82 | } 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /Sources/LocalLLMClient/LLMError.swift: -------------------------------------------------------------------------------- 1 | /// Errors that can occur when interacting with the Local LLM Client. 2 | public enum LLMError: Swift.Error { 3 | /// Indicates that the model failed to load. 4 | /// - Parameter reason: A description of why the model loading failed. 5 | case failedToLoad(reason: String) 6 | 7 | /// Indicates that an invalid parameter was provided to the LLM. 8 | /// - Parameter reason: A description of the invalid parameter. 9 | case invalidParameter(reason: String) 10 | 11 | /// Indicates that the LLM response could not be decoded. 12 | case failedToDecode(reason: String) 13 | 14 | /// Indicates that vision features are not supported by the current model or configuration. 15 | case visionUnsupported 16 | } 17 | -------------------------------------------------------------------------------- /Sources/LocalLLMClient/LLMInput.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// A structure representing various types of inputs for LLMs. 4 | /// 5 | /// `LLMInput` encapsulates different formats of input data that can be provided to a language model: 6 | /// - Plain text input 7 | /// - Custom chat template format 8 | /// - Chat messages with defined roles (system, user, assistant) 9 | /// 10 | /// Example usage: 11 | /// ```swift 12 | /// // Plain text input 13 | /// let plainInput = LLMInput.plain("Hello, how can I help you?") 14 | /// 15 | /// // Chat template input 16 | /// let templateInput = LLMInput.chatTemplate([ 17 | /// .init(value: ["role": "user", "content": "Hello"]) 18 | /// ]) 19 | /// 20 | /// // Chat messages input 21 | /// let chatInput = LLMInput.chat([ 22 | /// .system("You are a helpful assistant"), 23 | /// .user("Tell me about Swift") 24 | /// ]) 25 | /// ``` 26 | public struct LLMInput: Sendable { 27 | /// Initializes an input with the specified value. 28 | /// 29 | /// - Parameter value: The input value 30 | public init(_ value: Input) { 31 | self.value = value 32 | } 33 | 34 | /// Creates a plain text input. 35 | /// 36 | /// - Parameter value: The text string to use as input. 37 | /// - Returns: An `LLMInput` instance with plain text input. 38 | public static func plain(_ value: String) -> LLMInput { 39 | .init(.plain(value)) 40 | } 41 | 42 | /// Creates a custom chat template input. 43 | /// 44 | /// - Parameter messages: An array of chat template messages. 45 | /// - Returns: An `LLMInput` instance with chat template input. 46 | public static func chatTemplate(_ messages: [ChatTemplateMessage]) -> LLMInput { 47 | .init(.chatTemplate(messages)) 48 | } 49 | 50 | /// Creates a chat input with role-based messages. 51 | /// 52 | /// - Parameter messages: An array of messages with defined roles. 53 | /// - Returns: An `LLMInput` instance with chat messages input. 54 | public static func chat(_ messages: [Message]) -> LLMInput { 55 | .init(.chat(messages)) 56 | } 57 | 58 | /// The underlying input value. 59 | public var value: Input 60 | 61 | /// Enumeration representing the different types of inputs that can be provided to a language model. 62 | public enum Input: Sendable { 63 | /// Plain text input, e.g., "hello" 64 | case plain(String) 65 | 66 | /// Chat template input format with structured messages. 67 | /// Example: [.init(value: ["role": "user", "content": "hello", "type": "text"])] 68 | case chatTemplate(_ messages: [ChatTemplateMessage]) 69 | 70 | /// Role-based chat messages. 71 | /// Example: [.init(role: .user, content: "hello")] 72 | case chat([Message]) 73 | } 74 | } 75 | 76 | /// Enables creating an `LLMInput` directly from a string literal. 77 | extension LLMInput: ExpressibleByStringLiteral { 78 | public init(stringLiteral value: String) { 79 | self.value = .plain(value) 80 | } 81 | } 82 | 83 | /// Represents different types of attachments that can be included with messages. 84 | public enum LLMAttachment: @unchecked Sendable, Hashable { 85 | /// An image attachment. 86 | case image(LLMInputImage) 87 | } 88 | 89 | public extension LLMInput { 90 | /// A structure representing a message in chat template format. 91 | /// 92 | /// Chat template messages are structured as key-value pairs that can be used 93 | /// by language models that expect input in a specific format. 94 | struct ChatTemplateMessage: Sendable { 95 | /// Initializes a chat template message. 96 | /// 97 | /// - Parameters: 98 | /// - value: A dictionary of key-value pairs representing the message structure. 99 | /// - attachments: Optional attachments to include with the message. 100 | public init( 101 | value: [String: any Sendable], 102 | attachments: [LLMAttachment] = [] 103 | ) { 104 | self.value = value 105 | self.attachments = attachments 106 | } 107 | 108 | /// The key-value pairs representing the message structure. 109 | public var value: [String: any Sendable] 110 | 111 | /// Attachments associated with this message. 112 | public var attachments: [LLMAttachment] 113 | } 114 | 115 | /// A structure representing a role-based message in a conversation. 116 | /// 117 | /// Each message has a role (system, user, assistant, or custom), content text, 118 | /// and optional attachments such as images. 119 | struct Message: Sendable, Hashable { 120 | /// Initializes a message with the specified role and content. 121 | /// 122 | /// - Parameters: 123 | /// - role: The role of the message sender. 124 | /// - content: The text content of the message. 125 | /// - attachments: Optional attachments to include with the message. 126 | public init( 127 | role: Role, 128 | content: String, 129 | attachments: [LLMAttachment] = [] 130 | ) { 131 | self.role = role 132 | self.content = content 133 | self.attachments = attachments 134 | } 135 | 136 | /// Creates a system message. 137 | /// 138 | /// System messages provide instructions or context to the language model. 139 | /// 140 | /// - Parameter content: The text content of the system message. 141 | /// - Returns: A new `Message` instance with system role. 142 | public static func system(_ content: String) -> Message { 143 | .init(role: .system, content: content) 144 | } 145 | 146 | /// Creates a user message. 147 | /// 148 | /// User messages represent input from the end-user. 149 | /// 150 | /// - Parameters: 151 | /// - content: The text content of the user message. 152 | /// - attachments: Optional attachments to include with the message. 153 | /// - Returns: A new `Message` instance with user role. 154 | public static func user(_ content: String, attachments: [LLMAttachment] = []) -> Message { 155 | .init(role: .user, content: content, attachments: attachments) 156 | } 157 | 158 | /// Creates an assistant message. 159 | /// 160 | /// Assistant messages represent responses from the language model. 161 | /// 162 | /// - Parameters: 163 | /// - content: The text content of the assistant message. 164 | /// - attachments: Optional attachments to include with the message. 165 | /// - Returns: A new `Message` instance with assistant role. 166 | public static func assistant(_ content: String, attachments: [LLMAttachment] = []) -> Message { 167 | .init(role: .assistant, content: content, attachments: attachments) 168 | } 169 | 170 | /// The role of the message sender. 171 | public var role: Role 172 | 173 | /// The text content of the message. 174 | public var content: String 175 | 176 | /// Attachments associated with this message. 177 | public var attachments: [LLMAttachment] 178 | 179 | /// Enumeration representing the role of a message sender in a conversation. 180 | public enum Role: Sendable, Hashable { 181 | /// System role, typically used for instructions or context. 182 | case system 183 | 184 | /// User role, representing end-user input. 185 | case user 186 | 187 | /// Assistant role, representing language model responses. 188 | case assistant 189 | 190 | /// Custom role with a specified name. 191 | case custom(String) 192 | 193 | /// The string representation of the role. 194 | public var rawValue: String { 195 | switch self { 196 | case .system: "system" 197 | case .user: "user" 198 | case .assistant: "assistant" 199 | case .custom(let value): value 200 | } 201 | } 202 | } 203 | } 204 | } 205 | 206 | #if os(macOS) 207 | import class CoreImage.CIImage 208 | @preconcurrency import class AppKit.NSImage 209 | @preconcurrency import class AppKit.NSBitmapImageRep 210 | /// On macOS, represents an image that can be used as input to a language model. 211 | public typealias LLMInputImage = NSImage 212 | 213 | /// Converts an image to PNG data. 214 | /// 215 | /// - Parameter image: The image to convert. 216 | /// - Returns: PNG data representation of the image. 217 | /// - Throws: `LLMError.failedToLoad` if the conversion fails. 218 | package func llmInputImageToData(_ image: LLMInputImage) throws(LLMError) -> Data { 219 | guard let cgImage = image.cgImage(forProposedRect: nil, context: nil, hints: nil) 220 | else { throw LLMError.failedToLoad(reason: "Failed to load image") } 221 | let imageRep = NSBitmapImageRep(cgImage: cgImage) 222 | imageRep.size = image.size 223 | guard let result = imageRep.representation(using: .png, properties: [:]) else { 224 | throw LLMError.failedToLoad(reason: "Failed to convert image to PNG") 225 | } 226 | return result 227 | } 228 | 229 | /// Converts an image to a CIImage. 230 | /// 231 | /// - Parameter image: The image to convert. 232 | /// - Returns: CIImage representation of the image. 233 | /// - Throws: `LLMError.failedToLoad` if the conversion fails. 234 | package func llmInputImageToCIImage(_ image: LLMInputImage) throws(LLMError) -> CIImage { 235 | guard let imageData = image.tiffRepresentation, let ciImage = CIImage(data: imageData) else { 236 | throw LLMError.failedToLoad(reason: "Failed to load image") 237 | } 238 | return ciImage 239 | } 240 | #elseif os(iOS) 241 | import class CoreImage.CIImage 242 | @preconcurrency import class UIKit.UIImage 243 | /// On iOS, represents an image that can be used as input to a language model. 244 | public typealias LLMInputImage = UIImage 245 | 246 | /// Converts an image to PNG data. 247 | /// 248 | /// - Parameter image: The image to convert. 249 | /// - Returns: PNG data representation of the image. 250 | /// - Throws: `LLMError.failedToLoad` if the conversion fails. 251 | package func llmInputImageToData(_ image: LLMInputImage) throws(LLMError) -> Data { 252 | guard let data = image.pngData() else { 253 | throw LLMError.failedToLoad(reason: "Failed to convert image to PNG") 254 | } 255 | return data 256 | } 257 | 258 | /// Converts an image to a CIImage. 259 | /// 260 | /// - Parameter image: The image to convert. 261 | /// - Returns: CIImage representation of the image. 262 | /// - Throws: `LLMError.failedToLoad` if the conversion fails. 263 | package func llmInputImageToCIImage(_ image: LLMInputImage) throws(LLMError) -> CIImage { 264 | guard let ciImage = CIImage(image: image) else { 265 | throw LLMError.failedToLoad(reason: "Failed to load image") 266 | } 267 | return ciImage 268 | } 269 | #else 270 | public struct LLMInputImage: Sendable, Equatable, Hashable { 271 | package let data: Data? 272 | 273 | /// Initializes an empty image. 274 | public init() { 275 | data = nil 276 | } 277 | 278 | /// Initializes an image with data. 279 | public init?(data: Data) { 280 | self.data = data 281 | } 282 | } 283 | 284 | /// Converts an image to data. 285 | /// 286 | /// - Parameter image: The image to convert. 287 | /// - Returns: Data representation of the image. 288 | /// - Throws: `LLMError.failedToLoad` if the conversion fails. 289 | package func llmInputImageToData(_ image: LLMInputImage) throws(LLMError) -> Data { 290 | guard let image = image.data else { 291 | throw LLMError.failedToLoad(reason: "data is nil") 292 | } 293 | return image 294 | } 295 | #endif 296 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Batch.swift: -------------------------------------------------------------------------------- 1 | #if BUILD_DOCC 2 | @preconcurrency @_implementationOnly import llama 3 | #elseif canImport(llama) 4 | @preconcurrency private import llama 5 | #else 6 | @preconcurrency import LocalLLMClientLlamaC 7 | #endif 8 | 9 | package extension llama_batch { 10 | mutating func clear() { 11 | n_tokens = 0 12 | } 13 | 14 | mutating func add(id: llama_token, pos: llama_pos, seq_ids: [llama_seq_id], logits: Bool) { 15 | self.token[Int(n_tokens)] = id 16 | self.pos[Int(n_tokens)] = pos 17 | self.n_seq_id[Int(n_tokens)] = Int32(seq_ids.count) 18 | 19 | for i in 0..<seq_ids.count { 20 | self.seq_id[Int(n_tokens)]![Int(i)] = seq_ids[i] 21 | } 22 | 23 | self.logits[Int(n_tokens)] = logits ? 1 : 0 24 | 25 | self.n_tokens += 1 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Context.swift: -------------------------------------------------------------------------------- 1 | #if BUILD_DOCC 2 | @preconcurrency @_implementationOnly import llama 3 | #elseif canImport(llama) 4 | @preconcurrency private import llama 5 | #else 6 | @preconcurrency import LocalLLMClientLlamaC 7 | #endif 8 | import Foundation 9 | import LocalLLMClient 10 | 11 | public final class Context: @unchecked Sendable { 12 | let parameter: LlamaClient.Parameter 13 | package let context: OpaquePointer 14 | package var batch: llama_batch 15 | var sampling: Sampler 16 | let grammer: Sampler? 17 | let cursorPointer: UnsafeMutableBufferPointer<llama_token_data> 18 | let model: Model 19 | let extraEOSTokens: Set<String> 20 | private var promptCaches: [(chunk: MessageChunk, lastPosition: llama_pos)] = [] 21 | 22 | package var vocab: OpaquePointer { 23 | model.vocab 24 | } 25 | 26 | package var numberOfBatch: Int32 { 27 | Int32(llama_n_batch(context)) 28 | } 29 | 30 | package var position: Int32 { 31 | llama_kv_self_seq_pos_max(context, 0) + 1 32 | } 33 | 34 | public init(url: URL, parameter: LlamaClient.Parameter = .default) throws(LLMError) { 35 | initializeLlama() 36 | 37 | var ctx_params = llama_context_default_params() 38 | ctx_params.n_ctx = UInt32(parameter.context) 39 | ctx_params.n_threads = Int32(parameter.numberOfThreads ?? max(1, min(8, ProcessInfo.processInfo.processorCount - 2))) 40 | ctx_params.n_threads_batch = ctx_params.n_threads 41 | 42 | self.parameter = parameter 43 | self.model = try Model(url: url) 44 | self.context = try model.makeAndAllocateContext(with: ctx_params) 45 | batch = llama_batch_init(Int32(parameter.batch), 0, 1) 46 | extraEOSTokens = parameter.options.extraEOSTokens 47 | 48 | // https://github.com/ggml-org/llama.cpp/blob/master/common/sampling.cpp 49 | sampling = llama_sampler_chain_init(llama_sampler_chain_default_params()) 50 | let minKeep = 0 51 | let penaltyFreq: Float = 0 52 | let penaltyPresent: Float = 0 53 | llama_sampler_chain_add(sampling, llama_sampler_init_temp(parameter.temperature)) 54 | llama_sampler_chain_add(sampling, llama_sampler_init_dist(parameter.seed.map(UInt32.init) ?? LLAMA_DEFAULT_SEED)) 55 | llama_sampler_chain_add(sampling, llama_sampler_init_top_k(Int32(parameter.topK))) 56 | llama_sampler_chain_add(sampling, llama_sampler_init_top_p(parameter.topP, minKeep)) 57 | llama_sampler_chain_add(sampling, llama_sampler_init_min_p(1 - parameter.topP, 1)) 58 | llama_sampler_chain_add(sampling, llama_sampler_init_typical(parameter.typicalP, minKeep)) 59 | llama_sampler_chain_add(sampling, llama_sampler_init_penalties(Int32(parameter.penaltyLastN), parameter.penaltyRepeat, penaltyFreq, penaltyPresent)) 60 | 61 | cursorPointer = .allocate(capacity: Int(llama_vocab_n_tokens(model.vocab))) 62 | 63 | if let format = parameter.options.responseFormat { 64 | switch format { 65 | case .json: 66 | do { 67 | let template = try String(contentsOf: Bundle.module.url(forResource: "json", withExtension: "gbnf")!, encoding: .utf8) 68 | grammer = llama_sampler_init_grammar(model.vocab, template, "root") 69 | } catch { 70 | throw .failedToLoad(reason: "Failed to load grammar template") 71 | } 72 | case let .grammar(grammar, root): 73 | grammer = llama_sampler_init_grammar(model.vocab, grammar, root) 74 | } 75 | llama_sampler_chain_add(sampling, grammer) 76 | } else { 77 | grammer = nil 78 | } 79 | } 80 | 81 | deinit { 82 | cursorPointer.deallocate() 83 | llama_sampler_free(sampling) 84 | llama_batch_free(batch) 85 | llama_free(context) 86 | } 87 | 88 | public func clear() { 89 | llama_kv_self_clear(context) 90 | } 91 | 92 | func addCache(for chunk: MessageChunk, position: llama_pos) { 93 | let endIndex = promptCaches.endIndex - 1 94 | switch (chunk, promptCaches.last?.chunk) { 95 | case let (.text(chunkText), .text(cacheText)): 96 | promptCaches[endIndex] = (chunk: .text(cacheText + chunkText), lastPosition: position) 97 | case let (.image(chunkImages), .image(cacheImages)): 98 | promptCaches[endIndex] = (chunk: .image(cacheImages + chunkImages), lastPosition: position) 99 | case let (.video(chunkVideos), .video(cacheVideos)): 100 | promptCaches[endIndex] = (chunk: .video(cacheVideos + chunkVideos), lastPosition: position) 101 | default: 102 | promptCaches.append((chunk: chunk, lastPosition: position)) 103 | } 104 | } 105 | 106 | func removeCachedChunks(_ chunks: inout [MessageChunk]) { 107 | guard let (lastCacheIndex, newChunk) = lastCacheIndex(of: chunks) else { 108 | return 109 | } 110 | chunks = Array(chunks[(lastCacheIndex + 1)...]) 111 | if let newChunk { 112 | chunks.append(newChunk) 113 | } 114 | if promptCaches[lastCacheIndex].lastPosition < position { 115 | assert(llama_kv_self_seq_rm(context, 0, promptCaches[lastCacheIndex].lastPosition, position)) 116 | } 117 | if promptCaches.count > lastCacheIndex { 118 | promptCaches.removeSubrange((lastCacheIndex + 1)...) 119 | } 120 | } 121 | 122 | func lastCacheIndex(of chunks: [MessageChunk]) -> (index: Int, remaining: MessageChunk?)? { 123 | for (index, (chunk, cache)) in zip(chunks, promptCaches).enumerated() { 124 | switch (chunk, cache.chunk) { 125 | case let (.text(chunkText), .text(cacheText)) where chunkText.hasPrefix(cacheText): 126 | if chunkText == cacheText { 127 | return (index, nil) 128 | } else { 129 | return (index, .text(String(chunkText.dropFirst(cacheText.count)))) 130 | } 131 | case let (.image(chunkImages), .image(cacheImages)) where chunkImages == cacheImages: 132 | return (index, nil) 133 | case let (.video(chunkVideos), .video(cacheVideos)) where chunkVideos == cacheVideos: 134 | return (index, nil) 135 | default: 136 | break 137 | } 138 | } 139 | return nil 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Decoder.swift: -------------------------------------------------------------------------------- 1 | #if BUILD_DOCC 2 | @preconcurrency @_implementationOnly import llama 3 | #elseif canImport(llama) 4 | @preconcurrency private import llama 5 | #else 6 | @preconcurrency import LocalLLMClientLlamaC 7 | #endif 8 | import LocalLLMClient 9 | 10 | public extension Context { 11 | @discardableResult 12 | func decode() throws(LLMError) -> Int32 { 13 | let numberOfTokens = batch.n_tokens 14 | guard batch.n_tokens > 0 else { 15 | return 0 // no data to decode 16 | } 17 | 18 | batch.logits[Int(batch.n_tokens) - 1] = 1 19 | 20 | guard llama_decode(context, batch) == 0 else { 21 | throw .failedToDecode(reason: "batch decode failed") 22 | } 23 | 24 | batch.clear() 25 | 26 | return numberOfTokens 27 | } 28 | 29 | func decode(text: String) throws(LLMError) { 30 | let position = position 31 | let tokens = [llama_token](text, addBos: false, special: true, vocab: vocab) 32 | for (index, token) in tokens.enumerated() { 33 | batch.add(id: token, pos: llama_pos(index) + position, seq_ids: [0], logits: false) 34 | } 35 | try decode() 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Generator.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | #if BUILD_DOCC 3 | @preconcurrency @_implementationOnly import llama 4 | #elseif canImport(llama) 5 | @preconcurrency private import llama 6 | #else 7 | @preconcurrency import LocalLLMClientLlamaC 8 | #endif 9 | import LocalLLMClient 10 | 11 | public struct Generator: AsyncSequence, Sendable { 12 | public init(context: Context) { 13 | self.context = context 14 | } 15 | 16 | let context: Context 17 | 18 | public func makeAsyncIterator() -> TokenGenerator { 19 | TokenGenerator(context: context) 20 | } 21 | } 22 | 23 | public struct TokenGenerator: AsyncIteratorProtocol { 24 | init(context: Context) { 25 | self.context = context 26 | } 27 | 28 | private let context: Context 29 | private var temporaryInvalidCharacters: [CChar] = [] 30 | private var currentResult = "" 31 | 32 | mutating public func next() async throws -> String? { 33 | if Task.isCancelled { 34 | updatePromptCache() 35 | return nil 36 | } 37 | 38 | try context.decode() 39 | 40 | let newTokenId = context.sampling.sample(context: context, index: -1) 41 | 42 | if llama_vocab_is_eog(context.vocab, newTokenId) || context.position >= context.parameter.context { 43 | if temporaryInvalidCharacters.isEmpty { 44 | updatePromptCache() 45 | return nil 46 | } else { 47 | let newToken = makeString() ?? "" 48 | temporaryInvalidCharacters.removeAll() 49 | return newToken 50 | } 51 | } 52 | 53 | temporaryInvalidCharacters.append(contentsOf: newTokenId.piece(vocab: context.vocab, special: true)) 54 | 55 | let newToken: String 56 | if let token = makeString() { 57 | temporaryInvalidCharacters.removeAll() 58 | newToken = token 59 | } else if (1 ..< temporaryInvalidCharacters.count).contains(where: { String(utf8String: Array(temporaryInvalidCharacters.suffix($0)) + [0]) != nil }) { 60 | let token = makeString() ?? "" 61 | temporaryInvalidCharacters.removeAll() 62 | newToken = token 63 | } else { 64 | newToken = "" 65 | } 66 | 67 | if context.extraEOSTokens.contains(newToken) { 68 | temporaryInvalidCharacters.removeAll() 69 | updatePromptCache() 70 | return nil 71 | } 72 | 73 | context.batch.add(id: newTokenId, pos: context.position, seq_ids: [0], logits: true) 74 | 75 | return newToken 76 | } 77 | 78 | private mutating func makeString() -> String? { 79 | guard let text = String(utf8String: temporaryInvalidCharacters + [0]) else { 80 | return nil 81 | } 82 | currentResult += text 83 | return text 84 | } 85 | 86 | private func updatePromptCache() { 87 | context.addCache(for: .text(currentResult), position: context.position) 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/LlamaAutoMessageDecoder.swift: -------------------------------------------------------------------------------- 1 | import LocalLLMClient 2 | import Jinja 3 | 4 | enum ChatTemplate { 5 | case `default` 6 | case gemma3 7 | case qwen2_5_VL 8 | case llama3_2V // llama4 9 | case phi4 10 | 11 | var decoder: any LlamaChatMessageDecoder { 12 | switch self { 13 | case .default: LlamaChatMLMessageDecoder() 14 | case .gemma3: LlamaCustomMessageDecoder(tokenImageRegex: "<start_of_image>") 15 | case .qwen2_5_VL: LlamaQwen2VLMessageDecoder() 16 | case .llama3_2V: LlamaLlama3_2VMessageDecoder() 17 | case .phi4: LlamaChatMLMessageDecoder() 18 | } 19 | } 20 | } 21 | 22 | public struct LlamaAutoMessageDecoder: LlamaChatMessageDecoder { 23 | var chatTemplate: ChatTemplate = .default 24 | 25 | public init(chatTemplate: String) { 26 | guard let template = try? Template(chatTemplate) else { 27 | return 28 | } 29 | 30 | let contentMarker = "$$TEXT$$" 31 | let image = LLMInputImage() 32 | let candidateTemplates: [ChatTemplate] = [.gemma3, .qwen2_5_VL, .llama3_2V, .phi4] 33 | 34 | do { 35 | let messages = [ 36 | LLMInput.Message(role: .user, content: contentMarker, attachments: [.image(image)]), 37 | ] 38 | 39 | for candidate in candidateTemplates { 40 | let value = candidate.decoder.templateValue(from: messages).map(\.value) 41 | do { 42 | // Pick the template that can extract image chunks 43 | let rendered = try template.render(["messages": value]) 44 | let chunks = try candidate.decoder.extractChunks(prompt: rendered, imageChunks: [[image]]) 45 | if chunks.hasVisionItems() { 46 | self.chatTemplate = candidate 47 | return 48 | } 49 | } catch { 50 | } 51 | } 52 | } 53 | do { 54 | let messages = [ 55 | LLMInput.Message(role: .system, content: contentMarker), 56 | LLMInput.Message(role: .user, content: contentMarker, attachments: [.image(image)]), 57 | LLMInput.Message(role: .assistant, content: contentMarker), 58 | ] 59 | var maxLength = 0 60 | 61 | for candidate in candidateTemplates { 62 | let value = candidate.decoder.templateValue(from: messages).map(\.value) 63 | do { 64 | // Pick the template that can render more characters 65 | let rendered = try template.render(["messages": value]) 66 | if maxLength <= rendered.count { 67 | maxLength = rendered.count 68 | self.chatTemplate = candidate 69 | } 70 | } catch { 71 | } 72 | } 73 | } 74 | } 75 | 76 | public func templateValue(from messages: [LLMInput.Message]) -> [LLMInput.ChatTemplateMessage] { 77 | chatTemplate.decoder.templateValue(from: messages) 78 | } 79 | 80 | public func applyTemplate(_ messages: [LLMInput.ChatTemplateMessage], chatTemplate: String, additionalContext: [String: Any]?) throws(LLMError) -> String { 81 | try self.chatTemplate.decoder.applyTemplate(messages, chatTemplate: chatTemplate, additionalContext: additionalContext) 82 | } 83 | 84 | public func extractChunks(prompt: String, imageChunks: [[LLMInputImage]]) throws -> [MessageChunk] { 85 | try chatTemplate.decoder.extractChunks(prompt: prompt, imageChunks: imageChunks) 86 | } 87 | 88 | public func decode(_ messages: [LLMInput.ChatTemplateMessage], context: Context, multimodal: MultimodalContext?) throws { 89 | try chatTemplate.decoder.decode(messages, context: context, multimodal: multimodal) 90 | } 91 | } 92 | 93 | private extension [MessageChunk] { 94 | func hasVisionItems() -> Bool { 95 | contains { chunk in 96 | switch chunk { 97 | case .text: false 98 | case .image, .video: true 99 | } 100 | } 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/LlamaChatMessageDecoder.swift: -------------------------------------------------------------------------------- 1 | import LocalLLMClient 2 | import Jinja 3 | 4 | public enum MessageChunk: Equatable, Hashable { 5 | case text(String) 6 | case image([LLMInputImage]) 7 | case video([LLMInputImage]) // Placeholder for future video support 8 | } 9 | 10 | public protocol LlamaChatMessageDecoder: Sendable { 11 | func templateValue(from messages: [LLMInput.Message]) -> [LLMInput.ChatTemplateMessage] 12 | func applyTemplate(_ messages: [LLMInput.ChatTemplateMessage], chatTemplate: String, additionalContext: [String: Any]?) throws(LLMError) -> String 13 | func extractChunks(prompt: String, imageChunks: [[LLMInputImage]]) throws -> [MessageChunk] 14 | func decode(_ messages: [LLMInput.ChatTemplateMessage], context: Context, multimodal: MultimodalContext?) throws 15 | } 16 | 17 | public extension LlamaChatMessageDecoder { 18 | func templateValue(from messages: [LLMInput.Message]) -> [LLMInput.ChatTemplateMessage] { 19 | messages.map { message in 20 | LLMInput.ChatTemplateMessage( 21 | value: [ 22 | "role": message.role.rawValue, 23 | "content": (0..<message.attachments.images().count).map { _ in 24 | ["type": "image"] 25 | } + [["type": "text", "text": message.content]], 26 | ], 27 | attachments: message.attachments 28 | ) 29 | } 30 | } 31 | 32 | func applyTemplate( 33 | _ messages: [LLMInput.ChatTemplateMessage], 34 | chatTemplate: String, 35 | additionalContext: [String: Any]? = nil 36 | ) throws(LLMError) -> String { 37 | do { 38 | let template = try Template(chatTemplate) 39 | 40 | var templateContext: [String: Any] = [ 41 | "messages": messages.map(\.value), 42 | "add_generation_prompt": true, 43 | ] 44 | 45 | if let additionalContext { 46 | templateContext.merge(additionalContext) { _, new in new } 47 | } 48 | 49 | return try template.render(templateContext) 50 | } catch { 51 | throw .invalidParameter(reason: "Failed to apply template: \(error.localizedDescription)") 52 | } 53 | } 54 | 55 | func extractChunks(prompt: String, imageChunks: [[LLMInputImage]]) throws -> [MessageChunk] { 56 | [.text(prompt)] 57 | } 58 | 59 | func decode(_ messages: [LLMInput.ChatTemplateMessage], context: Context, multimodal: MultimodalContext?) throws { 60 | let specialTokens: [String: String] = [ 61 | "bos_token": String(utf8String: llama_vocab_get_text(context.model.vocab, max(0, llama_vocab_bos(context.model.vocab)))) ?? "", 62 | "eos_token": String(utf8String: llama_vocab_get_text(context.model.vocab, max(0, llama_vocab_eos(context.model.vocab)))) ?? "", 63 | "unk_token": String(utf8String: llama_vocab_get_text(context.model.vocab, 0)) ?? "", 64 | "sep_token": String(utf8String: llama_vocab_get_text(context.model.vocab, max(0, llama_vocab_sep(context.model.vocab)))) ?? "", 65 | "pad_token": String(utf8String: llama_vocab_get_text(context.model.vocab, max(0, llama_vocab_pad(context.model.vocab)))) ?? "", 66 | "cls_token": String(utf8String: llama_vocab_get_text(context.model.vocab, max(0, llama_vocab_bos(context.model.vocab)))) ?? "", 67 | "mask_token": "" 68 | ] 69 | 70 | let prompt = try applyTemplate(messages, chatTemplate: context.model.chatTemplate, additionalContext: specialTokens) 71 | let imagesChunks = messages.imageChunks() 72 | var chunks = try extractChunks(prompt: prompt, imageChunks: imagesChunks) 73 | context.removeCachedChunks(&chunks) 74 | 75 | for chunk in chunks { 76 | switch chunk { 77 | case .text(let text): 78 | try context.decode(text: text) 79 | case .image(let images): 80 | guard let multimodal else { throw LLMError.failedToDecode(reason: "no mmproj file") } 81 | let bitmap = try multimodal.chunks(images: images) 82 | try context.decode(bitmap: bitmap, with: multimodal) 83 | case .video: 84 | // Video not supported in this decoder yet 85 | break 86 | } 87 | 88 | context.addCache(for: chunk, position: context.position) 89 | } 90 | } 91 | } 92 | 93 | public struct LlamaCustomMessageDecoder: LlamaChatMessageDecoder { 94 | public init( 95 | tokenImageRegex: String = "<start_of_image>" 96 | ) { 97 | self.tokenImageRegex = tokenImageRegex 98 | } 99 | 100 | public let tokenImageRegex: String 101 | 102 | public func extractChunks(prompt: String, imageChunks: [[LLMInputImage]]) throws -> [MessageChunk] { 103 | let pattern = try Regex<Substring>(tokenImageRegex) 104 | var chunks: [MessageChunk] = [] 105 | var lastIndex = prompt.startIndex 106 | var imageIndex = 0 107 | 108 | for match in prompt.matches(of: pattern) { 109 | if lastIndex < match.range.lowerBound { 110 | let prefix = prompt[lastIndex..<match.range.lowerBound] 111 | chunks.append(.text(String(prefix))) 112 | } 113 | 114 | if imageIndex < imageChunks.count { 115 | chunks.append(.image(imageChunks[imageIndex])) 116 | imageIndex += 1 117 | } 118 | 119 | lastIndex = match.range.upperBound 120 | } 121 | 122 | if lastIndex < prompt.endIndex { 123 | let suffix = prompt[lastIndex..<prompt.endIndex] 124 | chunks.append(.text(String(suffix))) 125 | } 126 | 127 | return chunks 128 | } 129 | } 130 | 131 | public struct LlamaQwen2VLMessageDecoder: LlamaChatMessageDecoder { 132 | public func extractChunks(prompt: String, imageChunks: [[LLMInputImage]]) throws -> [MessageChunk] { 133 | let pattern = /(?<image><\|image_pad\|>)|(?<video><\|video_pad\|>)/ 134 | var chunks = [MessageChunk]() 135 | var lastIndex = prompt.startIndex 136 | var imageIndex = 0 137 | 138 | for match in prompt.matches(of: pattern) { 139 | if lastIndex < match.range.lowerBound { 140 | let prefix = prompt[lastIndex..<match.range.lowerBound] 141 | chunks.append(.text(String(prefix))) 142 | } 143 | 144 | if let _ = match.output.image { 145 | guard imageIndex < imageChunks.count else { 146 | throw LLMError.failedToDecode(reason: "Not enough image chunks") 147 | } 148 | chunks.append(.image(imageChunks[imageIndex])) 149 | imageIndex += 1 150 | } else if let _ = match.output.video { 151 | // TODO: Handle video - add placeholder for now 152 | chunks.append(.video([])) 153 | } 154 | 155 | lastIndex = match.range.upperBound 156 | } 157 | 158 | if lastIndex < prompt.endIndex { 159 | let suffix = prompt[lastIndex..<prompt.endIndex] 160 | chunks.append(.text(String(suffix))) 161 | } 162 | 163 | return chunks 164 | } 165 | } 166 | 167 | public struct LlamaLlama3_2VMessageDecoder: LlamaChatMessageDecoder { 168 | public func templateValue(from messages: [LLMInput.Message]) -> [LLMInput.ChatTemplateMessage] { 169 | messages.map { message in 170 | switch message.role { 171 | case .system, .assistant, .custom: 172 | LLMInput.ChatTemplateMessage( 173 | value: ["role": message.role.rawValue, "content": message.content], 174 | attachments: message.attachments 175 | ) 176 | case .user: 177 | LLMInput.ChatTemplateMessage( 178 | value: [ 179 | "role": message.role.rawValue, 180 | "content": [["type": "text", "text": message.content]] + (0..<message.attachments.images().count).map { _ in 181 | ["type": "image"] 182 | }, 183 | ], 184 | attachments: message.attachments 185 | ) 186 | } 187 | } 188 | } 189 | 190 | public func extractChunks(prompt: String, imageChunks: [[LLMInputImage]]) throws -> [MessageChunk] { 191 | let decoder = LlamaCustomMessageDecoder(tokenImageRegex: #"<\|image\|>"#) 192 | return try decoder.extractChunks(prompt: prompt, imageChunks: imageChunks) 193 | } 194 | } 195 | 196 | public struct LlamaChatMLMessageDecoder: LlamaChatMessageDecoder { 197 | public func templateValue(from messages: [LLMInput.Message]) -> [LLMInput.ChatTemplateMessage] { 198 | messages.map { message in 199 | LLMInput.ChatTemplateMessage( 200 | value: ["role": message.role.rawValue, "content": message.content], 201 | attachments: message.attachments 202 | ) 203 | } 204 | } 205 | } 206 | 207 | // MARK: - Utilities 208 | 209 | private extension [LLMInput.ChatTemplateMessage] { 210 | func imageChunks() -> [[LLMInputImage]] { 211 | compactMap { message in 212 | let images = message.attachments.images() 213 | return images.isEmpty ? nil : images 214 | } 215 | } 216 | } 217 | 218 | private extension [LLMAttachment] { 219 | func images() -> [LLMInputImage] { 220 | return compactMap { attachment -> LLMInputImage? in 221 | if case let .image(image) = attachment { 222 | return image 223 | } 224 | return nil 225 | } 226 | } 227 | } 228 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/LlamaClient.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import LocalLLMClient 3 | 4 | /// A client for interacting with the Llama models. 5 | /// 6 | /// This class provides methods for generating text streams from various inputs, 7 | /// and handles the communication with the underlying Llama model. 8 | public final class LlamaClient: LLMClient { 9 | private let context: Context 10 | private let multimodal: MultimodalContext? 11 | private let messageDecoder: any LlamaChatMessageDecoder 12 | 13 | /// Initializes a new Llama client. 14 | /// 15 | /// - Parameters: 16 | /// - url: The URL of the Llama model file. 17 | /// - mmprojURL: The URL of the multimodal projector file (optional). 18 | /// - parameter: The parameters for the Llama model. 19 | /// - messageDecoder: The message decoder to use for chat messages (optional). 20 | /// - verbose: A Boolean value indicating whether to enable verbose logging. 21 | /// - Throws: An error if the client fails to initialize. 22 | public init( 23 | url: URL, 24 | mmprojURL: URL?, 25 | parameter: Parameter, 26 | messageDecoder: (any LlamaChatMessageDecoder)?, 27 | verbose: Bool 28 | ) throws { 29 | context = try Context(url: url, parameter: parameter) 30 | if let mmprojURL { 31 | multimodal = try MultimodalContext(url: mmprojURL, context: context, parameter: parameter, verbose: verbose) 32 | } else { 33 | multimodal = nil 34 | } 35 | self.messageDecoder = messageDecoder ?? LlamaAutoMessageDecoder(chatTemplate: context.model.chatTemplate) 36 | } 37 | 38 | /// Generates a text stream from the given input. 39 | /// 40 | /// - Parameter input: The input to generate text from. 41 | /// - Returns: A generator that produces text as it's generated by the model. 42 | /// - Throws: An `LLMError.failedToDecode` error if the input cannot be decoded. 43 | public func textStream(from input: LLMInput) throws -> Generator { 44 | do { 45 | switch input.value { 46 | case .plain(let text): 47 | context.clear() 48 | try context.decode(text: text) 49 | case .chatTemplate(let messages): 50 | try messageDecoder.decode(messages, context: context, multimodal: multimodal) 51 | case .chat(let messages): 52 | let value = messageDecoder.templateValue(from: messages) 53 | try messageDecoder.decode(value, context: context, multimodal: multimodal) 54 | } 55 | } catch { 56 | throw LLMError.failedToDecode(reason: error.localizedDescription) 57 | } 58 | 59 | return Generator(context: context) 60 | } 61 | } 62 | 63 | public extension LocalLLMClient { 64 | /// Creates a new Llama client. 65 | /// 66 | /// This is a factory method for creating `LlamaClient` instances. 67 | /// 68 | /// - Parameters: 69 | /// - url: The URL of the Llama model file. 70 | /// - mmprojURL: The URL of the multimodal projector file (optional). 71 | /// - parameter: The parameters for the Llama model. Defaults to `.default`. 72 | /// - messageDecoder: The message decoder to use for chat messages (optional). 73 | /// - verbose: A Boolean value indicating whether to enable verbose logging. Defaults to `false`. 74 | /// - Returns: A new `LlamaClient` instance. 75 | /// - Throws: An error if the client fails to initialize. 76 | static func llama( 77 | url: URL, 78 | mmprojURL: URL? = nil, 79 | parameter: LlamaClient.Parameter = .default, 80 | messageDecoder: (any LlamaChatMessageDecoder)? = nil, 81 | verbose: Bool = false 82 | ) async throws -> LlamaClient { 83 | setLlamaVerbose(verbose) 84 | return try LlamaClient( 85 | url: url, 86 | mmprojURL: mmprojURL, 87 | parameter: parameter, 88 | messageDecoder: messageDecoder, 89 | verbose: verbose 90 | ) 91 | } 92 | } 93 | 94 | #if DEBUG 95 | extension LlamaClient { 96 | var _context: Context { 97 | context 98 | } 99 | 100 | var _multimodal: MultimodalContext? { 101 | multimodal 102 | } 103 | } 104 | #endif 105 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Logger.swift: -------------------------------------------------------------------------------- 1 | #if canImport(OSLog) 2 | import OSLog 3 | #else 4 | import Foundation 5 | 6 | // Fallback for platforms without OSLog 7 | 8 | package struct Logger { 9 | let subsystem: String 10 | let category: String 11 | 12 | func log(_ message: String) { 13 | print("[\(subsystem).\(category)] \(message)") 14 | } 15 | 16 | func debug(_ message: String) { 17 | print("[DEBUG] [\(subsystem).\(category)] \(message)") 18 | } 19 | 20 | func info(_ message: String) { 21 | print("[INFO] [\(subsystem).\(category)] \(message)") 22 | } 23 | 24 | func warning(_ message: String) { 25 | print("[WARNING] [\(subsystem).\(category)] \(message)") 26 | } 27 | 28 | func fault(_ message: String) { 29 | print("[FAULT] [\(subsystem).\(category)] \(message)") 30 | } 31 | } 32 | #endif 33 | 34 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Model.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import LocalLLMClient 3 | import Jinja 4 | 5 | final class Model { 6 | let model: OpaquePointer 7 | let chatTemplate: String 8 | 9 | var vocab: OpaquePointer { 10 | llama_model_get_vocab(model) 11 | } 12 | 13 | init(url: URL) throws(LLMError) { 14 | var model_params = llama_model_default_params() 15 | #if targetEnvironment(simulator) 16 | model_params.n_gpu_layers = 0 17 | #endif 18 | model_params.use_mmap = true 19 | 20 | guard let model = llama_model_load_from_file(url.path(), model_params) else { 21 | throw .failedToLoad(reason: "Failed to load model from file") 22 | } 23 | 24 | self.model = model 25 | 26 | let chatTemplate = getString(capacity: 2048) { buffer, length in 27 | // LLM_KV_TOKENIZER_CHAT_TEMPLATE 28 | llama_model_meta_val_str(model, "tokenizer.chat_template", buffer, length) 29 | } 30 | 31 | // If the template is empty, it uses Gemma3-styled template as default 32 | self.chatTemplate = chatTemplate.isEmpty ? #"{{ bos_token }} {%- if messages[0]['role'] == 'system' -%} {%- if messages[0]['content'] is string -%} {%- set first_user_prefix = messages[0]['content'] + ' ' -%} {%- else -%} {%- set first_user_prefix = messages[0]['content'][0]['text'] + ' ' -%} {%- endif -%} {%- set loop_messages = messages[1:] -%} {%- else -%} {%- set first_user_prefix = "" -%} {%- set loop_messages = messages -%} {%- endif -%} {%- for message in loop_messages -%} {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} {%- endif -%} {%- if (message['role'] == 'assistant') -%} {%- set role = "model" -%} {%- else -%} {%- set role = message['role'] -%} {%- endif -%} {{ (first_user_prefix if loop.first else "") }} {%- if message['content'] is string -%} {{ message['content'] | trim }} {%- elif message['content'] is iterable -%} {%- for item in message['content'] -%} {%- if item['type'] == 'image' -%} {{ '<start_of_image>' }} {%- elif item['type'] == 'text' -%} {{ item['text'] | trim }} {%- endif -%} {%- endfor -%} {%- else -%} {{ raise_exception("Invalid content type") }} {%- endif -%} {%- endfor -%}"# : chatTemplate 33 | } 34 | 35 | deinit { 36 | llama_model_free(model) 37 | } 38 | 39 | func makeAndAllocateContext(with ctx_params: llama_context_params) throws(LLMError) -> OpaquePointer { 40 | guard let context = llama_init_from_model(model, ctx_params) else { 41 | throw .invalidParameter(reason: "Failed to create context") 42 | } 43 | return context 44 | } 45 | 46 | func tokenizerConfigs() -> [String: Any] { 47 | let numberOfConfigs = llama_model_meta_count(model) 48 | return (0..<numberOfConfigs).reduce(into: [:]) { partialResult, i in 49 | let key = getString(capacity: 64) { buffer, length in 50 | llama_model_meta_key_by_index(model, i, buffer, length) 51 | } 52 | let value = getString(capacity: 2048) { buffer, length in 53 | llama_model_meta_val_str_by_index(model, i, buffer, length) 54 | } 55 | partialResult[key] = value 56 | } 57 | } 58 | } 59 | 60 | private func getString(capacity: Int = 1024, getter: (UnsafeMutablePointer<CChar>?, Int) -> Int32) -> String { 61 | String(unsafeUninitializedCapacity: capacity) { buffer in 62 | buffer.withMemoryRebound(to: CChar.self) { buffer in 63 | let length = Int(getter(buffer.baseAddress, capacity)) 64 | return max(0, length) 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Multimodal.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import LocalLLMClient 3 | @_exported import LocalLLMClientLlamaC 4 | 5 | public class MultimodalContext: @unchecked Sendable { 6 | package let multimodalContext: OpaquePointer 7 | package let verbose: Bool 8 | 9 | package init(url: URL, context: Context, parameter: LlamaClient.Parameter, verbose: Bool = false) throws(LLMError) { 10 | var mparams = mtmd_context_params_default() 11 | mparams.use_gpu = true 12 | mparams.print_timings = verbose 13 | if let numberOfThreads = parameter.numberOfThreads { 14 | mparams.n_threads = Int32(numberOfThreads) 15 | } 16 | mparams.verbosity = verbose ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_CONT; 17 | guard let multimodalContext = mtmd_init_from_file(url.path(), context.model.model, mparams) else { 18 | throw .failedToLoad(reason: "Failed to load the mmproj file") 19 | } 20 | self.multimodalContext = multimodalContext 21 | self.verbose = verbose 22 | } 23 | 24 | deinit { 25 | mtmd_free(multimodalContext) 26 | } 27 | 28 | package func chunks(images: [LLMInputImage]) throws(LLMError) -> MultimodalChunks { 29 | var bitmaps: [OpaquePointer?] = try images.map { image throws(LLMError) in 30 | let data = try llmInputImageToData(image) 31 | let (bytes, width, height) = imageDataToRGBBytes(imageData: data)! 32 | guard let bitmap = mtmd_bitmap_init(UInt32(width), UInt32(height), bytes) else { 33 | throw .failedToLoad(reason: "Failed to create bitmap") 34 | } 35 | return bitmap 36 | } 37 | defer { 38 | bitmaps.forEach(mtmd_bitmap_free) 39 | } 40 | 41 | let chunks = mtmd_input_chunks_init()! 42 | 43 | let textStorage = " \(MTMD_DEFAULT_IMAGE_MARKER) " // spaces for the workaround of tokenizer 44 | var text = textStorage.withCString { 45 | mtmd_input_text(text: $0, add_special: false, parse_special: true) 46 | } 47 | 48 | guard mtmd_tokenize(multimodalContext, chunks, &text, &bitmaps, bitmaps.count) == 0 else { 49 | throw .failedToLoad(reason: "Failed to tokenize bitmap") 50 | } 51 | 52 | return MultimodalChunks(chunks: chunks) 53 | } 54 | } 55 | 56 | package final class MultimodalChunks: @unchecked Sendable { 57 | package let chunks: OpaquePointer 58 | 59 | public init(chunks: OpaquePointer) { 60 | self.chunks = chunks 61 | } 62 | 63 | deinit { 64 | mtmd_input_chunks_free(chunks) 65 | } 66 | } 67 | 68 | package extension Context { 69 | func decode(bitmap: MultimodalChunks, with multimodal: MultimodalContext) throws(LLMError) { 70 | var newPosition: Int32 = 0 71 | let chunk = mtmd_input_chunks_get(bitmap.chunks, 1) // 1: <space><img><space> 72 | 73 | let imageTokens = mtmd_input_chunk_get_tokens_image(chunk) 74 | 75 | if multimodal.verbose { 76 | llamaLog(level: .debug, message: "encoding image or slice...\n") 77 | } 78 | 79 | guard mtmd_encode(multimodal.multimodalContext, imageTokens) == 0 else { 80 | throw .failedToDecode(reason: "Failed to encode image") 81 | } 82 | 83 | let embd = mtmd_get_output_embd(multimodal.multimodalContext); 84 | guard mtmd_helper_decode_image_chunk( 85 | multimodal.multimodalContext, 86 | context, 87 | chunk, 88 | embd, 89 | position, 90 | 0, // seq_id 91 | Int32(parameter.batch), 92 | &newPosition) == 0 else { 93 | throw .failedToDecode(reason: "Failed to decode image") 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Parameter.swift: -------------------------------------------------------------------------------- 1 | public extension LlamaClient { 2 | /// Defines the parameters for the Llama client and model. 3 | /// 4 | /// These parameters control various aspects of the text generation process, 5 | /// such as the context size, sampling methods, and penalty settings. 6 | struct Parameter: Sendable { 7 | /// Initializes a new set of parameters for the Llama client. 8 | /// 9 | /// - Parameters: 10 | /// - context: The size of the context window in tokens. Default is `2048`. 11 | /// - seed: The random seed for generation. `nil` means a random seed will be used. Default is `nil`. 12 | /// - numberOfThreads: The number of threads to use for generation. `nil` means the optimal number of threads will be chosen. Default is `nil`. 13 | /// - batch: The batch size for prompt processing. Default is `512`. 14 | /// - temperature: Controls randomness in sampling. Lower values make the model more deterministic. Default is `0.8`. 15 | /// - topK: Limits sampling to the K most likely tokens. Default is `40`. 16 | /// - topP: Limits sampling to a cumulative probability. Default is `0.95`. 17 | /// - typicalP: Limits sampling based on typical probability. Default is `1`. 18 | /// - penaltyLastN: The number of recent tokens to consider for penalty. Default is `64`. 19 | /// - penaltyRepeat: The penalty factor for repeating tokens. Default is `1.1`. 20 | /// - options: Additional options for the Llama client. 21 | public init( 22 | context: Int = 2048, 23 | seed: Int? = nil, 24 | numberOfThreads: Int? = nil, 25 | batch: Int = 512, 26 | temperature: Float = 0.8, 27 | topK: Int = 40, 28 | topP: Float = 0.95, 29 | typicalP: Float = 1, 30 | penaltyLastN: Int = 64, 31 | penaltyRepeat: Float = 1.1, 32 | options: Options = .init() 33 | ) { 34 | self.context = context 35 | self.seed = seed 36 | self.numberOfThreads = numberOfThreads 37 | self.batch = batch 38 | self.temperature = temperature 39 | self.topK = topK 40 | self.topP = topP 41 | self.typicalP = typicalP 42 | self.penaltyLastN = penaltyLastN 43 | self.penaltyRepeat = penaltyRepeat 44 | self.options = options 45 | } 46 | 47 | /// The size of the context window in tokens. 48 | public var context: Int 49 | /// The random seed for generation. `nil` means a random seed will be used. 50 | public var seed: Int? 51 | /// The number of threads to use for generation. `nil` means the optimal number of threads will be chosen. 52 | public var numberOfThreads: Int? 53 | /// The batch size for prompt processing. 54 | public var batch: Int 55 | /// Controls randomness in sampling. Lower values make the model more deterministic. 56 | public var temperature: Float 57 | /// Limits sampling to the K most likely tokens. 58 | public var topK: Int 59 | /// Limits sampling to a cumulative probability. 60 | public var topP: Float 61 | /// Limits sampling based on typical probability. 62 | public var typicalP: Float 63 | /// The number of recent tokens to consider for penalty. 64 | public var penaltyLastN: Int 65 | /// The penalty factor for repeating tokens. 66 | public var penaltyRepeat: Float 67 | 68 | /// Additional options for the Llama client. 69 | public var options: Options 70 | 71 | /// Default parameter settings. 72 | public static let `default` = Parameter() 73 | } 74 | 75 | /// Defines additional, less commonly used options for the Llama client. 76 | struct Options: Sendable { 77 | /// Initializes a new set of options for the Llama client. 78 | /// 79 | /// - Parameters: 80 | /// - responseFormat: Specifies the desired format for the model's response, such as JSON or a custom grammar. `nil` means no specific format is enforced. Default is `nil`. 81 | /// - extraEOSTokens: A set of additional strings that, when encountered, will be treated as end-of-sequence tokens by the model. Default is an empty set. 82 | public init( 83 | responseFormat: ResponseFormat? = nil, 84 | extraEOSTokens: Set<String> = [] 85 | ) { 86 | self.responseFormat = responseFormat 87 | self.extraEOSTokens = extraEOSTokens 88 | } 89 | 90 | /// Specifies the desired format for the model's response (e.g., JSON, custom grammar). 91 | public var responseFormat: ResponseFormat? 92 | /// Additional strings to be treated as end-of-sequence tokens. 93 | public var extraEOSTokens: Set<String> 94 | } 95 | 96 | /// Specifies the desired format for the model's response. 97 | enum ResponseFormat: Sendable { 98 | /// Constrains the model's output to a specific grammar defined in GBNF (GGML BNF) format. 99 | /// - Parameters: 100 | /// - gbnf: The grammar definition in GBNF format. 101 | /// - root: The name of the root rule in the GBNF grammar. Defaults to "root". 102 | case grammar(gbnf: String, root: String = "root") 103 | /// Constrains the model's output to valid JSON format. 104 | case json 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Resources/Grammars/json.gbnf: -------------------------------------------------------------------------------- 1 | root ::= object 2 | value ::= object | array | string | number | ("true" | "false" | "null") ws 3 | 4 | object ::= 5 | "{" ws ( 6 | string ":" ws value 7 | ("," ws string ":" ws value)* 8 | )? "}" ws 9 | 10 | array ::= 11 | "[" ws ( 12 | value 13 | ("," ws value)* 14 | )? "]" ws 15 | 16 | string ::= 17 | "\"" ( 18 | [^"\\\x7F\x00-\x1F] | 19 | "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes 20 | )* "\"" ws 21 | 22 | number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws 23 | 24 | # Optional space: by convention, applied in this grammar after literal chars when allowed 25 | ws ::= | " " | "\n" [ \t]{0,20} 26 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Sampler.swift: -------------------------------------------------------------------------------- 1 | #if BUILD_DOCC 2 | @preconcurrency @_implementationOnly import llama 3 | #elseif canImport(llama) 4 | @preconcurrency private import llama 5 | #else 6 | @preconcurrency import LocalLLMClientLlamaC 7 | #endif 8 | import Foundation 9 | 10 | typealias Sampler = UnsafeMutablePointer<llama_sampler> 11 | 12 | package extension Sampler { 13 | func sample(context: Context, index: Int32) -> llama_token { 14 | let logits = UncheckedSendable(llama_get_logits_ith(context.context, Int32(index))!) 15 | 16 | DispatchQueue.concurrentPerform(iterations: context.cursorPointer.count) { tokenID in 17 | context.cursorPointer[tokenID] = llama_token_data( 18 | id: Int32(tokenID), logit: logits.value[tokenID], p: 0.0 19 | ) 20 | } 21 | 22 | var tokenDataArray = llama_token_data_array( 23 | data: context.cursorPointer.baseAddress, 24 | size: context.cursorPointer.count, 25 | selected: -1, 26 | sorted: false 27 | ) 28 | 29 | if let grammer = context.grammer { 30 | llama_sampler_apply(grammer, &tokenDataArray) 31 | } 32 | 33 | llama_sampler_apply(self, &tokenDataArray) 34 | assert(tokenDataArray.selected != -1) 35 | 36 | let token = tokenDataArray.data[Int(tokenDataArray.selected)].id 37 | llama_sampler_accept(self, token) 38 | return token 39 | } 40 | } 41 | 42 | private struct UncheckedSendable<Value>: @unchecked Sendable { 43 | var value: Value 44 | 45 | init(_ value: Value) { 46 | self.value = value 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Token.swift: -------------------------------------------------------------------------------- 1 | #if BUILD_DOCC 2 | @preconcurrency @_implementationOnly import llama 3 | #elseif canImport(llama) 4 | @preconcurrency private import llama 5 | #else 6 | @preconcurrency import LocalLLMClientLlamaC 7 | #endif 8 | 9 | package extension [llama_token] { 10 | init(_ text: String, addBos: Bool, special: Bool, vocab: OpaquePointer) { 11 | let utf8Count = text.utf8.count 12 | let n_tokens = utf8Count + (addBos ? 1 : 0) + 1 13 | self.init(unsafeUninitializedCapacity: n_tokens) { buffer, initializedCount in 14 | let count = llama_tokenize(vocab, text, Int32(utf8Count), buffer.baseAddress, Int32(n_tokens), addBos, special) 15 | initializedCount = Int(count) 16 | } 17 | } 18 | } 19 | 20 | package extension llama_token { 21 | func piece(vocab: OpaquePointer, special: Bool) -> [CChar] { 22 | var result = [CChar](repeating: 0, count: 8) 23 | let nTokens = llama_token_to_piece(vocab, self, &result, 8, 0, special) 24 | if nTokens < 0 { 25 | result = [CChar](repeating: 0, count: Int(-nTokens)) 26 | llama_token_to_piece(vocab, self, &result, -nTokens, 0, special) 27 | return result 28 | } else { 29 | return Array(result[0..<Int(nTokens)]) 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/Utility.swift: -------------------------------------------------------------------------------- 1 | #if BUILD_DOCC 2 | @preconcurrency @_implementationOnly import llama 3 | #elseif canImport(llama) 4 | @preconcurrency private import llama 5 | #else 6 | @preconcurrency import LocalLLMClientLlamaC 7 | #endif 8 | import Foundation 9 | #if canImport(OSLog) 10 | import OSLog 11 | #endif 12 | 13 | // MARK: - Global State 14 | 15 | nonisolated(unsafe) private var isLlamaInitialized = false 16 | nonisolated(unsafe) private var isCustomLogEnabled = false 17 | nonisolated(unsafe) private var llamaLogCallback: ((LlamaLogLevel, String) -> Void)? 18 | 19 | // MARK: - Life Cycle 20 | 21 | public func initializeLlama() { 22 | guard !isLlamaInitialized else { return } 23 | isLlamaInitialized = true 24 | #if os(Linux) 25 | ggml_backend_load_all_from_path(ProcessInfo.processInfo.environment["LD_LIBRARY_PATH"]) 26 | #endif 27 | 28 | llama_backend_init() 29 | 30 | if !isCustomLogEnabled { 31 | #if DEBUG 32 | setLlamaVerbose(true) 33 | #else 34 | setLlamaVerbose(false) 35 | #endif 36 | } 37 | } 38 | 39 | public func shutdownLlama() { 40 | guard isLlamaInitialized else { return } 41 | llama_backend_free() 42 | } 43 | 44 | // MARK: - Logging 45 | 46 | package extension Logger { 47 | static let localllm = Logger(subsystem: "com.github.tattn.LocalLLMClient", category: "localllm") 48 | } 49 | 50 | public func setLlamaLog(callback: ((LlamaLogLevel, String) -> Void)?) { 51 | llamaLogCallback = callback 52 | isCustomLogEnabled = true 53 | 54 | llama_log_set({ level, text, _ in 55 | guard let llamaLogCallback else { return } 56 | let level = LlamaLogLevel(rawValue: level.rawValue) ?? .none 57 | llamaLogCallback(level, text.map(String.init(cString:)) ?? "") 58 | }, nil) 59 | } 60 | 61 | public func setLlamaVerbose(_ verbose: Bool) { 62 | setLlamaLog(callback: verbose ? { level, message in 63 | llamaLog(level: level, message: message) 64 | } : nil) 65 | } 66 | 67 | package func llamaLog(level: LlamaLogLevel, message: String) { 68 | switch level { 69 | case .none: 70 | Logger.localllm.log("\(message)") 71 | case .debug, .continue: 72 | Logger.localllm.debug("\(message)") 73 | case .info: 74 | Logger.localllm.info("\(message)") 75 | case .warn: 76 | Logger.localllm.warning("\(message)") 77 | case .error: 78 | Logger.localllm.fault("\(message)") 79 | } 80 | } 81 | 82 | public enum LlamaLogLevel: ggml_log_level.RawValue, Sendable { 83 | case none, debug, info, warn, error, `continue` 84 | } 85 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlama/stb_image.swift: -------------------------------------------------------------------------------- 1 | #if canImport(CoreImage) 2 | // // Alternative to stb_image.h 3 | import Accelerate 4 | import CoreImage 5 | 6 | @_silgen_name("stbi_load_from_memory") 7 | func stbi_load_from_memory(_ buffer: UnsafePointer<UInt8>, _ len: UInt64, _ x: UnsafeMutablePointer<Int32>, _ y: UnsafeMutablePointer<Int32>, _ comp: UnsafeMutablePointer<Int32>, _ req_comp: Int32) -> UnsafeMutableRawPointer? { 8 | assert(req_comp == 3, "Only RGB format is supported") 9 | 10 | let data = Data(bytes: buffer, count: Int(len)) 11 | guard let (rgbBytes, width, height) = imageDataToRGBBytes(imageData: data) else { 12 | print("Failed to convert image data to RGB bytes") 13 | return nil 14 | } 15 | 16 | x.pointee = Int32(width) 17 | y.pointee = Int32(height) 18 | 19 | return rgbBytes 20 | } 21 | 22 | @_silgen_name("stbi_load") 23 | func stbi_load(_ filename: UnsafePointer<CChar>, _ x: UnsafeMutablePointer<Int32>, _ y: UnsafeMutablePointer<Int32>, _ comp: UnsafeMutablePointer<Int32>, _ req_comp: Int32) -> UnsafeMutableRawPointer? { 24 | assert(req_comp == 3, "Only RGB format is supported") 25 | 26 | guard let url = URL(string: String(cString: filename)), 27 | let imageData = try? Data(contentsOf: url), 28 | let (rgbBytes, width, height) = imageDataToRGBBytes(imageData: imageData) else { 29 | print("Failed to convert image data to RGB bytes") 30 | return nil 31 | } 32 | 33 | x.pointee = Int32(width) 34 | y.pointee = Int32(height) 35 | 36 | return rgbBytes 37 | } 38 | 39 | @_silgen_name("stbi_image_free") 40 | func stbi_image_free(_ buffer: UnsafeMutableRawPointer) { 41 | buffer.assumingMemoryBound(to: UInt8.self).deallocate() 42 | } 43 | 44 | package func imageDataToRGBBytes( 45 | imageData: Data 46 | ) -> (bytes: UnsafeMutableRawPointer, width: Int, height: Int)? { 47 | let context = CIContext() 48 | let image = CIImage(data: imageData)! 49 | guard let cgImage = context.createCGImage(image, from: image.extent) else { 50 | return nil 51 | } 52 | 53 | var format = vImage_CGImageFormat( 54 | bitsPerComponent: 8, 55 | bitsPerPixel: 8 * 3, 56 | colorSpace: CGColorSpace(name: CGColorSpace.displayP3)!, 57 | bitmapInfo: .init(rawValue: CGImageAlphaInfo.none.rawValue))! 58 | 59 | guard let buffer = try? vImage.PixelBuffer( 60 | cgImage: cgImage, 61 | cgImageFormat: &format, 62 | pixelFormat: vImage.Interleaved8x3.self) else { 63 | return nil 64 | } 65 | 66 | let width = cgImage.width 67 | let height = cgImage.height 68 | 69 | let result = UnsafeMutableRawBufferPointer.allocate( 70 | byteCount: width * height * 3, 71 | alignment: MemoryLayout<UInt8>.alignment 72 | ) 73 | buffer.array.copyBytes(to: result) 74 | 75 | return (result.baseAddress!, width, height) 76 | } 77 | #else 78 | import Foundation 79 | 80 | @_silgen_name("stbi_load_from_memory") 81 | func stbi_load_from_memory(_ buffer: UnsafePointer<UInt8>, _ len: Int32, _ x: UnsafeMutablePointer<Int32>, _ y: UnsafeMutablePointer<Int32>, _ comp: UnsafeMutablePointer<Int32>, _ req_comp: Int32) -> UnsafeMutablePointer<UInt8>? 82 | 83 | package func imageDataToRGBBytes( 84 | imageData: Data 85 | ) -> (bytes: UnsafeMutableRawPointer, width: Int, height: Int)? { 86 | var width: Int32 = 0 87 | var height: Int32 = 0 88 | var comp: Int32 = 0 89 | return imageData.withUnsafeBytes { rawBufferPointer -> ((UnsafeMutableRawPointer, Int, Int)?) in 90 | guard let baseAddress = rawBufferPointer.baseAddress else { return nil } 91 | let pointer = baseAddress.assumingMemoryBound(to: UInt8.self) 92 | return stbi_load_from_memory(pointer, Int32(imageData.count), &width, &height, &comp, 3).map { bytes in 93 | (UnsafeMutableRawPointer(bytes), Int(width), Int(height)) 94 | } 95 | } 96 | } 97 | #endif 98 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/clip-impl.h: -------------------------------------------------------------------------------- 1 | exclude/llama.cpp/tools/mtmd/clip-impl.h -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/clip.cpp: -------------------------------------------------------------------------------- 1 | exclude/llama.cpp/tools/mtmd/clip.cpp -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/ggml-cpp.h: -------------------------------------------------------------------------------- 1 | exclude/llama.cpp/ggml/include/ggml-cpp.h -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/include/clip.h: -------------------------------------------------------------------------------- 1 | ../exclude/llama.cpp/tools/mtmd/clip.h -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/include/ggml-alloc.h: -------------------------------------------------------------------------------- 1 | #if __has_include(<llama/ggml-alloc.h>) 2 | #include <llama/ggml-alloc.h> 3 | #else 4 | // For DocC 5 | #include "../exclude/llama.cpp/ggml/include/ggml-alloc.h" 6 | #endif 7 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/include/ggml-backend.h: -------------------------------------------------------------------------------- 1 | #if __has_include(<llama/ggml-backend.h>) 2 | #include <llama/ggml-backend.h> 3 | #else 4 | // For DocC 5 | #include "../exclude/llama.cpp/ggml/include/ggml-backend.h" 6 | #endif 7 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/include/ggml-cpu.h: -------------------------------------------------------------------------------- 1 | #if __has_include(<llama/ggml-cpu.h>) 2 | #include <llama/ggml-cpu.h> 3 | #else 4 | // For DocC 5 | #include "../exclude/llama.cpp/ggml/include/ggml-cpu.h" 6 | #endif 7 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/include/ggml-opt.h: -------------------------------------------------------------------------------- 1 | // This header is used in DocC builds. 2 | #if __has_include(<llama/ggml-opt.h>) 3 | #include <llama/ggml-opt.h> 4 | #else 5 | // For DocC 6 | #include "../exclude/llama.cpp/ggml/include/ggml-opt.h" 7 | #endif 8 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/include/ggml.h: -------------------------------------------------------------------------------- 1 | #if __has_include(<llama/ggml.h>) 2 | #include <llama/ggml.h> 3 | #else 4 | // For DocC 5 | #include "../exclude/llama.cpp/ggml/include/ggml.h" 6 | #endif 7 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/include/gguf.h: -------------------------------------------------------------------------------- 1 | #if __has_include(<llama/gguf.h>) 2 | #include <llama/gguf.h> 3 | #else 4 | // For DocC 5 | #include "../exclude/llama.cpp/ggml/include/gguf.h" 6 | #endif 7 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/include/llama.h: -------------------------------------------------------------------------------- 1 | #define LLAVA_LOG_OFF 2 | #if __has_include(<llama/llama.h>) 3 | #include <llama/llama.h> 4 | #else 5 | // For DocC 6 | #include "../exclude/llama.cpp/include/llama.h" 7 | #endif 8 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/include/mtmd-helper.h: -------------------------------------------------------------------------------- 1 | ../exclude/llama.cpp/tools/mtmd/mtmd-helper.h -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/include/mtmd.h: -------------------------------------------------------------------------------- 1 | ../exclude/llama.cpp/tools/mtmd/mtmd.h -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/miniaudio/miniaudio.h: -------------------------------------------------------------------------------- 1 | ../exclude/llama.cpp/vendor/miniaudio/miniaudio.h -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/mtmd-audio.cpp: -------------------------------------------------------------------------------- 1 | exclude/llama.cpp/tools/mtmd/mtmd-audio.cpp -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/mtmd-audio.h: -------------------------------------------------------------------------------- 1 | exclude/llama.cpp/tools/mtmd/mtmd-audio.h -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/mtmd-helper.cpp: -------------------------------------------------------------------------------- 1 | exclude/llama.cpp/tools/mtmd/mtmd-helper.cpp -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/mtmd.cpp: -------------------------------------------------------------------------------- 1 | exclude/llama.cpp/tools/mtmd/mtmd.cpp -------------------------------------------------------------------------------- /Sources/LocalLLMClientLlamaC/stb/stb_image.h: -------------------------------------------------------------------------------- 1 | #ifdef __linux__ 2 | #include "../exclude/llama.cpp/vendor/stb/stb_image.h" 3 | #else 4 | // Implemented by stb_image.swift 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | extern const unsigned char *stbi_load(char const *filename, int *x, int *y, 11 | int *comp, int req_comp); 12 | extern void stbi_image_free(const unsigned char *retval_from_stbi_load); 13 | extern const unsigned char *stbi_load_from_memory(void const *buffer, 14 | size_t len, int *x, int *y, 15 | int *comp, int req_comp); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | 21 | #endif // __linux__ 22 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientMLX/Context.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLXVLM 3 | import LocalLLMClient 4 | import MLX 5 | import MLXLLM 6 | import MLXLMCommon 7 | import MLXRandom 8 | import Tokenizers 9 | 10 | public final class Context: Sendable { 11 | let modelContainer: ModelContainer 12 | let supportsVision: Bool 13 | 14 | public init(url: URL, parameter: MLXClient.Parameter) async throws(LLMError) { 15 | initializeMLX() 16 | 17 | MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000)) 18 | 19 | let configuration = ModelConfiguration(directory: url, extraEOSTokens: parameter.options.extraEOSTokens) 20 | 21 | let (model, tokenizer) = try await Self.loadModel( 22 | url: url, configuration: configuration 23 | ) 24 | let (processor, supportsVision) = Self.makeProcessor( 25 | url: url, configuration: configuration, tokenizer: tokenizer 26 | ) 27 | 28 | let context = ModelContext( 29 | configuration: configuration, 30 | model: model, 31 | processor: processor, 32 | tokenizer: tokenizer 33 | ) 34 | modelContainer = ModelContainer(context: context) 35 | self.supportsVision = supportsVision 36 | } 37 | 38 | private static func loadModel( 39 | url: URL, configuration: ModelConfiguration 40 | ) async throws(LLMError) -> (any LanguageModel, any Tokenizer) { 41 | do { 42 | let configurationURL = url.appending(component: "config.json") 43 | let baseConfiguration = try JSONDecoder().decode( 44 | BaseConfiguration.self, from: Data(contentsOf: configurationURL) 45 | ) 46 | let model: any LanguageModel 47 | do { 48 | model = try VLMTypeRegistry.shared.createModel( 49 | configuration: configurationURL, 50 | modelType: baseConfiguration.modelType 51 | ) 52 | } catch { 53 | model = try LLMTypeRegistry.shared.createModel( 54 | configuration: configurationURL, 55 | modelType: baseConfiguration.modelType 56 | ) 57 | } 58 | 59 | try loadWeights(modelDirectory: url, model: model, perLayerQuantization: baseConfiguration.perLayerQuantization) 60 | 61 | let tokenizer = try await loadTokenizer(configuration: configuration, hub: .shared) 62 | return (model, tokenizer) 63 | } catch { 64 | throw .failedToLoad(reason: error.localizedDescription) 65 | } 66 | } 67 | 68 | private static func makeProcessor( 69 | url: URL, configuration: ModelConfiguration, tokenizer: any Tokenizer, 70 | ) -> (any UserInputProcessor, Bool) { 71 | do { 72 | let processorConfiguration = url.appending( 73 | component: "preprocessor_config.json" 74 | ) 75 | let baseProcessorConfig = try JSONDecoder().decode( 76 | BaseProcessorConfiguration.self, 77 | from: Data(contentsOf: processorConfiguration) 78 | ) 79 | 80 | return (try VLMProcessorTypeRegistry.shared.createModel( 81 | configuration: processorConfiguration, 82 | processorType: baseProcessorConfig.processorClass, 83 | tokenizer: tokenizer 84 | ), true) 85 | } catch { 86 | return (LLMUserInputProcessor( 87 | tokenizer: tokenizer, 88 | configuration: configuration, 89 | messageGenerator: DefaultMessageGenerator() 90 | ), false) 91 | } 92 | } 93 | } 94 | 95 | private struct LLMUserInputProcessor: UserInputProcessor { 96 | let tokenizer: Tokenizer 97 | let configuration: ModelConfiguration 98 | let messageGenerator: MessageGenerator 99 | 100 | init( 101 | tokenizer: any Tokenizer, configuration: ModelConfiguration, 102 | messageGenerator: MessageGenerator 103 | ) { 104 | self.tokenizer = tokenizer 105 | self.configuration = configuration 106 | self.messageGenerator = messageGenerator 107 | } 108 | 109 | func prepare(input: UserInput) throws -> LMInput { 110 | let messages = messageGenerator.generate(from: input) 111 | do { 112 | let promptTokens = try tokenizer.applyChatTemplate( 113 | messages: messages, tools: input.tools, additionalContext: input.additionalContext 114 | ) 115 | return LMInput(tokens: MLXArray(promptTokens)) 116 | } catch { 117 | let prompt = messages 118 | .compactMap { $0["content"] as? String } 119 | .joined(separator: "\n\n") 120 | let promptTokens = tokenizer.encode(text: prompt) 121 | return LMInput(tokens: MLXArray(promptTokens)) 122 | } 123 | } 124 | } 125 | 126 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientMLX/MLXClient.swift: -------------------------------------------------------------------------------- 1 | import LocalLLMClient 2 | import MLX 3 | import MLXLMCommon 4 | import Foundation 5 | 6 | /// A client for interacting with MLX models. 7 | /// 8 | /// This actor-based class provides methods for generating text streams from various inputs, 9 | /// and handles the communication with the underlying MLX model via the `MLX` and `MLXLMCommon` frameworks. 10 | public final actor MLXClient: LLMClient { 11 | private let context: Context 12 | private let parameter: MLXClient.Parameter 13 | 14 | /// Initializes a new MLX client. 15 | /// 16 | /// - Parameters: 17 | /// - url: The URL of the MLX model directory. This directory should contain the model weights, tokenizer configuration, and any other necessary model files. 18 | /// - parameter: The parameters for the MLX model. Defaults to `.default`. 19 | /// - Throws: An error if the client fails to initialize, for example, if the model files cannot be loaded. 20 | nonisolated public init(url: URL, parameter: Parameter = .default) async throws { 21 | context = try await Context(url: url, parameter: parameter) 22 | self.parameter = parameter 23 | } 24 | 25 | /// Generates a text stream from the given input. 26 | /// 27 | /// This function processes the input, whether it's plain text, a chat template, or structured chat messages, 28 | /// and prepares it for the MLX model. It then generates text asynchronously. 29 | /// 30 | /// - Parameter input: The input to generate text from. This can be plain text, a chat template, or an array of chat messages. 31 | /// - Returns: An `AsyncStream<String>` that yields text chunks as they are generated by the model. 32 | /// - Throws: An `LLMError.visionUnsupported` error if the input contains images and the loaded model does not support vision. 33 | /// It can also throw errors related to model processing or input preparation. 34 | public func textStream(from input: LLMInput) async throws -> AsyncStream<String> { 35 | let chat: [Chat.Message] = switch input.value { 36 | case .plain(let text): 37 | [.user(text)] 38 | case .chatTemplate(let messages): 39 | messages.map { 40 | Chat.Message( 41 | role: .init(rawValue: $0.value["role"] as? String ?? "") ?? .user, 42 | content: $0.value["content"] as? String ?? "", 43 | images: $0.attachments.images 44 | ) 45 | } 46 | case .chat(let messages): 47 | messages.map { 48 | Chat.Message( 49 | role: .init(rawValue: $0.role.rawValue) ?? .user, 50 | content: $0.content, 51 | images: $0.attachments.images 52 | ) 53 | } 54 | } 55 | 56 | var userInput = UserInput(chat: chat, additionalContext: ["enable_thinking": false]) // TODO: public API 57 | userInput.processing.resize = .init(width: 448, height: 448) 58 | 59 | if chat.contains(where: { !$0.images.isEmpty }), !context.supportsVision { 60 | throw LLMError.visionUnsupported 61 | } 62 | let modelContainer = context.modelContainer 63 | 64 | return try await modelContainer.perform { [userInput] context in 65 | let lmInput = try await context.processor.prepare(input: userInput) 66 | let stream = try MLXLMCommon.generate( 67 | input: lmInput, 68 | parameters: parameter.parameters, 69 | context: context 70 | ) 71 | 72 | return .init { continuation in 73 | let task = Task { 74 | for await generated in stream { 75 | continuation.yield(generated.chunk ?? "") 76 | } 77 | continuation.finish() 78 | } 79 | continuation.onTermination = { _ in 80 | task.cancel() 81 | } 82 | } 83 | } 84 | } 85 | } 86 | 87 | private extension [LLMAttachment] { 88 | var images: [UserInput.Image] { 89 | compactMap { 90 | switch $0 { 91 | case let .image(image): 92 | return try? UserInput.Image.ciImage(llmInputImageToCIImage(image)) 93 | } 94 | } 95 | } 96 | } 97 | 98 | public extension LocalLLMClient { 99 | /// Creates a new MLX client. 100 | /// 101 | /// This is a factory method for creating `MLXClient` instances. 102 | /// 103 | /// - Parameters: 104 | /// - url: The URL of the MLX model directory. This directory should contain the model weights, tokenizer configuration, and any other necessary model files. 105 | /// - parameter: The parameters for the MLX model. Defaults to `.default`. 106 | /// - Returns: A new `MLXClient` instance. 107 | /// - Throws: An error if the client fails to initialize, for example, if the model files cannot be loaded. 108 | static func mlx(url: URL, parameter: MLXClient.Parameter = .default) async throws -> MLXClient { 109 | try await MLXClient(url: url, parameter: parameter) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientMLX/Parameter.swift: -------------------------------------------------------------------------------- 1 | import MLXLMCommon 2 | 3 | public extension MLXClient { 4 | /// Defines the parameters for the MLX client and model. 5 | /// 6 | /// These parameters control various aspects of the text generation process, 7 | /// such as token limits, sampling methods, and repetition penalties, 8 | /// largely by configuring the underlying `MLXLMCommon.GenerateParameters`. 9 | struct Parameter: Sendable { 10 | /// Initializes a new set of parameters for the MLX client. 11 | /// 12 | /// - Parameters: 13 | /// - maxTokens: The maximum number of tokens to generate. If `nil`, generation continues until an end-of-sequence token is produced. Default is `nil`. 14 | /// - temperature: Controls the randomness of the generated text. Higher values (e.g., 0.8) make the output more random, while lower values (e.g., 0.2) make it more focused. Default is `0.6`. 15 | /// - topP: Restricts token selection to a cumulative probability distribution. Only tokens whose cumulative probability is less than or equal to `topP` are considered. Default is `1.0`. 16 | /// - repetitionPenalty: The penalty applied to repeated tokens. A value of `nil` or `1.0` means no penalty. Higher values discourage repetition. Default is `nil`. 17 | /// - repetitionContextSize: The number of recent tokens to consider for the repetition penalty. Default is `20`. 18 | /// - options: Additional, less commonly used options for the MLX client. 19 | public init( 20 | maxTokens: Int? = nil, 21 | temperature: Float = 0.6, 22 | topP: Float = 1.0, 23 | repetitionPenalty: Float? = nil, 24 | repetitionContextSize: Int = 20, 25 | options: Options = .init(), 26 | ) { 27 | parameters = .init( 28 | maxTokens: maxTokens, 29 | temperature: temperature, 30 | topP: topP, 31 | repetitionPenalty: repetitionPenalty, 32 | repetitionContextSize: repetitionContextSize 33 | ) 34 | self.options = options 35 | } 36 | 37 | /// The core generation parameters passed to the `MLXLMCommon` framework. 38 | public var parameters: GenerateParameters 39 | /// Additional, less commonly used options for the MLX client. 40 | public var options: Options 41 | 42 | /// Provides a default set of parameters. 43 | public static let `default` = Parameter() 44 | } 45 | 46 | /// Defines additional, less commonly used options for the MLX client. 47 | struct Options: Sendable { 48 | /// Initializes a new set of options for the MLX client. 49 | /// 50 | /// - Parameters: 51 | /// - extraEOSTokens: A set of additional strings that, when encountered, will be treated as end-of-sequence tokens by the model. Default is an empty set. 52 | public init( 53 | extraEOSTokens: Set<String> = [] 54 | ) { 55 | self.extraEOSTokens = extraEOSTokens 56 | } 57 | 58 | /// Additional strings to be treated as end-of-sequence tokens by the model. 59 | public var extraEOSTokens: Set<String> 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientMLX/Utility.swift: -------------------------------------------------------------------------------- 1 | import MLX 2 | 3 | // MARK: - Global State 4 | 5 | nonisolated(unsafe) private var isMLXInitialized = false 6 | 7 | // MARK: - Life Cycle 8 | 9 | public func initializeMLX() { 10 | guard !isMLXInitialized else { return } 11 | isMLXInitialized = true 12 | MLX.GPU.set(cacheLimit: 20 * 1024 * 1024) 13 | } 14 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientUtility/Downloader.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | #if canImport(FoundationNetworking) 3 | import FoundationNetworking 4 | #endif 5 | 6 | final class Downloader { 7 | private(set) var downloaders: [ChildDownloader] = [] 8 | let progress = Progress(totalUnitCount: 0) 9 | #if os(Linux) 10 | private var observer: Task<Void, Never>? 11 | #else 12 | private var observer: NSKeyValueObservation? 13 | #endif 14 | 15 | var isDownloading: Bool { 16 | downloaders.contains(where: \.isDownloading) 17 | } 18 | 19 | var isDownloaded: Bool { 20 | downloaders.allSatisfy(\.isDownloaded) 21 | } 22 | 23 | init() {} 24 | 25 | #if os(Linux) 26 | deinit { 27 | observer?.cancel() 28 | } 29 | #endif 30 | 31 | func add(_ downloader: ChildDownloader) { 32 | downloaders.append(downloader) 33 | progress.addChild(downloader.progress, withPendingUnitCount: 1) 34 | progress.totalUnitCount += 1 35 | } 36 | 37 | func setObserver(_ action: @Sendable @escaping (Progress) async -> Void) { 38 | #if os(Linux) 39 | observer?.cancel() 40 | observer = Task { [progress] in 41 | var fractionCompleted = progress.fractionCompleted 42 | while !Task.isCancelled { 43 | if fractionCompleted != progress.fractionCompleted { 44 | fractionCompleted = progress.fractionCompleted 45 | await action(progress) 46 | } 47 | try? await Task.sleep(for: .seconds(1)) 48 | } 49 | } 50 | #else 51 | observer = progress.observe(\.fractionCompleted, options: [.initial, .new]) { progress, change in 52 | Task { 53 | await action(progress) 54 | } 55 | } 56 | #endif 57 | } 58 | 59 | func download() { 60 | guard !downloaders.isEmpty else { 61 | // Notify that download is complete 62 | progress.totalUnitCount = 1 63 | progress.completedUnitCount = 1 64 | return 65 | } 66 | for downloader in downloaders { 67 | downloader.download() 68 | } 69 | } 70 | 71 | func waitForDownloads() async { 72 | while isDownloading && progress.fractionCompleted < 1.0 { 73 | try? await Task.sleep(for: .seconds(1)) 74 | } 75 | } 76 | } 77 | 78 | extension Downloader { 79 | final class ChildDownloader: Sendable { 80 | private let url: URL 81 | private let destinationURL: URL 82 | private let session: URLSession 83 | private let delegate = Delegate() 84 | 85 | var progress: Progress { 86 | delegate.progress 87 | } 88 | 89 | var isDownloading: Bool { 90 | delegate.isDownloading.withLock(\.self) 91 | } 92 | 93 | var isDownloaded: Bool { 94 | FileManager.default.fileExists(atPath: destinationURL.path) 95 | } 96 | 97 | public init(url: URL, destinationURL: URL, configuration: URLSessionConfiguration = .default) { 98 | self.url = url 99 | self.destinationURL = destinationURL 100 | session = URLSession(configuration: configuration, delegate: delegate, delegateQueue: nil) 101 | 102 | #if !os(Linux) 103 | Task { 104 | for task in await session.allTasks { 105 | if task.taskDescription == destinationURL.absoluteString { 106 | download(existingTask: task) 107 | } else { 108 | task.cancel() 109 | } 110 | } 111 | } 112 | #endif 113 | } 114 | 115 | public func download(existingTask: URLSessionTask? = nil) { 116 | guard !isDownloading else { return } 117 | delegate.isDownloading.withLock { $0 = true } 118 | 119 | try? FileManager.default.createDirectory(at: destinationURL.deletingLastPathComponent(), withIntermediateDirectories: true) 120 | let task = existingTask ?? session.downloadTask(with: url) 121 | task.taskDescription = destinationURL.absoluteString 122 | task.priority = URLSessionTask.highPriority 123 | task.resume() 124 | } 125 | } 126 | } 127 | 128 | extension Downloader.ChildDownloader { 129 | final class Delegate: NSObject, URLSessionDownloadDelegate { 130 | let progress = Progress(totalUnitCount: 1) 131 | let isDownloading = Locked(false) 132 | 133 | func urlSession( 134 | _ session: URLSession, downloadTask: URLSessionDownloadTask, 135 | didFinishDownloadingTo location: URL 136 | ) { 137 | #if DEBUG 138 | print("Download finished to location: \(location.path)") 139 | #endif 140 | 141 | // Move the downloaded file to the permanent location 142 | guard let taskDescription = downloadTask.taskDescription, 143 | let destinationURL = URL(string: taskDescription) else { 144 | return 145 | } 146 | try? FileManager.default.removeItem(at: destinationURL) 147 | do { 148 | try FileManager.default.createDirectory( 149 | at: destinationURL.deletingLastPathComponent(), 150 | withIntermediateDirectories: true 151 | ) 152 | try FileManager.default.moveItem(at: location, to: destinationURL) 153 | } catch { 154 | print("The URLSessionTask may be old. The app container was already invalid: \(error.localizedDescription)") 155 | } 156 | } 157 | 158 | func urlSession( 159 | _ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error? 160 | ) { 161 | #if DEBUG 162 | if let error { 163 | print("Download failed with error: \(error.localizedDescription)") 164 | } 165 | #endif 166 | isDownloading.withLock { $0 = false } 167 | } 168 | 169 | func urlSession( 170 | _ session: URLSession, downloadTask: URLSessionDownloadTask, 171 | didWriteData bytesWritten: Int64, 172 | totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64 173 | ) { 174 | if bytesWritten == totalBytesWritten { 175 | progress.totalUnitCount = totalBytesExpectedToWrite 176 | } 177 | progress.completedUnitCount = totalBytesWritten 178 | } 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientUtility/FileDownloader.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// A protocol defining the requirements for an entity that can download files. 4 | /// 5 | /// Types conforming to `FileDownloadable` are expected to manage the source and destination 6 | /// of downloadable files, check their download status, and handle associated metadata. 7 | public protocol FileDownloadable: Sendable { 8 | /// The source from which the file(s) are downloaded (e.g., a specific Hugging Face repository). 9 | var source: FileDownloader.Source { get } 10 | /// The local URL where the downloaded file(s) are or will be stored. 11 | var destination: URL { get } 12 | /// A Boolean value indicating whether the file(s) from the source have been successfully downloaded to the destination. 13 | var isDownloaded: Bool { get } 14 | 15 | /// Removes any metadata associated with the downloaded files. 16 | /// 17 | /// This is useful for clearing up stored information about the files, potentially 18 | /// forcing a re-download or re-check of metadata in the future. 19 | /// 20 | /// - Throws: An error if the metadata cannot be removed, for example, due to file permission issues or if the metadata file doesn't exist. 21 | func removeMetadata() throws 22 | } 23 | 24 | /// A struct that implements the `FileDownloadable` protocol to manage file downloads. 25 | /// 26 | /// This struct provides a concrete implementation for downloading files, particularly from 27 | /// Hugging Face Hub, using the `HubApi`. It handles metadata storage and progress reporting. 28 | public struct FileDownloader: FileDownloadable { 29 | public let source: Source 30 | private let rootDestination: URL 31 | private let downloadConfiguration: DownloadConfiguration 32 | 33 | /// The default root directory where downloaded files are stored. 34 | /// This is typically a subdirectory within the application's support directory, named "LocalLLM". 35 | public static let defaultRootDestination = URL.defaultRootDirectory 36 | 37 | public var destination: URL { 38 | source.destination(for: rootDestination) 39 | } 40 | 41 | public var isDownloaded: Bool { 42 | source.isDownloaded(for: rootDestination) 43 | } 44 | 45 | /// Specifies the source from which files are to be downloaded. 46 | public enum Source: Sendable, Equatable { 47 | /// Represents a source from Hugging Face Hub. 48 | /// 49 | /// - Parameters: 50 | /// - id: The repository identifier on Hugging Face (e.g., "ml-explore/mlx-swift-examples"). 51 | /// - globs: A set of glob patterns to filter which files are downloaded from the repository. 52 | case huggingFace(id: String, globs: Globs) 53 | 54 | func destination(for rootDestination: URL) -> URL { 55 | switch self { 56 | case let .huggingFace(id, _): 57 | let client = HuggingFaceAPI(repo: .init(id: id)) 58 | return client.getLocalRepoLocation(downloadBase: rootDestination) 59 | } 60 | } 61 | 62 | func isDownloaded(for destination: URL) -> Bool { 63 | guard FileManager.default.fileExists(atPath: destination.path), 64 | let meta = try? FilesMetadata.load(from: destination) else { 65 | return false 66 | } 67 | 68 | let fileURLs = FileManager.default.enumerator(at: destination, includingPropertiesForKeys: nil)? 69 | .compactMap { $0 as? URL } ?? [] 70 | return meta.files.allSatisfy { file in 71 | fileURLs.contains { url in 72 | file.name == url.lastPathComponent 73 | } 74 | } 75 | } 76 | 77 | func downloadFiles(to rootDestination: URL, configuration: HuggingFaceAPI.DownloadConfiguration = .default, onProgress: @Sendable @escaping (Double) async -> Void) async throws { 78 | switch self { 79 | case let .huggingFace(id, globs): 80 | let client = HuggingFaceAPI(repo: .init(id: id)) 81 | try await client.downloadSnapshot(to: rootDestination, matching: globs, configuration: configuration) { progress in 82 | Task { [progress] in 83 | await onProgress(progress.fractionCompleted) 84 | } 85 | } 86 | } 87 | } 88 | 89 | @discardableResult 90 | func saveMetadata(to destination: URL) async throws -> FilesMetadata { 91 | switch self { 92 | case let .huggingFace(id, globs): 93 | let client = HuggingFaceAPI(repo: .init(id: id)) 94 | let filenames = try await client.getFilenames(matching: globs) 95 | let metadata = FilesMetadata(files: filenames.map { FilesMetadata.FileMetadata(name: $0) }) 96 | try metadata.save(to: destination) 97 | return metadata 98 | } 99 | } 100 | 101 | func removeMetadata(from destination: URL) throws { 102 | try FileManager.default.removeItem(at: destination.appendingPathComponent(FilesMetadata.filename)) 103 | } 104 | } 105 | 106 | public struct DownloadConfiguration: Sendable { 107 | public var identifier: String? 108 | public var protocolClasses: [AnyClass]? 109 | 110 | /// Initializes a new download configuration 111 | public static let `default` = DownloadConfiguration(identifier: nil) 112 | 113 | /// Creates a new download configuration for background downloads 114 | public static func background(withIdentifier identifier: String) -> DownloadConfiguration { 115 | DownloadConfiguration(identifier: identifier) 116 | } 117 | 118 | func makeHuggingFaceConfiguration() -> HuggingFaceAPI.DownloadConfiguration { 119 | var result: HuggingFaceAPI.DownloadConfiguration = if let identifier { 120 | .background(withIdentifier: identifier) 121 | } else { 122 | .default 123 | } 124 | result.protocolClasses = protocolClasses 125 | return result 126 | } 127 | } 128 | 129 | /// Initializes a new file downloader. 130 | /// 131 | /// - Parameters: 132 | /// - source: The source from which to download the file(s), e.g., a Hugging Face repository. 133 | /// - destination: The root URL where the downloaded files should be stored. Defaults to `defaultRootDestination`. 134 | public init(source: Source, destination: URL = defaultRootDestination) { 135 | self.source = source 136 | self.rootDestination = destination 137 | self.downloadConfiguration = .default 138 | } 139 | 140 | /// Initializes a new file downloader with a custom download configuration. 141 | /// 142 | /// - Parameters: 143 | /// - source: The source from which to download the file(s), e.g., a Hugging Face repository. 144 | /// - destination: The root URL where the downloaded files should be stored. Defaults to `defaultRootDestination`. 145 | /// - configuration: The download configuration to use, which can include background download settings. 146 | public init(source: Source, destination: URL = defaultRootDestination, configuration: DownloadConfiguration) { 147 | self.source = source 148 | self.rootDestination = destination 149 | self.downloadConfiguration = configuration 150 | } 151 | 152 | /// Starts the download of the file(s) from the specified source. 153 | /// 154 | /// If the files are already downloaded, this method completes immediately, calling the progress handler with `1.0`. 155 | /// It handles saving metadata and then uses `HubApi` to perform the actual download, reporting progress via the `onProgress` closure. 156 | /// 157 | /// - Parameter onProgress: An asynchronous closure that is called with the download progress (a `Double` between 0.0 and 1.0). Defaults to an empty closure. 158 | /// - Throws: An error if saving metadata fails or if the `HubApi` encounters an issue during the download. 159 | public func download(onProgress: @Sendable @escaping (Double) async -> Void = { _ in }) async throws { 160 | let destination = source.destination(for: rootDestination) 161 | guard !source.isDownloaded(for: destination) else { 162 | await onProgress(1.0) 163 | return 164 | } 165 | try await source.saveMetadata(to: destination) 166 | try await source.downloadFiles( 167 | to: rootDestination, 168 | configuration: downloadConfiguration.makeHuggingFaceConfiguration(), 169 | onProgress: onProgress 170 | ) 171 | } 172 | } 173 | 174 | struct FilesMetadata: Codable, Sendable { 175 | static let filename = ".filesmeta" 176 | 177 | let files: [FileMetadata] 178 | 179 | struct FileMetadata: Codable, Sendable { 180 | let name: String 181 | } 182 | 183 | static func load(from url: URL) throws -> FilesMetadata { 184 | let data = try Data(contentsOf: url.appendingPathComponent(filename)) 185 | return try JSONDecoder().decode(FilesMetadata.self, from: data) 186 | } 187 | 188 | func save(to url: URL) throws { 189 | try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true) 190 | let data = try JSONEncoder().encode(self) 191 | try data.write(to: url.appendingPathComponent(Self.filename)) 192 | } 193 | } 194 | 195 | public extension FileDownloadable { 196 | func removeMetadata() throws { 197 | try source.removeMetadata(from: destination) 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientUtility/Globs.swift: -------------------------------------------------------------------------------- 1 | /// A struct representing a collection of glob patterns used to filter files. 2 | public struct Globs: Sendable, Equatable { 3 | public let rawValue: [String] 4 | 5 | /// Initializes a new set of glob patterns. 6 | /// 7 | /// - Parameter globs: An array of strings, where each string is a glob pattern (e.g., "*.json", "model.*.gguf"). 8 | public init(_ globs: [String]) { 9 | self.rawValue = globs 10 | } 11 | 12 | /// Default glob patterns for MLX models, typically including "*.safetensors" and "*.json". 13 | public static let mlx = Globs(["*.safetensors", "*.json"]) 14 | } 15 | 16 | extension Globs: ExpressibleByArrayLiteral { 17 | /// Initializes a new set of glob patterns from an array literal 18 | /// - Parameter elements: Array of strings representing glob patterns 19 | public init(arrayLiteral elements: String...) { 20 | self.init(elements) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientUtility/HuggingFaceAPI.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | #if canImport(FoundationNetworking) 3 | import FoundationNetworking 4 | #endif 5 | 6 | /// Represents the Hugging Face API client 7 | public struct HuggingFaceAPI: Sendable { 8 | /// API endpoint for Hugging Face Hub 9 | private let endpoint = URL(string: "https://huggingface.co")! 10 | 11 | /// Authentication token for accessing Hugging Face Hub 12 | public let hfToken: String? 13 | 14 | /// Repository reference for Hugging Face 15 | public let repo: Repo 16 | 17 | /// Initializes a new Hugging Face API client 18 | /// - Parameters: 19 | /// - repo: Repository information 20 | /// - token: Authentication token (optional) 21 | public init(repo: Repo, token: String? = nil) { 22 | self.repo = repo 23 | self.hfToken = token 24 | } 25 | 26 | /// Repository type for Hugging Face repositories 27 | public enum RepoType: String, Sendable { 28 | case models 29 | case datasets 30 | case spaces 31 | } 32 | 33 | /// Repository information for Hugging Face 34 | public struct Repo: Equatable, Sendable { 35 | /// Repository identifier, such as "meta-llama/Meta-Llama-3-8B" 36 | public let id: String 37 | 38 | /// Repository type, defaults to models 39 | public let type: RepoType 40 | 41 | /// Creates a new repository reference 42 | /// - Parameters: 43 | /// - id: Repository identifier 44 | /// - type: Repository type, defaults to `.models` 45 | public init(id: String, type: RepoType = .models) { 46 | self.id = id 47 | self.type = type 48 | } 49 | } 50 | 51 | public struct DownloadConfiguration: Sendable { 52 | public var identifier: String? 53 | public var protocolClasses: [AnyClass]? 54 | 55 | /// Initializes a new download configuration 56 | public static let `default` = DownloadConfiguration(identifier: nil) 57 | 58 | /// Creates a new download configuration for background downloads 59 | public static func background(withIdentifier identifier: String) -> DownloadConfiguration { 60 | DownloadConfiguration(identifier: identifier) 61 | } 62 | 63 | func makeURLSessionConfiguration() -> URLSessionConfiguration { 64 | let config: URLSessionConfiguration 65 | #if os(iOS) || os(macOS) 66 | if let identifier { 67 | config = URLSessionConfiguration.background(withIdentifier: identifier) 68 | config.isDiscretionary = true 69 | config.sessionSendsLaunchEvents = true 70 | } else { 71 | config = .default 72 | } 73 | #else 74 | config = .default 75 | #endif 76 | config.protocolClasses = protocolClasses 77 | return config 78 | } 79 | } 80 | 81 | /// Get the local directory location for a repository 82 | /// - Parameters: 83 | /// - downloadBase: The base directory for downloads 84 | /// - Returns: The local URL for the repository 85 | public func getLocalRepoLocation(downloadBase: URL) -> URL { 86 | downloadBase 87 | .appending(component: "huggingface") 88 | .appending(component: repo.type.rawValue) 89 | .appending(component: repo.id) 90 | } 91 | 92 | /// Retrieves filenames from a Hugging Face repository that match the given glob patterns 93 | /// - Parameters: 94 | /// - globs: Array of glob patterns to match files (e.g., "*.json") 95 | /// - revision: The repository revision (branch, tag, or commit hash), defaults to "main" 96 | /// - Returns: Array of matching filenames 97 | public func getFilenames( 98 | matching globs: Globs, 99 | revision: String = "main", 100 | configuration: URLSessionConfiguration = .default 101 | ) async throws -> [String] { 102 | // Read repo info and only parse "siblings" (files in the repository) 103 | let (data, _) = try await get( 104 | for: endpoint.appending(path: "api/\(repo.type.rawValue)/\(repo.id)/revision/\(revision)"), 105 | configuration: configuration 106 | ) 107 | 108 | // Decode the JSON response 109 | let response = try JSONDecoder().decode(SiblingsResponse.self, from: data) 110 | let filenames = response.siblings.map(\.rfilename) 111 | 112 | // If no globs are provided, return all filenames 113 | guard !globs.rawValue.isEmpty else { return filenames } 114 | 115 | // Filter filenames based on glob patterns 116 | var selected: Set<String> = [] 117 | for glob in globs.rawValue { 118 | selected = selected.union(filenames.matching(glob: glob)) 119 | } 120 | 121 | return Array(selected) 122 | } 123 | 124 | /// Downloads files from a Hugging Face repository that match the given glob patterns 125 | /// - Parameters: 126 | /// - downloadBase: The base directory for downloads 127 | /// - globs: Array of glob patterns to match files (e.g., "*.json") 128 | /// - revision: The repository revision (branch, tag, or commit hash), defaults to "main" 129 | /// - configuration: URLSession configuration to use for the download, defaults to .default 130 | /// - progressHandler: Closure to report download progress 131 | /// - Returns: The local URL where files were downloaded 132 | @discardableResult 133 | public func downloadSnapshot( 134 | to downloadBase: URL, 135 | matching globs: Globs, 136 | revision: String = "main", 137 | configuration: DownloadConfiguration = .default, 138 | progressHandler: @Sendable @escaping (Progress) async -> Void = { _ in } 139 | ) async throws -> URL { 140 | let destination = getLocalRepoLocation(downloadBase: downloadBase) 141 | 142 | // Create the directory structure 143 | try FileManager.default.createDirectory(at: destination, withIntermediateDirectories: true) 144 | 145 | // Get filenames to download 146 | let filenames = try await getFilenames(matching: globs, revision: revision, configuration: configuration.makeURLSessionConfiguration()) 147 | 148 | let downloader = Downloader() 149 | for filename in filenames { 150 | let type = repo.type == .models ? "" : "\(repo.type.rawValue)/" 151 | downloader.add(.init( 152 | url: endpoint.appending(path: "\(type)\(repo.id)/resolve/\(revision)/\(filename)"), 153 | destinationURL: destination.appendingPathComponent(filename), 154 | configuration: { 155 | if let identifier = configuration.identifier { 156 | var configuration = configuration 157 | configuration.identifier = "\(identifier)_\(filename)" 158 | return configuration.makeURLSessionConfiguration() 159 | } else { 160 | return configuration.makeURLSessionConfiguration() 161 | } 162 | }() 163 | )) 164 | } 165 | downloader.setObserver { progress in 166 | await progressHandler(progress) 167 | } 168 | 169 | downloader.download() 170 | await downloader.waitForDownloads() 171 | await progressHandler(downloader.progress) 172 | 173 | return destination 174 | } 175 | 176 | /// Gets metadata for a file in a Hugging Face repository 177 | /// - Parameters: 178 | /// - url: The URL of the file 179 | /// - Returns: The file metadata 180 | public func getFileMetadata(url: URL) async throws -> FileMetadata { 181 | let (_, response) = try await get(for: url) 182 | let location = response.statusCode == 302 ? response.value(forHTTPHeaderField: "Location") : response.url?.absoluteString 183 | 184 | return FileMetadata( 185 | commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"), 186 | etag: normalizeEtag(response.value(forHTTPHeaderField: "ETag")), 187 | location: location ?? url.absoluteString, 188 | size: Int(response.value(forHTTPHeaderField: "Content-Length") ?? "") 189 | ) 190 | } 191 | 192 | private func normalizeEtag(_ etag: String?) -> String? { 193 | guard let etag else { return nil } 194 | return etag.trimmingPrefix("W/").trimmingCharacters(in: CharacterSet(charactersIn: "\"")) 195 | } 196 | 197 | /// Gets metadata for files in a Hugging Face repository that match the given glob patterns 198 | /// - Parameters: 199 | /// - globs: Array of glob patterns to match files (e.g., "*.json") 200 | /// - revision: The repository revision (branch, tag, or commit hash), defaults to "main" 201 | /// - Returns: Array of file metadata 202 | public func getFileMetadata(matching globs: Globs, revision: String = "main") async throws -> [FileMetadata] { 203 | let files = try await getFilenames(matching: globs, revision: revision) 204 | let baseURL = URL(string: "\(endpoint)/\(repo.type.rawValue)/\(repo.id)/resolve/\(revision)")! 205 | 206 | var metadata: [FileMetadata] = [] 207 | for file in files { 208 | let fileURL = baseURL.appendingPathComponent(file) 209 | try await metadata.append(getFileMetadata(url: fileURL)) 210 | } 211 | 212 | return metadata 213 | } 214 | 215 | /// Data structure containing information about a file versioned on the Hub 216 | public struct FileMetadata { 217 | /// The commit hash related to the file 218 | public let commitHash: String? 219 | 220 | /// Etag of the file on the server 221 | public let etag: String? 222 | 223 | /// Location where to download the file. Can be a Hub url or not (CDN). 224 | public let location: String 225 | 226 | /// Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer. 227 | public let size: Int? 228 | } 229 | } 230 | 231 | // MARK: - API Helpers 232 | 233 | extension HuggingFaceAPI { 234 | private struct SiblingsResponse: Codable { 235 | let siblings: [Sibling] 236 | 237 | /// Model data for parsed filenames 238 | struct Sibling: Codable { 239 | let rfilename: String 240 | } 241 | } 242 | 243 | /// Performs an HTTP GET request to the specified URL 244 | /// - Parameter url: The URL to request 245 | /// - Returns: Tuple containing the response data and HTTP response 246 | private func get(for url: URL, configuration: URLSessionConfiguration = .default) async throws -> (Data, HTTPURLResponse) { 247 | var request = URLRequest(url: url) 248 | if let hfToken { 249 | request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") 250 | } 251 | 252 | var configuration = configuration 253 | if configuration.identifier != nil { 254 | let foregroundConfiguration = URLSessionConfiguration.default 255 | foregroundConfiguration.protocolClasses = configuration.protocolClasses 256 | configuration = foregroundConfiguration 257 | } 258 | 259 | let (data, response) = try await URLSession(configuration: configuration).data(for: request) 260 | guard let httpResponse = response as? HTTPURLResponse else { 261 | throw URLError(.badServerResponse) 262 | } 263 | 264 | switch httpResponse.statusCode { 265 | case 200..<400: 266 | return (data, httpResponse) 267 | case 401, 403: 268 | throw URLError(.userAuthenticationRequired) 269 | case 404: 270 | throw URLError(.fileDoesNotExist) 271 | default: 272 | throw URLError(.badServerResponse) 273 | } 274 | } 275 | } 276 | 277 | private extension [String] { 278 | /// Filters the array to only include strings that match the specified glob pattern 279 | /// - Parameter glob: The glob pattern to match against 280 | /// - Returns: Array of strings that match the glob pattern 281 | func matching(glob: String) -> [String] { 282 | filter { fnmatch(glob, $0, 0) == 0 } 283 | } 284 | } 285 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientUtility/Lock.swift: -------------------------------------------------------------------------------- 1 | #if canImport(os) 2 | import os 3 | #else 4 | import Glibc 5 | #endif 6 | 7 | #if canImport(os) 8 | package typealias Lock = OSAllocatedUnfairLock 9 | #else 10 | package final class Lock: @unchecked Sendable { 11 | @usableFromInline 12 | let mutex: UnsafeMutablePointer<pthread_mutex_t> = UnsafeMutablePointer.allocate(capacity: 1) 13 | 14 | package init() { 15 | let err = pthread_mutex_init(self.mutex, nil) 16 | precondition(err == 0) 17 | } 18 | 19 | deinit { 20 | let err = pthread_mutex_destroy(self.mutex) 21 | precondition(err == 0) 22 | mutex.deallocate() 23 | } 24 | 25 | @usableFromInline 26 | func lock() { 27 | let err = pthread_mutex_lock(self.mutex) 28 | precondition(err == 0) 29 | } 30 | 31 | @usableFromInline 32 | func unlock() { 33 | let err = pthread_mutex_unlock(self.mutex) 34 | precondition(err == 0) 35 | } 36 | 37 | @inlinable 38 | package func withLock<T>(_ body: () throws -> T) rethrows -> T { 39 | self.lock() 40 | defer { 41 | self.unlock() 42 | } 43 | return try body() 44 | } 45 | } 46 | #endif 47 | 48 | package final class Locked<Value: ~Copyable> { 49 | @usableFromInline let lock = Lock() 50 | 51 | @usableFromInline var value: Value 52 | package init(_ value: consuming sending Value) { 53 | self.value = value 54 | } 55 | } 56 | 57 | extension Locked where Value: ~Copyable { 58 | @discardableResult @inlinable 59 | package borrowing func withLock<Result: ~Copyable, E: Error>(_ block: (inout sending Value) throws(E) -> sending Result) throws(E) -> sending Result { 60 | lock.lock() 61 | defer { lock.unlock() } 62 | return try block(&value) 63 | } 64 | } 65 | 66 | extension Locked: @unchecked Sendable where Value: ~Copyable { 67 | } 68 | 69 | extension Locked where Value: Sendable { 70 | package func exchange(_ newValue: Value) -> Value { 71 | withLock { 72 | let old = $0 73 | $0 = newValue 74 | return old 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /Sources/LocalLLMClientUtility/URL+.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | extension URL { 4 | #if os(macOS) || os(Linux) 5 | static let defaultRootDirectory = FileManager.default.homeDirectoryForCurrentUser.appending(path: ".localllmclient").excludedFromBackup 6 | #else 7 | static let defaultRootDirectory = URL.documentsDirectory.appending(path: ".localllmclient").excludedFromBackup 8 | #endif 9 | 10 | var excludedFromBackup: URL { 11 | #if os(Linux) 12 | return self 13 | #else 14 | var url = self 15 | var resourceValues = URLResourceValues() 16 | resourceValues.isExcludedFromBackup = true 17 | try? url.setResourceValues(resourceValues) 18 | return url 19 | #endif 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /Tests/LocalLLMClientLlamaTests/ContextTests.swift: -------------------------------------------------------------------------------- 1 | import Testing 2 | import Foundation 3 | import LocalLLMClient 4 | @testable import LocalLLMClientLlama 5 | 6 | extension ModelTests { 7 | struct ContextTests {} 8 | } 9 | 10 | extension ModelTests.ContextTests { 11 | 12 | @Test 13 | func verifyContext() async throws { 14 | try await verifyContext(withText: "Hello, world!") 15 | } 16 | 17 | @Test 18 | func verifyContextMultibytes() async throws { 19 | try await verifyContext(withText: "こんにちは, 世界!") 20 | } 21 | 22 | private func verifyContext(withText text: String) async throws { 23 | let client = try await LocalLLMClient.llama() 24 | let context = client._context 25 | let textTokens = [llama_token](text, addBos: false, special: true, vocab: context.vocab) 26 | var expectedPosition = textTokens.count 27 | 28 | #expect(context.position == 0) 29 | try context.decode(text: text) 30 | #expect(context.position == expectedPosition) 31 | 32 | _ = context.sampling.sample(context: context, index: context.position - 1) 33 | var token = context.sampling.sample(context: context, index: -1) 34 | var pieces = token.piece(vocab: context.vocab, special: true) 35 | while String(utf8String: pieces + [0]) == nil { 36 | context.batch.add(id: token, pos: context.position, seq_ids: [0], logits: true) 37 | 38 | #expect(context.position == expectedPosition) 39 | try context.decode() 40 | expectedPosition += 1 41 | #expect(context.position == expectedPosition) 42 | 43 | token = context.sampling.sample(context: context, index: -1) 44 | pieces.append(contentsOf: token.piece(vocab: context.vocab, special: true)) 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /Tests/LocalLLMClientLlamaTests/LocalLLMClientTests.swift: -------------------------------------------------------------------------------- 1 | import Testing 2 | import Foundation 3 | #if canImport(FoundationNetworking) 4 | import FoundationNetworking 5 | #endif 6 | import LocalLLMClient 7 | import LocalLLMClientLlama 8 | 9 | private let prompt = "<|im_start|>user\nWhat is the answer to one plus two?<|im_end|>\n<|im_start|>assistant\n" 10 | 11 | extension ModelTests { 12 | struct LocalLLMClientLlamaTests {} 13 | } 14 | 15 | extension ModelTests.LocalLLMClientLlamaTests { 16 | @Test 17 | func simpleStream() async throws { 18 | var result = "" 19 | 20 | for try await text in try await LocalLLMClient.llama().textStream(from: prompt) { 21 | print(text, terminator: "") 22 | result += text 23 | } 24 | 25 | #expect(!result.isEmpty) 26 | } 27 | 28 | @Test 29 | func image() async throws { 30 | let stream = try await LocalLLMClient.llama().textStream(from: LLMInput( 31 | .chat([.user("<|test_img|>What is in this image?", attachments: [ 32 | .image(.init(data: Data(contentsOf: URL(string: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.jpeg")!))!) 33 | ])]), 34 | )) 35 | 36 | var result = "" 37 | for try await text in stream { 38 | print(text, terminator: "") 39 | result += text 40 | } 41 | 42 | #expect(!result.isEmpty) 43 | } 44 | 45 | @Test @MainActor 46 | func cancel() async throws { 47 | var counter = 0 48 | var breaked = false 49 | 50 | var task: Task<Void, Error>? 51 | task = Task { 52 | for try await _ in try await LocalLLMClient.llama().textStream(from: prompt) { 53 | counter += 1 54 | task?.cancel() 55 | } 56 | breaked = true 57 | } 58 | 59 | try await Task.sleep(for: .seconds(5)) 60 | task!.cancel() 61 | try? await task!.value 62 | 63 | #expect(counter == 1) 64 | #expect(breaked) 65 | } 66 | 67 | @Test 68 | func json() async throws { 69 | var result = "" 70 | 71 | for _ in 0...2 { 72 | do { 73 | let input = LLMInput.chat([ 74 | .system("You are a helpful assistant."), 75 | .user("What is the answer to one plus two?\nRespond in JSON.\n\n{ \"answer\": \"<answer>\" }\n") 76 | ]) 77 | for try await text in try await LocalLLMClient.llama(parameter: .init( 78 | temperature: 1.0, 79 | penaltyRepeat: 1.3, 80 | options: .init(responseFormat: .json) 81 | )).textStream(from: input) { 82 | print(text, terminator: "") 83 | result += text 84 | } 85 | 86 | try JSONSerialization.jsonObject(with: Data(result.utf8), options: []) 87 | return 88 | } catch { 89 | print(error) 90 | } 91 | } 92 | 93 | Issue.record() 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /Tests/LocalLLMClientLlamaTests/ModelTests.swift: -------------------------------------------------------------------------------- 1 | import Testing 2 | import Foundation 3 | import LocalLLMClient 4 | @testable import LocalLLMClientLlama 5 | #if canImport(LocalLLMClientUtility) 6 | import LocalLLMClientUtility 7 | #endif 8 | 9 | let disabledTests = ![nil, "Llama"].contains(ProcessInfo.processInfo.environment["GITHUB_ACTIONS_TEST"]) 10 | 11 | extension LocalLLMClient { 12 | static let model = "SmolVLM-256M-Instruct-Q8_0.gguf" 13 | static let clip = "mmproj-SmolVLM-256M-Instruct-Q8_0.gguf" 14 | 15 | static func llama( 16 | parameter: LlamaClient.Parameter? = nil, 17 | messageDecoder: any LlamaChatMessageDecoder = LlamaCustomMessageDecoder(tokenImageRegex: #"<\|test_img\|>"#) 18 | ) async throws -> LlamaClient { 19 | let url = try await downloadModel() 20 | return try await LocalLLMClient.llama( 21 | url: url.appending(component: model), 22 | mmprojURL: url.appending(component: clip), 23 | parameter: parameter ?? .init(context: 512), 24 | messageDecoder: messageDecoder, 25 | verbose: true 26 | ) 27 | } 28 | 29 | static func downloadModel() async throws -> URL { 30 | #if canImport(LocalLLMClientUtility) 31 | let downloader = FileDownloader( 32 | source: .huggingFace(id: "ggml-org/SmolVLM-256M-Instruct-GGUF", globs: [model, clip]), 33 | destination: ProcessInfo.processInfo.environment["GITHUB_MODEL_CACHE"].map { URL(filePath: $0) } ?? FileDownloader.defaultRootDestination 34 | ) 35 | try await downloader.download { print("Download: \($0)") } 36 | return downloader.destination 37 | #else 38 | return URL(filePath: ProcessInfo.processInfo.environment["GITHUB_MODEL_CACHE", default: "~/.localllmclient"]) 39 | .appending(component: "huggingface/models/ggml-org/SmolVLM-256M-Instruct-GGUF") 40 | #endif 41 | } 42 | } 43 | 44 | @Suite(.serialized, .timeLimit(.minutes(5)), .disabled(if: disabledTests)) 45 | actor ModelTests { 46 | nonisolated(unsafe) private static var initialized = false 47 | 48 | init() async throws { 49 | if !Self.initialized && !disabledTests { 50 | let url = try await LocalLLMClient.downloadModel() 51 | let path = url.appending(component: LocalLLMClient.model).path 52 | if !FileManager.default.fileExists(atPath: path) { 53 | throw LLMError.failedToLoad(reason: "Model file not found at \(path)") 54 | } 55 | Self.initialized = true 56 | } 57 | } 58 | 59 | @Test 60 | func validateChatTemplate() async throws { 61 | let client = try await LocalLLMClient.llama() 62 | #expect(client._context.model.chatTemplate == """ 63 | <|im_start|>{% for message in messages %}{{message[\'role\'] | capitalize}}{% if message[\'content\'][0][\'type\'] == \'image\' %}{{\':\'}}{% else %}{{\': \'}}{% endif %}{% for line in message[\'content\'] %}{% if line[\'type\'] == \'text\' %}{{line[\'text\']}}{% elif line[\'type\'] == \'image\' %}{{ \'<image>\' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ \'Assistant:\' }}{% endif %} 64 | """) 65 | } 66 | 67 | @Test 68 | func validateRenderedTemplate() async throws { 69 | let client = try await LocalLLMClient.llama() 70 | let decoder = LlamaAutoMessageDecoder(chatTemplate: client._context.model.chatTemplate) 71 | let messages: [LLMInput.Message] = [ 72 | .system("You are a helpful assistant."), 73 | .user("What is the answer to one plus two?"), 74 | .assistant("The answer is 3."), 75 | ] 76 | let value = decoder.templateValue(from: messages) 77 | let template = try decoder.applyTemplate(value, chatTemplate: client._context.model.chatTemplate) 78 | #expect(decoder.chatTemplate == .qwen2_5_VL) 79 | #expect(template == "<|im_start|>System: You are a helpful assistant.<end_of_utterance>\nUser: What is the answer to one plus two?<end_of_utterance>\nAssistant: The answer is 3.<end_of_utterance>\nAssistant:") 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /Tests/LocalLLMClientMLXTests/LocalLLMClientTests.swift: -------------------------------------------------------------------------------- 1 | import Testing 2 | import Foundation 3 | import LocalLLMClient 4 | import LocalLLMClientMLX 5 | 6 | let prompt = "What is the answer to one plus two?" 7 | 8 | extension ModelTests { 9 | struct LocalLLMClientMLXTests {} 10 | } 11 | 12 | extension ModelTests.LocalLLMClientMLXTests { 13 | @Test(.timeLimit(.minutes(5))) 14 | func simpleStream() async throws { 15 | var result = "" 16 | for try await text in try await LocalLLMClient.mlx().textStream(from: prompt) { 17 | print(text, terminator: "") 18 | result += text 19 | } 20 | 21 | #expect(!result.isEmpty) 22 | } 23 | 24 | @Test(.timeLimit(.minutes(5))) 25 | func image() async throws { 26 | let stream = try await LocalLLMClient.mlx().textStream(from: LLMInput( 27 | .chat([.user("What is in this image?", attachments: [ 28 | .image(.init(contentsOf: URL(string: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.jpeg")!)!) 29 | ])]) 30 | )) 31 | var result = "" 32 | for try await text in stream { 33 | print(text, terminator: "") 34 | result += text 35 | } 36 | 37 | #expect(!result.isEmpty) 38 | } 39 | 40 | @Test(.timeLimit(.minutes(5))) @MainActor 41 | func cancel() async throws { 42 | var counter = 0 43 | var breaked = false 44 | 45 | var task: Task<Void, Error>? 46 | task = Task { 47 | for try await _ in try await LocalLLMClient.mlx().textStream(from: prompt) { 48 | counter += 1 49 | task?.cancel() 50 | } 51 | breaked = true 52 | } 53 | 54 | try await Task.sleep(for: .seconds(2)) 55 | task!.cancel() 56 | try? await task!.value 57 | 58 | #expect(counter == 1) 59 | #expect(breaked) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /Tests/LocalLLMClientMLXTests/ModelTests.swift: -------------------------------------------------------------------------------- 1 | import Testing 2 | import Foundation 3 | import LocalLLMClient 4 | import LocalLLMClientMLX 5 | import LocalLLMClientUtility 6 | 7 | private let disabledTests = ![nil, "MLX"].contains(ProcessInfo.processInfo.environment["GITHUB_ACTIONS_TEST"]) 8 | 9 | extension LocalLLMClient { 10 | static func mlx() async throws -> MLXClient { 11 | try await LocalLLMClient.mlx(url: downloadModel(), parameter: .init(maxTokens: 256)) 12 | } 13 | 14 | static func downloadModel() async throws -> URL { 15 | let downloader = FileDownloader( 16 | source: .huggingFace(id: "mlx-community/SmolVLM2-256M-Video-Instruct-mlx", globs: .mlx), 17 | destination: ProcessInfo.processInfo.environment["GITHUB_MODEL_CACHE"].map { URL(filePath: $0) } ?? FileDownloader.defaultRootDestination 18 | ) 19 | try await downloader.download { print("Download: \($0)") } 20 | return downloader.destination 21 | } 22 | } 23 | 24 | @Suite(.serialized, .disabled(if: disabledTests)) 25 | actor ModelTests { 26 | nonisolated(unsafe) private static var initialized = false 27 | 28 | init() async throws { 29 | if !Self.initialized && !disabledTests { 30 | _ = try await LocalLLMClient.downloadModel() 31 | Self.initialized = true 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /Tests/LocalLLMClientUtilityTests/FileDownloaderTests.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Testing 3 | #if canImport(Hub) 4 | import Hub 5 | #endif 6 | @testable import LocalLLMClientUtility 7 | 8 | struct FileDownloaderTests { 9 | @Test 10 | func testFileDownloaderInitialization() async throws { 11 | let testDirectory = FileManager.default.temporaryDirectory.appendingPathComponent("FileDownloaderTests") 12 | 13 | // Initialize a FileDownloader with Hugging Face source 14 | let globs: Globs = ["*.safetensors", "*.json"] 15 | let source = FileDownloader.Source.huggingFace(id: "test-org/test-repo", globs: globs) 16 | let downloader = FileDownloader(source: source, destination: testDirectory) 17 | 18 | // Verify the initialization 19 | #expect(downloader.source == source) 20 | #expect(downloader.destination.pathComponents[(downloader.destination.pathComponents.count - 5)...] == ["FileDownloaderTests", "huggingface", "models", "test-org", "test-repo"]) 21 | } 22 | 23 | #if canImport(Hub) 24 | @Test 25 | func checkCompatibilityWithHuggingFaceAPI() async throws { 26 | let testDirectory = FileManager.default.temporaryDirectory.appendingPathComponent("FileDownloaderTests") 27 | 28 | // Initialize a FileDownloader with Hugging Face source 29 | let globs: Globs = ["*.safetensors", "*.json"] 30 | let source = FileDownloader.Source.huggingFace(id: "test-org/test-repo", globs: globs) 31 | let downloader = FileDownloader(source: source, destination: testDirectory) 32 | 33 | let hub = HubApi( 34 | downloadBase: testDirectory.appending(component: "huggingface"), 35 | useOfflineMode: false 36 | ) 37 | let hubDestination = hub.localRepoLocation(Hub.Repo(id: "test-org/test-repo")) 38 | #expect(downloader.destination == hubDestination) 39 | } 40 | #endif 41 | } 42 | -------------------------------------------------------------------------------- /Tests/LocalLLMClientUtilityTests/FilesMetadataTests.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Testing 3 | @testable import LocalLLMClientUtility 4 | 5 | struct FilesMetadataTests { 6 | @Test 7 | func testFilesMetadataSaveAndLoad() throws { 8 | // Create a temporary directory for testing 9 | let testDirectory = FileManager.default.temporaryDirectory.appendingPathComponent("FilesMetadataTests") 10 | try? FileManager.default.createDirectory(at: testDirectory, withIntermediateDirectories: true) 11 | 12 | // Create test metadata 13 | let fileMetadata1 = FilesMetadata.FileMetadata(name: "test1.json") 14 | let fileMetadata2 = FilesMetadata.FileMetadata(name: "test2.safetensors") 15 | let metadata = FilesMetadata(files: [fileMetadata1, fileMetadata2]) 16 | 17 | // Save metadata to the test directory 18 | try metadata.save(to: testDirectory) 19 | defer { 20 | try? FileManager.default.removeItem(at: testDirectory) 21 | } 22 | 23 | // Verify metadata file was created 24 | let metadataFilePath = testDirectory.appendingPathComponent(FilesMetadata.filename) 25 | #expect(FileManager.default.fileExists(atPath: metadataFilePath.path)) 26 | 27 | // Load metadata from the test directory 28 | let loadedMetadata = try FilesMetadata.load(from: testDirectory) 29 | 30 | // Verify loaded metadata matches what was saved 31 | #expect(loadedMetadata.files.count == 2) 32 | #expect(loadedMetadata.files[0].name == "test1.json") 33 | #expect(loadedMetadata.files[1].name == "test2.safetensors") 34 | } 35 | 36 | @Test 37 | func testHuggingFaceGlobsMLXDefault() { 38 | let mlxGlobs = Globs.mlx 39 | 40 | #expect(mlxGlobs.rawValue.count == 2) 41 | #expect(mlxGlobs.rawValue.contains("*.safetensors")) 42 | #expect(mlxGlobs.rawValue.contains("*.json")) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /Tests/LocalLLMClientUtilityTests/HuggingFaceAPITests.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Testing 3 | #if canImport(FoundationNetworking) 4 | import FoundationNetworking 5 | #endif 6 | @testable import LocalLLMClientUtility 7 | 8 | struct HuggingFaceAPITests { 9 | 10 | /// Creates a mock session configuration with MockURLProtocol 11 | private func mockSessionConfiguration() -> HuggingFaceAPI.DownloadConfiguration { 12 | var configuration = HuggingFaceAPI.DownloadConfiguration.default 13 | configuration.protocolClasses = [MockURLProtocol.self] 14 | return configuration 15 | } 16 | 17 | /// Creates a background session configuration with MockURLProtocol 18 | private func mockBackgroundSessionConfiguration() -> HuggingFaceAPI.DownloadConfiguration { 19 | var configuration = HuggingFaceAPI.DownloadConfiguration.background(withIdentifier: "com.\(#function)") 20 | configuration.protocolClasses = [MockURLProtocol.self] 21 | return configuration 22 | } 23 | 24 | /// Helper function to create a temporary directory for test file downloads 25 | private func createTemporaryDirectory() -> URL { 26 | let tempDirURL = FileManager.default.temporaryDirectory.appendingPathComponent( 27 | "HuggingFaceAPITests_\(UUID().uuidString)", 28 | isDirectory: true 29 | ) 30 | try? FileManager.default.createDirectory( 31 | at: tempDirURL, 32 | withIntermediateDirectories: true 33 | ) 34 | return tempDirURL 35 | } 36 | 37 | /// Cleanup temporary files after tests 38 | private func cleanupTemporaryDirectory(_ url: URL) { 39 | try? FileManager.default.removeItem(at: url) 40 | } 41 | 42 | @Test(.disabled(if: canImportFoundationNetworking)) 43 | func testDownloadSnapshotWithDefaultConfiguration() async throws { 44 | // Setup the mock responses 45 | let repoInfo = """ 46 | { 47 | "siblings": [ 48 | {"rfilename": "test1.bin"}, 49 | {"rfilename": "test2.bin"} 50 | ] 51 | } 52 | """ 53 | let repoInfoData = repoInfo.data(using: .utf8)! 54 | let testFile1Data = "Test file 1 content".data(using: .utf8)! 55 | let testFile2Data = "Test file 2 content".data(using: .utf8)! 56 | 57 | let apiURL = URL(string: "https://huggingface.co/api/models/\(#function)/revision/main")! 58 | let file1URL = URL(string: "https://huggingface.co/\(#function)/resolve/main/test1.bin")! 59 | let file2URL = URL(string: "https://huggingface.co/\(#function)/resolve/main/test2.bin")! 60 | 61 | MockURLProtocol.setResponse(for: apiURL, with: repoInfoData) 62 | MockURLProtocol.setResponse(for: file1URL, with: testFile1Data) 63 | MockURLProtocol.setResponse(for: file2URL, with: testFile2Data) 64 | 65 | defer { 66 | MockURLProtocol.removeResponse(for: apiURL) 67 | MockURLProtocol.removeResponse(for: file1URL) 68 | MockURLProtocol.removeResponse(for: file2URL) 69 | } 70 | 71 | // Create a temporary directory for downloads 72 | let tempDir = createTemporaryDirectory() 73 | defer { cleanupTemporaryDirectory(tempDir) } 74 | 75 | // Setup the HuggingFaceAPI client 76 | let repo = HuggingFaceAPI.Repo(id: #function) 77 | let api = HuggingFaceAPI(repo: repo) 78 | 79 | var progressFractionCompleted: Double = 0 80 | 81 | // Download with default configuration 82 | let destination = try await api.downloadSnapshot( 83 | to: tempDir, 84 | matching: Globs(["*.bin"]), 85 | configuration: mockSessionConfiguration() 86 | ) { progress in 87 | await MainActor.run { 88 | progressFractionCompleted = progress.fractionCompleted 89 | } 90 | } 91 | 92 | // Verify the files were downloaded 93 | let file1Path = destination.appendingPathComponent("test1.bin") 94 | let file2Path = destination.appendingPathComponent("test2.bin") 95 | 96 | #expect(FileManager.default.fileExists(atPath: file1Path.path)) 97 | #expect(FileManager.default.fileExists(atPath: file2Path.path)) 98 | 99 | let file1Content = try String(contentsOf: file1Path, encoding: .utf8) 100 | let file2Content = try String(contentsOf: file2Path, encoding: .utf8) 101 | 102 | #expect(file1Content == "Test file 1 content") 103 | #expect(file2Content == "Test file 2 content") 104 | #expect(progressFractionCompleted == 1.0) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /Tests/LocalLLMClientUtilityTests/MockURLProtocol.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | #if canImport(FoundationNetworking) 3 | import FoundationNetworking 4 | #endif 5 | import LocalLLMClientUtility 6 | 7 | /// Mock URLProtocol for testing download functionality without actual network requests 8 | final class MockURLProtocol: URLProtocol { 9 | /// Dictionary to store mock responses by URL 10 | static let mockResponses: Locked<[URL: (data: Data, response: HTTPURLResponse, error: Error?)]> = .init([:]) 11 | 12 | /// Storage for downloaded files 13 | static let downloadedFiles: Locked<[URL: URL]> = .init([:]) 14 | 15 | /// Registers a mock response for a specific URL 16 | static func setResponse(for url: URL, with data: Data, statusCode: Int = 200, error: Error? = nil) { 17 | let response = HTTPURLResponse( 18 | url: url, 19 | statusCode: statusCode, 20 | httpVersion: "HTTP/1.1", 21 | headerFields: ["Content-Length": "\(data.count)"] 22 | )! 23 | mockResponses.withLock { 24 | $0[url] = (data, response, error) 25 | } 26 | } 27 | 28 | /// Unregisters a mock response for a specific URL 29 | /// - Parameter url: The URL for which to remove the mock response 30 | static func removeResponse(for url: URL) { 31 | mockResponses.withLock { $0.removeValue(forKey: url) } 32 | downloadedFiles.withLock { $0.removeValue(forKey: url) } 33 | } 34 | 35 | // MARK: - URLProtocol methods 36 | 37 | override class func canInit(with request: URLRequest) -> Bool { 38 | return true 39 | } 40 | 41 | override class func canonicalRequest(for request: URLRequest) -> URLRequest { 42 | return request 43 | } 44 | 45 | override func startLoading() { 46 | guard let url = request.url, let client else { 47 | client?.urlProtocolDidFinishLoading(self) 48 | return 49 | } 50 | 51 | guard let mockData = MockURLProtocol.mockResponses.withLock({ $0[url] }) else { 52 | let error = NSError(domain: NSURLErrorDomain, code: NSURLErrorUnsupportedURL, userInfo: nil) 53 | client.urlProtocol(self, didFailWithError: error) 54 | return 55 | } 56 | 57 | // Send mock response data 58 | client.urlProtocol(self, didReceive: mockData.response, cacheStoragePolicy: .notAllowed) 59 | 60 | if let error = mockData.error { 61 | client.urlProtocol(self, didFailWithError: error) 62 | return 63 | } 64 | 65 | // For download tasks, we need to create a temporary file 66 | let tempFileURL = FileManager.default.temporaryDirectory 67 | .appendingPathComponent(UUID().uuidString, isDirectory: false) 68 | try? mockData.data.write(to: tempFileURL) 69 | MockURLProtocol.downloadedFiles.withLock { 70 | $0[url] = tempFileURL 71 | } 72 | 73 | // Report download progress 74 | let totalBytes = mockData.data.count 75 | let chunkSize = totalBytes / 10 76 | 77 | var offset = 0 78 | while offset < totalBytes { 79 | // Create a chunk of data to simulate progressive loading 80 | let currentChunkSize = min(chunkSize, totalBytes - offset) 81 | let startIndex = offset 82 | let endIndex = offset + currentChunkSize 83 | let chunkData = mockData.data[startIndex..<endIndex] 84 | client.urlProtocol(self, didLoad: chunkData) 85 | offset += currentChunkSize 86 | } 87 | 88 | client.urlProtocolDidFinishLoading(self) 89 | } 90 | 91 | override func stopLoading() { 92 | // No action needed 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /Tests/LocalLLMClientUtilityTests/URLExtensionTests.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Testing 3 | @testable import LocalLLMClientUtility 4 | 5 | struct URLExtensionTests { 6 | @Test 7 | func testDefaultRootDirectory() { 8 | let url = URL.defaultRootDirectory 9 | 10 | #if os(macOS) 11 | let homeDir = FileManager.default.homeDirectoryForCurrentUser 12 | #expect(url.path.hasPrefix(homeDir.path)) 13 | #expect(url.path.contains("/.localllmclient")) 14 | #elseif os(iOS) 15 | let docsDir = URL.documentsDirectory 16 | #expect(url.path.hasPrefix(docsDir.path)) 17 | #expect(url.path.contains("/.localllmclient")) 18 | #endif 19 | } 20 | 21 | #if !os(Linux) 22 | @Test 23 | func testExcludedFromBackup() throws { 24 | let tempURL = FileManager.default.temporaryDirectory.appendingPathComponent("testExcludedFromBackup") 25 | try FileManager.default.createDirectory(at: tempURL, withIntermediateDirectories: true) 26 | 27 | // Apply the excludedFromBackup property 28 | let excludedURL = tempURL.excludedFromBackup 29 | 30 | // Check if the resource value is set correctly 31 | let resourceValues = try excludedURL.resourceValues(forKeys: [.isExcludedFromBackupKey]) 32 | let isExcluded = resourceValues.isExcludedFromBackup ?? false 33 | 34 | #expect(isExcluded, "URL should be excluded from backup") 35 | } 36 | #endif 37 | } 38 | -------------------------------------------------------------------------------- /scripts/get_llama_version.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # get_llama_version.sh 4 | # This script retrieves the current llama.cpp version from Package.swift 5 | 6 | set -e 7 | 8 | if [ "$#" -eq 0 ]; then 9 | # If no arguments provided, use parent directory of the script 10 | PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )" 11 | else 12 | # Use the provided path as the project root 13 | PROJECT_ROOT="$1" 14 | fi 15 | 16 | PACKAGE_FILE="$PROJECT_ROOT/Package.swift" 17 | 18 | # Get the current version from Package.swift using grep and sed 19 | # This pattern specifically looks for the let llamaVersion = "b5486" format 20 | CURRENT_VERSION=$(grep -E "let llamaVersion = \"[a-zA-Z0-9]+\"" "$PACKAGE_FILE" | sed -E 's/.*"([a-zA-Z0-9]+)".*/\1/') 21 | 22 | # Print the version to stdout 23 | echo "$CURRENT_VERSION" 24 | -------------------------------------------------------------------------------- /scripts/run_mlx.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DERIVED_DATA_PATH="./DerivedData" 4 | CONFIGURATION="Debug" 5 | BINARY_PATH="${DERIVED_DATA_PATH}/Build/Products/${CONFIGURATION}/localllm" 6 | 7 | # Check if the binary exists and is executable 8 | if [ ! -x "${BINARY_PATH}" ]; then 9 | echo "Building localllm..." 10 | xcodebuild -scheme localllm -configuration ${CONFIGURATION} -derivedDataPath "${DERIVED_DATA_PATH}" \ 11 | -destination 'platform=macOS,arch=arm64' build -quiet 12 | fi 13 | 14 | # Run 15 | "${BINARY_PATH}" --backend mlx "$@" -------------------------------------------------------------------------------- /scripts/update_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )" 6 | PACKAGE_FILE="$PROJECT_ROOT/Package.swift" 7 | 8 | cd "$PROJECT_ROOT/scripts" 9 | 10 | echo "Fetching latest release from ggml-org/llama.cpp..." 11 | 12 | if [ -n "$GITHUB_TOKEN" ]; then 13 | LATEST_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" "https://api.github.com/repos/ggml-org/llama.cpp/releases/latest" | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/') 14 | else 15 | LATEST_TAG=$(curl -s "https://api.github.com/repos/ggml-org/llama.cpp/releases/latest" | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/') 16 | fi 17 | 18 | if [ -z "$LATEST_TAG" ]; then 19 | echo "Error: Could not fetch the latest release tag" 20 | exit 1 21 | fi 22 | 23 | echo "Latest release tag: $LATEST_TAG" 24 | 25 | CURRENT_VERSION=$("$PROJECT_ROOT/scripts/get_llama_version.sh") 26 | 27 | echo "Current version: $CURRENT_VERSION" 28 | 29 | if [ "$CURRENT_VERSION" = "$LATEST_TAG" ]; then 30 | echo "Already using the latest version. No update needed." 31 | exit 0 32 | fi 33 | 34 | TEMP_DIR=$(mktemp -d) 35 | cd "$TEMP_DIR" 36 | 37 | XC_FRAMEWORK_FILE="llama-${LATEST_TAG}-xcframework.zip" 38 | XC_FRAMEWORK_URL="https://github.com/ggml-org/llama.cpp/releases/download/${LATEST_TAG}/${XC_FRAMEWORK_FILE}" 39 | 40 | echo "Downloading $XC_FRAMEWORK_URL..." 41 | curl -L -o "$XC_FRAMEWORK_FILE" "$XC_FRAMEWORK_URL" 42 | 43 | echo "Computing checksum..." 44 | CHECKSUM=$(swift package compute-checksum "$XC_FRAMEWORK_FILE") 45 | 46 | echo "New checksum: $CHECKSUM" 47 | 48 | # Update Package.swift - version 49 | sed -i '' "s/let llamaVersion = \"$CURRENT_VERSION\"/let llamaVersion = \"$LATEST_TAG\"/" "$PACKAGE_FILE" 50 | 51 | # Update Package.swift - checksum 52 | sed -i '' "s/checksum: \"[a-f0-9]*\"/checksum: \"$CHECKSUM\"/" "$PACKAGE_FILE" 53 | 54 | # Clean up 55 | cd "$PROJECT_ROOT" 56 | rm -rf "$TEMP_DIR" 57 | 58 | echo "Package.swift has been updated to use llama.cpp version $LATEST_TAG" 59 | 60 | echo "Updating git submodules..." 61 | git fetch --tags 62 | git -C "$PROJECT_ROOT/Sources/LocalLLMClientLlamaC/exclude/llama.cpp" checkout tags/$LATEST_TAG 63 | echo "All submodules have been updated." --------------------------------------------------------------------------------