├── .swift-format
├── Moshi
├── Assets.xcassets
│ ├── Contents.json
│ ├── AccentColor.colorset
│ │ └── Contents.json
│ └── AppIcon.appiconset
│ │ └── Contents.json
├── Preview Content
│ └── Preview Assets.xcassets
│ │ └── Contents.json
├── Info.plist
├── moshi.entitlements
├── moshiApp.swift
├── DeviceStat.swift
├── AudioRT.swift
├── ModelView.swift
└── ContentView.swift
├── moshi.xcodeproj
├── project.xcworkspace
│ ├── contents.xcworkspacedata
│ └── xcshareddata
│ │ └── swiftpm
│ │ └── Package.resolved
└── xcshareddata
│ └── xcschemes
│ ├── moshi-cli.xcscheme
│ └── Moshi.xcscheme
├── .github
└── PULL_REQUEST_TEMPLATE.md
├── MoshiLib
├── MoshiLib.docc
│ └── MoshiLib.md
├── MoshiLib.h
├── Utils.swift
├── ASR.swift
├── Streaming.swift
├── Mimi.swift
├── Qwen2.swift
├── KVCache.swift
├── Quantization.swift
├── Perf.swift
├── Seanet.swift
├── Conv.swift
├── Transformer.swift
└── LM.swift
├── MoshiTests
└── moshiTests.swift
├── .gitignore
├── MoshiLibTests
└── moshi_libTests.swift
├── scripts
├── sp.py
└── eval_mlx.py
├── Makefile
├── MoshiUITests
├── moshiUITestsLaunchTests.swift
└── moshiUITests.swift
├── LICENSE
├── MoshiCLI
├── RunHelium.swift
├── Audio.swift
├── AudioRT.swift
├── RunMimi.swift
├── RunMoshi.swift
└── CLI.swift
├── README.md
└── CONTRIBUTING.md
/.swift-format:
--------------------------------------------------------------------------------
1 | {
2 | "lineLength": 100,
3 | "indentation": {
4 | "spaces": 4
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Moshi/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/Moshi/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/moshi.xcodeproj/project.xcworkspace/contents.xcworkspacedata:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/Moshi/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/Moshi/Info.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | UIFileSharingEnabled
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Checklist
2 |
3 | - [ ] Read CONTRIBUTING.md, and accept the CLA by including the provided snippet. We will not accept PR without this.
4 | - [ ] Test manually that the app is still working.
5 |
6 | ## PR Description
7 |
8 |
9 |
--------------------------------------------------------------------------------
/MoshiLib/MoshiLib.docc/MoshiLib.md:
--------------------------------------------------------------------------------
1 | # ``MoshiLib``
2 |
3 | Summary
4 |
5 | ## Overview
6 |
7 | Text
8 |
9 | ## Topics
10 |
11 | ### Group
12 |
13 | - ``Symbol``
14 |
--------------------------------------------------------------------------------
/MoshiTests/moshiTests.swift:
--------------------------------------------------------------------------------
1 | //
2 | // moshiTests.swift
3 | // moshiTests
4 | //
5 | // Created by Laurent on 20/11/2024.
6 | //
7 |
8 | import Testing
9 |
10 | struct moshiTests {
11 |
12 | @Test func example() async throws {
13 | // Write your test here and use APIs like `#expect(...)` to check expected conditions.
14 | }
15 |
16 | }
17 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .helix
2 | .vscode
3 | *.dylib
4 | *.so
5 | *.swp
6 | *.swo
7 | trace-*.json
8 | .DS_Store
9 | .idea/*
10 | __pycache__
11 | build
12 | buildServer.json
13 | xcuserdata
14 | model*.safetensors
15 | mimi-*.safetensors
16 | moshi-*.safetensors
17 | codes-*.safetensors
18 | asr-*.safetensors
19 | moshi-out.wav
20 | moshi-trace.json
21 | bria*safetensors
22 | bria*wav
23 |
--------------------------------------------------------------------------------
/MoshiLibTests/moshi_libTests.swift:
--------------------------------------------------------------------------------
1 | //
2 | // moshi_libTests.swift
3 | // moshi-libTests
4 | //
5 | // Created by Laurent on 20/11/2024.
6 | //
7 |
8 | import Testing
9 |
10 | @testable import moshi_lib
11 |
12 | struct moshi_libTests {
13 |
14 | @Test func example() async throws {
15 | // Write your test here and use APIs like `#expect(...)` to check expected conditions.
16 | }
17 |
18 | }
19 |
--------------------------------------------------------------------------------
/scripts/sp.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sentencepiece
4 |
5 | home_dir = os.path.expanduser("~")
6 | text_tokenizer = sentencepiece.SentencePieceProcessor(home_dir + "/tmp/tokenizer_spm_32k_3.model")
7 |
8 | vocab = {}
9 | for id in range(text_tokenizer.vocab_size()):
10 | vocab[id] = text_tokenizer.id_to_piece(id)
11 | with open(home_dir + "/tmp/tokenizer_spm_32k_3.json", "w") as json_file:
12 | json.dump(vocab, json_file)
13 |
--------------------------------------------------------------------------------
/MoshiLib/MoshiLib.h:
--------------------------------------------------------------------------------
1 | //
2 | // moshi_lib.h
3 | // moshi-lib
4 | //
5 | // Created by Laurent on 20/11/2024.
6 | //
7 |
8 | #import
9 |
10 | //! Project version number for moshi_lib.
11 | FOUNDATION_EXPORT double MoshiLibVersionNumber;
12 |
13 | //! Project version string for moshi_lib.
14 | FOUNDATION_EXPORT const unsigned char MoshiLibVersionString[];
15 |
16 | // In this header, you should import all the public headers of your framework using statements like #import
17 |
18 |
19 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: format run-1b run-asr build
2 |
3 | format:
4 | swift-format format --in-place --recursive .
5 |
6 | run-1b: build
7 | ./build/Build/Products/Release/MoshiCLI run
8 |
9 | run-asr: build
10 | ./build/Build/Products/Release/MoshiCLI run-asr
11 |
12 | run-mimi: build
13 | ./build/Build/Products/Release/MoshiCLI run-mimi
14 |
15 | run-helium: build
16 | ./build/Build/Products/Release/MoshiCLI run-helium
17 |
18 | run-qwen: build
19 | ./build/Build/Products/Release/MoshiCLI run-qwen
20 |
21 | build:
22 | xcodebuild -scheme moshi-cli -derivedDataPath ./build
23 |
--------------------------------------------------------------------------------
/Moshi/moshi.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.developer.kernel.increased-memory-limit
6 |
7 | com.apple.security.app-sandbox
8 |
9 | com.apple.security.device.audio-input
10 |
11 | com.apple.security.files.downloads.read-write
12 |
13 | com.apple.security.files.user-selected.read-write
14 |
15 | com.apple.security.network.client
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/MoshiUITests/moshiUITestsLaunchTests.swift:
--------------------------------------------------------------------------------
1 | //
2 | // moshiUITestsLaunchTests.swift
3 | // moshiUITests
4 | //
5 | // Created by Laurent on 20/11/2024.
6 | //
7 |
8 | import XCTest
9 |
10 | final class moshiUITestsLaunchTests: XCTestCase {
11 |
12 | override class var runsForEachTargetApplicationUIConfiguration: Bool {
13 | true
14 | }
15 |
16 | override func setUpWithError() throws {
17 | continueAfterFailure = false
18 | }
19 |
20 | @MainActor
21 | func testLaunch() throws {
22 | let app = XCUIApplication()
23 | app.launch()
24 |
25 | // Insert steps here to perform after app launch but before taking a screenshot,
26 | // such as logging into a test account or navigating somewhere in the app
27 |
28 | let attachment = XCTAttachment(screenshot: app.screenshot())
29 | attachment.name = "Launch Screen"
30 | attachment.lifetime = .keepAlways
31 | add(attachment)
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Permission is hereby granted, free of charge, to any
2 | person obtaining a copy of this software and associated
3 | documentation files (the "Software"), to deal in the
4 | Software without restriction, including without
5 | limitation the rights to use, copy, modify, merge,
6 | publish, distribute, sublicense, and/or sell copies of
7 | the Software, and to permit persons to whom the Software
8 | is furnished to do so, subject to the following
9 | conditions:
10 |
11 | The above copyright notice and this permission notice
12 | shall be included in all copies or substantial portions
13 | of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
23 | DEALINGS IN THE SOFTWARE.
24 |
--------------------------------------------------------------------------------
/MoshiUITests/moshiUITests.swift:
--------------------------------------------------------------------------------
1 | //
2 | // moshiUITests.swift
3 | // moshiUITests
4 | //
5 | // Created by Laurent on 20/11/2024.
6 | //
7 |
8 | import XCTest
9 |
10 | final class moshiUITests: XCTestCase {
11 |
12 | override func setUpWithError() throws {
13 | // Put setup code here. This method is called before the invocation of each test method in the class.
14 |
15 | // In UI tests it is usually best to stop immediately when a failure occurs.
16 | continueAfterFailure = false
17 |
18 | // In UI tests it’s important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this.
19 | }
20 |
21 | override func tearDownWithError() throws {
22 | // Put teardown code here. This method is called after the invocation of each test method in the class.
23 | }
24 |
25 | @MainActor
26 | func testExample() throws {
27 | // UI tests must launch the application that they test.
28 | let app = XCUIApplication()
29 | app.launch()
30 |
31 | // Use XCTAssert and related functions to verify your tests produce the correct results.
32 | }
33 |
34 | @MainActor
35 | func testLaunchPerformance() throws {
36 | if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 7.0, *) {
37 | // This measures how long it takes to launch your application.
38 | measure(metrics: [XCTApplicationLaunchMetric()]) {
39 | XCUIApplication().launch()
40 | }
41 | }
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/Moshi/moshiApp.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import AVFoundation
6 | import Foundation
7 | import SwiftUI
8 |
9 | func requestMicrophoneAccess() {
10 | switch AVCaptureDevice.authorizationStatus(for: .audio) {
11 | case .authorized:
12 | return
13 | case .notDetermined:
14 | AVCaptureDevice.requestAccess(for: .audio) { granted in
15 | print("granted", granted)
16 | }
17 | case .denied: // The user has previously denied access.
18 | return
19 | case .restricted: // The user can't grant access due to restrictions.
20 | return
21 | case _:
22 | return
23 | }
24 | }
25 |
26 | @main
27 | struct moshiApp: App {
28 | @Environment(\.scenePhase) var scenePhase
29 |
30 | init() {
31 | requestMicrophoneAccess()
32 | }
33 |
34 | var body: some Scene {
35 | WindowGroup {
36 | ContentView()
37 | .environment(DeviceStat())
38 | }
39 | #if os(iOS)
40 | .onChange(of: scenePhase) { (phase) in
41 | switch phase {
42 | case .active:
43 | // In iOS 13+, idle timer needs to be set in scene to override default
44 | UIApplication.shared.isIdleTimerDisabled = true
45 | case .inactive: break
46 | case .background: break
47 | @unknown default: print("ScenePhase: unexpected state")
48 | }
49 | }
50 | #endif
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/Moshi/DeviceStat.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import MLX
3 |
4 | enum ThermalState: String {
5 | case nominal = "Nominal"
6 | case fair = "Fair"
7 | case serious = "Serious"
8 | case critical = "Critical"
9 | case unknown = "Unknown"
10 | }
11 |
12 | @Observable
13 | final class DeviceStat: @unchecked Sendable {
14 |
15 | @MainActor
16 | var gpuUsage = GPU.snapshot()
17 | @MainActor
18 | var thermalState: ThermalState = .nominal
19 |
20 | private let initialGPUSnapshot = GPU.snapshot()
21 | private var timer: Timer?
22 |
23 | init() {
24 | timer = Timer.scheduledTimer(withTimeInterval: 2.0, repeats: true) { [weak self] _ in
25 | self?.updateGPUUsages()
26 | }
27 | }
28 |
29 | deinit {
30 | timer?.invalidate()
31 | }
32 |
33 | private func updateGPUUsages() {
34 | let gpuSnapshotDelta = initialGPUSnapshot.delta(GPU.snapshot())
35 | let thermalState = ProcessInfo.processInfo.thermalState
36 | DispatchQueue.main.async { [weak self] in
37 | self?.gpuUsage = gpuSnapshotDelta
38 | switch thermalState {
39 | case .nominal:
40 | self?.thermalState = .nominal
41 | case .fair:
42 | self?.thermalState = .fair
43 | case .serious:
44 | self?.thermalState = .serious
45 | case .critical:
46 | self?.thermalState = .critical
47 | default:
48 | self?.thermalState = .unknown
49 | }
50 | }
51 | }
52 |
53 | }
54 |
--------------------------------------------------------------------------------
/Moshi/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "idiom" : "universal",
5 | "platform" : "ios",
6 | "size" : "1024x1024"
7 | },
8 | {
9 | "appearances" : [
10 | {
11 | "appearance" : "luminosity",
12 | "value" : "dark"
13 | }
14 | ],
15 | "idiom" : "universal",
16 | "platform" : "ios",
17 | "size" : "1024x1024"
18 | },
19 | {
20 | "appearances" : [
21 | {
22 | "appearance" : "luminosity",
23 | "value" : "tinted"
24 | }
25 | ],
26 | "idiom" : "universal",
27 | "platform" : "ios",
28 | "size" : "1024x1024"
29 | },
30 | {
31 | "idiom" : "mac",
32 | "scale" : "1x",
33 | "size" : "16x16"
34 | },
35 | {
36 | "idiom" : "mac",
37 | "scale" : "2x",
38 | "size" : "16x16"
39 | },
40 | {
41 | "idiom" : "mac",
42 | "scale" : "1x",
43 | "size" : "32x32"
44 | },
45 | {
46 | "idiom" : "mac",
47 | "scale" : "2x",
48 | "size" : "32x32"
49 | },
50 | {
51 | "idiom" : "mac",
52 | "scale" : "1x",
53 | "size" : "128x128"
54 | },
55 | {
56 | "idiom" : "mac",
57 | "scale" : "2x",
58 | "size" : "128x128"
59 | },
60 | {
61 | "idiom" : "mac",
62 | "scale" : "1x",
63 | "size" : "256x256"
64 | },
65 | {
66 | "idiom" : "mac",
67 | "scale" : "2x",
68 | "size" : "256x256"
69 | },
70 | {
71 | "idiom" : "mac",
72 | "scale" : "1x",
73 | "size" : "512x512"
74 | },
75 | {
76 | "idiom" : "mac",
77 | "scale" : "2x",
78 | "size" : "512x512"
79 | }
80 | ],
81 | "info" : {
82 | "author" : "xcode",
83 | "version" : 1
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/MoshiCLI/RunHelium.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import AVFoundation
6 | import Foundation
7 | import MLX
8 | import MLXNN
9 | import MoshiLib
10 |
11 | func makeHelium(_ url: URL, _ cfg: LmConfig) throws -> LM {
12 | let weights = try loadArrays(url: url)
13 | let parameters = ModuleParameters.unflattened(weights)
14 | let model = LM(cfg, bSize: 1)
15 | if url.lastPathComponent.hasSuffix("q4.safetensors") {
16 | quantize(model: model, groupSize: 64, bits: 4)
17 | } else if url.lastPathComponent.hasSuffix("q6.safetensors") {
18 | quantize(model: model, groupSize: 64, bits: 6)
19 | } else if url.lastPathComponent.hasSuffix("q8.safetensors") {
20 | quantize(model: model, groupSize: 64, bits: 8)
21 | }
22 | try model.update(parameters: parameters, verify: [.all])
23 | eval(model)
24 | return model
25 | }
26 |
27 | func runHelium(_ url: URL, cfg: LmConfig) throws {
28 | let stats = PerfStats()
29 | let helium = try makeHelium(url, cfg)
30 | let vocab = try loadVocab(cfg)
31 | helium.warmup()
32 | print("done warming up")
33 |
34 | let maxSteps = helium.cfg.transformer.maxSeqLen
35 | let sampler = Sampler()
36 |
37 | var lastToken = MLXArray([1])
38 | for stepIdx in 0...maxSteps {
39 | let (textToken, _) = helium.sample(
40 | textIds: lastToken.reshaped([1, 1]), audioIds: [], stepIdx: stepIdx,
41 | textSampler: sampler,
42 | audioSampler: sampler, cb: stats)
43 | let textTokenI: Int = textToken[0].item()
44 | if var v = vocab[textTokenI] {
45 | if v == "<0x0A>" {
46 | print()
47 | } else {
48 | v.replace("▁", with: " ")
49 | print(v, terminator: "")
50 | fflush(stdout)
51 | }
52 | }
53 | lastToken = textToken
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/moshi.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved:
--------------------------------------------------------------------------------
1 | {
2 | "originHash" : "982acab5e54d309f8a98a66985086b61e82a869bc8ebd717682e681f43aac314",
3 | "pins" : [
4 | {
5 | "identity" : "jinja",
6 | "kind" : "remoteSourceControl",
7 | "location" : "https://github.com/johnmai-dev/Jinja",
8 | "state" : {
9 | "revision" : "bbddb92fc51ae420b87300298370fd1dfc308f73",
10 | "version" : "1.1.1"
11 | }
12 | },
13 | {
14 | "identity" : "mlx-swift",
15 | "kind" : "remoteSourceControl",
16 | "location" : "https://github.com/ml-explore/mlx-swift",
17 | "state" : {
18 | "revision" : "70dbb62128a5a1471a5ab80363430adb33470cab",
19 | "version" : "0.21.2"
20 | }
21 | },
22 | {
23 | "identity" : "swift-argument-parser",
24 | "kind" : "remoteSourceControl",
25 | "location" : "https://github.com/apple/swift-argument-parser.git",
26 | "state" : {
27 | "revision" : "41982a3656a71c768319979febd796c6fd111d5c",
28 | "version" : "1.5.0"
29 | }
30 | },
31 | {
32 | "identity" : "swift-collections",
33 | "kind" : "remoteSourceControl",
34 | "location" : "https://github.com/apple/swift-collections.git",
35 | "state" : {
36 | "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7",
37 | "version" : "1.1.4"
38 | }
39 | },
40 | {
41 | "identity" : "swift-numerics",
42 | "kind" : "remoteSourceControl",
43 | "location" : "https://github.com/apple/swift-numerics",
44 | "state" : {
45 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b",
46 | "version" : "1.0.2"
47 | }
48 | },
49 | {
50 | "identity" : "swift-transformers",
51 | "kind" : "remoteSourceControl",
52 | "location" : "https://github.com/huggingface/swift-transformers",
53 | "state" : {
54 | "revision" : "8a83416cc00ab07a5de9991e6ad817a9b8588d20",
55 | "version" : "0.1.15"
56 | }
57 | }
58 | ],
59 | "version" : 3
60 | }
61 |
--------------------------------------------------------------------------------
/MoshiCLI/Audio.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import AVFoundation
6 | import Foundation
7 |
8 | func readAudioToPCMArray(fileURL: URL, channel: Int = 0) -> [Float]? {
9 | do {
10 | let audioFile = try AVAudioFile(forReading: fileURL)
11 | let format = audioFile.processingFormat
12 | let frameCount = UInt32(audioFile.length)
13 | guard let buffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: frameCount) else {
14 | print("failed to create buffer")
15 | return nil
16 | }
17 |
18 | try audioFile.read(into: buffer)
19 | guard let channelData = buffer.floatChannelData else {
20 | print("failed to get channel data")
21 | return nil
22 | }
23 | let _ = Int(format.channelCount)
24 | let frameLength = Int(buffer.frameLength)
25 | var pcmData: [Float] = []
26 | let samples = channelData[channel]
27 | pcmData.append(contentsOf: UnsafeBufferPointer(start: samples, count: frameLength))
28 | return pcmData
29 | } catch {
30 | print("error reading audio file: \(error)")
31 | return nil
32 | }
33 | }
34 |
35 | func writeWAVFile(_ pcmData: [Float], sampleRate: Double, outputURL: URL) throws {
36 | let format = AVAudioFormat(
37 | standardFormatWithSampleRate: sampleRate, channels: AVAudioChannelCount(1))
38 | guard let format = format else {
39 | throw NSError(
40 | domain: "AudioFormat", code: -1,
41 | userInfo: [NSLocalizedDescriptionKey: "Failed to create audio format"])
42 | }
43 | let audioFile = try AVAudioFile(forWriting: outputURL, settings: format.settings)
44 | guard
45 | let buffer = AVAudioPCMBuffer(
46 | pcmFormat: format, frameCapacity: AVAudioFrameCount(pcmData.count))
47 | else {
48 | throw NSError(
49 | domain: "PCMBuffer", code: -1,
50 | userInfo: [NSLocalizedDescriptionKey: "Failed to create PCM buffer"])
51 | }
52 | buffer.frameLength = AVAudioFrameCount(pcmData.count)
53 | let channelData = buffer.floatChannelData!
54 | channelData[0].update(from: Array(pcmData), count: pcmData.count)
55 | try audioFile.write(from: buffer)
56 | }
57 |
--------------------------------------------------------------------------------
/MoshiLib/Utils.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import MLX
6 | import MLXFast
7 | import MLXNN
8 | import MLXRandom
9 |
10 | func topKSampling(logits: MLXArray, topK: Int, temp: Float) -> MLXArray {
11 | let c = logits.dim(-1)
12 | let sortedIndices = argSort(logits, axis: -1)
13 | let sortedLogits = logits[0..., sortedIndices.squeezed(axis: 0)]
14 | let topLogits = `where`(MLXArray(0..= c - topK, sortedLogits, -Float.infinity)
15 | let sortedToken = categorical(topLogits / temp)
16 | let token = sortedIndices.squeezed(axis: 0)[sortedToken]
17 | return token
18 | }
19 |
20 | func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArray {
21 | let probs = softmax(logits * (1 / temp), axis: -1)
22 | let sortedIndices = argSort(probs, axis: -1)
23 | let sortedProbs = probs[0..., sortedIndices.squeezed(axis: 0)]
24 | let cumulativeProbs = cumsum(sortedProbs, axis: -1)
25 | let topProbs = `where`(cumulativeProbs .> 1 - topP, sortedProbs, 0)
26 | let sortedToken = categorical(topProbs.log())
27 | let token = sortedIndices.squeezed(axis: 0)[sortedToken]
28 | return token
29 | }
30 |
31 | func categoricalSampling(logits: MLXArray, temp: Float) -> MLXArray {
32 | categorical(logits * (1 / temp))
33 | }
34 |
35 | public class Sampler {
36 | let temp: Float
37 | let topP: Float = 0.95
38 | let topK: Int = 0
39 |
40 | public init(temp: Float = 0.8) {
41 | self.temp = temp
42 | }
43 |
44 | public func callAsFunction(logits: MLXArray) -> (MLXArray, MLXArray) {
45 | if logits.shape.count != 2 {
46 | fatalError("expected a two dimensions logits array, got \(logits.shape)")
47 | }
48 | let logProbs = logits - logits.logSumExp()
49 | var tokens: MLXArray
50 | if temp <= 0.0 {
51 | tokens = logProbs.argMax(axis: -1)
52 | } else if self.topP > 0.0 && self.topP < 1.0 {
53 | tokens = topPSampling(logits: logits, topP: self.topP, temp: self.temp)
54 | } else if self.topK != 0 {
55 | tokens = topKSampling(logits: logits, topK: self.topK, temp: self.temp)
56 | } else {
57 | tokens = categoricalSampling(logits: logits, temp: self.temp)
58 | }
59 | return (tokens, logProbs)
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/scripts/eval_mlx.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import mlx.core as mx
4 | import mlx.nn as nn
5 | from moshi_mlx import models
6 |
7 | def config1b() -> models.lm.LmConfig:
8 | transformer = models.lm.TransformerConfig(
9 | d_model=2048,
10 | num_heads=16,
11 | num_layers=16,
12 | dim_feedforward=2048 * 4, # dim * hidden_scale
13 | causal=True,
14 | norm_first=True,
15 | bias_ff=False,
16 | bias_attn=False,
17 | layer_scale=None,
18 | context=750,
19 | max_period=100000,
20 | use_conv_block=False,
21 | use_conv_bias=True,
22 | cross_attention=False,
23 | gating=True,
24 | norm="rms_norm",
25 | positional_embedding="rope",
26 | conv_layout=False,
27 | conv_kernel_size=3,
28 | kv_repeat=1,
29 | max_seq_len=4096,
30 | )
31 | depformer = models.lm.DepFormerConfig(
32 | transformer=models.lm.TransformerConfig(
33 | d_model=1024,
34 | num_heads=16,
35 | num_layers=6,
36 | dim_feedforward=1024 * 4, # dim * hidden_scale
37 | causal=True,
38 | norm_first=True,
39 | bias_ff=False,
40 | bias_attn=False,
41 | layer_scale=None,
42 | context=8,
43 | max_period=10000,
44 | use_conv_block=False,
45 | use_conv_bias=True,
46 | cross_attention=False,
47 | gating=True,
48 | norm="rms_norm",
49 | positional_embedding="none",
50 | conv_layout=False,
51 | conv_kernel_size=3,
52 | kv_repeat=1,
53 | max_seq_len=4096,
54 | ),
55 | num_slices=0,
56 | )
57 | return models.lm.LmConfig(
58 | transformer=transformer,
59 | depformer=depformer,
60 | audio_vocab_size=2049,
61 | text_in_vocab_size=48001,
62 | text_out_vocab_size=48000,
63 | audio_codebooks=8,
64 | audio_delays=[0]*16
65 | )
66 | home_dir = os.path.expanduser("~")
67 |
68 | lm_config = config1b()
69 | model = models.Lm(lm_config)
70 | model.set_dtype(mx.bfloat16)
71 | model.load_weights(home_dir + "/tmp/asr-1b-8d2516b9@150.safetensors", strict=True)
72 |
73 | def run_one():
74 | text_token_ids = mx.array([48000]).reshape(1, 1)
75 | audio_token_ids = [mx.array([2048]).reshape(1, 1)] * 8
76 | xs = model.text_emb(text_token_ids)
77 | for token_ids, emb in zip(audio_token_ids, model.audio_embs):
78 | xs = xs + emb(token_ids)
79 | transformer_out = model.transformer(xs, cache=model.transformer_cache)
80 | transformer_out = model.out_norm(transformer_out)
81 | text_logits = model.text_linear(transformer_out)
82 | print(text_logits)
83 |
84 | run_one()
85 |
--------------------------------------------------------------------------------
/MoshiLib/ASR.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import AVFoundation
6 | import Foundation
7 | import Hub
8 | import MLX
9 | import MLXNN
10 | import MoshiLib
11 |
12 | public class ASR {
13 | let moshi: LM
14 | let vocab: [Int: String]
15 | let mimi: Mimi
16 | var prevTextToken: Int = 0
17 | let sampler: Sampler = Sampler(temp: 0.0)
18 | let cb: Callbacks
19 |
20 | public init(
21 | _ moshi: LM, _ mimi: Mimi, vocab: [Int: String], cb: Callbacks = EmptyCallbacks()
22 | ) {
23 | self.moshi = moshi
24 | self.mimi = mimi
25 | self.vocab = vocab
26 | self.cb = cb
27 | }
28 |
29 | public func reset() {
30 | mimi.resetState()
31 | moshi.resetCache()
32 | prevTextToken = self.moshi.cfg.textInitToken()
33 | let textIds = MLXArray([prevTextToken]).reshaped([1, 1])
34 | let audioIds = (0.. [String] {
45 | var tokens: [String] = []
46 | let codebooks = moshi.cfg.audioCodebooks
47 | cb.onEvent(.beginEncode)
48 | let codes = mimi.encodeStep(StreamArray(pcm))
49 | codes.eval()
50 | cb.onEvent(.endEncode)
51 | if let codes = codes.asArray() {
52 | cb.onInputAudioTokens(codes)
53 | let (_, _, steps) = codes.shape3
54 | for step in 0.. I, {your GitHub handle}, confirm that I have read and understood the terms of the CLA of Kyutai-labs, as outlined in the repository's CONTRIBUTING.md, and I agree to be bound by these terms.
25 |
26 | The full CLA is provided as follows:
27 |
28 | > I, {your GitHub handle}, hereby grant to Kyutai-labs a perpetual, worldwide, non-exclusive, royalty-free,
29 | > irrevocable license to use, modify, distribute, and sublicense my Contributions.
30 |
31 | > I understand and accept that Contributions are limited to modifications, improvements, or changes
32 | > to the project’s source code submitted via pull requests. I accept that Kyutai-labs has full discretion to
33 | > review, accept, reject, or request changes to any Contributions I submit, and that submitting
34 | > a pull request does not guarantee its inclusion in the project.
35 |
36 | > By submitting a Contribution, I grant Kyutai-labs a perpetual, worldwide license to use, modify,
37 | > reproduce, distribute, and create derivative works based on my Contributions.
38 | > I also agree to assign all patent rights for any inventions or improvements that arise from my Contributions,
39 | > giving the Kyutai-labs full rights to file for and enforce patents.
40 | > I understand that the Kyutai-labs may commercialize, relicense, or exploit the project and my Contributions without further notice or obligation to me.
41 | > I confirm that my Contributions are original and that I have the legal right to grant this license.
42 | > If my Contributions include third-party materials, I will ensure that I have the necessary permissions
43 | > and will disclose this information. I accept that once my Contributions are integrated, they may be altered or removed at the Kyutai-labs’s discretion.
44 |
45 | > I acknowledge that I am making these Contributions voluntarily and will not receive any compensation.
46 | > Furthermore, I understand that all Contributions, including mine, are provided on an "as-is" basis, with no warranties.
47 | > By submitting a pull request, I agree to be bound by these terms.
48 |
49 | ## Issues
50 |
51 | Please submit issues on our Github repository.
52 |
53 | ## License
54 |
55 | By contributing to Moshi, you agree that your contributions will be licensed
56 | under the MIT license as provided by the LICENSE file at the top of the repo.
57 |
--------------------------------------------------------------------------------
/moshi.xcodeproj/xcshareddata/xcschemes/moshi-cli.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
25 |
26 |
32 |
33 |
44 |
46 |
52 |
53 |
54 |
55 |
58 |
59 |
62 |
63 |
64 |
65 |
71 |
73 |
79 |
80 |
81 |
82 |
84 |
85 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/moshi.xcodeproj/xcshareddata/xcschemes/Moshi.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
25 |
26 |
32 |
33 |
36 |
42 |
43 |
44 |
47 |
53 |
54 |
55 |
56 |
57 |
67 |
69 |
75 |
76 |
77 |
78 |
84 |
86 |
92 |
93 |
94 |
95 |
97 |
98 |
101 |
102 |
103 |
--------------------------------------------------------------------------------
/MoshiLib/Streaming.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import MLX
6 | import MLXFast
7 | import MLXNN
8 | import MLXRandom
9 |
10 | public class StreamArray {
11 | let inner: MLXArray?
12 |
13 | public init(_ x: MLXArray? = nil) {
14 | self.inner = x
15 | }
16 |
17 | public func dim(_ dim: Int) -> Int {
18 | self.inner?.dim(dim) ?? 0
19 | }
20 |
21 | public func eval() {
22 | self.inner?.eval()
23 | }
24 |
25 | public func cat2(_ rhs: StreamArray, axis: Int) -> StreamArray {
26 | switch (self.inner, rhs.inner) {
27 | case (.none, .none): StreamArray()
28 | case (.some(let lhs), .none): StreamArray(lhs)
29 | case (.none, .some(let rhs)): StreamArray(rhs)
30 | case (.some(let lhs), .some(let rhs)): StreamArray(concatenated([lhs, rhs], axis: axis))
31 | }
32 | }
33 |
34 | public var shape: [Int]? {
35 | self.inner?.shape
36 | }
37 |
38 | public func asArray() -> MLXArray? {
39 | self.inner
40 | }
41 |
42 | public func narrow(_ offset: Int, _ len: Int, axis: Int) -> StreamArray {
43 | if let inner = self.inner {
44 | let totalLen = inner.dim(axis)
45 | if len <= 0 {
46 | return StreamArray()
47 | } else {
48 | // Not sure if there exists something closer to pytorch narrow
49 | let t = inner.split(indices: [offset, min(totalLen, offset + len)], axis: axis)
50 | return StreamArray(t[1])
51 | }
52 | } else {
53 | return StreamArray()
54 | }
55 | }
56 |
57 | public func split(lhsLen: Int, axis: Int) -> (StreamArray, StreamArray) {
58 | if let t = self.inner {
59 | let len = t.dim(axis)
60 | let lhsLen = min(len, lhsLen)
61 | if lhsLen == 0 {
62 | return (StreamArray(), StreamArray(t))
63 | } else if lhsLen == len {
64 | return (StreamArray(t), StreamArray())
65 | } else {
66 | let split = t.split(indices: [lhsLen], axis: axis)
67 | return (StreamArray(split[0]), StreamArray(split[1]))
68 | }
69 | } else {
70 | return (StreamArray(), StreamArray())
71 | }
72 | }
73 |
74 | public func elu() -> StreamArray {
75 | self.map { MLXNN.elu($0, alpha: 1.0) }
76 | }
77 |
78 | public func map(_ f: (MLXArray) -> MLXArray) -> StreamArray {
79 | switch self.inner {
80 | case .none: StreamArray()
81 | case .some(let x): StreamArray(f(x))
82 | }
83 | }
84 | }
85 |
86 | public protocol StreamingLayer {
87 | func resetState()
88 | func step(_ x: StreamArray) -> StreamArray
89 | }
90 |
91 | public class StreamingBinOp {
92 | enum BinOp {
93 | case add
94 | case mul
95 | case sub
96 | case div
97 | }
98 |
99 | var prevLHS: StreamArray
100 | var prevRHS: StreamArray
101 | let op: BinOp
102 | let axis: Int
103 |
104 | init(_ op: BinOp, axis: Int) {
105 | self.prevLHS = StreamArray()
106 | self.prevRHS = StreamArray()
107 | self.op = op
108 | self.axis = axis
109 | }
110 |
111 | public func resetState() {
112 | self.prevLHS = StreamArray()
113 | self.prevRHS = StreamArray()
114 | }
115 |
116 | public func step(_ lhs: StreamArray, _ rhs: StreamArray) -> StreamArray {
117 | let lhs = self.prevLHS.cat2(lhs, axis: self.axis)
118 | let rhs = self.prevRHS.cat2(rhs, axis: self.axis)
119 | let lhsLen = lhs.dim(self.axis)
120 | let rhsLen = rhs.dim(self.axis)
121 | let commonLen = min(lhsLen, rhsLen)
122 | let (lhs_, prevLHS) = lhs.split(lhsLen: commonLen, axis: self.axis)
123 | let (rhs_, prevRHS) = rhs.split(lhsLen: commonLen, axis: self.axis)
124 | self.prevLHS = prevLHS
125 | self.prevRHS = prevRHS
126 | switch (lhs_.inner, rhs_.inner) {
127 | case (.some(let l), .some(let r)):
128 | var res: MLXArray
129 | switch self.op {
130 | case .add: res = l + r
131 | case .sub: res = l - r
132 | case .mul: res = l * r
133 | case .div: res = l / r
134 | }
135 | return StreamArray(res)
136 | case (.none, .none): return StreamArray()
137 | case _: fatalError("internal error")
138 | }
139 | }
140 | }
141 |
--------------------------------------------------------------------------------
/MoshiLib/Mimi.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import MLX
6 | import MLXFast
7 | import MLXNN
8 | import MLXRandom
9 |
10 | public struct MimiConfig {
11 | public var channels: Int
12 | public var sampleRate: Float
13 | public var frameRate: Float
14 | public var renormalize: Bool
15 | // public var resampleMethod: String
16 | public var seanet: SeanetConfig
17 | public var transformer: TransformerConfig
18 | public var quantizerNQ: Int
19 | public var quantizerBins: Int
20 | public var quantizerDim: Int
21 |
22 | public static func mimi_2024_07(numCodebooks: Int = 16) -> MimiConfig {
23 | let seanet = SeanetConfig.v0_1()
24 | let transformer = TransformerConfig(
25 | dModel: seanet.dimension,
26 | numHeads: 8,
27 | numLayers: 8,
28 | causal: true,
29 | normFirst: true,
30 | biasFF: false,
31 | biasAttn: false,
32 | layerScale: 0.01,
33 | positionalEmbedding: .rope,
34 | useConvBias: true,
35 | gating: false,
36 | norm: .layerNorm,
37 | context: 250,
38 | maxPeriod: 10000,
39 | maxSeqLen: 8192,
40 | kvRepeat: 1,
41 | dimFeedForward: 2048,
42 | convLayout: true,
43 | useRotatingKVCache: true
44 | )
45 | return MimiConfig(
46 | channels: 1,
47 | sampleRate: 24000, frameRate: 12.5, renormalize: true, seanet: seanet,
48 | transformer: transformer, quantizerNQ: numCodebooks, quantizerBins: 2048,
49 | quantizerDim: 256)
50 | }
51 | }
52 |
53 | public class Mimi: Module {
54 | let cfg: MimiConfig
55 | let encoderCache: [KVCache]
56 | let decoderCache: [KVCache]
57 | @ModuleInfo(key: "encoder") var encoder: SeanetEncoder
58 | @ModuleInfo(key: "decoder") var decoder: SeanetDecoder
59 | @ModuleInfo(key: "encoder_transformer") var encoderTransformer: ProjectedTransformer
60 | @ModuleInfo(key: "decoder_transformer") var decoderTransformer: ProjectedTransformer
61 | @ModuleInfo(key: "downsample") var downsample: ConvDownsample1d
62 | @ModuleInfo(key: "upsample") var upsample: ConvTrUpsample1d
63 | @ModuleInfo(key: "quantizer") var quantizer: SplitResidualVectorQuantizer
64 |
65 | public init(_ cfg: MimiConfig, bSize: Int) {
66 | let dim = cfg.seanet.dimension
67 | self.cfg = cfg
68 | let encoderFrameRate = cfg.sampleRate / Float(cfg.seanet.ratios.reduce(1, *))
69 | let downsampleStride = Int(encoderFrameRate / cfg.frameRate)
70 | self._encoder.wrappedValue = SeanetEncoder(cfg.seanet)
71 | self._decoder.wrappedValue = SeanetDecoder(cfg.seanet)
72 | self._quantizer.wrappedValue = SplitResidualVectorQuantizer(
73 | dim: cfg.quantizerDim, inputDim: dim, outputDim: dim, nQ: cfg.quantizerNQ,
74 | bins: cfg.quantizerBins)
75 | self._encoderTransformer.wrappedValue = ProjectedTransformer(
76 | cfg.transformer, inputDim: dim, outputDims: [dim])
77 | self._decoderTransformer.wrappedValue = ProjectedTransformer(
78 | cfg.transformer, inputDim: dim, outputDims: [dim])
79 | self._downsample.wrappedValue = ConvDownsample1d(
80 | stride: downsampleStride, dim: dim, causal: true)
81 | self._upsample.wrappedValue = ConvTrUpsample1d(
82 | stride: downsampleStride, dim: dim, causal: true)
83 | self.encoderCache = self._encoderTransformer.wrappedValue.makeCache(bSize: bSize)
84 | self.decoderCache = self._decoderTransformer.wrappedValue.makeCache(bSize: bSize)
85 | }
86 |
87 | public func warmup() {
88 | let pcm = MLXArray.zeros([1, 1, 1920 * 4])
89 | let codes = self.encode(pcm)
90 | let pcmOut = self.decode(codes)
91 | eval(pcmOut)
92 | }
93 |
94 | public func encode(_ x: MLXArray) -> MLXArray {
95 | self.encoder.resetState()
96 | self.encoderCache.forEach { c in c.reset() }
97 | var x = self.encoder(x)
98 | x = self.encoderTransformer(x, cache: self.encoderCache)[0]
99 | x = self.downsample(x)
100 | let codes = self.quantizer.encode(x)
101 | return codes
102 | }
103 |
104 | public func encodeStep(_ x: StreamArray) -> StreamArray {
105 | var x = self.encoder.step(x)
106 | x = x.map { self.encoderTransformer($0, cache: self.encoderCache)[0] }
107 | x = self.downsample.step(x)
108 | let codes = x.map(self.quantizer.encode)
109 | return codes
110 | }
111 |
112 | public func decode(_ codes: MLXArray) -> MLXArray {
113 | self.decoder.resetState()
114 | self.decoderCache.forEach { c in c.reset() }
115 | let emb = self.quantizer.decode(codes)
116 | let embUp = self.upsample(emb)
117 | let outs = self.decoderTransformer(embUp, cache: self.decoderCache)
118 | return self.decoder(outs[0])
119 | }
120 |
121 | public func decodeStep(_ codes: StreamArray) -> StreamArray {
122 | var emb = codes.map { self.quantizer.decode($0) }
123 | emb = self.upsample.step(emb)
124 | let out = emb.map { self.decoderTransformer($0, cache: self.decoderCache)[0] }
125 | return self.decoder.step(out)
126 | }
127 |
128 | public func resetState() {
129 | self.encoder.resetState()
130 | self.decoder.resetState()
131 | self.encoderCache.forEach { c in c.reset() }
132 | self.decoderCache.forEach { c in c.reset() }
133 | }
134 | }
135 |
--------------------------------------------------------------------------------
/MoshiLib/Qwen2.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import MLX
6 | import MLXFast
7 | import MLXNN
8 |
9 | public struct QwenQuantization: Codable {
10 | public var groupSize: Int
11 | public var bits: Int
12 | }
13 |
14 | public struct QwenConfig: Codable {
15 | public var bosTokenId: Int
16 | public var eosTokenId: Int
17 | public var hiddenSize: Int
18 | public var intermediateSize: Int
19 | public var maxPositionEmbeddings: Int
20 | public var maxWindowLayers: Int
21 | public var numAttentionHeads: Int
22 | public var numHiddenLayers: Int
23 | public var numKeyValueHeads: Int
24 | public var rmsNormEps: Float
25 | public var ropeTheta: Float
26 | public var tieWordEmbeddings: Bool
27 | public var useSlidingWindow: Bool
28 | public var vocabSize: Int
29 | public var quantization: QwenQuantization? = nil
30 |
31 | public func headDim() -> Int {
32 | self.hiddenSize / self.numAttentionHeads
33 | }
34 | }
35 |
36 | private class Mlp: Module, UnaryLayer {
37 | @ModuleInfo(key: "gate_proj") var gateProj: Linear
38 | @ModuleInfo(key: "down_proj") var downProj: Linear
39 | @ModuleInfo(key: "up_proj") var upProj: Linear
40 |
41 | init(_ cfg: QwenConfig) {
42 | self._gateProj.wrappedValue = Linear(cfg.hiddenSize, cfg.intermediateSize, bias: false)
43 | self._upProj.wrappedValue = Linear(cfg.hiddenSize, cfg.intermediateSize, bias: false)
44 | self._downProj.wrappedValue = Linear(cfg.intermediateSize, cfg.hiddenSize, bias: false)
45 | }
46 |
47 | func callAsFunction(_ x: MLXArray) -> MLXArray {
48 | return downProj(silu(gateProj(x)) * upProj(x))
49 | }
50 | }
51 |
52 | private class Attention: Module {
53 | let cfg: QwenConfig
54 | let scale: Float
55 | let rope: RoPE
56 |
57 | @ModuleInfo(key: "q_proj") var qProj: Linear
58 | @ModuleInfo(key: "k_proj") var kProj: Linear
59 | @ModuleInfo(key: "v_proj") var vProj: Linear
60 | @ModuleInfo(key: "o_proj") var oProj: Linear
61 |
62 | init(_ cfg: QwenConfig) {
63 | self.cfg = cfg
64 | self.scale = 1.0 / sqrt(Float(cfg.headDim()))
65 | let headDim = cfg.headDim()
66 | self._qProj.wrappedValue = Linear(
67 | cfg.hiddenSize, cfg.numAttentionHeads * headDim, bias: true)
68 | self._kProj.wrappedValue = Linear(
69 | cfg.hiddenSize, cfg.numKeyValueHeads * headDim, bias: true)
70 | self._vProj.wrappedValue = Linear(
71 | cfg.hiddenSize, cfg.numKeyValueHeads * headDim, bias: true)
72 | self._oProj.wrappedValue = Linear(
73 | cfg.numAttentionHeads * headDim, cfg.hiddenSize, bias: false)
74 | self.rope =
75 | RoPE(dimensions: cfg.headDim(), traditional: false, base: Float(cfg.ropeTheta))
76 | }
77 |
78 | func callAsFunction(_ x: MLXArray, mask: MLXArray?, cache: KVCache?) -> MLXArray {
79 | let (B, T, H) = (x.dim(0), x.dim(1), x.dim(2))
80 | let headDim = cfg.headDim()
81 | var queryStates = qProj(x).reshaped(B, T, -1, headDim).transposed(0, 2, 1, 3)
82 | var keyStates = kProj(x).reshaped(B, T, -1, headDim).transposed(0, 2, 1, 3)
83 | var valueStates = vProj(x).reshaped(B, T, -1, headDim).transposed(0, 2, 1, 3)
84 | let offset = cache?.offset ?? 0
85 | queryStates = rope(queryStates, offset: offset)
86 | keyStates = rope(keyStates, offset: offset)
87 | if let cache {
88 | (keyStates, valueStates) = cache.update(keys: keyStates, values: valueStates)
89 | }
90 | // sliding window is not supported here
91 | var mask = mask
92 | if let m = mask {
93 | let maskLen = m.dim(-1)
94 | if keyStates.dim(2) < maskLen {
95 | let offset = maskLen - keyStates.dim(2)
96 | mask = m[0..., offset...]
97 | }
98 | }
99 | let x = MLXFast.scaledDotProductAttention(
100 | queries: queryStates, keys: keyStates, values: valueStates, scale: self.scale,
101 | mask: mask
102 | ).transposed(0, 2, 1, 3).reshaped(B, T, H)
103 | return oProj(x)
104 | }
105 | }
106 |
107 | private class Layer: Module {
108 | @ModuleInfo(key: "mlp") var mlp: Mlp
109 | @ModuleInfo(key: "input_layernorm") var inputNorm: RMSNorm
110 | @ModuleInfo(key: "post_attention_layernorm") var postAttnNorm: RMSNorm
111 | @ModuleInfo(key: "self_attn") var selfAttn: Attention
112 |
113 | init(_ cfg: QwenConfig) {
114 | self._mlp.wrappedValue = Mlp(cfg)
115 | self._inputNorm.wrappedValue = RMSNorm(dimensions: cfg.hiddenSize, eps: cfg.rmsNormEps)
116 | self._postAttnNorm.wrappedValue = RMSNorm(dimensions: cfg.hiddenSize, eps: cfg.rmsNormEps)
117 | self._selfAttn.wrappedValue = Attention(cfg)
118 | }
119 |
120 | func callAsFunction(_ x: MLXArray, mask: MLXArray?, cache: KVCache) -> MLXArray {
121 | var residual = x
122 | var x = x
123 | x = selfAttn(inputNorm(x), mask: mask, cache: cache)
124 | x = residual + x
125 | residual = x
126 | x = mlp(postAttnNorm(x))
127 | return residual + x
128 | }
129 | }
130 |
131 | public class QwenModel: Module {
132 | let cfg: QwenConfig
133 | private let norm: RMSNorm
134 | private let layers: [Layer]
135 | @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
136 |
137 | public init(_ cfg: QwenConfig) {
138 | self.cfg = cfg
139 | self.layers = (0.. MLXArray {
146 | var x = embedTokens(x)
147 | let mask = cache.first?.createAttentionMask(h: x)
148 | for (layer, c) in zip(self.layers, cache) {
149 | x = layer(x, mask: mask, cache: c)
150 | }
151 | return embedTokens.asLinear(norm(x))
152 | }
153 |
154 | public func makeCache(bSize: Int) -> [KVCache] {
155 | let kvHeads = cfg.numKeyValueHeads
156 | let cache = (0.. {
5 | private var buffer: [T] = []
6 | private let queue = DispatchQueue(label: "tschannel", attributes: .concurrent)
7 | private let semaphore = DispatchSemaphore(value: 0)
8 |
9 | func send(_ value: T) {
10 | queue.async(flags: .barrier) {
11 | self.buffer.append(value)
12 | self.semaphore.signal()
13 | }
14 | }
15 |
16 | func receive() -> T? {
17 | semaphore.wait()
18 | return queue.sync {
19 | guard !buffer.isEmpty else { return nil }
20 | return buffer.removeFirst()
21 | }
22 | }
23 | }
24 |
25 | // The code below is probably macos specific and unlikely to work on ios.
26 | class MicrophoneCapture {
27 | private let audioEngine: AVAudioEngine
28 | private let channel: ThreadSafeChannel<[Float]>
29 |
30 | init() {
31 | audioEngine = AVAudioEngine()
32 | channel = ThreadSafeChannel()
33 | }
34 |
35 | func startCapturing() {
36 | let inputNode = audioEngine.inputNode
37 |
38 | // Desired format: 1 channel (mono), 24kHz, Float32
39 | let desiredSampleRate: Double = 24000.0
40 | let desiredChannelCount: AVAudioChannelCount = 1
41 |
42 | let inputFormat = inputNode.inputFormat(forBus: 0)
43 |
44 | // Create a custom audio format with the desired settings
45 | guard
46 | let mono24kHzFormat = AVAudioFormat(
47 | commonFormat: .pcmFormatFloat32,
48 | sampleRate: desiredSampleRate,
49 | channels: desiredChannelCount,
50 | interleaved: false)
51 | else {
52 | print("Could not create target format")
53 | return
54 | }
55 |
56 | // Resample the buffer to match the desired format
57 | let converter = AVAudioConverter(from: inputFormat, to: mono24kHzFormat)
58 |
59 | // Install a tap to capture audio and resample to the target format
60 | inputNode.installTap(onBus: 0, bufferSize: 1920, format: inputFormat) { buffer, _ in
61 | let targetLen = Int(buffer.frameLength) * 24000 / Int(inputFormat.sampleRate)
62 | let convertedBuffer = AVAudioPCMBuffer(
63 | pcmFormat: mono24kHzFormat, frameCapacity: AVAudioFrameCount(targetLen))!
64 | var error: NSError? = nil
65 | let inputBlock: AVAudioConverterInputBlock = { inNumPackets, outStatus in
66 | outStatus.pointee = .haveData
67 | return buffer
68 | }
69 |
70 | converter?.convert(to: convertedBuffer, error: &error, withInputFrom: inputBlock)
71 |
72 | if let error = error {
73 | print("Conversion error: \(error)")
74 | return
75 | }
76 |
77 | self.processAudioBuffer(buffer: convertedBuffer)
78 | }
79 |
80 | // Start the audio engine
81 | do {
82 | audioEngine.prepare()
83 | try audioEngine.start()
84 | print("Microphone capturing started at 24kHz, mono")
85 | } catch {
86 | print("Error starting audio engine: \(error)")
87 | }
88 | }
89 |
90 | private func processAudioBuffer(buffer: AVAudioPCMBuffer) {
91 | guard let channelData = buffer.floatChannelData else { return }
92 | let frameCount = Int(buffer.frameLength)
93 |
94 | let pcmData = Array(UnsafeBufferPointer(start: channelData[0], count: frameCount)).map {
95 | $0
96 | }
97 | channel.send(pcmData)
98 | }
99 |
100 | func stopCapturing() {
101 | audioEngine.stop()
102 | audioEngine.inputNode.removeTap(onBus: 0)
103 | print("Microphone capturing stopped")
104 | }
105 |
106 | func receive() -> [Float]? {
107 | channel.receive()
108 | }
109 | }
110 |
111 | class FloatRingBuffer {
112 | private var buffer: [Float]
113 | private let capacity: Int
114 | private var readIndex = 0
115 | private var writeIndex = 0
116 | private var count = 0
117 |
118 | private let lock = NSLock()
119 |
120 | init(capacity: Int) {
121 | self.capacity = capacity
122 | self.buffer = [Float](repeating: 0, count: capacity)
123 | }
124 |
125 | func write(_ values: [Float]) -> Bool {
126 | lock.lock()
127 | defer { lock.unlock() }
128 |
129 | if values.count + count > capacity {
130 | return false
131 | }
132 | for value in values {
133 | buffer[writeIndex] = value
134 | writeIndex = (writeIndex + 1) % capacity
135 | count += 1
136 | }
137 | return true
138 | }
139 |
140 | func read(maxCount: Int) -> [Float] {
141 | lock.lock()
142 | defer { lock.unlock() }
143 |
144 | var values: [Float] = []
145 | for _ in 0.. OSStatus in
173 | let audioBuffers = UnsafeMutableAudioBufferListPointer(audioBufferList)
174 | guard let channelData = audioBuffers[0].mData?.assumingMemoryBound(to: Float.self)
175 | else {
176 | return kAudioHardwareUnspecifiedError
177 | }
178 | let data = self.ringBuffer.read(maxCount: Int(frameCount))
179 | for i in 0.. Bool {
192 | ringBuffer.write(values)
193 | }
194 | }
195 |
--------------------------------------------------------------------------------
/MoshiLib/KVCache.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 | //
5 | // Parts of this file came from:
6 | // https://github.com/ml-explore/mlx-swift-examples/blob/main/Libraries/LLM/KVCache.swift
7 | // Copyright © 2024 Apple Inc.
8 |
9 | import Foundation
10 | import MLX
11 |
12 | /// Interface for Key/Value cache for LLMs.
13 | ///
14 | /// See ``LLMModel/newCache(parameters:)-47tyu``
15 | public protocol KVCache: Evaluatable {
16 |
17 | /// get the current offset
18 | var offset: Int { get }
19 |
20 | func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray)
21 | func reset()
22 | func createAttentionMask(h: MLXArray) -> MLXArray?
23 | }
24 |
25 | func createAdditiveCausalMask(n: Int, offset: Int) -> MLXArray {
26 | let rinds = MLXArray(Int32(0).. [MLXArray] {
58 | [self.keys, self.values].compactMap { $0 }
59 | }
60 |
61 | func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) {
62 | let previous = self.offset
63 |
64 | let reset =
65 | if let currentKeys = self.keys, (previous + keys.dim(2)) > currentKeys.dim(2) {
66 | true
67 | } else {
68 | self.keys == nil
69 | }
70 | if reset {
71 | let B = keys.dim(0)
72 | let nSteps = (step + keys.dim(2) - 1) / step
73 | let kShape = [B, kvHeads, nSteps * step, kHeadDim]
74 | let vShape = [B, kvHeads, nSteps * step, vHeadDim]
75 | let newK = MLXArray.zeros(kShape, dtype: keys.dtype)
76 | let newV = MLXArray.zeros(vShape, dtype: values.dtype)
77 |
78 | if var currentKeys = self.keys, var currentValues = self.values {
79 | if previous % step != 0 {
80 | currentKeys = currentKeys[.ellipsis, .. MLXArray? {
107 | let t = h.dim(1)
108 | if t > 1 {
109 | let rinds = MLXArray(Int32(0).. (MLXArray, MLXArray) {
131 | let t = keys.dim(2)
132 | if t > self.maxSize {
133 | fatalError("query to update with shape \(keys.shape) larger than maxSize \(maxSize)")
134 | }
135 | let currentOffset = self.offset % self.maxSize
136 | let tMax = min(self.maxSize, currentOffset + t)
137 | self.keys[0..., 0..., currentOffset.. [MLXArray] {
154 | [self.keys, self.values].compactMap { $0 }
155 | }
156 |
157 | func createAttentionMask(h: MLXArray) -> MLXArray? {
158 | let t = h.dim(1)
159 | let finalOffset = self.offset + t
160 | let finalOffsetMod = finalOffset % self.maxSize
161 | // As a default we use finalOffset + 1 so that these slices cannot be seen.
162 | var rinds = Array(repeating: Int32(finalOffset + 1), count: self.maxSize)
163 | for i in 0.. Self {
30 | try super.update(parameters: parameters, verify: verify)
31 | let clusterUsage = maximum(self._clusterUsage.wrappedValue, self.epsilon)[0..., .newAxis]
32 | self.embedding = self._embeddingSum.wrappedValue / clusterUsage
33 | self.c2 = self.embedding.square().sum(axis: -1) / 2
34 | return self
35 | }
36 |
37 | func encode(_ x: MLXArray) -> MLXArray {
38 | let targetShape = Array(x.shape.dropLast())
39 | let x = x.flattened(end: -2)
40 | let dotProd = x.matmul(embedding.swappedAxes(-1, -2))
41 | return (c2 - dotProd).argMin(axis: -1).reshaped(targetShape)
42 | }
43 |
44 | func decode(_ indexes: MLXArray) -> MLXArray {
45 | let finalDims = indexes.shape + [self.dim]
46 | let indexes = indexes.flattened()
47 | return embedding.take(indexes, axis: 0).reshaped(finalDims)
48 | }
49 | }
50 |
51 | class VectorQuantization: Module {
52 | @ModuleInfo(key: "project_in") var projectIn: Linear?
53 | @ModuleInfo(key: "project_out") var projectOut: Linear?
54 | @ModuleInfo(key: "_codebook") var codebook: EuclideanCodebook
55 |
56 | init(dim: Int, codebookSize: Int, codebookDim: Int? = nil) {
57 | let codebookDim = codebookDim ?? dim
58 | if codebookDim == dim {
59 | self._projectIn.wrappedValue = nil
60 | self._projectOut.wrappedValue = nil
61 | } else {
62 | self._projectIn.wrappedValue = Linear(dim, codebookDim)
63 | self._projectOut.wrappedValue = Linear(codebookDim, dim)
64 | }
65 | self._codebook.wrappedValue = EuclideanCodebook(
66 | dim: codebookDim, codebookSize: codebookSize)
67 | }
68 |
69 | func encode(_ x: MLXArray) -> MLXArray {
70 | var x = x.swappedAxes(-1, -2)
71 | if let projectIn = self.projectIn {
72 | x = projectIn(x)
73 | }
74 | return self.codebook.encode(x)
75 | }
76 |
77 | func decode(_ indexes: MLXArray) -> MLXArray {
78 | var quantized = self.codebook.decode(indexes)
79 | if let projectOut = self.projectOut {
80 | quantized = projectOut(quantized)
81 | }
82 | return quantized.swappedAxes(-1, -2)
83 | }
84 | }
85 |
86 | class ResidualVectorQuantization: Module {
87 | @ModuleInfo(key: "layers") var layers: [VectorQuantization]
88 |
89 | init(nQ: Int, dim: Int, codebookSize: Int, codebookDim: Int? = nil) {
90 | self._layers.wrappedValue = (0.. MLXArray {
96 | var codes: [MLXArray] = []
97 | var residual = x
98 | for layer in self.layers {
99 | let indices = layer.encode(residual)
100 | let quantized = layer.decode(indices)
101 | residual = residual - quantized
102 | codes.append(indices)
103 | }
104 | return stacked(codes, axis: 0)
105 | }
106 |
107 | func decode(_ indexes: MLXArray) -> MLXArray {
108 | let seqLen = indexes.dim(0)
109 | var quantized = self.layers[0].decode(indexes[0])
110 | for i in 1.. MLXArray {
142 | var x = x
143 | if let inputProj = self.inputProj {
144 | x = inputProj(x)
145 | }
146 | return self.vq.encode(x).swappedAxes(0, 1)
147 | }
148 |
149 | func decode(_ codes: MLXArray) -> MLXArray {
150 | let codes = codes.swappedAxes(0, 1)
151 | var quantized = self.vq.decode(codes)
152 | if let outputProj = self.outputProj {
153 | quantized = outputProj(quantized)
154 | }
155 | return quantized
156 | }
157 | }
158 |
159 | class SplitResidualVectorQuantizer: Module {
160 | let nQ: Int
161 | @ModuleInfo(key: "rvq_first") var rvqFirst: ResidualVectorQuantizer
162 | @ModuleInfo(key: "rvq_rest") var rvqRest: ResidualVectorQuantizer
163 |
164 | init(dim: Int, inputDim: Int?, outputDim: Int?, nQ: Int, bins: Int) {
165 | self.nQ = nQ
166 | self._rvqFirst.wrappedValue = ResidualVectorQuantizer(
167 | dim: dim, inputDim: inputDim, outputDim: outputDim, nQ: 1, bins: bins,
168 | forceProjection: true)
169 | self._rvqRest.wrappedValue = ResidualVectorQuantizer(
170 | dim: dim, inputDim: inputDim, outputDim: outputDim, nQ: nQ - 1, bins: bins,
171 | forceProjection: true)
172 | }
173 |
174 | func encode(_ x: MLXArray) -> MLXArray {
175 | var codes = self.rvqFirst.encode(x)
176 | if self.nQ > 1 {
177 | let restCodes = self.rvqRest.encode(x)
178 | codes = concatenated([codes, restCodes], axis: 1)
179 | }
180 | return codes
181 | }
182 |
183 | func decode(_ codes: MLXArray) -> MLXArray {
184 | var quantized = self.rvqFirst.decode(codes[0..., ..<1])
185 | if self.nQ > 1 {
186 | quantized = quantized + self.rvqRest.decode(codes[0..., 1...])
187 | }
188 | return quantized
189 | }
190 | }
191 |
--------------------------------------------------------------------------------
/MoshiLib/Perf.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 | import MLX
5 | import os.signpost
6 |
7 | public enum EventKind {
8 | case beginStep
9 | case endStep
10 | case beginDepformer
11 | case endDepformer
12 | case beginDecode
13 | case endDecode
14 | case beginEncode
15 | case endEncode
16 | }
17 |
18 | public protocol Callbacks {
19 | func onReset()
20 | func onEvent(_ eventKind: EventKind)
21 | func onInputAudioTokens(_ codes: MLXArray)
22 | func onOutputTextToken(_ token: Int)
23 | func onOutputAudioTokens(_ codes: MLXArray)
24 | }
25 |
26 | public class EmptyCallbacks: Callbacks {
27 | public init() {}
28 | public func onReset() {}
29 | public func onEvent(_ eventKind: EventKind) {}
30 | public func onInputAudioTokens(_ codes: MLXArray) {}
31 | public func onOutputTextToken(_ token: Int) {}
32 | public func onOutputAudioTokens(_ codes: MLXArray) {}
33 | }
34 |
35 | public struct ChromeTraceEvent: Codable {
36 | let name: String
37 | let cat: String
38 | let ph: String
39 | let ts: Int
40 | let pid: Int
41 | let tid: Int
42 | }
43 |
44 | public struct StatsSummary {
45 | public struct Stats: Identifiable {
46 | public var min: Float = Float.infinity
47 | public var max: Float = 0.0
48 | public var sum: Float = 0.0
49 | public var cnt: Int = 0
50 | public let id = UUID()
51 |
52 | mutating func addValue(_ b: CFAbsoluteTime, _ e: CFAbsoluteTime) {
53 | let v = Float(e - b)
54 | self.min = Float.minimum(self.min, v)
55 | self.max = Float.maximum(self.max, v)
56 | self.sum += v
57 | self.cnt += 1
58 | }
59 | }
60 | public var encode: Stats = Stats()
61 | public var decode: Stats = Stats()
62 | public var step: Stats = Stats()
63 | public var depformer: Stats = Stats()
64 |
65 | public init() {
66 | }
67 | }
68 |
69 | public class PerfStats: Callbacks {
70 | private let log: OSLog
71 | private var events: [(CFAbsoluteTime, EventKind)] = []
72 | private var inputAudioTokens: [MLXArray] = []
73 | private var outputAudioTokens: [MLXArray] = []
74 | private var textTokens: [Int] = []
75 |
76 | public init() {
77 | self.log = OSLog(subsystem: "org.kyutai.moshi", category: "Performance")
78 | }
79 |
80 | func append(_ kind: EventKind) {
81 | events.append((CFAbsoluteTimeGetCurrent(), kind))
82 | }
83 |
84 | public func onReset() {
85 | events.removeAll()
86 | inputAudioTokens.removeAll()
87 | outputAudioTokens.removeAll()
88 | textTokens.removeAll()
89 | }
90 |
91 | public func onInputAudioTokens(_ codes: MLXArray) {
92 | codes.eval()
93 | inputAudioTokens.append(codes)
94 | }
95 |
96 | public func onOutputTextToken(_ token: Int) {
97 | textTokens.append(token)
98 | }
99 |
100 | public func onOutputAudioTokens(_ codes: MLXArray) {
101 | codes.eval()
102 | outputAudioTokens.append(codes)
103 | }
104 |
105 | public func getSummary(maxEvents: Int) -> StatsSummary {
106 | let startIdx = max(0, self.events.count - maxEvents)
107 | var summary = StatsSummary()
108 | var lastBeginEncode: CFAbsoluteTime? = nil
109 | var lastBeginDecode: CFAbsoluteTime? = nil
110 | var lastBeginStep: CFAbsoluteTime? = nil
111 | var lastBeginDepformer: CFAbsoluteTime? = nil
112 | for event in self.events[startIdx...] {
113 | let time = event.0
114 | switch event.1 {
115 | case .beginStep:
116 | lastBeginStep = time
117 | case .beginDecode:
118 | lastBeginDecode = time
119 | case .beginEncode:
120 | lastBeginEncode = time
121 | case .beginDepformer:
122 | lastBeginDepformer = time
123 | case .endStep:
124 | if let b = lastBeginStep {
125 | lastBeginStep = nil
126 | summary.step.addValue(b, time)
127 | }
128 | case .endDecode:
129 | if let b = lastBeginDecode {
130 | lastBeginDecode = nil
131 | summary.decode.addValue(b, time)
132 | }
133 | case .endEncode:
134 | if let b = lastBeginEncode {
135 | lastBeginEncode = nil
136 | summary.encode.addValue(b, time)
137 | }
138 | case .endDepformer:
139 | if let b = lastBeginDepformer {
140 | lastBeginDepformer = nil
141 | summary.depformer.addValue(b, time)
142 | }
143 | }
144 | }
145 | return summary
146 | }
147 |
148 | public func onEvent(_ kind: EventKind) {
149 | switch kind {
150 | case .beginStep:
151 | os_signpost(.begin, log: log, name: "step")
152 | case .endStep:
153 | os_signpost(.end, log: log, name: "step")
154 | case .beginDepformer:
155 | os_signpost(.begin, log: log, name: "depformer")
156 | case .endDepformer:
157 | os_signpost(.end, log: log, name: "depformer")
158 | case .beginEncode:
159 | os_signpost(.begin, log: log, name: "encode")
160 | case .endEncode:
161 | os_signpost(.end, log: log, name: "encode")
162 | case .beginDecode:
163 | os_signpost(.begin, log: log, name: "decode")
164 | case .endDecode:
165 | os_signpost(.end, log: log, name: "decode")
166 | }
167 | append(kind)
168 | }
169 |
170 | public func writeJSONTrace(url: URL) throws {
171 | let encoder = JSONEncoder()
172 | var traceEvents: [ChromeTraceEvent] = []
173 | for (time, kind) in events {
174 | let ts = Int((time - events[0].0) * 1e6)
175 | let (name, ph) =
176 | switch kind {
177 | case .beginStep: ("step", "B")
178 | case .endStep: ("step", "E")
179 | case .beginEncode: ("encode", "B")
180 | case .endEncode: ("encode", "E")
181 | case .beginDepformer: ("depformer", "B")
182 | case .endDepformer: ("depformer", "E")
183 | case .beginDecode: ("decode", "B")
184 | case .endDecode: ("decode", "E")
185 | }
186 | traceEvents.append(
187 | ChromeTraceEvent(name: name, cat: "", ph: ph, ts: ts, pid: 42, tid: 1))
188 | }
189 | let jsonData = try encoder.encode(traceEvents)
190 | try jsonData.write(to: url)
191 | }
192 |
193 | public func allInputAudioTokens() -> MLXArray? {
194 | inputAudioTokens.isEmpty ? nil : concatenated(inputAudioTokens, axis: 2)
195 | }
196 |
197 | public func allOutputAudioTokens() -> MLXArray? {
198 | outputAudioTokens.isEmpty ? nil : concatenated(outputAudioTokens, axis: 2)
199 | }
200 |
201 | public func allTextTokens() -> MLXArray? {
202 | textTokens.isEmpty ? nil : MLXArray(textTokens)
203 | }
204 |
205 | public func writeCodes(url: URL) throws {
206 | var arrays: [String: MLXArray] = [:]
207 | if let a = allInputAudioTokens() {
208 | arrays["input_audio_tokens"] = a
209 | }
210 | if let a = allOutputAudioTokens() {
211 | arrays["output_audio_tokens"] = a
212 | }
213 | if let a = allTextTokens() {
214 | arrays["text_tokens"] = a
215 | }
216 | try save(arrays: arrays, url: url)
217 | }
218 | }
219 |
--------------------------------------------------------------------------------
/Moshi/AudioRT.swift:
--------------------------------------------------------------------------------
1 | import AVFoundation
2 | import Foundation
3 | import Synchronization
4 |
5 | class ThreadSafeChannel {
6 | private var buffer: [T] = []
7 | private let queue = DispatchQueue(label: "tschannel", attributes: .concurrent)
8 | private let semaphore = DispatchSemaphore(value: 0)
9 |
10 | func send(_ value: T) {
11 | queue.async(flags: .barrier) {
12 | self.buffer.append(value)
13 | self.semaphore.signal()
14 | }
15 | }
16 |
17 | func receive() -> T? {
18 | semaphore.wait()
19 | return queue.sync {
20 | guard !buffer.isEmpty else { return nil }
21 | return buffer.removeFirst()
22 | }
23 | }
24 | }
25 |
26 | // The code below is probably macos specific and unlikely to work on ios.
27 | class MicrophoneCapture {
28 | private let audioEngine: AVAudioEngine
29 | private let channel: ThreadSafeChannel<[Float]>
30 | private let bufferedLen: Atomic = .init(0)
31 |
32 | init() {
33 | audioEngine = AVAudioEngine()
34 | channel = ThreadSafeChannel()
35 | }
36 |
37 | func startCapturing() {
38 | let inputNode = audioEngine.inputNode
39 | // Setting the voice mode on macos causes weird hangs of the microphone
40 | // so we discard it for now.
41 | #if os(iOS)
42 | do {
43 | // try inputNode.setVoiceProcessingEnabled(true)
44 | } catch {
45 | print("could not set voice processing on the input node")
46 | }
47 | #endif
48 |
49 | // Desired format: 1 channel (mono), 24kHz, Float32
50 | let desiredSampleRate: Double = 24000.0
51 | let desiredChannelCount: AVAudioChannelCount = 1
52 |
53 | let inputFormat = inputNode.inputFormat(forBus: 0)
54 | // Create a custom audio format with the desired settings
55 | guard
56 | let mono24kHzFormat = AVAudioFormat(
57 | commonFormat: .pcmFormatFloat32,
58 | sampleRate: desiredSampleRate,
59 | channels: desiredChannelCount,
60 | interleaved: false)
61 | else {
62 | print("Could not create target format")
63 | return
64 | }
65 |
66 | // Resample the buffer to match the desired format
67 | let converter = AVAudioConverter(from: inputFormat, to: mono24kHzFormat)
68 |
69 | // Install a tap to capture audio and resample to the target format
70 | inputNode.installTap(onBus: 0, bufferSize: 1920, format: inputFormat) { buffer, _ in
71 | let targetLen = Int(buffer.frameLength) * 24000 / Int(inputFormat.sampleRate)
72 | let convertedBuffer = AVAudioPCMBuffer(
73 | pcmFormat: mono24kHzFormat, frameCapacity: AVAudioFrameCount(targetLen))!
74 | var error: NSError? = nil
75 | let inputBlock: AVAudioConverterInputBlock = { inNumPackets, outStatus in
76 | outStatus.pointee = .haveData
77 | return buffer
78 | }
79 |
80 | converter?.convert(to: convertedBuffer, error: &error, withInputFrom: inputBlock)
81 |
82 | if let error = error {
83 | print("Conversion error: \(error)")
84 | return
85 | }
86 |
87 | self.processAudioBuffer(buffer: convertedBuffer)
88 | }
89 |
90 | // Start the audio engine
91 | do {
92 | audioEngine.prepare()
93 | try audioEngine.start()
94 | print("Microphone capturing started at 24kHz, mono")
95 | } catch {
96 | print("Error starting audio engine: \(error)")
97 | }
98 | }
99 |
100 | private func processAudioBuffer(buffer: AVAudioPCMBuffer) {
101 | guard let channelData = buffer.floatChannelData else { return }
102 | let frameCount = Int(buffer.frameLength)
103 |
104 | let pcmData = Array(UnsafeBufferPointer(start: channelData[0], count: frameCount)).map {
105 | $0
106 | }
107 | bufferedLen.add(pcmData.count, ordering: .sequentiallyConsistent)
108 | channel.send(pcmData)
109 | }
110 |
111 | func stopCapturing() {
112 | audioEngine.stop()
113 | audioEngine.inputNode.removeTap(onBus: 0)
114 | print("Microphone capturing stopped")
115 | }
116 |
117 | func receive() -> [Float]? {
118 | let data = channel.receive()
119 | if let data = data {
120 | bufferedLen.subtract(data.count, ordering: .sequentiallyConsistent)
121 | }
122 | return data
123 | }
124 |
125 | func bufferedDuration() -> Double {
126 | return Double(bufferedLen.load(ordering: .sequentiallyConsistent)) / 24000.0
127 | }
128 | }
129 |
130 | class FloatRingBuffer {
131 | private var buffer: [Float]
132 | private let capacity: Int
133 | private var readIndex = 0
134 | private var writeIndex = 0
135 | private var count = 0
136 |
137 | private let lock = NSLock()
138 |
139 | init(capacity: Int) {
140 | self.capacity = capacity
141 | self.buffer = [Float](repeating: 0, count: capacity)
142 | }
143 |
144 | func currentCount() -> Int {
145 | lock.lock()
146 | defer { lock.unlock() }
147 | return count
148 | }
149 |
150 | func write(_ values: [Float]) -> Bool {
151 | lock.lock()
152 | defer { lock.unlock() }
153 |
154 | if values.count + count > capacity {
155 | return false
156 | }
157 | for value in values {
158 | buffer[writeIndex] = value
159 | writeIndex = (writeIndex + 1) % capacity
160 | count += 1
161 | }
162 | return true
163 | }
164 |
165 | func read(maxCount: Int) -> [Float] {
166 | lock.lock()
167 | defer { lock.unlock() }
168 |
169 | var values: [Float] = []
170 | for _ in 0.. Double {
195 | Double(ringBuffer.currentCount()) / sampleRate
196 | }
197 |
198 | func startPlaying() throws {
199 | let audioFormat = AVAudioFormat(standardFormatWithSampleRate: self.sampleRate, channels: 1)!
200 | let sourceNode = AVAudioSourceNode(format: audioFormat) {
201 | _, _, frameCount, audioBufferList -> OSStatus in
202 | let audioBuffers = UnsafeMutableAudioBufferListPointer(audioBufferList)
203 | guard let channelData = audioBuffers[0].mData?.assumingMemoryBound(to: Float.self)
204 | else {
205 | // TODO: Get a proper error here that would work on ios.
206 | return 1
207 | }
208 | let data = self.ringBuffer.read(maxCount: Int(frameCount))
209 | for i in 0.. Bool {
223 | ringBuffer.write(values)
224 | }
225 | }
226 |
227 | func setDefaultToStd() {
228 | #if os(iOS)
229 | do {
230 | let audioSession = AVAudioSession.sharedInstance()
231 | try audioSession.setCategory(.playAndRecord, mode: .default, options: [.allowAirPlay, .allowBluetooth])
232 | try audioSession.setActive(true)
233 | try audioSession.overrideOutputAudioPort(.none)
234 | } catch {
235 | print("failed to configure audio session: \(error.localizedDescription)")
236 | }
237 | #endif
238 | }
239 |
240 | func setDefaultToSpeaker() {
241 | #if os(iOS)
242 | do {
243 | let audioSession = AVAudioSession.sharedInstance()
244 | try audioSession.setCategory(.playAndRecord, mode: .default, options: .defaultToSpeaker)
245 | try audioSession.setActive(true)
246 | try audioSession.overrideOutputAudioPort(.speaker)
247 | } catch {
248 | print("failed to configure audio session: \(error.localizedDescription)")
249 | }
250 | #endif
251 | }
252 |
--------------------------------------------------------------------------------
/MoshiCLI/RunMimi.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import AVFoundation
6 | import Foundation
7 | import Hub
8 | import MLX
9 | import MLXNN
10 | import MoshiLib
11 |
12 | func makeMimi(numCodebooks: Int) throws -> Mimi {
13 | let cfg = MimiConfig.mimi_2024_07(numCodebooks: numCodebooks)
14 | let model = Mimi(cfg, bSize: 1)
15 |
16 | let url = try downloadFromHub(
17 | id: "lmz/moshi-swift",
18 | filename: "tokenizer-dbaa9758-checkpoint125.safetensors")
19 | let origWeights = try loadArrays(url: url)
20 | var weights: [String: MLXArray] = [:]
21 | for (var key, var weight) in origWeights {
22 | // Mutating the keys while iterating over the map seems pretty dodgy, not sure what the idiomatic
23 | // way to do this is in swift. Hopefully this is copy on write and it's all good :)
24 | if key.hasPrefix("encoder.model") {
25 | key.replace("encoder.model.", with: "encoder.")
26 | }
27 | if key.hasPrefix("decoder.model") {
28 | key.replace("decoder.model.", with: "decoder.")
29 | }
30 | if key.hasSuffix(".in_proj_weight") {
31 | key.replace(".in_proj_weight", with: ".in_proj.weight")
32 | }
33 | if key.hasSuffix(".linear1.weight") {
34 | key.replace(".linear1.weight", with: ".gating.linear1.weight")
35 | }
36 | if key.hasSuffix(".linear2.weight") {
37 | key.replace(".linear2.weight", with: ".gating.linear2.weight")
38 | }
39 | // Awfully hardcoded matching between the pytorch layers and their mlx equivalent :(
40 | for (layerIdx, decoderIdx) in [2, 5, 8, 11].enumerated() {
41 | key.replace("decoder.\(decoderIdx).", with: "decoder.layers.\(layerIdx).upsample.")
42 | key.replace(
43 | "decoder.\(decoderIdx + 1).", with: "decoder.layers.\(layerIdx).residuals.0.")
44 | }
45 | for (layerIdx, encoderIdx) in [1, 4, 7, 10].enumerated() {
46 | key.replace("encoder.\(encoderIdx).", with: "encoder.layers.\(layerIdx).residuals.0.")
47 | key.replace(
48 | "encoder.\(encoderIdx + 2).", with: "encoder.layers.\(layerIdx).downsample.")
49 | }
50 | key.replace("decoder.0.", with: "decoder.init_conv1d.")
51 | key.replace("decoder.14.", with: "decoder.final_conv1d.")
52 | key.replace("encoder.0.", with: "encoder.init_conv1d.")
53 | key.replace("encoder.14.", with: "encoder.final_conv1d.")
54 | key.replace(".block.1.", with: ".block.0.")
55 | key.replace(".block.3.", with: ".block.1.")
56 | // PyTorch layout for conv weights is outC, inC, kSize, for MLX it's outC, kSize, inC
57 | if key.hasSuffix(".conv.weight") || key.hasSuffix(".output_proj.weight")
58 | || key.hasSuffix(".input_proj.weight")
59 | {
60 | weight = weight.swappedAxes(-1, -2)
61 | }
62 | // PyTorch layout for conv-transposed weights is inC, outC, kSize, for MLX it's outC, kSize, inC
63 | if key.hasSuffix(".convtr.weight") {
64 | weight = weight.transposed(axes: [1, 2, 0])
65 | }
66 |
67 | // print(key, weight.shape)
68 | weights[key] = weight
69 | }
70 | let parameters = ModuleParameters.unflattened(weights)
71 | try model.update(parameters: parameters, verify: [.all])
72 | return model
73 | }
74 |
75 | func runMimi(streaming: Bool, audioFile: URL?, channel: Int = 0) throws {
76 | let model = try makeMimi(numCodebooks: 16)
77 | print("using device \(Device.defaultDevice().description)")
78 | let sampleURL =
79 | switch audioFile {
80 | case .none: try downloadFromHub(id: "lmz/moshi-swift", filename: "bria-24khz.mp3")
81 | case .some(let url): url
82 | }
83 | let pcm = readAudioToPCMArray(fileURL: sampleURL, channel: channel)!
84 |
85 | if streaming {
86 | let chunkSize = 1920
87 | var pcmOuts: [[Float]] = []
88 | var elapsedTimes: [Double] = []
89 | var nSteps = 0
90 | for start in stride(from: 0, to: pcm.count, by: chunkSize) {
91 | let startTime = CFAbsoluteTimeGetCurrent()
92 | let pct = 100 * start / pcm.count
93 | let end = min(start + chunkSize, pcm.count)
94 | let pcmA = MLXArray(pcm[start.. 125 {
200 | break
201 | }
202 | }
203 | }
204 | try save(
205 | arrays: ["codes": concatenated(allCodes, axis: -1)],
206 | url: URL(fileURLWithPath: "mimi-codes.safetensors"))
207 | microphoneCapture.stopCapturing()
208 |
209 | }
210 |
--------------------------------------------------------------------------------
/MoshiCLI/RunMoshi.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import AVFoundation
6 | import Foundation
7 | import MLX
8 | import MLXNN
9 | import MoshiLib
10 |
11 | func makeMoshi(_ url: URL, _ cfg: LmConfig) throws -> LM {
12 | let weights = try loadArrays(url: url)
13 | let parameters = ModuleParameters.unflattened(weights)
14 | let model = LM(cfg, bSize: 1)
15 | if url.lastPathComponent.hasSuffix(".q4.safetensors") {
16 | quantize(model: model, groupSize: 32, bits: 4)
17 | } else if url.lastPathComponent.hasSuffix(".q6.safetensors") {
18 | quantize(model: model, groupSize: 64, bits: 6)
19 | } else if url.lastPathComponent.hasSuffix(".q8.safetensors") {
20 | quantize(model: model, groupSize: 64, bits: 8)
21 | }
22 | try model.update(parameters: parameters, verify: [.all])
23 | eval(model)
24 | return model
25 | }
26 |
27 | func loadVocab(_ cfg: LmConfig) throws -> [Int: String] {
28 | let filename =
29 | switch cfg.textOutVocabSize {
30 | case 48000: "tokenizer_spm_48k_multi6_2.json"
31 | case 32000: "tokenizer_spm_32k_3.json"
32 | case 8000: "tokenizer_spm_8k_0.json"
33 | case 4000: "test_en_audio_4000.json"
34 | case let other: fatalError("unexpected text vocab size \(other)")
35 | }
36 | let fileURL = try downloadFromHub(id: "lmz/moshi-swift", filename: filename)
37 | let jsonData = try Data(contentsOf: fileURL)
38 | let dictionary = try JSONDecoder().decode([Int: String].self, from: jsonData)
39 | return dictionary
40 | }
41 |
42 | func runMoshiMic(_ url: URL, cfg: LmConfig) throws {
43 | let mimi = try makeMimi(numCodebooks: 16)
44 | let moshi = try makeMoshi(url, cfg)
45 | let vocab = try loadVocab(cfg)
46 | print("using device \(Device.defaultDevice().description)")
47 | print("warming up mimi")
48 | mimi.warmup()
49 | print("warming up moshi")
50 | moshi.warmup()
51 | print("done warming up")
52 |
53 | let maxSteps = moshi.cfg.transformer.maxSeqLen
54 | let gen = LMGen(moshi, maxSteps: maxSteps, audioSampler: Sampler(), textSampler: Sampler())
55 |
56 | let microphoneCapture = MicrophoneCapture()
57 | microphoneCapture.startCapturing()
58 | let player = AudioPlayer(sampleRate: 24000)
59 | try player.startPlaying()
60 | print("started the audio loops")
61 |
62 | while let pcm = microphoneCapture.receive() {
63 | let pcm = MLXArray(pcm)[.newAxis, .newAxis]
64 | let codes = mimi.encodeStep(StreamArray(pcm))
65 | if let codes = codes.asArray() {
66 | let (_, _, steps) = codes.shape3
67 | for step in 0.. URL {
15 | let targetURL = HubApi().localRepoLocation(Hub.Repo(id: id)).appending(path: filename)
16 | if FileManager.default.fileExists(atPath: targetURL.path) {
17 | print("using cached file \(targetURL.path)")
18 | return targetURL
19 | }
20 | var url: URL? = nil
21 | let semaphore = DispatchSemaphore(value: 0)
22 | Task {
23 | let repo = Hub.Repo(id: id)
24 | do {
25 | url = try await Hub.snapshot(from: repo, matching: filename) { progress in
26 | let pct = Int(progress.fractionCompleted * 100)
27 | print("\rretrieving \(filename): \(pct)%", terminator: "")
28 | }
29 | } catch {
30 | fatalError("cannot fetch \(id) \(filename): \(error)")
31 | }
32 | semaphore.signal()
33 | }
34 | semaphore.wait()
35 | print("\rretrieved \(filename)")
36 | return url!.appending(path: filename)
37 | }
38 |
39 | func maybeDownloadFromHub(filename: String) throws -> URL {
40 | // dowloadFromHub(id: id, filename: filename)
41 | let prefix = "hf://"
42 | if filename.hasPrefix(prefix) {
43 | let rest = filename.dropFirst(prefix.count)
44 | let components = rest.split(separator: "/", omittingEmptySubsequences: false)
45 | let id = components[0.. any Tokenizer {
54 | var tokenizer: (any Tokenizer)? = nil
55 | let semaphore = DispatchSemaphore(value: 0)
56 | Task {
57 | do {
58 | tokenizer = try await AutoTokenizer.from(pretrained: hfRepo)
59 | } catch {
60 | fatalError("cannot build tokenizer \(error)")
61 | }
62 | semaphore.signal()
63 | }
64 | semaphore.wait()
65 | return tokenizer!
66 | }
67 |
68 | @main
69 | struct Moshi: ParsableCommand {
70 | static let configuration = CommandConfiguration(
71 | subcommands: [
72 | Run.self, RunHelium.self, RunMimi.self, AudioToCodes.self, CodesToAudio.self,
73 | RunAsr.self, RunQwen.self,
74 | ]
75 | )
76 | }
77 |
78 | public enum Config: String, CaseIterable, ExpressibleByArgument {
79 | case moshi1b
80 | case moshi7b
81 | }
82 |
83 | struct Run: ParsableCommand {
84 | @Argument(help: "the model to run")
85 | var model: String
86 |
87 | @Option(help: "the file to process, use 'mic' for the microphone")
88 | var input: String?
89 |
90 | @Option(help: "the audio delay to apply")
91 | var audioDelay: Int = 2
92 |
93 | @Option(help: "the audio channel from the input file to be used")
94 | var channel: Int = 0
95 |
96 | @Option(help: "the config size")
97 | var config: Config = .moshi1b
98 |
99 | mutating func run() throws {
100 | let model = URL(fileURLWithPath: model)
101 | let cfg =
102 | switch config {
103 | case .moshi1b: LmConfig.moshi1b(audioDelay: audioDelay)
104 | case .moshi7b: LmConfig.moshi_2024_07()
105 | }
106 |
107 | switch input {
108 | case .none:
109 | try runMoshi(model, cfg: cfg, audioFile: nil)
110 | case .some("mic"): try runMoshiMic(model, cfg: cfg)
111 | case .some(let input):
112 | let audioFile = URL(fileURLWithPath: input)
113 | try runMoshi(
114 | model, cfg: cfg, audioFile: audioFile, channel: channel)
115 | }
116 | }
117 | }
118 |
119 | struct RunMimi: ParsableCommand {
120 | @Option(help: "whether to use the streaming mode or not")
121 | var streaming: Bool = false
122 |
123 | @Option(help: "the file to process")
124 | var input: String?
125 |
126 | @Option(help: "the audio channel from the input file to be used")
127 | var channel: Int = 0
128 |
129 | mutating func run() throws {
130 | let audioFile = input.flatMap { URL(fileURLWithPath: $0) }
131 | try runMimi(streaming: streaming, audioFile: audioFile, channel: channel)
132 | }
133 | }
134 |
135 | public enum HeliumConfig: String, CaseIterable, ExpressibleByArgument {
136 | case q4
137 | case q6
138 | case q8
139 | case bf16
140 | }
141 |
142 | struct RunQwen: ParsableCommand {
143 | @Option(help: "the config")
144 | var hfRepo: String = "Qwen/Qwen2.5-0.5B-Instruct"
145 |
146 | @Option(help: "the prompt to be used")
147 | var prompt: String = "Describe the swift programming language."
148 |
149 | @Option(help: "the number of tokens to generate")
150 | var n: Int = 256
151 |
152 | mutating func run() throws {
153 | let tokenizer = try makeTokenizer(hfRepo: hfRepo)
154 | let messages = [["role": "user", "content": prompt]]
155 | let encodedPrompt = try tokenizer.applyChatTemplate(messages: messages)
156 | let configUrl = try downloadFromHub(id: hfRepo, filename: "config.json")
157 | let configData = try Data(contentsOf: configUrl)
158 | let decoder = JSONDecoder()
159 | decoder.keyDecodingStrategy = .convertFromSnakeCase
160 | let config = try decoder.decode(QwenConfig.self, from: configData)
161 | print("config \(config)")
162 | let modelUrl = try downloadFromHub(id: hfRepo, filename: "model.safetensors")
163 | print("model \(modelUrl)")
164 | let weights = try loadArrays(url: modelUrl)
165 | guard let modelItem = ModuleParameters.unflattened(weights)["model"] else {
166 | fatalError("no model key in {configUrl}")
167 | }
168 | let parameters =
169 | switch modelItem {
170 | case .dictionary(let d): NestedDictionary(values: d)
171 | default: fatalError("model key in {configUrl} is not a dict")
172 | }
173 |
174 | let model = QwenModel(config)
175 | if let q = config.quantization {
176 | quantize(model: model, groupSize: q.groupSize, bits: q.bits)
177 | }
178 | try model.update(parameters: parameters, verify: [.all])
179 | eval(model)
180 | let cache = model.makeCache(bSize: 1)
181 | let sampler = Sampler()
182 | var lastToken = config.bosTokenId
183 | let startTime = CFAbsoluteTimeGetCurrent()
184 | var nTokens = 0
185 | for index in 0...(n + prompt.count) {
186 | let logits = model(MLXArray([lastToken]).reshaped(1, 1), cache: cache)
187 | if index < encodedPrompt.count {
188 | lastToken = encodedPrompt[index]
189 | } else {
190 | let (tok, _) = sampler(logits: logits[0])
191 | lastToken = tok.item()
192 | }
193 | let s = tokenizer.decode(tokens: [lastToken])
194 | print("\(s)", terminator: "")
195 | fflush(stdout)
196 | nTokens += 1
197 | }
198 | print()
199 | let elapsedTime = CFAbsoluteTimeGetCurrent() - startTime
200 | print("\(nTokens) tokens generated, \(Double(nTokens) / elapsedTime) tok/s")
201 | }
202 | }
203 |
204 | struct RunHelium: ParsableCommand {
205 | @Option(help: "the config")
206 | var config: HeliumConfig = .q4
207 |
208 | mutating func run() throws {
209 | let cfg = LmConfig.helium2b()
210 | let filename =
211 | switch config {
212 | case .q4: "helium-1-preview-2b-q4.safetensors"
213 | case .q6: "helium-1-preview-2b-q6.safetensors"
214 | case .q8: "helium-1-preview-2b-q8.safetensors"
215 | case .bf16: "helium-1-preview-2b-bf16.safetensors"
216 | }
217 | let url = try downloadFromHub(id: "kyutai/helium-1-preview-2b-mlx", filename: filename)
218 | try runHelium(url, cfg: cfg)
219 | }
220 | }
221 |
222 | struct AudioToCodes: ParsableCommand {
223 | mutating func run() throws {
224 | try runAudioToCodes()
225 | }
226 | }
227 |
228 | struct CodesToAudio: ParsableCommand {
229 | @Option(help: "whether to write the output file or not")
230 | var writeFile: Bool = false
231 |
232 | mutating func run() throws {
233 | try runCodesToAudio(writeFile: writeFile)
234 | }
235 | }
236 |
237 | struct RunAsr: ParsableCommand {
238 | @Argument(help: "the model to run")
239 | var model: String
240 |
241 | @Option(help: "the file to process, use 'mic' for the microphone")
242 | var input: String?
243 |
244 | @Option(help: "the audio channel from the input file to be used")
245 | var channel: Int = 0
246 |
247 | mutating func run() throws {
248 | let model = try maybeDownloadFromHub(filename: model)
249 | let weights = try loadArrays(url: model)
250 | let cfg =
251 | switch weights["out_norm.weight"]?.shape {
252 | case .none: fatalError("no out_norm.weight tensor in \(model)")
253 | case .some([1024]): LmConfig.asr300m()
254 | case .some([2048]): LmConfig.asr1b()
255 | case .some([2560]): LmConfig.asr2b()
256 | case .some(let s): fatalError("unexpected shape for out_norm.weight \(s)")
257 | }
258 | print("here2")
259 | switch input {
260 | case .none:
261 | try runAsr(model, cfg, audioFile: nil, channel: channel)
262 | case .some("mic"): try runAsrMic(model, cfg)
263 | case .some(let input):
264 | let audioFile = URL(fileURLWithPath: input)
265 | try runAsr(model, cfg, audioFile: audioFile, channel: channel)
266 | }
267 | }
268 | }
269 |
--------------------------------------------------------------------------------
/MoshiLib/Seanet.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import MLX
6 | import MLXFast
7 | import MLXNN
8 | import MLXRandom
9 |
10 | public struct SeanetConfig {
11 | public var dimension: Int
12 | public var channels: Int
13 | public var causal: Bool
14 | public var nFilters: Int
15 | public var nResidualLayers: Int
16 | public var ratios: [Int]
17 | // public var activation: String: hardcoded to Elu(1) for now
18 | // public var norm: Norm
19 | public var kernelSize: Int
20 | public var residualKernelSize: Int
21 | public var lastKernelSize: Int
22 | public var dilationBase: Int
23 | public var padMode: PadMode
24 | public var trueSkip: Bool
25 | public var compress: Int
26 | // public var disableNormOuterBlocks: Int
27 | // public var finalActivation: String?: hardcoded to None for now
28 |
29 | public static func v0_1() -> SeanetConfig {
30 | SeanetConfig(
31 | dimension: 512, channels: 1, causal: true, nFilters: 64, nResidualLayers: 1,
32 | ratios: [8, 6, 5, 4], kernelSize: 7, residualKernelSize: 3, lastKernelSize: 3,
33 | dilationBase: 2, padMode: .constant, trueSkip: true, compress: 2)
34 | }
35 | }
36 |
37 | class SeanetResnetBlock: Module, UnaryLayer, StreamingLayer {
38 | let skipOp: StreamingBinOp
39 | @ModuleInfo(key: "block") var block: [StreamableConv1d]
40 | @ModuleInfo(key: "shortcut") var shortcut: StreamableConv1d?
41 |
42 | init(_ cfg: SeanetConfig, dim: Int, kSizesAndDilations: [(Int, Int)]) {
43 | self.skipOp = StreamingBinOp(.add, axis: -1)
44 | var block: [StreamableConv1d] = []
45 | var shortcut: StreamableConv1d? = nil
46 | let hidden = dim / cfg.compress
47 |
48 | for (i, (kSize, dilation)) in kSizesAndDilations.enumerated() {
49 | let inC = i == 0 ? dim : hidden
50 | let outC = i == kSizesAndDilations.count - 1 ? dim : hidden
51 | let c = StreamableConv1d(
52 | inC: inC, outC: outC, kSize: kSize, stride: 1, dilation: dilation, groups: 1,
53 | bias: true, causal: cfg.causal, padMode: cfg.padMode)
54 | block.append(c)
55 | }
56 | if !cfg.trueSkip {
57 | shortcut = StreamableConv1d(
58 | inC: dim, outC: dim, kSize: 1, stride: 1, dilation: 1, groups: 1, bias: true,
59 | causal: cfg.causal, padMode: cfg.padMode)
60 | }
61 |
62 | self._block.wrappedValue = block
63 | self._shortcut.wrappedValue = shortcut
64 | }
65 |
66 | func callAsFunction(_ x: MLXArray) -> MLXArray {
67 | let residual = x
68 | var x = x
69 | for b in self.block {
70 | x = b(elu(x, alpha: 1.0))
71 | }
72 | if let shortcut = self.shortcut {
73 | return x + shortcut(residual)
74 | } else {
75 | return x + residual
76 | }
77 | }
78 |
79 | func resetState() {
80 | self.skipOp.resetState()
81 | for b in self.block {
82 | b.resetState()
83 | }
84 | if let s = self.shortcut {
85 | s.resetState()
86 | }
87 | }
88 |
89 | func step(_ x: StreamArray) -> StreamArray {
90 | let residual = x
91 | var x = x
92 | for b in self.block {
93 | x = b.step(x.elu())
94 | }
95 | if let shortcut = self.shortcut {
96 | return self.skipOp.step(x, shortcut.step(residual))
97 | } else {
98 | return self.skipOp.step(x, residual)
99 | }
100 | }
101 | }
102 |
103 | class EncoderLayer: Module, UnaryLayer, StreamingLayer {
104 | @ModuleInfo(key: "residuals") var residuals: [SeanetResnetBlock]
105 | @ModuleInfo(key: "downsample") var downsample: StreamableConv1d
106 |
107 | init(_ cfg: SeanetConfig, ratio: Int, mult: Int) {
108 | var residuals: [SeanetResnetBlock] = []
109 | var dilation: Int = 1
110 | for _ in 0.. MLXArray {
127 | var x = x
128 | for r in self.residuals {
129 | x = r(x)
130 | }
131 | let y = self.downsample(elu(x, alpha: 1.0))
132 | return y
133 | }
134 |
135 | func resetState() {
136 | for r in self.residuals {
137 | r.resetState()
138 | }
139 | self.downsample.resetState()
140 | }
141 |
142 | func step(_ x: StreamArray) -> StreamArray {
143 | var x = x
144 | for r in self.residuals {
145 | x = r.step(x)
146 | }
147 | return self.downsample.step(x.elu())
148 | }
149 | }
150 |
151 | class SeanetEncoder: Module, StreamingLayer {
152 | @ModuleInfo(key: "init_conv1d") var initConv1d: StreamableConv1d
153 | @ModuleInfo(key: "layers") var layers: [EncoderLayer]
154 | @ModuleInfo(key: "final_conv1d") var finalConv1d: StreamableConv1d
155 |
156 | init(_ cfg: SeanetConfig) {
157 | var mult = 1
158 | let initConv1d = StreamableConv1d(
159 | inC: cfg.channels, outC: mult * cfg.nFilters, kSize: cfg.kernelSize, stride: 1,
160 | dilation: 1, groups: 1, bias: true, causal: cfg.causal,
161 | padMode: cfg.padMode)
162 | var layers: [EncoderLayer] = []
163 | for ratio in cfg.ratios.reversed() {
164 | let layer = EncoderLayer(cfg, ratio: ratio, mult: mult)
165 | layers.append(layer)
166 | mult *= 2
167 | }
168 | let finalConv1d = StreamableConv1d(
169 | inC: mult * cfg.nFilters, outC: cfg.dimension, kSize: cfg.lastKernelSize, stride: 1,
170 | dilation: 1, groups: 1, bias: true, causal: cfg.causal, padMode: cfg.padMode)
171 | self._initConv1d.wrappedValue = initConv1d
172 | self._layers.wrappedValue = layers
173 | self._finalConv1d.wrappedValue = finalConv1d
174 | }
175 |
176 | func callAsFunction(_ x: MLXArray) -> MLXArray {
177 | var x = self.initConv1d(x)
178 | for layer in self.layers {
179 | x = layer(x)
180 | }
181 | return self.finalConv1d(elu(x, alpha: 1.0))
182 | }
183 |
184 | func resetState() {
185 | self.initConv1d.resetState()
186 | self.finalConv1d.resetState()
187 | for l in self.layers {
188 | l.resetState()
189 | }
190 | }
191 |
192 | func step(_ x: StreamArray) -> StreamArray {
193 | var x = self.initConv1d.step(x)
194 | for layer in self.layers {
195 | x = layer.step(x)
196 | }
197 | return self.finalConv1d.step(x.elu())
198 | }
199 | }
200 |
201 | class DecoderLayer: Module, UnaryLayer, StreamingLayer {
202 | @ModuleInfo(key: "upsample") var upsample: StreamableConvTranspose1d
203 | @ModuleInfo(key: "residuals") var residuals: [SeanetResnetBlock]
204 |
205 | init(_ cfg: SeanetConfig, ratio: Int, mult: Int) {
206 | var residuals: [SeanetResnetBlock] = []
207 | var dilation = 1
208 | for _ in 0.. MLXArray {
225 | var x = self.upsample(elu(x, alpha: 1.0))
226 | for r in self.residuals {
227 | x = r(x)
228 | }
229 | return x
230 | }
231 |
232 | func resetState() {
233 | self.upsample.resetState()
234 | for r in self.residuals {
235 | r.resetState()
236 | }
237 | }
238 |
239 | func step(_ x: StreamArray) -> StreamArray {
240 | var x = self.upsample.step(x.elu())
241 | for r in self.residuals {
242 | x = r.step(x)
243 | }
244 | return x
245 | }
246 | }
247 |
248 | class SeanetDecoder: Module, StreamingLayer {
249 | @ModuleInfo(key: "init_conv1d") var initConv1d: StreamableConv1d
250 | @ModuleInfo(key: "layers") var layers: [DecoderLayer]
251 | @ModuleInfo(key: "final_conv1d") var finalConv1d: StreamableConv1d
252 |
253 | init(_ cfg: SeanetConfig) {
254 | var layers: [DecoderLayer] = []
255 | var mult = 1 << cfg.ratios.count
256 | let initConv1d = StreamableConv1d(
257 | inC: cfg.dimension, outC: mult * cfg.nFilters, kSize: cfg.kernelSize, stride: 1,
258 | dilation: 1, groups: 1, bias: true, causal: cfg.causal, padMode: cfg.padMode)
259 | for ratio in cfg.ratios {
260 | let l = DecoderLayer(cfg, ratio: ratio, mult: mult)
261 | layers.append(l)
262 | mult /= 2
263 | }
264 | let finalConv1d = StreamableConv1d(
265 | inC: cfg.nFilters, outC: cfg.channels, kSize: cfg.lastKernelSize, stride: 1,
266 | dilation: 1, groups: 1, bias: true, causal: cfg.causal, padMode: cfg.padMode)
267 | self._initConv1d.wrappedValue = initConv1d
268 | self._layers.wrappedValue = layers
269 | self._finalConv1d.wrappedValue = finalConv1d
270 | }
271 |
272 | func callAsFunction(_ x: MLXArray) -> MLXArray {
273 | var x = self.initConv1d(x)
274 | for layer in self.layers {
275 | x = layer(x)
276 | }
277 | return self.finalConv1d(elu(x, alpha: 1.0))
278 | }
279 |
280 | func resetState() {
281 | self.initConv1d.resetState()
282 | self.finalConv1d.resetState()
283 | for l in self.layers {
284 | l.resetState()
285 | }
286 | }
287 |
288 | func step(_ x: StreamArray) -> StreamArray {
289 | var x = self.initConv1d.step(x)
290 | for layer in self.layers {
291 | x = layer.step(x)
292 | }
293 | return self.finalConv1d.step(x.elu())
294 | }
295 | }
296 |
--------------------------------------------------------------------------------
/MoshiLib/Conv.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import MLX
6 | import MLXFast
7 | import MLXNN
8 | import MLXRandom
9 |
10 | // Conv1d + dilation
11 | class Conv1d: Module, UnaryLayer {
12 | @ModuleInfo(key: "weight") var weight: MLXArray
13 | @ModuleInfo(key: "bias") var bias: MLXArray?
14 | let padding: Int
15 | let groups: Int
16 | let stride: Int
17 | let dilation: Int
18 |
19 | init(
20 | inputChannels: Int,
21 | outputChannels: Int,
22 | kernelSize: Int,
23 | stride: Int = 1,
24 | padding: Int = 0,
25 | groups: Int = 1,
26 | dilation: Int = 1,
27 | bias: Bool = true
28 | ) {
29 | let scale = sqrt(1 / Float(inputChannels * kernelSize))
30 |
31 | self._weight.wrappedValue = uniform(
32 | low: -scale, high: scale, [outputChannels, kernelSize, inputChannels / groups])
33 | self._bias.wrappedValue = bias ? MLXArray.zeros([outputChannels]) : nil
34 | self.padding = padding
35 | self.groups = groups
36 | self.stride = stride
37 | self.dilation = dilation
38 | }
39 |
40 | func callAsFunction(_ x: MLXArray) -> MLXArray {
41 | // MLX uses NLC whereas pytorch/candle use NCL
42 | var y = conv1d(
43 | x.swappedAxes(-1, -2), weight, stride: stride, padding: padding, dilation: dilation,
44 | groups: groups
45 | )
46 | if let bias {
47 | y = y + bias
48 | }
49 | y = y.swappedAxes(-1, -2)
50 | return y
51 | }
52 | }
53 |
54 | // ConvTranspose1d + groups
55 | class ConvTransposed1d: Module, UnaryLayer {
56 | @ModuleInfo(key: "weight") var weight: MLXArray
57 | @ModuleInfo(key: "bias") var bias: MLXArray?
58 | let padding: Int
59 | let stride: Int
60 | let groups: Int
61 | let inC: Int
62 | let outC: Int
63 | let kSize: Int
64 | var expandedWeight: MLXArray
65 | var expandedGroups: Int
66 |
67 | init(
68 | inputChannels: Int,
69 | outputChannels: Int,
70 | kernelSize: Int,
71 | stride: Int = 1,
72 | padding: Int = 0,
73 | groups: Int = 1,
74 | bias: Bool = true
75 | ) {
76 | let scale = sqrt(1 / Float(inputChannels * kernelSize))
77 |
78 | self._weight.wrappedValue = uniform(
79 | low: -scale, high: scale, [outputChannels / groups, kernelSize, inputChannels])
80 | let weight = self._weight.wrappedValue
81 | self._bias.wrappedValue = bias ? MLXArray.zeros([outputChannels]) : nil
82 | self.padding = padding
83 | self.stride = stride
84 | self.groups = groups
85 | self.inC = inputChannels
86 | self.outC = outputChannels
87 | self.kSize = kernelSize
88 | if groups == inC && groups == outC {
89 | let eye = repeated(
90 | eye(outC).asType(weight.dtype).reshaped([outC, 1, outC]), count: kSize, axis: 1)
91 | self.expandedWeight = repeated(weight, count: groups, axis: 0) * eye
92 | self.expandedGroups = 1
93 | } else if groups > 1 {
94 | fatalError("groups are not supported in ConvTranspose1d, \(groups), \(inC), \(outC)")
95 | } else {
96 | self.expandedWeight = weight
97 | self.expandedGroups = groups
98 | }
99 | }
100 |
101 | override func update(parameters: ModuleParameters, verify: Module.VerifyUpdate) throws -> Self {
102 | try super.update(parameters: parameters, verify: verify)
103 | if groups == inC && groups == outC {
104 | let eye = repeated(
105 | eye(outC).asType(weight.dtype).reshaped([outC, 1, outC]), count: kSize, axis: 1)
106 | self.expandedWeight = repeated(weight, count: groups, axis: 0) * eye
107 | self.expandedGroups = 1
108 | } else if groups > 1 {
109 | fatalError("groups are not supported in ConvTranspose1d, \(groups), \(inC), \(outC)")
110 | } else {
111 | self.expandedWeight = weight
112 | self.expandedGroups = groups
113 | }
114 | return self
115 | }
116 |
117 | open func callAsFunction(_ x: MLXArray) -> MLXArray {
118 | // Groups are not supported in convTransposed1d as of 0.18.1 so we hack our way around it.
119 | var y = convTransposed1d(
120 | x.swappedAxes(-1, -2), expandedWeight, stride: stride, padding: padding,
121 | groups: expandedGroups
122 | )
123 | if let bias {
124 | y = y + bias
125 | }
126 | return y.swappedAxes(-1, -2)
127 | }
128 | }
129 |
130 | // weight-norm is handled externally when generating the weights file so this class is just
131 | // a wrapper around Conv1d
132 | class NormConv1d: Module, UnaryLayer {
133 | @ModuleInfo(key: "conv") var conv: Conv1d
134 |
135 | init(
136 | inC: Int, outC: Int, kSize: Int,
137 | stride: Int = 1,
138 | padding: Int = 0,
139 | groups: Int = 1,
140 | dilation: Int = 1,
141 | bias: Bool = true
142 | ) {
143 | self._conv.wrappedValue = Conv1d(
144 | inputChannels: inC, outputChannels: outC, kernelSize: kSize, stride: stride,
145 | padding: padding,
146 | groups: groups, dilation: dilation, bias: bias)
147 | }
148 |
149 | func callAsFunction(_ x: MLXArray) -> MLXArray {
150 | self.conv(x)
151 | }
152 | }
153 |
154 | class NormConvTranspose1d: Module, UnaryLayer {
155 | @ModuleInfo(key: "convtr") var convtr: ConvTransposed1d
156 |
157 | init(
158 | inC: Int, outC: Int, kSize: Int,
159 | stride: Int = 1,
160 | padding: Int = 0,
161 | groups: Int = 1,
162 | bias: Bool = true
163 | ) {
164 | self._convtr.wrappedValue = ConvTransposed1d(
165 | inputChannels: inC, outputChannels: outC, kernelSize: kSize, stride: stride,
166 | padding: padding, groups: groups, bias: bias)
167 | }
168 |
169 | func callAsFunction(_ x: MLXArray) -> MLXArray {
170 | self.convtr(x)
171 | }
172 | }
173 |
174 | func getExtraPaddingForConv1d(_ x: MLXArray, kSize: Int, stride: Int, paddingTotal: Int) -> Int {
175 | let len = x.dim(-1)
176 | let nFrames = Float(max(len + paddingTotal - kSize, 0)) / Float(stride) + 1.0
177 | let idealLen = (Int(nFrames.rounded(.up)) - 1) * stride + kSize - paddingTotal
178 | return max(0, idealLen - len)
179 | }
180 |
181 | func unpad1d(_ x: MLXArray, unpadL: Int, unpadR: Int) -> MLXArray {
182 | let len = x.dim(-1)
183 | let left = unpadL
184 | let right = len - unpadR
185 | return x[.ellipsis, left.. MLXArray {
209 | var kSize = self.kSize
210 | // Effective kernel size with dilations.
211 | kSize = (kSize - 1) * self.conv.conv.dilation + 1
212 | let paddingTotal = kSize - self.conv.conv.stride
213 | let extraPadding = getExtraPaddingForConv1d(
214 | x, kSize: kSize, stride: self.conv.conv.stride, paddingTotal: paddingTotal)
215 | var pd: MLXArray
216 | let z = IntOrPair.init((0, 0))
217 | if self.causal {
218 | let widths = [z, z, IntOrPair((paddingTotal, extraPadding))]
219 | pd = padded(x, widths: widths, mode: self.padMode)
220 | } else {
221 | let paddingRight = paddingTotal / 2
222 | let paddingLeft = paddingTotal - paddingRight
223 | let widths = [z, z, IntOrPair((paddingLeft, paddingRight + extraPadding))]
224 | pd = padded(x, widths: widths, mode: self.padMode)
225 | }
226 | let y = self.conv(pd)
227 | return y
228 | }
229 |
230 | func resetState() {
231 | statePrevXs = StreamArray()
232 | self.leftPadApplied = false
233 | }
234 |
235 | func step(_ x: StreamArray) -> StreamArray {
236 | if var inner = x.inner {
237 | let stride = self.conv.conv.stride
238 | let dilation = self.conv.conv.dilation
239 | if !self.leftPadApplied {
240 | self.leftPadApplied = true
241 | let kSize = (self.kSize - 1) * dilation + 1
242 | let paddingTotal = kSize - stride
243 | let z = IntOrPair.init((0, 0))
244 | let widths = [z, z, IntOrPair((paddingTotal, 0))]
245 | inner = padded(inner, widths: widths, mode: self.padMode)
246 | }
247 | let kernel = (self.kSize - 1) * dilation + 1
248 | var x = StreamArray(inner)
249 | x = self.statePrevXs.cat2(x, axis: -1)
250 | let seqLen = x.dim(-1)
251 | let numFrames = max(seqLen + stride - kernel, 0) / stride
252 | if numFrames > 0 {
253 | let offset = numFrames * stride
254 | self.statePrevXs = x.narrow(offset, seqLen - offset, axis: -1)
255 | let inL = (numFrames - 1) * stride + kernel
256 | x = x.narrow(0, inL, axis: -1)
257 | if let x = x.inner {
258 | return StreamArray(self.conv.conv(x))
259 | } else {
260 | return StreamArray()
261 | }
262 | } else {
263 | self.statePrevXs = x
264 | return StreamArray()
265 | }
266 | } else {
267 | return StreamArray()
268 | }
269 | }
270 | }
271 |
272 | class StreamableConvTranspose1d: Module, UnaryLayer, StreamingLayer {
273 | let causal: Bool
274 | let kSize: Int
275 | var statePrevYs: StreamArray = StreamArray()
276 | @ModuleInfo(key: "convtr") var convtr: NormConvTranspose1d
277 |
278 | init(
279 | inC: Int, outC: Int, kSize: Int, stride: Int, groups: Int, bias: Bool,
280 | causal: Bool
281 | ) {
282 | self.causal = causal
283 | self.kSize = kSize
284 | self._convtr.wrappedValue = NormConvTranspose1d(
285 | inC: inC, outC: outC, kSize: kSize, stride: stride, groups: groups, bias: bias)
286 | }
287 |
288 | func callAsFunction(_ x: MLXArray) -> MLXArray {
289 | let stride = self.convtr.convtr.stride
290 | let paddingTotal = max(self.kSize - stride, 0)
291 | let x = self.convtr(x)
292 | if self.causal {
293 | return unpad1d(x, unpadL: 0, unpadR: paddingTotal)
294 | } else {
295 | let unpadR = paddingTotal / 2
296 | let unpadL = paddingTotal - unpadR
297 | return unpad1d(x, unpadL: unpadL, unpadR: unpadR)
298 | }
299 | }
300 |
301 | func resetState() {
302 | statePrevYs = StreamArray()
303 | }
304 |
305 | func step(_ x: StreamArray) -> StreamArray {
306 | if let x = x.inner {
307 | let stride = self.convtr.convtr.stride
308 | var ys = self.convtr(x)
309 | let ot = ys.dim(-1)
310 | if var prevYs = self.statePrevYs.inner {
311 | let pt = prevYs.dim(-1)
312 | if let bias = self.convtr.convtr.bias {
313 | prevYs = prevYs - bias[.newAxis, 0..., .newAxis]
314 | }
315 | let ys1 = ys[.ellipsis, 0.. MLXArray {
339 | self.conv(x)
340 | }
341 |
342 | func resetState() {
343 | self.conv.resetState()
344 | }
345 |
346 | func step(_ x: StreamArray) -> StreamArray {
347 | self.conv.step(x)
348 | }
349 | }
350 |
351 | class ConvTrUpsample1d: Module, UnaryLayer {
352 | @ModuleInfo(key: "convtr") var convtr: StreamableConvTranspose1d
353 |
354 | init(stride: Int, dim: Int, causal: Bool) {
355 | self._convtr.wrappedValue = StreamableConvTranspose1d(
356 | inC: dim, outC: dim, kSize: 2 * stride, stride: stride, groups: dim, bias: false,
357 | causal: causal)
358 | }
359 |
360 | func callAsFunction(_ x: MLXArray) -> MLXArray {
361 | self.convtr(x)
362 | }
363 |
364 | func resetState() {
365 | self.convtr.resetState()
366 | }
367 |
368 | func step(_ x: StreamArray) -> StreamArray {
369 | self.convtr.step(x)
370 | }
371 | }
372 |
--------------------------------------------------------------------------------
/MoshiLib/Transformer.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import MLX
6 | import MLXFast
7 | import MLXNN
8 |
9 | public enum Norm {
10 | case layerNorm
11 | case rmsNorm
12 | }
13 |
14 | public enum PositionalEmbedding {
15 | case none
16 | case rope
17 | }
18 |
19 | public struct TransformerConfig {
20 | public var dModel: Int
21 | public var numHeads: Int
22 | public var numLayers: Int
23 | public var causal: Bool
24 | public var normFirst: Bool
25 | public var biasFF: Bool
26 | public var biasAttn: Bool
27 | public var layerScale: Float?
28 | public var positionalEmbedding: PositionalEmbedding
29 | public var useConvBias: Bool
30 | public var gating: Bool
31 | public var norm: Norm
32 | public var context: Int
33 | public var maxPeriod: Int
34 | public var maxSeqLen: Int
35 | public var kvRepeat: Int
36 | public var dimFeedForward: Int
37 | public var convLayout: Bool
38 | public var useRotatingKVCache: Bool
39 |
40 | public func headDim() -> Int {
41 | self.dModel / self.numHeads
42 | }
43 |
44 | // kyutai/neilz/mimi_exp/xps/f28fe6d5/.hydra/config.yaml
45 | public static func v1_300m() -> TransformerConfig {
46 | TransformerConfig(
47 | dModel: 1024,
48 | numHeads: 8,
49 | numLayers: 16,
50 | causal: true,
51 | normFirst: true,
52 | biasFF: false,
53 | biasAttn: false,
54 | layerScale: nil,
55 | positionalEmbedding: .rope,
56 | useConvBias: false,
57 | gating: true,
58 | norm: .rmsNorm,
59 | context: 750,
60 | maxPeriod: 100000,
61 | maxSeqLen: 4096,
62 | kvRepeat: 1,
63 | dimFeedForward: 1024 * 4,
64 | convLayout: false,
65 | useRotatingKVCache: false
66 | )
67 | }
68 |
69 | // kyutai/neilz/mimi_exp/xps/8d2516b9/.hydra/config.yaml
70 | public static func v1_1b() -> TransformerConfig {
71 | TransformerConfig(
72 | dModel: 2048,
73 | numHeads: 16,
74 | numLayers: 16,
75 | causal: true,
76 | normFirst: true,
77 | biasFF: false,
78 | biasAttn: false,
79 | layerScale: nil,
80 | positionalEmbedding: .rope,
81 | useConvBias: false,
82 | gating: true,
83 | norm: .rmsNorm,
84 | context: 3000,
85 | maxPeriod: 100000,
86 | maxSeqLen: 4096,
87 | kvRepeat: 1,
88 | dimFeedForward: 2048 * 4,
89 | convLayout: false,
90 | useRotatingKVCache: false
91 | )
92 | }
93 |
94 | // kyutai/neilz/mimi_exp/xps/b3f79570/.hydra/config.yaml
95 | public static func v1_2b() -> TransformerConfig {
96 | TransformerConfig(
97 | dModel: 2560,
98 | numHeads: 20,
99 | numLayers: 24,
100 | causal: true,
101 | normFirst: true,
102 | biasFF: false,
103 | biasAttn: false,
104 | layerScale: nil,
105 | positionalEmbedding: .rope,
106 | useConvBias: false,
107 | gating: true,
108 | norm: .rmsNorm,
109 | context: 3000,
110 | maxPeriod: 100000,
111 | maxSeqLen: 4096,
112 | kvRepeat: 1,
113 | dimFeedForward: 2560 * 4,
114 | convLayout: false,
115 | useRotatingKVCache: false
116 | )
117 | }
118 |
119 | public static func v1_7b() -> TransformerConfig {
120 | TransformerConfig(
121 | dModel: 4096,
122 | numHeads: 32,
123 | numLayers: 32,
124 | causal: true,
125 | normFirst: true,
126 | biasFF: false,
127 | biasAttn: false,
128 | layerScale: nil,
129 | positionalEmbedding: .rope,
130 | useConvBias: false,
131 | gating: true,
132 | norm: .rmsNorm,
133 | context: 3000,
134 | maxPeriod: 10000,
135 | maxSeqLen: 4096,
136 | kvRepeat: 1,
137 | dimFeedForward: 4096 * 4,
138 | convLayout: false,
139 | useRotatingKVCache: false
140 | )
141 | }
142 | }
143 |
144 | private class MlpGating: Module, UnaryLayer {
145 | @ModuleInfo(key: "linear_in") var linear_in: Linear
146 | @ModuleInfo(key: "linear_out") var linear_out: Linear
147 |
148 | init(_ cfg: TransformerConfig) {
149 | let hidden =
150 | cfg.dimFeedForward == 4 * cfg.dModel ? 11 * cfg.dModel / 4 : 2 * cfg.dimFeedForward / 3
151 | self._linear_in.wrappedValue = Linear(cfg.dModel, 2 * hidden, bias: cfg.biasFF)
152 | self._linear_out.wrappedValue = Linear(hidden, cfg.dModel, bias: cfg.biasFF)
153 | }
154 |
155 | func callAsFunction(_ x: MLXArray) -> MLXArray {
156 | let x = linear_in(x)
157 | let (B, T) = (x.dim(0), x.dim(1))
158 | let x_reshaped = x.reshaped(B, T, 2, -1)
159 | return linear_out(silu(x_reshaped[0..., 0..., 0]) * x_reshaped[0..., 0..., 1])
160 | }
161 | }
162 |
163 | private class MlpNoGating: Module, UnaryLayer {
164 | @ModuleInfo(key: "linear1") var linear1: Linear
165 | @ModuleInfo(key: "linear2") var linear2: Linear
166 |
167 | init(_ cfg: TransformerConfig) {
168 | self._linear1.wrappedValue = Linear(cfg.dModel, cfg.dimFeedForward, bias: cfg.biasFF)
169 | self._linear2.wrappedValue = Linear(cfg.dimFeedForward, cfg.dModel, bias: cfg.biasFF)
170 | }
171 |
172 | func callAsFunction(_ x: MLXArray) -> MLXArray {
173 | linear2(geluApproximate(linear1(x)))
174 | }
175 | }
176 |
177 | private class Attention: Module {
178 | let cfg: TransformerConfig
179 | let scale: Float
180 | let rope: RoPE?
181 |
182 | @ModuleInfo(key: "in_proj") var inProj: Linear
183 | @ModuleInfo(key: "out_proj") var outProj: Linear
184 |
185 | init(_ cfg: TransformerConfig) {
186 | self.cfg = cfg
187 | self.scale = 1.0 / sqrt(Float(cfg.headDim()))
188 | let numKV = cfg.numHeads / cfg.kvRepeat
189 | let outDim = cfg.dModel + 2 * numKV * cfg.dModel / cfg.numHeads
190 | self._inProj.wrappedValue = Linear(cfg.dModel, outDim, bias: cfg.biasAttn)
191 | self._outProj.wrappedValue = Linear(cfg.dModel, cfg.dModel, bias: cfg.biasAttn)
192 | self.rope =
193 | switch cfg.positionalEmbedding {
194 | case .none: nil
195 | case .rope:
196 | RoPE(dimensions: cfg.headDim(), traditional: true, base: Float(cfg.maxPeriod))
197 | }
198 | }
199 |
200 | func callAsFunction(_ x: MLXArray, mask: MLXArray?, cache: KVCache?) -> MLXArray {
201 | let (B, T, H) = (x.dim(0), x.dim(1), x.dim(2))
202 | let qkv = inProj(x).reshaped(B, T, 3, cfg.numHeads, cfg.headDim())
203 | var q = qkv[0..., 0..., 0].transposed(0, 2, 1, 3)
204 | var k = qkv[0..., 0..., 1].transposed(0, 2, 1, 3)
205 | var v = qkv[0..., 0..., 2].transposed(0, 2, 1, 3)
206 | if let rope {
207 | let offset = cache?.offset ?? 0
208 | q = rope(q, offset: offset)
209 | k = rope(k, offset: offset)
210 | }
211 | if let cache {
212 | (k, v) = cache.update(keys: k, values: v)
213 | }
214 | let kLen = k.dim(2)
215 | let kTargetLen = T + min(self.cfg.context, kLen - T)
216 | if kTargetLen < kLen {
217 | let offset = kLen - kTargetLen
218 | k = k[0..., 0..., offset...]
219 | v = v[0..., 0..., offset...]
220 | }
221 |
222 | var mask = mask
223 | if let m = mask {
224 | let maskLen = m.dim(-1)
225 | if k.dim(2) < maskLen {
226 | let offset = maskLen - k.dim(2)
227 | mask = m[0..., offset...]
228 | }
229 | }
230 | let x = MLXFast.scaledDotProductAttention(
231 | queries: q, keys: k, values: v, scale: self.scale, mask: mask
232 | ).transposed(0, 2, 1, 3).reshaped(B, T, H)
233 | return outProj(x)
234 | }
235 | }
236 |
237 | class LayerScale: Module, UnaryLayer {
238 | @ModuleInfo(key: "scale") var scale: MLXArray
239 |
240 | init(_ dModel: Int, initValue: Float) {
241 | self._scale.wrappedValue = ones([dModel]) * initValue
242 | }
243 |
244 | func callAsFunction(_ x: MLXArray) -> MLXArray {
245 | x * self.scale
246 | }
247 | }
248 |
249 | private class TransformerLayer: Module {
250 | @ModuleInfo(key: "gating") var gating: UnaryLayer
251 | @ModuleInfo(key: "norm1") var norm1: UnaryLayer
252 | @ModuleInfo(key: "norm2") var norm2: UnaryLayer
253 | @ModuleInfo(key: "layer_scale_1") var layerScale1: LayerScale?
254 | @ModuleInfo(key: "layer_scale_2") var layerScale2: LayerScale?
255 | @ModuleInfo(key: "self_attn") var selfAttn: Attention
256 |
257 | init(_ cfg: TransformerConfig) {
258 | self._gating.wrappedValue = cfg.gating ? MlpGating(cfg) : MlpNoGating(cfg)
259 | self._norm1.wrappedValue =
260 | switch cfg.norm {
261 | case .layerNorm:
262 | LayerNorm(dimensions: cfg.dModel, eps: 1e-5)
263 | case .rmsNorm: RMSNorm(dimensions: cfg.dModel, eps: 1e-8)
264 | }
265 | self._norm2.wrappedValue =
266 | switch cfg.norm {
267 | case .layerNorm:
268 | LayerNorm(dimensions: cfg.dModel, eps: 1e-5)
269 | case .rmsNorm: RMSNorm(dimensions: cfg.dModel, eps: 1e-8)
270 | }
271 | self._selfAttn.wrappedValue = Attention(cfg)
272 |
273 | if let scale = cfg.layerScale {
274 | self._layerScale1.wrappedValue = LayerScale(cfg.dModel, initValue: scale)
275 | self._layerScale2.wrappedValue = LayerScale(cfg.dModel, initValue: scale)
276 | } else {
277 | self._layerScale1.wrappedValue = nil
278 | self._layerScale2.wrappedValue = nil
279 | }
280 | }
281 |
282 | func callAsFunction(_ x: MLXArray, mask: MLXArray?, cache: KVCache) -> MLXArray {
283 | var residual = x
284 | var x = x
285 | x = selfAttn(norm1(x), mask: mask, cache: cache)
286 | if let layerScale1 = self.layerScale1 {
287 | x = layerScale1(x)
288 | }
289 | x = residual + x
290 | residual = x
291 | x = gating(norm2(x))
292 | if let layerScale2 = self.layerScale2 {
293 | x = layerScale2(x)
294 | }
295 | return residual + x
296 | }
297 | }
298 |
299 | public class Transformer: Module {
300 | let cfg: TransformerConfig
301 | private let layers: [TransformerLayer]
302 |
303 | public init(_ cfg: TransformerConfig) {
304 | self.cfg = cfg
305 | self.layers = (0.. MLXArray {
309 | var x = x
310 | let mask = cache.first?.createAttentionMask(h: x)
311 | for (layer, c) in zip(self.layers, cache) {
312 | x = layer(x, mask: mask, cache: c)
313 | }
314 | return x
315 | }
316 |
317 | public func makeCache(bSize: Int) -> [KVCache] {
318 | let kvHeads = cfg.numHeads / cfg.kvRepeat
319 | let dtype = self.layers.first!.selfAttn.inProj.weight.dtype
320 | let cache = (0.. [MLXArray] {
359 | var x = x
360 | if self.convLayout {
361 | x = x.swappedAxes(1, 2)
362 | }
363 | if let inputProj = self.inputProj {
364 | x = inputProj(x)
365 | }
366 | x = self.transformer(x, cache: cache)
367 | var outs: [MLXArray] = []
368 | for outputProj in self.outputProjs {
369 | var out = outputProj?(x) ?? x
370 | if self.convLayout {
371 | out = out.swappedAxes(1, 2)
372 | }
373 | outs.append(out)
374 | }
375 | return outs
376 | }
377 |
378 | public func makeCache(bSize: Int) -> [KVCache] {
379 | self.transformer.makeCache(bSize: bSize)
380 | }
381 | }
382 |
--------------------------------------------------------------------------------
/Moshi/ModelView.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | //
6 | // SwiftUIView.swift
7 | // Moshi
8 | //
9 | // Created by Sebastien Mazare on 06/12/2024.
10 | //
11 |
12 | import Hub
13 | import MLX
14 | import MLXNN
15 | import MLXRandom
16 | import Metal
17 | import MoshiLib
18 | import SwiftUI
19 | import Synchronization
20 | import AVFoundation
21 |
22 | struct ModelView: View {
23 | @Binding var model: Evaluator
24 | let modelType: ModelSelect
25 | @Environment(DeviceStat.self) private var deviceStat
26 | @State var sendToSpeaker = false
27 | @State private var showSettings = false
28 |
29 | var body: some View {
30 | VStack(spacing: 16) {
31 | Group {
32 | if deviceStat.gpuUsage.activeMemory > 0 || model.statsSummary.step.cnt > 0 {
33 | CombinedStatsView(
34 | summary: model.statsSummary,
35 | deviceStat: deviceStat,
36 | modelInfo: model.modelInfo,
37 | modelName: model.modelName,
38 | urls: model.urls
39 | )
40 | .padding()
41 | .background(RoundedRectangle(cornerRadius: 12).fill(.secondary.opacity(0.1)))
42 | }
43 | }
44 | .transition(.push(from: .top))
45 | .animation(.easeInOut(duration: 0.2), value: deviceStat.gpuUsage.activeMemory > 0 || model.statsSummary.step.cnt > 0)
46 |
47 | Group {
48 | if !model.running && model.output.isEmpty {
49 | ModelInfoView(modelType: modelType)
50 | } else {
51 | OutputSection(output: model.output)
52 | }
53 | }
54 | .transition(.opacity)
55 | .animation(.easeInOut(duration: 0.2), value: !model.running && model.output.isEmpty)
56 |
57 | Group {
58 | if model.running || !model.output.isEmpty {
59 | StatusSection(model: model)
60 | }
61 | }
62 | .transition(.push(from: .bottom))
63 | .animation(.easeInOut(duration: 0.2), value: model.running || !model.output.isEmpty)
64 |
65 | // Bottom controls
66 | ZStack {
67 | // Centered Start/Stop button
68 | Button(action: model.running ? stopGenerate : generate) {
69 | Label(model.running ? "Stop" : "Start",
70 | systemImage: model.running ? "stop.circle.fill" : "play.circle.fill")
71 | .font(.title2)
72 | }
73 | .buttonStyle(.borderedProminent)
74 |
75 | // Right-aligned settings button
76 | HStack {
77 | Spacer()
78 | Button(action: { showSettings.toggle() }) {
79 | Image(systemName: "gear")
80 | .font(.title2)
81 | }
82 | .popover(isPresented: $showSettings, arrowEdge: .bottom) {
83 | Toggle(isOn: $sendToSpeaker) {
84 | Label("Use External Speaker", systemImage: "speaker.wave.2")
85 | }
86 | .padding()
87 | .onChange(of: sendToSpeaker) { (_, newValue) in
88 | if newValue {
89 | setDefaultToSpeaker()
90 | } else {
91 | setDefaultToStd()
92 | }
93 | }
94 | .presentationCompactAdaptation(.popover)
95 | }
96 | }
97 | }
98 | .padding()
99 | }
100 | .padding()
101 | .navigationTitle("Moshi: \(modelType.name)")
102 | }
103 |
104 | private func generate() {
105 | Task {
106 | await model.generate(self.modelType)
107 | }
108 | }
109 |
110 | private func stopGenerate() {
111 | Task {
112 | await model.stopGenerate()
113 | }
114 | }
115 | }
116 |
117 | struct StatusSection: View {
118 | let model: Evaluator
119 |
120 | var body: some View {
121 | VStack(spacing: 8) {
122 | if let progress = model.progress {
123 | ProgressView(progress)
124 | .transition(.slide)
125 | }
126 |
127 | if model.running {
128 | let tint = model.totalDuration > 0 ? Color.blue : Color.red
129 | HStack(spacing: 12) {
130 | Label("\(Int(model.totalDuration))s", systemImage: "clock")
131 |
132 | Image(systemName: "microphone.circle.fill")
133 | .foregroundStyle(.blue)
134 | .font(.title2)
135 | .symbolEffect(.bounce, options: .repeating)
136 |
137 | VStack(alignment: .leading) {
138 | Text("Buffer")
139 | .font(.caption)
140 | .foregroundStyle(.secondary)
141 | Gauge(value: model.bufferedDuration * 1000.0, in: 0...500) {
142 | EmptyView()
143 | } currentValueLabel: {
144 | Text("\(Int(model.bufferedDuration * 1000.0))ms")
145 | }
146 | .tint(.blue)
147 | }
148 | }
149 | .padding()
150 | .background(RoundedRectangle(cornerRadius: 10).fill(tint.opacity(0.1)))
151 | }
152 | }
153 | }
154 | }
155 |
156 | struct OutputSection: View {
157 | let output: String
158 |
159 | var body: some View {
160 | ScrollView(.vertical) {
161 | ScrollViewReader { proxy in
162 | Text(output)
163 | .textSelection(.enabled)
164 | .multilineTextAlignment(.leading)
165 | .frame(maxWidth: .infinity, alignment: .leading)
166 | .padding()
167 | .onChange(of: output) { _, _ in
168 | withAnimation {
169 | proxy.scrollTo("bottom", anchor: .bottom)
170 | }
171 | }
172 | Color.clear
173 | .frame(height: 1)
174 | .id("bottom")
175 | }
176 | }
177 | .background(RoundedRectangle(cornerRadius: 12).fill(.secondary.opacity(0.1)))
178 | }
179 | }
180 |
181 | struct DeviceStatsView: View {
182 | let deviceStat: DeviceStat
183 |
184 | var body: some View {
185 | VStack(alignment: .leading, spacing: 8) {
186 | MemoryStatRow(
187 | label: "Active Memory",
188 | value: deviceStat.gpuUsage.activeMemory,
189 | total: GPU.memoryLimit
190 | )
191 | MemoryStatRow(
192 | label: "Cache Memory",
193 | value: deviceStat.gpuUsage.cacheMemory,
194 | total: GPU.cacheLimit
195 | )
196 | HStack {
197 | MemoryStatRow(
198 | label: "Peak Memory",
199 | value: deviceStat.gpuUsage.peakMemory
200 | )
201 | Spacer()
202 | VStack(alignment: .leading) {
203 | Text("Thermal State")
204 | .foregroundStyle(.secondary)
205 | Text(deviceStat.thermalState.rawValue)
206 | }
207 | }
208 |
209 | }
210 | }
211 | }
212 |
213 | struct DebugView: View {
214 | let modelInfo: String
215 | let modelName: String
216 | let urls: (URL, URL)?
217 |
218 | var body: some View {
219 | VStack(alignment: .leading, spacing: 8) {
220 | Text("Last Info").bold()
221 | Text(modelInfo)
222 | Text("Model Name").bold()
223 | Text(modelName)
224 | if let (traceURL, codesURL) = urls {
225 | HStack {
226 | ShareLink(item: traceURL) {
227 | Label("Trace", systemImage: "square.and.arrow.up")
228 | }
229 | .buttonStyle(BorderedButtonStyle())
230 | ShareLink(item: codesURL) {
231 | Label("Codes", systemImage: "square.and.arrow.up")
232 | }
233 | .buttonStyle(BorderedButtonStyle())
234 | }
235 | .padding()
236 | }
237 | }
238 | }
239 | }
240 |
241 | struct MemoryStatRow: View {
242 | let label: String
243 | let value: Int
244 | var total: Int?
245 |
246 | var body: some View {
247 | VStack(alignment: .leading, spacing: 4) {
248 | Text(label)
249 | .foregroundStyle(.secondary)
250 | if let total = total {
251 | Gauge(value: Double(value), in: 0...Double(total)) {
252 | Text(value.formatted(.byteCount(style: .memory)))
253 | }
254 | .tint(.blue)
255 | } else {
256 | Text(value.formatted(.byteCount(style: .memory)))
257 | }
258 | }
259 | }
260 | }
261 |
262 | struct StatsView: View {
263 | let summary: StatsSummary
264 |
265 | var body: some View {
266 | Grid(alignment: .leading) {
267 | GridRow {
268 | Text("Step")
269 | Text("Avg (ms)")
270 | Text("Min (ms)")
271 | Text("Max (ms)")
272 | Text("Count")
273 | }
274 | .bold()
275 |
276 | Divider()
277 |
278 | StatRow(label: "Encode", stats: summary.encode)
279 | StatRow(label: "Main", stats: summary.step)
280 | StatRow(label: "Depformer", stats: summary.depformer)
281 | StatRow(label: "Decode", stats: summary.decode)
282 | }
283 | .font(.system(.body, design: .monospaced))
284 | }
285 | }
286 |
287 | struct StatRow: View {
288 | let label: String
289 | let stats: StatsSummary.Stats
290 |
291 | var body: some View {
292 | GridRow {
293 | Text(label)
294 | Text((1000 * stats.sum / Float(stats.cnt)).rounded(), format: .number)
295 | Text((1000 * stats.min).rounded(), format: .number)
296 | Text((1000 * stats.max).rounded(), format: .number)
297 | Text(stats.cnt, format: .number)
298 | }
299 | }
300 | }
301 |
302 | struct CombinedStatsView: View {
303 | let summary: StatsSummary
304 | let deviceStat: DeviceStat
305 | let modelInfo: String
306 | let modelName: String
307 | let urls: (URL, URL)?
308 | @State private var currentPage = 0
309 | @State private var isExpanded = true
310 |
311 | var body: some View {
312 | VStack(spacing: 0) {
313 | // Header with page indicator and collapse button
314 | ZStack {
315 | // Left-aligned content
316 | HStack {
317 | if isExpanded {
318 | HStack(spacing: 16) {
319 | ForEach(0..<3) { index in
320 | Button(action: { withAnimation { currentPage = index } }) {
321 | let text = switch index {
322 | case 0: "Model"
323 | case 1: "Device"
324 | case 2: "Details"
325 | case _: "unk"
326 | }
327 | VStack(spacing: 4) {
328 | Text(text)
329 | .font(.subheadline)
330 | .foregroundStyle(currentPage == index ? .primary : .secondary)
331 | Rectangle()
332 | .fill(currentPage == index ? .blue : .clear)
333 | .frame(height: 2)
334 | }
335 | }
336 | .buttonStyle(.plain)
337 | }
338 | }
339 | .frame(maxWidth: .infinity)
340 | } else {
341 | Text("Details")
342 | .font(.headline)
343 | .foregroundStyle(.secondary)
344 | .frame(maxWidth: .infinity, alignment: .leading)
345 | }
346 | }
347 |
348 | // Right-aligned button (always in the same position)
349 | HStack {
350 | Spacer()
351 | Button(action: {
352 | isExpanded.toggle()
353 | }) {
354 | Image(systemName: isExpanded ? "chevron.up.circle.fill" : "chevron.down.circle.fill")
355 | .foregroundStyle(.secondary)
356 | .font(.title3)
357 | }
358 | }
359 | }
360 | .padding(.bottom, isExpanded ? 8 : 0)
361 | // Add tap gesture only when collapsed
362 | .contentShape(Rectangle()) // Make entire area tappable
363 | .onTapGesture {
364 | if !isExpanded {
365 | withAnimation {
366 | isExpanded.toggle()
367 | }
368 | }
369 | }
370 |
371 | if isExpanded {
372 | TabView(selection: $currentPage) {
373 | StatsView(summary: summary)
374 | .padding(.vertical)
375 | .frame(height: 250)
376 | .tag(0)
377 | DeviceStatsView(deviceStat: deviceStat)
378 | .padding(.vertical)
379 | .tag(1)
380 | DebugView(modelInfo: modelInfo, modelName: modelName, urls: urls)
381 | .padding(.vertical)
382 | .tag(2)
383 | }
384 | #if os(iOS)
385 | .tabViewStyle(.page(indexDisplayMode: .never))
386 | #endif
387 | }
388 | }
389 | .frame(height: isExpanded ? 250 : 44)
390 | .animation(.easeInOut(duration: 0.2), value: isExpanded)
391 | }
392 | }
393 |
394 | struct ModelInfoView: View {
395 | let modelType: ModelSelect
396 |
397 | var body: some View {
398 | VStack(spacing: 24) {
399 | Image(systemName: "waveform.and.mic")
400 | .font(.system(size: 64))
401 | .foregroundStyle(.blue)
402 |
403 | VStack(spacing: 12) {
404 | Text(modelType.name)
405 | .font(.title)
406 | .bold()
407 |
408 | Text(modelType.description)
409 | .font(.body)
410 | .multilineTextAlignment(.center)
411 | .foregroundStyle(.secondary)
412 | }
413 | .padding(.horizontal)
414 | }
415 | .frame(maxWidth: .infinity, maxHeight: .infinity)
416 | .background(RoundedRectangle(cornerRadius: 12).fill(.secondary.opacity(0.1)))
417 | }
418 | }
419 |
--------------------------------------------------------------------------------
/MoshiLib/LM.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import MLX
6 | import MLXFast
7 | import MLXNN
8 |
9 | public struct DepformerConfig {
10 | var transformer: TransformerConfig
11 | var numSlices: Int
12 | }
13 |
14 | class DepformerSlice: Module {
15 | @ModuleInfo(key: "emb") var emb: Embedding
16 | @ModuleInfo(key: "linear_in") var linearIn: Linear
17 | @ModuleInfo(key: "linear_out") var linearOut: Linear
18 | @ModuleInfo(key: "transformer") var transformer: Transformer
19 |
20 | public init(
21 | inVocabSize: Int, outVocabSize: Int, mainTransformerDim: Int, cfg: TransformerConfig
22 | ) {
23 | self._emb.wrappedValue = Embedding(embeddingCount: inVocabSize, dimensions: cfg.dModel)
24 | self._linearIn.wrappedValue = Linear(mainTransformerDim, cfg.dModel, bias: false)
25 | self._linearOut.wrappedValue = Linear(cfg.dModel, outVocabSize, bias: false)
26 | self._transformer.wrappedValue = Transformer(cfg)
27 | }
28 | }
29 |
30 | class Depformer: Module {
31 | let cfg: LmConfig
32 | let transformerCache: [KVCache]
33 | @ModuleInfo(key: "slices") var slices: [DepformerSlice]
34 |
35 | public init(_ cfg: LmConfig, _ cfgDepformer: DepformerConfig, bSize: Int) {
36 | self.cfg = cfg
37 | let slices = (0.. MLXArray {
51 | for c in self.transformerCache {
52 | c.reset()
53 | }
54 | var lastToken = textToken
55 | var tokens: [MLXArray] = []
56 | for (sliceIdx, slice) in slices.enumerated() {
57 | if sliceIdx != 0 && stepIdx < self.cfg.audioDelays[sliceIdx - 1] {
58 | lastToken = MLXArray([self.cfg.audioPaddingToken()])
59 | }
60 | var xs = slice.linearIn(mainTransformerOut) + slice.emb(lastToken)
61 | xs = slice.transformer(xs, cache: self.transformerCache)
62 | let logits = slice.linearOut(xs)
63 | (lastToken, _) = sampler(logits: logits[0])
64 | tokens.append(lastToken)
65 | }
66 | return concatenated(tokens)
67 | }
68 | }
69 |
70 | public struct LmConfig {
71 | public var transformer: TransformerConfig
72 | public var depformer: DepformerConfig?
73 | public var textInVocabSize: Int
74 | public var textOutVocabSize: Int
75 | public var audioVocabSize: Int
76 | public var audioCodebooks: Int
77 | public var audioDelays: [Int]
78 |
79 | public func audioEOSToken() -> Int {
80 | self.audioVocabSize - 2
81 | }
82 |
83 | public func audioPaddingToken() -> Int {
84 | self.audioVocabSize - 1
85 | }
86 |
87 | public func textInitToken() -> Int {
88 | self.textInVocabSize - 1
89 | }
90 |
91 | public func depformerSlices() -> Int {
92 | self.depformer?.numSlices ?? 0
93 | }
94 |
95 | public static func moshi_2024_07() -> LmConfig {
96 | let depformer = DepformerConfig(
97 | transformer:
98 | TransformerConfig(
99 | dModel: 1024,
100 | numHeads: 16,
101 | numLayers: 6,
102 | causal: true,
103 | normFirst: true,
104 | biasFF: false,
105 | biasAttn: false,
106 | layerScale: nil,
107 | positionalEmbedding: .none,
108 | useConvBias: false,
109 | gating: true,
110 | norm: .rmsNorm,
111 | context: 8,
112 | maxPeriod: 10000,
113 | maxSeqLen: 4096,
114 | kvRepeat: 1,
115 | dimFeedForward: 1024 * 4,
116 | convLayout: false,
117 | useRotatingKVCache: false
118 | ), numSlices: 8)
119 | return LmConfig(
120 | transformer: TransformerConfig.v1_7b(),
121 | depformer: depformer,
122 | textInVocabSize: 32001,
123 | textOutVocabSize: 32000,
124 | audioVocabSize: 2049,
125 | audioCodebooks: 16,
126 | audioDelays: [0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1]
127 | )
128 | }
129 |
130 | public static func moshi1b(audioDelay: Int) -> LmConfig {
131 | let audioDelays = [0] + Array(repeating: audioDelay, count: 7)
132 | let depformer = DepformerConfig(
133 | transformer:
134 | TransformerConfig(
135 | dModel: 1024,
136 | numHeads: 16,
137 | numLayers: 6,
138 | causal: true,
139 | normFirst: true,
140 | biasFF: false,
141 | biasAttn: false,
142 | layerScale: nil,
143 | positionalEmbedding: .none,
144 | useConvBias: false,
145 | gating: true,
146 | norm: .rmsNorm,
147 | context: 8,
148 | maxPeriod: 10000,
149 | maxSeqLen: 4096,
150 | kvRepeat: 1,
151 | dimFeedForward: 1024 * 4,
152 | convLayout: false,
153 | useRotatingKVCache: false
154 | ), numSlices: 8)
155 | return LmConfig(
156 | transformer: TransformerConfig.v1_1b(),
157 | depformer: depformer,
158 | textInVocabSize: 48001,
159 | textOutVocabSize: 48000,
160 | audioVocabSize: 2049,
161 | audioCodebooks: 16,
162 | audioDelays: audioDelays + audioDelays
163 | )
164 | }
165 |
166 | public static func moshi2b(audioDelay: Int) -> LmConfig {
167 | let audioDelays = [0] + Array(repeating: audioDelay, count: 7)
168 | let depformer = DepformerConfig(
169 | transformer:
170 | TransformerConfig(
171 | dModel: 1024,
172 | numHeads: 16,
173 | numLayers: 6,
174 | causal: true,
175 | normFirst: true,
176 | biasFF: false,
177 | biasAttn: false,
178 | layerScale: nil,
179 | positionalEmbedding: .none,
180 | useConvBias: false,
181 | gating: true,
182 | norm: .rmsNorm,
183 | context: 8,
184 | maxPeriod: 10000,
185 | maxSeqLen: 4096,
186 | kvRepeat: 1,
187 | dimFeedForward: 1024 * 4,
188 | convLayout: false,
189 | useRotatingKVCache: false
190 | ), numSlices: 8)
191 | return LmConfig(
192 | transformer: TransformerConfig.v1_2b(),
193 | depformer: depformer,
194 | textInVocabSize: 48001,
195 | textOutVocabSize: 48000,
196 | audioVocabSize: 2049,
197 | audioCodebooks: 16,
198 | audioDelays: audioDelays + audioDelays
199 | )
200 | }
201 |
202 | public static func asr300m() -> LmConfig {
203 | return LmConfig(
204 | transformer: TransformerConfig.v1_300m(),
205 | depformer: nil,
206 | textInVocabSize: 48001,
207 | textOutVocabSize: 48000,
208 | audioVocabSize: 2049,
209 | audioCodebooks: 32,
210 | audioDelays: Array(repeating: 0, count: 32)
211 | )
212 | }
213 |
214 | public static func asr1b() -> LmConfig {
215 | return LmConfig(
216 | transformer: TransformerConfig.v1_1b(),
217 | depformer: nil,
218 | textInVocabSize: 8001,
219 | textOutVocabSize: 8000,
220 | audioVocabSize: 2049,
221 | audioCodebooks: 32,
222 | audioDelays: Array(repeating: 0, count: 32)
223 | )
224 | }
225 |
226 | public static func asr2b() -> LmConfig {
227 | return LmConfig(
228 | transformer: TransformerConfig.v1_2b(),
229 | depformer: nil,
230 | textInVocabSize: 4001,
231 | textOutVocabSize: 4000,
232 | audioVocabSize: 2049,
233 | audioCodebooks: 32,
234 | audioDelays: Array(repeating: 0, count: 32)
235 | )
236 | }
237 |
238 | public static func helium2b() -> LmConfig {
239 | return LmConfig(
240 | transformer: TransformerConfig.v1_2b(),
241 | depformer: nil,
242 | textInVocabSize: 48000,
243 | textOutVocabSize: 48000,
244 | audioVocabSize: 2049,
245 | audioCodebooks: 0,
246 | audioDelays: Array(repeating: 0, count: 0)
247 | )
248 | }
249 | }
250 |
251 | public class LM: Module {
252 | var transformerCache: [KVCache]
253 | public let cfg: LmConfig
254 | @ModuleInfo(key: "depformer") var depformer: Depformer?
255 | @ModuleInfo(key: "transformer") public var transformer: Transformer
256 | @ModuleInfo(key: "text_emb") var textEmb: Embedding
257 | @ModuleInfo(key: "out_norm") var outNorm: UnaryLayer
258 | @ModuleInfo(key: "text_linear") var textLinear: Linear
259 | @ModuleInfo(key: "audio_embs") var audioEmbs: [Embedding]
260 |
261 | public init(_ cfg: LmConfig, bSize: Int) {
262 | self.cfg = cfg
263 | self._transformer.wrappedValue = Transformer(cfg.transformer)
264 | self._depformer.wrappedValue = cfg.depformer.map { Depformer(cfg, $0, bSize: bSize) }
265 | self._textEmb.wrappedValue = Embedding(
266 | embeddingCount: cfg.textInVocabSize, dimensions: cfg.transformer.dModel)
267 | self._outNorm.wrappedValue =
268 | switch cfg.transformer.norm {
269 | case .layerNorm:
270 | LayerNorm(dimensions: cfg.transformer.dModel, eps: 1e-5)
271 | case .rmsNorm: RMSNorm(dimensions: cfg.transformer.dModel, eps: 1e-8)
272 | }
273 | self._textLinear.wrappedValue = Linear(
274 | cfg.transformer.dModel, cfg.textOutVocabSize, bias: false)
275 | self._audioEmbs.wrappedValue = (0.. MLXArray {
282 | var x = textEmb(x)
283 | x = transformer(x, cache: self.transformerCache)
284 | return textLinear(outNorm(x))
285 | }
286 |
287 | public func resetCache() {
288 | for cache in self.transformerCache {
289 | cache.reset()
290 | }
291 | }
292 |
293 | public func stepMain(textIds: MLXArray?, audioIds: [MLXArray]) -> (MLXArray, MLXArray) {
294 | var x = textIds.flatMap { textEmb($0) }
295 | for (a, emb) in zip(audioIds, self.audioEmbs) {
296 | let e = emb(a)
297 | x = x.map { $0 + e } ?? e
298 | }
299 | let out = outNorm(transformer(x!, cache: self.transformerCache))
300 | let logits = textLinear(out[0..., -1, 0...])
301 | return (out, logits)
302 | }
303 |
304 | public func sample(
305 | textIds: MLXArray?,
306 | audioIds: [MLXArray],
307 | stepIdx: Int,
308 | textSampler: Sampler,
309 | audioSampler: Sampler,
310 | cb: Callbacks
311 | ) -> (MLXArray, MLXArray) {
312 | cb.onEvent(.beginStep)
313 | var x = textIds.flatMap { textEmb($0) }
314 | for (a, emb) in zip(audioIds, self.audioEmbs) {
315 | let e = emb(a)
316 | x = x.map { $0 + e } ?? e
317 | }
318 | let mainTransformerOut = outNorm(transformer(x!, cache: self.transformerCache))
319 | let textLogits = textLinear(mainTransformerOut[0..., -1, 0...])
320 | let (textToken, _) = textSampler(logits: textLogits)
321 | textToken.eval()
322 | cb.onEvent(.endStep)
323 | if let depformer = self.depformer {
324 | cb.onEvent(.beginDepformer)
325 | let audioTokens = depformer.sample(
326 | mainTransformerOut: mainTransformerOut,
327 | stepIdx: stepIdx,
328 | sampler: audioSampler,
329 | textToken: textToken)
330 | audioTokens.eval()
331 | cb.onEvent(.endDepformer)
332 | return (textToken, audioTokens)
333 | } else {
334 | return (textToken, MLXArray())
335 | }
336 | }
337 |
338 | public func warmup() {
339 | let sampler = Sampler()
340 | let textIds = MLXArray.zeros([1, 1], dtype: .int32)
341 | let audioIds = (0.. MLXArray? {
386 | if self.stepIdx >= self.maxSteps {
387 | return nil
388 | }
389 | let textIds: MLXArray
390 | if self.stepIdx == 0 {
391 | textIds = MLXArray([self.model.cfg.textOutVocabSize]).reshaped([1, 1])
392 | } else {
393 | textIds = self.genSequence[.newAxis, 0..., 0, self.stepIdx - 1]
394 | }
395 | self.genSequence[0..., (1 + self.mainCodebooks)..., self.stepIdx] = otherAudioTokens
396 | var audioIds: [MLXArray] = []
397 | for (cbIdx, delay) in self.model.cfg.audioDelays.enumerated() {
398 | let genIdx = self.stepIdx - 1 - delay
399 | let audioToken: MLXArray
400 | if genIdx >= 0 {
401 | audioToken = self.genSequence[.newAxis, 0..., 1 + cbIdx, genIdx]
402 | } else {
403 | audioToken = MLXArray([self.model.cfg.audioPaddingToken()]).reshaped([1, 1])
404 | }
405 | if (audioToken .== MLXArray(ungeneratedToken)).any().item() {
406 | fatalError("ungenerated value in audio tokens, cb \(cbIdx), step \(stepIdx)")
407 | }
408 | assert(audioToken.shape == [1, 1])
409 | audioIds.append(audioToken)
410 | }
411 | if (textIds .== MLXArray(ungeneratedToken)).any().item() {
412 | fatalError("ungenerated value in text tokens, step \(stepIdx)")
413 | }
414 | assert(textIds.shape == [1, 1])
415 | let (tt, at) = model.sample(
416 | textIds: textIds,
417 | audioIds: audioIds,
418 | stepIdx: self.stepIdx,
419 | textSampler: self.textSampler,
420 | audioSampler: self.audioSampler,
421 | cb: self.cb
422 | )
423 | assert(tt.shape == [1])
424 | assert(at.shape == [self.model.cfg.depformerSlices()])
425 |
426 | self.genSequence[0..., 0, self.stepIdx] = tt
427 | for cbIdx in 0..= 0 {
431 | self.genSequence[0..., cbIdx + 1, genIdx] = at[cbIdx]
432 | }
433 | }
434 |
435 | self.stepIdx += 1
436 | return textIds
437 | }
438 |
439 | public func lastAudioTokens() -> MLXArray? {
440 | let genIdx = self.stepIdx - 1 - (self.model.cfg.audioDelays.max() ?? 0)
441 | if genIdx < 0 {
442 | return nil
443 | }
444 | let tokens = self.genSequence[0..., 1...self.mainCodebooks, genIdx]
445 | if (tokens .== MLXArray(ungeneratedToken)).any().item() {
446 | fatalError("ungenerated value in text tokens, step \(stepIdx)")
447 | }
448 | if (tokens .== MLXArray(self.model.cfg.audioPaddingToken())).any().item() {
449 | return nil
450 | }
451 | return tokens
452 | }
453 |
454 | public func reset() {
455 | self.stepIdx = 0
456 | self.model.resetCache()
457 | self.genSequence[0...] = MLXArray(ungeneratedToken)
458 | cb.onReset()
459 | }
460 | }
461 |
--------------------------------------------------------------------------------
/Moshi/ContentView.swift:
--------------------------------------------------------------------------------
1 | // Copyright (c) Kyutai, all rights reserved.
2 | // This source code is licensed under the license found in the
3 | // LICENSE file in the root directory of this source tree.
4 |
5 | import Hub
6 | import MLX
7 | import MLXNN
8 | import MLXRandom
9 | import Metal
10 | import MoshiLib
11 | import SwiftUI
12 | import Synchronization
13 | import Tokenizers
14 |
15 | struct CustomError: Error {
16 | let message: String
17 |
18 | init(_ s: String) {
19 | message = s
20 | }
21 | }
22 |
23 | enum ModelSelect: String, CaseIterable, Identifiable {
24 | case mimi
25 | case asr
26 | case hibiki
27 | case helium
28 | case qwen
29 |
30 | var id: Self { return self }
31 |
32 | var name: String {
33 | switch self {
34 | case .hibiki:
35 | return "Hibiki 1B"
36 | default:
37 | return rawValue.capitalized
38 | }
39 | }
40 |
41 | var description: String {
42 | switch self {
43 | case .hibiki:
44 | return
45 | "A French to English simultaneous translation model designed for real-time speech translation."
46 | default:
47 | return ""
48 | }
49 | }
50 | }
51 |
52 | struct ContentView: View {
53 | @State var model = Evaluator()
54 | @State var selectedModel: ModelSelect = .mimi
55 | @State var sendToSpeaker = false
56 | @Environment(DeviceStat.self) private var deviceStat
57 |
58 | // Currently available models
59 | private let availableModels: [ModelSelect] = [.hibiki, .asr]
60 | var body: some View {
61 | Group {
62 | if availableModels.count == 1 {
63 | // Skip navigation if only one model
64 | ModelView(
65 | model: $model,
66 | modelType: availableModels[0]
67 | )
68 | } else {
69 | NavigationSplitView(
70 | sidebar: {
71 | VStack {
72 | List {
73 | ForEach(availableModels) { modelType in
74 | NavigationLink(
75 | modelType.rawValue,
76 | destination: {
77 | ModelView(
78 | model: $model,
79 | modelType: modelType
80 | )
81 | }
82 | )
83 | }
84 | }
85 | }
86 | .navigationTitle("Available Models")
87 | #if os(iOS)
88 | .navigationBarTitleDisplayMode(.inline)
89 | #endif
90 | },
91 | detail: {
92 | Text("Please choose a model type")
93 | }
94 | )
95 | }
96 | }
97 | }
98 | }
99 |
100 | #Preview {
101 | ContentView()
102 | }
103 |
104 | @Observable
105 | @MainActor
106 | class Evaluator {
107 | var running = false
108 | var modelInfo = "...moshi..."
109 | var modelName = ""
110 | var urls: (URL, URL)? = nil
111 | var stat = ""
112 | var output = ""
113 | var progress: Progress? = nil
114 | var statsSummary: StatsSummary = StatsSummary()
115 | var bufferedDuration: Double = 0.0
116 | var totalDuration: Double = 0.0
117 | let shouldStop: Atomic = .init(false)
118 | let cb: PerfStats = PerfStats()
119 |
120 | enum LoadState {
121 | case idle
122 | case loaded(ModelState, ModelSelect)
123 | }
124 |
125 | var loadState = LoadState.idle
126 |
127 | func downloadFromHub(id: String, filename: String) async throws -> URL {
128 | guard
129 | let downloadDir = FileManager.default.urls(
130 | for: .applicationSupportDirectory, in: .userDomainMask
131 | ).first
132 | else {
133 | throw CustomError("cannot find the app support directory")
134 | }
135 | let api = HubApi(downloadBase: downloadDir)
136 | let repo = Hub.Repo(id: id)
137 | let targetURL = api.localRepoLocation(repo).appending(path: filename)
138 | if FileManager.default.fileExists(atPath: targetURL.path) {
139 | print("using cached file \(targetURL.path)")
140 | return targetURL
141 | }
142 | let url = try await api.snapshot(from: repo, matching: filename) { progress in
143 | Task { @MainActor in
144 | self.progress = progress
145 | }
146 | }
147 | // TODO: also set this back to nil on errors.
148 | self.progress = nil
149 | return url.appending(path: filename)
150 | }
151 |
152 | func loadVocab(_ cfg: LmConfig) async throws -> [Int: String] {
153 | let filename =
154 | switch cfg.textOutVocabSize {
155 | case 48000: "tokenizer_spm_48k_multi6_2.json"
156 | case 32000: "tokenizer_spm_32k_3.json"
157 | case 8000: "tokenizer_spm_8k_0.json"
158 | case 4000: "test_en_audio_4000.json"
159 | case let other: fatalError("unexpected text vocab size \(other)")
160 | }
161 | let fileURL = try await downloadFromHub(id: "lmz/moshi-swift", filename: filename)
162 | let jsonData = try Data(contentsOf: fileURL)
163 | let dictionary = try JSONDecoder().decode([Int: String].self, from: jsonData)
164 | return dictionary
165 | }
166 |
167 | func makeMoshi(_ url: URL, _ cfg: LmConfig) throws -> LM {
168 | let weights = try loadArrays(url: url)
169 | let parameters = ModuleParameters.unflattened(weights)
170 | let model = LM(cfg, bSize: 1)
171 | if url.lastPathComponent.hasSuffix(".q4.safetensors") {
172 | quantize(model: model, groupSize: 32, bits: 4)
173 | } else if url.lastPathComponent.hasSuffix(".q6.safetensors") {
174 | quantize(model: model, groupSize: 64, bits: 6)
175 | } else if url.lastPathComponent.hasSuffix(".q8.safetensors") {
176 | quantize(model: model, groupSize: 64, bits: 8)
177 | }
178 | try model.update(parameters: parameters, verify: [.all])
179 | eval(model)
180 | return model
181 | }
182 |
183 | func makeHelium(_ url: URL, _ cfg: LmConfig) throws -> LM {
184 | let weights = try loadArrays(url: url)
185 | let parameters = ModuleParameters.unflattened(weights)
186 | let model = LM(cfg, bSize: 1)
187 | if url.lastPathComponent.hasSuffix("q4.safetensors") {
188 | quantize(model: model, groupSize: 64, bits: 4)
189 | } else if url.lastPathComponent.hasSuffix(".q6.safetensors") {
190 | quantize(model: model, groupSize: 64, bits: 6)
191 | } else if url.lastPathComponent.hasSuffix("q8.safetensors") {
192 | quantize(model: model, groupSize: 64, bits: 8)
193 | }
194 | try model.update(parameters: parameters, verify: [.all])
195 | eval(model)
196 | return model
197 | }
198 |
199 | func makeQwen(_ url: URL, _ cfg: QwenConfig) throws -> QwenModel {
200 | let weights = try loadArrays(url: url)
201 | let parameters = ModuleParameters.unflattened(weights)
202 | let model = QwenModel(cfg)
203 | if let q = cfg.quantization {
204 | quantize(model: model, groupSize: q.groupSize, bits: q.bits)
205 | }
206 | try model.update(parameters: parameters, verify: [.all])
207 | eval(model)
208 | return model
209 | }
210 |
211 | func makeMimi(numCodebooks: Int) async throws -> Mimi {
212 | let cfg = MimiConfig.mimi_2024_07(numCodebooks: numCodebooks)
213 | let model = Mimi(cfg, bSize: 1)
214 |
215 | let url = try await downloadFromHub(
216 | id: "lmz/moshi-swift",
217 | filename: "tokenizer-dbaa9758-checkpoint125.safetensors")
218 | let origWeights = try loadArrays(url: url)
219 | var weights: [String: MLXArray] = [:]
220 | for (var key, var weight) in origWeights {
221 | // Mutating the keys while iterating over the map seems pretty dodgy, not sure what the idiomatic
222 | // way to do this is in swift. Hopefully this is copy on write and it's all good :)
223 | if key.hasPrefix("encoder.model") {
224 | key.replace("encoder.model.", with: "encoder.")
225 | }
226 | if key.hasPrefix("decoder.model") {
227 | key.replace("decoder.model.", with: "decoder.")
228 | }
229 | if key.hasSuffix(".in_proj_weight") {
230 | key.replace(".in_proj_weight", with: ".in_proj.weight")
231 | }
232 | if key.hasSuffix(".linear1.weight") {
233 | key.replace(".linear1.weight", with: ".gating.linear1.weight")
234 | }
235 | if key.hasSuffix(".linear2.weight") {
236 | key.replace(".linear2.weight", with: ".gating.linear2.weight")
237 | }
238 | // Awfully hardcoded matching between the pytorch layers and their mlx equivalent :(
239 | for (layerIdx, decoderIdx) in [2, 5, 8, 11].enumerated() {
240 | key.replace("decoder.\(decoderIdx).", with: "decoder.layers.\(layerIdx).upsample.")
241 | key.replace(
242 | "decoder.\(decoderIdx + 1).", with: "decoder.layers.\(layerIdx).residuals.0.")
243 | }
244 | for (layerIdx, encoderIdx) in [1, 4, 7, 10].enumerated() {
245 | key.replace(
246 | "encoder.\(encoderIdx).", with: "encoder.layers.\(layerIdx).residuals.0.")
247 | key.replace(
248 | "encoder.\(encoderIdx + 2).", with: "encoder.layers.\(layerIdx).downsample.")
249 | }
250 | key.replace("decoder.0.", with: "decoder.init_conv1d.")
251 | key.replace("decoder.14.", with: "decoder.final_conv1d.")
252 | key.replace("encoder.0.", with: "encoder.init_conv1d.")
253 | key.replace("encoder.14.", with: "encoder.final_conv1d.")
254 | key.replace(".block.1.", with: ".block.0.")
255 | key.replace(".block.3.", with: ".block.1.")
256 | // PyTorch layout for conv weights is outC, inC, kSize, for MLX it's outC, kSize, inC
257 | if key.hasSuffix(".conv.weight") || key.hasSuffix(".output_proj.weight")
258 | || key.hasSuffix(".input_proj.weight")
259 | {
260 | weight = weight.swappedAxes(-1, -2)
261 | }
262 | // PyTorch layout for conv-transposed weights is inC, outC, kSize, for MLX it's outC, kSize, inC
263 | if key.hasSuffix(".convtr.weight") {
264 | weight = weight.transposed(axes: [1, 2, 0])
265 | }
266 |
267 | // print(key, weight.shape)
268 | weights[key] = weight
269 | }
270 | let parameters = ModuleParameters.unflattened(weights)
271 | try model.update(parameters: parameters, verify: [.all])
272 | return model
273 | }
274 |
275 | func stopGenerate() async {
276 | self.shouldStop.store(true, ordering: .relaxed)
277 | }
278 |
279 | func generate(_ sm: ModelSelect) async {
280 | guard !running else { return }
281 |
282 | self.shouldStop.store(false, ordering: .relaxed)
283 | self.modelInfo = "starting"
284 | self.output = ""
285 | self.totalDuration = 0.0
286 | running = true
287 | do {
288 | let model = try await load(sm)
289 | let urls = try await model.perform { model in
290 | model.reset()
291 | await self.cb.onReset()
292 | // TODO: Do not create a fresh audio input/output on each session.
293 | let microphoneCapture = MicrophoneCapture()
294 | microphoneCapture.startCapturing()
295 | let ap = AudioPlayer(sampleRate: 24000)
296 | try ap.startPlaying()
297 | print("started the audio loops")
298 |
299 | var step = 0
300 | let startTime = CFAbsoluteTimeGetCurrent()
301 | while let pcm = microphoneCapture.receive() {
302 | if shouldStop.load(ordering: .relaxed) {
303 | break
304 | }
305 | let pcm = MLXArray(pcm)[.newAxis, .newAxis]
306 | if !model.onMicrophonePcm(pcm, ap: ap, ev: self) {
307 | break
308 | }
309 | step += 1
310 | if step % 128 == 0 {
311 | GPU.clearCache()
312 | }
313 | let currentStep = step
314 | Task { @MainActor in
315 | if currentStep % 5 == 0 {
316 | self.bufferedDuration =
317 | ap.bufferedDuration() + microphoneCapture.bufferedDuration()
318 | self.totalDuration = CFAbsoluteTimeGetCurrent() - startTime
319 | }
320 | if currentStep % 20 == 0 {
321 | self.statsSummary = self.cb.getSummary(maxEvents: 100)
322 | }
323 | }
324 | }
325 | print()
326 | let traceURL = FileManager.default.temporaryDirectory.appendingPathComponent(
327 | "moshi-trace.json")
328 | try await self.cb.writeJSONTrace(url: traceURL)
329 | let codesURL = FileManager.default.temporaryDirectory.appendingPathComponent(
330 | "moshi-codes.safetensors")
331 | do {
332 | let timestamp = Int(Date().timeIntervalSince1970)
333 | guard
334 | let documentsDirectory = FileManager.default.urls(
335 | for: .documentDirectory, in: .userDomainMask
336 | ).first
337 | else {
338 | throw NSError(
339 | domain: "FileManagerError", code: 1,
340 | userInfo: [NSLocalizedDescriptionKey: "Documents directory not found"])
341 | }
342 | let docURL = documentsDirectory.appendingPathComponent(
343 | "moshi-codes-\(timestamp).safetensors")
344 | let fileManager = FileManager.default
345 | try fileManager.copyItem(at: codesURL, to: docURL)
346 | } catch {
347 | print("Error copying file: \(error)")
348 | }
349 | try await self.cb.writeCodes(url: codesURL)
350 | microphoneCapture.stopCapturing()
351 | return (traceURL, codesURL)
352 | }
353 | self.urls = urls
354 | self.modelInfo = "finished generating"
355 | } catch {
356 | self.modelInfo = "failed: \(error)"
357 | }
358 | self.bufferedDuration = 0.0
359 | self.statsSummary = StatsSummary()
360 | running = false
361 | }
362 |
363 | func load(_ sm: ModelSelect) async throws -> ModelState {
364 | if case .loaded(let m, sm) = self.loadState {
365 | return m
366 | }
367 | // Start by reseting loadState so as to release the memory used
368 | // by the currently loaded model if any.
369 | self.loadState = .idle
370 | let m: ModelState
371 | switch sm {
372 | case .hibiki:
373 | let model = try await MoshiModel(self, self.cb)
374 | m = ModelState(model)
375 | case .mimi:
376 | let model = try await MimiModel(self, self.cb)
377 | m = ModelState(model)
378 | case .asr:
379 | let model = try await AsrModel(self, self.cb)
380 | m = ModelState(model)
381 | case .helium:
382 | let model = try await HeliumModel(self, self.cb)
383 | m = ModelState(model)
384 | case .qwen:
385 | let model = try await QwenModel_(self, self.cb)
386 | m = ModelState(model)
387 | }
388 | self.loadState = .loaded(m, sm)
389 | return m
390 | }
391 |
392 | func setModelName(_ s: String) {
393 | modelName = s
394 | }
395 | func setModelInfo(_ s: String) {
396 | modelInfo = s
397 | }
398 | }
399 |
400 | protocol Model {
401 | init(_ ev: Evaluator, _ cb: Callbacks) async throws
402 | mutating func reset()
403 | // If onMicrophonePcm returns true continue, otherwise break.
404 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool
405 | }
406 |
407 | struct MimiModel: Model {
408 | let mimi: Mimi
409 | let codes: MLXArray
410 | var currentStep: Int
411 | let totalSteps: Int
412 | let cb: Callbacks
413 |
414 | init(_ ev: Evaluator, _ cb: Callbacks) async throws {
415 | await ev.setModelInfo("building model")
416 | self.mimi = try await ev.makeMimi(numCodebooks: 16)
417 | await ev.setModelInfo("model built")
418 | let codeURL = try await ev.downloadFromHub(
419 | id: "lmz/moshi-swift", filename: "bria-codes.safetensors")
420 | guard let codes = try loadArrays(url: codeURL)["codes"] else {
421 | throw CustomError("no codes in \(codeURL)")
422 | }
423 | self.codes = codes
424 | self.currentStep = 0
425 | self.totalSteps = codes.dim(-1)
426 | self.cb = cb
427 | }
428 |
429 | mutating func reset() {
430 | mimi.resetState()
431 | }
432 |
433 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool {
434 | let sampleText =
435 | "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."
436 | let sampleWords = sampleText.components(separatedBy: " ")
437 | self.cb.onEvent(.beginEncode)
438 | let micCodes = mimi.encodeStep(StreamArray(pcm))
439 | micCodes.eval()
440 | self.cb.onEvent(.endEncode)
441 | if let micCodes = micCodes.asArray() {
442 | // As of 2024-12-04, there is a memory leak if this eval is removed, this is
443 | // triggered even without the decoding and audio playing, only the line
444 | // below is needed:
445 | // let audioTokens = codes[.ellipsis, currentStep...currentStep]
446 | eval(micCodes)
447 | self.cb.onInputAudioTokens(micCodes)
448 | let (_, _, steps) = micCodes.shape3
449 | for _ in 0..= totalSteps {
451 | break
452 | }
453 | let audioTokens = codes[.ellipsis, currentStep...currentStep]
454 | self.cb.onEvent(.beginDecode)
455 | self.cb.onOutputAudioTokens(audioTokens)
456 | let pcmOut = mimi.decodeStep(StreamArray(audioTokens))
457 | pcmOut.eval()
458 | self.cb.onEvent(.endDecode)
459 | if let p = pcmOut.asArray() {
460 | let _ = ap.send(p.asArray(Float.self))
461 | }
462 | if currentStep % 4 == 0 {
463 | let v = sampleWords[(currentStep / 4) % sampleWords.count] + " "
464 | Task { @MainActor in
465 | ev.output += v
466 | }
467 | }
468 | currentStep += 1
469 | }
470 | }
471 | return currentStep < totalSteps
472 | }
473 | }
474 |
475 | struct QwenModel_: Model {
476 | let qwen: QwenModel
477 | let tokenizer: any Tokenizer
478 |
479 | init(_ ev: Evaluator, _ cb: Callbacks) async throws {
480 | let hfRepo = "Qwen/Qwen2.5-0.5B-Instruct"
481 | await ev.setModelInfo("building model")
482 |
483 | let configUrl = try await ev.downloadFromHub(id: hfRepo, filename: "config.json")
484 | let configData = try Data(contentsOf: configUrl)
485 | let decoder = JSONDecoder()
486 | decoder.keyDecodingStrategy = .convertFromSnakeCase
487 | let cfg = try decoder.decode(QwenConfig.self, from: configData)
488 |
489 | let url = try await ev.downloadFromHub(id: hfRepo, filename: "model.safetensors")
490 | qwen = try await ev.makeQwen(url, cfg)
491 | await ev.setModelInfo("model built")
492 | tokenizer = try await AutoTokenizer.from(pretrained: hfRepo)
493 | }
494 |
495 | mutating func reset() {
496 | }
497 |
498 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool {
499 | let sampler = Sampler()
500 | let cache = qwen.makeCache(bSize: 1)
501 |
502 | var lastToken = 151643
503 | for stepIdx in 0...500 {
504 | let logits = qwen(MLXArray([lastToken]).reshaped(1, 1), cache: cache)
505 | let (tok, _) = sampler(logits: logits[0])
506 | lastToken = tok.item()
507 | let s = tokenizer.decode(tokens: [lastToken])
508 | Task { @MainActor in
509 | ev.output += s
510 | }
511 | }
512 |
513 | return false
514 | }
515 | }
516 |
517 | struct HeliumModel: Model {
518 | let helium: LM
519 | let vocab: [Int: String]
520 |
521 | init(_ ev: Evaluator, _ cb: Callbacks) async throws {
522 | await ev.setModelInfo("building model")
523 | let cfg = LmConfig.helium2b()
524 | let url = try await ev.downloadFromHub(
525 | id: "kyutai/helium-1-preview-2b-mlx", filename: "helium-1-preview-2b-q4.safetensors")
526 | helium = try await ev.makeHelium(url, cfg)
527 | await ev.setModelInfo("model built")
528 | vocab = try await ev.loadVocab(cfg)
529 | await ev.setModelInfo("warming up helium")
530 | helium.warmup()
531 | await ev.setModelInfo("done warming up")
532 | }
533 |
534 | mutating func reset() {
535 | }
536 |
537 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool {
538 | let maxSteps = helium.cfg.transformer.maxSeqLen
539 | let sampler = Sampler()
540 | let cb = PerfStats()
541 |
542 | var lastToken = MLXArray([1])
543 | for stepIdx in 0...min(500, maxSteps) {
544 | let (textToken, _) = helium.sample(
545 | textIds: lastToken.reshaped([1, 1]), audioIds: [], stepIdx: stepIdx,
546 | textSampler: sampler,
547 | audioSampler: sampler, cb: cb)
548 | let textTokenI: Int = textToken[0].item()
549 | if var v = vocab[textTokenI] {
550 | if v == "<0x0A>" {
551 | Task { @MainActor in
552 | ev.output += "\n"
553 | }
554 | } else {
555 | v.replace("▁", with: " ")
556 | Task { @MainActor in
557 | ev.output += v
558 | }
559 | }
560 | }
561 | lastToken = textToken
562 | }
563 |
564 | return false
565 | }
566 | }
567 |
568 | struct AsrModel: Model {
569 | var asr: ASR
570 |
571 | init(_ ev: Evaluator, _ cb: Callbacks) async throws {
572 | await ev.setModelInfo("building model")
573 | let url: URL
574 | let localURL = Bundle.main.url(forResource: "stt-model", withExtension: "safetensors")
575 | switch localURL {
576 | case .none:
577 | url = try await ev.downloadFromHub(
578 | id: "lmz/moshi-swift", filename: "moshi-70f8f0ea@500.q8.safetensors")
579 | case .some(let localURL):
580 | url = localURL
581 | }
582 | let cfg = LmConfig.asr1b()
583 | let moshi = try await ev.makeMoshi(url, cfg)
584 | let mimi = try await ev.makeMimi(numCodebooks: 32)
585 | await ev.setModelInfo("model built")
586 | let vocab = try await ev.loadVocab(cfg)
587 | await ev.setModelInfo("warming up mimi")
588 | mimi.warmup()
589 | await ev.setModelInfo("warming up moshi")
590 | moshi.warmup()
591 | await ev.setModelInfo("done warming up")
592 | self.asr = ASR(moshi, mimi, vocab: vocab, cb: cb)
593 | }
594 |
595 | mutating func reset() {
596 | asr.reset()
597 | }
598 |
599 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool {
600 | let tokens = asr.onPcmInput(pcm)
601 | for v in tokens {
602 | print(v, terminator: "")
603 | fflush(stdout)
604 | Task { @MainActor in
605 | ev.output += v
606 | }
607 | }
608 | return true
609 | }
610 | }
611 |
612 | struct MoshiModel: Model {
613 | let moshi: LM
614 | let vocab: [Int: String]
615 | let mimi: Mimi
616 | let gen: LMGen
617 | let cb: Callbacks
618 |
619 | init(_ ev: Evaluator, _ cb: Callbacks) async throws {
620 | await ev.setModelInfo("building model")
621 | let url: URL
622 | let cfg = LmConfig.moshi1b(audioDelay: 2)
623 | let localURL = Bundle.main.url(
624 | forResource: "moshi-37c6cfd6@200.q6", withExtension: "safetensors")
625 | switch localURL {
626 | case .none:
627 | url = try await ev.downloadFromHub(
628 | id: "lmz/moshi-swift", filename: "moshi-37c6cfd6@200.q6.safetensors")
629 | case .some(let localURL):
630 | url = localURL
631 | }
632 | self.moshi = try await ev.makeMoshi(url, cfg)
633 | await ev.setModelName(url.lastPathComponent)
634 | self.mimi = try await ev.makeMimi(numCodebooks: 16)
635 | await ev.setModelInfo("model built")
636 | let maxSteps = cfg.transformer.maxSeqLen
637 | self.cb = cb
638 | self.gen = LMGen(
639 | moshi, maxSteps: maxSteps, audioSampler: Sampler(), textSampler: Sampler(), cb: self.cb)
640 | self.vocab = try await ev.loadVocab(cfg)
641 | await ev.setModelInfo("warming up mimi")
642 | self.mimi.warmup()
643 | await ev.setModelInfo("warming up moshi")
644 | self.moshi.warmup()
645 | await ev.setModelInfo("done warming up")
646 | }
647 |
648 | mutating func reset() {
649 | mimi.resetState()
650 | gen.reset()
651 | }
652 |
653 | mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool {
654 | cb.onEvent(.beginEncode)
655 | let codes = mimi.encodeStep(StreamArray(pcm))
656 | codes.eval()
657 | cb.onEvent(.endEncode)
658 | if let codes = codes.asArray() {
659 | self.cb.onInputAudioTokens(codes)
660 | let (_, _, steps) = codes.shape3
661 | for step in 0..(_ action: @Sendable (inout Model) async throws -> R)
703 | async rethrows
704 | -> R
705 | {
706 | var model = model
707 | let result = try await action(&model)
708 | return result
709 | }
710 | }
711 |
--------------------------------------------------------------------------------