├── .gitignore ├── Package.resolved ├── Package.swift ├── README.md └── Sources ├── CLI.swift └── Kit ├── CoreML+Extensions.swift ├── KVCacheProcessor.swift ├── LogitProcessor.swift ├── ModelPipeline.swift ├── MultiArrayStore.swift ├── PipelineInferenceConfiguration.swift ├── Swift+Extensions.swift ├── TextGenerator.swift └── Transformers+Extensions.swift /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | /.build 3 | /Packages 4 | xcuserdata/ 5 | DerivedData/ 6 | .swiftpm/configuration/registries.json 7 | .swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata 8 | .netrc 9 | -------------------------------------------------------------------------------- /Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "pins" : [ 3 | { 4 | "identity" : "jinja", 5 | "kind" : "remoteSourceControl", 6 | "location" : "https://github.com/maiqingqiang/Jinja", 7 | "state" : { 8 | "revision" : "b435eb62b0d3d5f34167ec70a128355486981712", 9 | "version" : "1.0.5" 10 | } 11 | }, 12 | { 13 | "identity" : "swift-argument-parser", 14 | "kind" : "remoteSourceControl", 15 | "location" : "https://github.com/apple/swift-argument-parser", 16 | "state" : { 17 | "revision" : "41982a3656a71c768319979febd796c6fd111d5c", 18 | "version" : "1.5.0" 19 | } 20 | }, 21 | { 22 | "identity" : "swift-transformers", 23 | "kind" : "remoteSourceControl", 24 | "location" : "https://github.com/huggingface/swift-transformers", 25 | "state" : { 26 | "revision" : "4d25d20e49d2269aec1556231f8e278db7b2a4f0", 27 | "version" : "0.1.13" 28 | } 29 | } 30 | ], 31 | "version" : 2 32 | } 33 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 5.9 2 | // The swift-tools-version declares the minimum version of Swift required to build this package. 3 | 4 | import PackageDescription 5 | 6 | let package = Package( 7 | name: "LLMCLI", 8 | platforms: [ 9 | .macOS(.v14), 10 | ], 11 | dependencies: [ 12 | .package(url: "https://github.com/huggingface/swift-transformers", from: "0.1.13"), 13 | .package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"), 14 | ], 15 | targets: [ 16 | // Targets are the basic building blocks of a package, defining a module or a test suite. 17 | // Targets can depend on other targets in this package and products from dependencies. 18 | .executableTarget( 19 | name: "LLMCLI", 20 | dependencies: [ 21 | .product(name: "ArgumentParser", package: "swift-argument-parser"), 22 | .product(name: "Transformers", package: "swift-transformers"), 23 | ], 24 | path: "Sources" 25 | ), 26 | ] 27 | ) 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoreML LLM CLI 2 | CLI to demonstrate running a large language model (LLM) on Apple Neural Engine. 3 | 4 | Download a CoreML-compatible Llama 2 7B model (~4GB), load it, and generate text: 5 | ```shell 6 | $ swift run LLMCLI --repo-id smpanaro/Llama-2-7b-coreml 7 | ``` 8 | 9 | **Note:** macOS 14 (Sonoma) is required. 10 | 11 | ## Performance 12 | I'm curious to see the performance of this model on different M-series chips. If you don't see your chip below or if you get significantly different timings, please run the following command and open an issue so I can add it! 13 | 14 | To download + measure: 15 | ```shell 16 | $ swift run -c release LLMCLI --repo-id smpanaro/Llama-2-7b-coreml --max-new-tokens 80 17 | ``` 18 | 19 | |Variant|1st Load Time|2nd+ Load Time|Tokens/Sec |ANE Power| 20 | |-- |-- |-- |-- |- | 21 | |M1 Max |113s |8.1s |7.02 ± 0.11 |4.2W | 22 | |M3 Max |- |0.8s |13.92 ± 0.5 |8W | 23 | 24 | For M2 and M4, consider the times for Model Version 1 as a lower bound: 25 | 26 |
Model Version 1 (a76a14d) 27 | 28 | |Variant|1st Load Time|2nd+ Load Time|Tokens/Sec |ANE Power| 29 | |-- |-- |-- |-- |- | 30 | |M1 Max |77s |7.5s |4.97 ± 0.11 |3.6W | 31 | |M2 |- |22.9s |5.51 ± 0.56 |4.5-7.2W | 32 | |M2 Pro |71s |4.2s |6.76 ± 0.09 |- | 33 | |M3 |64s |5.47s |7.12 ± 0.16 |5.6W | 34 | |M3 Max |- |- |7.6 |5.5W | 35 | |M4 iPad|66s |- |7.76 ± 0.36 |- | 36 | 37 |
38 | 39 | ## Inference Optimizations 40 | This CLI implements a couple optimizations that might be useful/interesting even in medium-sized and smaller models. 41 | 42 | ### CVPixelBuffer/IOSurface MLMultiArrays 43 | All float16 model input and output arrays are created with IOSurface-backed CVPixelBuffers ([docs](https://developer.apple.com/documentation/coreml/mlmultiarray/3882834-init)). This avoids unnecessary copying that can occur when going between CPU and ANE, especially for large arrays like the KV cache. 44 | 45 | ### 20% Faster Convolutions With Reshaping 46 | As recommended in Apple's [Deploying Transformers on the Apple Neural Engine](https://machinelearning.apple.com/research/neural-engine-transformers), this model uses a 4D tensor layout: (Batch, Channels, 1, Sequence). We process 64 tokens at a time, so most tensors are (B,C,1,64). It turns out that the convolutions in the MLP are 50% faster when the tensor is (B,C,8,8). This seems to hold across hardware versions (M1, A14, A16, A17 Pro) so we can leverage it for an extra speedup. 47 | 48 | Unfortunately, attention requires a (B,C,1,S) shape tensor so we cannot simply run the whole model in (B,C,8,8) for the full 50% speedup. Instead we reshape to (B,C,1,64) before[^1] the QKV projections and back to (B,C,8,8) just before the attention out projection. This seems to minimize the cost of reshapes and allows us to achieve a ~20% overall speedup. 49 | 50 | [^1]: Yes, before. Doing less reshapes (1 instead of 3) is faster than doing these smaller convolutions in (8,8). 51 | 52 | ### Model Chunking 53 | The model is split into multiple smaller CoreML model chunks: 54 | - 1 for the embedding+attention mask+RoPE cos/sin 55 | - N for the blocks (3 [blocks](https://github.com/Lightning-AI/litgpt/blob/221b7ef54161272162aa9b036f1ef3674f3160a4/litgpt/model.py#L139) per chunk) 56 | - 1 for the lm (lanaguage modeling) head 57 | 58 | This allows for faster model loading and enables async KV cache manipulation (see below). 59 | 60 | ### Async KV Cache Updates 61 | This model uses an ANE-compatible KV cache which needs to be shifted periodically. Due to the model chunking this does not need to happen synchronously in each chunk. It only needs to be updated before the next chunk prediction (~1 full forward pass in the future). 62 | 63 | To take advantage of this, the new sections of the KV cache are returned from the model and a separate CoreML model combines them with the prior cache values asynchronously. For Llama this saves ~1-2ms per chunk (~20ms overall). 64 | 65 | ``` 66 | ┌──────────────┐ ┌───────────────┐ 67 | │ Old KV Cache │ │ Hidden States │ 68 | │ (Length 448) │ │ (Length 64) │ 69 | └──────────────┘ └───────────────┘ 70 | ↘ ↙ 71 | ┌───────────┐ 72 | │Chunk Model│ 73 | └───────────┘ 74 | ↙ ↘ 75 | ┌──────────────┐ ┌─────────────────┐ 76 | │ New KV Cache │ │New Hidden States│ 77 | │ (Length 64) │ │ (Length 64) │ 78 | └──────────────┘ └─────────────────┘ 79 | 80 | ┌───────────────────────────────────────────┐ 81 | │Async after the chunk prediction completes:│ 82 | │ │ 83 | │ ┌──────────────┐ ┌──────────────┐ │ 84 | │ │ Old KV Cache │ │ New KV Cache │ │ 85 | │ │ (Length 448) │ │ (Length 64) │ │ 86 | │ └──────────────┘ └──────────────┘ │ 87 | │ ↘ ↙ │ 88 | │ ┌──────────────────┐ │ 89 | │ │Cache Update Model│ │ 90 | │ └──────────────────┘ │ 91 | │ ↓ │ 92 | │ ┌────────────────┐ │ 93 | │ │Updated KV Cache│ │ 94 | │ │ (Length 448) │ │ 95 | │ └────────────────┘ │ 96 | └───────────────────────────────────────────┘ 97 | ``` 98 | 99 | This pattern would probably work for other non-critical path computations too. 100 | -------------------------------------------------------------------------------- /Sources/CLI.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import ArgumentParser 3 | import Tokenizers 4 | import Hub 5 | import CoreML 6 | 7 | @main 8 | struct CLI: AsyncParsableCommand { 9 | @Option(help: "Huggingface repo ID. e.g. smpanaro/Llama-2-7b-CoreML") 10 | var repoID: String? = nil 11 | 12 | @Option( 13 | help: "The directory containing the model's mlmodelc files.", 14 | completion: .file(), transform: URL.init(fileURLWithPath:)) 15 | var localModelDirectory: URL? 16 | 17 | @Option(help: "The model filename prefix, to differentiate when there are multiple models in a folder.") 18 | var localModelPrefix: String? 19 | 20 | @Option(help: "KV cache processor model filename, located in the model directory.") 21 | var cacheProcessorModelName: String = "cache-processor.mlmodelc" 22 | 23 | @Option(help: "Logit processor model filename, located in the model directory.") 24 | var logitProcessorModelName: String = "logit-processor.mlmodelc" 25 | 26 | @Option(help: "Tokenizer name on huggingface.") 27 | var tokenizerName: String? 28 | 29 | @Argument(help: "Input text.") 30 | var inputText: String = "Several species of shrub of the genus Coffea produce the berries from which coffee is extracted. The two main species commercially cultivated are Coffea canephora" 31 | 32 | @Option(help: "Maximum number of new tokens to generate.") 33 | var maxNewTokens: Int = 60 34 | 35 | @Option(help: "Print verbose logs for debugging.") 36 | var verbose: Bool = false 37 | 38 | mutating func run() async throws { 39 | var modelDirectory = localModelDirectory 40 | if let repoID { 41 | modelDirectory = try await downloadModel(repoID: repoID) 42 | } 43 | 44 | guard let modelDirectory else { 45 | print("Either --repoID or --localModelDirectory must be provided.") 46 | return 47 | } 48 | 49 | guard let tokenizerName = tokenizerName ?? inferTokenizer() else { 50 | print("Unable to infer tokenizer. Please provide --tokenizerName.") 51 | return 52 | } 53 | 54 | let pipeline = try ModelPipeline.from( 55 | folder: modelDirectory, 56 | modelPrefix: localModelPrefix, 57 | cacheProcessorModelName: cacheProcessorModelName, 58 | logitProcessorModelName: logitProcessorModelName 59 | // For debugging. 60 | // primaryCompute: .cpuOnly 61 | // chunkLimit: 1 62 | ) 63 | print(pipeline) 64 | 65 | let tokenizer = try await AutoTokenizer.from(cached: tokenizerName, hubApi: .init(hfToken: HubApi.defaultToken())) 66 | if verbose { print("Tokenizer \(tokenizerName) loaded.") } 67 | 68 | let generator = TextGenerator(pipeline: pipeline, tokenizer: tokenizer) 69 | try await generator.generate(text: inputText, maxNewTokens: maxNewTokens) 70 | } 71 | 72 | func inferTokenizer() -> String? { 73 | switch (repoID ?? localModelPrefix ?? localModelDirectory?.lastPathComponent)?.lowercased() ?? "" { 74 | case let str where str.contains("llama-2-7b"): 75 | return "pcuenq/Llama-2-7b-chat-coreml" 76 | case let str where str.contains( "llama-3.2-1b"): 77 | return "meta-llama/Llama-3.2-1B" 78 | case let str where str.contains( "llama-3.2-3b"): 79 | return "meta-llama/Llama-3.2-3B" 80 | default: 81 | return nil 82 | } 83 | } 84 | 85 | /// Download a model and return the local directory URL. 86 | func downloadModel(repoID: String) async throws -> URL { 87 | let hub = HubApi(hfToken: HubApi.defaultToken()) 88 | let repo = Hub.Repo(id: repoID, type: .models) 89 | 90 | let mlmodelcs = ["Llama*.mlmodelc/*", "logit*", "cache*"] 91 | let filenames = try await hub.getFilenames(from: repo, matching: mlmodelcs) 92 | 93 | let localURL = hub.localRepoLocation(repo) 94 | let localFileURLs = filenames.map { 95 | localURL.appending(component: $0) 96 | } 97 | let anyNotExists = localFileURLs.filter { 98 | !FileManager.default.fileExists(atPath: $0.path(percentEncoded: false)) 99 | }.count > 0 100 | 101 | // swift-transformers doesn't offer a way to know if we're up to date. 102 | let newestTimestamp = localFileURLs.filter { 103 | FileManager.default.fileExists(atPath: $0.path(percentEncoded: false)) 104 | }.compactMap { 105 | let attrs = try! FileManager.default.attributesOfItem(atPath: $0.path(percentEncoded: false)) 106 | return attrs[.modificationDate] as? Date 107 | }.max() ?? Date.distantFuture 108 | let lastUploadDate = Date(timeIntervalSince1970: 1723688450) 109 | let isStale = repoID == "smpanaro/Llama-2-7b-coreml" && newestTimestamp < lastUploadDate 110 | 111 | // I would rather not delete things automatically. 112 | if isStale { 113 | print("⚠️ You have an old model downloaded. Please move the following directory to the Trash and try again:") 114 | print(localURL.path()) 115 | throw CLIError.staleFiles 116 | } 117 | 118 | let needsDownload = anyNotExists || isStale 119 | guard needsDownload else { return localURL } 120 | 121 | print("Downloading from \(repoID)...") 122 | if filenames.count == 0 { 123 | throw CLIError.noModelFilesFound 124 | } 125 | 126 | let downloadDir = try await hub.snapshot(from: repo, matching: mlmodelcs) { progress in 127 | let percent = progress.fractionCompleted * 100 128 | if !progress.isFinished { 129 | print("\(percent.formatted(.number.precision(.fractionLength(0))))%", terminator: "\r") 130 | fflush(stdout) 131 | } 132 | } 133 | print("Done.") 134 | print("Downloaded to \(downloadDir.path())") 135 | return downloadDir 136 | } 137 | } 138 | 139 | enum CLIError: Error { 140 | case noModelFilesFound 141 | case staleFiles 142 | } 143 | -------------------------------------------------------------------------------- /Sources/Kit/CoreML+Extensions.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import CoreML 3 | 4 | extension MLComputeUnits { 5 | var debugName: String { 6 | switch self { 7 | case .all: return "All" 8 | case .cpuAndGPU: return "CPU+GPU" 9 | case .cpuAndNeuralEngine: return "CPU+ANE" 10 | case .cpuOnly: return "CPU" 11 | @unknown default: return "Unknown" 12 | } 13 | } 14 | } 15 | 16 | extension MLMultiArray { 17 | class func emptyIOSurfaceArray(shape: [Int]) -> MLMultiArray? { 18 | guard 19 | shape.count > 0, 20 | let width = shape.last 21 | else { return nil } 22 | let height = shape[0.. [String] { 50 | var ignored = [String]() 51 | outputBackings.forEach { name, backing in 52 | guard let multiArrayValue = outputs.featureValue(for: name)?.multiArrayValue, 53 | multiArrayValue.dataType == .float16 54 | else { return } 55 | 56 | guard let outputPixelBuffer = multiArrayValue.pixelBuffer 57 | else { preconditionFailure("output array is not pixel buffer backed: \(name)") } 58 | 59 | guard let backingPixelBuffer = (backing as? MLMultiArray)?.pixelBuffer 60 | else { preconditionFailure("output backing is not pixel buffer backed: \(name)")} 61 | 62 | if outputPixelBuffer !== backingPixelBuffer { 63 | ignored.append(name) 64 | } 65 | } 66 | 67 | return ignored 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /Sources/Kit/KVCacheProcessor.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import CoreML 3 | import OSLog 4 | 5 | /// Asynchronously use a chunk's KV cache output to prepare 6 | /// its KV cache inputs for the next predicition. 7 | class KVCacheProcessor { 8 | var chunkTasks: [Task?] 9 | let processorModel: MLModel 10 | 11 | let signposter = OSSignposter(subsystem: "com.stephenpanaro.llm-cli", category: "KVCacheProcessor") 12 | 13 | init(chunkCount: Int, processorModel: MLModel) { 14 | self.chunkTasks = Array(repeating: nil, count: chunkCount) 15 | self.processorModel = processorModel 16 | } 17 | 18 | convenience init(pipeline: ModelPipeline, processorModel: MLModel) { 19 | self.init(chunkCount: pipeline.chunks.count, processorModel: processorModel) 20 | } 21 | 22 | func submit(inputs: MLFeatureProvider, outputs: MLFeatureProvider, forChunk chunkIndex: Int) { 23 | let pair = ChunkInputsOutputs(inputs: inputs, outputs: outputs) 24 | chunkTasks[chunkIndex] = Task { 25 | try await withThrowingTaskGroup(of: Int.self) { group in 26 | for blockIndex in 0.. Int { 65 | let maxIndex = featureNames.filter { 66 | $0.contains("k_cache") 67 | }.compactMap { 68 | $0.components(separatedBy: "_").last 69 | }.compactMap { 70 | Int($0) 71 | }.max() 72 | guard let maxIndex else { return 0 } 73 | return maxIndex + 1 74 | } 75 | } 76 | 77 | struct ChunkInputsOutputs { 78 | let inputs: MLFeatureProvider 79 | let outputs: MLFeatureProvider 80 | 81 | func cacheProcessorInputs(forBlock blockIndex: Int) throws -> MLFeatureProvider { 82 | let inputs = [ 83 | "old_k_cache": inputs.featureValue(for: "k_cache_\(blockIndex)")!.multiArrayValue!, 84 | "new_k_cache": outputs.featureValue(for: "new_k_cache_\(blockIndex)")!.multiArrayValue!, 85 | "old_v_cache": inputs.featureValue(for: "v_cache_\(blockIndex)")!.multiArrayValue!, 86 | "new_v_cache": outputs.featureValue(for: "new_v_cache_\(blockIndex)")!.multiArrayValue!, 87 | ] 88 | return try MLDictionaryFeatureProvider(dictionary: inputs) 89 | } 90 | 91 | func cacheProcessorOptions(forBlock blockIndex: Int) -> MLPredictionOptions { 92 | let opts = MLPredictionOptions() 93 | opts.outputBackings = [ 94 | "updated_k_cache": inputs.featureValue(for: "k_cache_\(blockIndex)")!.multiArrayValue!, 95 | "updated_v_cache": inputs.featureValue(for: "v_cache_\(blockIndex)")!.multiArrayValue!, 96 | ] 97 | return opts 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /Sources/Kit/LogitProcessor.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import CoreML 3 | import OSLog 4 | 5 | /// Pick the next token from the provided logits. 6 | class LogitProcessor { 7 | let model: DeferredModel 8 | 9 | let signposter = OSSignposter(subsystem: "com.stephenpanaro.llm-cli", category: "LogitProcessor") 10 | 11 | init(model: DeferredModel) { 12 | self.model = model 13 | } 14 | 15 | func load() { 16 | self.model.load() 17 | } 18 | 19 | func unload() { 20 | self.model.unload() 21 | } 22 | 23 | // func multinomial(logits: MLMultiArray) async throws -> Int { 24 | // } 25 | 26 | func argmax(logits: [MLMultiArray], index: Int? = nil) async throws -> Int { 27 | let state = signposter.beginInterval("Sample", "Argmax") 28 | defer { signposter.endInterval("Sample", state) } 29 | 30 | let outputs = try await predict(logits: logits) 31 | let argmaxArray = MLShapedArray(outputs.featureValue(for: "argmax")!.multiArrayValue!) 32 | let prediction = argmaxArray[0, index ?? (argmaxArray.shape[1] - 1)].scalar ?? -1 33 | return Int(prediction) 34 | } 35 | 36 | fileprivate func predict(logits: [MLMultiArray]) async throws -> MLFeatureProvider { 37 | let keysValues = logits.enumerated().map { (index, array) in 38 | return (logits.count == 1 ? "logits" : "logits_\(index)", array) 39 | } 40 | let inputs = try MLDictionaryFeatureProvider( 41 | dictionary: Dictionary(keysValues, uniquingKeysWith: { (first, _) in first }) 42 | ) 43 | return try await model.model!.prediction(from: inputs) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /Sources/Kit/ModelPipeline.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import CoreML 3 | import OSLog 4 | 5 | /// ModelPipeline manages the forward pass for an LLM that 6 | /// has been split across many MLModels. 7 | class ModelPipeline { 8 | let chunks: [PipelineChunk] 9 | var inferenceConfiguration: PipelineInferenceConfiguration? // Can't retrieve this until models are loaded. 10 | let cacheProcessorModel: DeferredModel 11 | let logitProcessor: LogitProcessor 12 | 13 | var loadableProcessors: [Loadable] { 14 | [cacheProcessorModel, logitProcessor] 15 | } 16 | 17 | let signposter = OSSignposter(subsystem: "com.stephenpanaro.llm-cli", category: "ModelPipeline") 18 | 19 | init(chunks: [PipelineChunk], cacheProcessor: DeferredModel, logitProcessor: LogitProcessor) { 20 | self.chunks = chunks 21 | precondition(chunks.count > 0) 22 | self.cacheProcessorModel = cacheProcessor 23 | self.logitProcessor = logitProcessor 24 | } 25 | 26 | /// Load the pipeline gradually to minimize resource usage 27 | /// during the initial load and model compilation/specialization. 28 | fileprivate func prewarm() { 29 | print("Compiling models: ", terminator: "") 30 | fflush(stdout) 31 | 32 | // No need for signposts, should be fast. 33 | loadableProcessors.forEach { 34 | $0.load() 35 | $0.unload() 36 | } 37 | 38 | chunks.enumerated().forEach { (i, chunk) in 39 | signposter.withIntervalSignpost("Prepare", "Warm Chunk \(i)") { 40 | chunk.load() 41 | chunk.unload() 42 | } 43 | print("*", terminator: "") 44 | fflush(stdout) 45 | } 46 | print() 47 | } 48 | 49 | func load() throws { 50 | prewarm() 51 | print("Loading models : ", terminator: "") 52 | fflush(stdout) 53 | 54 | loadableProcessors.forEach { 55 | $0.load() 56 | } 57 | 58 | chunks.enumerated().forEach { (i, chunk) in 59 | signposter.withIntervalSignpost("Prepare", "Load Chunk \(i)") { 60 | chunk.load() 61 | } 62 | print("*", terminator: "") 63 | fflush(stdout) 64 | } 65 | print() 66 | 67 | inferenceConfiguration = .init(from: chunks.compactMap { $0.model }) 68 | if inferenceConfiguration == nil { 69 | // Unable to infer the correct model parameters from the model inputs. 70 | // We won't be able to predict. 71 | throw PipelineError.unsupportedInferenceConfiguration 72 | } 73 | } 74 | 75 | func predict(tokens initialTokens: [Int], maxNewTokens: Int) throws -> AsyncThrowingStream { 76 | guard let inferenceConfiguration else { 77 | throw PipelineError.unsupportedInferenceConfiguration 78 | } 79 | 80 | let arrayStore = MultiArrayStore(for: self) 81 | let kvCacheProcessor = KVCacheProcessor(pipeline: self, processorModel: cacheProcessorModel.model!) 82 | 83 | return AsyncThrowingStream { continuation in 84 | let inputLength = inferenceConfiguration.inputLength 85 | var promptChunks = stride(from: 0, to: initialTokens.count, by: inputLength).map { 86 | Array(initialTokens[$0..? = nil 94 | while tokens.count < maxTokens { 95 | let timer = CodeTimer() 96 | let tokenSignpostState = self.signposter.beginInterval("Predict", id: self.signposter.makeSignpostID(), "Token #\(tokens.count)") 97 | 98 | // Do a forward pass of the model. 99 | var logitChunks: [MLMultiArray] = [] 100 | for (i, chunk) in self.chunks.enumerated() { 101 | let model = chunk.model! 102 | 103 | // Wait for the KV cache to be asynchronously updated. 104 | try await kvCacheProcessor.wait(forChunk: i) 105 | 106 | let inputs = try arrayStore.featureProvider(forChunk: i, model: model, tokens: tokens) 107 | 108 | let options = MLPredictionOptions() 109 | options.outputBackings = arrayStore.outputBackings(forChunk: i, model: model) 110 | 111 | let predictState = self.signposter.beginInterval("Predict", id: self.signposter.makeSignpostID(), "Chunk \(i)") 112 | let outputs = try await model.prediction(from: inputs, options: options) 113 | self.signposter.endInterval("Predict", predictState) 114 | arrayStore.update(outputs: outputs, forChunk: i) // Using output backings should make this ~a noop. 115 | 116 | // Update the cache (async) each time a full set of input tokens is processed. 117 | // e.g. if the model takes in 64 input tokens at once, update when tokens 118 | // is length 64, 128, 192, etc. 119 | if tokens.count % inputLength == 0 { 120 | kvCacheProcessor.submit(inputs: inputs, outputs: outputs, forChunk: i) 121 | } 122 | 123 | // Sometimes logits are returned in chunks along the vocab dimension. 124 | let logitFeatures = outputs.featureNames 125 | .filter { $0.starts(with: "logit") } 126 | .sorted { ($0.trailingNumberSuffix() ?? -1) < ($1.trailingNumberSuffix() ?? -1) } 127 | logitChunks = logitFeatures.map { outputs.featureValue(for: $0)!.multiArrayValue! } 128 | } 129 | 130 | if !promptChunks.isEmpty { 131 | for token in promptChunks.removeFirst() { 132 | tokens.append(token) 133 | continuation.yield(Prediction(newToken: token, allTokens: tokens, latency: nil, promptLatency: nil)) 134 | } 135 | promptLatency = promptTimer.elapsed() 136 | } else { 137 | let newTokenIndex = tokens.isEmpty ? 0 : (tokens.count - 1) % inputLength 138 | let newToken = try await self.logitProcessor.argmax(logits: logitChunks, index: newTokenIndex) 139 | tokens.append(newToken) 140 | continuation.yield(Prediction(newToken: newToken, allTokens: tokens, latency: timer.elapsed(), promptLatency: promptLatency)) 141 | promptLatency = nil 142 | } 143 | 144 | self.signposter.endInterval("Predict", tokenSignpostState, "\(tokens.last!)") 145 | } 146 | 147 | continuation.finish() 148 | } 149 | } 150 | } 151 | 152 | 153 | extension ModelPipeline: CustomDebugStringConvertible { 154 | var debugDescription: String { 155 | let fileName = chunks.first?.fileInfo.displayModelName ?? "" 156 | return"\(Self.self) \(fileName) (\(chunks.count) chunks)" 157 | } 158 | } 159 | 160 | extension ModelPipeline { 161 | /// Creates a pipeline from the mlmodelc files in the given folder. 162 | /// Model files should follow the format: `${MODEL_PREFIX}_chunk${CHUNK_NUMBER}.mlmodelc` 163 | /// Does not load the model. 164 | class func from( 165 | folder: URL, 166 | modelPrefix: String?, 167 | cacheProcessorModelName: String, 168 | logitProcessorModelName: String, 169 | primaryCompute: MLComputeUnits = .cpuAndNeuralEngine, 170 | chunkLimit: Int? = nil 171 | ) throws -> ModelPipeline { 172 | let manager = FileManager.default 173 | let contents = try manager.contentsOfDirectory(atPath: folder.path(percentEncoded: false)) 174 | 175 | let chunkFiles = contents 176 | .compactMap { ChunkFileInfo(url: folder.appending(path: $0)) } 177 | .filter { $0.url.pathExtension == "mlmodelc" } 178 | .filter { 179 | if let modelPrefix { $0.modelPrefix.hasPrefix(modelPrefix) } 180 | else { true } 181 | } 182 | .sorted(by: { $0.chunkNumber < $1.chunkNumber }) 183 | 184 | let uniquePrefixes = Set(chunkFiles.map { $0.modelPrefix }) 185 | if uniquePrefixes.count > 1 { 186 | throw PipelineError.ambiguousModelPath(possiblePrefixes: Array(uniquePrefixes)) 187 | } 188 | 189 | let chunks = chunkFiles.enumerated() 190 | .filter { (i, _) in 191 | // Allow limiting the number of chunks for debugging. 192 | if i == 0 || i == chunkFiles.count - 1 { return true } 193 | if let chunkLimit { return i <= chunkLimit } 194 | return true 195 | } 196 | .map { (i, chunkFile) in 197 | let config = MLModelConfiguration() 198 | // The first chunk has operations that cannot run on ANE. 199 | config.computeUnits = i == 0 ? .cpuOnly : primaryCompute 200 | config.modelDisplayName = "Chunk \(chunkFile.chunkNumber)" 201 | return PipelineChunk(fileInfo: chunkFile, configuration: config) 202 | } 203 | 204 | if chunks.count == 0 { 205 | throw PipelineError.modelChunksNotFound 206 | } 207 | 208 | // Model for updating KV caches. 209 | let cacheProcessorURL = folder.appending(component: cacheProcessorModelName) 210 | let cacheModelConfig = MLModelConfiguration() 211 | cacheModelConfig.computeUnits = primaryCompute 212 | cacheModelConfig.modelDisplayName = "Cache Processor" 213 | let cacheProcessor = DeferredModel(url: cacheProcessorURL, configuration: cacheModelConfig) 214 | 215 | // Model for choosing the next token. 216 | let logitProcessorURL = folder.appending(component: logitProcessorModelName) 217 | let logitProcessorModelConfig = MLModelConfiguration() 218 | logitProcessorModelConfig.computeUnits = primaryCompute 219 | logitProcessorModelConfig.modelDisplayName = "Logit Processor" 220 | let logitProcessor = LogitProcessor(model: DeferredModel(url: logitProcessorURL, configuration: logitProcessorModelConfig)) 221 | 222 | return ModelPipeline(chunks: chunks, cacheProcessor: cacheProcessor, logitProcessor: logitProcessor) 223 | } 224 | } 225 | 226 | class PipelineChunk { 227 | let fileInfo: ChunkFileInfo 228 | let configuration: MLModelConfiguration 229 | var model: MLModel? 230 | 231 | convenience init?(url: URL, configuration: MLModelConfiguration) { 232 | guard let fileInfo = ChunkFileInfo(url: url) else { return nil } 233 | self.init(fileInfo: fileInfo, configuration: configuration) 234 | } 235 | 236 | init(fileInfo: ChunkFileInfo, configuration: MLModelConfiguration) { 237 | self.fileInfo = fileInfo 238 | self.configuration = configuration 239 | } 240 | 241 | func load() { 242 | guard model == nil else { return } 243 | model = try! MLModel(contentsOf: fileInfo.url, configuration: configuration) 244 | } 245 | 246 | func unload() { 247 | model = nil 248 | } 249 | } 250 | 251 | extension PipelineChunk: CustomDebugStringConvertible { 252 | var debugDescription: String { 253 | return "\(Self.self) \(fileInfo.chunkNumber) (\(configuration.computeUnits.debugName))" 254 | } 255 | } 256 | 257 | class DeferredModel { 258 | let url: URL 259 | let configuration: MLModelConfiguration 260 | 261 | var model: MLModel? 262 | 263 | init(url: URL, configuration: MLModelConfiguration) { 264 | self.url = url 265 | self.configuration = configuration 266 | } 267 | 268 | func load() { 269 | model = try! MLModel(contentsOf: url, configuration: configuration) 270 | } 271 | 272 | func unload() { 273 | model = nil 274 | } 275 | } 276 | 277 | struct ChunkFileInfo { 278 | let url: URL 279 | let fileName: String 280 | let modelPrefix: String 281 | let chunkNumber: Int 282 | 283 | init?(url: URL) { 284 | self.url = url 285 | self.fileName = url.lastPathComponent 286 | 287 | var split = url.deletingPathExtension().lastPathComponent.split(separator: "_") 288 | guard 289 | let chunkString = split.popLast(), 290 | let chunkNumber = Int(chunkString.replacingOccurrences(of: "chunk", with: "")) 291 | else { 292 | return nil 293 | } 294 | 295 | self.modelPrefix = split.joined(separator: "_") + "_" 296 | self.chunkNumber = chunkNumber 297 | } 298 | 299 | var displayModelName: String { 300 | // Drop the last _ 301 | String(modelPrefix.prefix(upTo: modelPrefix.index(before: modelPrefix.endIndex))) 302 | } 303 | } 304 | 305 | enum PipelineError: Error { 306 | case unsupportedInferenceConfiguration 307 | case cacheProcessorNotFound 308 | case modelChunksNotFound 309 | case ambiguousModelPath(possiblePrefixes: [String]) 310 | case notImplementedError 311 | } 312 | 313 | struct Prediction { 314 | let newToken: Int 315 | let allTokens: [Int] 316 | let latency: Measurement? 317 | let promptLatency: Measurement? 318 | } 319 | 320 | struct CodeTimer { 321 | let start = CFAbsoluteTimeGetCurrent() 322 | 323 | func elapsed() -> Measurement { 324 | let seconds = CFAbsoluteTimeGetCurrent() - start 325 | return Measurement(value: seconds, unit: .seconds) 326 | } 327 | } 328 | 329 | protocol Loadable { 330 | func load() 331 | func unload() 332 | } 333 | 334 | extension LogitProcessor: Loadable {} 335 | extension DeferredModel: Loadable {} 336 | extension PipelineChunk: Loadable {} 337 | 338 | extension String { 339 | func trailingNumberSuffix() -> Int? { 340 | guard let number = self.split(separator: "_").last else { return nil } 341 | return Int(number) 342 | } 343 | } 344 | -------------------------------------------------------------------------------- /Sources/Kit/MultiArrayStore.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import CoreML 3 | import OSLog 4 | 5 | /// Store and make accessible MLMultiArrays that will be reused 6 | /// in either subsequent chunks or subsequent forward passes. 7 | class MultiArrayStore { 8 | // Arrays that are used as inputs/outputs for multiple chunks. 9 | // Names must be unique across the whole pipeline. 10 | let sharedArrays: [String: MLMultiArray] 11 | // Arrays that are only used with one chunk. 12 | // Names must only be unique within a chunk. e.g. KV cache inputs/outputs. 13 | let chunkArrays: [[String: MLMultiArray]] 14 | 15 | // Map from output names to a different key name to use for the output 16 | // backing. Useful when the input is consumed and the output is the same shape. 17 | var outputBackingMapping: [String: String] 18 | 19 | let signposter = OSSignposter(subsystem: "com.stephenpanaro.llm-cli", category: "MultiArrayStore") 20 | 21 | init(sharedArrays: [String : MLMultiArray], 22 | chunkArrays: [[String : MLMultiArray]], 23 | outputBackingMapping: [String: String] = [:] 24 | ) { 25 | self.sharedArrays = sharedArrays 26 | self.chunkArrays = chunkArrays 27 | self.outputBackingMapping = outputBackingMapping 28 | } 29 | 30 | func inputFeatures(forChunk chunkIndex: Int, model: MLModel) -> [String: MLMultiArray] { 31 | Dictionary(uniqueKeysWithValues: model.modelDescription.inputDescriptionsByName.keys.compactMap { name in 32 | self.array(named: name, for: chunkIndex).map { arr in (name, arr) } 33 | }) 34 | } 35 | 36 | func outputBackings(forChunk chunkIndex: Int, model: MLModel) -> [String: MLMultiArray] { 37 | Dictionary(uniqueKeysWithValues: model.modelDescription.outputDescriptionsByName.keys.compactMap { name in 38 | let backingName = outputBackingMapping[name] ?? name 39 | return self.array(named: backingName, for: chunkIndex).map { arr in (name, arr) } 40 | }) 41 | } 42 | 43 | func update(outputs: MLFeatureProvider, forChunk chunkIndex: Int) { 44 | let state = signposter.beginInterval("Update Outputs", "Chunk \(chunkIndex)") 45 | defer { signposter.endInterval("Update Outputs", state) } 46 | 47 | outputs.featureNames.forEach { name in 48 | // Only float16 buffers are cached. 49 | guard let multiArrayValue = outputs.featureValue(for: name)?.multiArrayValue, 50 | multiArrayValue.dataType == .float16 51 | else { return } 52 | 53 | guard let pixelBuffer = multiArrayValue.pixelBuffer 54 | else { preconditionFailure("output array is not pixel buffer backed: \(name)") } 55 | 56 | let cacheName = outputBackingMapping[name] ?? name 57 | guard let cachePixelBuffer = self.array(named: cacheName, for: chunkIndex)?.pixelBuffer 58 | else { preconditionFailure("output array is not present in the cache: \(name)") } 59 | 60 | // TODO: Gracefully degrade by manually copying the array. 61 | precondition(pixelBuffer === cachePixelBuffer, "output backing not used: \(name)") 62 | } 63 | } 64 | 65 | fileprivate func array(named name: String, for chunkIndex: Int) -> MLMultiArray? { 66 | if let sharedArr = sharedArrays[name] { 67 | return sharedArr 68 | } else if let chunkArr = chunkArrays[chunkIndex][name] { 69 | return chunkArr 70 | } 71 | return nil 72 | } 73 | } 74 | 75 | extension MultiArrayStore { 76 | convenience init(for pipeline: ModelPipeline) { 77 | var sharedArrays = [String: MLMultiArray]() 78 | var chunkArrays = [[String: MLMultiArray]]() 79 | 80 | // For each output key, reuse the input array named by the value. 81 | let outputBackingMapping = ["new_x": "x"] 82 | 83 | // KV caches, cos + sin for RoPE, attention mask 84 | pipeline.chunks.forEach { chunk in 85 | let model = chunk.model! 86 | var floatShapes = [String: [Int]]() 87 | 88 | let modelDescription = model.modelDescription 89 | let allDescriptions = modelDescription.inputDescriptionsByName.merging( 90 | modelDescription.outputDescriptionsByName, uniquingKeysWith: { a,b in a }) 91 | 92 | allDescriptions.forEach { (name, desc) in 93 | guard let constraint = desc.multiArrayConstraint else { return } 94 | if constraint.dataType != .float16 { return } 95 | if outputBackingMapping.keys.contains(name) { return } 96 | floatShapes[name] = constraint.shape.map { $0.intValue } 97 | } 98 | 99 | let isShared = { (key: String) -> Bool in !key.contains("cache") } 100 | let sharedShapes = floatShapes.filter { isShared($0.key) } 101 | let chunkShapes = floatShapes.filter { !isShared($0.key) } 102 | 103 | // TODO: Sort so K caches are allocated close to other K caches. Ditto for V. 104 | // See if it has a performance impact. 105 | 106 | sharedShapes.forEach { (name, shape) in 107 | if sharedArrays[name] == nil { 108 | sharedArrays[name] = MLMultiArray.emptyIOSurfaceArray(shape: shape)! 109 | } 110 | } 111 | 112 | var currChunkArrays = [String:MLMultiArray]() 113 | chunkShapes.forEach { (name, shape) in 114 | currChunkArrays[name] = MLMultiArray.emptyIOSurfaceArray(shape: shape)! 115 | } 116 | chunkArrays.append(currChunkArrays) 117 | } 118 | 119 | self.init(sharedArrays: sharedArrays, chunkArrays: chunkArrays, outputBackingMapping: outputBackingMapping) 120 | } 121 | } 122 | 123 | 124 | extension MultiArrayStore { 125 | func featureProvider(forChunk chunkIndex: Int, model: MLModel, tokens: [Int]) throws -> MLFeatureProvider { 126 | let state = signposter.beginInterval("Prepare Features", "Chunk \(chunkIndex)") 127 | defer { signposter.endInterval("Prepare Features", state) } 128 | 129 | var cacheFeatures = inputFeatures(forChunk: chunkIndex, model: model) 130 | 131 | // Input IDs is populated in concert with KV cache updates. 132 | // 133 | // 0 1 2 3 4 134 | // _ _ _ _ _ _ _ _ _ _ _ _ _ 135 | // ┌─────────────────────────┬────────┐ 136 | // └───────────Cache─────────┴─Inputs─┘ 137 | // 138 | // 139 | // 5 _ _ _ _ 140 | // _ _ _ _ _ _ _ _ 0 1 2 3 4 141 | // ┌─────────────────────────┬────────┐ 142 | // └───────────Cache─────────┴─Inputs─┘ 143 | // 144 | // 145 | // 5 6 _ _ _ 146 | // _ _ _ _ _ _ _ _ 0 1 2 3 4 147 | // ┌─────────────────────────┬────────┐ 148 | // └───────────Cache─────────┴─Inputs─┘ 149 | // 150 | // ┌ ─ ─ ─ ─ ─ ─ ┐ 151 | // Same for 7-9 152 | // └ ─ ─ ─ ─ ─ ─ ┘ 153 | // 154 | // 5 6 7 8 9 155 | // _ _ _ _ _ _ _ _ 0 1 2 3 4 156 | // ┌─────────────────────────┬────────┐ 157 | // └───────────Cache─────────┴─Inputs─┘ 158 | // 159 | // ◀━━━━━━━━━━━━━Shift!━━━━━━━━━━━━━━━ 160 | // 161 | // A _ _ _ _ 162 | // _ _ _ 0 1 2 3 4 5 6 7 8 9 163 | // ┌─────────────────────────┬────────┐ 164 | // └───────────Cache─────────┴─Inputs─┘ 165 | // 166 | // ┌ ─ ─ ─ ─ ─ 167 | // Repeat │ 168 | // └ ─ ─ ─ ─ ─ 169 | var padCount = 0 170 | let inputDescriptions = model.modelDescription.inputDescriptionsByName 171 | if let inputIDsConstraint = inputDescriptions["input_ids"]?.multiArrayConstraint { 172 | let inputShape = inputIDsConstraint.shape.map { $0.intValue } 173 | let inputLength = inputShape.last! 174 | 175 | // For inputLength 64, tokens.count -> suffixLength: 176 | // 0 -> 0, 5 -> 5, 64 -> 64, 65 -> 1, 128 -> 64 177 | let suffixLength = tokens.isEmpty ? 0 : (tokens.count - 1) % inputLength + 1 178 | let inputTokens = tokens.suffix(suffixLength).map { Int32($0) } 179 | padCount = max(0, inputLength - inputTokens.count) 180 | let paddedInputTokens = inputTokens + Array(repeating: Int32(0), count: padCount) 181 | let inputIDs = MLShapedArray(scalars: paddedInputTokens, shape: inputShape) 182 | cacheFeatures["input_ids"] = MLMultiArray(inputIDs) 183 | } 184 | 185 | if inputDescriptions["full_sequence_length"]?.multiArrayConstraint != nil { 186 | // For example: [Cache0 Cache1][Input2 Input3 Pad0] 187 | // The causal mask prevents Pad0 from impacting 188 | // the InputN token predictions. We must include it 189 | // here so RoPE is correct for the InputN tokens. 190 | let fullSequenceLength = MLShapedArray(repeating: Int32(tokens.count + padCount), shape: [1]) 191 | cacheFeatures["full_sequence_length"] = MLMultiArray(fullSequenceLength) 192 | } 193 | 194 | return try MLDictionaryFeatureProvider(dictionary: cacheFeatures) 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /Sources/Kit/PipelineInferenceConfiguration.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import CoreML 3 | 4 | struct PipelineInferenceConfiguration { 5 | let vocabSize: Int 6 | let inputLength: Int // query length 7 | let contextLength: Int // key + value length 8 | 9 | var cacheLength: Int { 10 | contextLength - inputLength 11 | } 12 | } 13 | 14 | extension PipelineInferenceConfiguration { 15 | init?(from models: [MLModel]) { 16 | guard models.count > 2 else { return nil } 17 | 18 | guard let first = models.first, let last = models.last 19 | else { return nil } 20 | let innerModels = models[1.. Element { 5 | return self.reduce(0, +) 6 | } 7 | 8 | func mean() -> Element { 9 | return self.sum() / Element(self.count) 10 | } 11 | 12 | func stdev() -> Element { 13 | let mean = self.mean() 14 | let v = self.reduce(0, { $0 + ($1-mean)*($1-mean) }) 15 | return sqrt(v / (Element(self.count) - 1)) 16 | } 17 | } 18 | 19 | extension AsyncThrowingStream { 20 | /// Initialize a stream that runs its body closure asynchronously. 21 | public init( 22 | _ elementType: Element.Type = Element.self, 23 | bufferingPolicy limit: AsyncThrowingStream.Continuation.BufferingPolicy = .unbounded, 24 | _ build: @Sendable @escaping (AsyncThrowingStream.Continuation) async throws -> Void 25 | ) where Failure == Error { 26 | self = AsyncThrowingStream(elementType, bufferingPolicy: limit) { continuation in 27 | let task = Task { 28 | try await build(continuation) 29 | } 30 | 31 | continuation.onTermination = { _ in 32 | task.cancel() 33 | } 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /Sources/Kit/TextGenerator.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Tokenizers 3 | 4 | /// TextGenerator uses an LLM `ModelPipeline` to generate text. 5 | class TextGenerator { 6 | let pipeline: ModelPipeline 7 | let tokenizer: Tokenizer 8 | // var strategy: Strategy = .argmax // etc. 9 | 10 | init(pipeline: ModelPipeline, tokenizer: Tokenizer) { 11 | self.pipeline = pipeline 12 | self.tokenizer = tokenizer 13 | } 14 | 15 | func generate(text: String, maxNewTokens: Int) async throws { 16 | 17 | let loadTimer = CodeTimer() 18 | try pipeline.load() 19 | let loadDuration = loadTimer.elapsed() 20 | 21 | let tokens = tokenizer.encode(text: text) 22 | 23 | var predictions = [Prediction]() 24 | tokens.forEach { print($0, terminator: " ") } 25 | fflush(stdout) 26 | 27 | for try await prediction in try pipeline.predict(tokens: tokens, maxNewTokens: maxNewTokens) { 28 | predictions.append(prediction) 29 | print(prediction.newToken, terminator: " ") 30 | fflush(stdout) 31 | } 32 | print("\n") 33 | 34 | print(tokenizer.decode(tokens: predictions.last?.allTokens ?? tokens)) 35 | print() 36 | 37 | print("Compile + Load: \(loadDuration.converted(to: .seconds).value.formatted(.number.precision(.fractionLength(2)))) sec") 38 | let numberFormat = FloatingPointFormatStyle.number.precision(.fractionLength(2)) 39 | 40 | let promptLatencies = predictions.compactMap { $0.promptLatency?.converted(to: .milliseconds).value } 41 | if !promptLatencies.isEmpty { 42 | let averagePrompt = Measurement(value: promptLatencies.mean(), unit: UnitDuration.milliseconds) 43 | print("Prompt : \(averagePrompt.value.formatted(numberFormat)) ms") 44 | } 45 | 46 | print("Generate :", terminator: " ") 47 | 48 | let latencies = predictions.compactMap { $0.latency?.converted(to: .milliseconds).value } 49 | let average = Measurement(value: latencies.mean(), unit: UnitDuration.milliseconds) 50 | let stdev = Measurement(value: latencies.stdev(), unit: UnitDuration.milliseconds) 51 | print("\(average.value.formatted(numberFormat)) +/- \(stdev.value.formatted(numberFormat)) ms / token") 52 | 53 | let throughputs = predictions.compactMap { $0.latency?.converted(to: .seconds).value }.map { 1 / $0 } 54 | let averageThroughput = Measurement(value: throughputs.mean(), unit: UnitDuration.seconds) 55 | let stdevThroughput = Measurement(value: throughputs.stdev(), unit: UnitDuration.seconds) 56 | print(" \(averageThroughput.value.formatted(numberFormat)) +/- \(stdevThroughput.value.formatted(numberFormat)) token / sec") 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /Sources/Kit/Transformers+Extensions.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Tokenizers 3 | import Hub 4 | 5 | extension AutoTokenizer { 6 | /// Load the tokenizer from a cache when offline. 7 | /// Workaround until https://github.com/huggingface/swift-transformers/issues/39 8 | static func from( 9 | cached model: String, 10 | hubApi: HubApi = .shared 11 | ) async throws -> Tokenizer { 12 | // Network first since there is no cache invalidation. 13 | do { 14 | let config = LanguageModelConfigurationFromHub(modelName: model, hubApi: hubApi) 15 | guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig } 16 | let tokenizerData = try await config.tokenizerData 17 | 18 | return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) 19 | } catch { 20 | let modelFolder = hubApi.localRepoLocation(.init(id: model)) 21 | return try await from(modelFolder: modelFolder, hubApi: hubApi) 22 | } 23 | } 24 | } 25 | 26 | // Not public, so copied over. 27 | enum TokenizerError : Error { 28 | case missingConfig 29 | } 30 | 31 | extension HubApi { 32 | static var defaultTokenLocation: URL { 33 | FileManager.default.homeDirectoryForCurrentUser.appending(path: ".cache/huggingface/token") 34 | } 35 | 36 | static func defaultToken() -> String? { 37 | if let envToken = ProcessInfo.processInfo.environment["HF_TOKEN"] { 38 | return envToken 39 | } 40 | return try? String(contentsOf: defaultTokenLocation, encoding: .utf8) 41 | } 42 | } 43 | --------------------------------------------------------------------------------