├── .gitignore
├── Package.resolved
├── Package.swift
├── README.md
└── Sources
├── CLI.swift
└── Kit
├── CoreML+Extensions.swift
├── KVCacheProcessor.swift
├── LogitProcessor.swift
├── ModelPipeline.swift
├── MultiArrayStore.swift
├── PipelineInferenceConfiguration.swift
├── Swift+Extensions.swift
├── TextGenerator.swift
└── Transformers+Extensions.swift
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | /.build
3 | /Packages
4 | xcuserdata/
5 | DerivedData/
6 | .swiftpm/configuration/registries.json
7 | .swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
8 | .netrc
9 |
--------------------------------------------------------------------------------
/Package.resolved:
--------------------------------------------------------------------------------
1 | {
2 | "pins" : [
3 | {
4 | "identity" : "jinja",
5 | "kind" : "remoteSourceControl",
6 | "location" : "https://github.com/maiqingqiang/Jinja",
7 | "state" : {
8 | "revision" : "b435eb62b0d3d5f34167ec70a128355486981712",
9 | "version" : "1.0.5"
10 | }
11 | },
12 | {
13 | "identity" : "swift-argument-parser",
14 | "kind" : "remoteSourceControl",
15 | "location" : "https://github.com/apple/swift-argument-parser",
16 | "state" : {
17 | "revision" : "41982a3656a71c768319979febd796c6fd111d5c",
18 | "version" : "1.5.0"
19 | }
20 | },
21 | {
22 | "identity" : "swift-transformers",
23 | "kind" : "remoteSourceControl",
24 | "location" : "https://github.com/huggingface/swift-transformers",
25 | "state" : {
26 | "revision" : "4d25d20e49d2269aec1556231f8e278db7b2a4f0",
27 | "version" : "0.1.13"
28 | }
29 | }
30 | ],
31 | "version" : 2
32 | }
33 |
--------------------------------------------------------------------------------
/Package.swift:
--------------------------------------------------------------------------------
1 | // swift-tools-version: 5.9
2 | // The swift-tools-version declares the minimum version of Swift required to build this package.
3 |
4 | import PackageDescription
5 |
6 | let package = Package(
7 | name: "LLMCLI",
8 | platforms: [
9 | .macOS(.v14),
10 | ],
11 | dependencies: [
12 | .package(url: "https://github.com/huggingface/swift-transformers", from: "0.1.13"),
13 | .package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"),
14 | ],
15 | targets: [
16 | // Targets are the basic building blocks of a package, defining a module or a test suite.
17 | // Targets can depend on other targets in this package and products from dependencies.
18 | .executableTarget(
19 | name: "LLMCLI",
20 | dependencies: [
21 | .product(name: "ArgumentParser", package: "swift-argument-parser"),
22 | .product(name: "Transformers", package: "swift-transformers"),
23 | ],
24 | path: "Sources"
25 | ),
26 | ]
27 | )
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CoreML LLM CLI
2 | CLI to demonstrate running a large language model (LLM) on Apple Neural Engine.
3 |
4 | Download a CoreML-compatible Llama 2 7B model (~4GB), load it, and generate text:
5 | ```shell
6 | $ swift run LLMCLI --repo-id smpanaro/Llama-2-7b-coreml
7 | ```
8 |
9 | **Note:** macOS 14 (Sonoma) is required.
10 |
11 | ## Performance
12 | I'm curious to see the performance of this model on different M-series chips. If you don't see your chip below or if you get significantly different timings, please run the following command and open an issue so I can add it!
13 |
14 | To download + measure:
15 | ```shell
16 | $ swift run -c release LLMCLI --repo-id smpanaro/Llama-2-7b-coreml --max-new-tokens 80
17 | ```
18 |
19 | |Variant|1st Load Time|2nd+ Load Time|Tokens/Sec |ANE Power|
20 | |-- |-- |-- |-- |- |
21 | |M1 Max |113s |8.1s |7.02 ± 0.11 |4.2W |
22 | |M3 Max |- |0.8s |13.92 ± 0.5 |8W |
23 |
24 | For M2 and M4, consider the times for Model Version 1 as a lower bound:
25 |
26 | Model Version 1 (a76a14d)
27 |
28 | |Variant|1st Load Time|2nd+ Load Time|Tokens/Sec |ANE Power|
29 | |-- |-- |-- |-- |- |
30 | |M1 Max |77s |7.5s |4.97 ± 0.11 |3.6W |
31 | |M2 |- |22.9s |5.51 ± 0.56 |4.5-7.2W |
32 | |M2 Pro |71s |4.2s |6.76 ± 0.09 |- |
33 | |M3 |64s |5.47s |7.12 ± 0.16 |5.6W |
34 | |M3 Max |- |- |7.6 |5.5W |
35 | |M4 iPad|66s |- |7.76 ± 0.36 |- |
36 |
37 |
38 |
39 | ## Inference Optimizations
40 | This CLI implements a couple optimizations that might be useful/interesting even in medium-sized and smaller models.
41 |
42 | ### CVPixelBuffer/IOSurface MLMultiArrays
43 | All float16 model input and output arrays are created with IOSurface-backed CVPixelBuffers ([docs](https://developer.apple.com/documentation/coreml/mlmultiarray/3882834-init)). This avoids unnecessary copying that can occur when going between CPU and ANE, especially for large arrays like the KV cache.
44 |
45 | ### 20% Faster Convolutions With Reshaping
46 | As recommended in Apple's [Deploying Transformers on the Apple Neural Engine](https://machinelearning.apple.com/research/neural-engine-transformers), this model uses a 4D tensor layout: (Batch, Channels, 1, Sequence). We process 64 tokens at a time, so most tensors are (B,C,1,64). It turns out that the convolutions in the MLP are 50% faster when the tensor is (B,C,8,8). This seems to hold across hardware versions (M1, A14, A16, A17 Pro) so we can leverage it for an extra speedup.
47 |
48 | Unfortunately, attention requires a (B,C,1,S) shape tensor so we cannot simply run the whole model in (B,C,8,8) for the full 50% speedup. Instead we reshape to (B,C,1,64) before[^1] the QKV projections and back to (B,C,8,8) just before the attention out projection. This seems to minimize the cost of reshapes and allows us to achieve a ~20% overall speedup.
49 |
50 | [^1]: Yes, before. Doing less reshapes (1 instead of 3) is faster than doing these smaller convolutions in (8,8).
51 |
52 | ### Model Chunking
53 | The model is split into multiple smaller CoreML model chunks:
54 | - 1 for the embedding+attention mask+RoPE cos/sin
55 | - N for the blocks (3 [blocks](https://github.com/Lightning-AI/litgpt/blob/221b7ef54161272162aa9b036f1ef3674f3160a4/litgpt/model.py#L139) per chunk)
56 | - 1 for the lm (lanaguage modeling) head
57 |
58 | This allows for faster model loading and enables async KV cache manipulation (see below).
59 |
60 | ### Async KV Cache Updates
61 | This model uses an ANE-compatible KV cache which needs to be shifted periodically. Due to the model chunking this does not need to happen synchronously in each chunk. It only needs to be updated before the next chunk prediction (~1 full forward pass in the future).
62 |
63 | To take advantage of this, the new sections of the KV cache are returned from the model and a separate CoreML model combines them with the prior cache values asynchronously. For Llama this saves ~1-2ms per chunk (~20ms overall).
64 |
65 | ```
66 | ┌──────────────┐ ┌───────────────┐
67 | │ Old KV Cache │ │ Hidden States │
68 | │ (Length 448) │ │ (Length 64) │
69 | └──────────────┘ └───────────────┘
70 | ↘ ↙
71 | ┌───────────┐
72 | │Chunk Model│
73 | └───────────┘
74 | ↙ ↘
75 | ┌──────────────┐ ┌─────────────────┐
76 | │ New KV Cache │ │New Hidden States│
77 | │ (Length 64) │ │ (Length 64) │
78 | └──────────────┘ └─────────────────┘
79 |
80 | ┌───────────────────────────────────────────┐
81 | │Async after the chunk prediction completes:│
82 | │ │
83 | │ ┌──────────────┐ ┌──────────────┐ │
84 | │ │ Old KV Cache │ │ New KV Cache │ │
85 | │ │ (Length 448) │ │ (Length 64) │ │
86 | │ └──────────────┘ └──────────────┘ │
87 | │ ↘ ↙ │
88 | │ ┌──────────────────┐ │
89 | │ │Cache Update Model│ │
90 | │ └──────────────────┘ │
91 | │ ↓ │
92 | │ ┌────────────────┐ │
93 | │ │Updated KV Cache│ │
94 | │ │ (Length 448) │ │
95 | │ └────────────────┘ │
96 | └───────────────────────────────────────────┘
97 | ```
98 |
99 | This pattern would probably work for other non-critical path computations too.
100 |
--------------------------------------------------------------------------------
/Sources/CLI.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import ArgumentParser
3 | import Tokenizers
4 | import Hub
5 | import CoreML
6 |
7 | @main
8 | struct CLI: AsyncParsableCommand {
9 | @Option(help: "Huggingface repo ID. e.g. smpanaro/Llama-2-7b-CoreML")
10 | var repoID: String? = nil
11 |
12 | @Option(
13 | help: "The directory containing the model's mlmodelc files.",
14 | completion: .file(), transform: URL.init(fileURLWithPath:))
15 | var localModelDirectory: URL?
16 |
17 | @Option(help: "The model filename prefix, to differentiate when there are multiple models in a folder.")
18 | var localModelPrefix: String?
19 |
20 | @Option(help: "KV cache processor model filename, located in the model directory.")
21 | var cacheProcessorModelName: String = "cache-processor.mlmodelc"
22 |
23 | @Option(help: "Logit processor model filename, located in the model directory.")
24 | var logitProcessorModelName: String = "logit-processor.mlmodelc"
25 |
26 | @Option(help: "Tokenizer name on huggingface.")
27 | var tokenizerName: String?
28 |
29 | @Argument(help: "Input text.")
30 | var inputText: String = "Several species of shrub of the genus Coffea produce the berries from which coffee is extracted. The two main species commercially cultivated are Coffea canephora"
31 |
32 | @Option(help: "Maximum number of new tokens to generate.")
33 | var maxNewTokens: Int = 60
34 |
35 | @Option(help: "Print verbose logs for debugging.")
36 | var verbose: Bool = false
37 |
38 | mutating func run() async throws {
39 | var modelDirectory = localModelDirectory
40 | if let repoID {
41 | modelDirectory = try await downloadModel(repoID: repoID)
42 | }
43 |
44 | guard let modelDirectory else {
45 | print("Either --repoID or --localModelDirectory must be provided.")
46 | return
47 | }
48 |
49 | guard let tokenizerName = tokenizerName ?? inferTokenizer() else {
50 | print("Unable to infer tokenizer. Please provide --tokenizerName.")
51 | return
52 | }
53 |
54 | let pipeline = try ModelPipeline.from(
55 | folder: modelDirectory,
56 | modelPrefix: localModelPrefix,
57 | cacheProcessorModelName: cacheProcessorModelName,
58 | logitProcessorModelName: logitProcessorModelName
59 | // For debugging.
60 | // primaryCompute: .cpuOnly
61 | // chunkLimit: 1
62 | )
63 | print(pipeline)
64 |
65 | let tokenizer = try await AutoTokenizer.from(cached: tokenizerName, hubApi: .init(hfToken: HubApi.defaultToken()))
66 | if verbose { print("Tokenizer \(tokenizerName) loaded.") }
67 |
68 | let generator = TextGenerator(pipeline: pipeline, tokenizer: tokenizer)
69 | try await generator.generate(text: inputText, maxNewTokens: maxNewTokens)
70 | }
71 |
72 | func inferTokenizer() -> String? {
73 | switch (repoID ?? localModelPrefix ?? localModelDirectory?.lastPathComponent)?.lowercased() ?? "" {
74 | case let str where str.contains("llama-2-7b"):
75 | return "pcuenq/Llama-2-7b-chat-coreml"
76 | case let str where str.contains( "llama-3.2-1b"):
77 | return "meta-llama/Llama-3.2-1B"
78 | case let str where str.contains( "llama-3.2-3b"):
79 | return "meta-llama/Llama-3.2-3B"
80 | default:
81 | return nil
82 | }
83 | }
84 |
85 | /// Download a model and return the local directory URL.
86 | func downloadModel(repoID: String) async throws -> URL {
87 | let hub = HubApi(hfToken: HubApi.defaultToken())
88 | let repo = Hub.Repo(id: repoID, type: .models)
89 |
90 | let mlmodelcs = ["Llama*.mlmodelc/*", "logit*", "cache*"]
91 | let filenames = try await hub.getFilenames(from: repo, matching: mlmodelcs)
92 |
93 | let localURL = hub.localRepoLocation(repo)
94 | let localFileURLs = filenames.map {
95 | localURL.appending(component: $0)
96 | }
97 | let anyNotExists = localFileURLs.filter {
98 | !FileManager.default.fileExists(atPath: $0.path(percentEncoded: false))
99 | }.count > 0
100 |
101 | // swift-transformers doesn't offer a way to know if we're up to date.
102 | let newestTimestamp = localFileURLs.filter {
103 | FileManager.default.fileExists(atPath: $0.path(percentEncoded: false))
104 | }.compactMap {
105 | let attrs = try! FileManager.default.attributesOfItem(atPath: $0.path(percentEncoded: false))
106 | return attrs[.modificationDate] as? Date
107 | }.max() ?? Date.distantFuture
108 | let lastUploadDate = Date(timeIntervalSince1970: 1723688450)
109 | let isStale = repoID == "smpanaro/Llama-2-7b-coreml" && newestTimestamp < lastUploadDate
110 |
111 | // I would rather not delete things automatically.
112 | if isStale {
113 | print("⚠️ You have an old model downloaded. Please move the following directory to the Trash and try again:")
114 | print(localURL.path())
115 | throw CLIError.staleFiles
116 | }
117 |
118 | let needsDownload = anyNotExists || isStale
119 | guard needsDownload else { return localURL }
120 |
121 | print("Downloading from \(repoID)...")
122 | if filenames.count == 0 {
123 | throw CLIError.noModelFilesFound
124 | }
125 |
126 | let downloadDir = try await hub.snapshot(from: repo, matching: mlmodelcs) { progress in
127 | let percent = progress.fractionCompleted * 100
128 | if !progress.isFinished {
129 | print("\(percent.formatted(.number.precision(.fractionLength(0))))%", terminator: "\r")
130 | fflush(stdout)
131 | }
132 | }
133 | print("Done.")
134 | print("Downloaded to \(downloadDir.path())")
135 | return downloadDir
136 | }
137 | }
138 |
139 | enum CLIError: Error {
140 | case noModelFilesFound
141 | case staleFiles
142 | }
143 |
--------------------------------------------------------------------------------
/Sources/Kit/CoreML+Extensions.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import CoreML
3 |
4 | extension MLComputeUnits {
5 | var debugName: String {
6 | switch self {
7 | case .all: return "All"
8 | case .cpuAndGPU: return "CPU+GPU"
9 | case .cpuAndNeuralEngine: return "CPU+ANE"
10 | case .cpuOnly: return "CPU"
11 | @unknown default: return "Unknown"
12 | }
13 | }
14 | }
15 |
16 | extension MLMultiArray {
17 | class func emptyIOSurfaceArray(shape: [Int]) -> MLMultiArray? {
18 | guard
19 | shape.count > 0,
20 | let width = shape.last
21 | else { return nil }
22 | let height = shape[0.. [String] {
50 | var ignored = [String]()
51 | outputBackings.forEach { name, backing in
52 | guard let multiArrayValue = outputs.featureValue(for: name)?.multiArrayValue,
53 | multiArrayValue.dataType == .float16
54 | else { return }
55 |
56 | guard let outputPixelBuffer = multiArrayValue.pixelBuffer
57 | else { preconditionFailure("output array is not pixel buffer backed: \(name)") }
58 |
59 | guard let backingPixelBuffer = (backing as? MLMultiArray)?.pixelBuffer
60 | else { preconditionFailure("output backing is not pixel buffer backed: \(name)")}
61 |
62 | if outputPixelBuffer !== backingPixelBuffer {
63 | ignored.append(name)
64 | }
65 | }
66 |
67 | return ignored
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/Sources/Kit/KVCacheProcessor.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import CoreML
3 | import OSLog
4 |
5 | /// Asynchronously use a chunk's KV cache output to prepare
6 | /// its KV cache inputs for the next predicition.
7 | class KVCacheProcessor {
8 | var chunkTasks: [Task?]
9 | let processorModel: MLModel
10 |
11 | let signposter = OSSignposter(subsystem: "com.stephenpanaro.llm-cli", category: "KVCacheProcessor")
12 |
13 | init(chunkCount: Int, processorModel: MLModel) {
14 | self.chunkTasks = Array(repeating: nil, count: chunkCount)
15 | self.processorModel = processorModel
16 | }
17 |
18 | convenience init(pipeline: ModelPipeline, processorModel: MLModel) {
19 | self.init(chunkCount: pipeline.chunks.count, processorModel: processorModel)
20 | }
21 |
22 | func submit(inputs: MLFeatureProvider, outputs: MLFeatureProvider, forChunk chunkIndex: Int) {
23 | let pair = ChunkInputsOutputs(inputs: inputs, outputs: outputs)
24 | chunkTasks[chunkIndex] = Task {
25 | try await withThrowingTaskGroup(of: Int.self) { group in
26 | for blockIndex in 0.. Int {
65 | let maxIndex = featureNames.filter {
66 | $0.contains("k_cache")
67 | }.compactMap {
68 | $0.components(separatedBy: "_").last
69 | }.compactMap {
70 | Int($0)
71 | }.max()
72 | guard let maxIndex else { return 0 }
73 | return maxIndex + 1
74 | }
75 | }
76 |
77 | struct ChunkInputsOutputs {
78 | let inputs: MLFeatureProvider
79 | let outputs: MLFeatureProvider
80 |
81 | func cacheProcessorInputs(forBlock blockIndex: Int) throws -> MLFeatureProvider {
82 | let inputs = [
83 | "old_k_cache": inputs.featureValue(for: "k_cache_\(blockIndex)")!.multiArrayValue!,
84 | "new_k_cache": outputs.featureValue(for: "new_k_cache_\(blockIndex)")!.multiArrayValue!,
85 | "old_v_cache": inputs.featureValue(for: "v_cache_\(blockIndex)")!.multiArrayValue!,
86 | "new_v_cache": outputs.featureValue(for: "new_v_cache_\(blockIndex)")!.multiArrayValue!,
87 | ]
88 | return try MLDictionaryFeatureProvider(dictionary: inputs)
89 | }
90 |
91 | func cacheProcessorOptions(forBlock blockIndex: Int) -> MLPredictionOptions {
92 | let opts = MLPredictionOptions()
93 | opts.outputBackings = [
94 | "updated_k_cache": inputs.featureValue(for: "k_cache_\(blockIndex)")!.multiArrayValue!,
95 | "updated_v_cache": inputs.featureValue(for: "v_cache_\(blockIndex)")!.multiArrayValue!,
96 | ]
97 | return opts
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/Sources/Kit/LogitProcessor.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import CoreML
3 | import OSLog
4 |
5 | /// Pick the next token from the provided logits.
6 | class LogitProcessor {
7 | let model: DeferredModel
8 |
9 | let signposter = OSSignposter(subsystem: "com.stephenpanaro.llm-cli", category: "LogitProcessor")
10 |
11 | init(model: DeferredModel) {
12 | self.model = model
13 | }
14 |
15 | func load() {
16 | self.model.load()
17 | }
18 |
19 | func unload() {
20 | self.model.unload()
21 | }
22 |
23 | // func multinomial(logits: MLMultiArray) async throws -> Int {
24 | // }
25 |
26 | func argmax(logits: [MLMultiArray], index: Int? = nil) async throws -> Int {
27 | let state = signposter.beginInterval("Sample", "Argmax")
28 | defer { signposter.endInterval("Sample", state) }
29 |
30 | let outputs = try await predict(logits: logits)
31 | let argmaxArray = MLShapedArray(outputs.featureValue(for: "argmax")!.multiArrayValue!)
32 | let prediction = argmaxArray[0, index ?? (argmaxArray.shape[1] - 1)].scalar ?? -1
33 | return Int(prediction)
34 | }
35 |
36 | fileprivate func predict(logits: [MLMultiArray]) async throws -> MLFeatureProvider {
37 | let keysValues = logits.enumerated().map { (index, array) in
38 | return (logits.count == 1 ? "logits" : "logits_\(index)", array)
39 | }
40 | let inputs = try MLDictionaryFeatureProvider(
41 | dictionary: Dictionary(keysValues, uniquingKeysWith: { (first, _) in first })
42 | )
43 | return try await model.model!.prediction(from: inputs)
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/Sources/Kit/ModelPipeline.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import CoreML
3 | import OSLog
4 |
5 | /// ModelPipeline manages the forward pass for an LLM that
6 | /// has been split across many MLModels.
7 | class ModelPipeline {
8 | let chunks: [PipelineChunk]
9 | var inferenceConfiguration: PipelineInferenceConfiguration? // Can't retrieve this until models are loaded.
10 | let cacheProcessorModel: DeferredModel
11 | let logitProcessor: LogitProcessor
12 |
13 | var loadableProcessors: [Loadable] {
14 | [cacheProcessorModel, logitProcessor]
15 | }
16 |
17 | let signposter = OSSignposter(subsystem: "com.stephenpanaro.llm-cli", category: "ModelPipeline")
18 |
19 | init(chunks: [PipelineChunk], cacheProcessor: DeferredModel, logitProcessor: LogitProcessor) {
20 | self.chunks = chunks
21 | precondition(chunks.count > 0)
22 | self.cacheProcessorModel = cacheProcessor
23 | self.logitProcessor = logitProcessor
24 | }
25 |
26 | /// Load the pipeline gradually to minimize resource usage
27 | /// during the initial load and model compilation/specialization.
28 | fileprivate func prewarm() {
29 | print("Compiling models: ", terminator: "")
30 | fflush(stdout)
31 |
32 | // No need for signposts, should be fast.
33 | loadableProcessors.forEach {
34 | $0.load()
35 | $0.unload()
36 | }
37 |
38 | chunks.enumerated().forEach { (i, chunk) in
39 | signposter.withIntervalSignpost("Prepare", "Warm Chunk \(i)") {
40 | chunk.load()
41 | chunk.unload()
42 | }
43 | print("*", terminator: "")
44 | fflush(stdout)
45 | }
46 | print()
47 | }
48 |
49 | func load() throws {
50 | prewarm()
51 | print("Loading models : ", terminator: "")
52 | fflush(stdout)
53 |
54 | loadableProcessors.forEach {
55 | $0.load()
56 | }
57 |
58 | chunks.enumerated().forEach { (i, chunk) in
59 | signposter.withIntervalSignpost("Prepare", "Load Chunk \(i)") {
60 | chunk.load()
61 | }
62 | print("*", terminator: "")
63 | fflush(stdout)
64 | }
65 | print()
66 |
67 | inferenceConfiguration = .init(from: chunks.compactMap { $0.model })
68 | if inferenceConfiguration == nil {
69 | // Unable to infer the correct model parameters from the model inputs.
70 | // We won't be able to predict.
71 | throw PipelineError.unsupportedInferenceConfiguration
72 | }
73 | }
74 |
75 | func predict(tokens initialTokens: [Int], maxNewTokens: Int) throws -> AsyncThrowingStream {
76 | guard let inferenceConfiguration else {
77 | throw PipelineError.unsupportedInferenceConfiguration
78 | }
79 |
80 | let arrayStore = MultiArrayStore(for: self)
81 | let kvCacheProcessor = KVCacheProcessor(pipeline: self, processorModel: cacheProcessorModel.model!)
82 |
83 | return AsyncThrowingStream { continuation in
84 | let inputLength = inferenceConfiguration.inputLength
85 | var promptChunks = stride(from: 0, to: initialTokens.count, by: inputLength).map {
86 | Array(initialTokens[$0..? = nil
94 | while tokens.count < maxTokens {
95 | let timer = CodeTimer()
96 | let tokenSignpostState = self.signposter.beginInterval("Predict", id: self.signposter.makeSignpostID(), "Token #\(tokens.count)")
97 |
98 | // Do a forward pass of the model.
99 | var logitChunks: [MLMultiArray] = []
100 | for (i, chunk) in self.chunks.enumerated() {
101 | let model = chunk.model!
102 |
103 | // Wait for the KV cache to be asynchronously updated.
104 | try await kvCacheProcessor.wait(forChunk: i)
105 |
106 | let inputs = try arrayStore.featureProvider(forChunk: i, model: model, tokens: tokens)
107 |
108 | let options = MLPredictionOptions()
109 | options.outputBackings = arrayStore.outputBackings(forChunk: i, model: model)
110 |
111 | let predictState = self.signposter.beginInterval("Predict", id: self.signposter.makeSignpostID(), "Chunk \(i)")
112 | let outputs = try await model.prediction(from: inputs, options: options)
113 | self.signposter.endInterval("Predict", predictState)
114 | arrayStore.update(outputs: outputs, forChunk: i) // Using output backings should make this ~a noop.
115 |
116 | // Update the cache (async) each time a full set of input tokens is processed.
117 | // e.g. if the model takes in 64 input tokens at once, update when tokens
118 | // is length 64, 128, 192, etc.
119 | if tokens.count % inputLength == 0 {
120 | kvCacheProcessor.submit(inputs: inputs, outputs: outputs, forChunk: i)
121 | }
122 |
123 | // Sometimes logits are returned in chunks along the vocab dimension.
124 | let logitFeatures = outputs.featureNames
125 | .filter { $0.starts(with: "logit") }
126 | .sorted { ($0.trailingNumberSuffix() ?? -1) < ($1.trailingNumberSuffix() ?? -1) }
127 | logitChunks = logitFeatures.map { outputs.featureValue(for: $0)!.multiArrayValue! }
128 | }
129 |
130 | if !promptChunks.isEmpty {
131 | for token in promptChunks.removeFirst() {
132 | tokens.append(token)
133 | continuation.yield(Prediction(newToken: token, allTokens: tokens, latency: nil, promptLatency: nil))
134 | }
135 | promptLatency = promptTimer.elapsed()
136 | } else {
137 | let newTokenIndex = tokens.isEmpty ? 0 : (tokens.count - 1) % inputLength
138 | let newToken = try await self.logitProcessor.argmax(logits: logitChunks, index: newTokenIndex)
139 | tokens.append(newToken)
140 | continuation.yield(Prediction(newToken: newToken, allTokens: tokens, latency: timer.elapsed(), promptLatency: promptLatency))
141 | promptLatency = nil
142 | }
143 |
144 | self.signposter.endInterval("Predict", tokenSignpostState, "\(tokens.last!)")
145 | }
146 |
147 | continuation.finish()
148 | }
149 | }
150 | }
151 |
152 |
153 | extension ModelPipeline: CustomDebugStringConvertible {
154 | var debugDescription: String {
155 | let fileName = chunks.first?.fileInfo.displayModelName ?? ""
156 | return"\(Self.self) \(fileName) (\(chunks.count) chunks)"
157 | }
158 | }
159 |
160 | extension ModelPipeline {
161 | /// Creates a pipeline from the mlmodelc files in the given folder.
162 | /// Model files should follow the format: `${MODEL_PREFIX}_chunk${CHUNK_NUMBER}.mlmodelc`
163 | /// Does not load the model.
164 | class func from(
165 | folder: URL,
166 | modelPrefix: String?,
167 | cacheProcessorModelName: String,
168 | logitProcessorModelName: String,
169 | primaryCompute: MLComputeUnits = .cpuAndNeuralEngine,
170 | chunkLimit: Int? = nil
171 | ) throws -> ModelPipeline {
172 | let manager = FileManager.default
173 | let contents = try manager.contentsOfDirectory(atPath: folder.path(percentEncoded: false))
174 |
175 | let chunkFiles = contents
176 | .compactMap { ChunkFileInfo(url: folder.appending(path: $0)) }
177 | .filter { $0.url.pathExtension == "mlmodelc" }
178 | .filter {
179 | if let modelPrefix { $0.modelPrefix.hasPrefix(modelPrefix) }
180 | else { true }
181 | }
182 | .sorted(by: { $0.chunkNumber < $1.chunkNumber })
183 |
184 | let uniquePrefixes = Set(chunkFiles.map { $0.modelPrefix })
185 | if uniquePrefixes.count > 1 {
186 | throw PipelineError.ambiguousModelPath(possiblePrefixes: Array(uniquePrefixes))
187 | }
188 |
189 | let chunks = chunkFiles.enumerated()
190 | .filter { (i, _) in
191 | // Allow limiting the number of chunks for debugging.
192 | if i == 0 || i == chunkFiles.count - 1 { return true }
193 | if let chunkLimit { return i <= chunkLimit }
194 | return true
195 | }
196 | .map { (i, chunkFile) in
197 | let config = MLModelConfiguration()
198 | // The first chunk has operations that cannot run on ANE.
199 | config.computeUnits = i == 0 ? .cpuOnly : primaryCompute
200 | config.modelDisplayName = "Chunk \(chunkFile.chunkNumber)"
201 | return PipelineChunk(fileInfo: chunkFile, configuration: config)
202 | }
203 |
204 | if chunks.count == 0 {
205 | throw PipelineError.modelChunksNotFound
206 | }
207 |
208 | // Model for updating KV caches.
209 | let cacheProcessorURL = folder.appending(component: cacheProcessorModelName)
210 | let cacheModelConfig = MLModelConfiguration()
211 | cacheModelConfig.computeUnits = primaryCompute
212 | cacheModelConfig.modelDisplayName = "Cache Processor"
213 | let cacheProcessor = DeferredModel(url: cacheProcessorURL, configuration: cacheModelConfig)
214 |
215 | // Model for choosing the next token.
216 | let logitProcessorURL = folder.appending(component: logitProcessorModelName)
217 | let logitProcessorModelConfig = MLModelConfiguration()
218 | logitProcessorModelConfig.computeUnits = primaryCompute
219 | logitProcessorModelConfig.modelDisplayName = "Logit Processor"
220 | let logitProcessor = LogitProcessor(model: DeferredModel(url: logitProcessorURL, configuration: logitProcessorModelConfig))
221 |
222 | return ModelPipeline(chunks: chunks, cacheProcessor: cacheProcessor, logitProcessor: logitProcessor)
223 | }
224 | }
225 |
226 | class PipelineChunk {
227 | let fileInfo: ChunkFileInfo
228 | let configuration: MLModelConfiguration
229 | var model: MLModel?
230 |
231 | convenience init?(url: URL, configuration: MLModelConfiguration) {
232 | guard let fileInfo = ChunkFileInfo(url: url) else { return nil }
233 | self.init(fileInfo: fileInfo, configuration: configuration)
234 | }
235 |
236 | init(fileInfo: ChunkFileInfo, configuration: MLModelConfiguration) {
237 | self.fileInfo = fileInfo
238 | self.configuration = configuration
239 | }
240 |
241 | func load() {
242 | guard model == nil else { return }
243 | model = try! MLModel(contentsOf: fileInfo.url, configuration: configuration)
244 | }
245 |
246 | func unload() {
247 | model = nil
248 | }
249 | }
250 |
251 | extension PipelineChunk: CustomDebugStringConvertible {
252 | var debugDescription: String {
253 | return "\(Self.self) \(fileInfo.chunkNumber) (\(configuration.computeUnits.debugName))"
254 | }
255 | }
256 |
257 | class DeferredModel {
258 | let url: URL
259 | let configuration: MLModelConfiguration
260 |
261 | var model: MLModel?
262 |
263 | init(url: URL, configuration: MLModelConfiguration) {
264 | self.url = url
265 | self.configuration = configuration
266 | }
267 |
268 | func load() {
269 | model = try! MLModel(contentsOf: url, configuration: configuration)
270 | }
271 |
272 | func unload() {
273 | model = nil
274 | }
275 | }
276 |
277 | struct ChunkFileInfo {
278 | let url: URL
279 | let fileName: String
280 | let modelPrefix: String
281 | let chunkNumber: Int
282 |
283 | init?(url: URL) {
284 | self.url = url
285 | self.fileName = url.lastPathComponent
286 |
287 | var split = url.deletingPathExtension().lastPathComponent.split(separator: "_")
288 | guard
289 | let chunkString = split.popLast(),
290 | let chunkNumber = Int(chunkString.replacingOccurrences(of: "chunk", with: ""))
291 | else {
292 | return nil
293 | }
294 |
295 | self.modelPrefix = split.joined(separator: "_") + "_"
296 | self.chunkNumber = chunkNumber
297 | }
298 |
299 | var displayModelName: String {
300 | // Drop the last _
301 | String(modelPrefix.prefix(upTo: modelPrefix.index(before: modelPrefix.endIndex)))
302 | }
303 | }
304 |
305 | enum PipelineError: Error {
306 | case unsupportedInferenceConfiguration
307 | case cacheProcessorNotFound
308 | case modelChunksNotFound
309 | case ambiguousModelPath(possiblePrefixes: [String])
310 | case notImplementedError
311 | }
312 |
313 | struct Prediction {
314 | let newToken: Int
315 | let allTokens: [Int]
316 | let latency: Measurement?
317 | let promptLatency: Measurement?
318 | }
319 |
320 | struct CodeTimer {
321 | let start = CFAbsoluteTimeGetCurrent()
322 |
323 | func elapsed() -> Measurement {
324 | let seconds = CFAbsoluteTimeGetCurrent() - start
325 | return Measurement(value: seconds, unit: .seconds)
326 | }
327 | }
328 |
329 | protocol Loadable {
330 | func load()
331 | func unload()
332 | }
333 |
334 | extension LogitProcessor: Loadable {}
335 | extension DeferredModel: Loadable {}
336 | extension PipelineChunk: Loadable {}
337 |
338 | extension String {
339 | func trailingNumberSuffix() -> Int? {
340 | guard let number = self.split(separator: "_").last else { return nil }
341 | return Int(number)
342 | }
343 | }
344 |
--------------------------------------------------------------------------------
/Sources/Kit/MultiArrayStore.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import CoreML
3 | import OSLog
4 |
5 | /// Store and make accessible MLMultiArrays that will be reused
6 | /// in either subsequent chunks or subsequent forward passes.
7 | class MultiArrayStore {
8 | // Arrays that are used as inputs/outputs for multiple chunks.
9 | // Names must be unique across the whole pipeline.
10 | let sharedArrays: [String: MLMultiArray]
11 | // Arrays that are only used with one chunk.
12 | // Names must only be unique within a chunk. e.g. KV cache inputs/outputs.
13 | let chunkArrays: [[String: MLMultiArray]]
14 |
15 | // Map from output names to a different key name to use for the output
16 | // backing. Useful when the input is consumed and the output is the same shape.
17 | var outputBackingMapping: [String: String]
18 |
19 | let signposter = OSSignposter(subsystem: "com.stephenpanaro.llm-cli", category: "MultiArrayStore")
20 |
21 | init(sharedArrays: [String : MLMultiArray],
22 | chunkArrays: [[String : MLMultiArray]],
23 | outputBackingMapping: [String: String] = [:]
24 | ) {
25 | self.sharedArrays = sharedArrays
26 | self.chunkArrays = chunkArrays
27 | self.outputBackingMapping = outputBackingMapping
28 | }
29 |
30 | func inputFeatures(forChunk chunkIndex: Int, model: MLModel) -> [String: MLMultiArray] {
31 | Dictionary(uniqueKeysWithValues: model.modelDescription.inputDescriptionsByName.keys.compactMap { name in
32 | self.array(named: name, for: chunkIndex).map { arr in (name, arr) }
33 | })
34 | }
35 |
36 | func outputBackings(forChunk chunkIndex: Int, model: MLModel) -> [String: MLMultiArray] {
37 | Dictionary(uniqueKeysWithValues: model.modelDescription.outputDescriptionsByName.keys.compactMap { name in
38 | let backingName = outputBackingMapping[name] ?? name
39 | return self.array(named: backingName, for: chunkIndex).map { arr in (name, arr) }
40 | })
41 | }
42 |
43 | func update(outputs: MLFeatureProvider, forChunk chunkIndex: Int) {
44 | let state = signposter.beginInterval("Update Outputs", "Chunk \(chunkIndex)")
45 | defer { signposter.endInterval("Update Outputs", state) }
46 |
47 | outputs.featureNames.forEach { name in
48 | // Only float16 buffers are cached.
49 | guard let multiArrayValue = outputs.featureValue(for: name)?.multiArrayValue,
50 | multiArrayValue.dataType == .float16
51 | else { return }
52 |
53 | guard let pixelBuffer = multiArrayValue.pixelBuffer
54 | else { preconditionFailure("output array is not pixel buffer backed: \(name)") }
55 |
56 | let cacheName = outputBackingMapping[name] ?? name
57 | guard let cachePixelBuffer = self.array(named: cacheName, for: chunkIndex)?.pixelBuffer
58 | else { preconditionFailure("output array is not present in the cache: \(name)") }
59 |
60 | // TODO: Gracefully degrade by manually copying the array.
61 | precondition(pixelBuffer === cachePixelBuffer, "output backing not used: \(name)")
62 | }
63 | }
64 |
65 | fileprivate func array(named name: String, for chunkIndex: Int) -> MLMultiArray? {
66 | if let sharedArr = sharedArrays[name] {
67 | return sharedArr
68 | } else if let chunkArr = chunkArrays[chunkIndex][name] {
69 | return chunkArr
70 | }
71 | return nil
72 | }
73 | }
74 |
75 | extension MultiArrayStore {
76 | convenience init(for pipeline: ModelPipeline) {
77 | var sharedArrays = [String: MLMultiArray]()
78 | var chunkArrays = [[String: MLMultiArray]]()
79 |
80 | // For each output key, reuse the input array named by the value.
81 | let outputBackingMapping = ["new_x": "x"]
82 |
83 | // KV caches, cos + sin for RoPE, attention mask
84 | pipeline.chunks.forEach { chunk in
85 | let model = chunk.model!
86 | var floatShapes = [String: [Int]]()
87 |
88 | let modelDescription = model.modelDescription
89 | let allDescriptions = modelDescription.inputDescriptionsByName.merging(
90 | modelDescription.outputDescriptionsByName, uniquingKeysWith: { a,b in a })
91 |
92 | allDescriptions.forEach { (name, desc) in
93 | guard let constraint = desc.multiArrayConstraint else { return }
94 | if constraint.dataType != .float16 { return }
95 | if outputBackingMapping.keys.contains(name) { return }
96 | floatShapes[name] = constraint.shape.map { $0.intValue }
97 | }
98 |
99 | let isShared = { (key: String) -> Bool in !key.contains("cache") }
100 | let sharedShapes = floatShapes.filter { isShared($0.key) }
101 | let chunkShapes = floatShapes.filter { !isShared($0.key) }
102 |
103 | // TODO: Sort so K caches are allocated close to other K caches. Ditto for V.
104 | // See if it has a performance impact.
105 |
106 | sharedShapes.forEach { (name, shape) in
107 | if sharedArrays[name] == nil {
108 | sharedArrays[name] = MLMultiArray.emptyIOSurfaceArray(shape: shape)!
109 | }
110 | }
111 |
112 | var currChunkArrays = [String:MLMultiArray]()
113 | chunkShapes.forEach { (name, shape) in
114 | currChunkArrays[name] = MLMultiArray.emptyIOSurfaceArray(shape: shape)!
115 | }
116 | chunkArrays.append(currChunkArrays)
117 | }
118 |
119 | self.init(sharedArrays: sharedArrays, chunkArrays: chunkArrays, outputBackingMapping: outputBackingMapping)
120 | }
121 | }
122 |
123 |
124 | extension MultiArrayStore {
125 | func featureProvider(forChunk chunkIndex: Int, model: MLModel, tokens: [Int]) throws -> MLFeatureProvider {
126 | let state = signposter.beginInterval("Prepare Features", "Chunk \(chunkIndex)")
127 | defer { signposter.endInterval("Prepare Features", state) }
128 |
129 | var cacheFeatures = inputFeatures(forChunk: chunkIndex, model: model)
130 |
131 | // Input IDs is populated in concert with KV cache updates.
132 | //
133 | // 0 1 2 3 4
134 | // _ _ _ _ _ _ _ _ _ _ _ _ _
135 | // ┌─────────────────────────┬────────┐
136 | // └───────────Cache─────────┴─Inputs─┘
137 | //
138 | //
139 | // 5 _ _ _ _
140 | // _ _ _ _ _ _ _ _ 0 1 2 3 4
141 | // ┌─────────────────────────┬────────┐
142 | // └───────────Cache─────────┴─Inputs─┘
143 | //
144 | //
145 | // 5 6 _ _ _
146 | // _ _ _ _ _ _ _ _ 0 1 2 3 4
147 | // ┌─────────────────────────┬────────┐
148 | // └───────────Cache─────────┴─Inputs─┘
149 | //
150 | // ┌ ─ ─ ─ ─ ─ ─ ┐
151 | // Same for 7-9
152 | // └ ─ ─ ─ ─ ─ ─ ┘
153 | //
154 | // 5 6 7 8 9
155 | // _ _ _ _ _ _ _ _ 0 1 2 3 4
156 | // ┌─────────────────────────┬────────┐
157 | // └───────────Cache─────────┴─Inputs─┘
158 | //
159 | // ◀━━━━━━━━━━━━━Shift!━━━━━━━━━━━━━━━
160 | //
161 | // A _ _ _ _
162 | // _ _ _ 0 1 2 3 4 5 6 7 8 9
163 | // ┌─────────────────────────┬────────┐
164 | // └───────────Cache─────────┴─Inputs─┘
165 | //
166 | // ┌ ─ ─ ─ ─ ─
167 | // Repeat │
168 | // └ ─ ─ ─ ─ ─
169 | var padCount = 0
170 | let inputDescriptions = model.modelDescription.inputDescriptionsByName
171 | if let inputIDsConstraint = inputDescriptions["input_ids"]?.multiArrayConstraint {
172 | let inputShape = inputIDsConstraint.shape.map { $0.intValue }
173 | let inputLength = inputShape.last!
174 |
175 | // For inputLength 64, tokens.count -> suffixLength:
176 | // 0 -> 0, 5 -> 5, 64 -> 64, 65 -> 1, 128 -> 64
177 | let suffixLength = tokens.isEmpty ? 0 : (tokens.count - 1) % inputLength + 1
178 | let inputTokens = tokens.suffix(suffixLength).map { Int32($0) }
179 | padCount = max(0, inputLength - inputTokens.count)
180 | let paddedInputTokens = inputTokens + Array(repeating: Int32(0), count: padCount)
181 | let inputIDs = MLShapedArray(scalars: paddedInputTokens, shape: inputShape)
182 | cacheFeatures["input_ids"] = MLMultiArray(inputIDs)
183 | }
184 |
185 | if inputDescriptions["full_sequence_length"]?.multiArrayConstraint != nil {
186 | // For example: [Cache0 Cache1][Input2 Input3 Pad0]
187 | // The causal mask prevents Pad0 from impacting
188 | // the InputN token predictions. We must include it
189 | // here so RoPE is correct for the InputN tokens.
190 | let fullSequenceLength = MLShapedArray(repeating: Int32(tokens.count + padCount), shape: [1])
191 | cacheFeatures["full_sequence_length"] = MLMultiArray(fullSequenceLength)
192 | }
193 |
194 | return try MLDictionaryFeatureProvider(dictionary: cacheFeatures)
195 | }
196 | }
197 |
--------------------------------------------------------------------------------
/Sources/Kit/PipelineInferenceConfiguration.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import CoreML
3 |
4 | struct PipelineInferenceConfiguration {
5 | let vocabSize: Int
6 | let inputLength: Int // query length
7 | let contextLength: Int // key + value length
8 |
9 | var cacheLength: Int {
10 | contextLength - inputLength
11 | }
12 | }
13 |
14 | extension PipelineInferenceConfiguration {
15 | init?(from models: [MLModel]) {
16 | guard models.count > 2 else { return nil }
17 |
18 | guard let first = models.first, let last = models.last
19 | else { return nil }
20 | let innerModels = models[1.. Element {
5 | return self.reduce(0, +)
6 | }
7 |
8 | func mean() -> Element {
9 | return self.sum() / Element(self.count)
10 | }
11 |
12 | func stdev() -> Element {
13 | let mean = self.mean()
14 | let v = self.reduce(0, { $0 + ($1-mean)*($1-mean) })
15 | return sqrt(v / (Element(self.count) - 1))
16 | }
17 | }
18 |
19 | extension AsyncThrowingStream {
20 | /// Initialize a stream that runs its body closure asynchronously.
21 | public init(
22 | _ elementType: Element.Type = Element.self,
23 | bufferingPolicy limit: AsyncThrowingStream.Continuation.BufferingPolicy = .unbounded,
24 | _ build: @Sendable @escaping (AsyncThrowingStream.Continuation) async throws -> Void
25 | ) where Failure == Error {
26 | self = AsyncThrowingStream(elementType, bufferingPolicy: limit) { continuation in
27 | let task = Task {
28 | try await build(continuation)
29 | }
30 |
31 | continuation.onTermination = { _ in
32 | task.cancel()
33 | }
34 | }
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/Sources/Kit/TextGenerator.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import Tokenizers
3 |
4 | /// TextGenerator uses an LLM `ModelPipeline` to generate text.
5 | class TextGenerator {
6 | let pipeline: ModelPipeline
7 | let tokenizer: Tokenizer
8 | // var strategy: Strategy = .argmax // etc.
9 |
10 | init(pipeline: ModelPipeline, tokenizer: Tokenizer) {
11 | self.pipeline = pipeline
12 | self.tokenizer = tokenizer
13 | }
14 |
15 | func generate(text: String, maxNewTokens: Int) async throws {
16 |
17 | let loadTimer = CodeTimer()
18 | try pipeline.load()
19 | let loadDuration = loadTimer.elapsed()
20 |
21 | let tokens = tokenizer.encode(text: text)
22 |
23 | var predictions = [Prediction]()
24 | tokens.forEach { print($0, terminator: " ") }
25 | fflush(stdout)
26 |
27 | for try await prediction in try pipeline.predict(tokens: tokens, maxNewTokens: maxNewTokens) {
28 | predictions.append(prediction)
29 | print(prediction.newToken, terminator: " ")
30 | fflush(stdout)
31 | }
32 | print("\n")
33 |
34 | print(tokenizer.decode(tokens: predictions.last?.allTokens ?? tokens))
35 | print()
36 |
37 | print("Compile + Load: \(loadDuration.converted(to: .seconds).value.formatted(.number.precision(.fractionLength(2)))) sec")
38 | let numberFormat = FloatingPointFormatStyle.number.precision(.fractionLength(2))
39 |
40 | let promptLatencies = predictions.compactMap { $0.promptLatency?.converted(to: .milliseconds).value }
41 | if !promptLatencies.isEmpty {
42 | let averagePrompt = Measurement(value: promptLatencies.mean(), unit: UnitDuration.milliseconds)
43 | print("Prompt : \(averagePrompt.value.formatted(numberFormat)) ms")
44 | }
45 |
46 | print("Generate :", terminator: " ")
47 |
48 | let latencies = predictions.compactMap { $0.latency?.converted(to: .milliseconds).value }
49 | let average = Measurement(value: latencies.mean(), unit: UnitDuration.milliseconds)
50 | let stdev = Measurement(value: latencies.stdev(), unit: UnitDuration.milliseconds)
51 | print("\(average.value.formatted(numberFormat)) +/- \(stdev.value.formatted(numberFormat)) ms / token")
52 |
53 | let throughputs = predictions.compactMap { $0.latency?.converted(to: .seconds).value }.map { 1 / $0 }
54 | let averageThroughput = Measurement(value: throughputs.mean(), unit: UnitDuration.seconds)
55 | let stdevThroughput = Measurement(value: throughputs.stdev(), unit: UnitDuration.seconds)
56 | print(" \(averageThroughput.value.formatted(numberFormat)) +/- \(stdevThroughput.value.formatted(numberFormat)) token / sec")
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/Sources/Kit/Transformers+Extensions.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import Tokenizers
3 | import Hub
4 |
5 | extension AutoTokenizer {
6 | /// Load the tokenizer from a cache when offline.
7 | /// Workaround until https://github.com/huggingface/swift-transformers/issues/39
8 | static func from(
9 | cached model: String,
10 | hubApi: HubApi = .shared
11 | ) async throws -> Tokenizer {
12 | // Network first since there is no cache invalidation.
13 | do {
14 | let config = LanguageModelConfigurationFromHub(modelName: model, hubApi: hubApi)
15 | guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
16 | let tokenizerData = try await config.tokenizerData
17 |
18 | return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
19 | } catch {
20 | let modelFolder = hubApi.localRepoLocation(.init(id: model))
21 | return try await from(modelFolder: modelFolder, hubApi: hubApi)
22 | }
23 | }
24 | }
25 |
26 | // Not public, so copied over.
27 | enum TokenizerError : Error {
28 | case missingConfig
29 | }
30 |
31 | extension HubApi {
32 | static var defaultTokenLocation: URL {
33 | FileManager.default.homeDirectoryForCurrentUser.appending(path: ".cache/huggingface/token")
34 | }
35 |
36 | static func defaultToken() -> String? {
37 | if let envToken = ProcessInfo.processInfo.environment["HF_TOKEN"] {
38 | return envToken
39 | }
40 | return try? String(contentsOf: defaultTokenLocation, encoding: .utf8)
41 | }
42 | }
43 |
--------------------------------------------------------------------------------