├── .swift-format ├── Moshi ├── Assets.xcassets │ ├── Contents.json │ ├── AccentColor.colorset │ │ └── Contents.json │ └── AppIcon.appiconset │ │ └── Contents.json ├── Preview Content │ └── Preview Assets.xcassets │ │ └── Contents.json ├── Info.plist ├── moshi.entitlements ├── moshiApp.swift ├── DeviceStat.swift ├── AudioRT.swift ├── ModelView.swift └── ContentView.swift ├── moshi.xcodeproj ├── project.xcworkspace │ ├── contents.xcworkspacedata │ └── xcshareddata │ │ └── swiftpm │ │ └── Package.resolved └── xcshareddata │ └── xcschemes │ ├── moshi-cli.xcscheme │ └── Moshi.xcscheme ├── .github └── PULL_REQUEST_TEMPLATE.md ├── MoshiLib ├── MoshiLib.docc │ └── MoshiLib.md ├── MoshiLib.h ├── Utils.swift ├── ASR.swift ├── Streaming.swift ├── Mimi.swift ├── Qwen2.swift ├── KVCache.swift ├── Quantization.swift ├── Perf.swift ├── Seanet.swift ├── Conv.swift ├── Transformer.swift └── LM.swift ├── MoshiTests └── moshiTests.swift ├── .gitignore ├── MoshiLibTests └── moshi_libTests.swift ├── scripts ├── sp.py └── eval_mlx.py ├── Makefile ├── MoshiUITests ├── moshiUITestsLaunchTests.swift └── moshiUITests.swift ├── LICENSE ├── MoshiCLI ├── RunHelium.swift ├── Audio.swift ├── AudioRT.swift ├── RunMimi.swift ├── RunMoshi.swift └── CLI.swift ├── README.md └── CONTRIBUTING.md /.swift-format: -------------------------------------------------------------------------------- 1 | { 2 | "lineLength": 100, 3 | "indentation": { 4 | "spaces": 4 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Moshi/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Moshi/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /moshi.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Moshi/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Moshi/Info.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | UIFileSharingEnabled 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Checklist 2 | 3 | - [ ] Read CONTRIBUTING.md, and accept the CLA by including the provided snippet. We will not accept PR without this. 4 | - [ ] Test manually that the app is still working. 5 | 6 | ## PR Description 7 | 8 | 9 | -------------------------------------------------------------------------------- /MoshiLib/MoshiLib.docc/MoshiLib.md: -------------------------------------------------------------------------------- 1 | # ``MoshiLib`` 2 | 3 | Summary 4 | 5 | ## Overview 6 | 7 | Text 8 | 9 | ## Topics 10 | 11 | ### Group 12 | 13 | - ``Symbol`` 14 | -------------------------------------------------------------------------------- /MoshiTests/moshiTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // moshiTests.swift 3 | // moshiTests 4 | // 5 | // Created by Laurent on 20/11/2024. 6 | // 7 | 8 | import Testing 9 | 10 | struct moshiTests { 11 | 12 | @Test func example() async throws { 13 | // Write your test here and use APIs like `#expect(...)` to check expected conditions. 14 | } 15 | 16 | } 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .helix 2 | .vscode 3 | *.dylib 4 | *.so 5 | *.swp 6 | *.swo 7 | trace-*.json 8 | .DS_Store 9 | .idea/* 10 | __pycache__ 11 | build 12 | buildServer.json 13 | xcuserdata 14 | model*.safetensors 15 | mimi-*.safetensors 16 | moshi-*.safetensors 17 | codes-*.safetensors 18 | asr-*.safetensors 19 | moshi-out.wav 20 | moshi-trace.json 21 | bria*safetensors 22 | bria*wav 23 | -------------------------------------------------------------------------------- /MoshiLibTests/moshi_libTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // moshi_libTests.swift 3 | // moshi-libTests 4 | // 5 | // Created by Laurent on 20/11/2024. 6 | // 7 | 8 | import Testing 9 | 10 | @testable import moshi_lib 11 | 12 | struct moshi_libTests { 13 | 14 | @Test func example() async throws { 15 | // Write your test here and use APIs like `#expect(...)` to check expected conditions. 16 | } 17 | 18 | } 19 | -------------------------------------------------------------------------------- /scripts/sp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sentencepiece 4 | 5 | home_dir = os.path.expanduser("~") 6 | text_tokenizer = sentencepiece.SentencePieceProcessor(home_dir + "/tmp/tokenizer_spm_32k_3.model") 7 | 8 | vocab = {} 9 | for id in range(text_tokenizer.vocab_size()): 10 | vocab[id] = text_tokenizer.id_to_piece(id) 11 | with open(home_dir + "/tmp/tokenizer_spm_32k_3.json", "w") as json_file: 12 | json.dump(vocab, json_file) 13 | -------------------------------------------------------------------------------- /MoshiLib/MoshiLib.h: -------------------------------------------------------------------------------- 1 | // 2 | // moshi_lib.h 3 | // moshi-lib 4 | // 5 | // Created by Laurent on 20/11/2024. 6 | // 7 | 8 | #import 9 | 10 | //! Project version number for moshi_lib. 11 | FOUNDATION_EXPORT double MoshiLibVersionNumber; 12 | 13 | //! Project version string for moshi_lib. 14 | FOUNDATION_EXPORT const unsigned char MoshiLibVersionString[]; 15 | 16 | // In this header, you should import all the public headers of your framework using statements like #import 17 | 18 | 19 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: format run-1b run-asr build 2 | 3 | format: 4 | swift-format format --in-place --recursive . 5 | 6 | run-1b: build 7 | ./build/Build/Products/Release/MoshiCLI run 8 | 9 | run-asr: build 10 | ./build/Build/Products/Release/MoshiCLI run-asr 11 | 12 | run-mimi: build 13 | ./build/Build/Products/Release/MoshiCLI run-mimi 14 | 15 | run-helium: build 16 | ./build/Build/Products/Release/MoshiCLI run-helium 17 | 18 | run-qwen: build 19 | ./build/Build/Products/Release/MoshiCLI run-qwen 20 | 21 | build: 22 | xcodebuild -scheme moshi-cli -derivedDataPath ./build 23 | -------------------------------------------------------------------------------- /Moshi/moshi.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.device.audio-input 10 | 11 | com.apple.security.files.downloads.read-write 12 | 13 | com.apple.security.files.user-selected.read-write 14 | 15 | com.apple.security.network.client 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /MoshiUITests/moshiUITestsLaunchTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // moshiUITestsLaunchTests.swift 3 | // moshiUITests 4 | // 5 | // Created by Laurent on 20/11/2024. 6 | // 7 | 8 | import XCTest 9 | 10 | final class moshiUITestsLaunchTests: XCTestCase { 11 | 12 | override class var runsForEachTargetApplicationUIConfiguration: Bool { 13 | true 14 | } 15 | 16 | override func setUpWithError() throws { 17 | continueAfterFailure = false 18 | } 19 | 20 | @MainActor 21 | func testLaunch() throws { 22 | let app = XCUIApplication() 23 | app.launch() 24 | 25 | // Insert steps here to perform after app launch but before taking a screenshot, 26 | // such as logging into a test account or navigating somewhere in the app 27 | 28 | let attachment = XCTAttachment(screenshot: app.screenshot()) 29 | attachment.name = "Launch Screen" 30 | attachment.lifetime = .keepAlways 31 | add(attachment) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /MoshiUITests/moshiUITests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // moshiUITests.swift 3 | // moshiUITests 4 | // 5 | // Created by Laurent on 20/11/2024. 6 | // 7 | 8 | import XCTest 9 | 10 | final class moshiUITests: XCTestCase { 11 | 12 | override func setUpWithError() throws { 13 | // Put setup code here. This method is called before the invocation of each test method in the class. 14 | 15 | // In UI tests it is usually best to stop immediately when a failure occurs. 16 | continueAfterFailure = false 17 | 18 | // In UI tests it’s important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this. 19 | } 20 | 21 | override func tearDownWithError() throws { 22 | // Put teardown code here. This method is called after the invocation of each test method in the class. 23 | } 24 | 25 | @MainActor 26 | func testExample() throws { 27 | // UI tests must launch the application that they test. 28 | let app = XCUIApplication() 29 | app.launch() 30 | 31 | // Use XCTAssert and related functions to verify your tests produce the correct results. 32 | } 33 | 34 | @MainActor 35 | func testLaunchPerformance() throws { 36 | if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 7.0, *) { 37 | // This measures how long it takes to launch your application. 38 | measure(metrics: [XCTApplicationLaunchMetric()]) { 39 | XCUIApplication().launch() 40 | } 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /Moshi/moshiApp.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import AVFoundation 6 | import Foundation 7 | import SwiftUI 8 | 9 | func requestMicrophoneAccess() { 10 | switch AVCaptureDevice.authorizationStatus(for: .audio) { 11 | case .authorized: 12 | return 13 | case .notDetermined: 14 | AVCaptureDevice.requestAccess(for: .audio) { granted in 15 | print("granted", granted) 16 | } 17 | case .denied: // The user has previously denied access. 18 | return 19 | case .restricted: // The user can't grant access due to restrictions. 20 | return 21 | case _: 22 | return 23 | } 24 | } 25 | 26 | @main 27 | struct moshiApp: App { 28 | @Environment(\.scenePhase) var scenePhase 29 | 30 | init() { 31 | requestMicrophoneAccess() 32 | } 33 | 34 | var body: some Scene { 35 | WindowGroup { 36 | ContentView() 37 | .environment(DeviceStat()) 38 | } 39 | #if os(iOS) 40 | .onChange(of: scenePhase) { (phase) in 41 | switch phase { 42 | case .active: 43 | // In iOS 13+, idle timer needs to be set in scene to override default 44 | UIApplication.shared.isIdleTimerDisabled = true 45 | case .inactive: break 46 | case .background: break 47 | @unknown default: print("ScenePhase: unexpected state") 48 | } 49 | } 50 | #endif 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /Moshi/DeviceStat.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLX 3 | 4 | enum ThermalState: String { 5 | case nominal = "Nominal" 6 | case fair = "Fair" 7 | case serious = "Serious" 8 | case critical = "Critical" 9 | case unknown = "Unknown" 10 | } 11 | 12 | @Observable 13 | final class DeviceStat: @unchecked Sendable { 14 | 15 | @MainActor 16 | var gpuUsage = GPU.snapshot() 17 | @MainActor 18 | var thermalState: ThermalState = .nominal 19 | 20 | private let initialGPUSnapshot = GPU.snapshot() 21 | private var timer: Timer? 22 | 23 | init() { 24 | timer = Timer.scheduledTimer(withTimeInterval: 2.0, repeats: true) { [weak self] _ in 25 | self?.updateGPUUsages() 26 | } 27 | } 28 | 29 | deinit { 30 | timer?.invalidate() 31 | } 32 | 33 | private func updateGPUUsages() { 34 | let gpuSnapshotDelta = initialGPUSnapshot.delta(GPU.snapshot()) 35 | let thermalState = ProcessInfo.processInfo.thermalState 36 | DispatchQueue.main.async { [weak self] in 37 | self?.gpuUsage = gpuSnapshotDelta 38 | switch thermalState { 39 | case .nominal: 40 | self?.thermalState = .nominal 41 | case .fair: 42 | self?.thermalState = .fair 43 | case .serious: 44 | self?.thermalState = .serious 45 | case .critical: 46 | self?.thermalState = .critical 47 | default: 48 | self?.thermalState = .unknown 49 | } 50 | } 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /Moshi/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "appearances" : [ 10 | { 11 | "appearance" : "luminosity", 12 | "value" : "dark" 13 | } 14 | ], 15 | "idiom" : "universal", 16 | "platform" : "ios", 17 | "size" : "1024x1024" 18 | }, 19 | { 20 | "appearances" : [ 21 | { 22 | "appearance" : "luminosity", 23 | "value" : "tinted" 24 | } 25 | ], 26 | "idiom" : "universal", 27 | "platform" : "ios", 28 | "size" : "1024x1024" 29 | }, 30 | { 31 | "idiom" : "mac", 32 | "scale" : "1x", 33 | "size" : "16x16" 34 | }, 35 | { 36 | "idiom" : "mac", 37 | "scale" : "2x", 38 | "size" : "16x16" 39 | }, 40 | { 41 | "idiom" : "mac", 42 | "scale" : "1x", 43 | "size" : "32x32" 44 | }, 45 | { 46 | "idiom" : "mac", 47 | "scale" : "2x", 48 | "size" : "32x32" 49 | }, 50 | { 51 | "idiom" : "mac", 52 | "scale" : "1x", 53 | "size" : "128x128" 54 | }, 55 | { 56 | "idiom" : "mac", 57 | "scale" : "2x", 58 | "size" : "128x128" 59 | }, 60 | { 61 | "idiom" : "mac", 62 | "scale" : "1x", 63 | "size" : "256x256" 64 | }, 65 | { 66 | "idiom" : "mac", 67 | "scale" : "2x", 68 | "size" : "256x256" 69 | }, 70 | { 71 | "idiom" : "mac", 72 | "scale" : "1x", 73 | "size" : "512x512" 74 | }, 75 | { 76 | "idiom" : "mac", 77 | "scale" : "2x", 78 | "size" : "512x512" 79 | } 80 | ], 81 | "info" : { 82 | "author" : "xcode", 83 | "version" : 1 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /MoshiCLI/RunHelium.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import AVFoundation 6 | import Foundation 7 | import MLX 8 | import MLXNN 9 | import MoshiLib 10 | 11 | func makeHelium(_ url: URL, _ cfg: LmConfig) throws -> LM { 12 | let weights = try loadArrays(url: url) 13 | let parameters = ModuleParameters.unflattened(weights) 14 | let model = LM(cfg, bSize: 1) 15 | if url.lastPathComponent.hasSuffix("q4.safetensors") { 16 | quantize(model: model, groupSize: 64, bits: 4) 17 | } else if url.lastPathComponent.hasSuffix("q6.safetensors") { 18 | quantize(model: model, groupSize: 64, bits: 6) 19 | } else if url.lastPathComponent.hasSuffix("q8.safetensors") { 20 | quantize(model: model, groupSize: 64, bits: 8) 21 | } 22 | try model.update(parameters: parameters, verify: [.all]) 23 | eval(model) 24 | return model 25 | } 26 | 27 | func runHelium(_ url: URL, cfg: LmConfig) throws { 28 | let stats = PerfStats() 29 | let helium = try makeHelium(url, cfg) 30 | let vocab = try loadVocab(cfg) 31 | helium.warmup() 32 | print("done warming up") 33 | 34 | let maxSteps = helium.cfg.transformer.maxSeqLen 35 | let sampler = Sampler() 36 | 37 | var lastToken = MLXArray([1]) 38 | for stepIdx in 0...maxSteps { 39 | let (textToken, _) = helium.sample( 40 | textIds: lastToken.reshaped([1, 1]), audioIds: [], stepIdx: stepIdx, 41 | textSampler: sampler, 42 | audioSampler: sampler, cb: stats) 43 | let textTokenI: Int = textToken[0].item() 44 | if var v = vocab[textTokenI] { 45 | if v == "<0x0A>" { 46 | print() 47 | } else { 48 | v.replace("▁", with: " ") 49 | print(v, terminator: "") 50 | fflush(stdout) 51 | } 52 | } 53 | lastToken = textToken 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /moshi.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "originHash" : "982acab5e54d309f8a98a66985086b61e82a869bc8ebd717682e681f43aac314", 3 | "pins" : [ 4 | { 5 | "identity" : "jinja", 6 | "kind" : "remoteSourceControl", 7 | "location" : "https://github.com/johnmai-dev/Jinja", 8 | "state" : { 9 | "revision" : "bbddb92fc51ae420b87300298370fd1dfc308f73", 10 | "version" : "1.1.1" 11 | } 12 | }, 13 | { 14 | "identity" : "mlx-swift", 15 | "kind" : "remoteSourceControl", 16 | "location" : "https://github.com/ml-explore/mlx-swift", 17 | "state" : { 18 | "revision" : "70dbb62128a5a1471a5ab80363430adb33470cab", 19 | "version" : "0.21.2" 20 | } 21 | }, 22 | { 23 | "identity" : "swift-argument-parser", 24 | "kind" : "remoteSourceControl", 25 | "location" : "https://github.com/apple/swift-argument-parser.git", 26 | "state" : { 27 | "revision" : "41982a3656a71c768319979febd796c6fd111d5c", 28 | "version" : "1.5.0" 29 | } 30 | }, 31 | { 32 | "identity" : "swift-collections", 33 | "kind" : "remoteSourceControl", 34 | "location" : "https://github.com/apple/swift-collections.git", 35 | "state" : { 36 | "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7", 37 | "version" : "1.1.4" 38 | } 39 | }, 40 | { 41 | "identity" : "swift-numerics", 42 | "kind" : "remoteSourceControl", 43 | "location" : "https://github.com/apple/swift-numerics", 44 | "state" : { 45 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", 46 | "version" : "1.0.2" 47 | } 48 | }, 49 | { 50 | "identity" : "swift-transformers", 51 | "kind" : "remoteSourceControl", 52 | "location" : "https://github.com/huggingface/swift-transformers", 53 | "state" : { 54 | "revision" : "8a83416cc00ab07a5de9991e6ad817a9b8588d20", 55 | "version" : "0.1.15" 56 | } 57 | } 58 | ], 59 | "version" : 3 60 | } 61 | -------------------------------------------------------------------------------- /MoshiCLI/Audio.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import AVFoundation 6 | import Foundation 7 | 8 | func readAudioToPCMArray(fileURL: URL, channel: Int = 0) -> [Float]? { 9 | do { 10 | let audioFile = try AVAudioFile(forReading: fileURL) 11 | let format = audioFile.processingFormat 12 | let frameCount = UInt32(audioFile.length) 13 | guard let buffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: frameCount) else { 14 | print("failed to create buffer") 15 | return nil 16 | } 17 | 18 | try audioFile.read(into: buffer) 19 | guard let channelData = buffer.floatChannelData else { 20 | print("failed to get channel data") 21 | return nil 22 | } 23 | let _ = Int(format.channelCount) 24 | let frameLength = Int(buffer.frameLength) 25 | var pcmData: [Float] = [] 26 | let samples = channelData[channel] 27 | pcmData.append(contentsOf: UnsafeBufferPointer(start: samples, count: frameLength)) 28 | return pcmData 29 | } catch { 30 | print("error reading audio file: \(error)") 31 | return nil 32 | } 33 | } 34 | 35 | func writeWAVFile(_ pcmData: [Float], sampleRate: Double, outputURL: URL) throws { 36 | let format = AVAudioFormat( 37 | standardFormatWithSampleRate: sampleRate, channels: AVAudioChannelCount(1)) 38 | guard let format = format else { 39 | throw NSError( 40 | domain: "AudioFormat", code: -1, 41 | userInfo: [NSLocalizedDescriptionKey: "Failed to create audio format"]) 42 | } 43 | let audioFile = try AVAudioFile(forWriting: outputURL, settings: format.settings) 44 | guard 45 | let buffer = AVAudioPCMBuffer( 46 | pcmFormat: format, frameCapacity: AVAudioFrameCount(pcmData.count)) 47 | else { 48 | throw NSError( 49 | domain: "PCMBuffer", code: -1, 50 | userInfo: [NSLocalizedDescriptionKey: "Failed to create PCM buffer"]) 51 | } 52 | buffer.frameLength = AVAudioFrameCount(pcmData.count) 53 | let channelData = buffer.floatChannelData! 54 | channelData[0].update(from: Array(pcmData), count: pcmData.count) 55 | try audioFile.write(from: buffer) 56 | } 57 | -------------------------------------------------------------------------------- /MoshiLib/Utils.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import MLX 6 | import MLXFast 7 | import MLXNN 8 | import MLXRandom 9 | 10 | func topKSampling(logits: MLXArray, topK: Int, temp: Float) -> MLXArray { 11 | let c = logits.dim(-1) 12 | let sortedIndices = argSort(logits, axis: -1) 13 | let sortedLogits = logits[0..., sortedIndices.squeezed(axis: 0)] 14 | let topLogits = `where`(MLXArray(0..= c - topK, sortedLogits, -Float.infinity) 15 | let sortedToken = categorical(topLogits / temp) 16 | let token = sortedIndices.squeezed(axis: 0)[sortedToken] 17 | return token 18 | } 19 | 20 | func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArray { 21 | let probs = softmax(logits * (1 / temp), axis: -1) 22 | let sortedIndices = argSort(probs, axis: -1) 23 | let sortedProbs = probs[0..., sortedIndices.squeezed(axis: 0)] 24 | let cumulativeProbs = cumsum(sortedProbs, axis: -1) 25 | let topProbs = `where`(cumulativeProbs .> 1 - topP, sortedProbs, 0) 26 | let sortedToken = categorical(topProbs.log()) 27 | let token = sortedIndices.squeezed(axis: 0)[sortedToken] 28 | return token 29 | } 30 | 31 | func categoricalSampling(logits: MLXArray, temp: Float) -> MLXArray { 32 | categorical(logits * (1 / temp)) 33 | } 34 | 35 | public class Sampler { 36 | let temp: Float 37 | let topP: Float = 0.95 38 | let topK: Int = 0 39 | 40 | public init(temp: Float = 0.8) { 41 | self.temp = temp 42 | } 43 | 44 | public func callAsFunction(logits: MLXArray) -> (MLXArray, MLXArray) { 45 | if logits.shape.count != 2 { 46 | fatalError("expected a two dimensions logits array, got \(logits.shape)") 47 | } 48 | let logProbs = logits - logits.logSumExp() 49 | var tokens: MLXArray 50 | if temp <= 0.0 { 51 | tokens = logProbs.argMax(axis: -1) 52 | } else if self.topP > 0.0 && self.topP < 1.0 { 53 | tokens = topPSampling(logits: logits, topP: self.topP, temp: self.temp) 54 | } else if self.topK != 0 { 55 | tokens = topKSampling(logits: logits, topK: self.topK, temp: self.temp) 56 | } else { 57 | tokens = categoricalSampling(logits: logits, temp: self.temp) 58 | } 59 | return (tokens, logProbs) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /scripts/eval_mlx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | from moshi_mlx import models 6 | 7 | def config1b() -> models.lm.LmConfig: 8 | transformer = models.lm.TransformerConfig( 9 | d_model=2048, 10 | num_heads=16, 11 | num_layers=16, 12 | dim_feedforward=2048 * 4, # dim * hidden_scale 13 | causal=True, 14 | norm_first=True, 15 | bias_ff=False, 16 | bias_attn=False, 17 | layer_scale=None, 18 | context=750, 19 | max_period=100000, 20 | use_conv_block=False, 21 | use_conv_bias=True, 22 | cross_attention=False, 23 | gating=True, 24 | norm="rms_norm", 25 | positional_embedding="rope", 26 | conv_layout=False, 27 | conv_kernel_size=3, 28 | kv_repeat=1, 29 | max_seq_len=4096, 30 | ) 31 | depformer = models.lm.DepFormerConfig( 32 | transformer=models.lm.TransformerConfig( 33 | d_model=1024, 34 | num_heads=16, 35 | num_layers=6, 36 | dim_feedforward=1024 * 4, # dim * hidden_scale 37 | causal=True, 38 | norm_first=True, 39 | bias_ff=False, 40 | bias_attn=False, 41 | layer_scale=None, 42 | context=8, 43 | max_period=10000, 44 | use_conv_block=False, 45 | use_conv_bias=True, 46 | cross_attention=False, 47 | gating=True, 48 | norm="rms_norm", 49 | positional_embedding="none", 50 | conv_layout=False, 51 | conv_kernel_size=3, 52 | kv_repeat=1, 53 | max_seq_len=4096, 54 | ), 55 | num_slices=0, 56 | ) 57 | return models.lm.LmConfig( 58 | transformer=transformer, 59 | depformer=depformer, 60 | audio_vocab_size=2049, 61 | text_in_vocab_size=48001, 62 | text_out_vocab_size=48000, 63 | audio_codebooks=8, 64 | audio_delays=[0]*16 65 | ) 66 | home_dir = os.path.expanduser("~") 67 | 68 | lm_config = config1b() 69 | model = models.Lm(lm_config) 70 | model.set_dtype(mx.bfloat16) 71 | model.load_weights(home_dir + "/tmp/asr-1b-8d2516b9@150.safetensors", strict=True) 72 | 73 | def run_one(): 74 | text_token_ids = mx.array([48000]).reshape(1, 1) 75 | audio_token_ids = [mx.array([2048]).reshape(1, 1)] * 8 76 | xs = model.text_emb(text_token_ids) 77 | for token_ids, emb in zip(audio_token_ids, model.audio_embs): 78 | xs = xs + emb(token_ids) 79 | transformer_out = model.transformer(xs, cache=model.transformer_cache) 80 | transformer_out = model.out_norm(transformer_out) 81 | text_logits = model.text_linear(transformer_out) 82 | print(text_logits) 83 | 84 | run_one() 85 | -------------------------------------------------------------------------------- /MoshiLib/ASR.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import AVFoundation 6 | import Foundation 7 | import Hub 8 | import MLX 9 | import MLXNN 10 | import MoshiLib 11 | 12 | public class ASR { 13 | let moshi: LM 14 | let vocab: [Int: String] 15 | let mimi: Mimi 16 | var prevTextToken: Int = 0 17 | let sampler: Sampler = Sampler(temp: 0.0) 18 | let cb: Callbacks 19 | 20 | public init( 21 | _ moshi: LM, _ mimi: Mimi, vocab: [Int: String], cb: Callbacks = EmptyCallbacks() 22 | ) { 23 | self.moshi = moshi 24 | self.mimi = mimi 25 | self.vocab = vocab 26 | self.cb = cb 27 | } 28 | 29 | public func reset() { 30 | mimi.resetState() 31 | moshi.resetCache() 32 | prevTextToken = self.moshi.cfg.textInitToken() 33 | let textIds = MLXArray([prevTextToken]).reshaped([1, 1]) 34 | let audioIds = (0.. [String] { 45 | var tokens: [String] = [] 46 | let codebooks = moshi.cfg.audioCodebooks 47 | cb.onEvent(.beginEncode) 48 | let codes = mimi.encodeStep(StreamArray(pcm)) 49 | codes.eval() 50 | cb.onEvent(.endEncode) 51 | if let codes = codes.asArray() { 52 | cb.onInputAudioTokens(codes) 53 | let (_, _, steps) = codes.shape3 54 | for step in 0.. I, {your GitHub handle}, confirm that I have read and understood the terms of the CLA of Kyutai-labs, as outlined in the repository's CONTRIBUTING.md, and I agree to be bound by these terms. 25 | 26 | The full CLA is provided as follows: 27 | 28 | > I, {your GitHub handle}, hereby grant to Kyutai-labs a perpetual, worldwide, non-exclusive, royalty-free, 29 | > irrevocable license to use, modify, distribute, and sublicense my Contributions. 30 | 31 | > I understand and accept that Contributions are limited to modifications, improvements, or changes 32 | > to the project’s source code submitted via pull requests. I accept that Kyutai-labs has full discretion to 33 | > review, accept, reject, or request changes to any Contributions I submit, and that submitting 34 | > a pull request does not guarantee its inclusion in the project. 35 | 36 | > By submitting a Contribution, I grant Kyutai-labs a perpetual, worldwide license to use, modify, 37 | > reproduce, distribute, and create derivative works based on my Contributions. 38 | > I also agree to assign all patent rights for any inventions or improvements that arise from my Contributions, 39 | > giving the Kyutai-labs full rights to file for and enforce patents. 40 | > I understand that the Kyutai-labs may commercialize, relicense, or exploit the project and my Contributions without further notice or obligation to me. 41 | > I confirm that my Contributions are original and that I have the legal right to grant this license. 42 | > If my Contributions include third-party materials, I will ensure that I have the necessary permissions 43 | > and will disclose this information. I accept that once my Contributions are integrated, they may be altered or removed at the Kyutai-labs’s discretion. 44 | 45 | > I acknowledge that I am making these Contributions voluntarily and will not receive any compensation. 46 | > Furthermore, I understand that all Contributions, including mine, are provided on an "as-is" basis, with no warranties. 47 | > By submitting a pull request, I agree to be bound by these terms. 48 | 49 | ## Issues 50 | 51 | Please submit issues on our Github repository. 52 | 53 | ## License 54 | 55 | By contributing to Moshi, you agree that your contributions will be licensed 56 | under the MIT license as provided by the LICENSE file at the top of the repo. 57 | -------------------------------------------------------------------------------- /moshi.xcodeproj/xcshareddata/xcschemes/moshi-cli.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 25 | 26 | 32 | 33 | 44 | 46 | 52 | 53 | 54 | 55 | 58 | 59 | 62 | 63 | 64 | 65 | 71 | 73 | 79 | 80 | 81 | 82 | 84 | 85 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /moshi.xcodeproj/xcshareddata/xcschemes/Moshi.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 9 | 10 | 16 | 22 | 23 | 24 | 25 | 26 | 32 | 33 | 36 | 42 | 43 | 44 | 47 | 53 | 54 | 55 | 56 | 57 | 67 | 69 | 75 | 76 | 77 | 78 | 84 | 86 | 92 | 93 | 94 | 95 | 97 | 98 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /MoshiLib/Streaming.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import MLX 6 | import MLXFast 7 | import MLXNN 8 | import MLXRandom 9 | 10 | public class StreamArray { 11 | let inner: MLXArray? 12 | 13 | public init(_ x: MLXArray? = nil) { 14 | self.inner = x 15 | } 16 | 17 | public func dim(_ dim: Int) -> Int { 18 | self.inner?.dim(dim) ?? 0 19 | } 20 | 21 | public func eval() { 22 | self.inner?.eval() 23 | } 24 | 25 | public func cat2(_ rhs: StreamArray, axis: Int) -> StreamArray { 26 | switch (self.inner, rhs.inner) { 27 | case (.none, .none): StreamArray() 28 | case (.some(let lhs), .none): StreamArray(lhs) 29 | case (.none, .some(let rhs)): StreamArray(rhs) 30 | case (.some(let lhs), .some(let rhs)): StreamArray(concatenated([lhs, rhs], axis: axis)) 31 | } 32 | } 33 | 34 | public var shape: [Int]? { 35 | self.inner?.shape 36 | } 37 | 38 | public func asArray() -> MLXArray? { 39 | self.inner 40 | } 41 | 42 | public func narrow(_ offset: Int, _ len: Int, axis: Int) -> StreamArray { 43 | if let inner = self.inner { 44 | let totalLen = inner.dim(axis) 45 | if len <= 0 { 46 | return StreamArray() 47 | } else { 48 | // Not sure if there exists something closer to pytorch narrow 49 | let t = inner.split(indices: [offset, min(totalLen, offset + len)], axis: axis) 50 | return StreamArray(t[1]) 51 | } 52 | } else { 53 | return StreamArray() 54 | } 55 | } 56 | 57 | public func split(lhsLen: Int, axis: Int) -> (StreamArray, StreamArray) { 58 | if let t = self.inner { 59 | let len = t.dim(axis) 60 | let lhsLen = min(len, lhsLen) 61 | if lhsLen == 0 { 62 | return (StreamArray(), StreamArray(t)) 63 | } else if lhsLen == len { 64 | return (StreamArray(t), StreamArray()) 65 | } else { 66 | let split = t.split(indices: [lhsLen], axis: axis) 67 | return (StreamArray(split[0]), StreamArray(split[1])) 68 | } 69 | } else { 70 | return (StreamArray(), StreamArray()) 71 | } 72 | } 73 | 74 | public func elu() -> StreamArray { 75 | self.map { MLXNN.elu($0, alpha: 1.0) } 76 | } 77 | 78 | public func map(_ f: (MLXArray) -> MLXArray) -> StreamArray { 79 | switch self.inner { 80 | case .none: StreamArray() 81 | case .some(let x): StreamArray(f(x)) 82 | } 83 | } 84 | } 85 | 86 | public protocol StreamingLayer { 87 | func resetState() 88 | func step(_ x: StreamArray) -> StreamArray 89 | } 90 | 91 | public class StreamingBinOp { 92 | enum BinOp { 93 | case add 94 | case mul 95 | case sub 96 | case div 97 | } 98 | 99 | var prevLHS: StreamArray 100 | var prevRHS: StreamArray 101 | let op: BinOp 102 | let axis: Int 103 | 104 | init(_ op: BinOp, axis: Int) { 105 | self.prevLHS = StreamArray() 106 | self.prevRHS = StreamArray() 107 | self.op = op 108 | self.axis = axis 109 | } 110 | 111 | public func resetState() { 112 | self.prevLHS = StreamArray() 113 | self.prevRHS = StreamArray() 114 | } 115 | 116 | public func step(_ lhs: StreamArray, _ rhs: StreamArray) -> StreamArray { 117 | let lhs = self.prevLHS.cat2(lhs, axis: self.axis) 118 | let rhs = self.prevRHS.cat2(rhs, axis: self.axis) 119 | let lhsLen = lhs.dim(self.axis) 120 | let rhsLen = rhs.dim(self.axis) 121 | let commonLen = min(lhsLen, rhsLen) 122 | let (lhs_, prevLHS) = lhs.split(lhsLen: commonLen, axis: self.axis) 123 | let (rhs_, prevRHS) = rhs.split(lhsLen: commonLen, axis: self.axis) 124 | self.prevLHS = prevLHS 125 | self.prevRHS = prevRHS 126 | switch (lhs_.inner, rhs_.inner) { 127 | case (.some(let l), .some(let r)): 128 | var res: MLXArray 129 | switch self.op { 130 | case .add: res = l + r 131 | case .sub: res = l - r 132 | case .mul: res = l * r 133 | case .div: res = l / r 134 | } 135 | return StreamArray(res) 136 | case (.none, .none): return StreamArray() 137 | case _: fatalError("internal error") 138 | } 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /MoshiLib/Mimi.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import MLX 6 | import MLXFast 7 | import MLXNN 8 | import MLXRandom 9 | 10 | public struct MimiConfig { 11 | public var channels: Int 12 | public var sampleRate: Float 13 | public var frameRate: Float 14 | public var renormalize: Bool 15 | // public var resampleMethod: String 16 | public var seanet: SeanetConfig 17 | public var transformer: TransformerConfig 18 | public var quantizerNQ: Int 19 | public var quantizerBins: Int 20 | public var quantizerDim: Int 21 | 22 | public static func mimi_2024_07(numCodebooks: Int = 16) -> MimiConfig { 23 | let seanet = SeanetConfig.v0_1() 24 | let transformer = TransformerConfig( 25 | dModel: seanet.dimension, 26 | numHeads: 8, 27 | numLayers: 8, 28 | causal: true, 29 | normFirst: true, 30 | biasFF: false, 31 | biasAttn: false, 32 | layerScale: 0.01, 33 | positionalEmbedding: .rope, 34 | useConvBias: true, 35 | gating: false, 36 | norm: .layerNorm, 37 | context: 250, 38 | maxPeriod: 10000, 39 | maxSeqLen: 8192, 40 | kvRepeat: 1, 41 | dimFeedForward: 2048, 42 | convLayout: true, 43 | useRotatingKVCache: true 44 | ) 45 | return MimiConfig( 46 | channels: 1, 47 | sampleRate: 24000, frameRate: 12.5, renormalize: true, seanet: seanet, 48 | transformer: transformer, quantizerNQ: numCodebooks, quantizerBins: 2048, 49 | quantizerDim: 256) 50 | } 51 | } 52 | 53 | public class Mimi: Module { 54 | let cfg: MimiConfig 55 | let encoderCache: [KVCache] 56 | let decoderCache: [KVCache] 57 | @ModuleInfo(key: "encoder") var encoder: SeanetEncoder 58 | @ModuleInfo(key: "decoder") var decoder: SeanetDecoder 59 | @ModuleInfo(key: "encoder_transformer") var encoderTransformer: ProjectedTransformer 60 | @ModuleInfo(key: "decoder_transformer") var decoderTransformer: ProjectedTransformer 61 | @ModuleInfo(key: "downsample") var downsample: ConvDownsample1d 62 | @ModuleInfo(key: "upsample") var upsample: ConvTrUpsample1d 63 | @ModuleInfo(key: "quantizer") var quantizer: SplitResidualVectorQuantizer 64 | 65 | public init(_ cfg: MimiConfig, bSize: Int) { 66 | let dim = cfg.seanet.dimension 67 | self.cfg = cfg 68 | let encoderFrameRate = cfg.sampleRate / Float(cfg.seanet.ratios.reduce(1, *)) 69 | let downsampleStride = Int(encoderFrameRate / cfg.frameRate) 70 | self._encoder.wrappedValue = SeanetEncoder(cfg.seanet) 71 | self._decoder.wrappedValue = SeanetDecoder(cfg.seanet) 72 | self._quantizer.wrappedValue = SplitResidualVectorQuantizer( 73 | dim: cfg.quantizerDim, inputDim: dim, outputDim: dim, nQ: cfg.quantizerNQ, 74 | bins: cfg.quantizerBins) 75 | self._encoderTransformer.wrappedValue = ProjectedTransformer( 76 | cfg.transformer, inputDim: dim, outputDims: [dim]) 77 | self._decoderTransformer.wrappedValue = ProjectedTransformer( 78 | cfg.transformer, inputDim: dim, outputDims: [dim]) 79 | self._downsample.wrappedValue = ConvDownsample1d( 80 | stride: downsampleStride, dim: dim, causal: true) 81 | self._upsample.wrappedValue = ConvTrUpsample1d( 82 | stride: downsampleStride, dim: dim, causal: true) 83 | self.encoderCache = self._encoderTransformer.wrappedValue.makeCache(bSize: bSize) 84 | self.decoderCache = self._decoderTransformer.wrappedValue.makeCache(bSize: bSize) 85 | } 86 | 87 | public func warmup() { 88 | let pcm = MLXArray.zeros([1, 1, 1920 * 4]) 89 | let codes = self.encode(pcm) 90 | let pcmOut = self.decode(codes) 91 | eval(pcmOut) 92 | } 93 | 94 | public func encode(_ x: MLXArray) -> MLXArray { 95 | self.encoder.resetState() 96 | self.encoderCache.forEach { c in c.reset() } 97 | var x = self.encoder(x) 98 | x = self.encoderTransformer(x, cache: self.encoderCache)[0] 99 | x = self.downsample(x) 100 | let codes = self.quantizer.encode(x) 101 | return codes 102 | } 103 | 104 | public func encodeStep(_ x: StreamArray) -> StreamArray { 105 | var x = self.encoder.step(x) 106 | x = x.map { self.encoderTransformer($0, cache: self.encoderCache)[0] } 107 | x = self.downsample.step(x) 108 | let codes = x.map(self.quantizer.encode) 109 | return codes 110 | } 111 | 112 | public func decode(_ codes: MLXArray) -> MLXArray { 113 | self.decoder.resetState() 114 | self.decoderCache.forEach { c in c.reset() } 115 | let emb = self.quantizer.decode(codes) 116 | let embUp = self.upsample(emb) 117 | let outs = self.decoderTransformer(embUp, cache: self.decoderCache) 118 | return self.decoder(outs[0]) 119 | } 120 | 121 | public func decodeStep(_ codes: StreamArray) -> StreamArray { 122 | var emb = codes.map { self.quantizer.decode($0) } 123 | emb = self.upsample.step(emb) 124 | let out = emb.map { self.decoderTransformer($0, cache: self.decoderCache)[0] } 125 | return self.decoder.step(out) 126 | } 127 | 128 | public func resetState() { 129 | self.encoder.resetState() 130 | self.decoder.resetState() 131 | self.encoderCache.forEach { c in c.reset() } 132 | self.decoderCache.forEach { c in c.reset() } 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /MoshiLib/Qwen2.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import MLX 6 | import MLXFast 7 | import MLXNN 8 | 9 | public struct QwenQuantization: Codable { 10 | public var groupSize: Int 11 | public var bits: Int 12 | } 13 | 14 | public struct QwenConfig: Codable { 15 | public var bosTokenId: Int 16 | public var eosTokenId: Int 17 | public var hiddenSize: Int 18 | public var intermediateSize: Int 19 | public var maxPositionEmbeddings: Int 20 | public var maxWindowLayers: Int 21 | public var numAttentionHeads: Int 22 | public var numHiddenLayers: Int 23 | public var numKeyValueHeads: Int 24 | public var rmsNormEps: Float 25 | public var ropeTheta: Float 26 | public var tieWordEmbeddings: Bool 27 | public var useSlidingWindow: Bool 28 | public var vocabSize: Int 29 | public var quantization: QwenQuantization? = nil 30 | 31 | public func headDim() -> Int { 32 | self.hiddenSize / self.numAttentionHeads 33 | } 34 | } 35 | 36 | private class Mlp: Module, UnaryLayer { 37 | @ModuleInfo(key: "gate_proj") var gateProj: Linear 38 | @ModuleInfo(key: "down_proj") var downProj: Linear 39 | @ModuleInfo(key: "up_proj") var upProj: Linear 40 | 41 | init(_ cfg: QwenConfig) { 42 | self._gateProj.wrappedValue = Linear(cfg.hiddenSize, cfg.intermediateSize, bias: false) 43 | self._upProj.wrappedValue = Linear(cfg.hiddenSize, cfg.intermediateSize, bias: false) 44 | self._downProj.wrappedValue = Linear(cfg.intermediateSize, cfg.hiddenSize, bias: false) 45 | } 46 | 47 | func callAsFunction(_ x: MLXArray) -> MLXArray { 48 | return downProj(silu(gateProj(x)) * upProj(x)) 49 | } 50 | } 51 | 52 | private class Attention: Module { 53 | let cfg: QwenConfig 54 | let scale: Float 55 | let rope: RoPE 56 | 57 | @ModuleInfo(key: "q_proj") var qProj: Linear 58 | @ModuleInfo(key: "k_proj") var kProj: Linear 59 | @ModuleInfo(key: "v_proj") var vProj: Linear 60 | @ModuleInfo(key: "o_proj") var oProj: Linear 61 | 62 | init(_ cfg: QwenConfig) { 63 | self.cfg = cfg 64 | self.scale = 1.0 / sqrt(Float(cfg.headDim())) 65 | let headDim = cfg.headDim() 66 | self._qProj.wrappedValue = Linear( 67 | cfg.hiddenSize, cfg.numAttentionHeads * headDim, bias: true) 68 | self._kProj.wrappedValue = Linear( 69 | cfg.hiddenSize, cfg.numKeyValueHeads * headDim, bias: true) 70 | self._vProj.wrappedValue = Linear( 71 | cfg.hiddenSize, cfg.numKeyValueHeads * headDim, bias: true) 72 | self._oProj.wrappedValue = Linear( 73 | cfg.numAttentionHeads * headDim, cfg.hiddenSize, bias: false) 74 | self.rope = 75 | RoPE(dimensions: cfg.headDim(), traditional: false, base: Float(cfg.ropeTheta)) 76 | } 77 | 78 | func callAsFunction(_ x: MLXArray, mask: MLXArray?, cache: KVCache?) -> MLXArray { 79 | let (B, T, H) = (x.dim(0), x.dim(1), x.dim(2)) 80 | let headDim = cfg.headDim() 81 | var queryStates = qProj(x).reshaped(B, T, -1, headDim).transposed(0, 2, 1, 3) 82 | var keyStates = kProj(x).reshaped(B, T, -1, headDim).transposed(0, 2, 1, 3) 83 | var valueStates = vProj(x).reshaped(B, T, -1, headDim).transposed(0, 2, 1, 3) 84 | let offset = cache?.offset ?? 0 85 | queryStates = rope(queryStates, offset: offset) 86 | keyStates = rope(keyStates, offset: offset) 87 | if let cache { 88 | (keyStates, valueStates) = cache.update(keys: keyStates, values: valueStates) 89 | } 90 | // sliding window is not supported here 91 | var mask = mask 92 | if let m = mask { 93 | let maskLen = m.dim(-1) 94 | if keyStates.dim(2) < maskLen { 95 | let offset = maskLen - keyStates.dim(2) 96 | mask = m[0..., offset...] 97 | } 98 | } 99 | let x = MLXFast.scaledDotProductAttention( 100 | queries: queryStates, keys: keyStates, values: valueStates, scale: self.scale, 101 | mask: mask 102 | ).transposed(0, 2, 1, 3).reshaped(B, T, H) 103 | return oProj(x) 104 | } 105 | } 106 | 107 | private class Layer: Module { 108 | @ModuleInfo(key: "mlp") var mlp: Mlp 109 | @ModuleInfo(key: "input_layernorm") var inputNorm: RMSNorm 110 | @ModuleInfo(key: "post_attention_layernorm") var postAttnNorm: RMSNorm 111 | @ModuleInfo(key: "self_attn") var selfAttn: Attention 112 | 113 | init(_ cfg: QwenConfig) { 114 | self._mlp.wrappedValue = Mlp(cfg) 115 | self._inputNorm.wrappedValue = RMSNorm(dimensions: cfg.hiddenSize, eps: cfg.rmsNormEps) 116 | self._postAttnNorm.wrappedValue = RMSNorm(dimensions: cfg.hiddenSize, eps: cfg.rmsNormEps) 117 | self._selfAttn.wrappedValue = Attention(cfg) 118 | } 119 | 120 | func callAsFunction(_ x: MLXArray, mask: MLXArray?, cache: KVCache) -> MLXArray { 121 | var residual = x 122 | var x = x 123 | x = selfAttn(inputNorm(x), mask: mask, cache: cache) 124 | x = residual + x 125 | residual = x 126 | x = mlp(postAttnNorm(x)) 127 | return residual + x 128 | } 129 | } 130 | 131 | public class QwenModel: Module { 132 | let cfg: QwenConfig 133 | private let norm: RMSNorm 134 | private let layers: [Layer] 135 | @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding 136 | 137 | public init(_ cfg: QwenConfig) { 138 | self.cfg = cfg 139 | self.layers = (0.. MLXArray { 146 | var x = embedTokens(x) 147 | let mask = cache.first?.createAttentionMask(h: x) 148 | for (layer, c) in zip(self.layers, cache) { 149 | x = layer(x, mask: mask, cache: c) 150 | } 151 | return embedTokens.asLinear(norm(x)) 152 | } 153 | 154 | public func makeCache(bSize: Int) -> [KVCache] { 155 | let kvHeads = cfg.numKeyValueHeads 156 | let cache = (0.. { 5 | private var buffer: [T] = [] 6 | private let queue = DispatchQueue(label: "tschannel", attributes: .concurrent) 7 | private let semaphore = DispatchSemaphore(value: 0) 8 | 9 | func send(_ value: T) { 10 | queue.async(flags: .barrier) { 11 | self.buffer.append(value) 12 | self.semaphore.signal() 13 | } 14 | } 15 | 16 | func receive() -> T? { 17 | semaphore.wait() 18 | return queue.sync { 19 | guard !buffer.isEmpty else { return nil } 20 | return buffer.removeFirst() 21 | } 22 | } 23 | } 24 | 25 | // The code below is probably macos specific and unlikely to work on ios. 26 | class MicrophoneCapture { 27 | private let audioEngine: AVAudioEngine 28 | private let channel: ThreadSafeChannel<[Float]> 29 | 30 | init() { 31 | audioEngine = AVAudioEngine() 32 | channel = ThreadSafeChannel() 33 | } 34 | 35 | func startCapturing() { 36 | let inputNode = audioEngine.inputNode 37 | 38 | // Desired format: 1 channel (mono), 24kHz, Float32 39 | let desiredSampleRate: Double = 24000.0 40 | let desiredChannelCount: AVAudioChannelCount = 1 41 | 42 | let inputFormat = inputNode.inputFormat(forBus: 0) 43 | 44 | // Create a custom audio format with the desired settings 45 | guard 46 | let mono24kHzFormat = AVAudioFormat( 47 | commonFormat: .pcmFormatFloat32, 48 | sampleRate: desiredSampleRate, 49 | channels: desiredChannelCount, 50 | interleaved: false) 51 | else { 52 | print("Could not create target format") 53 | return 54 | } 55 | 56 | // Resample the buffer to match the desired format 57 | let converter = AVAudioConverter(from: inputFormat, to: mono24kHzFormat) 58 | 59 | // Install a tap to capture audio and resample to the target format 60 | inputNode.installTap(onBus: 0, bufferSize: 1920, format: inputFormat) { buffer, _ in 61 | let targetLen = Int(buffer.frameLength) * 24000 / Int(inputFormat.sampleRate) 62 | let convertedBuffer = AVAudioPCMBuffer( 63 | pcmFormat: mono24kHzFormat, frameCapacity: AVAudioFrameCount(targetLen))! 64 | var error: NSError? = nil 65 | let inputBlock: AVAudioConverterInputBlock = { inNumPackets, outStatus in 66 | outStatus.pointee = .haveData 67 | return buffer 68 | } 69 | 70 | converter?.convert(to: convertedBuffer, error: &error, withInputFrom: inputBlock) 71 | 72 | if let error = error { 73 | print("Conversion error: \(error)") 74 | return 75 | } 76 | 77 | self.processAudioBuffer(buffer: convertedBuffer) 78 | } 79 | 80 | // Start the audio engine 81 | do { 82 | audioEngine.prepare() 83 | try audioEngine.start() 84 | print("Microphone capturing started at 24kHz, mono") 85 | } catch { 86 | print("Error starting audio engine: \(error)") 87 | } 88 | } 89 | 90 | private func processAudioBuffer(buffer: AVAudioPCMBuffer) { 91 | guard let channelData = buffer.floatChannelData else { return } 92 | let frameCount = Int(buffer.frameLength) 93 | 94 | let pcmData = Array(UnsafeBufferPointer(start: channelData[0], count: frameCount)).map { 95 | $0 96 | } 97 | channel.send(pcmData) 98 | } 99 | 100 | func stopCapturing() { 101 | audioEngine.stop() 102 | audioEngine.inputNode.removeTap(onBus: 0) 103 | print("Microphone capturing stopped") 104 | } 105 | 106 | func receive() -> [Float]? { 107 | channel.receive() 108 | } 109 | } 110 | 111 | class FloatRingBuffer { 112 | private var buffer: [Float] 113 | private let capacity: Int 114 | private var readIndex = 0 115 | private var writeIndex = 0 116 | private var count = 0 117 | 118 | private let lock = NSLock() 119 | 120 | init(capacity: Int) { 121 | self.capacity = capacity 122 | self.buffer = [Float](repeating: 0, count: capacity) 123 | } 124 | 125 | func write(_ values: [Float]) -> Bool { 126 | lock.lock() 127 | defer { lock.unlock() } 128 | 129 | if values.count + count > capacity { 130 | return false 131 | } 132 | for value in values { 133 | buffer[writeIndex] = value 134 | writeIndex = (writeIndex + 1) % capacity 135 | count += 1 136 | } 137 | return true 138 | } 139 | 140 | func read(maxCount: Int) -> [Float] { 141 | lock.lock() 142 | defer { lock.unlock() } 143 | 144 | var values: [Float] = [] 145 | for _ in 0.. OSStatus in 173 | let audioBuffers = UnsafeMutableAudioBufferListPointer(audioBufferList) 174 | guard let channelData = audioBuffers[0].mData?.assumingMemoryBound(to: Float.self) 175 | else { 176 | return kAudioHardwareUnspecifiedError 177 | } 178 | let data = self.ringBuffer.read(maxCount: Int(frameCount)) 179 | for i in 0.. Bool { 192 | ringBuffer.write(values) 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /MoshiLib/KVCache.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | // 5 | // Parts of this file came from: 6 | // https://github.com/ml-explore/mlx-swift-examples/blob/main/Libraries/LLM/KVCache.swift 7 | // Copyright © 2024 Apple Inc. 8 | 9 | import Foundation 10 | import MLX 11 | 12 | /// Interface for Key/Value cache for LLMs. 13 | /// 14 | /// See ``LLMModel/newCache(parameters:)-47tyu`` 15 | public protocol KVCache: Evaluatable { 16 | 17 | /// get the current offset 18 | var offset: Int { get } 19 | 20 | func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) 21 | func reset() 22 | func createAttentionMask(h: MLXArray) -> MLXArray? 23 | } 24 | 25 | func createAdditiveCausalMask(n: Int, offset: Int) -> MLXArray { 26 | let rinds = MLXArray(Int32(0).. [MLXArray] { 58 | [self.keys, self.values].compactMap { $0 } 59 | } 60 | 61 | func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { 62 | let previous = self.offset 63 | 64 | let reset = 65 | if let currentKeys = self.keys, (previous + keys.dim(2)) > currentKeys.dim(2) { 66 | true 67 | } else { 68 | self.keys == nil 69 | } 70 | if reset { 71 | let B = keys.dim(0) 72 | let nSteps = (step + keys.dim(2) - 1) / step 73 | let kShape = [B, kvHeads, nSteps * step, kHeadDim] 74 | let vShape = [B, kvHeads, nSteps * step, vHeadDim] 75 | let newK = MLXArray.zeros(kShape, dtype: keys.dtype) 76 | let newV = MLXArray.zeros(vShape, dtype: values.dtype) 77 | 78 | if var currentKeys = self.keys, var currentValues = self.values { 79 | if previous % step != 0 { 80 | currentKeys = currentKeys[.ellipsis, .. MLXArray? { 107 | let t = h.dim(1) 108 | if t > 1 { 109 | let rinds = MLXArray(Int32(0).. (MLXArray, MLXArray) { 131 | let t = keys.dim(2) 132 | if t > self.maxSize { 133 | fatalError("query to update with shape \(keys.shape) larger than maxSize \(maxSize)") 134 | } 135 | let currentOffset = self.offset % self.maxSize 136 | let tMax = min(self.maxSize, currentOffset + t) 137 | self.keys[0..., 0..., currentOffset.. [MLXArray] { 154 | [self.keys, self.values].compactMap { $0 } 155 | } 156 | 157 | func createAttentionMask(h: MLXArray) -> MLXArray? { 158 | let t = h.dim(1) 159 | let finalOffset = self.offset + t 160 | let finalOffsetMod = finalOffset % self.maxSize 161 | // As a default we use finalOffset + 1 so that these slices cannot be seen. 162 | var rinds = Array(repeating: Int32(finalOffset + 1), count: self.maxSize) 163 | for i in 0.. Self { 30 | try super.update(parameters: parameters, verify: verify) 31 | let clusterUsage = maximum(self._clusterUsage.wrappedValue, self.epsilon)[0..., .newAxis] 32 | self.embedding = self._embeddingSum.wrappedValue / clusterUsage 33 | self.c2 = self.embedding.square().sum(axis: -1) / 2 34 | return self 35 | } 36 | 37 | func encode(_ x: MLXArray) -> MLXArray { 38 | let targetShape = Array(x.shape.dropLast()) 39 | let x = x.flattened(end: -2) 40 | let dotProd = x.matmul(embedding.swappedAxes(-1, -2)) 41 | return (c2 - dotProd).argMin(axis: -1).reshaped(targetShape) 42 | } 43 | 44 | func decode(_ indexes: MLXArray) -> MLXArray { 45 | let finalDims = indexes.shape + [self.dim] 46 | let indexes = indexes.flattened() 47 | return embedding.take(indexes, axis: 0).reshaped(finalDims) 48 | } 49 | } 50 | 51 | class VectorQuantization: Module { 52 | @ModuleInfo(key: "project_in") var projectIn: Linear? 53 | @ModuleInfo(key: "project_out") var projectOut: Linear? 54 | @ModuleInfo(key: "_codebook") var codebook: EuclideanCodebook 55 | 56 | init(dim: Int, codebookSize: Int, codebookDim: Int? = nil) { 57 | let codebookDim = codebookDim ?? dim 58 | if codebookDim == dim { 59 | self._projectIn.wrappedValue = nil 60 | self._projectOut.wrappedValue = nil 61 | } else { 62 | self._projectIn.wrappedValue = Linear(dim, codebookDim) 63 | self._projectOut.wrappedValue = Linear(codebookDim, dim) 64 | } 65 | self._codebook.wrappedValue = EuclideanCodebook( 66 | dim: codebookDim, codebookSize: codebookSize) 67 | } 68 | 69 | func encode(_ x: MLXArray) -> MLXArray { 70 | var x = x.swappedAxes(-1, -2) 71 | if let projectIn = self.projectIn { 72 | x = projectIn(x) 73 | } 74 | return self.codebook.encode(x) 75 | } 76 | 77 | func decode(_ indexes: MLXArray) -> MLXArray { 78 | var quantized = self.codebook.decode(indexes) 79 | if let projectOut = self.projectOut { 80 | quantized = projectOut(quantized) 81 | } 82 | return quantized.swappedAxes(-1, -2) 83 | } 84 | } 85 | 86 | class ResidualVectorQuantization: Module { 87 | @ModuleInfo(key: "layers") var layers: [VectorQuantization] 88 | 89 | init(nQ: Int, dim: Int, codebookSize: Int, codebookDim: Int? = nil) { 90 | self._layers.wrappedValue = (0.. MLXArray { 96 | var codes: [MLXArray] = [] 97 | var residual = x 98 | for layer in self.layers { 99 | let indices = layer.encode(residual) 100 | let quantized = layer.decode(indices) 101 | residual = residual - quantized 102 | codes.append(indices) 103 | } 104 | return stacked(codes, axis: 0) 105 | } 106 | 107 | func decode(_ indexes: MLXArray) -> MLXArray { 108 | let seqLen = indexes.dim(0) 109 | var quantized = self.layers[0].decode(indexes[0]) 110 | for i in 1.. MLXArray { 142 | var x = x 143 | if let inputProj = self.inputProj { 144 | x = inputProj(x) 145 | } 146 | return self.vq.encode(x).swappedAxes(0, 1) 147 | } 148 | 149 | func decode(_ codes: MLXArray) -> MLXArray { 150 | let codes = codes.swappedAxes(0, 1) 151 | var quantized = self.vq.decode(codes) 152 | if let outputProj = self.outputProj { 153 | quantized = outputProj(quantized) 154 | } 155 | return quantized 156 | } 157 | } 158 | 159 | class SplitResidualVectorQuantizer: Module { 160 | let nQ: Int 161 | @ModuleInfo(key: "rvq_first") var rvqFirst: ResidualVectorQuantizer 162 | @ModuleInfo(key: "rvq_rest") var rvqRest: ResidualVectorQuantizer 163 | 164 | init(dim: Int, inputDim: Int?, outputDim: Int?, nQ: Int, bins: Int) { 165 | self.nQ = nQ 166 | self._rvqFirst.wrappedValue = ResidualVectorQuantizer( 167 | dim: dim, inputDim: inputDim, outputDim: outputDim, nQ: 1, bins: bins, 168 | forceProjection: true) 169 | self._rvqRest.wrappedValue = ResidualVectorQuantizer( 170 | dim: dim, inputDim: inputDim, outputDim: outputDim, nQ: nQ - 1, bins: bins, 171 | forceProjection: true) 172 | } 173 | 174 | func encode(_ x: MLXArray) -> MLXArray { 175 | var codes = self.rvqFirst.encode(x) 176 | if self.nQ > 1 { 177 | let restCodes = self.rvqRest.encode(x) 178 | codes = concatenated([codes, restCodes], axis: 1) 179 | } 180 | return codes 181 | } 182 | 183 | func decode(_ codes: MLXArray) -> MLXArray { 184 | var quantized = self.rvqFirst.decode(codes[0..., ..<1]) 185 | if self.nQ > 1 { 186 | quantized = quantized + self.rvqRest.decode(codes[0..., 1...]) 187 | } 188 | return quantized 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /MoshiLib/Perf.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | import MLX 5 | import os.signpost 6 | 7 | public enum EventKind { 8 | case beginStep 9 | case endStep 10 | case beginDepformer 11 | case endDepformer 12 | case beginDecode 13 | case endDecode 14 | case beginEncode 15 | case endEncode 16 | } 17 | 18 | public protocol Callbacks { 19 | func onReset() 20 | func onEvent(_ eventKind: EventKind) 21 | func onInputAudioTokens(_ codes: MLXArray) 22 | func onOutputTextToken(_ token: Int) 23 | func onOutputAudioTokens(_ codes: MLXArray) 24 | } 25 | 26 | public class EmptyCallbacks: Callbacks { 27 | public init() {} 28 | public func onReset() {} 29 | public func onEvent(_ eventKind: EventKind) {} 30 | public func onInputAudioTokens(_ codes: MLXArray) {} 31 | public func onOutputTextToken(_ token: Int) {} 32 | public func onOutputAudioTokens(_ codes: MLXArray) {} 33 | } 34 | 35 | public struct ChromeTraceEvent: Codable { 36 | let name: String 37 | let cat: String 38 | let ph: String 39 | let ts: Int 40 | let pid: Int 41 | let tid: Int 42 | } 43 | 44 | public struct StatsSummary { 45 | public struct Stats: Identifiable { 46 | public var min: Float = Float.infinity 47 | public var max: Float = 0.0 48 | public var sum: Float = 0.0 49 | public var cnt: Int = 0 50 | public let id = UUID() 51 | 52 | mutating func addValue(_ b: CFAbsoluteTime, _ e: CFAbsoluteTime) { 53 | let v = Float(e - b) 54 | self.min = Float.minimum(self.min, v) 55 | self.max = Float.maximum(self.max, v) 56 | self.sum += v 57 | self.cnt += 1 58 | } 59 | } 60 | public var encode: Stats = Stats() 61 | public var decode: Stats = Stats() 62 | public var step: Stats = Stats() 63 | public var depformer: Stats = Stats() 64 | 65 | public init() { 66 | } 67 | } 68 | 69 | public class PerfStats: Callbacks { 70 | private let log: OSLog 71 | private var events: [(CFAbsoluteTime, EventKind)] = [] 72 | private var inputAudioTokens: [MLXArray] = [] 73 | private var outputAudioTokens: [MLXArray] = [] 74 | private var textTokens: [Int] = [] 75 | 76 | public init() { 77 | self.log = OSLog(subsystem: "org.kyutai.moshi", category: "Performance") 78 | } 79 | 80 | func append(_ kind: EventKind) { 81 | events.append((CFAbsoluteTimeGetCurrent(), kind)) 82 | } 83 | 84 | public func onReset() { 85 | events.removeAll() 86 | inputAudioTokens.removeAll() 87 | outputAudioTokens.removeAll() 88 | textTokens.removeAll() 89 | } 90 | 91 | public func onInputAudioTokens(_ codes: MLXArray) { 92 | codes.eval() 93 | inputAudioTokens.append(codes) 94 | } 95 | 96 | public func onOutputTextToken(_ token: Int) { 97 | textTokens.append(token) 98 | } 99 | 100 | public func onOutputAudioTokens(_ codes: MLXArray) { 101 | codes.eval() 102 | outputAudioTokens.append(codes) 103 | } 104 | 105 | public func getSummary(maxEvents: Int) -> StatsSummary { 106 | let startIdx = max(0, self.events.count - maxEvents) 107 | var summary = StatsSummary() 108 | var lastBeginEncode: CFAbsoluteTime? = nil 109 | var lastBeginDecode: CFAbsoluteTime? = nil 110 | var lastBeginStep: CFAbsoluteTime? = nil 111 | var lastBeginDepformer: CFAbsoluteTime? = nil 112 | for event in self.events[startIdx...] { 113 | let time = event.0 114 | switch event.1 { 115 | case .beginStep: 116 | lastBeginStep = time 117 | case .beginDecode: 118 | lastBeginDecode = time 119 | case .beginEncode: 120 | lastBeginEncode = time 121 | case .beginDepformer: 122 | lastBeginDepformer = time 123 | case .endStep: 124 | if let b = lastBeginStep { 125 | lastBeginStep = nil 126 | summary.step.addValue(b, time) 127 | } 128 | case .endDecode: 129 | if let b = lastBeginDecode { 130 | lastBeginDecode = nil 131 | summary.decode.addValue(b, time) 132 | } 133 | case .endEncode: 134 | if let b = lastBeginEncode { 135 | lastBeginEncode = nil 136 | summary.encode.addValue(b, time) 137 | } 138 | case .endDepformer: 139 | if let b = lastBeginDepformer { 140 | lastBeginDepformer = nil 141 | summary.depformer.addValue(b, time) 142 | } 143 | } 144 | } 145 | return summary 146 | } 147 | 148 | public func onEvent(_ kind: EventKind) { 149 | switch kind { 150 | case .beginStep: 151 | os_signpost(.begin, log: log, name: "step") 152 | case .endStep: 153 | os_signpost(.end, log: log, name: "step") 154 | case .beginDepformer: 155 | os_signpost(.begin, log: log, name: "depformer") 156 | case .endDepformer: 157 | os_signpost(.end, log: log, name: "depformer") 158 | case .beginEncode: 159 | os_signpost(.begin, log: log, name: "encode") 160 | case .endEncode: 161 | os_signpost(.end, log: log, name: "encode") 162 | case .beginDecode: 163 | os_signpost(.begin, log: log, name: "decode") 164 | case .endDecode: 165 | os_signpost(.end, log: log, name: "decode") 166 | } 167 | append(kind) 168 | } 169 | 170 | public func writeJSONTrace(url: URL) throws { 171 | let encoder = JSONEncoder() 172 | var traceEvents: [ChromeTraceEvent] = [] 173 | for (time, kind) in events { 174 | let ts = Int((time - events[0].0) * 1e6) 175 | let (name, ph) = 176 | switch kind { 177 | case .beginStep: ("step", "B") 178 | case .endStep: ("step", "E") 179 | case .beginEncode: ("encode", "B") 180 | case .endEncode: ("encode", "E") 181 | case .beginDepformer: ("depformer", "B") 182 | case .endDepformer: ("depformer", "E") 183 | case .beginDecode: ("decode", "B") 184 | case .endDecode: ("decode", "E") 185 | } 186 | traceEvents.append( 187 | ChromeTraceEvent(name: name, cat: "", ph: ph, ts: ts, pid: 42, tid: 1)) 188 | } 189 | let jsonData = try encoder.encode(traceEvents) 190 | try jsonData.write(to: url) 191 | } 192 | 193 | public func allInputAudioTokens() -> MLXArray? { 194 | inputAudioTokens.isEmpty ? nil : concatenated(inputAudioTokens, axis: 2) 195 | } 196 | 197 | public func allOutputAudioTokens() -> MLXArray? { 198 | outputAudioTokens.isEmpty ? nil : concatenated(outputAudioTokens, axis: 2) 199 | } 200 | 201 | public func allTextTokens() -> MLXArray? { 202 | textTokens.isEmpty ? nil : MLXArray(textTokens) 203 | } 204 | 205 | public func writeCodes(url: URL) throws { 206 | var arrays: [String: MLXArray] = [:] 207 | if let a = allInputAudioTokens() { 208 | arrays["input_audio_tokens"] = a 209 | } 210 | if let a = allOutputAudioTokens() { 211 | arrays["output_audio_tokens"] = a 212 | } 213 | if let a = allTextTokens() { 214 | arrays["text_tokens"] = a 215 | } 216 | try save(arrays: arrays, url: url) 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /Moshi/AudioRT.swift: -------------------------------------------------------------------------------- 1 | import AVFoundation 2 | import Foundation 3 | import Synchronization 4 | 5 | class ThreadSafeChannel { 6 | private var buffer: [T] = [] 7 | private let queue = DispatchQueue(label: "tschannel", attributes: .concurrent) 8 | private let semaphore = DispatchSemaphore(value: 0) 9 | 10 | func send(_ value: T) { 11 | queue.async(flags: .barrier) { 12 | self.buffer.append(value) 13 | self.semaphore.signal() 14 | } 15 | } 16 | 17 | func receive() -> T? { 18 | semaphore.wait() 19 | return queue.sync { 20 | guard !buffer.isEmpty else { return nil } 21 | return buffer.removeFirst() 22 | } 23 | } 24 | } 25 | 26 | // The code below is probably macos specific and unlikely to work on ios. 27 | class MicrophoneCapture { 28 | private let audioEngine: AVAudioEngine 29 | private let channel: ThreadSafeChannel<[Float]> 30 | private let bufferedLen: Atomic = .init(0) 31 | 32 | init() { 33 | audioEngine = AVAudioEngine() 34 | channel = ThreadSafeChannel() 35 | } 36 | 37 | func startCapturing() { 38 | let inputNode = audioEngine.inputNode 39 | // Setting the voice mode on macos causes weird hangs of the microphone 40 | // so we discard it for now. 41 | #if os(iOS) 42 | do { 43 | // try inputNode.setVoiceProcessingEnabled(true) 44 | } catch { 45 | print("could not set voice processing on the input node") 46 | } 47 | #endif 48 | 49 | // Desired format: 1 channel (mono), 24kHz, Float32 50 | let desiredSampleRate: Double = 24000.0 51 | let desiredChannelCount: AVAudioChannelCount = 1 52 | 53 | let inputFormat = inputNode.inputFormat(forBus: 0) 54 | // Create a custom audio format with the desired settings 55 | guard 56 | let mono24kHzFormat = AVAudioFormat( 57 | commonFormat: .pcmFormatFloat32, 58 | sampleRate: desiredSampleRate, 59 | channels: desiredChannelCount, 60 | interleaved: false) 61 | else { 62 | print("Could not create target format") 63 | return 64 | } 65 | 66 | // Resample the buffer to match the desired format 67 | let converter = AVAudioConverter(from: inputFormat, to: mono24kHzFormat) 68 | 69 | // Install a tap to capture audio and resample to the target format 70 | inputNode.installTap(onBus: 0, bufferSize: 1920, format: inputFormat) { buffer, _ in 71 | let targetLen = Int(buffer.frameLength) * 24000 / Int(inputFormat.sampleRate) 72 | let convertedBuffer = AVAudioPCMBuffer( 73 | pcmFormat: mono24kHzFormat, frameCapacity: AVAudioFrameCount(targetLen))! 74 | var error: NSError? = nil 75 | let inputBlock: AVAudioConverterInputBlock = { inNumPackets, outStatus in 76 | outStatus.pointee = .haveData 77 | return buffer 78 | } 79 | 80 | converter?.convert(to: convertedBuffer, error: &error, withInputFrom: inputBlock) 81 | 82 | if let error = error { 83 | print("Conversion error: \(error)") 84 | return 85 | } 86 | 87 | self.processAudioBuffer(buffer: convertedBuffer) 88 | } 89 | 90 | // Start the audio engine 91 | do { 92 | audioEngine.prepare() 93 | try audioEngine.start() 94 | print("Microphone capturing started at 24kHz, mono") 95 | } catch { 96 | print("Error starting audio engine: \(error)") 97 | } 98 | } 99 | 100 | private func processAudioBuffer(buffer: AVAudioPCMBuffer) { 101 | guard let channelData = buffer.floatChannelData else { return } 102 | let frameCount = Int(buffer.frameLength) 103 | 104 | let pcmData = Array(UnsafeBufferPointer(start: channelData[0], count: frameCount)).map { 105 | $0 106 | } 107 | bufferedLen.add(pcmData.count, ordering: .sequentiallyConsistent) 108 | channel.send(pcmData) 109 | } 110 | 111 | func stopCapturing() { 112 | audioEngine.stop() 113 | audioEngine.inputNode.removeTap(onBus: 0) 114 | print("Microphone capturing stopped") 115 | } 116 | 117 | func receive() -> [Float]? { 118 | let data = channel.receive() 119 | if let data = data { 120 | bufferedLen.subtract(data.count, ordering: .sequentiallyConsistent) 121 | } 122 | return data 123 | } 124 | 125 | func bufferedDuration() -> Double { 126 | return Double(bufferedLen.load(ordering: .sequentiallyConsistent)) / 24000.0 127 | } 128 | } 129 | 130 | class FloatRingBuffer { 131 | private var buffer: [Float] 132 | private let capacity: Int 133 | private var readIndex = 0 134 | private var writeIndex = 0 135 | private var count = 0 136 | 137 | private let lock = NSLock() 138 | 139 | init(capacity: Int) { 140 | self.capacity = capacity 141 | self.buffer = [Float](repeating: 0, count: capacity) 142 | } 143 | 144 | func currentCount() -> Int { 145 | lock.lock() 146 | defer { lock.unlock() } 147 | return count 148 | } 149 | 150 | func write(_ values: [Float]) -> Bool { 151 | lock.lock() 152 | defer { lock.unlock() } 153 | 154 | if values.count + count > capacity { 155 | return false 156 | } 157 | for value in values { 158 | buffer[writeIndex] = value 159 | writeIndex = (writeIndex + 1) % capacity 160 | count += 1 161 | } 162 | return true 163 | } 164 | 165 | func read(maxCount: Int) -> [Float] { 166 | lock.lock() 167 | defer { lock.unlock() } 168 | 169 | var values: [Float] = [] 170 | for _ in 0.. Double { 195 | Double(ringBuffer.currentCount()) / sampleRate 196 | } 197 | 198 | func startPlaying() throws { 199 | let audioFormat = AVAudioFormat(standardFormatWithSampleRate: self.sampleRate, channels: 1)! 200 | let sourceNode = AVAudioSourceNode(format: audioFormat) { 201 | _, _, frameCount, audioBufferList -> OSStatus in 202 | let audioBuffers = UnsafeMutableAudioBufferListPointer(audioBufferList) 203 | guard let channelData = audioBuffers[0].mData?.assumingMemoryBound(to: Float.self) 204 | else { 205 | // TODO: Get a proper error here that would work on ios. 206 | return 1 207 | } 208 | let data = self.ringBuffer.read(maxCount: Int(frameCount)) 209 | for i in 0.. Bool { 223 | ringBuffer.write(values) 224 | } 225 | } 226 | 227 | func setDefaultToStd() { 228 | #if os(iOS) 229 | do { 230 | let audioSession = AVAudioSession.sharedInstance() 231 | try audioSession.setCategory(.playAndRecord, mode: .default, options: [.allowAirPlay, .allowBluetooth]) 232 | try audioSession.setActive(true) 233 | try audioSession.overrideOutputAudioPort(.none) 234 | } catch { 235 | print("failed to configure audio session: \(error.localizedDescription)") 236 | } 237 | #endif 238 | } 239 | 240 | func setDefaultToSpeaker() { 241 | #if os(iOS) 242 | do { 243 | let audioSession = AVAudioSession.sharedInstance() 244 | try audioSession.setCategory(.playAndRecord, mode: .default, options: .defaultToSpeaker) 245 | try audioSession.setActive(true) 246 | try audioSession.overrideOutputAudioPort(.speaker) 247 | } catch { 248 | print("failed to configure audio session: \(error.localizedDescription)") 249 | } 250 | #endif 251 | } 252 | -------------------------------------------------------------------------------- /MoshiCLI/RunMimi.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import AVFoundation 6 | import Foundation 7 | import Hub 8 | import MLX 9 | import MLXNN 10 | import MoshiLib 11 | 12 | func makeMimi(numCodebooks: Int) throws -> Mimi { 13 | let cfg = MimiConfig.mimi_2024_07(numCodebooks: numCodebooks) 14 | let model = Mimi(cfg, bSize: 1) 15 | 16 | let url = try downloadFromHub( 17 | id: "lmz/moshi-swift", 18 | filename: "tokenizer-dbaa9758-checkpoint125.safetensors") 19 | let origWeights = try loadArrays(url: url) 20 | var weights: [String: MLXArray] = [:] 21 | for (var key, var weight) in origWeights { 22 | // Mutating the keys while iterating over the map seems pretty dodgy, not sure what the idiomatic 23 | // way to do this is in swift. Hopefully this is copy on write and it's all good :) 24 | if key.hasPrefix("encoder.model") { 25 | key.replace("encoder.model.", with: "encoder.") 26 | } 27 | if key.hasPrefix("decoder.model") { 28 | key.replace("decoder.model.", with: "decoder.") 29 | } 30 | if key.hasSuffix(".in_proj_weight") { 31 | key.replace(".in_proj_weight", with: ".in_proj.weight") 32 | } 33 | if key.hasSuffix(".linear1.weight") { 34 | key.replace(".linear1.weight", with: ".gating.linear1.weight") 35 | } 36 | if key.hasSuffix(".linear2.weight") { 37 | key.replace(".linear2.weight", with: ".gating.linear2.weight") 38 | } 39 | // Awfully hardcoded matching between the pytorch layers and their mlx equivalent :( 40 | for (layerIdx, decoderIdx) in [2, 5, 8, 11].enumerated() { 41 | key.replace("decoder.\(decoderIdx).", with: "decoder.layers.\(layerIdx).upsample.") 42 | key.replace( 43 | "decoder.\(decoderIdx + 1).", with: "decoder.layers.\(layerIdx).residuals.0.") 44 | } 45 | for (layerIdx, encoderIdx) in [1, 4, 7, 10].enumerated() { 46 | key.replace("encoder.\(encoderIdx).", with: "encoder.layers.\(layerIdx).residuals.0.") 47 | key.replace( 48 | "encoder.\(encoderIdx + 2).", with: "encoder.layers.\(layerIdx).downsample.") 49 | } 50 | key.replace("decoder.0.", with: "decoder.init_conv1d.") 51 | key.replace("decoder.14.", with: "decoder.final_conv1d.") 52 | key.replace("encoder.0.", with: "encoder.init_conv1d.") 53 | key.replace("encoder.14.", with: "encoder.final_conv1d.") 54 | key.replace(".block.1.", with: ".block.0.") 55 | key.replace(".block.3.", with: ".block.1.") 56 | // PyTorch layout for conv weights is outC, inC, kSize, for MLX it's outC, kSize, inC 57 | if key.hasSuffix(".conv.weight") || key.hasSuffix(".output_proj.weight") 58 | || key.hasSuffix(".input_proj.weight") 59 | { 60 | weight = weight.swappedAxes(-1, -2) 61 | } 62 | // PyTorch layout for conv-transposed weights is inC, outC, kSize, for MLX it's outC, kSize, inC 63 | if key.hasSuffix(".convtr.weight") { 64 | weight = weight.transposed(axes: [1, 2, 0]) 65 | } 66 | 67 | // print(key, weight.shape) 68 | weights[key] = weight 69 | } 70 | let parameters = ModuleParameters.unflattened(weights) 71 | try model.update(parameters: parameters, verify: [.all]) 72 | return model 73 | } 74 | 75 | func runMimi(streaming: Bool, audioFile: URL?, channel: Int = 0) throws { 76 | let model = try makeMimi(numCodebooks: 16) 77 | print("using device \(Device.defaultDevice().description)") 78 | let sampleURL = 79 | switch audioFile { 80 | case .none: try downloadFromHub(id: "lmz/moshi-swift", filename: "bria-24khz.mp3") 81 | case .some(let url): url 82 | } 83 | let pcm = readAudioToPCMArray(fileURL: sampleURL, channel: channel)! 84 | 85 | if streaming { 86 | let chunkSize = 1920 87 | var pcmOuts: [[Float]] = [] 88 | var elapsedTimes: [Double] = [] 89 | var nSteps = 0 90 | for start in stride(from: 0, to: pcm.count, by: chunkSize) { 91 | let startTime = CFAbsoluteTimeGetCurrent() 92 | let pct = 100 * start / pcm.count 93 | let end = min(start + chunkSize, pcm.count) 94 | let pcmA = MLXArray(pcm[start.. 125 { 200 | break 201 | } 202 | } 203 | } 204 | try save( 205 | arrays: ["codes": concatenated(allCodes, axis: -1)], 206 | url: URL(fileURLWithPath: "mimi-codes.safetensors")) 207 | microphoneCapture.stopCapturing() 208 | 209 | } 210 | -------------------------------------------------------------------------------- /MoshiCLI/RunMoshi.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import AVFoundation 6 | import Foundation 7 | import MLX 8 | import MLXNN 9 | import MoshiLib 10 | 11 | func makeMoshi(_ url: URL, _ cfg: LmConfig) throws -> LM { 12 | let weights = try loadArrays(url: url) 13 | let parameters = ModuleParameters.unflattened(weights) 14 | let model = LM(cfg, bSize: 1) 15 | if url.lastPathComponent.hasSuffix(".q4.safetensors") { 16 | quantize(model: model, groupSize: 32, bits: 4) 17 | } else if url.lastPathComponent.hasSuffix(".q6.safetensors") { 18 | quantize(model: model, groupSize: 64, bits: 6) 19 | } else if url.lastPathComponent.hasSuffix(".q8.safetensors") { 20 | quantize(model: model, groupSize: 64, bits: 8) 21 | } 22 | try model.update(parameters: parameters, verify: [.all]) 23 | eval(model) 24 | return model 25 | } 26 | 27 | func loadVocab(_ cfg: LmConfig) throws -> [Int: String] { 28 | let filename = 29 | switch cfg.textOutVocabSize { 30 | case 48000: "tokenizer_spm_48k_multi6_2.json" 31 | case 32000: "tokenizer_spm_32k_3.json" 32 | case 8000: "tokenizer_spm_8k_0.json" 33 | case 4000: "test_en_audio_4000.json" 34 | case let other: fatalError("unexpected text vocab size \(other)") 35 | } 36 | let fileURL = try downloadFromHub(id: "lmz/moshi-swift", filename: filename) 37 | let jsonData = try Data(contentsOf: fileURL) 38 | let dictionary = try JSONDecoder().decode([Int: String].self, from: jsonData) 39 | return dictionary 40 | } 41 | 42 | func runMoshiMic(_ url: URL, cfg: LmConfig) throws { 43 | let mimi = try makeMimi(numCodebooks: 16) 44 | let moshi = try makeMoshi(url, cfg) 45 | let vocab = try loadVocab(cfg) 46 | print("using device \(Device.defaultDevice().description)") 47 | print("warming up mimi") 48 | mimi.warmup() 49 | print("warming up moshi") 50 | moshi.warmup() 51 | print("done warming up") 52 | 53 | let maxSteps = moshi.cfg.transformer.maxSeqLen 54 | let gen = LMGen(moshi, maxSteps: maxSteps, audioSampler: Sampler(), textSampler: Sampler()) 55 | 56 | let microphoneCapture = MicrophoneCapture() 57 | microphoneCapture.startCapturing() 58 | let player = AudioPlayer(sampleRate: 24000) 59 | try player.startPlaying() 60 | print("started the audio loops") 61 | 62 | while let pcm = microphoneCapture.receive() { 63 | let pcm = MLXArray(pcm)[.newAxis, .newAxis] 64 | let codes = mimi.encodeStep(StreamArray(pcm)) 65 | if let codes = codes.asArray() { 66 | let (_, _, steps) = codes.shape3 67 | for step in 0.. URL { 15 | let targetURL = HubApi().localRepoLocation(Hub.Repo(id: id)).appending(path: filename) 16 | if FileManager.default.fileExists(atPath: targetURL.path) { 17 | print("using cached file \(targetURL.path)") 18 | return targetURL 19 | } 20 | var url: URL? = nil 21 | let semaphore = DispatchSemaphore(value: 0) 22 | Task { 23 | let repo = Hub.Repo(id: id) 24 | do { 25 | url = try await Hub.snapshot(from: repo, matching: filename) { progress in 26 | let pct = Int(progress.fractionCompleted * 100) 27 | print("\rretrieving \(filename): \(pct)%", terminator: "") 28 | } 29 | } catch { 30 | fatalError("cannot fetch \(id) \(filename): \(error)") 31 | } 32 | semaphore.signal() 33 | } 34 | semaphore.wait() 35 | print("\rretrieved \(filename)") 36 | return url!.appending(path: filename) 37 | } 38 | 39 | func maybeDownloadFromHub(filename: String) throws -> URL { 40 | // dowloadFromHub(id: id, filename: filename) 41 | let prefix = "hf://" 42 | if filename.hasPrefix(prefix) { 43 | let rest = filename.dropFirst(prefix.count) 44 | let components = rest.split(separator: "/", omittingEmptySubsequences: false) 45 | let id = components[0.. any Tokenizer { 54 | var tokenizer: (any Tokenizer)? = nil 55 | let semaphore = DispatchSemaphore(value: 0) 56 | Task { 57 | do { 58 | tokenizer = try await AutoTokenizer.from(pretrained: hfRepo) 59 | } catch { 60 | fatalError("cannot build tokenizer \(error)") 61 | } 62 | semaphore.signal() 63 | } 64 | semaphore.wait() 65 | return tokenizer! 66 | } 67 | 68 | @main 69 | struct Moshi: ParsableCommand { 70 | static let configuration = CommandConfiguration( 71 | subcommands: [ 72 | Run.self, RunHelium.self, RunMimi.self, AudioToCodes.self, CodesToAudio.self, 73 | RunAsr.self, RunQwen.self, 74 | ] 75 | ) 76 | } 77 | 78 | public enum Config: String, CaseIterable, ExpressibleByArgument { 79 | case moshi1b 80 | case moshi7b 81 | } 82 | 83 | struct Run: ParsableCommand { 84 | @Argument(help: "the model to run") 85 | var model: String 86 | 87 | @Option(help: "the file to process, use 'mic' for the microphone") 88 | var input: String? 89 | 90 | @Option(help: "the audio delay to apply") 91 | var audioDelay: Int = 2 92 | 93 | @Option(help: "the audio channel from the input file to be used") 94 | var channel: Int = 0 95 | 96 | @Option(help: "the config size") 97 | var config: Config = .moshi1b 98 | 99 | mutating func run() throws { 100 | let model = URL(fileURLWithPath: model) 101 | let cfg = 102 | switch config { 103 | case .moshi1b: LmConfig.moshi1b(audioDelay: audioDelay) 104 | case .moshi7b: LmConfig.moshi_2024_07() 105 | } 106 | 107 | switch input { 108 | case .none: 109 | try runMoshi(model, cfg: cfg, audioFile: nil) 110 | case .some("mic"): try runMoshiMic(model, cfg: cfg) 111 | case .some(let input): 112 | let audioFile = URL(fileURLWithPath: input) 113 | try runMoshi( 114 | model, cfg: cfg, audioFile: audioFile, channel: channel) 115 | } 116 | } 117 | } 118 | 119 | struct RunMimi: ParsableCommand { 120 | @Option(help: "whether to use the streaming mode or not") 121 | var streaming: Bool = false 122 | 123 | @Option(help: "the file to process") 124 | var input: String? 125 | 126 | @Option(help: "the audio channel from the input file to be used") 127 | var channel: Int = 0 128 | 129 | mutating func run() throws { 130 | let audioFile = input.flatMap { URL(fileURLWithPath: $0) } 131 | try runMimi(streaming: streaming, audioFile: audioFile, channel: channel) 132 | } 133 | } 134 | 135 | public enum HeliumConfig: String, CaseIterable, ExpressibleByArgument { 136 | case q4 137 | case q6 138 | case q8 139 | case bf16 140 | } 141 | 142 | struct RunQwen: ParsableCommand { 143 | @Option(help: "the config") 144 | var hfRepo: String = "Qwen/Qwen2.5-0.5B-Instruct" 145 | 146 | @Option(help: "the prompt to be used") 147 | var prompt: String = "Describe the swift programming language." 148 | 149 | @Option(help: "the number of tokens to generate") 150 | var n: Int = 256 151 | 152 | mutating func run() throws { 153 | let tokenizer = try makeTokenizer(hfRepo: hfRepo) 154 | let messages = [["role": "user", "content": prompt]] 155 | let encodedPrompt = try tokenizer.applyChatTemplate(messages: messages) 156 | let configUrl = try downloadFromHub(id: hfRepo, filename: "config.json") 157 | let configData = try Data(contentsOf: configUrl) 158 | let decoder = JSONDecoder() 159 | decoder.keyDecodingStrategy = .convertFromSnakeCase 160 | let config = try decoder.decode(QwenConfig.self, from: configData) 161 | print("config \(config)") 162 | let modelUrl = try downloadFromHub(id: hfRepo, filename: "model.safetensors") 163 | print("model \(modelUrl)") 164 | let weights = try loadArrays(url: modelUrl) 165 | guard let modelItem = ModuleParameters.unflattened(weights)["model"] else { 166 | fatalError("no model key in {configUrl}") 167 | } 168 | let parameters = 169 | switch modelItem { 170 | case .dictionary(let d): NestedDictionary(values: d) 171 | default: fatalError("model key in {configUrl} is not a dict") 172 | } 173 | 174 | let model = QwenModel(config) 175 | if let q = config.quantization { 176 | quantize(model: model, groupSize: q.groupSize, bits: q.bits) 177 | } 178 | try model.update(parameters: parameters, verify: [.all]) 179 | eval(model) 180 | let cache = model.makeCache(bSize: 1) 181 | let sampler = Sampler() 182 | var lastToken = config.bosTokenId 183 | let startTime = CFAbsoluteTimeGetCurrent() 184 | var nTokens = 0 185 | for index in 0...(n + prompt.count) { 186 | let logits = model(MLXArray([lastToken]).reshaped(1, 1), cache: cache) 187 | if index < encodedPrompt.count { 188 | lastToken = encodedPrompt[index] 189 | } else { 190 | let (tok, _) = sampler(logits: logits[0]) 191 | lastToken = tok.item() 192 | } 193 | let s = tokenizer.decode(tokens: [lastToken]) 194 | print("\(s)", terminator: "") 195 | fflush(stdout) 196 | nTokens += 1 197 | } 198 | print() 199 | let elapsedTime = CFAbsoluteTimeGetCurrent() - startTime 200 | print("\(nTokens) tokens generated, \(Double(nTokens) / elapsedTime) tok/s") 201 | } 202 | } 203 | 204 | struct RunHelium: ParsableCommand { 205 | @Option(help: "the config") 206 | var config: HeliumConfig = .q4 207 | 208 | mutating func run() throws { 209 | let cfg = LmConfig.helium2b() 210 | let filename = 211 | switch config { 212 | case .q4: "helium-1-preview-2b-q4.safetensors" 213 | case .q6: "helium-1-preview-2b-q6.safetensors" 214 | case .q8: "helium-1-preview-2b-q8.safetensors" 215 | case .bf16: "helium-1-preview-2b-bf16.safetensors" 216 | } 217 | let url = try downloadFromHub(id: "kyutai/helium-1-preview-2b-mlx", filename: filename) 218 | try runHelium(url, cfg: cfg) 219 | } 220 | } 221 | 222 | struct AudioToCodes: ParsableCommand { 223 | mutating func run() throws { 224 | try runAudioToCodes() 225 | } 226 | } 227 | 228 | struct CodesToAudio: ParsableCommand { 229 | @Option(help: "whether to write the output file or not") 230 | var writeFile: Bool = false 231 | 232 | mutating func run() throws { 233 | try runCodesToAudio(writeFile: writeFile) 234 | } 235 | } 236 | 237 | struct RunAsr: ParsableCommand { 238 | @Argument(help: "the model to run") 239 | var model: String 240 | 241 | @Option(help: "the file to process, use 'mic' for the microphone") 242 | var input: String? 243 | 244 | @Option(help: "the audio channel from the input file to be used") 245 | var channel: Int = 0 246 | 247 | mutating func run() throws { 248 | let model = try maybeDownloadFromHub(filename: model) 249 | let weights = try loadArrays(url: model) 250 | let cfg = 251 | switch weights["out_norm.weight"]?.shape { 252 | case .none: fatalError("no out_norm.weight tensor in \(model)") 253 | case .some([1024]): LmConfig.asr300m() 254 | case .some([2048]): LmConfig.asr1b() 255 | case .some([2560]): LmConfig.asr2b() 256 | case .some(let s): fatalError("unexpected shape for out_norm.weight \(s)") 257 | } 258 | print("here2") 259 | switch input { 260 | case .none: 261 | try runAsr(model, cfg, audioFile: nil, channel: channel) 262 | case .some("mic"): try runAsrMic(model, cfg) 263 | case .some(let input): 264 | let audioFile = URL(fileURLWithPath: input) 265 | try runAsr(model, cfg, audioFile: audioFile, channel: channel) 266 | } 267 | } 268 | } 269 | -------------------------------------------------------------------------------- /MoshiLib/Seanet.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import MLX 6 | import MLXFast 7 | import MLXNN 8 | import MLXRandom 9 | 10 | public struct SeanetConfig { 11 | public var dimension: Int 12 | public var channels: Int 13 | public var causal: Bool 14 | public var nFilters: Int 15 | public var nResidualLayers: Int 16 | public var ratios: [Int] 17 | // public var activation: String: hardcoded to Elu(1) for now 18 | // public var norm: Norm 19 | public var kernelSize: Int 20 | public var residualKernelSize: Int 21 | public var lastKernelSize: Int 22 | public var dilationBase: Int 23 | public var padMode: PadMode 24 | public var trueSkip: Bool 25 | public var compress: Int 26 | // public var disableNormOuterBlocks: Int 27 | // public var finalActivation: String?: hardcoded to None for now 28 | 29 | public static func v0_1() -> SeanetConfig { 30 | SeanetConfig( 31 | dimension: 512, channels: 1, causal: true, nFilters: 64, nResidualLayers: 1, 32 | ratios: [8, 6, 5, 4], kernelSize: 7, residualKernelSize: 3, lastKernelSize: 3, 33 | dilationBase: 2, padMode: .constant, trueSkip: true, compress: 2) 34 | } 35 | } 36 | 37 | class SeanetResnetBlock: Module, UnaryLayer, StreamingLayer { 38 | let skipOp: StreamingBinOp 39 | @ModuleInfo(key: "block") var block: [StreamableConv1d] 40 | @ModuleInfo(key: "shortcut") var shortcut: StreamableConv1d? 41 | 42 | init(_ cfg: SeanetConfig, dim: Int, kSizesAndDilations: [(Int, Int)]) { 43 | self.skipOp = StreamingBinOp(.add, axis: -1) 44 | var block: [StreamableConv1d] = [] 45 | var shortcut: StreamableConv1d? = nil 46 | let hidden = dim / cfg.compress 47 | 48 | for (i, (kSize, dilation)) in kSizesAndDilations.enumerated() { 49 | let inC = i == 0 ? dim : hidden 50 | let outC = i == kSizesAndDilations.count - 1 ? dim : hidden 51 | let c = StreamableConv1d( 52 | inC: inC, outC: outC, kSize: kSize, stride: 1, dilation: dilation, groups: 1, 53 | bias: true, causal: cfg.causal, padMode: cfg.padMode) 54 | block.append(c) 55 | } 56 | if !cfg.trueSkip { 57 | shortcut = StreamableConv1d( 58 | inC: dim, outC: dim, kSize: 1, stride: 1, dilation: 1, groups: 1, bias: true, 59 | causal: cfg.causal, padMode: cfg.padMode) 60 | } 61 | 62 | self._block.wrappedValue = block 63 | self._shortcut.wrappedValue = shortcut 64 | } 65 | 66 | func callAsFunction(_ x: MLXArray) -> MLXArray { 67 | let residual = x 68 | var x = x 69 | for b in self.block { 70 | x = b(elu(x, alpha: 1.0)) 71 | } 72 | if let shortcut = self.shortcut { 73 | return x + shortcut(residual) 74 | } else { 75 | return x + residual 76 | } 77 | } 78 | 79 | func resetState() { 80 | self.skipOp.resetState() 81 | for b in self.block { 82 | b.resetState() 83 | } 84 | if let s = self.shortcut { 85 | s.resetState() 86 | } 87 | } 88 | 89 | func step(_ x: StreamArray) -> StreamArray { 90 | let residual = x 91 | var x = x 92 | for b in self.block { 93 | x = b.step(x.elu()) 94 | } 95 | if let shortcut = self.shortcut { 96 | return self.skipOp.step(x, shortcut.step(residual)) 97 | } else { 98 | return self.skipOp.step(x, residual) 99 | } 100 | } 101 | } 102 | 103 | class EncoderLayer: Module, UnaryLayer, StreamingLayer { 104 | @ModuleInfo(key: "residuals") var residuals: [SeanetResnetBlock] 105 | @ModuleInfo(key: "downsample") var downsample: StreamableConv1d 106 | 107 | init(_ cfg: SeanetConfig, ratio: Int, mult: Int) { 108 | var residuals: [SeanetResnetBlock] = [] 109 | var dilation: Int = 1 110 | for _ in 0.. MLXArray { 127 | var x = x 128 | for r in self.residuals { 129 | x = r(x) 130 | } 131 | let y = self.downsample(elu(x, alpha: 1.0)) 132 | return y 133 | } 134 | 135 | func resetState() { 136 | for r in self.residuals { 137 | r.resetState() 138 | } 139 | self.downsample.resetState() 140 | } 141 | 142 | func step(_ x: StreamArray) -> StreamArray { 143 | var x = x 144 | for r in self.residuals { 145 | x = r.step(x) 146 | } 147 | return self.downsample.step(x.elu()) 148 | } 149 | } 150 | 151 | class SeanetEncoder: Module, StreamingLayer { 152 | @ModuleInfo(key: "init_conv1d") var initConv1d: StreamableConv1d 153 | @ModuleInfo(key: "layers") var layers: [EncoderLayer] 154 | @ModuleInfo(key: "final_conv1d") var finalConv1d: StreamableConv1d 155 | 156 | init(_ cfg: SeanetConfig) { 157 | var mult = 1 158 | let initConv1d = StreamableConv1d( 159 | inC: cfg.channels, outC: mult * cfg.nFilters, kSize: cfg.kernelSize, stride: 1, 160 | dilation: 1, groups: 1, bias: true, causal: cfg.causal, 161 | padMode: cfg.padMode) 162 | var layers: [EncoderLayer] = [] 163 | for ratio in cfg.ratios.reversed() { 164 | let layer = EncoderLayer(cfg, ratio: ratio, mult: mult) 165 | layers.append(layer) 166 | mult *= 2 167 | } 168 | let finalConv1d = StreamableConv1d( 169 | inC: mult * cfg.nFilters, outC: cfg.dimension, kSize: cfg.lastKernelSize, stride: 1, 170 | dilation: 1, groups: 1, bias: true, causal: cfg.causal, padMode: cfg.padMode) 171 | self._initConv1d.wrappedValue = initConv1d 172 | self._layers.wrappedValue = layers 173 | self._finalConv1d.wrappedValue = finalConv1d 174 | } 175 | 176 | func callAsFunction(_ x: MLXArray) -> MLXArray { 177 | var x = self.initConv1d(x) 178 | for layer in self.layers { 179 | x = layer(x) 180 | } 181 | return self.finalConv1d(elu(x, alpha: 1.0)) 182 | } 183 | 184 | func resetState() { 185 | self.initConv1d.resetState() 186 | self.finalConv1d.resetState() 187 | for l in self.layers { 188 | l.resetState() 189 | } 190 | } 191 | 192 | func step(_ x: StreamArray) -> StreamArray { 193 | var x = self.initConv1d.step(x) 194 | for layer in self.layers { 195 | x = layer.step(x) 196 | } 197 | return self.finalConv1d.step(x.elu()) 198 | } 199 | } 200 | 201 | class DecoderLayer: Module, UnaryLayer, StreamingLayer { 202 | @ModuleInfo(key: "upsample") var upsample: StreamableConvTranspose1d 203 | @ModuleInfo(key: "residuals") var residuals: [SeanetResnetBlock] 204 | 205 | init(_ cfg: SeanetConfig, ratio: Int, mult: Int) { 206 | var residuals: [SeanetResnetBlock] = [] 207 | var dilation = 1 208 | for _ in 0.. MLXArray { 225 | var x = self.upsample(elu(x, alpha: 1.0)) 226 | for r in self.residuals { 227 | x = r(x) 228 | } 229 | return x 230 | } 231 | 232 | func resetState() { 233 | self.upsample.resetState() 234 | for r in self.residuals { 235 | r.resetState() 236 | } 237 | } 238 | 239 | func step(_ x: StreamArray) -> StreamArray { 240 | var x = self.upsample.step(x.elu()) 241 | for r in self.residuals { 242 | x = r.step(x) 243 | } 244 | return x 245 | } 246 | } 247 | 248 | class SeanetDecoder: Module, StreamingLayer { 249 | @ModuleInfo(key: "init_conv1d") var initConv1d: StreamableConv1d 250 | @ModuleInfo(key: "layers") var layers: [DecoderLayer] 251 | @ModuleInfo(key: "final_conv1d") var finalConv1d: StreamableConv1d 252 | 253 | init(_ cfg: SeanetConfig) { 254 | var layers: [DecoderLayer] = [] 255 | var mult = 1 << cfg.ratios.count 256 | let initConv1d = StreamableConv1d( 257 | inC: cfg.dimension, outC: mult * cfg.nFilters, kSize: cfg.kernelSize, stride: 1, 258 | dilation: 1, groups: 1, bias: true, causal: cfg.causal, padMode: cfg.padMode) 259 | for ratio in cfg.ratios { 260 | let l = DecoderLayer(cfg, ratio: ratio, mult: mult) 261 | layers.append(l) 262 | mult /= 2 263 | } 264 | let finalConv1d = StreamableConv1d( 265 | inC: cfg.nFilters, outC: cfg.channels, kSize: cfg.lastKernelSize, stride: 1, 266 | dilation: 1, groups: 1, bias: true, causal: cfg.causal, padMode: cfg.padMode) 267 | self._initConv1d.wrappedValue = initConv1d 268 | self._layers.wrappedValue = layers 269 | self._finalConv1d.wrappedValue = finalConv1d 270 | } 271 | 272 | func callAsFunction(_ x: MLXArray) -> MLXArray { 273 | var x = self.initConv1d(x) 274 | for layer in self.layers { 275 | x = layer(x) 276 | } 277 | return self.finalConv1d(elu(x, alpha: 1.0)) 278 | } 279 | 280 | func resetState() { 281 | self.initConv1d.resetState() 282 | self.finalConv1d.resetState() 283 | for l in self.layers { 284 | l.resetState() 285 | } 286 | } 287 | 288 | func step(_ x: StreamArray) -> StreamArray { 289 | var x = self.initConv1d.step(x) 290 | for layer in self.layers { 291 | x = layer.step(x) 292 | } 293 | return self.finalConv1d.step(x.elu()) 294 | } 295 | } 296 | -------------------------------------------------------------------------------- /MoshiLib/Conv.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import MLX 6 | import MLXFast 7 | import MLXNN 8 | import MLXRandom 9 | 10 | // Conv1d + dilation 11 | class Conv1d: Module, UnaryLayer { 12 | @ModuleInfo(key: "weight") var weight: MLXArray 13 | @ModuleInfo(key: "bias") var bias: MLXArray? 14 | let padding: Int 15 | let groups: Int 16 | let stride: Int 17 | let dilation: Int 18 | 19 | init( 20 | inputChannels: Int, 21 | outputChannels: Int, 22 | kernelSize: Int, 23 | stride: Int = 1, 24 | padding: Int = 0, 25 | groups: Int = 1, 26 | dilation: Int = 1, 27 | bias: Bool = true 28 | ) { 29 | let scale = sqrt(1 / Float(inputChannels * kernelSize)) 30 | 31 | self._weight.wrappedValue = uniform( 32 | low: -scale, high: scale, [outputChannels, kernelSize, inputChannels / groups]) 33 | self._bias.wrappedValue = bias ? MLXArray.zeros([outputChannels]) : nil 34 | self.padding = padding 35 | self.groups = groups 36 | self.stride = stride 37 | self.dilation = dilation 38 | } 39 | 40 | func callAsFunction(_ x: MLXArray) -> MLXArray { 41 | // MLX uses NLC whereas pytorch/candle use NCL 42 | var y = conv1d( 43 | x.swappedAxes(-1, -2), weight, stride: stride, padding: padding, dilation: dilation, 44 | groups: groups 45 | ) 46 | if let bias { 47 | y = y + bias 48 | } 49 | y = y.swappedAxes(-1, -2) 50 | return y 51 | } 52 | } 53 | 54 | // ConvTranspose1d + groups 55 | class ConvTransposed1d: Module, UnaryLayer { 56 | @ModuleInfo(key: "weight") var weight: MLXArray 57 | @ModuleInfo(key: "bias") var bias: MLXArray? 58 | let padding: Int 59 | let stride: Int 60 | let groups: Int 61 | let inC: Int 62 | let outC: Int 63 | let kSize: Int 64 | var expandedWeight: MLXArray 65 | var expandedGroups: Int 66 | 67 | init( 68 | inputChannels: Int, 69 | outputChannels: Int, 70 | kernelSize: Int, 71 | stride: Int = 1, 72 | padding: Int = 0, 73 | groups: Int = 1, 74 | bias: Bool = true 75 | ) { 76 | let scale = sqrt(1 / Float(inputChannels * kernelSize)) 77 | 78 | self._weight.wrappedValue = uniform( 79 | low: -scale, high: scale, [outputChannels / groups, kernelSize, inputChannels]) 80 | let weight = self._weight.wrappedValue 81 | self._bias.wrappedValue = bias ? MLXArray.zeros([outputChannels]) : nil 82 | self.padding = padding 83 | self.stride = stride 84 | self.groups = groups 85 | self.inC = inputChannels 86 | self.outC = outputChannels 87 | self.kSize = kernelSize 88 | if groups == inC && groups == outC { 89 | let eye = repeated( 90 | eye(outC).asType(weight.dtype).reshaped([outC, 1, outC]), count: kSize, axis: 1) 91 | self.expandedWeight = repeated(weight, count: groups, axis: 0) * eye 92 | self.expandedGroups = 1 93 | } else if groups > 1 { 94 | fatalError("groups are not supported in ConvTranspose1d, \(groups), \(inC), \(outC)") 95 | } else { 96 | self.expandedWeight = weight 97 | self.expandedGroups = groups 98 | } 99 | } 100 | 101 | override func update(parameters: ModuleParameters, verify: Module.VerifyUpdate) throws -> Self { 102 | try super.update(parameters: parameters, verify: verify) 103 | if groups == inC && groups == outC { 104 | let eye = repeated( 105 | eye(outC).asType(weight.dtype).reshaped([outC, 1, outC]), count: kSize, axis: 1) 106 | self.expandedWeight = repeated(weight, count: groups, axis: 0) * eye 107 | self.expandedGroups = 1 108 | } else if groups > 1 { 109 | fatalError("groups are not supported in ConvTranspose1d, \(groups), \(inC), \(outC)") 110 | } else { 111 | self.expandedWeight = weight 112 | self.expandedGroups = groups 113 | } 114 | return self 115 | } 116 | 117 | open func callAsFunction(_ x: MLXArray) -> MLXArray { 118 | // Groups are not supported in convTransposed1d as of 0.18.1 so we hack our way around it. 119 | var y = convTransposed1d( 120 | x.swappedAxes(-1, -2), expandedWeight, stride: stride, padding: padding, 121 | groups: expandedGroups 122 | ) 123 | if let bias { 124 | y = y + bias 125 | } 126 | return y.swappedAxes(-1, -2) 127 | } 128 | } 129 | 130 | // weight-norm is handled externally when generating the weights file so this class is just 131 | // a wrapper around Conv1d 132 | class NormConv1d: Module, UnaryLayer { 133 | @ModuleInfo(key: "conv") var conv: Conv1d 134 | 135 | init( 136 | inC: Int, outC: Int, kSize: Int, 137 | stride: Int = 1, 138 | padding: Int = 0, 139 | groups: Int = 1, 140 | dilation: Int = 1, 141 | bias: Bool = true 142 | ) { 143 | self._conv.wrappedValue = Conv1d( 144 | inputChannels: inC, outputChannels: outC, kernelSize: kSize, stride: stride, 145 | padding: padding, 146 | groups: groups, dilation: dilation, bias: bias) 147 | } 148 | 149 | func callAsFunction(_ x: MLXArray) -> MLXArray { 150 | self.conv(x) 151 | } 152 | } 153 | 154 | class NormConvTranspose1d: Module, UnaryLayer { 155 | @ModuleInfo(key: "convtr") var convtr: ConvTransposed1d 156 | 157 | init( 158 | inC: Int, outC: Int, kSize: Int, 159 | stride: Int = 1, 160 | padding: Int = 0, 161 | groups: Int = 1, 162 | bias: Bool = true 163 | ) { 164 | self._convtr.wrappedValue = ConvTransposed1d( 165 | inputChannels: inC, outputChannels: outC, kernelSize: kSize, stride: stride, 166 | padding: padding, groups: groups, bias: bias) 167 | } 168 | 169 | func callAsFunction(_ x: MLXArray) -> MLXArray { 170 | self.convtr(x) 171 | } 172 | } 173 | 174 | func getExtraPaddingForConv1d(_ x: MLXArray, kSize: Int, stride: Int, paddingTotal: Int) -> Int { 175 | let len = x.dim(-1) 176 | let nFrames = Float(max(len + paddingTotal - kSize, 0)) / Float(stride) + 1.0 177 | let idealLen = (Int(nFrames.rounded(.up)) - 1) * stride + kSize - paddingTotal 178 | return max(0, idealLen - len) 179 | } 180 | 181 | func unpad1d(_ x: MLXArray, unpadL: Int, unpadR: Int) -> MLXArray { 182 | let len = x.dim(-1) 183 | let left = unpadL 184 | let right = len - unpadR 185 | return x[.ellipsis, left.. MLXArray { 209 | var kSize = self.kSize 210 | // Effective kernel size with dilations. 211 | kSize = (kSize - 1) * self.conv.conv.dilation + 1 212 | let paddingTotal = kSize - self.conv.conv.stride 213 | let extraPadding = getExtraPaddingForConv1d( 214 | x, kSize: kSize, stride: self.conv.conv.stride, paddingTotal: paddingTotal) 215 | var pd: MLXArray 216 | let z = IntOrPair.init((0, 0)) 217 | if self.causal { 218 | let widths = [z, z, IntOrPair((paddingTotal, extraPadding))] 219 | pd = padded(x, widths: widths, mode: self.padMode) 220 | } else { 221 | let paddingRight = paddingTotal / 2 222 | let paddingLeft = paddingTotal - paddingRight 223 | let widths = [z, z, IntOrPair((paddingLeft, paddingRight + extraPadding))] 224 | pd = padded(x, widths: widths, mode: self.padMode) 225 | } 226 | let y = self.conv(pd) 227 | return y 228 | } 229 | 230 | func resetState() { 231 | statePrevXs = StreamArray() 232 | self.leftPadApplied = false 233 | } 234 | 235 | func step(_ x: StreamArray) -> StreamArray { 236 | if var inner = x.inner { 237 | let stride = self.conv.conv.stride 238 | let dilation = self.conv.conv.dilation 239 | if !self.leftPadApplied { 240 | self.leftPadApplied = true 241 | let kSize = (self.kSize - 1) * dilation + 1 242 | let paddingTotal = kSize - stride 243 | let z = IntOrPair.init((0, 0)) 244 | let widths = [z, z, IntOrPair((paddingTotal, 0))] 245 | inner = padded(inner, widths: widths, mode: self.padMode) 246 | } 247 | let kernel = (self.kSize - 1) * dilation + 1 248 | var x = StreamArray(inner) 249 | x = self.statePrevXs.cat2(x, axis: -1) 250 | let seqLen = x.dim(-1) 251 | let numFrames = max(seqLen + stride - kernel, 0) / stride 252 | if numFrames > 0 { 253 | let offset = numFrames * stride 254 | self.statePrevXs = x.narrow(offset, seqLen - offset, axis: -1) 255 | let inL = (numFrames - 1) * stride + kernel 256 | x = x.narrow(0, inL, axis: -1) 257 | if let x = x.inner { 258 | return StreamArray(self.conv.conv(x)) 259 | } else { 260 | return StreamArray() 261 | } 262 | } else { 263 | self.statePrevXs = x 264 | return StreamArray() 265 | } 266 | } else { 267 | return StreamArray() 268 | } 269 | } 270 | } 271 | 272 | class StreamableConvTranspose1d: Module, UnaryLayer, StreamingLayer { 273 | let causal: Bool 274 | let kSize: Int 275 | var statePrevYs: StreamArray = StreamArray() 276 | @ModuleInfo(key: "convtr") var convtr: NormConvTranspose1d 277 | 278 | init( 279 | inC: Int, outC: Int, kSize: Int, stride: Int, groups: Int, bias: Bool, 280 | causal: Bool 281 | ) { 282 | self.causal = causal 283 | self.kSize = kSize 284 | self._convtr.wrappedValue = NormConvTranspose1d( 285 | inC: inC, outC: outC, kSize: kSize, stride: stride, groups: groups, bias: bias) 286 | } 287 | 288 | func callAsFunction(_ x: MLXArray) -> MLXArray { 289 | let stride = self.convtr.convtr.stride 290 | let paddingTotal = max(self.kSize - stride, 0) 291 | let x = self.convtr(x) 292 | if self.causal { 293 | return unpad1d(x, unpadL: 0, unpadR: paddingTotal) 294 | } else { 295 | let unpadR = paddingTotal / 2 296 | let unpadL = paddingTotal - unpadR 297 | return unpad1d(x, unpadL: unpadL, unpadR: unpadR) 298 | } 299 | } 300 | 301 | func resetState() { 302 | statePrevYs = StreamArray() 303 | } 304 | 305 | func step(_ x: StreamArray) -> StreamArray { 306 | if let x = x.inner { 307 | let stride = self.convtr.convtr.stride 308 | var ys = self.convtr(x) 309 | let ot = ys.dim(-1) 310 | if var prevYs = self.statePrevYs.inner { 311 | let pt = prevYs.dim(-1) 312 | if let bias = self.convtr.convtr.bias { 313 | prevYs = prevYs - bias[.newAxis, 0..., .newAxis] 314 | } 315 | let ys1 = ys[.ellipsis, 0.. MLXArray { 339 | self.conv(x) 340 | } 341 | 342 | func resetState() { 343 | self.conv.resetState() 344 | } 345 | 346 | func step(_ x: StreamArray) -> StreamArray { 347 | self.conv.step(x) 348 | } 349 | } 350 | 351 | class ConvTrUpsample1d: Module, UnaryLayer { 352 | @ModuleInfo(key: "convtr") var convtr: StreamableConvTranspose1d 353 | 354 | init(stride: Int, dim: Int, causal: Bool) { 355 | self._convtr.wrappedValue = StreamableConvTranspose1d( 356 | inC: dim, outC: dim, kSize: 2 * stride, stride: stride, groups: dim, bias: false, 357 | causal: causal) 358 | } 359 | 360 | func callAsFunction(_ x: MLXArray) -> MLXArray { 361 | self.convtr(x) 362 | } 363 | 364 | func resetState() { 365 | self.convtr.resetState() 366 | } 367 | 368 | func step(_ x: StreamArray) -> StreamArray { 369 | self.convtr.step(x) 370 | } 371 | } 372 | -------------------------------------------------------------------------------- /MoshiLib/Transformer.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import MLX 6 | import MLXFast 7 | import MLXNN 8 | 9 | public enum Norm { 10 | case layerNorm 11 | case rmsNorm 12 | } 13 | 14 | public enum PositionalEmbedding { 15 | case none 16 | case rope 17 | } 18 | 19 | public struct TransformerConfig { 20 | public var dModel: Int 21 | public var numHeads: Int 22 | public var numLayers: Int 23 | public var causal: Bool 24 | public var normFirst: Bool 25 | public var biasFF: Bool 26 | public var biasAttn: Bool 27 | public var layerScale: Float? 28 | public var positionalEmbedding: PositionalEmbedding 29 | public var useConvBias: Bool 30 | public var gating: Bool 31 | public var norm: Norm 32 | public var context: Int 33 | public var maxPeriod: Int 34 | public var maxSeqLen: Int 35 | public var kvRepeat: Int 36 | public var dimFeedForward: Int 37 | public var convLayout: Bool 38 | public var useRotatingKVCache: Bool 39 | 40 | public func headDim() -> Int { 41 | self.dModel / self.numHeads 42 | } 43 | 44 | // kyutai/neilz/mimi_exp/xps/f28fe6d5/.hydra/config.yaml 45 | public static func v1_300m() -> TransformerConfig { 46 | TransformerConfig( 47 | dModel: 1024, 48 | numHeads: 8, 49 | numLayers: 16, 50 | causal: true, 51 | normFirst: true, 52 | biasFF: false, 53 | biasAttn: false, 54 | layerScale: nil, 55 | positionalEmbedding: .rope, 56 | useConvBias: false, 57 | gating: true, 58 | norm: .rmsNorm, 59 | context: 750, 60 | maxPeriod: 100000, 61 | maxSeqLen: 4096, 62 | kvRepeat: 1, 63 | dimFeedForward: 1024 * 4, 64 | convLayout: false, 65 | useRotatingKVCache: false 66 | ) 67 | } 68 | 69 | // kyutai/neilz/mimi_exp/xps/8d2516b9/.hydra/config.yaml 70 | public static func v1_1b() -> TransformerConfig { 71 | TransformerConfig( 72 | dModel: 2048, 73 | numHeads: 16, 74 | numLayers: 16, 75 | causal: true, 76 | normFirst: true, 77 | biasFF: false, 78 | biasAttn: false, 79 | layerScale: nil, 80 | positionalEmbedding: .rope, 81 | useConvBias: false, 82 | gating: true, 83 | norm: .rmsNorm, 84 | context: 3000, 85 | maxPeriod: 100000, 86 | maxSeqLen: 4096, 87 | kvRepeat: 1, 88 | dimFeedForward: 2048 * 4, 89 | convLayout: false, 90 | useRotatingKVCache: false 91 | ) 92 | } 93 | 94 | // kyutai/neilz/mimi_exp/xps/b3f79570/.hydra/config.yaml 95 | public static func v1_2b() -> TransformerConfig { 96 | TransformerConfig( 97 | dModel: 2560, 98 | numHeads: 20, 99 | numLayers: 24, 100 | causal: true, 101 | normFirst: true, 102 | biasFF: false, 103 | biasAttn: false, 104 | layerScale: nil, 105 | positionalEmbedding: .rope, 106 | useConvBias: false, 107 | gating: true, 108 | norm: .rmsNorm, 109 | context: 3000, 110 | maxPeriod: 100000, 111 | maxSeqLen: 4096, 112 | kvRepeat: 1, 113 | dimFeedForward: 2560 * 4, 114 | convLayout: false, 115 | useRotatingKVCache: false 116 | ) 117 | } 118 | 119 | public static func v1_7b() -> TransformerConfig { 120 | TransformerConfig( 121 | dModel: 4096, 122 | numHeads: 32, 123 | numLayers: 32, 124 | causal: true, 125 | normFirst: true, 126 | biasFF: false, 127 | biasAttn: false, 128 | layerScale: nil, 129 | positionalEmbedding: .rope, 130 | useConvBias: false, 131 | gating: true, 132 | norm: .rmsNorm, 133 | context: 3000, 134 | maxPeriod: 10000, 135 | maxSeqLen: 4096, 136 | kvRepeat: 1, 137 | dimFeedForward: 4096 * 4, 138 | convLayout: false, 139 | useRotatingKVCache: false 140 | ) 141 | } 142 | } 143 | 144 | private class MlpGating: Module, UnaryLayer { 145 | @ModuleInfo(key: "linear_in") var linear_in: Linear 146 | @ModuleInfo(key: "linear_out") var linear_out: Linear 147 | 148 | init(_ cfg: TransformerConfig) { 149 | let hidden = 150 | cfg.dimFeedForward == 4 * cfg.dModel ? 11 * cfg.dModel / 4 : 2 * cfg.dimFeedForward / 3 151 | self._linear_in.wrappedValue = Linear(cfg.dModel, 2 * hidden, bias: cfg.biasFF) 152 | self._linear_out.wrappedValue = Linear(hidden, cfg.dModel, bias: cfg.biasFF) 153 | } 154 | 155 | func callAsFunction(_ x: MLXArray) -> MLXArray { 156 | let x = linear_in(x) 157 | let (B, T) = (x.dim(0), x.dim(1)) 158 | let x_reshaped = x.reshaped(B, T, 2, -1) 159 | return linear_out(silu(x_reshaped[0..., 0..., 0]) * x_reshaped[0..., 0..., 1]) 160 | } 161 | } 162 | 163 | private class MlpNoGating: Module, UnaryLayer { 164 | @ModuleInfo(key: "linear1") var linear1: Linear 165 | @ModuleInfo(key: "linear2") var linear2: Linear 166 | 167 | init(_ cfg: TransformerConfig) { 168 | self._linear1.wrappedValue = Linear(cfg.dModel, cfg.dimFeedForward, bias: cfg.biasFF) 169 | self._linear2.wrappedValue = Linear(cfg.dimFeedForward, cfg.dModel, bias: cfg.biasFF) 170 | } 171 | 172 | func callAsFunction(_ x: MLXArray) -> MLXArray { 173 | linear2(geluApproximate(linear1(x))) 174 | } 175 | } 176 | 177 | private class Attention: Module { 178 | let cfg: TransformerConfig 179 | let scale: Float 180 | let rope: RoPE? 181 | 182 | @ModuleInfo(key: "in_proj") var inProj: Linear 183 | @ModuleInfo(key: "out_proj") var outProj: Linear 184 | 185 | init(_ cfg: TransformerConfig) { 186 | self.cfg = cfg 187 | self.scale = 1.0 / sqrt(Float(cfg.headDim())) 188 | let numKV = cfg.numHeads / cfg.kvRepeat 189 | let outDim = cfg.dModel + 2 * numKV * cfg.dModel / cfg.numHeads 190 | self._inProj.wrappedValue = Linear(cfg.dModel, outDim, bias: cfg.biasAttn) 191 | self._outProj.wrappedValue = Linear(cfg.dModel, cfg.dModel, bias: cfg.biasAttn) 192 | self.rope = 193 | switch cfg.positionalEmbedding { 194 | case .none: nil 195 | case .rope: 196 | RoPE(dimensions: cfg.headDim(), traditional: true, base: Float(cfg.maxPeriod)) 197 | } 198 | } 199 | 200 | func callAsFunction(_ x: MLXArray, mask: MLXArray?, cache: KVCache?) -> MLXArray { 201 | let (B, T, H) = (x.dim(0), x.dim(1), x.dim(2)) 202 | let qkv = inProj(x).reshaped(B, T, 3, cfg.numHeads, cfg.headDim()) 203 | var q = qkv[0..., 0..., 0].transposed(0, 2, 1, 3) 204 | var k = qkv[0..., 0..., 1].transposed(0, 2, 1, 3) 205 | var v = qkv[0..., 0..., 2].transposed(0, 2, 1, 3) 206 | if let rope { 207 | let offset = cache?.offset ?? 0 208 | q = rope(q, offset: offset) 209 | k = rope(k, offset: offset) 210 | } 211 | if let cache { 212 | (k, v) = cache.update(keys: k, values: v) 213 | } 214 | let kLen = k.dim(2) 215 | let kTargetLen = T + min(self.cfg.context, kLen - T) 216 | if kTargetLen < kLen { 217 | let offset = kLen - kTargetLen 218 | k = k[0..., 0..., offset...] 219 | v = v[0..., 0..., offset...] 220 | } 221 | 222 | var mask = mask 223 | if let m = mask { 224 | let maskLen = m.dim(-1) 225 | if k.dim(2) < maskLen { 226 | let offset = maskLen - k.dim(2) 227 | mask = m[0..., offset...] 228 | } 229 | } 230 | let x = MLXFast.scaledDotProductAttention( 231 | queries: q, keys: k, values: v, scale: self.scale, mask: mask 232 | ).transposed(0, 2, 1, 3).reshaped(B, T, H) 233 | return outProj(x) 234 | } 235 | } 236 | 237 | class LayerScale: Module, UnaryLayer { 238 | @ModuleInfo(key: "scale") var scale: MLXArray 239 | 240 | init(_ dModel: Int, initValue: Float) { 241 | self._scale.wrappedValue = ones([dModel]) * initValue 242 | } 243 | 244 | func callAsFunction(_ x: MLXArray) -> MLXArray { 245 | x * self.scale 246 | } 247 | } 248 | 249 | private class TransformerLayer: Module { 250 | @ModuleInfo(key: "gating") var gating: UnaryLayer 251 | @ModuleInfo(key: "norm1") var norm1: UnaryLayer 252 | @ModuleInfo(key: "norm2") var norm2: UnaryLayer 253 | @ModuleInfo(key: "layer_scale_1") var layerScale1: LayerScale? 254 | @ModuleInfo(key: "layer_scale_2") var layerScale2: LayerScale? 255 | @ModuleInfo(key: "self_attn") var selfAttn: Attention 256 | 257 | init(_ cfg: TransformerConfig) { 258 | self._gating.wrappedValue = cfg.gating ? MlpGating(cfg) : MlpNoGating(cfg) 259 | self._norm1.wrappedValue = 260 | switch cfg.norm { 261 | case .layerNorm: 262 | LayerNorm(dimensions: cfg.dModel, eps: 1e-5) 263 | case .rmsNorm: RMSNorm(dimensions: cfg.dModel, eps: 1e-8) 264 | } 265 | self._norm2.wrappedValue = 266 | switch cfg.norm { 267 | case .layerNorm: 268 | LayerNorm(dimensions: cfg.dModel, eps: 1e-5) 269 | case .rmsNorm: RMSNorm(dimensions: cfg.dModel, eps: 1e-8) 270 | } 271 | self._selfAttn.wrappedValue = Attention(cfg) 272 | 273 | if let scale = cfg.layerScale { 274 | self._layerScale1.wrappedValue = LayerScale(cfg.dModel, initValue: scale) 275 | self._layerScale2.wrappedValue = LayerScale(cfg.dModel, initValue: scale) 276 | } else { 277 | self._layerScale1.wrappedValue = nil 278 | self._layerScale2.wrappedValue = nil 279 | } 280 | } 281 | 282 | func callAsFunction(_ x: MLXArray, mask: MLXArray?, cache: KVCache) -> MLXArray { 283 | var residual = x 284 | var x = x 285 | x = selfAttn(norm1(x), mask: mask, cache: cache) 286 | if let layerScale1 = self.layerScale1 { 287 | x = layerScale1(x) 288 | } 289 | x = residual + x 290 | residual = x 291 | x = gating(norm2(x)) 292 | if let layerScale2 = self.layerScale2 { 293 | x = layerScale2(x) 294 | } 295 | return residual + x 296 | } 297 | } 298 | 299 | public class Transformer: Module { 300 | let cfg: TransformerConfig 301 | private let layers: [TransformerLayer] 302 | 303 | public init(_ cfg: TransformerConfig) { 304 | self.cfg = cfg 305 | self.layers = (0.. MLXArray { 309 | var x = x 310 | let mask = cache.first?.createAttentionMask(h: x) 311 | for (layer, c) in zip(self.layers, cache) { 312 | x = layer(x, mask: mask, cache: c) 313 | } 314 | return x 315 | } 316 | 317 | public func makeCache(bSize: Int) -> [KVCache] { 318 | let kvHeads = cfg.numHeads / cfg.kvRepeat 319 | let dtype = self.layers.first!.selfAttn.inProj.weight.dtype 320 | let cache = (0.. [MLXArray] { 359 | var x = x 360 | if self.convLayout { 361 | x = x.swappedAxes(1, 2) 362 | } 363 | if let inputProj = self.inputProj { 364 | x = inputProj(x) 365 | } 366 | x = self.transformer(x, cache: cache) 367 | var outs: [MLXArray] = [] 368 | for outputProj in self.outputProjs { 369 | var out = outputProj?(x) ?? x 370 | if self.convLayout { 371 | out = out.swappedAxes(1, 2) 372 | } 373 | outs.append(out) 374 | } 375 | return outs 376 | } 377 | 378 | public func makeCache(bSize: Int) -> [KVCache] { 379 | self.transformer.makeCache(bSize: bSize) 380 | } 381 | } 382 | -------------------------------------------------------------------------------- /Moshi/ModelView.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | // 6 | // SwiftUIView.swift 7 | // Moshi 8 | // 9 | // Created by Sebastien Mazare on 06/12/2024. 10 | // 11 | 12 | import Hub 13 | import MLX 14 | import MLXNN 15 | import MLXRandom 16 | import Metal 17 | import MoshiLib 18 | import SwiftUI 19 | import Synchronization 20 | import AVFoundation 21 | 22 | struct ModelView: View { 23 | @Binding var model: Evaluator 24 | let modelType: ModelSelect 25 | @Environment(DeviceStat.self) private var deviceStat 26 | @State var sendToSpeaker = false 27 | @State private var showSettings = false 28 | 29 | var body: some View { 30 | VStack(spacing: 16) { 31 | Group { 32 | if deviceStat.gpuUsage.activeMemory > 0 || model.statsSummary.step.cnt > 0 { 33 | CombinedStatsView( 34 | summary: model.statsSummary, 35 | deviceStat: deviceStat, 36 | modelInfo: model.modelInfo, 37 | modelName: model.modelName, 38 | urls: model.urls 39 | ) 40 | .padding() 41 | .background(RoundedRectangle(cornerRadius: 12).fill(.secondary.opacity(0.1))) 42 | } 43 | } 44 | .transition(.push(from: .top)) 45 | .animation(.easeInOut(duration: 0.2), value: deviceStat.gpuUsage.activeMemory > 0 || model.statsSummary.step.cnt > 0) 46 | 47 | Group { 48 | if !model.running && model.output.isEmpty { 49 | ModelInfoView(modelType: modelType) 50 | } else { 51 | OutputSection(output: model.output) 52 | } 53 | } 54 | .transition(.opacity) 55 | .animation(.easeInOut(duration: 0.2), value: !model.running && model.output.isEmpty) 56 | 57 | Group { 58 | if model.running || !model.output.isEmpty { 59 | StatusSection(model: model) 60 | } 61 | } 62 | .transition(.push(from: .bottom)) 63 | .animation(.easeInOut(duration: 0.2), value: model.running || !model.output.isEmpty) 64 | 65 | // Bottom controls 66 | ZStack { 67 | // Centered Start/Stop button 68 | Button(action: model.running ? stopGenerate : generate) { 69 | Label(model.running ? "Stop" : "Start", 70 | systemImage: model.running ? "stop.circle.fill" : "play.circle.fill") 71 | .font(.title2) 72 | } 73 | .buttonStyle(.borderedProminent) 74 | 75 | // Right-aligned settings button 76 | HStack { 77 | Spacer() 78 | Button(action: { showSettings.toggle() }) { 79 | Image(systemName: "gear") 80 | .font(.title2) 81 | } 82 | .popover(isPresented: $showSettings, arrowEdge: .bottom) { 83 | Toggle(isOn: $sendToSpeaker) { 84 | Label("Use External Speaker", systemImage: "speaker.wave.2") 85 | } 86 | .padding() 87 | .onChange(of: sendToSpeaker) { (_, newValue) in 88 | if newValue { 89 | setDefaultToSpeaker() 90 | } else { 91 | setDefaultToStd() 92 | } 93 | } 94 | .presentationCompactAdaptation(.popover) 95 | } 96 | } 97 | } 98 | .padding() 99 | } 100 | .padding() 101 | .navigationTitle("Moshi: \(modelType.name)") 102 | } 103 | 104 | private func generate() { 105 | Task { 106 | await model.generate(self.modelType) 107 | } 108 | } 109 | 110 | private func stopGenerate() { 111 | Task { 112 | await model.stopGenerate() 113 | } 114 | } 115 | } 116 | 117 | struct StatusSection: View { 118 | let model: Evaluator 119 | 120 | var body: some View { 121 | VStack(spacing: 8) { 122 | if let progress = model.progress { 123 | ProgressView(progress) 124 | .transition(.slide) 125 | } 126 | 127 | if model.running { 128 | let tint = model.totalDuration > 0 ? Color.blue : Color.red 129 | HStack(spacing: 12) { 130 | Label("\(Int(model.totalDuration))s", systemImage: "clock") 131 | 132 | Image(systemName: "microphone.circle.fill") 133 | .foregroundStyle(.blue) 134 | .font(.title2) 135 | .symbolEffect(.bounce, options: .repeating) 136 | 137 | VStack(alignment: .leading) { 138 | Text("Buffer") 139 | .font(.caption) 140 | .foregroundStyle(.secondary) 141 | Gauge(value: model.bufferedDuration * 1000.0, in: 0...500) { 142 | EmptyView() 143 | } currentValueLabel: { 144 | Text("\(Int(model.bufferedDuration * 1000.0))ms") 145 | } 146 | .tint(.blue) 147 | } 148 | } 149 | .padding() 150 | .background(RoundedRectangle(cornerRadius: 10).fill(tint.opacity(0.1))) 151 | } 152 | } 153 | } 154 | } 155 | 156 | struct OutputSection: View { 157 | let output: String 158 | 159 | var body: some View { 160 | ScrollView(.vertical) { 161 | ScrollViewReader { proxy in 162 | Text(output) 163 | .textSelection(.enabled) 164 | .multilineTextAlignment(.leading) 165 | .frame(maxWidth: .infinity, alignment: .leading) 166 | .padding() 167 | .onChange(of: output) { _, _ in 168 | withAnimation { 169 | proxy.scrollTo("bottom", anchor: .bottom) 170 | } 171 | } 172 | Color.clear 173 | .frame(height: 1) 174 | .id("bottom") 175 | } 176 | } 177 | .background(RoundedRectangle(cornerRadius: 12).fill(.secondary.opacity(0.1))) 178 | } 179 | } 180 | 181 | struct DeviceStatsView: View { 182 | let deviceStat: DeviceStat 183 | 184 | var body: some View { 185 | VStack(alignment: .leading, spacing: 8) { 186 | MemoryStatRow( 187 | label: "Active Memory", 188 | value: deviceStat.gpuUsage.activeMemory, 189 | total: GPU.memoryLimit 190 | ) 191 | MemoryStatRow( 192 | label: "Cache Memory", 193 | value: deviceStat.gpuUsage.cacheMemory, 194 | total: GPU.cacheLimit 195 | ) 196 | HStack { 197 | MemoryStatRow( 198 | label: "Peak Memory", 199 | value: deviceStat.gpuUsage.peakMemory 200 | ) 201 | Spacer() 202 | VStack(alignment: .leading) { 203 | Text("Thermal State") 204 | .foregroundStyle(.secondary) 205 | Text(deviceStat.thermalState.rawValue) 206 | } 207 | } 208 | 209 | } 210 | } 211 | } 212 | 213 | struct DebugView: View { 214 | let modelInfo: String 215 | let modelName: String 216 | let urls: (URL, URL)? 217 | 218 | var body: some View { 219 | VStack(alignment: .leading, spacing: 8) { 220 | Text("Last Info").bold() 221 | Text(modelInfo) 222 | Text("Model Name").bold() 223 | Text(modelName) 224 | if let (traceURL, codesURL) = urls { 225 | HStack { 226 | ShareLink(item: traceURL) { 227 | Label("Trace", systemImage: "square.and.arrow.up") 228 | } 229 | .buttonStyle(BorderedButtonStyle()) 230 | ShareLink(item: codesURL) { 231 | Label("Codes", systemImage: "square.and.arrow.up") 232 | } 233 | .buttonStyle(BorderedButtonStyle()) 234 | } 235 | .padding() 236 | } 237 | } 238 | } 239 | } 240 | 241 | struct MemoryStatRow: View { 242 | let label: String 243 | let value: Int 244 | var total: Int? 245 | 246 | var body: some View { 247 | VStack(alignment: .leading, spacing: 4) { 248 | Text(label) 249 | .foregroundStyle(.secondary) 250 | if let total = total { 251 | Gauge(value: Double(value), in: 0...Double(total)) { 252 | Text(value.formatted(.byteCount(style: .memory))) 253 | } 254 | .tint(.blue) 255 | } else { 256 | Text(value.formatted(.byteCount(style: .memory))) 257 | } 258 | } 259 | } 260 | } 261 | 262 | struct StatsView: View { 263 | let summary: StatsSummary 264 | 265 | var body: some View { 266 | Grid(alignment: .leading) { 267 | GridRow { 268 | Text("Step") 269 | Text("Avg (ms)") 270 | Text("Min (ms)") 271 | Text("Max (ms)") 272 | Text("Count") 273 | } 274 | .bold() 275 | 276 | Divider() 277 | 278 | StatRow(label: "Encode", stats: summary.encode) 279 | StatRow(label: "Main", stats: summary.step) 280 | StatRow(label: "Depformer", stats: summary.depformer) 281 | StatRow(label: "Decode", stats: summary.decode) 282 | } 283 | .font(.system(.body, design: .monospaced)) 284 | } 285 | } 286 | 287 | struct StatRow: View { 288 | let label: String 289 | let stats: StatsSummary.Stats 290 | 291 | var body: some View { 292 | GridRow { 293 | Text(label) 294 | Text((1000 * stats.sum / Float(stats.cnt)).rounded(), format: .number) 295 | Text((1000 * stats.min).rounded(), format: .number) 296 | Text((1000 * stats.max).rounded(), format: .number) 297 | Text(stats.cnt, format: .number) 298 | } 299 | } 300 | } 301 | 302 | struct CombinedStatsView: View { 303 | let summary: StatsSummary 304 | let deviceStat: DeviceStat 305 | let modelInfo: String 306 | let modelName: String 307 | let urls: (URL, URL)? 308 | @State private var currentPage = 0 309 | @State private var isExpanded = true 310 | 311 | var body: some View { 312 | VStack(spacing: 0) { 313 | // Header with page indicator and collapse button 314 | ZStack { 315 | // Left-aligned content 316 | HStack { 317 | if isExpanded { 318 | HStack(spacing: 16) { 319 | ForEach(0..<3) { index in 320 | Button(action: { withAnimation { currentPage = index } }) { 321 | let text = switch index { 322 | case 0: "Model" 323 | case 1: "Device" 324 | case 2: "Details" 325 | case _: "unk" 326 | } 327 | VStack(spacing: 4) { 328 | Text(text) 329 | .font(.subheadline) 330 | .foregroundStyle(currentPage == index ? .primary : .secondary) 331 | Rectangle() 332 | .fill(currentPage == index ? .blue : .clear) 333 | .frame(height: 2) 334 | } 335 | } 336 | .buttonStyle(.plain) 337 | } 338 | } 339 | .frame(maxWidth: .infinity) 340 | } else { 341 | Text("Details") 342 | .font(.headline) 343 | .foregroundStyle(.secondary) 344 | .frame(maxWidth: .infinity, alignment: .leading) 345 | } 346 | } 347 | 348 | // Right-aligned button (always in the same position) 349 | HStack { 350 | Spacer() 351 | Button(action: { 352 | isExpanded.toggle() 353 | }) { 354 | Image(systemName: isExpanded ? "chevron.up.circle.fill" : "chevron.down.circle.fill") 355 | .foregroundStyle(.secondary) 356 | .font(.title3) 357 | } 358 | } 359 | } 360 | .padding(.bottom, isExpanded ? 8 : 0) 361 | // Add tap gesture only when collapsed 362 | .contentShape(Rectangle()) // Make entire area tappable 363 | .onTapGesture { 364 | if !isExpanded { 365 | withAnimation { 366 | isExpanded.toggle() 367 | } 368 | } 369 | } 370 | 371 | if isExpanded { 372 | TabView(selection: $currentPage) { 373 | StatsView(summary: summary) 374 | .padding(.vertical) 375 | .frame(height: 250) 376 | .tag(0) 377 | DeviceStatsView(deviceStat: deviceStat) 378 | .padding(.vertical) 379 | .tag(1) 380 | DebugView(modelInfo: modelInfo, modelName: modelName, urls: urls) 381 | .padding(.vertical) 382 | .tag(2) 383 | } 384 | #if os(iOS) 385 | .tabViewStyle(.page(indexDisplayMode: .never)) 386 | #endif 387 | } 388 | } 389 | .frame(height: isExpanded ? 250 : 44) 390 | .animation(.easeInOut(duration: 0.2), value: isExpanded) 391 | } 392 | } 393 | 394 | struct ModelInfoView: View { 395 | let modelType: ModelSelect 396 | 397 | var body: some View { 398 | VStack(spacing: 24) { 399 | Image(systemName: "waveform.and.mic") 400 | .font(.system(size: 64)) 401 | .foregroundStyle(.blue) 402 | 403 | VStack(spacing: 12) { 404 | Text(modelType.name) 405 | .font(.title) 406 | .bold() 407 | 408 | Text(modelType.description) 409 | .font(.body) 410 | .multilineTextAlignment(.center) 411 | .foregroundStyle(.secondary) 412 | } 413 | .padding(.horizontal) 414 | } 415 | .frame(maxWidth: .infinity, maxHeight: .infinity) 416 | .background(RoundedRectangle(cornerRadius: 12).fill(.secondary.opacity(0.1))) 417 | } 418 | } 419 | -------------------------------------------------------------------------------- /MoshiLib/LM.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import MLX 6 | import MLXFast 7 | import MLXNN 8 | 9 | public struct DepformerConfig { 10 | var transformer: TransformerConfig 11 | var numSlices: Int 12 | } 13 | 14 | class DepformerSlice: Module { 15 | @ModuleInfo(key: "emb") var emb: Embedding 16 | @ModuleInfo(key: "linear_in") var linearIn: Linear 17 | @ModuleInfo(key: "linear_out") var linearOut: Linear 18 | @ModuleInfo(key: "transformer") var transformer: Transformer 19 | 20 | public init( 21 | inVocabSize: Int, outVocabSize: Int, mainTransformerDim: Int, cfg: TransformerConfig 22 | ) { 23 | self._emb.wrappedValue = Embedding(embeddingCount: inVocabSize, dimensions: cfg.dModel) 24 | self._linearIn.wrappedValue = Linear(mainTransformerDim, cfg.dModel, bias: false) 25 | self._linearOut.wrappedValue = Linear(cfg.dModel, outVocabSize, bias: false) 26 | self._transformer.wrappedValue = Transformer(cfg) 27 | } 28 | } 29 | 30 | class Depformer: Module { 31 | let cfg: LmConfig 32 | let transformerCache: [KVCache] 33 | @ModuleInfo(key: "slices") var slices: [DepformerSlice] 34 | 35 | public init(_ cfg: LmConfig, _ cfgDepformer: DepformerConfig, bSize: Int) { 36 | self.cfg = cfg 37 | let slices = (0.. MLXArray { 51 | for c in self.transformerCache { 52 | c.reset() 53 | } 54 | var lastToken = textToken 55 | var tokens: [MLXArray] = [] 56 | for (sliceIdx, slice) in slices.enumerated() { 57 | if sliceIdx != 0 && stepIdx < self.cfg.audioDelays[sliceIdx - 1] { 58 | lastToken = MLXArray([self.cfg.audioPaddingToken()]) 59 | } 60 | var xs = slice.linearIn(mainTransformerOut) + slice.emb(lastToken) 61 | xs = slice.transformer(xs, cache: self.transformerCache) 62 | let logits = slice.linearOut(xs) 63 | (lastToken, _) = sampler(logits: logits[0]) 64 | tokens.append(lastToken) 65 | } 66 | return concatenated(tokens) 67 | } 68 | } 69 | 70 | public struct LmConfig { 71 | public var transformer: TransformerConfig 72 | public var depformer: DepformerConfig? 73 | public var textInVocabSize: Int 74 | public var textOutVocabSize: Int 75 | public var audioVocabSize: Int 76 | public var audioCodebooks: Int 77 | public var audioDelays: [Int] 78 | 79 | public func audioEOSToken() -> Int { 80 | self.audioVocabSize - 2 81 | } 82 | 83 | public func audioPaddingToken() -> Int { 84 | self.audioVocabSize - 1 85 | } 86 | 87 | public func textInitToken() -> Int { 88 | self.textInVocabSize - 1 89 | } 90 | 91 | public func depformerSlices() -> Int { 92 | self.depformer?.numSlices ?? 0 93 | } 94 | 95 | public static func moshi_2024_07() -> LmConfig { 96 | let depformer = DepformerConfig( 97 | transformer: 98 | TransformerConfig( 99 | dModel: 1024, 100 | numHeads: 16, 101 | numLayers: 6, 102 | causal: true, 103 | normFirst: true, 104 | biasFF: false, 105 | biasAttn: false, 106 | layerScale: nil, 107 | positionalEmbedding: .none, 108 | useConvBias: false, 109 | gating: true, 110 | norm: .rmsNorm, 111 | context: 8, 112 | maxPeriod: 10000, 113 | maxSeqLen: 4096, 114 | kvRepeat: 1, 115 | dimFeedForward: 1024 * 4, 116 | convLayout: false, 117 | useRotatingKVCache: false 118 | ), numSlices: 8) 119 | return LmConfig( 120 | transformer: TransformerConfig.v1_7b(), 121 | depformer: depformer, 122 | textInVocabSize: 32001, 123 | textOutVocabSize: 32000, 124 | audioVocabSize: 2049, 125 | audioCodebooks: 16, 126 | audioDelays: [0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1] 127 | ) 128 | } 129 | 130 | public static func moshi1b(audioDelay: Int) -> LmConfig { 131 | let audioDelays = [0] + Array(repeating: audioDelay, count: 7) 132 | let depformer = DepformerConfig( 133 | transformer: 134 | TransformerConfig( 135 | dModel: 1024, 136 | numHeads: 16, 137 | numLayers: 6, 138 | causal: true, 139 | normFirst: true, 140 | biasFF: false, 141 | biasAttn: false, 142 | layerScale: nil, 143 | positionalEmbedding: .none, 144 | useConvBias: false, 145 | gating: true, 146 | norm: .rmsNorm, 147 | context: 8, 148 | maxPeriod: 10000, 149 | maxSeqLen: 4096, 150 | kvRepeat: 1, 151 | dimFeedForward: 1024 * 4, 152 | convLayout: false, 153 | useRotatingKVCache: false 154 | ), numSlices: 8) 155 | return LmConfig( 156 | transformer: TransformerConfig.v1_1b(), 157 | depformer: depformer, 158 | textInVocabSize: 48001, 159 | textOutVocabSize: 48000, 160 | audioVocabSize: 2049, 161 | audioCodebooks: 16, 162 | audioDelays: audioDelays + audioDelays 163 | ) 164 | } 165 | 166 | public static func moshi2b(audioDelay: Int) -> LmConfig { 167 | let audioDelays = [0] + Array(repeating: audioDelay, count: 7) 168 | let depformer = DepformerConfig( 169 | transformer: 170 | TransformerConfig( 171 | dModel: 1024, 172 | numHeads: 16, 173 | numLayers: 6, 174 | causal: true, 175 | normFirst: true, 176 | biasFF: false, 177 | biasAttn: false, 178 | layerScale: nil, 179 | positionalEmbedding: .none, 180 | useConvBias: false, 181 | gating: true, 182 | norm: .rmsNorm, 183 | context: 8, 184 | maxPeriod: 10000, 185 | maxSeqLen: 4096, 186 | kvRepeat: 1, 187 | dimFeedForward: 1024 * 4, 188 | convLayout: false, 189 | useRotatingKVCache: false 190 | ), numSlices: 8) 191 | return LmConfig( 192 | transformer: TransformerConfig.v1_2b(), 193 | depformer: depformer, 194 | textInVocabSize: 48001, 195 | textOutVocabSize: 48000, 196 | audioVocabSize: 2049, 197 | audioCodebooks: 16, 198 | audioDelays: audioDelays + audioDelays 199 | ) 200 | } 201 | 202 | public static func asr300m() -> LmConfig { 203 | return LmConfig( 204 | transformer: TransformerConfig.v1_300m(), 205 | depformer: nil, 206 | textInVocabSize: 48001, 207 | textOutVocabSize: 48000, 208 | audioVocabSize: 2049, 209 | audioCodebooks: 32, 210 | audioDelays: Array(repeating: 0, count: 32) 211 | ) 212 | } 213 | 214 | public static func asr1b() -> LmConfig { 215 | return LmConfig( 216 | transformer: TransformerConfig.v1_1b(), 217 | depformer: nil, 218 | textInVocabSize: 8001, 219 | textOutVocabSize: 8000, 220 | audioVocabSize: 2049, 221 | audioCodebooks: 32, 222 | audioDelays: Array(repeating: 0, count: 32) 223 | ) 224 | } 225 | 226 | public static func asr2b() -> LmConfig { 227 | return LmConfig( 228 | transformer: TransformerConfig.v1_2b(), 229 | depformer: nil, 230 | textInVocabSize: 4001, 231 | textOutVocabSize: 4000, 232 | audioVocabSize: 2049, 233 | audioCodebooks: 32, 234 | audioDelays: Array(repeating: 0, count: 32) 235 | ) 236 | } 237 | 238 | public static func helium2b() -> LmConfig { 239 | return LmConfig( 240 | transformer: TransformerConfig.v1_2b(), 241 | depformer: nil, 242 | textInVocabSize: 48000, 243 | textOutVocabSize: 48000, 244 | audioVocabSize: 2049, 245 | audioCodebooks: 0, 246 | audioDelays: Array(repeating: 0, count: 0) 247 | ) 248 | } 249 | } 250 | 251 | public class LM: Module { 252 | var transformerCache: [KVCache] 253 | public let cfg: LmConfig 254 | @ModuleInfo(key: "depformer") var depformer: Depformer? 255 | @ModuleInfo(key: "transformer") public var transformer: Transformer 256 | @ModuleInfo(key: "text_emb") var textEmb: Embedding 257 | @ModuleInfo(key: "out_norm") var outNorm: UnaryLayer 258 | @ModuleInfo(key: "text_linear") var textLinear: Linear 259 | @ModuleInfo(key: "audio_embs") var audioEmbs: [Embedding] 260 | 261 | public init(_ cfg: LmConfig, bSize: Int) { 262 | self.cfg = cfg 263 | self._transformer.wrappedValue = Transformer(cfg.transformer) 264 | self._depformer.wrappedValue = cfg.depformer.map { Depformer(cfg, $0, bSize: bSize) } 265 | self._textEmb.wrappedValue = Embedding( 266 | embeddingCount: cfg.textInVocabSize, dimensions: cfg.transformer.dModel) 267 | self._outNorm.wrappedValue = 268 | switch cfg.transformer.norm { 269 | case .layerNorm: 270 | LayerNorm(dimensions: cfg.transformer.dModel, eps: 1e-5) 271 | case .rmsNorm: RMSNorm(dimensions: cfg.transformer.dModel, eps: 1e-8) 272 | } 273 | self._textLinear.wrappedValue = Linear( 274 | cfg.transformer.dModel, cfg.textOutVocabSize, bias: false) 275 | self._audioEmbs.wrappedValue = (0.. MLXArray { 282 | var x = textEmb(x) 283 | x = transformer(x, cache: self.transformerCache) 284 | return textLinear(outNorm(x)) 285 | } 286 | 287 | public func resetCache() { 288 | for cache in self.transformerCache { 289 | cache.reset() 290 | } 291 | } 292 | 293 | public func stepMain(textIds: MLXArray?, audioIds: [MLXArray]) -> (MLXArray, MLXArray) { 294 | var x = textIds.flatMap { textEmb($0) } 295 | for (a, emb) in zip(audioIds, self.audioEmbs) { 296 | let e = emb(a) 297 | x = x.map { $0 + e } ?? e 298 | } 299 | let out = outNorm(transformer(x!, cache: self.transformerCache)) 300 | let logits = textLinear(out[0..., -1, 0...]) 301 | return (out, logits) 302 | } 303 | 304 | public func sample( 305 | textIds: MLXArray?, 306 | audioIds: [MLXArray], 307 | stepIdx: Int, 308 | textSampler: Sampler, 309 | audioSampler: Sampler, 310 | cb: Callbacks 311 | ) -> (MLXArray, MLXArray) { 312 | cb.onEvent(.beginStep) 313 | var x = textIds.flatMap { textEmb($0) } 314 | for (a, emb) in zip(audioIds, self.audioEmbs) { 315 | let e = emb(a) 316 | x = x.map { $0 + e } ?? e 317 | } 318 | let mainTransformerOut = outNorm(transformer(x!, cache: self.transformerCache)) 319 | let textLogits = textLinear(mainTransformerOut[0..., -1, 0...]) 320 | let (textToken, _) = textSampler(logits: textLogits) 321 | textToken.eval() 322 | cb.onEvent(.endStep) 323 | if let depformer = self.depformer { 324 | cb.onEvent(.beginDepformer) 325 | let audioTokens = depformer.sample( 326 | mainTransformerOut: mainTransformerOut, 327 | stepIdx: stepIdx, 328 | sampler: audioSampler, 329 | textToken: textToken) 330 | audioTokens.eval() 331 | cb.onEvent(.endDepformer) 332 | return (textToken, audioTokens) 333 | } else { 334 | return (textToken, MLXArray()) 335 | } 336 | } 337 | 338 | public func warmup() { 339 | let sampler = Sampler() 340 | let textIds = MLXArray.zeros([1, 1], dtype: .int32) 341 | let audioIds = (0.. MLXArray? { 386 | if self.stepIdx >= self.maxSteps { 387 | return nil 388 | } 389 | let textIds: MLXArray 390 | if self.stepIdx == 0 { 391 | textIds = MLXArray([self.model.cfg.textOutVocabSize]).reshaped([1, 1]) 392 | } else { 393 | textIds = self.genSequence[.newAxis, 0..., 0, self.stepIdx - 1] 394 | } 395 | self.genSequence[0..., (1 + self.mainCodebooks)..., self.stepIdx] = otherAudioTokens 396 | var audioIds: [MLXArray] = [] 397 | for (cbIdx, delay) in self.model.cfg.audioDelays.enumerated() { 398 | let genIdx = self.stepIdx - 1 - delay 399 | let audioToken: MLXArray 400 | if genIdx >= 0 { 401 | audioToken = self.genSequence[.newAxis, 0..., 1 + cbIdx, genIdx] 402 | } else { 403 | audioToken = MLXArray([self.model.cfg.audioPaddingToken()]).reshaped([1, 1]) 404 | } 405 | if (audioToken .== MLXArray(ungeneratedToken)).any().item() { 406 | fatalError("ungenerated value in audio tokens, cb \(cbIdx), step \(stepIdx)") 407 | } 408 | assert(audioToken.shape == [1, 1]) 409 | audioIds.append(audioToken) 410 | } 411 | if (textIds .== MLXArray(ungeneratedToken)).any().item() { 412 | fatalError("ungenerated value in text tokens, step \(stepIdx)") 413 | } 414 | assert(textIds.shape == [1, 1]) 415 | let (tt, at) = model.sample( 416 | textIds: textIds, 417 | audioIds: audioIds, 418 | stepIdx: self.stepIdx, 419 | textSampler: self.textSampler, 420 | audioSampler: self.audioSampler, 421 | cb: self.cb 422 | ) 423 | assert(tt.shape == [1]) 424 | assert(at.shape == [self.model.cfg.depformerSlices()]) 425 | 426 | self.genSequence[0..., 0, self.stepIdx] = tt 427 | for cbIdx in 0..= 0 { 431 | self.genSequence[0..., cbIdx + 1, genIdx] = at[cbIdx] 432 | } 433 | } 434 | 435 | self.stepIdx += 1 436 | return textIds 437 | } 438 | 439 | public func lastAudioTokens() -> MLXArray? { 440 | let genIdx = self.stepIdx - 1 - (self.model.cfg.audioDelays.max() ?? 0) 441 | if genIdx < 0 { 442 | return nil 443 | } 444 | let tokens = self.genSequence[0..., 1...self.mainCodebooks, genIdx] 445 | if (tokens .== MLXArray(ungeneratedToken)).any().item() { 446 | fatalError("ungenerated value in text tokens, step \(stepIdx)") 447 | } 448 | if (tokens .== MLXArray(self.model.cfg.audioPaddingToken())).any().item() { 449 | return nil 450 | } 451 | return tokens 452 | } 453 | 454 | public func reset() { 455 | self.stepIdx = 0 456 | self.model.resetCache() 457 | self.genSequence[0...] = MLXArray(ungeneratedToken) 458 | cb.onReset() 459 | } 460 | } 461 | -------------------------------------------------------------------------------- /Moshi/ContentView.swift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Kyutai, all rights reserved. 2 | // This source code is licensed under the license found in the 3 | // LICENSE file in the root directory of this source tree. 4 | 5 | import Hub 6 | import MLX 7 | import MLXNN 8 | import MLXRandom 9 | import Metal 10 | import MoshiLib 11 | import SwiftUI 12 | import Synchronization 13 | import Tokenizers 14 | 15 | struct CustomError: Error { 16 | let message: String 17 | 18 | init(_ s: String) { 19 | message = s 20 | } 21 | } 22 | 23 | enum ModelSelect: String, CaseIterable, Identifiable { 24 | case mimi 25 | case asr 26 | case hibiki 27 | case helium 28 | case qwen 29 | 30 | var id: Self { return self } 31 | 32 | var name: String { 33 | switch self { 34 | case .hibiki: 35 | return "Hibiki 1B" 36 | default: 37 | return rawValue.capitalized 38 | } 39 | } 40 | 41 | var description: String { 42 | switch self { 43 | case .hibiki: 44 | return 45 | "A French to English simultaneous translation model designed for real-time speech translation." 46 | default: 47 | return "" 48 | } 49 | } 50 | } 51 | 52 | struct ContentView: View { 53 | @State var model = Evaluator() 54 | @State var selectedModel: ModelSelect = .mimi 55 | @State var sendToSpeaker = false 56 | @Environment(DeviceStat.self) private var deviceStat 57 | 58 | // Currently available models 59 | private let availableModels: [ModelSelect] = [.hibiki, .asr] 60 | var body: some View { 61 | Group { 62 | if availableModels.count == 1 { 63 | // Skip navigation if only one model 64 | ModelView( 65 | model: $model, 66 | modelType: availableModels[0] 67 | ) 68 | } else { 69 | NavigationSplitView( 70 | sidebar: { 71 | VStack { 72 | List { 73 | ForEach(availableModels) { modelType in 74 | NavigationLink( 75 | modelType.rawValue, 76 | destination: { 77 | ModelView( 78 | model: $model, 79 | modelType: modelType 80 | ) 81 | } 82 | ) 83 | } 84 | } 85 | } 86 | .navigationTitle("Available Models") 87 | #if os(iOS) 88 | .navigationBarTitleDisplayMode(.inline) 89 | #endif 90 | }, 91 | detail: { 92 | Text("Please choose a model type") 93 | } 94 | ) 95 | } 96 | } 97 | } 98 | } 99 | 100 | #Preview { 101 | ContentView() 102 | } 103 | 104 | @Observable 105 | @MainActor 106 | class Evaluator { 107 | var running = false 108 | var modelInfo = "...moshi..." 109 | var modelName = "" 110 | var urls: (URL, URL)? = nil 111 | var stat = "" 112 | var output = "" 113 | var progress: Progress? = nil 114 | var statsSummary: StatsSummary = StatsSummary() 115 | var bufferedDuration: Double = 0.0 116 | var totalDuration: Double = 0.0 117 | let shouldStop: Atomic = .init(false) 118 | let cb: PerfStats = PerfStats() 119 | 120 | enum LoadState { 121 | case idle 122 | case loaded(ModelState, ModelSelect) 123 | } 124 | 125 | var loadState = LoadState.idle 126 | 127 | func downloadFromHub(id: String, filename: String) async throws -> URL { 128 | guard 129 | let downloadDir = FileManager.default.urls( 130 | for: .applicationSupportDirectory, in: .userDomainMask 131 | ).first 132 | else { 133 | throw CustomError("cannot find the app support directory") 134 | } 135 | let api = HubApi(downloadBase: downloadDir) 136 | let repo = Hub.Repo(id: id) 137 | let targetURL = api.localRepoLocation(repo).appending(path: filename) 138 | if FileManager.default.fileExists(atPath: targetURL.path) { 139 | print("using cached file \(targetURL.path)") 140 | return targetURL 141 | } 142 | let url = try await api.snapshot(from: repo, matching: filename) { progress in 143 | Task { @MainActor in 144 | self.progress = progress 145 | } 146 | } 147 | // TODO: also set this back to nil on errors. 148 | self.progress = nil 149 | return url.appending(path: filename) 150 | } 151 | 152 | func loadVocab(_ cfg: LmConfig) async throws -> [Int: String] { 153 | let filename = 154 | switch cfg.textOutVocabSize { 155 | case 48000: "tokenizer_spm_48k_multi6_2.json" 156 | case 32000: "tokenizer_spm_32k_3.json" 157 | case 8000: "tokenizer_spm_8k_0.json" 158 | case 4000: "test_en_audio_4000.json" 159 | case let other: fatalError("unexpected text vocab size \(other)") 160 | } 161 | let fileURL = try await downloadFromHub(id: "lmz/moshi-swift", filename: filename) 162 | let jsonData = try Data(contentsOf: fileURL) 163 | let dictionary = try JSONDecoder().decode([Int: String].self, from: jsonData) 164 | return dictionary 165 | } 166 | 167 | func makeMoshi(_ url: URL, _ cfg: LmConfig) throws -> LM { 168 | let weights = try loadArrays(url: url) 169 | let parameters = ModuleParameters.unflattened(weights) 170 | let model = LM(cfg, bSize: 1) 171 | if url.lastPathComponent.hasSuffix(".q4.safetensors") { 172 | quantize(model: model, groupSize: 32, bits: 4) 173 | } else if url.lastPathComponent.hasSuffix(".q6.safetensors") { 174 | quantize(model: model, groupSize: 64, bits: 6) 175 | } else if url.lastPathComponent.hasSuffix(".q8.safetensors") { 176 | quantize(model: model, groupSize: 64, bits: 8) 177 | } 178 | try model.update(parameters: parameters, verify: [.all]) 179 | eval(model) 180 | return model 181 | } 182 | 183 | func makeHelium(_ url: URL, _ cfg: LmConfig) throws -> LM { 184 | let weights = try loadArrays(url: url) 185 | let parameters = ModuleParameters.unflattened(weights) 186 | let model = LM(cfg, bSize: 1) 187 | if url.lastPathComponent.hasSuffix("q4.safetensors") { 188 | quantize(model: model, groupSize: 64, bits: 4) 189 | } else if url.lastPathComponent.hasSuffix(".q6.safetensors") { 190 | quantize(model: model, groupSize: 64, bits: 6) 191 | } else if url.lastPathComponent.hasSuffix("q8.safetensors") { 192 | quantize(model: model, groupSize: 64, bits: 8) 193 | } 194 | try model.update(parameters: parameters, verify: [.all]) 195 | eval(model) 196 | return model 197 | } 198 | 199 | func makeQwen(_ url: URL, _ cfg: QwenConfig) throws -> QwenModel { 200 | let weights = try loadArrays(url: url) 201 | let parameters = ModuleParameters.unflattened(weights) 202 | let model = QwenModel(cfg) 203 | if let q = cfg.quantization { 204 | quantize(model: model, groupSize: q.groupSize, bits: q.bits) 205 | } 206 | try model.update(parameters: parameters, verify: [.all]) 207 | eval(model) 208 | return model 209 | } 210 | 211 | func makeMimi(numCodebooks: Int) async throws -> Mimi { 212 | let cfg = MimiConfig.mimi_2024_07(numCodebooks: numCodebooks) 213 | let model = Mimi(cfg, bSize: 1) 214 | 215 | let url = try await downloadFromHub( 216 | id: "lmz/moshi-swift", 217 | filename: "tokenizer-dbaa9758-checkpoint125.safetensors") 218 | let origWeights = try loadArrays(url: url) 219 | var weights: [String: MLXArray] = [:] 220 | for (var key, var weight) in origWeights { 221 | // Mutating the keys while iterating over the map seems pretty dodgy, not sure what the idiomatic 222 | // way to do this is in swift. Hopefully this is copy on write and it's all good :) 223 | if key.hasPrefix("encoder.model") { 224 | key.replace("encoder.model.", with: "encoder.") 225 | } 226 | if key.hasPrefix("decoder.model") { 227 | key.replace("decoder.model.", with: "decoder.") 228 | } 229 | if key.hasSuffix(".in_proj_weight") { 230 | key.replace(".in_proj_weight", with: ".in_proj.weight") 231 | } 232 | if key.hasSuffix(".linear1.weight") { 233 | key.replace(".linear1.weight", with: ".gating.linear1.weight") 234 | } 235 | if key.hasSuffix(".linear2.weight") { 236 | key.replace(".linear2.weight", with: ".gating.linear2.weight") 237 | } 238 | // Awfully hardcoded matching between the pytorch layers and their mlx equivalent :( 239 | for (layerIdx, decoderIdx) in [2, 5, 8, 11].enumerated() { 240 | key.replace("decoder.\(decoderIdx).", with: "decoder.layers.\(layerIdx).upsample.") 241 | key.replace( 242 | "decoder.\(decoderIdx + 1).", with: "decoder.layers.\(layerIdx).residuals.0.") 243 | } 244 | for (layerIdx, encoderIdx) in [1, 4, 7, 10].enumerated() { 245 | key.replace( 246 | "encoder.\(encoderIdx).", with: "encoder.layers.\(layerIdx).residuals.0.") 247 | key.replace( 248 | "encoder.\(encoderIdx + 2).", with: "encoder.layers.\(layerIdx).downsample.") 249 | } 250 | key.replace("decoder.0.", with: "decoder.init_conv1d.") 251 | key.replace("decoder.14.", with: "decoder.final_conv1d.") 252 | key.replace("encoder.0.", with: "encoder.init_conv1d.") 253 | key.replace("encoder.14.", with: "encoder.final_conv1d.") 254 | key.replace(".block.1.", with: ".block.0.") 255 | key.replace(".block.3.", with: ".block.1.") 256 | // PyTorch layout for conv weights is outC, inC, kSize, for MLX it's outC, kSize, inC 257 | if key.hasSuffix(".conv.weight") || key.hasSuffix(".output_proj.weight") 258 | || key.hasSuffix(".input_proj.weight") 259 | { 260 | weight = weight.swappedAxes(-1, -2) 261 | } 262 | // PyTorch layout for conv-transposed weights is inC, outC, kSize, for MLX it's outC, kSize, inC 263 | if key.hasSuffix(".convtr.weight") { 264 | weight = weight.transposed(axes: [1, 2, 0]) 265 | } 266 | 267 | // print(key, weight.shape) 268 | weights[key] = weight 269 | } 270 | let parameters = ModuleParameters.unflattened(weights) 271 | try model.update(parameters: parameters, verify: [.all]) 272 | return model 273 | } 274 | 275 | func stopGenerate() async { 276 | self.shouldStop.store(true, ordering: .relaxed) 277 | } 278 | 279 | func generate(_ sm: ModelSelect) async { 280 | guard !running else { return } 281 | 282 | self.shouldStop.store(false, ordering: .relaxed) 283 | self.modelInfo = "starting" 284 | self.output = "" 285 | self.totalDuration = 0.0 286 | running = true 287 | do { 288 | let model = try await load(sm) 289 | let urls = try await model.perform { model in 290 | model.reset() 291 | await self.cb.onReset() 292 | // TODO: Do not create a fresh audio input/output on each session. 293 | let microphoneCapture = MicrophoneCapture() 294 | microphoneCapture.startCapturing() 295 | let ap = AudioPlayer(sampleRate: 24000) 296 | try ap.startPlaying() 297 | print("started the audio loops") 298 | 299 | var step = 0 300 | let startTime = CFAbsoluteTimeGetCurrent() 301 | while let pcm = microphoneCapture.receive() { 302 | if shouldStop.load(ordering: .relaxed) { 303 | break 304 | } 305 | let pcm = MLXArray(pcm)[.newAxis, .newAxis] 306 | if !model.onMicrophonePcm(pcm, ap: ap, ev: self) { 307 | break 308 | } 309 | step += 1 310 | if step % 128 == 0 { 311 | GPU.clearCache() 312 | } 313 | let currentStep = step 314 | Task { @MainActor in 315 | if currentStep % 5 == 0 { 316 | self.bufferedDuration = 317 | ap.bufferedDuration() + microphoneCapture.bufferedDuration() 318 | self.totalDuration = CFAbsoluteTimeGetCurrent() - startTime 319 | } 320 | if currentStep % 20 == 0 { 321 | self.statsSummary = self.cb.getSummary(maxEvents: 100) 322 | } 323 | } 324 | } 325 | print() 326 | let traceURL = FileManager.default.temporaryDirectory.appendingPathComponent( 327 | "moshi-trace.json") 328 | try await self.cb.writeJSONTrace(url: traceURL) 329 | let codesURL = FileManager.default.temporaryDirectory.appendingPathComponent( 330 | "moshi-codes.safetensors") 331 | do { 332 | let timestamp = Int(Date().timeIntervalSince1970) 333 | guard 334 | let documentsDirectory = FileManager.default.urls( 335 | for: .documentDirectory, in: .userDomainMask 336 | ).first 337 | else { 338 | throw NSError( 339 | domain: "FileManagerError", code: 1, 340 | userInfo: [NSLocalizedDescriptionKey: "Documents directory not found"]) 341 | } 342 | let docURL = documentsDirectory.appendingPathComponent( 343 | "moshi-codes-\(timestamp).safetensors") 344 | let fileManager = FileManager.default 345 | try fileManager.copyItem(at: codesURL, to: docURL) 346 | } catch { 347 | print("Error copying file: \(error)") 348 | } 349 | try await self.cb.writeCodes(url: codesURL) 350 | microphoneCapture.stopCapturing() 351 | return (traceURL, codesURL) 352 | } 353 | self.urls = urls 354 | self.modelInfo = "finished generating" 355 | } catch { 356 | self.modelInfo = "failed: \(error)" 357 | } 358 | self.bufferedDuration = 0.0 359 | self.statsSummary = StatsSummary() 360 | running = false 361 | } 362 | 363 | func load(_ sm: ModelSelect) async throws -> ModelState { 364 | if case .loaded(let m, sm) = self.loadState { 365 | return m 366 | } 367 | // Start by reseting loadState so as to release the memory used 368 | // by the currently loaded model if any. 369 | self.loadState = .idle 370 | let m: ModelState 371 | switch sm { 372 | case .hibiki: 373 | let model = try await MoshiModel(self, self.cb) 374 | m = ModelState(model) 375 | case .mimi: 376 | let model = try await MimiModel(self, self.cb) 377 | m = ModelState(model) 378 | case .asr: 379 | let model = try await AsrModel(self, self.cb) 380 | m = ModelState(model) 381 | case .helium: 382 | let model = try await HeliumModel(self, self.cb) 383 | m = ModelState(model) 384 | case .qwen: 385 | let model = try await QwenModel_(self, self.cb) 386 | m = ModelState(model) 387 | } 388 | self.loadState = .loaded(m, sm) 389 | return m 390 | } 391 | 392 | func setModelName(_ s: String) { 393 | modelName = s 394 | } 395 | func setModelInfo(_ s: String) { 396 | modelInfo = s 397 | } 398 | } 399 | 400 | protocol Model { 401 | init(_ ev: Evaluator, _ cb: Callbacks) async throws 402 | mutating func reset() 403 | // If onMicrophonePcm returns true continue, otherwise break. 404 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool 405 | } 406 | 407 | struct MimiModel: Model { 408 | let mimi: Mimi 409 | let codes: MLXArray 410 | var currentStep: Int 411 | let totalSteps: Int 412 | let cb: Callbacks 413 | 414 | init(_ ev: Evaluator, _ cb: Callbacks) async throws { 415 | await ev.setModelInfo("building model") 416 | self.mimi = try await ev.makeMimi(numCodebooks: 16) 417 | await ev.setModelInfo("model built") 418 | let codeURL = try await ev.downloadFromHub( 419 | id: "lmz/moshi-swift", filename: "bria-codes.safetensors") 420 | guard let codes = try loadArrays(url: codeURL)["codes"] else { 421 | throw CustomError("no codes in \(codeURL)") 422 | } 423 | self.codes = codes 424 | self.currentStep = 0 425 | self.totalSteps = codes.dim(-1) 426 | self.cb = cb 427 | } 428 | 429 | mutating func reset() { 430 | mimi.resetState() 431 | } 432 | 433 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool { 434 | let sampleText = 435 | "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum." 436 | let sampleWords = sampleText.components(separatedBy: " ") 437 | self.cb.onEvent(.beginEncode) 438 | let micCodes = mimi.encodeStep(StreamArray(pcm)) 439 | micCodes.eval() 440 | self.cb.onEvent(.endEncode) 441 | if let micCodes = micCodes.asArray() { 442 | // As of 2024-12-04, there is a memory leak if this eval is removed, this is 443 | // triggered even without the decoding and audio playing, only the line 444 | // below is needed: 445 | // let audioTokens = codes[.ellipsis, currentStep...currentStep] 446 | eval(micCodes) 447 | self.cb.onInputAudioTokens(micCodes) 448 | let (_, _, steps) = micCodes.shape3 449 | for _ in 0..= totalSteps { 451 | break 452 | } 453 | let audioTokens = codes[.ellipsis, currentStep...currentStep] 454 | self.cb.onEvent(.beginDecode) 455 | self.cb.onOutputAudioTokens(audioTokens) 456 | let pcmOut = mimi.decodeStep(StreamArray(audioTokens)) 457 | pcmOut.eval() 458 | self.cb.onEvent(.endDecode) 459 | if let p = pcmOut.asArray() { 460 | let _ = ap.send(p.asArray(Float.self)) 461 | } 462 | if currentStep % 4 == 0 { 463 | let v = sampleWords[(currentStep / 4) % sampleWords.count] + " " 464 | Task { @MainActor in 465 | ev.output += v 466 | } 467 | } 468 | currentStep += 1 469 | } 470 | } 471 | return currentStep < totalSteps 472 | } 473 | } 474 | 475 | struct QwenModel_: Model { 476 | let qwen: QwenModel 477 | let tokenizer: any Tokenizer 478 | 479 | init(_ ev: Evaluator, _ cb: Callbacks) async throws { 480 | let hfRepo = "Qwen/Qwen2.5-0.5B-Instruct" 481 | await ev.setModelInfo("building model") 482 | 483 | let configUrl = try await ev.downloadFromHub(id: hfRepo, filename: "config.json") 484 | let configData = try Data(contentsOf: configUrl) 485 | let decoder = JSONDecoder() 486 | decoder.keyDecodingStrategy = .convertFromSnakeCase 487 | let cfg = try decoder.decode(QwenConfig.self, from: configData) 488 | 489 | let url = try await ev.downloadFromHub(id: hfRepo, filename: "model.safetensors") 490 | qwen = try await ev.makeQwen(url, cfg) 491 | await ev.setModelInfo("model built") 492 | tokenizer = try await AutoTokenizer.from(pretrained: hfRepo) 493 | } 494 | 495 | mutating func reset() { 496 | } 497 | 498 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool { 499 | let sampler = Sampler() 500 | let cache = qwen.makeCache(bSize: 1) 501 | 502 | var lastToken = 151643 503 | for stepIdx in 0...500 { 504 | let logits = qwen(MLXArray([lastToken]).reshaped(1, 1), cache: cache) 505 | let (tok, _) = sampler(logits: logits[0]) 506 | lastToken = tok.item() 507 | let s = tokenizer.decode(tokens: [lastToken]) 508 | Task { @MainActor in 509 | ev.output += s 510 | } 511 | } 512 | 513 | return false 514 | } 515 | } 516 | 517 | struct HeliumModel: Model { 518 | let helium: LM 519 | let vocab: [Int: String] 520 | 521 | init(_ ev: Evaluator, _ cb: Callbacks) async throws { 522 | await ev.setModelInfo("building model") 523 | let cfg = LmConfig.helium2b() 524 | let url = try await ev.downloadFromHub( 525 | id: "kyutai/helium-1-preview-2b-mlx", filename: "helium-1-preview-2b-q4.safetensors") 526 | helium = try await ev.makeHelium(url, cfg) 527 | await ev.setModelInfo("model built") 528 | vocab = try await ev.loadVocab(cfg) 529 | await ev.setModelInfo("warming up helium") 530 | helium.warmup() 531 | await ev.setModelInfo("done warming up") 532 | } 533 | 534 | mutating func reset() { 535 | } 536 | 537 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool { 538 | let maxSteps = helium.cfg.transformer.maxSeqLen 539 | let sampler = Sampler() 540 | let cb = PerfStats() 541 | 542 | var lastToken = MLXArray([1]) 543 | for stepIdx in 0...min(500, maxSteps) { 544 | let (textToken, _) = helium.sample( 545 | textIds: lastToken.reshaped([1, 1]), audioIds: [], stepIdx: stepIdx, 546 | textSampler: sampler, 547 | audioSampler: sampler, cb: cb) 548 | let textTokenI: Int = textToken[0].item() 549 | if var v = vocab[textTokenI] { 550 | if v == "<0x0A>" { 551 | Task { @MainActor in 552 | ev.output += "\n" 553 | } 554 | } else { 555 | v.replace("▁", with: " ") 556 | Task { @MainActor in 557 | ev.output += v 558 | } 559 | } 560 | } 561 | lastToken = textToken 562 | } 563 | 564 | return false 565 | } 566 | } 567 | 568 | struct AsrModel: Model { 569 | var asr: ASR 570 | 571 | init(_ ev: Evaluator, _ cb: Callbacks) async throws { 572 | await ev.setModelInfo("building model") 573 | let url: URL 574 | let localURL = Bundle.main.url(forResource: "stt-model", withExtension: "safetensors") 575 | switch localURL { 576 | case .none: 577 | url = try await ev.downloadFromHub( 578 | id: "lmz/moshi-swift", filename: "moshi-70f8f0ea@500.q8.safetensors") 579 | case .some(let localURL): 580 | url = localURL 581 | } 582 | let cfg = LmConfig.asr1b() 583 | let moshi = try await ev.makeMoshi(url, cfg) 584 | let mimi = try await ev.makeMimi(numCodebooks: 32) 585 | await ev.setModelInfo("model built") 586 | let vocab = try await ev.loadVocab(cfg) 587 | await ev.setModelInfo("warming up mimi") 588 | mimi.warmup() 589 | await ev.setModelInfo("warming up moshi") 590 | moshi.warmup() 591 | await ev.setModelInfo("done warming up") 592 | self.asr = ASR(moshi, mimi, vocab: vocab, cb: cb) 593 | } 594 | 595 | mutating func reset() { 596 | asr.reset() 597 | } 598 | 599 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool { 600 | let tokens = asr.onPcmInput(pcm) 601 | for v in tokens { 602 | print(v, terminator: "") 603 | fflush(stdout) 604 | Task { @MainActor in 605 | ev.output += v 606 | } 607 | } 608 | return true 609 | } 610 | } 611 | 612 | struct MoshiModel: Model { 613 | let moshi: LM 614 | let vocab: [Int: String] 615 | let mimi: Mimi 616 | let gen: LMGen 617 | let cb: Callbacks 618 | 619 | init(_ ev: Evaluator, _ cb: Callbacks) async throws { 620 | await ev.setModelInfo("building model") 621 | let url: URL 622 | let cfg = LmConfig.moshi1b(audioDelay: 2) 623 | let localURL = Bundle.main.url( 624 | forResource: "moshi-37c6cfd6@200.q6", withExtension: "safetensors") 625 | switch localURL { 626 | case .none: 627 | url = try await ev.downloadFromHub( 628 | id: "lmz/moshi-swift", filename: "moshi-37c6cfd6@200.q6.safetensors") 629 | case .some(let localURL): 630 | url = localURL 631 | } 632 | self.moshi = try await ev.makeMoshi(url, cfg) 633 | await ev.setModelName(url.lastPathComponent) 634 | self.mimi = try await ev.makeMimi(numCodebooks: 16) 635 | await ev.setModelInfo("model built") 636 | let maxSteps = cfg.transformer.maxSeqLen 637 | self.cb = cb 638 | self.gen = LMGen( 639 | moshi, maxSteps: maxSteps, audioSampler: Sampler(), textSampler: Sampler(), cb: self.cb) 640 | self.vocab = try await ev.loadVocab(cfg) 641 | await ev.setModelInfo("warming up mimi") 642 | self.mimi.warmup() 643 | await ev.setModelInfo("warming up moshi") 644 | self.moshi.warmup() 645 | await ev.setModelInfo("done warming up") 646 | } 647 | 648 | mutating func reset() { 649 | mimi.resetState() 650 | gen.reset() 651 | } 652 | 653 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool { 654 | cb.onEvent(.beginEncode) 655 | let codes = mimi.encodeStep(StreamArray(pcm)) 656 | codes.eval() 657 | cb.onEvent(.endEncode) 658 | if let codes = codes.asArray() { 659 | self.cb.onInputAudioTokens(codes) 660 | let (_, _, steps) = codes.shape3 661 | for step in 0..(_ action: @Sendable (inout Model) async throws -> R) 703 | async rethrows 704 | -> R 705 | { 706 | var model = model 707 | let result = try await action(&model) 708 | return result 709 | } 710 | } 711 | --------------------------------------------------------------------------------