├── .gitignore ├── LICENSE ├── Package.resolved ├── Package.swift ├── README.md └── Sources ├── F5TTS ├── DiT.swift ├── Duration.swift ├── F5TTS.swift ├── Modules.swift └── Resources │ ├── mel_filters.npy │ └── test_en_1_ref_short.wav └── f5-tts-generate └── GenerateCommand.swift /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Xcode 4 | # 5 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 6 | 7 | ## User settings 8 | xcuserdata/ 9 | 10 | ## Obj-C/Swift specific 11 | *.hmap 12 | 13 | ## App packaging 14 | *.ipa 15 | *.dSYM.zip 16 | *.dSYM 17 | 18 | ## Playgrounds 19 | timeline.xctimeline 20 | playground.xcworkspace 21 | 22 | # Swift Package Manager 23 | # 24 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. 25 | # Packages/ 26 | # Package.pins 27 | # Package.resolved 28 | # *.xcodeproj 29 | # 30 | # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata 31 | # hence it is not needed unless you have added a package configuration file to your project 32 | # .swiftpm 33 | 34 | .build/ 35 | 36 | # CocoaPods 37 | # 38 | # We recommend against adding the Pods directory to your .gitignore. However 39 | # you should judge for yourself, the pros and cons are mentioned at: 40 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control 41 | # 42 | # Pods/ 43 | # 44 | # Add this line if you want to avoid checking in source code from the Xcode workspace 45 | # *.xcworkspace 46 | 47 | # Carthage 48 | # 49 | # Add this line if you want to avoid checking in source code from Carthage dependencies. 50 | # Carthage/Checkouts 51 | 52 | Carthage/Build/ 53 | 54 | # fastlane 55 | # 56 | # It is recommended to not store the screenshots in the git repo. 57 | # Instead, use fastlane to re-generate the screenshots whenever they are needed. 58 | # For more information about the recommended setup visit: 59 | # https://docs.fastlane.tools/best-practices/source-control/#source-control 60 | 61 | fastlane/report.xml 62 | fastlane/Preview.html 63 | fastlane/screenshots/**/*.png 64 | fastlane/test_output 65 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lucas Newman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "pins" : [ 3 | { 4 | "identity" : "jinja", 5 | "kind" : "remoteSourceControl", 6 | "location" : "https://github.com/maiqingqiang/Jinja", 7 | "state" : { 8 | "revision" : "6dbe4c449469fb586d0f7339f900f0dd4d78b167", 9 | "version" : "1.0.6" 10 | } 11 | }, 12 | { 13 | "identity" : "mlx-swift", 14 | "kind" : "remoteSourceControl", 15 | "location" : "https://github.com/ml-explore/mlx-swift", 16 | "state" : { 17 | "revision" : "70dbb62128a5a1471a5ab80363430adb33470cab", 18 | "version" : "0.21.2" 19 | } 20 | }, 21 | { 22 | "identity" : "swift-argument-parser", 23 | "kind" : "remoteSourceControl", 24 | "location" : "https://github.com/apple/swift-argument-parser.git", 25 | "state" : { 26 | "revision" : "41982a3656a71c768319979febd796c6fd111d5c", 27 | "version" : "1.5.0" 28 | } 29 | }, 30 | { 31 | "identity" : "swift-numerics", 32 | "kind" : "remoteSourceControl", 33 | "location" : "https://github.com/apple/swift-numerics", 34 | "state" : { 35 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", 36 | "version" : "1.0.2" 37 | } 38 | }, 39 | { 40 | "identity" : "swift-transformers", 41 | "kind" : "remoteSourceControl", 42 | "location" : "https://github.com/huggingface/swift-transformers", 43 | "state" : { 44 | "revision" : "d42fdae473c49ea216671da8caae58e102d28709", 45 | "version" : "0.1.14" 46 | } 47 | }, 48 | { 49 | "identity" : "vocos-swift", 50 | "kind" : "remoteSourceControl", 51 | "location" : "https://github.com/lucasnewman/vocos-swift.git", 52 | "state" : { 53 | "revision" : "021e7af9d0c0aff9f7b62bf9839c37554287f3af", 54 | "version" : "0.0.1" 55 | } 56 | } 57 | ], 58 | "version" : 2 59 | } 60 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 5.9 2 | // The swift-tools-version declares the minimum version of Swift required to build this package. 3 | 4 | import PackageDescription 5 | 6 | let package = Package( 7 | name: "f5-tts-swift", 8 | platforms: [.macOS(.v14), .iOS(.v16)], 9 | products: [ 10 | .library( 11 | name: "F5TTS", 12 | targets: ["F5TTS"] 13 | ) 14 | ], 15 | dependencies: [ 16 | .package(url: "https://github.com/ml-explore/mlx-swift", from: "0.18.0"), 17 | .package(url: "https://github.com/huggingface/swift-transformers", from: "0.1.13"), 18 | .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.3.0"), 19 | .package(url: "https://github.com/lucasnewman/vocos-swift.git", from: "0.0.1") 20 | ], 21 | targets: [ 22 | .target( 23 | name: "F5TTS", 24 | dependencies: [ 25 | .product(name: "MLX", package: "mlx-swift"), 26 | .product(name: "MLXNN", package: "mlx-swift"), 27 | .product(name: "MLXFast", package: "mlx-swift"), 28 | .product(name: "MLXFFT", package: "mlx-swift"), 29 | .product(name: "MLXLinalg", package: "mlx-swift"), 30 | .product(name: "MLXRandom", package: "mlx-swift"), 31 | .product(name: "Transformers", package: "swift-transformers"), 32 | .product(name: "Vocos", package: "vocos-swift"), 33 | ], 34 | path: "Sources/F5TTS", 35 | resources: [ 36 | .copy("Resources/test_en_1_ref_short.wav"), 37 | .copy("Resources/mel_filters.npy") 38 | ] 39 | ), 40 | .executableTarget( 41 | name: "f5-tts-generate", 42 | dependencies: [ 43 | "F5TTS", 44 | .product(name: "Vocos", package: "vocos-swift"), 45 | .product(name: "ArgumentParser", package: "swift-argument-parser"), 46 | .product(name: "MLX", package: "mlx-swift"), 47 | ], 48 | path: "Sources/f5-tts-generate" 49 | ) 50 | ] 51 | ) 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # F5 TTS for Swift 3 | 4 | Implementation of [F5-TTS](https://arxiv.org/abs/2410.06885) in Swift, using the [MLX Swift](https://github.com/ml-explore/mlx-swift) framework. 5 | 6 | You can listen to a [sample here](https://s3.amazonaws.com/lucasnewman.datasets/f5tts/sample.wav) that was generated in ~11 seconds on an M3 Max MacBook Pro. 7 | 8 | See the [Python repository](https://github.com/lucasnewman/f5-tts-mlx) for additional details on the model architecture. 9 | 10 | This repository is based on the original Pytorch implementation available [here](https://github.com/SWivid/F5-TTS). 11 | 12 | 13 | ## Installation 14 | 15 | The `F5TTS` Swift package can be built and run from Xcode or SwiftPM. 16 | 17 | A pretrained model is available [on Huggingface](https://hf.co/lucasnewman/f5-tts-mlx). 18 | 19 | 20 | ## Usage 21 | 22 | ```swift 23 | import F5TTS 24 | 25 | let f5tts = try await F5TTS.fromPretrained(repoId: "lucasnewman/f5-tts-mlx") 26 | 27 | let generatedAudio = try await f5tts.generate(text: "The quick brown fox jumped over the lazy dog.") 28 | ``` 29 | 30 | The result is an MLXArray with 24kHz audio samples. 31 | 32 | If you want to use your own reference audio sample, make sure it's a mono, 24kHz wav file of around 5-10 seconds: 33 | 34 | ```swift 35 | let generatedAudio = try await f5tts.generate( 36 | text: "The quick brown fox jumped over the lazy dog.", 37 | referenceAudioURL: ..., 38 | referenceAudioText: "This is the caption for the reference audio." 39 | ) 40 | ``` 41 | 42 | You can convert an audio file to the correct format with ffmpeg like this: 43 | 44 | ```bash 45 | ffmpeg -i /path/to/audio.wav -ac 1 -ar 24000 -sample_fmt s16 -t 10 /path/to/output_audio.wav 46 | ``` 47 | 48 | ## Appreciation 49 | 50 | [Yushen Chen](https://github.com/SWivid) for the original Pytorch implementation of F5 TTS and pretrained model. 51 | 52 | [Phil Wang](https://github.com/lucidrains) for the E2 TTS implementation that this model is based on. 53 | 54 | ## Citations 55 | 56 | ```bibtex 57 | @article{chen-etal-2024-f5tts, 58 | title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching}, 59 | author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen}, 60 | journal={arXiv preprint arXiv:2410.06885}, 61 | year={2024}, 62 | } 63 | ``` 64 | 65 | ```bibtex 66 | @inproceedings{Eskimez2024E2TE, 67 | title = {E2 TTS: Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS}, 68 | author = {Sefik Emre Eskimez and Xiaofei Wang and Manthan Thakker and Canrun Li and Chung-Hsien Tsai and Zhen Xiao and Hemin Yang and Zirun Zhu and Min Tang and Xu Tan and Yanqing Liu and Sheng Zhao and Naoyuki Kanda}, 69 | year = {2024}, 70 | url = {https://api.semanticscholar.org/CorpusID:270738197} 71 | } 72 | ``` 73 | 74 | ## License 75 | 76 | The code in this repository is released under the MIT license as found in the 77 | [LICENSE](LICENSE) file. 78 | -------------------------------------------------------------------------------- /Sources/F5TTS/DiT.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLX 3 | import MLXNN 4 | 5 | class TextEmbedding: Module { 6 | let text_embed: Embedding 7 | var extraModeling: Bool = false 8 | var precomputeMaxPos: Int = 4096 9 | var freqsCis: MLXArray? 10 | var text_blocks: Sequential? 11 | 12 | init(textNumEmbeds: Int, textDim: Int, convLayers: Int = 0, convMult: Int = 2) { 13 | self.text_embed = Embedding(embeddingCount: textNumEmbeds + 1, dimensions: textDim) 14 | 15 | if convLayers > 0 { 16 | self.extraModeling = true 17 | self.freqsCis = precomputeFreqsCis(dim: textDim, end: precomputeMaxPos) 18 | self.text_blocks = Sequential( 19 | layers: (0 ..< convLayers).map { _ in ConvNeXtV2Block(dim: textDim, intermediateDim: textDim * convMult) } 20 | ) 21 | } 22 | 23 | super.init() 24 | } 25 | 26 | func callAsFunction(_ inText: MLXArray, seqLen: Int, dropText: Bool = false) -> MLXArray { 27 | var text = inText + MLXArray([1]) 28 | let batchSize = text.shape[0] 29 | let textLen = text.shape[1] 30 | 31 | if textLen > seqLen { 32 | text = text[0..., 0 ..< seqLen] 33 | } 34 | 35 | if textLen < seqLen { 36 | text = MLX.padded(text, widths: [.init((0, 0)), .init((0, seqLen - textLen))], value: MLXArray(0)) 37 | } 38 | 39 | if dropText { 40 | text = MLX.zeros(like: text) 41 | } 42 | 43 | var output = text_embed(text) 44 | 45 | if extraModeling, let freqsCis = freqsCis, let textBlocks = text_blocks { 46 | let batchStart = MLX.zeros([batchSize], type: Int32.self) 47 | let posIdx = getPosEmbedIndices(start: batchStart, length: seqLen, maxPos: precomputeMaxPos) 48 | let textPosEmbed = freqsCis[posIdx] 49 | output = output + textPosEmbed 50 | output = textBlocks(output) 51 | } 52 | 53 | return output 54 | } 55 | } 56 | 57 | class InputEmbedding: Module { 58 | let proj: Linear 59 | let conv_pos_embed: ConvPositionEmbedding 60 | 61 | init(melDim: Int, textDim: Int, outDim: Int) { 62 | self.proj = Linear(melDim * 2 + textDim, outDim) 63 | self.conv_pos_embed = ConvPositionEmbedding(dim: outDim) 64 | super.init() 65 | } 66 | 67 | func callAsFunction( 68 | x: MLXArray, 69 | cond: MLXArray, 70 | textEmbed: MLXArray, 71 | dropAudioCond: Bool = false 72 | ) -> MLXArray { 73 | var cond = cond 74 | if dropAudioCond { 75 | cond = MLX.zeros(like: cond) 76 | } 77 | 78 | let combined = MLX.concatenated([x, cond, textEmbed], axis: -1) 79 | var output = proj(combined) 80 | output = conv_pos_embed(output) + output 81 | return output 82 | } 83 | } 84 | 85 | // Transformer backbone using DiT blocks 86 | 87 | public class DiT: Module { 88 | let dim: Int 89 | let time_embed: TimestepEmbedding 90 | let text_embed: TextEmbedding 91 | let input_embed: InputEmbedding 92 | let rotary_embed: RotaryEmbedding 93 | let transformer_blocks: [DiTBlock] 94 | let norm_out: AdaLayerNormZero_Final 95 | let proj_out: Linear 96 | let depth: Int 97 | 98 | init( 99 | dim: Int, 100 | depth: Int = 8, 101 | heads: Int = 8, 102 | dimHead: Int = 64, 103 | dropout: Float = 0.1, 104 | ffMult: Int = 4, 105 | melDim: Int = 100, 106 | textNumEmbeds: Int = 256, 107 | textDim: Int? = nil, 108 | convLayers: Int = 0 109 | ) { 110 | self.dim = dim 111 | let actualTextDim = textDim ?? melDim 112 | self.time_embed = TimestepEmbedding(dim: dim) 113 | self.text_embed = TextEmbedding(textNumEmbeds: textNumEmbeds, textDim: actualTextDim, convLayers: convLayers) 114 | self.input_embed = InputEmbedding(melDim: melDim, textDim: actualTextDim, outDim: dim) 115 | self.rotary_embed = RotaryEmbedding(dim: dimHead) 116 | self.depth = depth 117 | 118 | self.transformer_blocks = (0 ..< depth).map { _ in 119 | DiTBlock(dim: dim, heads: heads, dimHead: dimHead, ffMult: ffMult, dropout: dropout) 120 | } 121 | 122 | self.norm_out = AdaLayerNormZero_Final(dim: dim) 123 | self.proj_out = Linear(dim, melDim) 124 | 125 | super.init() 126 | } 127 | 128 | func callAsFunction( 129 | x: MLXArray, 130 | cond: MLXArray, 131 | text: MLXArray, 132 | time: MLXArray, 133 | dropAudioCond: Bool, 134 | dropText: Bool, 135 | mask: MLXArray? = nil 136 | ) -> MLXArray { 137 | let batchSize = x.shape[0] 138 | let seqLen = x.shape[1] 139 | 140 | let time = (time.ndim == 0) ? MLX.repeated(time.expandedDimensions(axis: 0), count: batchSize, axis: 0) : time 141 | let t = time_embed(time) 142 | let textEmbed = text_embed(text, seqLen: seqLen, dropText: dropText) 143 | var x = input_embed(x: x, cond: cond, textEmbed: textEmbed, dropAudioCond: dropAudioCond) 144 | 145 | let rope = rotary_embed.forwardFromSeqLen(seqLen) 146 | 147 | for block in transformer_blocks { 148 | x = block(x, t: t, mask: mask, rope: rope) 149 | } 150 | 151 | x = norm_out(x, emb: t) 152 | return proj_out(x) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /Sources/F5TTS/Duration.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLX 3 | import MLXNN 4 | 5 | class DurationInputEmbedding: Module { 6 | let proj: Linear 7 | let conv_pos_embed: ConvPositionEmbedding 8 | 9 | init(melDim: Int, textDim: Int, outDim: Int) { 10 | self.proj = Linear(melDim + textDim, outDim) 11 | self.conv_pos_embed = ConvPositionEmbedding(dim: outDim) 12 | super.init() 13 | } 14 | 15 | func callAsFunction( 16 | cond: MLXArray, 17 | textEmbed: MLXArray 18 | ) -> MLXArray { 19 | var output = proj(MLX.concatenated([cond, textEmbed], axis: -1)) 20 | output = conv_pos_embed(output) + output 21 | return output 22 | } 23 | } 24 | 25 | public class DurationBlock: Module { 26 | let attn_norm: LayerNorm 27 | let attn: Attention 28 | let ff_norm: LayerNorm 29 | let ff: FeedForward 30 | 31 | init(dim: Int, heads: Int, dimHead: Int, ffMult: Int = 4, dropout: Float = 0.1) { 32 | self.attn_norm = LayerNorm(dimensions: dim) 33 | self.attn = Attention(dim: dim, heads: heads, dimHead: dimHead, dropout: dropout) 34 | self.ff_norm = LayerNorm(dimensions: dim, eps: 1e-6, affine: false) 35 | self.ff = FeedForward(dim: dim, mult: ffMult, dropout: dropout, approximate: "tanh") 36 | 37 | super.init() 38 | } 39 | 40 | func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, rope: (MLXArray, Float)? = nil) -> MLXArray { 41 | let norm = attn_norm(x) 42 | let attnOutput = attn(norm, mask: mask, rope: rope) 43 | var output = x + attnOutput 44 | let normedOutput = ff_norm(output) 45 | let ffOutput = ff(normedOutput) 46 | output = output + ffOutput 47 | return output 48 | } 49 | } 50 | 51 | public class DurationTransformer: Module { 52 | let dim: Int 53 | let text_embed: TextEmbedding 54 | let input_embed: DurationInputEmbedding 55 | let rotary_embed: RotaryEmbedding 56 | let transformer_blocks: [DurationBlock] 57 | let norm_out: RMSNorm 58 | let depth: Int 59 | 60 | init( 61 | dim: Int, 62 | depth: Int = 8, 63 | heads: Int = 8, 64 | dimHead: Int = 64, 65 | dropout: Float = 0.0, 66 | ffMult: Int = 4, 67 | melDim: Int = 100, 68 | textNumEmbeds: Int = 256, 69 | textDim: Int? = nil, 70 | convLayers: Int = 0 71 | ) { 72 | self.dim = dim 73 | let actualTextDim = textDim ?? melDim 74 | self.text_embed = TextEmbedding(textNumEmbeds: textNumEmbeds, textDim: actualTextDim, convLayers: convLayers) 75 | self.input_embed = DurationInputEmbedding(melDim: melDim, textDim: actualTextDim, outDim: dim) 76 | self.rotary_embed = RotaryEmbedding(dim: dimHead) 77 | self.depth = depth 78 | 79 | self.transformer_blocks = (0 ..< depth).map { _ in 80 | DurationBlock(dim: dim, heads: heads, dimHead: dimHead, ffMult: ffMult, dropout: dropout) 81 | } 82 | 83 | self.norm_out = RMSNorm(dimensions: dim) 84 | 85 | super.init() 86 | } 87 | 88 | func callAsFunction( 89 | cond: MLXArray, 90 | text: MLXArray, 91 | mask: MLXArray? = nil 92 | ) -> MLXArray { 93 | let seqLen = cond.shape[1] 94 | 95 | let textEmbed = text_embed(text, seqLen: seqLen) 96 | var x = input_embed(cond: cond, textEmbed: textEmbed) 97 | 98 | let rope = rotary_embed.forwardFromSeqLen(seqLen) 99 | 100 | for block in transformer_blocks { 101 | x = block(x, mask: mask, rope: rope) 102 | } 103 | 104 | return norm_out(x) 105 | } 106 | } 107 | 108 | public class DurationPredictor: Module { 109 | enum DurationPredictorError: Error { 110 | case unableToLoadModel 111 | case unableToLoadReferenceAudio 112 | case unableToDetermineDuration 113 | } 114 | 115 | public let melSpec: MelSpec 116 | public let transformer: DurationTransformer 117 | 118 | let dim: Int 119 | let numChannels: Int 120 | let vocabCharMap: [String: Int] 121 | let to_pred: Sequential 122 | 123 | init( 124 | transformer: DurationTransformer, 125 | melSpec: MelSpec, 126 | vocabCharMap: [String: Int] 127 | ) { 128 | self.melSpec = melSpec 129 | self.numChannels = self.melSpec.nMels 130 | self.transformer = transformer 131 | self.dim = transformer.dim 132 | self.vocabCharMap = vocabCharMap 133 | 134 | self.to_pred = Sequential(layers: [ 135 | Linear(dim, 1, bias: false), Softplus() 136 | ]) 137 | 138 | super.init() 139 | } 140 | 141 | func callAsFunction(_ cond: MLXArray, text: [String]) -> MLXArray { 142 | var cond = cond 143 | 144 | // raw wave 145 | 146 | if cond.ndim == 2 { 147 | cond = cond.reshaped([cond.shape[1]]) 148 | cond = melSpec(x: cond) 149 | } 150 | 151 | let batch = cond.shape[0] 152 | let condSeqLen = cond.shape[1] 153 | var lens = MLX.full([batch], values: condSeqLen, type: Int.self) 154 | 155 | // text 156 | 157 | let inputText = listStrToIdx(text, vocabCharMap: vocabCharMap) 158 | let textLens = (inputText .!= -1).sum(axis: -1) 159 | lens = MLX.maximum(textLens, lens) 160 | 161 | var output = transformer(cond: cond, text: inputText) 162 | output = to_pred(output).mean().reshaped([batch, -1]) 163 | output.eval() 164 | 165 | return output 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /Sources/F5TTS/F5TTS.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Hub 3 | import MLX 4 | import MLXNN 5 | import MLXRandom 6 | import Vocos 7 | 8 | // MARK: - F5TTS 9 | 10 | func odeint_euler(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray { 11 | var ys = [y0] 12 | var yCurrent = y0 13 | 14 | for i in 0..<(t.shape[0] - 1) { 15 | let tCurrent = t[i].item(Float.self) 16 | let dt = t[i + 1].item(Float.self) - tCurrent 17 | 18 | let k = fun(tCurrent, yCurrent) 19 | let yNext = yCurrent + dt * k 20 | 21 | ys.append(yNext) 22 | yCurrent = yNext 23 | } 24 | 25 | return MLX.stacked(ys, axis: 0) 26 | } 27 | 28 | func odeint_midpoint(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray { 29 | var ys = [y0] 30 | var yCurrent = y0 31 | 32 | for i in 0..<(t.shape[0] - 1) { 33 | let tCurrent = t[i].item(Float.self) 34 | let dt = t[i + 1].item(Float.self) - tCurrent 35 | 36 | let k1 = fun(tCurrent, yCurrent) 37 | let mid = yCurrent + 0.5 * dt * k1 38 | 39 | let k2 = fun(tCurrent + 0.5 * dt, mid) 40 | let yNext = yCurrent + dt * k2 41 | 42 | ys.append(yNext) 43 | yCurrent = yNext 44 | } 45 | 46 | return MLX.stacked(ys, axis: 0) 47 | } 48 | 49 | func odeint_rk4(fun: (Float, MLXArray) -> MLXArray, y0: MLXArray, t: MLXArray) -> MLXArray { 50 | var ys = [y0] 51 | var yCurrent = y0 52 | 53 | for i in 0..<(t.shape[0] - 1) { 54 | let tCurrent = t[i].item(Float.self) 55 | let dt = t[i + 1].item(Float.self) - tCurrent 56 | 57 | let k1 = fun(tCurrent, yCurrent) 58 | let k2 = fun(tCurrent + 0.5 * dt, yCurrent + 0.5 * dt * k1) 59 | let k3 = fun(tCurrent + 0.5 * dt, yCurrent + 0.5 * dt * k2) 60 | let k4 = fun(tCurrent + dt, yCurrent + dt * k3) 61 | 62 | let yNext = yCurrent + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4) 63 | 64 | ys.append(yNext) 65 | yCurrent = yNext 66 | } 67 | 68 | return MLX.stacked(ys) 69 | } 70 | 71 | public class F5TTS: Module { 72 | public enum ODEMethod: String { 73 | case euler 74 | case midpoint 75 | case rk4 76 | } 77 | 78 | enum F5TTSError: Error { 79 | case unableToLoadModel 80 | case unableToLoadReferenceAudio 81 | case unableToDetermineDuration 82 | } 83 | 84 | public let melSpec: MelSpec 85 | public let transformer: DiT 86 | 87 | let dim: Int 88 | let numChannels: Int 89 | let vocabCharMap: [String: Int] 90 | let _durationPredictor: DurationPredictor? 91 | 92 | init( 93 | transformer: DiT, 94 | melSpec: MelSpec, 95 | vocabCharMap: [String: Int], 96 | durationPredictor: DurationPredictor? = nil 97 | ) { 98 | self.melSpec = melSpec 99 | self.numChannels = self.melSpec.nMels 100 | self.transformer = transformer 101 | self.dim = transformer.dim 102 | self.vocabCharMap = vocabCharMap 103 | self._durationPredictor = durationPredictor 104 | 105 | super.init() 106 | } 107 | 108 | private func sample( 109 | cond: MLXArray, 110 | text: [String], 111 | duration: Int? = nil, 112 | lens: MLXArray? = nil, 113 | steps: Int = 8, 114 | method: ODEMethod = .rk4, 115 | cfgStrength: Double = 2.0, 116 | swayCoef: Double? = -1.0, 117 | seed: Int? = nil, 118 | maxDuration: Int = 4096, 119 | vocoder: ((MLXArray) -> MLXArray)? = nil, 120 | progressHandler: ((Double) -> Void)? = nil 121 | ) throws -> (MLXArray, MLXArray) { 122 | MLX.eval(self.parameters()) 123 | 124 | var cond = cond 125 | 126 | // raw wave 127 | 128 | if cond.ndim == 2 { 129 | cond = cond.reshaped([cond.shape[1]]) 130 | cond = self.melSpec(x: cond) 131 | } 132 | 133 | let batch = cond.shape[0] 134 | let condSeqLen = cond.shape[1] 135 | var lens = lens ?? MLX.full([batch], values: condSeqLen, type: Int.self) 136 | 137 | // text 138 | 139 | let inputText = listStrToIdx(text, vocabCharMap: vocabCharMap) 140 | let textLens = (inputText .!= -1).sum(axis: -1) 141 | lens = MLX.maximum(textLens, lens) 142 | 143 | var condMask = lensToMask(t: lens) 144 | 145 | // duration 146 | var resolvedDuration: MLXArray? = (duration != nil) ? MLXArray(duration!) : nil 147 | 148 | if resolvedDuration == nil, let durationPredictor = self._durationPredictor { 149 | let estimatedDurationInSeconds = durationPredictor(cond, text: text).item(Float32.self) 150 | resolvedDuration = MLXArray(Int(Double(estimatedDurationInSeconds) * F5TTS.framesPerSecond)) 151 | } 152 | 153 | guard let resolvedDuration else { 154 | throw F5TTSError.unableToDetermineDuration 155 | } 156 | 157 | print("Generating \(Double(resolvedDuration.item(Float32.self)) / F5TTS.framesPerSecond) seconds of audio...") 158 | 159 | var duration = resolvedDuration 160 | duration = MLX.clip(MLX.maximum(lens + 1, duration), min: 0, max: maxDuration) 161 | let maxDuration = duration.max().item(Int.self) 162 | 163 | cond = MLX.padded(cond, widths: [.init((0, 0)), .init((0, maxDuration - condSeqLen)), .init((0, 0))]) 164 | condMask = MLX.padded(condMask, widths: [.init((0, 0)), .init((0, maxDuration - condMask.shape[1]))], value: MLXArray(false)) 165 | condMask = condMask.expandedDimensions(axis: -1) 166 | let stepCond = MLX.where(condMask, cond, MLX.zeros(like: cond)) 167 | 168 | let mask: MLXArray? = (batch > 1) ? lensToMask(t: duration) : nil 169 | 170 | // neural ode 171 | 172 | let fn: (Float, MLXArray) -> MLXArray = { t, x in 173 | let pred = self.transformer( 174 | x: x, 175 | cond: stepCond, 176 | text: inputText, 177 | time: MLXArray(t), 178 | dropAudioCond: false, 179 | dropText: false, 180 | mask: mask 181 | ) 182 | 183 | guard cfgStrength > 1e-5 else { 184 | pred.eval() 185 | return pred 186 | } 187 | 188 | let nullPred = self.transformer( 189 | x: x, 190 | cond: stepCond, 191 | text: inputText, 192 | time: MLXArray(t), 193 | dropAudioCond: true, 194 | dropText: true, 195 | mask: mask 196 | ) 197 | 198 | progressHandler?(Double(t)) 199 | 200 | let output = pred + (pred - nullPred) * cfgStrength 201 | output.eval() 202 | 203 | return output 204 | } 205 | 206 | // noise input 207 | 208 | var y0: [MLXArray] = [] 209 | for dur in duration { 210 | if let seed { 211 | MLXRandom.seed(UInt64(seed)) 212 | } 213 | let noise = MLXRandom.normal([dur.item(Int.self), self.numChannels]) 214 | y0.append(noise) 215 | } 216 | let y0Padded = padSequence(y0, paddingValue: 0.0) 217 | 218 | var t = MLXArray.linspace(Float32(0.0), Float32(1.0), count: steps) 219 | 220 | if let coef = swayCoef { 221 | t = t + coef * (MLX.cos(MLXArray(.pi) / 2 * t) - 1 + t) 222 | } 223 | 224 | let odeintFn = switch method { 225 | case .euler: odeint_euler 226 | case .midpoint: odeint_midpoint 227 | case .rk4: odeint_rk4 228 | } 229 | 230 | let trajectory = odeintFn(fn, y0Padded, t) 231 | let sampled = trajectory[-1] 232 | var out = MLX.where(condMask, cond, sampled) 233 | 234 | if let vocoder { 235 | out = vocoder(out) 236 | } 237 | out.eval() 238 | 239 | return (out, trajectory) 240 | } 241 | 242 | public func generate( 243 | text: String, 244 | referenceAudioURL: URL? = nil, 245 | referenceAudioText: String? = nil, 246 | duration: TimeInterval? = nil, 247 | steps: Int = 8, 248 | method: ODEMethod = .rk4, 249 | cfg: Double = 2.0, 250 | sway: Double = -1.0, 251 | speed: Double = 1.0, 252 | seed: Int? = nil, 253 | progressHandler: ((Double) -> Void)? = nil 254 | ) async throws -> MLXArray { 255 | print("Loading Vocos model...") 256 | let vocos = try await Vocos.fromPretrained(repoId: "lucasnewman/vocos-mel-24khz-mlx") 257 | 258 | // load the reference audio + text 259 | 260 | var audio: MLXArray 261 | let referenceText: String 262 | 263 | if let referenceAudioURL { 264 | audio = try F5TTS.loadAudioArray(url: referenceAudioURL) 265 | referenceText = referenceAudioText ?? "" 266 | } else { 267 | let refAudioAndCaption = try F5TTS.referenceAudio() 268 | (audio, referenceText) = refAudioAndCaption 269 | } 270 | 271 | let refAudioDuration = Double(audio.shape[0]) / Double(F5TTS.sampleRate) 272 | print("Using reference audio with duration: \(refAudioDuration)") 273 | 274 | // generate the audio 275 | 276 | let normalizedAudio = F5TTS.normalizeAudio(audio: audio) 277 | let processedText = referenceText + " " + text 278 | 279 | let (outputAudio, _) = try self.sample( 280 | cond: normalizedAudio.expandedDimensions(axis: 0), 281 | text: [processedText], 282 | duration: nil, 283 | steps: steps, 284 | method: method, 285 | cfgStrength: cfg, 286 | swayCoef: sway, 287 | seed: seed, 288 | vocoder: vocos.decode 289 | ) { progress in 290 | print("Generation progress: \(progress)") 291 | progressHandler?(progress) 292 | } 293 | 294 | return outputAudio[audio.shape[0]...] 295 | } 296 | } 297 | 298 | // MARK: - Pretrained Models 299 | 300 | public extension F5TTS { 301 | static func fromPretrained(repoId: String, downloadProgress: ((Progress) -> Void)? = nil) async throws -> F5TTS { 302 | let modelDirectoryURL = try await Hub.snapshot(from: repoId, matching: ["*.safetensors", "*.txt"]) { progress in 303 | downloadProgress?(progress) 304 | } 305 | return try self.fromPretrained(modelDirectoryURL: modelDirectoryURL) 306 | } 307 | 308 | static func fromPretrained(modelDirectoryURL: URL) throws -> F5TTS { 309 | let modelURL = modelDirectoryURL.appendingPathComponent("model.safetensors") 310 | let modelWeights = try loadArrays(url: modelURL) 311 | 312 | // mel spec 313 | 314 | guard let filterbankURL = Bundle.module.url(forResource: "mel_filters", withExtension: "npy") else { 315 | throw F5TTSError.unableToLoadModel 316 | } 317 | let filterbank = try MLX.loadArray(url: filterbankURL) 318 | 319 | // vocab 320 | 321 | let vocabURL = modelDirectoryURL.appendingPathComponent("vocab.txt") 322 | guard let vocabString = try String(data: Data(contentsOf: vocabURL), encoding: .utf8) else { 323 | throw F5TTSError.unableToLoadModel 324 | } 325 | 326 | let vocabEntries = vocabString.split(separator: "\n").map { String($0) } 327 | let vocab = Dictionary(uniqueKeysWithValues: zip(vocabEntries, vocabEntries.indices)) 328 | 329 | // duration model 330 | 331 | var durationPredictor: DurationPredictor? 332 | let durationModelURL = modelDirectoryURL.appendingPathComponent("duration_v2.safetensors") 333 | do { 334 | let durationModelWeights = try loadArrays(url: durationModelURL) 335 | 336 | let durationTransformer = DurationTransformer( 337 | dim: 512, 338 | depth: 8, 339 | heads: 8, 340 | dimHead: 64, 341 | ffMult: 2, 342 | textNumEmbeds: vocab.count, 343 | textDim: 512, 344 | convLayers: 2 345 | ) 346 | let predictor = DurationPredictor( 347 | transformer: durationTransformer, 348 | melSpec: MelSpec(filterbank: filterbank), 349 | vocabCharMap: vocab 350 | ) 351 | try predictor.update(parameters: ModuleParameters.unflattened(durationModelWeights), verify: [.all]) 352 | 353 | durationPredictor = predictor 354 | } catch { 355 | print("Warning: no duration predictor model found: \(error)") 356 | } 357 | 358 | // model 359 | 360 | let dit = DiT( 361 | dim: 1024, 362 | depth: 22, 363 | heads: 16, 364 | ffMult: 2, 365 | textNumEmbeds: vocab.count, 366 | textDim: 512, 367 | convLayers: 4 368 | ) 369 | let f5tts = F5TTS( 370 | transformer: dit, 371 | melSpec: MelSpec(filterbank: filterbank), 372 | vocabCharMap: vocab, 373 | durationPredictor: durationPredictor 374 | ) 375 | try f5tts.update(parameters: ModuleParameters.unflattened(modelWeights), verify: [.all]) 376 | 377 | return f5tts 378 | } 379 | } 380 | 381 | // MARK: - Utilities 382 | 383 | public extension F5TTS { 384 | static var sampleRate: Int = 24000 385 | static var hopLength: Int = 256 386 | static var framesPerSecond: Double = .init(sampleRate) / Double(hopLength) 387 | 388 | static func loadAudioArray(url: URL) throws -> MLXArray { 389 | try AudioUtilities.loadAudioFile(url: url) 390 | } 391 | 392 | static func referenceAudio() throws -> (MLXArray, String) { 393 | guard let url = Bundle.module.url(forResource: "test_en_1_ref_short", withExtension: "wav") else { 394 | throw F5TTSError.unableToLoadReferenceAudio 395 | } 396 | 397 | return try ( 398 | self.loadAudioArray(url: url), 399 | "Some call me nature, others call me mother nature." 400 | ) 401 | } 402 | 403 | static func normalizeAudio(audio: MLXArray, targetRMS: Double = 0.1) -> MLXArray { 404 | let rms = Double(audio.square().mean().sqrt().item(Float.self)) 405 | if rms < targetRMS { 406 | return audio * targetRMS / rms 407 | } 408 | return audio 409 | } 410 | 411 | static func estimatedDuration(refAudio: MLXArray, refText: String, text: String, speed: Double = 1.0) -> TimeInterval { 412 | let refDurationInFrames = refAudio.shape[0] / self.hopLength 413 | let refTextLength = refText.utf8.count 414 | let genTextLength = text.utf8.count 415 | 416 | let refAudioToTextRatio = Double(refDurationInFrames) / Double(refTextLength) 417 | let textLength = Double(genTextLength) / speed 418 | let estimatedDurationInFrames = Int(refAudioToTextRatio * textLength) 419 | 420 | let estimatedDuration = TimeInterval(estimatedDurationInFrames) / Self.framesPerSecond 421 | print("Using duration of \(estimatedDuration) seconds (\(estimatedDurationInFrames) frames) for generated speech.") 422 | 423 | return estimatedDuration 424 | } 425 | } 426 | 427 | // MLX utilities 428 | 429 | func lensToMask(t: MLXArray, length: Int? = nil) -> MLXArray { 430 | let maxLength = length ?? t.max(keepDims: false).item(Int.self) 431 | let seq = MLXArray(0.. MLXArray { 438 | let ndim = t.ndim 439 | 440 | guard let seqLen = t.shape.last, length > seqLen else { 441 | return t[0..., .ellipsis] 442 | } 443 | 444 | let paddingValue = MLXArray(value ?? 0.0) 445 | 446 | let padded: MLXArray 447 | switch ndim { 448 | case 1: 449 | padded = MLX.padded(t, widths: [.init((0, length - seqLen))], value: paddingValue) 450 | case 2: 451 | padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen))], value: paddingValue) 452 | case 3: 453 | padded = MLX.padded(t, widths: [.init((0, 0)), .init((0, length - seqLen)), .init((0, 0))], value: paddingValue) 454 | default: 455 | fatalError("Unsupported padding dims: \(ndim)") 456 | } 457 | 458 | return padded[0..., .ellipsis] 459 | } 460 | 461 | func padSequence(_ t: [MLXArray], paddingValue: Float = 0) -> MLXArray { 462 | let maxLen = t.map { $0.shape.last ?? 0 }.max() ?? 0 463 | let t = MLX.stacked(t, axis: 0) 464 | return padToLength(t, length: maxLen, value: paddingValue) 465 | } 466 | 467 | func listStrToIdx(_ text: [String], vocabCharMap: [String: Int], paddingValue: Int = -1) -> MLXArray { 468 | let listIdxTensors = text.map { str in str.map { char in vocabCharMap[String(char), default: 0] }} 469 | let mlxArrays = listIdxTensors.map { MLXArray($0) } 470 | let paddedText = padSequence(mlxArrays, paddingValue: Float(paddingValue)) 471 | return paddedText.asType(.int32) 472 | } 473 | -------------------------------------------------------------------------------- /Sources/F5TTS/Modules.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import MLX 3 | import MLXFast 4 | import MLXFFT 5 | import MLXLinalg 6 | import MLXNN 7 | import MLXRandom 8 | 9 | // rotary positional embedding related 10 | 11 | class RotaryEmbedding: Module { 12 | let inv_freq: MLXArray 13 | let interpolationFactor: Float 14 | 15 | init( 16 | dim: Int, 17 | useXpos: Bool = false, 18 | scaleBase: Int = 512, 19 | interpolationFactor: Float = 1.0, 20 | base: Float = 10000.0, 21 | baseRescaleFactor: Float = 1.0 22 | ) { 23 | let adjustedBase = base * pow(baseRescaleFactor, Float(dim) / Float(dim - 2)) 24 | self.inv_freq = 1.0 / pow(adjustedBase, MLXArray(stride(from: 0, to: dim, by: 2)).asType(.float32) / Float(dim)) 25 | 26 | assert(interpolationFactor >= 1.0, "Interpolation factor must be >= 1.0") 27 | self.interpolationFactor = interpolationFactor 28 | 29 | super.init() 30 | } 31 | 32 | func forwardFromSeqLen(_ seqLen: Int) -> (MLXArray, Float) { 33 | let t = MLXArray(0.. (MLXArray, Float) { 38 | var freqs = MLX.matmul(t.expandedDimensions(axis: 1).asType(inv_freq.dtype), inv_freq.expandedDimensions(axis: 0)) 39 | freqs = freqs / interpolationFactor 40 | 41 | freqs = MLX.stacked([freqs, freqs], axis: -1) 42 | let newShape = Array( 43 | freqs.shape.dropLast(2) + 44 | [freqs.shape[freqs.shape.count - 2] * freqs.shape[freqs.shape.count - 1]] 45 | ) 46 | freqs = MLX.reshaped(freqs, newShape) 47 | return (freqs, 1.0) 48 | } 49 | } 50 | 51 | func precomputeFreqsCis(dim: Int, end: Int, theta: Float = 10000.0, thetaRescaleFactor: Float = 1.0) -> MLXArray { 52 | let range = MLXArray(stride(from: 0, to: dim, by: 2)).asType(.float32)[0..<(dim / 2)] 53 | let freqs = 1.0 / MLX.pow(MLXArray(theta), range / Float(dim)) 54 | 55 | let t = MLXArray(0.. MLXArray { 65 | let scaleArray = MLX.ones(like: start).asType(.float32) * scale 66 | 67 | let pos = MLX.expandedDimensions(start, axis: 1) + 68 | (MLXArray(0.. MLXArray { 74 | let shape = x.shape 75 | let newShape = Array(shape.dropLast() + [shape.last! / 2, 2]) 76 | let reshapedX = x.reshaped(newShape) 77 | 78 | let x1x2 = reshapedX.split(parts: 2, axis: -1) 79 | let x1 = x1x2[0] 80 | let x2 = x1x2[1] 81 | 82 | let squeezedX1 = x1.squeezed(axis: -1) 83 | let squeezedX2 = x2.squeezed(axis: -1) 84 | 85 | let stackedX = MLX.stacked([-squeezedX2, squeezedX1], axis: -1) 86 | 87 | let finalShape = Array(stackedX.shape.dropLast(2) + [stackedX.shape[stackedX.shape.count - 2] * stackedX.shape[stackedX.shape.count - 1]]) 88 | let result = stackedX.reshaped(finalShape) 89 | 90 | return result 91 | } 92 | 93 | func applyRotaryPosEmb(t: MLXArray, freqs: MLXArray, scale: Float = 1.0) -> MLXArray { 94 | let rotDim = freqs.shape[freqs.shape.count - 1] 95 | let seqLen = t.shape[t.shape.count - 2] 96 | 97 | let freqsTrimmed = freqs[(-seqLen)..., 0...] 98 | let scaleAdjusted = MLXArray(scale) 99 | 100 | var freqsRearranged = freqsTrimmed 101 | if t.ndim == 4 && freqsRearranged.ndim == 3 { 102 | freqsRearranged = freqsRearranged.reshaped([freqsRearranged.shape[0], 1, freqsRearranged.shape[1], freqsRearranged.shape[2]]) 103 | } 104 | 105 | let tRotated = t[.ellipsis, 0.. MLXArray { 138 | logMelSpectrogram(audio: x, nMels: nMels, nFFT: nFFT, hopLength: hopLength, filterbank: filterbank) 139 | } 140 | 141 | public func stft(x: MLXArray, window: MLXArray, nperseg: Int, noverlap: Int? = nil, nfft: Int? = nil) -> MLXArray { 142 | let nfft = nfft ?? nperseg 143 | let noverlap = noverlap ?? nfft 144 | let padding = nperseg / 2 145 | let x = MLX.padded(x, width: IntOrPair(padding)) 146 | let strides = [noverlap, 1] 147 | let t = (x.shape[0] - nperseg + noverlap) / noverlap 148 | let shape = [t, nfft] 149 | let stridedX = MLX.asStrided(x, shape, strides: strides) 150 | return MLXFFT.rfft(stridedX * window) 151 | } 152 | 153 | public func logMelSpectrogram(audio: MLXArray, nMels: Int = 100, nFFT: Int = 1024, hopLength: Int = 256, filterbank: MLXArray) -> MLXArray { 154 | let freqs = stft(x: audio, window: hanning(nFFT), nperseg: nFFT, noverlap: hopLength) 155 | let magnitudes = freqs[0.. MLXArray { 162 | let window = (0.. MLXArray { 178 | let scale: Float = 1000.0 179 | let halfDim = dim / 2 180 | 181 | let emb = log(10000.0) / Float(halfDim - 1) 182 | let expEmb = MLX.exp(MLXArray(0.. MLXArray { 215 | var input = x 216 | 217 | if let mask = mask { 218 | let expandedMask = MLX.expandedDimensions(mask, axis: -1) 219 | input = input * expandedMask 220 | } 221 | 222 | var output = conv1d(input) 223 | 224 | if let mask = mask { 225 | let expandedMask = MLX.expandedDimensions(mask, axis: -1) 226 | output = output * expandedMask 227 | } 228 | 229 | return output 230 | } 231 | } 232 | 233 | // global response normalization 234 | 235 | class GRN: Module { 236 | var gamma: MLXArray 237 | var beta: MLXArray 238 | 239 | init(dim: Int) { 240 | self.gamma = MLX.zeros([1, 1, dim]) 241 | self.beta = MLX.zeros([1, 1, dim]) 242 | super.init() 243 | } 244 | 245 | func callAsFunction(_ x: MLXArray) -> MLXArray { 246 | let Gx = MLXLinalg.norm(x, ord: 2, axis: 1, keepDims: true) 247 | let Nx = Gx / (Gx.mean(axis: -1, keepDims: true) + 1e-6) 248 | let output = gamma * (x * Nx) + beta + x 249 | return output 250 | } 251 | } 252 | 253 | // ConvNeXt-v2 block 254 | 255 | open class GroupableConv1d: Module, UnaryLayer { 256 | public let weight: MLXArray 257 | public let bias: MLXArray? 258 | public let padding: Int 259 | public let dilation: Int 260 | public let groups: Int 261 | public let stride: Int 262 | 263 | convenience init(_ inputChannels: Int, _ outputChannels: Int, kernelSize: Int, padding: Int, dilation: Int, groups: Int) { 264 | self.init(inputChannels: inputChannels, outputChannels: outputChannels, kernelSize: kernelSize, padding: padding, dilation: dilation, groups: groups) 265 | } 266 | 267 | public init( 268 | inputChannels: Int, 269 | outputChannels: Int, 270 | kernelSize: Int, 271 | stride: Int = 1, 272 | padding: Int = 0, 273 | dilation: Int = 1, 274 | groups: Int = 1, 275 | bias: Bool = true 276 | ) { 277 | let scale = sqrt(1 / Float(inputChannels * kernelSize)) 278 | 279 | self.weight = uniform( 280 | low: -scale, high: scale, [outputChannels, kernelSize, inputChannels / groups] 281 | ) 282 | self.bias = bias ? MLXArray.zeros([outputChannels]) : nil 283 | self.stride = stride 284 | self.padding = padding 285 | self.dilation = dilation 286 | self.groups = groups 287 | } 288 | 289 | open func callAsFunction(_ x: MLXArray) -> MLXArray { 290 | var y = conv1d(x, weight, stride: stride, padding: padding, dilation: dilation, groups: groups) 291 | if let bias { 292 | y = y + bias 293 | } 294 | return y 295 | } 296 | } 297 | 298 | class ConvNeXtV2Block: Module, UnaryLayer { 299 | let dwconv: GroupableConv1d 300 | let norm: LayerNorm 301 | let pwconv1: Linear 302 | let act: GELU 303 | let grn: GRN 304 | let pwconv2: Linear 305 | 306 | init(dim: Int, intermediateDim: Int, dilation: Int = 1) { 307 | let padding = (dilation * (7 - 1)) / 2 308 | self.dwconv = GroupableConv1d(inputChannels: dim, outputChannels: dim, kernelSize: 7, padding: padding, groups: dim) 309 | self.norm = LayerNorm(dimensions: dim, eps: 1e-6) 310 | self.pwconv1 = Linear(dim, intermediateDim) 311 | self.act = GELU() 312 | self.grn = GRN(dim: intermediateDim) 313 | self.pwconv2 = Linear(intermediateDim, dim) 314 | 315 | super.init() 316 | } 317 | 318 | func callAsFunction(_ x: MLXArray) -> MLXArray { 319 | let residual = x 320 | var out = dwconv(x) 321 | out = norm(out) 322 | out = pwconv1(out) 323 | out = act(out) 324 | out = grn(out) 325 | out = pwconv2(out) 326 | return residual + out 327 | } 328 | } 329 | 330 | // AdaLayerNormZero 331 | // return with modulated x for attn input, and params for later mlp modulation 332 | 333 | class AdaLayerNormZero: Module { 334 | let silu: SiLU 335 | let linear: Linear 336 | let norm: LayerNorm 337 | 338 | init(dim: Int) { 339 | self.silu = SiLU() 340 | self.linear = Linear(dim, dim * 6) 341 | self.norm = LayerNorm(dimensions: dim, eps: 1e-6, affine: false) 342 | super.init() 343 | } 344 | 345 | func callAsFunction(_ x: MLXArray, emb: MLXArray) -> (MLXArray, MLXArray, MLXArray, MLXArray, MLXArray) { 346 | let embProcessed = linear(silu(emb)) 347 | let parts = embProcessed.split(parts: 6, axis: 1) 348 | let shiftMsa = parts[0] 349 | let scaleMsa = parts[1] 350 | let gateMsa = parts[2] 351 | let shiftMlp = parts[3] 352 | let scaleMlp = parts[4] 353 | let gateMlp = parts[5] 354 | 355 | let normX = norm(x) 356 | let modulatedX = normX * (MLXArray(1) + MLX.expandedDimensions(scaleMsa, axis: 1)) + MLX.expandedDimensions(shiftMsa, axis: 1) 357 | return (modulatedX, gateMsa, shiftMlp, scaleMlp, gateMlp) 358 | } 359 | } 360 | 361 | // AdaLayerNormZero for final layer 362 | // return only with modulated x for attn input, cuz no more mlp modulation 363 | 364 | class AdaLayerNormZero_Final: Module { 365 | let silu: SiLU 366 | let linear: Linear 367 | let norm: LayerNorm 368 | 369 | init(dim: Int) { 370 | self.silu = SiLU() 371 | self.linear = Linear(dim, dim * 2) 372 | self.norm = LayerNorm(dimensions: dim, eps: 1e-6, affine: false) 373 | super.init() 374 | } 375 | 376 | func callAsFunction(_ x: MLXArray, emb: MLXArray? = nil) -> MLXArray { 377 | guard let emb = emb else { 378 | fatalError("Embedding tensor must not be nil") 379 | } 380 | 381 | let embProcessed = linear(silu(emb)) 382 | 383 | let scaleAndShift = embProcessed.split(parts: 2, axis: 1) 384 | let scale = scaleAndShift[0] 385 | let shift = scaleAndShift[1] 386 | 387 | let modulatedX = norm(x) * (MLXArray(1) + scale.expandedDimensions(axis: 1)) + shift.expandedDimensions(axis: 1) 388 | 389 | return modulatedX 390 | } 391 | } 392 | 393 | // feed forward 394 | 395 | class FeedForward: Module { 396 | let ff: Sequential 397 | 398 | init(dim: Int, dimOut: Int? = nil, mult: Int = 4, dropout: Float = 0.0, approximate: String = "none") { 399 | let innerDim = Int(dim * mult) 400 | let outputDim = dimOut ?? dim 401 | 402 | let activation = GELU(approximation: approximate == "tanh" ? .tanh : .none) 403 | 404 | let projectIn = Sequential(layers: [ 405 | Linear(dim, innerDim), 406 | activation 407 | ]) 408 | 409 | self.ff = Sequential(layers: [ 410 | projectIn, 411 | Dropout(p: dropout), 412 | Linear(innerDim, outputDim) 413 | ]) 414 | 415 | super.init() 416 | } 417 | 418 | func callAsFunction(_ x: MLXArray) -> MLXArray { 419 | return ff(x) 420 | } 421 | } 422 | 423 | // attention 424 | 425 | class Attention: Module { 426 | let dim: Int 427 | let heads: Int 428 | let innerDim: Int 429 | let dropout: Float 430 | 431 | let to_q: Linear 432 | let to_k: Linear 433 | let to_v: Linear 434 | let to_out: Sequential 435 | 436 | init(dim: Int, heads: Int = 8, dimHead: Int = 64, dropout: Float = 0.0) { 437 | self.dim = dim 438 | self.heads = heads 439 | self.innerDim = heads * dimHead 440 | self.dropout = dropout 441 | 442 | self.to_q = Linear(dim, innerDim) 443 | self.to_k = Linear(dim, innerDim) 444 | self.to_v = Linear(dim, innerDim) 445 | 446 | self.to_out = Sequential(layers: [ 447 | Linear(innerDim, dim), 448 | Dropout(p: dropout) 449 | ]) 450 | 451 | super.init() 452 | } 453 | 454 | func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, rope: (MLXArray, Float)? = nil) -> MLXArray { 455 | let batch = x.shape[0] 456 | let seqLen = x.shape[1] 457 | 458 | var query = to_q(x) 459 | var key = to_k(x) 460 | var value = to_v(x) 461 | 462 | if let rope { 463 | let (freqs, xposScale) = rope 464 | let qXposScale = xposScale 465 | let kXposScale = pow(xposScale, -1.0) 466 | 467 | query = applyRotaryPosEmb(t: query, freqs: freqs, scale: qXposScale) 468 | key = applyRotaryPosEmb(t: key, freqs: freqs, scale: kXposScale) 469 | } 470 | 471 | query = rearrangeQuery(query, heads: heads) 472 | key = rearrangeQuery(key, heads: heads) 473 | value = rearrangeQuery(value, heads: heads) 474 | 475 | var attnMask: MLXArray? = nil 476 | if let mask = mask { 477 | let reshapedMask = mask.reshaped([mask.shape[0], 1, 1, mask.shape[1]]) 478 | attnMask = MLX.repeated(reshapedMask, count: heads, axis: 1) 479 | } 480 | 481 | let scaleFactor = 1.0 / sqrt(Double(query.shape[query.shape.count - 1])) 482 | var output = MLXFast.scaledDotProductAttention(queries: query, keys: key, values: value, scale: Float(scaleFactor), mask: attnMask) 483 | 484 | output = output.transposed(axes: [0, 2, 1, 3]).reshaped([batch, seqLen, -1]) 485 | output = to_out(output) 486 | 487 | if let mask = mask { 488 | let maskReshaped = mask.reshaped([batch, seqLen, 1]) 489 | output = MLX.where(maskReshaped, MLX.logicalNot(maskReshaped), 0.0) 490 | } 491 | 492 | return output 493 | } 494 | 495 | private func rearrangeQuery(_ query: MLXArray, heads: Int) -> MLXArray { 496 | let batchSize = query.shape[0] 497 | let seqLength = query.shape[1] 498 | let headDim = query.shape[2] / heads 499 | return query.reshaped([batchSize, seqLength, heads, headDim]).transposed(axes: [0, 2, 1, 3]) 500 | } 501 | } 502 | 503 | // DiT block 504 | 505 | class DiTBlock: Module { 506 | let attn_norm: AdaLayerNormZero 507 | let attn: Attention 508 | let ff_norm: LayerNorm 509 | let ff: FeedForward 510 | 511 | init(dim: Int, heads: Int, dimHead: Int, ffMult: Int = 4, dropout: Float = 0.1) { 512 | self.attn_norm = AdaLayerNormZero(dim: dim) 513 | self.attn = Attention(dim: dim, heads: heads, dimHead: dimHead, dropout: dropout) 514 | self.ff_norm = LayerNorm(dimensions: dim, eps: 1e-6, affine: false) 515 | self.ff = FeedForward(dim: dim, mult: ffMult, dropout: dropout, approximate: "tanh") 516 | 517 | super.init() 518 | } 519 | 520 | func callAsFunction(_ x: MLXArray, t: MLXArray, mask: MLXArray? = nil, rope: (MLXArray, Float)? = nil) -> MLXArray { 521 | let (norm, gateMsa, shiftMlp, scaleMlp, gateMlp) = attn_norm(x, emb: t) 522 | let attnOutput = attn(norm, mask: mask, rope: rope) 523 | var output = x + gateMsa.expandedDimensions(axis: 1) * attnOutput 524 | let normedOutput = ff_norm(output) * (1 + scaleMlp.expandedDimensions(axis: 1)) + shiftMlp.expandedDimensions(axis: 1) 525 | let ffOutput = ff(normedOutput) 526 | output = output + MLX.expandedDimensions(gateMlp, axis: 1) * ffOutput 527 | return output 528 | } 529 | } 530 | 531 | // time step conditioning embedding 532 | 533 | class TimestepEmbedding: Module { 534 | let time_embed: SinusPositionEmbedding 535 | let time_mlp: Sequential 536 | 537 | init(dim: Int, freqEmbedDim: Int = 256) { 538 | self.time_embed = SinusPositionEmbedding(dim: freqEmbedDim) 539 | 540 | self.time_mlp = Sequential( 541 | layers: [Linear(freqEmbedDim, dim), 542 | SiLU(), 543 | Linear(dim, dim)] 544 | ) 545 | 546 | super.init() 547 | } 548 | 549 | func callAsFunction(_ timestep: MLXArray) -> MLXArray { 550 | let timeHidden = time_embed(timestep) 551 | let time = time_mlp(timeHidden) 552 | return time 553 | } 554 | } 555 | -------------------------------------------------------------------------------- /Sources/F5TTS/Resources/mel_filters.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasnewman/f5-tts-swift/36899100db60159022fa20620facfa0404871c64/Sources/F5TTS/Resources/mel_filters.npy -------------------------------------------------------------------------------- /Sources/F5TTS/Resources/test_en_1_ref_short.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasnewman/f5-tts-swift/36899100db60159022fa20620facfa0404871c64/Sources/F5TTS/Resources/test_en_1_ref_short.wav -------------------------------------------------------------------------------- /Sources/f5-tts-generate/GenerateCommand.swift: -------------------------------------------------------------------------------- 1 | import ArgumentParser 2 | import F5TTS 3 | import Foundation 4 | import MLX 5 | import Vocos 6 | 7 | @main 8 | struct GenerateAudio: AsyncParsableCommand { 9 | @Argument(help: "Text to generate speech from") 10 | var text: String 11 | 12 | @Option(name: .long, help: "Duration of the generated audio in seconds") 13 | var duration: Double? 14 | 15 | @Option(name: .long, help: "Path to the reference audio file") 16 | var refAudioPath: String? 17 | 18 | @Option(name: .long, help: "Text spoken in the reference audio") 19 | var refAudioText: String? 20 | 21 | @Option(name: .long, help: "Model name to use") 22 | var model: String = "lucasnewman/f5-tts-mlx" 23 | 24 | @Option(name: .long, help: "Output path for the generated audio") 25 | var outputPath: String = "output.wav" 26 | 27 | @Option(name: .long, help: "The number of steps to use for ODE sampling") 28 | var steps: Int = 8 29 | 30 | @Option(name: .long, help: "Method to use for ODE sampling. Options are 'euler', 'midpoint', and 'rk4'.") 31 | var method: String = "rk4" 32 | 33 | @Option(name: .long, help: "Strength of classifier free guidance") 34 | var cfg: Double = 2.0 35 | 36 | @Option(name: .long, help: "Coefficient for sway sampling") 37 | var sway: Double = -1.0 38 | 39 | @Option(name: .long, help: "Speed factor for the duration heuristic") 40 | var speed: Double = 1.0 41 | 42 | @Option(name: .long, help: "Seed for noise generation") 43 | var seed: Int? 44 | 45 | func run() async throws { 46 | print("Loading F5-TTS model...") 47 | let f5tts = try await F5TTS.fromPretrained(repoId: model) { progress in 48 | print(" -- \(progress.completedUnitCount) of \(progress.totalUnitCount)") 49 | } 50 | 51 | let startTime = Date() 52 | 53 | let generatedAudio = try await f5tts.generate( 54 | text: text, 55 | referenceAudioURL: refAudioPath != nil ? URL(filePath: refAudioPath!) : nil, 56 | referenceAudioText: refAudioText, 57 | duration: duration, 58 | steps: steps, 59 | method: F5TTS.ODEMethod(rawValue: method)!, 60 | cfg: cfg, 61 | sway: sway, 62 | speed: speed, 63 | seed: seed 64 | ) 65 | 66 | let elapsedTime = Date().timeIntervalSince(startTime) 67 | print("Generated \(Double(generatedAudio.shape[0]) / Double(F5TTS.sampleRate)) seconds of audio in \(elapsedTime) seconds.") 68 | 69 | try AudioUtilities.saveAudioFile(url: URL(filePath: outputPath), samples: generatedAudio) 70 | print("Saved audio to: \(outputPath)") 71 | } 72 | } 73 | --------------------------------------------------------------------------------