├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── help_request.md ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENCE ├── README.md └── src ├── lib.rs ├── models ├── coqui.rs ├── gtts │ ├── languages.rs │ ├── mod.rs │ └── url.rs ├── meta │ ├── bs1770.rs │ ├── mod.rs │ └── utils.rs ├── mod.rs ├── msedge.rs ├── parler │ ├── mod.rs │ └── model.rs └── tts_rs.rs ├── test.rs └── utils.rs /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Report a bug to help improve the project 4 | title: "[Bug]: " 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Description 11 | 12 | A clear and concise description of the bug. 13 | 14 | ## Steps to Reproduce 15 | 16 | 1. Go to '...' 17 | 2. Click on '...' 18 | 3. Scroll down to '...' 19 | 4. See the error 20 | 21 | ## Expected Behavior 22 | 23 | A clear and concise description of what you expected to happen. 24 | 25 | ## Screenshots (if applicable) 26 | 27 | If applicable, add screenshots to help explain the problem. 28 | 29 | ## System Information 30 | 31 | - **OS**: [e.g., Ubuntu 22.04, Windows 10] 32 | - **Compiler/Version**: [e.g., Clang 17, GCC 13] 33 | - **Hardware**: [e.g., x86_64, ARM] 34 | 35 | ### TIPS 36 | when giving the hardware information we recomed using `hostnamectl` if your on a linux system 37 | 38 | ## Additional Context 39 | 40 | Add any other context about the problem here. -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest a new feature or improvement for the project 4 | title: "[Feature]: " 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Description 11 | 12 | A clear and concise description of the feature or improvement you’d like to see. 13 | 14 | ## Motivation 15 | 16 | Why is this feature important? What problem does it solve? 17 | 18 | ## Proposed Solution 19 | 20 | Describe how this feature could be implemented or any ideas you have. 21 | 22 | ## Alternatives Considered 23 | 24 | Have you considered any alternatives? If so, please describe them. 25 | 26 | ## Additional Context 27 | 28 | Add any other context, screenshots, or references related to the feature request. 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/help_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Help Request 3 | about: Request help or clarification on using the project 4 | title: "[Help]: " 5 | labels: question 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Description 11 | 12 | A clear and concise description of what you need help with. 13 | 14 | ## What I’ve Tried 15 | 16 | Describe what steps you’ve already taken to solve the issue on your own. 17 | 18 | ## Expected Outcome 19 | 20 | What you were expecting to happen. 21 | 22 | ## System Information (if applicable) 23 | 24 | - **OS**: [e.g., Ubuntu 22.04, Windows 10] 25 | - **Compiler/Version**: [e.g., Clang 17, GCC 13] 26 | - **Hardware**: [e.g., x86_64, ARM] 27 | 28 | ### TIPS 29 | when giving the hardware information we recomed using `hostnamectl` if your on a linux system 30 | 31 | 32 | ## Additional Context 33 | 34 | Any extra details, screenshots, or relevant links. 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | # /src/main.rs 3 | .idea 4 | .vscode 5 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "natural-tts" 3 | version = "0.2.1" 4 | edition = "2021" 5 | license = "MIT" 6 | authors = ["Creative Coders "] 7 | description = "High-level bindings to a variety of text-to-speech libraries." 8 | repository = "https://github.com/CodersCreative/natural-tts" 9 | readme = "README.md" 10 | publish = true 11 | keywords = ["text-to-speech", "coqui-ai", "gtts", "parler", "tts"] 12 | categories = ["science", "api-bindings"] 13 | 14 | [features] 15 | meta = ["dep:rand", "dep:tracing-chrome", "dep:tracing-subscriber", "dep:serde_json", "dep:hf-hub", "dep:candle-core", "dep:candle-nn", "dep:candle-transformers"] 16 | tts-rs = ["dep:tts"] 17 | coqui = ["dep:pyo3"] 18 | parler = ["dep:tokenizers", "meta"] 19 | gtts = ["dep:percent-encoding", "dep:minreq"] 20 | msedge = ["dep:msedge-tts"] 21 | py_tts = ["coqui"] 22 | non_py_tts = ["parler", "msedge", "tts-rs", "gtts"] 23 | default = ["gtts"] 24 | test = ["default", "coqui"] 25 | 26 | [dependencies] 27 | candle-core = { version = "0.8.3", optional = true } 28 | candle-nn = { version = "0.8.3", optional = true } 29 | candle-transformers = { version = "0.8.3", optional = true} 30 | derive_builder = { version = "0.20.2"} 31 | hf-hub = {version = "0.4.2", optional = true} 32 | hound = {version = "3.5.1"} 33 | msedge-tts = {version = "0.2.4", optional = true} 34 | pyo3 = { version = "0.23.5", features = ["auto-initialize"], optional = true} 35 | rand = {version = "0.8.5", optional = true} 36 | rodio = {version = "0.20.1"} 37 | serde = "1.0.218" 38 | percent-encoding = {version = "2.1.0", optional = true} 39 | minreq = { version="2.0.3", features=["https"], optional = true } 40 | serde_json = {version = "1.0.140", optional = true} 41 | thiserror = {version = "2.0.12"} 42 | tokenizers = {version = "0.21.0", optional = true} 43 | tracing-chrome = {version = "0.7.2", optional = true} 44 | tracing-subscriber = {version = "0.3.19", optional = true} 45 | tts = {version = "0.26.3", optional = true} 46 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | ***Copyright (c) 2024-2025 natural-tts*** 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Natural TTS [![Rust](https://img.shields.io/badge/Rust-%23000000.svg?e&logo=rust&logoColor=white)]() 2 | 3 | ![Linux](https://img.shields.io/badge/Linux-FCC624?style=for-the-badge&logo=linux&logoColor=black) ![Windows](https://img.shields.io/badge/Windows-0078D6?style=for-the-badge&logo=windows&logoColor=white) ![macOS](https://img.shields.io/badge/mac%20os-000000?style=for-the-badge&logo=macos&logoColor=F0F0F0) 4 | 5 | #### Natural TTS (natural-tts) is a rust crate for easily implementing Text-To-Speech into your rust programs. 6 | 7 | ### To Do: 8 | * [ ] Add support for [Piper TTS](https://github.com/rhasspy/piper). 9 | * [ ] Remove all pyo3 usage. 10 | 11 | ### Available TTS Engines / AIs: 12 | [Parler TTS](https://github.com/huggingface/parler-tts)\ 13 | [Google Gtts](https://github.com/pndurette/gTTS)\ 14 | [TTS-RS](https://github.com/ndarilek/tts-rs)\ 15 | [MSEdge TTS](https://github.com/hs-CN/msedge-tts)\ 16 | [MetaVoice TTS](https://github.com/metavoiceio/metavoice-src)\ 17 | [Coqui TTS](https://github.com/coqui-ai/TTS) 18 | 19 | ### Example of saying something using Gtts but initializing every model. 20 | 21 | ```Rust 22 | use std::error::Error; 23 | use crate::{*, models::{gtts::GttsModel, tts_rs::TtsModel, parler::ParlerModel, msedge::MSEdgeModel, meta::MetaModel}}; 24 | 25 | fn main() -> Result<(), Box>{ 26 | // Create the NaturalTts using the Builder pattern 27 | let mut natural = NaturalTtsBuilder::default() 28 | .default_model(Model::Gtts) 29 | .gtts_model(GttsModel::default()) 30 | .parler_model(ParlerModel::default()) 31 | .msedge_model(MSEdgeModel::default()) 32 | .meta_model(MetaModel::default()) 33 | .tts_model(TtsModel::default()) 34 | .build()?; 35 | 36 | // Use the pre-included function to say a message using the default_model. 37 | let _ = natural.say_auto("Hello, World!".to_string())?; 38 | } 39 | 40 | ``` 41 | 42 | ### Example of saying something using Meta Voice. 43 | 44 | ```Rust 45 | use std::error::Error; 46 | use natural_tts::{*, models::meta::MetaModel}; 47 | 48 | fn main() -> Result<(), Box>{ 49 | // Create the NaturalTts struct using the builder pattern. 50 | let mut natural = NaturalTtsBuilder::default() 51 | .meta_model(MetaModel::default()) 52 | .default_model(Model::Meta) 53 | .build()?; 54 | 55 | // Use the pre-included function to say a message using the default_model. 56 | let _ = natural.say_auto("Hello, World!".to_string())?; 57 | Ok(()) 58 | } 59 | 60 | ``` 61 | 62 | ### Example of saying something using Parler. 63 | 64 | ```Rust 65 | use std::error::Error; 66 | use natural_tts::{*, models::parler::ParlerModel}; 67 | 68 | fn main() -> Result<(), Box>{ 69 | // Create the NaturalTts using the Builder pattern 70 | let mut natural = NaturalTtsBuilder::default() 71 | .parler_model(ParlerModel::default()) 72 | .default_model(Model::Parler) 73 | .build()?; 74 | 75 | // Use the pre-included function to say a message using the default_model. 76 | let _ = natural.say_auto("Hello, World!".to_string())?; 77 | } 78 | 79 | ``` 80 | 81 | ### Example of saying something using Gtts. 82 | 83 | ```Rust 84 | use std::error::Error; 85 | use natural_tts::{*, models::gtts::GttsModel}; 86 | 87 | fn main() -> Result<(), Box>{ 88 | // Create the NaturalTts struct using the builder pattern. 89 | let mut natural = NaturalTtsBuilder::default() 90 | .gtts_model(GttsModel::default()) 91 | .default_model(Model::Gtts) 92 | .build()?; 93 | 94 | // Use the pre-included function to say a message using the default_model. 95 | let _ = natural.say_auto("Hello, World!".to_string())?; 96 | Ok(()) 97 | } 98 | 99 | ``` 100 | 101 | ### Example of saying something using MSEdge. 102 | 103 | ```Rust 104 | use std::error::Error; 105 | use natural_tts::{*, models::msedge::MSEdgeModel}; 106 | 107 | fn main() -> Result<(), Box>{ 108 | 109 | // Create the NaturalTts struct using the builder pattern. 110 | let mut natural = NaturalTtsBuilder::default() 111 | .msedge_model(MSEdgeModel::default()) 112 | .default_model(Model::MSEdge) 113 | .build()?; 114 | 115 | // Use the pre-included function to say a message using the default_model. 116 | let _ = natural.say_auto("Hello, World!".to_string())?; 117 | Ok(()) 118 | } 119 | 120 | ``` 121 | 122 | ### Example of saying something using TTS. 123 | 124 | ```Rust 125 | use std::error::Error; 126 | use natural_tts::{*, models::parler::TtsModel}; 127 | 128 | fn main() -> Result<(), Box>{ 129 | 130 | // Create the NaturalTts struct using the builder pattern. 131 | let mut natural = NaturalTtsBuilder::default() 132 | .tts_model(TtsModel::default()) 133 | .default_model(Model::TTS) 134 | .build()?; 135 | 136 | // Use the pre-included function to say a message using the default_model. 137 | let _ = natural.say_auto("Hello, World!".to_string())?; 138 | Ok(()) 139 | } 140 | 141 | ``` 142 | 143 | ### Example of saying something using Coqui Tts. 144 | #### Disclaimer : Currently only in test feature. 145 | 146 | ```Rust 147 | use std::error::Error; 148 | use natural_tts::{*, models::parler::CoquiModel}; 149 | 150 | fn main() -> Result<(), Box>{ 151 | 152 | // Create the NaturalTts struct using the builder pattern. 153 | let mut natural = NaturalTtsBuilder::default() 154 | .coqui_model(CoquiModel::default()) 155 | .default_model(Model::Coqui) 156 | .build().unwrap(); 157 | 158 | // Use the pre-included function to say a message using the default_model. 159 | let _ = natural.say_auto("Hello, World!".to_string())?; 160 | Ok(()) 161 | } 162 | 163 | ``` 164 | 165 | ## Contributing. 166 | 167 | Pull requests are welcome. For major changes, please open an issue first 168 | to discuss what you would like to change. 169 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024-2025 natural-tts 2 | // 3 | // Permission is hereby granted, free of charge, to any person obtaining a copy 4 | // of this software and associated documentation files (the "Software"), to deal 5 | // in the Software without restriction, including without limitation the rights 6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | // copies of the Software, and to permit persons to whom the Software is 8 | // furnished to do so, subject to the following conditions: 9 | // 10 | // The above copyright notice and this permission notice shall be included in all 11 | // copies or substantial portions of the Software. 12 | // 13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | // SOFTWARE. 20 | pub mod models; 21 | mod test; 22 | mod utils; 23 | 24 | use crate::models::NaturalModelTrait; 25 | use derive_builder::Builder; 26 | use std::error::Error; 27 | use thiserror::Error as TError; 28 | 29 | #[cfg(feature = "gtts")] 30 | use crate::models::gtts; 31 | #[cfg(feature = "meta")] 32 | use crate::models::meta; 33 | #[cfg(feature = "msedge")] 34 | use crate::models::msedge; 35 | #[cfg(feature = "parler")] 36 | use crate::models::parler; 37 | #[cfg(feature = "tts-rs")] 38 | use crate::models::tts_rs::TtsModel; 39 | 40 | #[derive(Builder, Clone, Default)] 41 | #[builder(setter(into))] 42 | pub struct NaturalTts { 43 | pub default_model: Option, 44 | #[cfg(feature = "tts-rs")] 45 | #[builder(default = "None")] 46 | pub tts_model: Option, 47 | 48 | #[cfg(feature = "parler")] 49 | #[builder(default = "None")] 50 | pub parler_model: Option, 51 | 52 | #[cfg(feature = "coqui")] 53 | #[builder(default = "None")] 54 | pub coqui_model: Option, 55 | 56 | #[cfg(feature = "gtts")] 57 | #[builder(default = "None")] 58 | pub gtts_model: Option, 59 | 60 | #[cfg(feature = "msedge")] 61 | #[builder(default = "None")] 62 | pub msedge_model: Option, 63 | 64 | #[cfg(feature = "meta")] 65 | #[builder(default = "None")] 66 | pub meta_model: Option, 67 | } 68 | 69 | impl NaturalModelTrait for NaturalTts { 70 | type SynthesizeType = f32; 71 | 72 | fn say(&mut self, message: String) -> Result<(), Box> { 73 | if let Some(model) = &self.default_model { 74 | return match model { 75 | #[cfg(feature = "coqui")] 76 | Model::Coqui => match &mut self.coqui_model { 77 | Some(x) => x.say(message), 78 | None => Err(Box::new(TtsError::NotLoaded)), 79 | }, 80 | #[cfg(feature = "parler")] 81 | Model::Parler => match &mut self.parler_model { 82 | Some(x) => x.say(message), 83 | None => Err(Box::new(TtsError::NotLoaded)), 84 | }, 85 | #[cfg(feature = "tts-rs")] 86 | Model::TTS => match &mut self.tts_model { 87 | Some(x) => x.say(message), 88 | None => Err(Box::new(TtsError::NotLoaded)), 89 | }, 90 | #[cfg(feature = "msedge")] 91 | Model::MSEdge => match &mut self.msedge_model { 92 | Some(x) => x.say(message), 93 | None => Err(Box::new(TtsError::NotLoaded)), 94 | }, 95 | #[cfg(feature = "meta")] 96 | Model::Meta => match &mut self.meta_model { 97 | Some(x) => x.say(message), 98 | None => Err(Box::new(TtsError::NotLoaded)), 99 | }, 100 | #[cfg(feature = "gtts")] 101 | _ => match &mut self.gtts_model { 102 | Some(x) => x.say(message), 103 | None => Err(Box::new(TtsError::NotLoaded)), 104 | }, 105 | }; 106 | } 107 | 108 | Err(Box::new(TtsError::NoDefaultModel)) 109 | } 110 | 111 | fn synthesize( 112 | &mut self, 113 | message: String, 114 | ) -> Result, Box> { 115 | if let Some(model) = &self.default_model { 116 | return match model { 117 | #[cfg(feature = "coqui")] 118 | Model::Coqui => match &mut self.coqui_model { 119 | Some(x) => x.synthesize(message), 120 | None => Err(Box::new(TtsError::NotLoaded)), 121 | }, 122 | #[cfg(feature = "parler")] 123 | Model::Parler => match &mut self.parler_model { 124 | Some(x) => x.synthesize(message), 125 | None => Err(Box::new(TtsError::NotLoaded)), 126 | }, 127 | #[cfg(feature = "tts-rs")] 128 | Model::TTS => match &mut self.tts_model { 129 | Some(x) => x.synthesize(message), 130 | None => Err(Box::new(TtsError::NotLoaded)), 131 | }, 132 | #[cfg(feature = "msedge")] 133 | Model::MSEdge => match &mut self.msedge_model { 134 | Some(x) => x.synthesize(message), 135 | None => Err(Box::new(TtsError::NotLoaded)), 136 | }, 137 | #[cfg(feature = "meta")] 138 | Model::Meta => match &mut self.meta_model { 139 | Some(x) => x.synthesize(message), 140 | None => Err(Box::new(TtsError::NotLoaded)), 141 | }, 142 | #[cfg(feature = "gtts")] 143 | _ => match &mut self.gtts_model { 144 | Some(x) => x.synthesize(message), 145 | None => Err(Box::new(TtsError::NotLoaded)), 146 | }, 147 | }; 148 | } 149 | 150 | Err(Box::new(TtsError::NoDefaultModel)) 151 | } 152 | 153 | fn save(&mut self, message: String, path: String) -> Result<(), Box> { 154 | if let Some(model) = &self.default_model { 155 | return match model { 156 | #[cfg(feature = "coqui")] 157 | Model::Coqui => match &mut self.coqui_model { 158 | Some(x) => x.save(message, path), 159 | None => Err(Box::new(TtsError::NotLoaded)), 160 | }, 161 | #[cfg(feature = "parler")] 162 | Model::Parler => match &mut self.parler_model { 163 | Some(x) => x.save(message, path), 164 | None => Err(Box::new(TtsError::NotLoaded)), 165 | }, 166 | #[cfg(feature = "tts-rs")] 167 | Model::TTS => match &mut self.tts_model { 168 | Some(x) => x.save(message, path), 169 | None => Err(Box::new(TtsError::NotLoaded)), 170 | }, 171 | #[cfg(feature = "msedge")] 172 | Model::MSEdge => match &mut self.msedge_model { 173 | Some(x) => x.save(message, path), 174 | None => Err(Box::new(TtsError::NotLoaded)), 175 | }, 176 | #[cfg(feature = "meta")] 177 | Model::Meta => match &mut self.meta_model { 178 | Some(x) => x.save(message, path), 179 | None => Err(Box::new(TtsError::NotLoaded)), 180 | }, 181 | #[cfg(feature = "gtts")] 182 | _ => match &mut self.gtts_model { 183 | Some(x) => x.save(message, path), 184 | None => Err(Box::new(TtsError::NotLoaded)), 185 | }, 186 | }; 187 | } 188 | 189 | Err(Box::new(TtsError::NoDefaultModel)) 190 | } 191 | } 192 | 193 | impl NaturalTts { 194 | pub fn say_auto(&mut self, message: String) -> Result<(), Box> { 195 | self.say(message) 196 | } 197 | 198 | pub fn save_auto(&mut self, message: String, path: String) -> Result<(), Box> { 199 | self.save(message, path) 200 | } 201 | 202 | pub fn synthesize_auto( 203 | &mut self, 204 | message: String, 205 | ) -> Result, Box> { 206 | self.synthesize(message) 207 | } 208 | } 209 | 210 | #[derive(Default, Clone)] 211 | pub enum Model { 212 | #[cfg(feature = "coqui")] 213 | Coqui, 214 | 215 | #[cfg(feature = "parler")] 216 | Parler, 217 | 218 | #[cfg(feature = "tts-rs")] 219 | TTS, 220 | 221 | #[cfg(feature = "msedge")] 222 | MSEdge, 223 | 224 | #[cfg(feature = "meta")] 225 | Meta, 226 | 227 | #[cfg(feature = "gtts")] 228 | #[default] 229 | Gtts, 230 | } 231 | 232 | #[derive(Debug, TError, Clone)] 233 | pub enum TtsError { 234 | #[error("Not Supported")] 235 | NotSupported, 236 | #[error("Operation failed")] 237 | OperationFailed, 238 | #[error("Model Not Loaded")] 239 | NotLoaded, 240 | #[error("Didn't Save")] 241 | NotSaved, 242 | #[error("Default model not set")] 243 | NoDefaultModel, 244 | #[error("Tensor Error")] 245 | Tensor, 246 | #[error("No Tokenizer Key")] 247 | NoTokenizerKey, 248 | #[error("Json Error")] 249 | Json, 250 | } 251 | -------------------------------------------------------------------------------- /src/models/coqui.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use pyo3::prelude::*; 3 | 4 | #[derive(Debug)] 5 | pub struct CoquiModel { 6 | model: Py, 7 | device: String, 8 | } 9 | 10 | impl Clone for CoquiModel { 11 | fn clone(&self) -> Self { 12 | return Python::with_gil(|py| -> Self { 13 | return Self { 14 | model: self.model.clone_ref(py), 15 | device: self.device.clone(), 16 | }; 17 | }); 18 | } 19 | } 20 | 21 | impl CoquiModel { 22 | pub fn new(model_name: String, use_gpu: bool) -> Result> { 23 | let m = Python::with_gil(|py| -> Result> { 24 | let torch = py.import("torch")?; 25 | let tts = py.import("TTS.api")?; 26 | 27 | let cuda: bool = torch 28 | .getattr("cuda")? 29 | .getattr("is_available")? 30 | .call0()? 31 | .extract()?; 32 | 33 | let device: String = if cuda && use_gpu { 34 | "cuda:0".to_string() 35 | } else { 36 | "cpu".to_string() 37 | }; 38 | 39 | let model = tts 40 | .getattr("TTS")? 41 | .call1((("model_name", model_name), ("progress_bar", false)))? 42 | .getattr("to")? 43 | .call1((device.clone(), ("return_tensors", "pt")))? 44 | .unbind(); 45 | 46 | return Ok(Self { model, device }); 47 | }); 48 | 49 | return m; 50 | } 51 | 52 | pub fn generate(&self, message: String, path: String) -> Result<(), Box> { 53 | return Python::with_gil(|py| -> Result<(), Box> { 54 | self.model 55 | .getattr(py, "tts_to_file")? 56 | .call1(py, (("text", message), ("file_path", path)))?; 57 | Ok(()) 58 | }); 59 | } 60 | } 61 | 62 | impl Default for CoquiModel { 63 | fn default() -> Self { 64 | return Self::new("tts_models/en/ljspeech/vits".to_string(), true).unwrap(); 65 | } 66 | } 67 | 68 | impl NaturalModelTrait for CoquiModel { 69 | type SynthesizeType = f32; 70 | 71 | fn save(&mut self, message: String, path: String) -> Result<(), Box> { 72 | let _ = self.generate(message, path.clone())?; 73 | did_save(path.as_str()) 74 | } 75 | 76 | fn say(&mut self, message: String) -> Result<(), Box> { 77 | speak_model(self, message) 78 | } 79 | 80 | fn synthesize( 81 | &mut self, 82 | message: String, 83 | ) -> Result, Box> { 84 | synthesize_model(self, message) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/models/gtts/languages.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | /// Enum containing all the languages supported by the GTTS API 4 | #[derive(Debug, Clone)] 5 | pub enum Languages { 6 | /// ISO code: af 7 | Afrikaans, 8 | /// ISO code: ar 9 | Arabic, 10 | /// ISO code: bg 11 | Bulgarian, 12 | /// ISO code: bn 13 | Bengali, 14 | /// ISO code: bs 15 | Bosnian, 16 | /// ISO code: ca 17 | Catalan, 18 | /// ISO code: cs 19 | Czech, 20 | /// ISO code: cy 21 | Welsh, 22 | /// ISO code: da 23 | Danish, 24 | /// ISO code: de 25 | German, 26 | /// ISO code: el 27 | Greek, 28 | /// ISO code: en 29 | English, 30 | /// ISO code: eo 31 | Esperanto, 32 | /// ISO code: es 33 | Spanish, 34 | /// ISO code: et 35 | Estonian, 36 | /// ISO code: fi 37 | Finnish, 38 | /// ISO code: fr 39 | French, 40 | /// ISO code: gu 41 | Gujarati, 42 | /// ISO code: hi 43 | Hindi, 44 | /// ISO code: hr 45 | Croatian, 46 | /// ISO code: hu 47 | Hungarian, 48 | /// ISO code: hy 49 | Armenian, 50 | /// ISO code: id 51 | Indonesian, 52 | /// ISO code: is 53 | Icelandic, 54 | /// ISO code: it 55 | Italian, 56 | /// ISO code: ja 57 | Japanese, 58 | /// ISO code: jw 59 | Javanese, 60 | /// ISO code: km 61 | Khmer, 62 | /// ISO code: kn 63 | Kannada, 64 | /// ISO code: ko 65 | Korean, 66 | /// ISO code: la 67 | Latin, 68 | /// ISO code: lv 69 | Latvian, 70 | /// ISO code: mk 71 | Macedonian, 72 | /// ISO code: ml 73 | Malayalam, 74 | /// ISO code: mr 75 | Marathi, 76 | /// ISO code: my 77 | MyanmarAKABurmese, 78 | /// ISO code: ne 79 | Nepali, 80 | /// ISO code: nl 81 | Dutch, 82 | /// ISO code: no 83 | Norwegian, 84 | /// ISO code: pl 85 | Polish, 86 | /// ISO code: pt 87 | Portuguese, 88 | /// ISO code: ro 89 | Romanian, 90 | /// ISO code: ru 91 | Russian, 92 | /// ISO code: si 93 | Sinhala, 94 | /// ISO code: sk 95 | Slovak, 96 | /// ISO code: sq 97 | Albanian, 98 | /// ISO code: sr 99 | Serbian, 100 | /// ISO code: su 101 | Sundanese, 102 | /// ISO code: sv 103 | Swedish, 104 | /// ISO code: sw 105 | Swahili, 106 | /// ISO code: ta 107 | Tamil, 108 | /// ISO code: te 109 | Telugu, 110 | /// ISO code: th 111 | Thai, 112 | /// ISO code: tl 113 | Filipino, 114 | /// ISO code: tr 115 | Turkish, 116 | /// ISO code: uk 117 | Ukrainian, 118 | /// ISO code: ur 119 | Urdu, 120 | /// ISO code: vi 121 | Vietnamese, 122 | /// ISO code: zh-CN 123 | Chinese, 124 | } 125 | impl FromStr for Languages { 126 | type Err = String; 127 | fn from_str(s: &str) -> Result { 128 | match s { 129 | "af" => Ok(Languages::Afrikaans), 130 | "ar" => Ok(Languages::Arabic), 131 | "bg" => Ok(Languages::Bulgarian), 132 | "bn" => Ok(Languages::Bengali), 133 | "bs" => Ok(Languages::Bosnian), 134 | "ca" => Ok(Languages::Catalan), 135 | "cs" => Ok(Languages::Czech), 136 | "cy" => Ok(Languages::Welsh), 137 | "da" => Ok(Languages::Danish), 138 | "de" => Ok(Languages::German), 139 | "el" => Ok(Languages::Greek), 140 | "en" => Ok(Languages::English), 141 | "eo" => Ok(Languages::Esperanto), 142 | "es" => Ok(Languages::Spanish), 143 | "et" => Ok(Languages::Estonian), 144 | "fi" => Ok(Languages::Finnish), 145 | "fr" => Ok(Languages::French), 146 | "gu" => Ok(Languages::Gujarati), 147 | "hi" => Ok(Languages::Hindi), 148 | "hr" => Ok(Languages::Croatian), 149 | "hu" => Ok(Languages::Hungarian), 150 | "hy" => Ok(Languages::Armenian), 151 | "id" => Ok(Languages::Indonesian), 152 | "is" => Ok(Languages::Icelandic), 153 | "it" => Ok(Languages::Italian), 154 | "ja" => Ok(Languages::Japanese), 155 | "jw" => Ok(Languages::Javanese), 156 | "km" => Ok(Languages::Khmer), 157 | "kn" => Ok(Languages::Kannada), 158 | "ko" => Ok(Languages::Korean), 159 | "la" => Ok(Languages::Latin), 160 | "lv" => Ok(Languages::Latvian), 161 | "mk" => Ok(Languages::Macedonian), 162 | "ml" => Ok(Languages::Malayalam), 163 | "mr" => Ok(Languages::Marathi), 164 | "my" => Ok(Languages::MyanmarAKABurmese), 165 | "ne" => Ok(Languages::Nepali), 166 | "nl" => Ok(Languages::Dutch), 167 | "no" => Ok(Languages::Norwegian), 168 | "pl" => Ok(Languages::Polish), 169 | "pt" => Ok(Languages::Portuguese), 170 | "ro" => Ok(Languages::Romanian), 171 | "ru" => Ok(Languages::Russian), 172 | "si" => Ok(Languages::Sinhala), 173 | "sk" => Ok(Languages::Slovak), 174 | "sq" => Ok(Languages::Albanian), 175 | "sr" => Ok(Languages::Serbian), 176 | "su" => Ok(Languages::Sundanese), 177 | "sv" => Ok(Languages::Swedish), 178 | "sw" => Ok(Languages::Swahili), 179 | "ta" => Ok(Languages::Tamil), 180 | "te" => Ok(Languages::Telugu), 181 | "th" => Ok(Languages::Thai), 182 | "tl" => Ok(Languages::Filipino), 183 | "tr" => Ok(Languages::Turkish), 184 | "uk" => Ok(Languages::Ukrainian), 185 | "ur" => Ok(Languages::Urdu), 186 | "vi" => Ok(Languages::Vietnamese), 187 | "zh-CN" => Ok(Languages::Chinese), 188 | _ => Err(format!( 189 | "Unknown language: {}. Make sure to use all the supported languages", 190 | s 191 | )), 192 | } 193 | } 194 | } 195 | 196 | impl Languages { 197 | pub fn as_code(&self) -> &'static str { 198 | match self { 199 | Languages::Afrikaans => "af", 200 | Languages::Albanian => "sq", 201 | Languages::Arabic => "ar", 202 | Languages::Armenian => "hy", 203 | Languages::Bengali => "bn", 204 | Languages::Bosnian => "bs", 205 | Languages::Bulgarian => "bg", 206 | Languages::Catalan => "ca", 207 | Languages::Chinese => "zh-CN", 208 | Languages::Croatian => "hr", 209 | Languages::Czech => "cs", 210 | Languages::Danish => "da", 211 | Languages::Dutch => "nl", 212 | Languages::English => "en", 213 | Languages::Esperanto => "eo", 214 | Languages::Estonian => "et", 215 | Languages::Filipino => "tl", 216 | Languages::Finnish => "fi", 217 | Languages::French => "fr", 218 | Languages::German => "de", 219 | Languages::Greek => "el", 220 | Languages::Gujarati => "gu", 221 | Languages::Hindi => "hi", 222 | Languages::Hungarian => "hu", 223 | Languages::Icelandic => "is", 224 | Languages::Indonesian => "id", 225 | Languages::Italian => "it", 226 | Languages::Japanese => "ja", 227 | Languages::Javanese => "jw", 228 | Languages::Kannada => "kn", 229 | Languages::Khmer => "km", 230 | Languages::Korean => "ko", 231 | Languages::Latin => "la", 232 | Languages::Latvian => "lv", 233 | Languages::Macedonian => "mk", 234 | Languages::Marathi => "mr", 235 | Languages::Nepali => "ne", 236 | Languages::Norwegian => "no", 237 | Languages::Polish => "pl", 238 | Languages::Portuguese => "pt", 239 | Languages::Romanian => "ro", 240 | Languages::Russian => "ru", 241 | Languages::Serbian => "sr", 242 | Languages::Sinhala => "si", 243 | Languages::Slovak => "sk", 244 | Languages::Spanish => "es", 245 | Languages::Swahili => "sw", 246 | Languages::Swedish => "sv", 247 | Languages::Tamil => "ta", 248 | Languages::Telugu => "te", 249 | Languages::Thai => "th", 250 | Languages::Turkish => "tr", 251 | Languages::Ukrainian => "uk", 252 | Languages::Urdu => "ur", 253 | Languages::Vietnamese => "vi", 254 | Languages::Welsh => "cy", 255 | Languages::MyanmarAKABurmese => "my", 256 | Languages::Malayalam => "ml", 257 | Languages::Sundanese => "su", 258 | } 259 | } 260 | } 261 | -------------------------------------------------------------------------------- /src/models/gtts/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod languages; 2 | pub mod url; 3 | use super::*; 4 | use minreq::get; 5 | use std::io::Write; 6 | use url::EncodedFragment; 7 | 8 | #[derive(Clone, Debug)] 9 | pub struct GttsModel { 10 | pub volume: f32, 11 | pub language: languages::Languages, 12 | pub tld: String, 13 | } 14 | 15 | pub enum Speed { 16 | Normal, 17 | Slow, 18 | } 19 | 20 | impl GttsModel { 21 | pub fn new(volume: f32, language: languages::Languages, tld: String) -> Self { 22 | Self { 23 | language, 24 | volume, 25 | tld, 26 | } 27 | } 28 | 29 | pub fn generate(&self, message: String, path: String) -> Result<(), Box> { 30 | let len = message.len(); 31 | if len > 100 { 32 | return Err(format!("The text is too long. Max length is {}", 100).into()); 33 | } 34 | let language = self.language.as_code(); 35 | let text = EncodedFragment::fragmenter(&message)?; 36 | let rep = get(format!("https://translate.google.{}/translate_tts?ie=UTF-8&q={}&tl={}&total=1&idx=0&textlen={}&tl={}&client=tw-ob", self.tld, text.encoded, language, len, language)) 37 | .send() 38 | .map_err(|e| format!("{}", e))?; 39 | let mut file = File::create(&path)?; 40 | let bytes = rep.as_bytes(); 41 | let _ = file.write_all(bytes)?; 42 | 43 | Ok(()) 44 | } 45 | } 46 | 47 | impl Default for GttsModel { 48 | fn default() -> Self { 49 | return Self::new(1.0, languages::Languages::English, String::from("com")); 50 | } 51 | } 52 | 53 | impl NaturalModelTrait for GttsModel { 54 | type SynthesizeType = f32; 55 | fn save(&mut self, message: String, path: String) -> Result<(), Box> { 56 | let _ = self.generate(message, path.clone())?; 57 | did_save(path.as_str()) 58 | } 59 | 60 | fn say(&mut self, message: String) -> Result<(), Box> { 61 | speak_model(self, message) 62 | } 63 | 64 | fn synthesize( 65 | &mut self, 66 | message: String, 67 | ) -> Result, Box> { 68 | synthesize_model(self, message) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/models/gtts/url.rs: -------------------------------------------------------------------------------- 1 | use percent_encoding::utf8_percent_encode; 2 | use percent_encoding::AsciiSet; 3 | use percent_encoding::CONTROLS; 4 | 5 | const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`'); 6 | 7 | pub struct EncodedFragment { 8 | pub encoded: String, 9 | pub decoded: String, 10 | } 11 | 12 | impl EncodedFragment { 13 | pub fn fragmenter(text: &str) -> Result { 14 | let raw_text = text; 15 | let text = utf8_percent_encode(raw_text, FRAGMENT).to_string(); 16 | if text.is_empty() { 17 | return Err("Empty text".to_string()); 18 | } 19 | Ok(Self { 20 | encoded: text, 21 | decoded: raw_text.to_string(), 22 | }) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/models/meta/bs1770.rs: -------------------------------------------------------------------------------- 1 | // Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs 2 | // BS1770 -- Loudness analysis library conforming to ITU-R BS.1770 3 | // Copyright 2020 Ruud van Asseldonk 4 | 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // A copy of the License has been included in the root of the repository. 8 | 9 | //! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704]. 10 | //! 11 | //! This library offers the building blocks to perform BS.1770 loudness 12 | //! measurements, but you need to put the pieces together yourself. 13 | //! 14 | //! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en 15 | //! 16 | //! # Stereo integrated loudness example 17 | //! 18 | //! ```ignore 19 | //! # fn load_stereo_audio() -> [Vec; 2] { 20 | //! # [vec![0; 48_000], vec![0; 48_000]] 21 | //! # } 22 | //! # 23 | //! let sample_rate_hz = 44_100; 24 | //! let bits_per_sample = 16; 25 | //! let channel_samples: [Vec; 2] = load_stereo_audio(); 26 | //! 27 | //! // When converting integer samples to float, note that the maximum amplitude 28 | //! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit. 29 | //! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32; 30 | //! 31 | //! let channel_power: Vec<_> = channel_samples.iter().map(|samples| { 32 | //! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz); 33 | //! meter.push(samples.iter().map(|&s| s as f32 * normalizer)); 34 | //! meter.into_100ms_windows() 35 | //! }).collect(); 36 | //! 37 | //! let stereo_power = bs1770::reduce_stereo( 38 | //! channel_power[0].as_ref(), 39 | //! channel_power[1].as_ref(), 40 | //! ); 41 | //! 42 | //! let gated_power = bs1770::gated_mean( 43 | //! stereo_power.as_ref() 44 | //! ).unwrap_or(bs1770::Power(0.0)); 45 | //! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs()); 46 | //! ``` 47 | 48 | use std::f32; 49 | 50 | /// Coefficients for a 2nd-degree infinite impulse response filter. 51 | /// 52 | /// Coefficient a0 is implicitly 1.0. 53 | #[derive(Clone)] 54 | struct Filter { 55 | a1: f32, 56 | a2: f32, 57 | b0: f32, 58 | b1: f32, 59 | b2: f32, 60 | 61 | // The past two input and output samples. 62 | x1: f32, 63 | x2: f32, 64 | y1: f32, 65 | y2: f32, 66 | } 67 | 68 | impl Filter { 69 | /// Stage 1 of th BS.1770-4 pre-filter. 70 | pub fn high_shelf(sample_rate_hz: f32) -> Filter { 71 | // Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/ 72 | // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136. 73 | let gain_db = 3.999_843_8; 74 | let q = 0.707_175_25; 75 | let center_hz = 1_681.974_5; 76 | 77 | // Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/ 78 | // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143. 79 | let k = (f32::consts::PI * center_hz / sample_rate_hz).tan(); 80 | let vh = 10.0_f32.powf(gain_db / 20.0); 81 | let vb = vh.powf(0.499_666_78); 82 | let a0 = 1.0 + k / q + k * k; 83 | Filter { 84 | b0: (vh + vb * k / q + k * k) / a0, 85 | b1: 2.0 * (k * k - vh) / a0, 86 | b2: (vh - vb * k / q + k * k) / a0, 87 | a1: 2.0 * (k * k - 1.0) / a0, 88 | a2: (1.0 - k / q + k * k) / a0, 89 | 90 | x1: 0.0, 91 | x2: 0.0, 92 | y1: 0.0, 93 | y2: 0.0, 94 | } 95 | } 96 | 97 | /// Stage 2 of th BS.1770-4 pre-filter. 98 | pub fn high_pass(sample_rate_hz: f32) -> Filter { 99 | // Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/ 100 | // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136. 101 | let q = 0.500_327_05; 102 | let center_hz = 38.135_47; 103 | 104 | // Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/ 105 | // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151 106 | let k = (f32::consts::PI * center_hz / sample_rate_hz).tan(); 107 | Filter { 108 | a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k), 109 | a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k), 110 | b0: 1.0, 111 | b1: -2.0, 112 | b2: 1.0, 113 | 114 | x1: 0.0, 115 | x2: 0.0, 116 | y1: 0.0, 117 | y2: 0.0, 118 | } 119 | } 120 | 121 | /// Feed the next input sample, get the next output sample. 122 | #[inline(always)] 123 | pub fn apply(&mut self, x0: f32) -> f32 { 124 | let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2 125 | - self.a1 * self.y1 126 | - self.a2 * self.y2; 127 | 128 | self.x2 = self.x1; 129 | self.x1 = x0; 130 | self.y2 = self.y1; 131 | self.y1 = y0; 132 | 133 | y0 134 | } 135 | } 136 | 137 | /// Compensated sum, for summing many values of different orders of magnitude 138 | /// accurately. 139 | #[derive(Copy, Clone, PartialEq)] 140 | struct Sum { 141 | sum: f32, 142 | residue: f32, 143 | } 144 | 145 | impl Sum { 146 | #[inline(always)] 147 | fn zero() -> Sum { 148 | Sum { 149 | sum: 0.0, 150 | residue: 0.0, 151 | } 152 | } 153 | 154 | #[inline(always)] 155 | fn add(&mut self, x: f32) { 156 | let sum = self.sum + (self.residue + x); 157 | self.residue = (self.residue + x) - (sum - self.sum); 158 | self.sum = sum; 159 | } 160 | } 161 | 162 | /// The mean of the squares of the K-weighted samples in a window of time. 163 | /// 164 | /// K-weighted power is equivalent to K-weighted loudness, the only difference 165 | /// is one of scale: power is quadratic in sample amplitudes, whereas loudness 166 | /// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power, 167 | /// and K-weighted Loudness Units relative to nominal Full Scale (LKFS). 168 | /// 169 | /// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale) 170 | /// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise 171 | /// interchangeable with the more widespread term “LUFS” (Loudness Units, 172 | /// relative to Full Scale). Loudness units are related to decibels in the 173 | /// following sense: boosting a signal that has a loudness of 174 | /// -LK LUFS by LK dB (by 175 | /// multiplying the amplitude by 10LK/20) will 176 | /// bring the loudness to 0 LUFS. 177 | /// 178 | /// K-weighting refers to a high-shelf and high-pass filter that model the 179 | /// effect that humans perceive a certain amount of power in low frequencies to 180 | /// be less loud than the same amount of power in higher frequencies. In this 181 | /// library the `Power` type is used exclusively to refer to power after applying K-weighting. 182 | /// 183 | /// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the 184 | /// mean square of the samples, if no input samples exceeded the full scale, the 185 | /// power will be in the range [0.0, 1.0]. However, the power delivered by 186 | /// multiple channels, which is a weighted sum over individual channel powers, 187 | /// can exceed this range, because the weighted sum is not normalized. 188 | #[derive(Copy, Clone, PartialEq, PartialOrd)] 189 | pub struct Power(pub f32); 190 | 191 | impl Power { 192 | /// Convert Loudness Units relative to Full Scale into a squared sample amplitude. 193 | /// 194 | /// This is the inverse of `loudness_lkfs`. 195 | pub fn from_lkfs(lkfs: f32) -> Power { 196 | // The inverse of the formula below. 197 | Power(10.0_f32.powf((lkfs + 0.691) * 0.1)) 198 | } 199 | 200 | /// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale. 201 | /// 202 | /// This is the inverse of `from_lkfs`. 203 | pub fn loudness_lkfs(&self) -> f32 { 204 | // Equation 2 (p.5) of BS.1770-4. 205 | -0.691 + 10.0 * self.0.log10() 206 | } 207 | } 208 | 209 | /// A `T` value for non-overlapping windows of audio, 100ms in length. 210 | /// 211 | /// The `ChannelLoudnessMeter` applies K-weighting and then produces the power 212 | /// for non-overlapping windows of 100ms duration. 213 | /// 214 | /// These non-overlapping 100ms windows can later be combined into overlapping 215 | /// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or 216 | /// to perform a gated measurement, or they can be combined into even larger 217 | /// windows for a momentary loudness measurement. 218 | #[derive(Copy, Clone, Debug)] 219 | pub struct Windows100ms { 220 | pub inner: T, 221 | } 222 | 223 | impl Windows100ms { 224 | /// Wrap a new empty vector. 225 | pub fn new() -> Windows100ms> { 226 | Windows100ms { inner: Vec::new() } 227 | } 228 | 229 | /// Apply `as_ref` to the inner value. 230 | pub fn as_ref(&self) -> Windows100ms<&[Power]> 231 | where 232 | T: AsRef<[Power]>, 233 | { 234 | Windows100ms { 235 | inner: self.inner.as_ref(), 236 | } 237 | } 238 | 239 | /// Apply `as_mut` to the inner value. 240 | pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]> 241 | where 242 | T: AsMut<[Power]>, 243 | { 244 | Windows100ms { 245 | inner: self.inner.as_mut(), 246 | } 247 | } 248 | 249 | #[allow(clippy::len_without_is_empty)] 250 | /// Apply `len` to the inner value. 251 | pub fn len(&self) -> usize 252 | where 253 | T: AsRef<[Power]>, 254 | { 255 | self.inner.as_ref().len() 256 | } 257 | } 258 | 259 | /// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio. 260 | /// 261 | /// # Output 262 | /// 263 | /// The output of the meter is an intermediate result in the form of power for 264 | /// 100ms non-overlapping windows. The windows need to be processed further to 265 | /// get one of the instantaneous, momentary, and integrated loudness 266 | /// measurements defined in BS.1770. 267 | /// 268 | /// The windows can also be inspected directly; the data is meaningful 269 | /// on its own (the K-weighted power delivered in that window of time), but it 270 | /// is not something that BS.1770 defines a term for. 271 | /// 272 | /// # Multichannel audio 273 | /// 274 | /// To perform a loudness measurement of multichannel audio, construct a 275 | /// `ChannelLoudnessMeter` per channel, and later combine the measured power 276 | /// with e.g. `reduce_stereo`. 277 | /// 278 | /// # Instantaneous loudness 279 | /// 280 | /// The instantaneous loudness is the power over a 400ms window, so you can 281 | /// average four 100ms windows. No special functionality is implemented to help 282 | /// with that at this time. ([Pull requests would be accepted.][contribute]) 283 | /// 284 | /// # Momentary loudness 285 | /// 286 | /// The momentary loudness is the power over a 3-second window, so you can 287 | /// average thirty 100ms windows. No special functionality is implemented to 288 | /// help with that at this time. ([Pull requests would be accepted.][contribute]) 289 | /// 290 | /// # Integrated loudness 291 | /// 292 | /// Use `gated_mean` to perform an integrated loudness measurement: 293 | /// 294 | /// ```ignore 295 | /// # use std::iter; 296 | /// # use bs1770::{ChannelLoudnessMeter, gated_mean}; 297 | /// # let sample_rate_hz = 44_100; 298 | /// # let samples_per_100ms = sample_rate_hz / 10; 299 | /// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz); 300 | /// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin())); 301 | /// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows()) 302 | /// .unwrap_or(bs1770::Power(0.0)) 303 | /// .loudness_lkfs(); 304 | /// ``` 305 | /// 306 | /// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md 307 | #[derive(Clone)] 308 | pub struct ChannelLoudnessMeter { 309 | /// The number of samples that fit in 100ms of audio. 310 | samples_per_100ms: u32, 311 | 312 | /// Stage 1 filter (head effects, high shelf). 313 | filter_stage1: Filter, 314 | 315 | /// Stage 2 filter (high-pass). 316 | filter_stage2: Filter, 317 | 318 | /// Sum of the squares over non-overlapping windows of 100ms. 319 | windows: Windows100ms>, 320 | 321 | /// The number of samples in the current unfinished window. 322 | count: u32, 323 | 324 | /// The sum of the squares of the samples in the current unfinished window. 325 | square_sum: Sum, 326 | } 327 | 328 | impl ChannelLoudnessMeter { 329 | /// Construct a new loudness meter for the given sample rate. 330 | pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter { 331 | ChannelLoudnessMeter { 332 | samples_per_100ms: sample_rate_hz / 10, 333 | filter_stage1: Filter::high_shelf(sample_rate_hz as f32), 334 | filter_stage2: Filter::high_pass(sample_rate_hz as f32), 335 | windows: Windows100ms::new(), 336 | count: 0, 337 | square_sum: Sum::zero(), 338 | } 339 | } 340 | 341 | /// Feed input samples for loudness analysis. 342 | /// 343 | /// # Full scale 344 | /// 345 | /// Full scale for the input samples is the interval [-1.0, 1.0]. If your 346 | /// input consists of signed integer samples, you can convert as follows: 347 | /// 348 | /// ```ignore 349 | /// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100); 350 | /// # let bits_per_sample = 16_usize; 351 | /// # let samples = &[0_i16]; 352 | /// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`, 353 | /// // one bit is the sign bit. 354 | /// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32; 355 | /// meter.push(samples.iter().map(|&s| s as f32 * normalizer)); 356 | /// ``` 357 | /// 358 | /// # Repeated calls 359 | /// 360 | /// You can call `push` multiple times to feed multiple batches of samples. 361 | /// This is equivalent to feeding a single chained iterator. The leftover of 362 | /// samples that did not fill a full 100ms window is not discarded: 363 | /// 364 | /// ```ignore 365 | /// # use std::iter; 366 | /// # use bs1770::ChannelLoudnessMeter; 367 | /// let sample_rate_hz = 44_100; 368 | /// let samples_per_100ms = sample_rate_hz / 10; 369 | /// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz); 370 | /// 371 | /// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1)); 372 | /// assert_eq!(meter.as_100ms_windows().len(), 0); 373 | /// 374 | /// meter.push(iter::once(0.0)); 375 | /// assert_eq!(meter.as_100ms_windows().len(), 1); 376 | /// ``` 377 | pub fn push>(&mut self, samples: I) { 378 | let normalizer = 1.0 / self.samples_per_100ms as f32; 379 | 380 | // LLVM, if you could go ahead and inline those apply calls, and then 381 | // unroll and vectorize the loop, that'd be terrific. 382 | for x in samples { 383 | let y = self.filter_stage1.apply(x); 384 | let z = self.filter_stage2.apply(y); 385 | 386 | self.square_sum.add(z * z); 387 | self.count += 1; 388 | 389 | // TODO: Should this branch be marked cold? 390 | if self.count == self.samples_per_100ms { 391 | let mean_squares = Power(self.square_sum.sum * normalizer); 392 | self.windows.inner.push(mean_squares); 393 | // We intentionally do not reset the residue. That way, leftover 394 | // energy from this window is not lost, so for the file overall, 395 | // the sum remains more accurate. 396 | self.square_sum.sum = 0.0; 397 | self.count = 0; 398 | } 399 | } 400 | } 401 | 402 | /// Return a reference to the 100ms windows analyzed so far. 403 | pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> { 404 | self.windows.as_ref() 405 | } 406 | 407 | /// Return all 100ms windows analyzed so far. 408 | pub fn into_100ms_windows(self) -> Windows100ms> { 409 | self.windows 410 | } 411 | } 412 | 413 | /// Combine power for multiple channels by taking a weighted sum. 414 | /// 415 | /// Note that BS.1770-4 defines power for a multi-channel signal as a weighted 416 | /// sum over channels which is not normalized. This means that a stereo signal 417 | /// is inherently louder than a mono signal. For a mono signal played back on 418 | /// stereo speakers, you should therefore still apply `reduce_stereo`, passing 419 | /// in the same signal for both channels. 420 | pub fn reduce_stereo( 421 | left: Windows100ms<&[Power]>, 422 | right: Windows100ms<&[Power]>, 423 | ) -> Windows100ms> { 424 | assert_eq!( 425 | left.len(), 426 | right.len(), 427 | "Channels must have the same length." 428 | ); 429 | let mut result = Vec::with_capacity(left.len()); 430 | for (l, r) in left.inner.iter().zip(right.inner) { 431 | result.push(Power(l.0 + r.0)); 432 | } 433 | Windows100ms { inner: result } 434 | } 435 | 436 | /// In-place version of `reduce_stereo` that stores the result in the former left channel. 437 | pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) { 438 | assert_eq!( 439 | left.len(), 440 | right.len(), 441 | "Channels must have the same length." 442 | ); 443 | for (l, r) in left.inner.iter_mut().zip(right.inner) { 444 | l.0 += r.0; 445 | } 446 | } 447 | 448 | /// Perform gating and averaging for a BS.1770-4 integrated loudness measurement. 449 | /// 450 | /// The integrated loudness measurement is not just the average power over the 451 | /// entire signal. BS.1770-4 defines two stages of gating that exclude 452 | /// parts of the signal, to ensure that silent parts do not contribute to the 453 | /// loudness measurement. This function performs that gating, and returns the 454 | /// average power over the windows that were not excluded. 455 | /// 456 | /// The result of this function is the integrated loudness measurement. 457 | /// 458 | /// When no signal remains after applying the gate, this function returns 459 | /// `None`. In particular, this happens when all of the signal is softer than 460 | /// -70 LKFS, including a signal that consists of pure silence. 461 | pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option { 462 | let mut gating_blocks = Vec::with_capacity(windows_100ms.len()); 463 | 464 | // Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.) 465 | let absolute_threshold = Power::from_lkfs(-70.0); 466 | 467 | // Iterate over all 400ms windows. 468 | for window in windows_100ms.inner.windows(4) { 469 | // Note that the sum over channels has already been performed at this point. 470 | let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::()); 471 | 472 | if gating_block_power > absolute_threshold { 473 | gating_blocks.push(gating_block_power); 474 | } 475 | } 476 | 477 | if gating_blocks.is_empty() { 478 | return None; 479 | } 480 | 481 | // Compute the loudness after applying the absolute gate, in order to 482 | // determine the threshold for the relative gate. 483 | let mut sum_power = Sum::zero(); 484 | for &gating_block_power in &gating_blocks { 485 | sum_power.add(gating_block_power.0); 486 | } 487 | let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32)); 488 | 489 | // Stage 2: Apply the relative gate. 490 | let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0); 491 | let mut sum_power = Sum::zero(); 492 | let mut n_blocks = 0_usize; 493 | for &gating_block_power in &gating_blocks { 494 | if gating_block_power > relative_threshold { 495 | sum_power.add(gating_block_power.0); 496 | n_blocks += 1; 497 | } 498 | } 499 | 500 | if n_blocks == 0 { 501 | return None; 502 | } 503 | 504 | let relative_gated_power = Power(sum_power.sum / n_blocks as f32); 505 | Some(relative_gated_power) 506 | } 507 | -------------------------------------------------------------------------------- /src/models/meta/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod bs1770; 2 | pub mod utils; 3 | 4 | use super::{did_save, NaturalModelTrait, SynthesizedAudio}; 5 | use crate::{ 6 | utils::{get_path, play_wav_file}, 7 | TtsError, 8 | }; 9 | use candle_core::{DType, Device, IndexOp, Tensor}; 10 | use candle_nn::VarBuilder; 11 | use candle_transformers::{ 12 | generation::LogitsProcessor, 13 | models::{ 14 | encodec, 15 | metavoice::{adapters, gpt, transformer}, 16 | }, 17 | }; 18 | use derive_builder::Builder; 19 | use hf_hub::api::sync::Api; 20 | use hound::WavSpec; 21 | use rand::{distributions::Distribution, SeedableRng}; 22 | use std::{error::Error, io::Write, path::PathBuf}; 23 | use utils::*; 24 | 25 | const MODEL_NAME: &str = "lmz/candle-metavoice"; 26 | 27 | #[derive(Builder, Clone, Debug, PartialEq)] 28 | pub struct MetaModelOptions { 29 | #[builder(default = "false")] 30 | pub cpu: bool, 31 | #[builder(default = "MODEL_NAME.to_string()")] 32 | model_name: String, 33 | #[builder(default = "false")] 34 | pub tracing: bool, 35 | #[builder(default = "None")] 36 | pub spk_emb: Option, 37 | #[builder(default = "1024")] 38 | pub encodec_ntokens: u32, 39 | #[builder(default = "299792458")] 40 | pub seed: u64, 41 | #[builder(default = "8")] 42 | pub max_tokens: u64, 43 | #[builder(default = "3.0")] 44 | pub guidance_scale: f64, 45 | #[builder(default = "1.0")] 46 | pub temperature: f64, 47 | } 48 | 49 | #[derive(Clone)] 50 | pub struct MetaModel { 51 | pub first_stage_model: transformer::Model, 52 | pub device: Device, 53 | pub first_stage_meta: serde_json::Value, 54 | pub dtype: DType, 55 | pub encodec_weights: PathBuf, 56 | pub second_stage_weights: PathBuf, 57 | pub repo_path: String, 58 | pub seed: u64, 59 | pub guidance_scale: f64, 60 | pub temperature: f64, 61 | pub max_tokens: u64, 62 | pub spk_emb: Option, 63 | pub encodec_ntokens: u32, 64 | } 65 | 66 | impl MetaModel { 67 | pub fn new(options: MetaModelOptions) -> Result> { 68 | use tracing_chrome::ChromeLayerBuilder; 69 | use tracing_subscriber::prelude::*; 70 | 71 | println!("{:?}", &options); 72 | 73 | let _guard = if options.tracing { 74 | let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); 75 | tracing_subscriber::registry().with(chrome_layer).init(); 76 | Some(guard) 77 | } else { 78 | None 79 | }; 80 | 81 | let device = device(options.cpu)?; 82 | let api = Api::new()?; 83 | 84 | let repo = api.model(options.model_name.clone()); 85 | 86 | let first_stage_meta = repo.get("first_stage.meta.json")?; 87 | let first_stage_meta: serde_json::Value = 88 | serde_json::from_reader(&std::fs::File::open(first_stage_meta)?)?; 89 | 90 | let dtype = DType::F32; 91 | let first_stage_config = transformer::Config::cfg1b_v0_1(); 92 | 93 | let first_stage_weights = repo.get("first_stage.safetensors")?; 94 | let first_stage_vb = 95 | unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? }; 96 | let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?; 97 | 98 | let second_stage_weights = repo.get("second_stage.safetensors")?; 99 | 100 | let encodec_weights = Api::new()? 101 | .model("sanchit-gandhi/encodec_24khz".to_string()) 102 | .get("model.safetensors")?; 103 | 104 | println!("Done"); 105 | return Ok(Self { 106 | first_stage_model, 107 | device, 108 | first_stage_meta, 109 | dtype, 110 | encodec_weights, 111 | second_stage_weights, 112 | repo_path: options.model_name, 113 | seed: options.seed, 114 | guidance_scale: options.guidance_scale, 115 | temperature: options.temperature, 116 | max_tokens: options.max_tokens, 117 | spk_emb: options.spk_emb, 118 | encodec_ntokens: options.encodec_ntokens, 119 | }); 120 | } 121 | 122 | pub fn generate(&mut self, prompt: String) -> Result, Box> { 123 | let second_stage_vb = unsafe { 124 | VarBuilder::from_mmaped_safetensors( 125 | &[self.second_stage_weights.clone()], 126 | self.dtype, 127 | &self.device, 128 | )? 129 | }; 130 | let second_stage_config = gpt::Config::cfg1b_v0_1(); 131 | let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?; 132 | 133 | let encodec_device = if self.device.is_metal() { 134 | candle_core::Device::Cpu 135 | } else { 136 | self.device.clone() 137 | }; 138 | 139 | let encodec_vb = unsafe { 140 | VarBuilder::from_mmaped_safetensors( 141 | &[self.encodec_weights.clone()], 142 | self.dtype, 143 | &encodec_device, 144 | )? 145 | }; 146 | let encodec_config = encodec::Config::default(); 147 | let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?; 148 | 149 | let fs_tokenizer = get_fs_tokenizer(self.first_stage_meta.clone())?; 150 | let prompt_tokens = fs_tokenizer.encode(&prompt)?; 151 | let mut tokens = prompt_tokens.clone(); 152 | 153 | let api = Api::new()?; 154 | let repo = api.model(self.repo_path.clone()); 155 | 156 | let spk_emb_file = match &self.spk_emb { 157 | Some(w) => std::path::PathBuf::from(w), 158 | None => repo.get("spk_emb.safetensors")?, 159 | }; 160 | 161 | let spk_emb = candle_core::safetensors::load(&spk_emb_file, &candle_core::Device::Cpu)?; 162 | 163 | let spk_emb = match spk_emb.get("spk_emb") { 164 | None => return Err(TtsError::Tensor.into()), 165 | Some(spk_emb) => spk_emb.to_dtype(self.dtype)?, 166 | }; 167 | 168 | let spk_emb = spk_emb.to_device(&self.device)?; 169 | let mut logits_processor = LogitsProcessor::new(self.seed, Some(self.temperature), None); 170 | 171 | for index in 0..self.max_tokens { 172 | let context_size = if index > 0 { 1 } else { tokens.len() }; 173 | let start_pos = tokens.len().saturating_sub(context_size); 174 | let ctxt = &tokens[start_pos..]; 175 | let input = Tensor::new(ctxt, &self.device)?; 176 | let input = Tensor::stack(&[&input, &input], 0)?; 177 | println!("gen"); 178 | let logits = 179 | self.first_stage_model 180 | .forward(&input, &spk_emb, tokens.len() - context_size)?; 181 | println!("Logits {:?}", logits); 182 | //let logits0 = logits.i((0, 0))?; 183 | //let logits1 = logits.i((1, 0))?; 184 | let logits = logits.i((0, logits.dim(1)? - 1))?; 185 | //let logits = ((logits0 * self.guidance_scale)? + logits1 * (1. - self.guidance_scale))?; 186 | let logits = logits.to_dtype(self.dtype)?; 187 | println!("Logits 2 {:?}", logits); 188 | let next_token = match logits_processor.sample(&logits) { 189 | Ok(x) => x, 190 | Err(e) => { 191 | println!("{}", e.to_string()); 192 | continue; 193 | } 194 | }; 195 | println!("tokened"); 196 | tokens.push(next_token); 197 | std::io::stdout().flush()?; 198 | if next_token == 2048 { 199 | break; 200 | } 201 | } 202 | 203 | let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(self.encodec_ntokens); 204 | let (_, ids1, ids2) = fie2c.decode(&tokens); 205 | let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed + 1337); 206 | let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect(); 207 | let mut hierarchies_in1 = [ 208 | encoded_text.as_slice(), 209 | ids1.as_slice(), 210 | &[self.encodec_ntokens], 211 | ] 212 | .concat(); 213 | let mut hierarchies_in2 = [ 214 | vec![self.encodec_ntokens; encoded_text.len()].as_slice(), 215 | ids2.as_slice(), 216 | &[self.encodec_ntokens], 217 | ] 218 | .concat(); 219 | 220 | hierarchies_in1.resize(second_stage_config.block_size, self.encodec_ntokens); 221 | hierarchies_in2.resize(second_stage_config.block_size, self.encodec_ntokens); 222 | let in_x1 = Tensor::new(hierarchies_in1, &self.device)?; 223 | let in_x2 = Tensor::new(hierarchies_in2, &self.device)?; 224 | let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?; 225 | 226 | let logits = second_stage_model.forward(&in_x)?; 227 | let mut codes = vec![]; 228 | for logits in logits.iter() { 229 | let logits = logits.squeeze(0)?; 230 | let (seq_len, _) = logits.dims2()?; 231 | let mut codes_ = Vec::with_capacity(seq_len); 232 | for step in 0..seq_len { 233 | let logits = logits.i(step)?.to_dtype(DType::F32)?; 234 | let logits = &(&logits / 1.0)?; 235 | let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::()?; 236 | let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?; 237 | let sample = distr.sample(&mut rng) as u32; 238 | codes_.push(sample) 239 | } 240 | codes.push(codes_) 241 | } 242 | 243 | let codes = Tensor::new(codes, &self.device)?.unsqueeze(0)?; 244 | let codes = Tensor::cat(&[in_x, codes], 1)?; 245 | let tilted_encodec = adapters::TiltedEncodec::new(self.encodec_ntokens); 246 | //let tilted_encodec = adapters::TiltedEncodec::new(512); 247 | let codes = codes.i(0)?.to_vec2::()?; 248 | let (_, audio_ids) = tilted_encodec.decode(&codes); 249 | println!("Ids {:?}", audio_ids); 250 | let audio_ids = Tensor::new(audio_ids, &encodec_device) 251 | .unwrap() 252 | .unsqueeze(0) 253 | .unwrap(); 254 | 255 | let pcm = encodec_model.decode(&audio_ids)?; 256 | let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?; 257 | let pcm = normalize_loudness(&pcm, 24_000, true)?; 258 | 259 | let pcm = pcm.to_vec1::()?; 260 | println!("Done"); 261 | return Ok(SynthesizedAudio::new( 262 | pcm, 263 | super::Spec::Wav(WavSpec { 264 | sample_rate: encodec_config.sampling_rate as u32, 265 | channels: encodec_config.audio_channels as u16, 266 | sample_format: hound::SampleFormat::Float, 267 | bits_per_sample: encodec_config.sampling_rate as u16, 268 | }), 269 | None, 270 | )); 271 | } 272 | } 273 | 274 | impl Default for MetaModel { 275 | fn default() -> Self { 276 | return Self::new(MetaModelOptionsBuilder::default().build().unwrap()).unwrap(); 277 | } 278 | } 279 | 280 | impl NaturalModelTrait for MetaModel { 281 | type SynthesizeType = f32; 282 | 283 | fn save(&mut self, message: String, path: String) -> Result<(), Box> { 284 | let data = self.synthesize(message)?; 285 | let mut output = std::fs::File::create(&path)?; 286 | write_pcm_as_wav(&mut output, &data.data, 24_000 as u32)?; 287 | did_save(path.as_str()) 288 | } 289 | 290 | fn say(&mut self, message: String) -> Result<(), Box> { 291 | let path = get_path("temp.wav".to_string()); 292 | self.save(message, path.clone())?; 293 | play_wav_file(&path)?; 294 | std::fs::remove_file(path)?; 295 | Ok(()) 296 | } 297 | 298 | fn synthesize( 299 | &mut self, 300 | message: String, 301 | ) -> Result, Box> { 302 | self.generate(message) 303 | } 304 | } 305 | -------------------------------------------------------------------------------- /src/models/meta/utils.rs: -------------------------------------------------------------------------------- 1 | use super::bs1770; 2 | use crate::TtsError; 3 | use candle_core::utils::{cuda_is_available, metal_is_available}; 4 | use candle_core::{Device, Tensor}; 5 | use candle_transformers::models::metavoice::{tokenizers, transformer}; 6 | use candle_transformers::models::quantized_metavoice::transformer as qtransformer; 7 | use std::error::Error; 8 | use std::io::Write; 9 | use std::path::PathBuf; 10 | 11 | #[derive(Clone, Debug)] 12 | pub enum Transformer { 13 | Normal(transformer::Model), 14 | Quantized(qtransformer::Model), 15 | } 16 | 17 | pub fn get_fs_tokenizer( 18 | first_stage_meta: serde_json::Value, 19 | ) -> Result> { 20 | let first_stage_tokenizer = match first_stage_meta.as_object() { 21 | None => return Err(TtsError::Json.into()), 22 | Some(j) => match j.get("tokenizer") { 23 | None => return Err(TtsError::NoTokenizerKey.into()), 24 | Some(j) => j, 25 | }, 26 | }; 27 | 28 | Ok(tokenizers::BPE::from_json(first_stage_tokenizer, 512)?) 29 | } 30 | pub fn device(cpu: bool) -> Result> { 31 | if cpu { 32 | Ok(Device::Cpu) 33 | } else if cuda_is_available() { 34 | Ok(Device::new_cuda(0)?) 35 | } else if metal_is_available() { 36 | Ok(Device::new_metal(0)?) 37 | } else { 38 | Ok(Device::Cpu) 39 | } 40 | } 41 | 42 | pub fn hub_load_safetensors( 43 | repo: &hf_hub::api::sync::ApiRepo, 44 | json_file: &str, 45 | ) -> Result, Box> { 46 | let json_file = repo.get(json_file).unwrap(); 47 | let json_file = std::fs::File::open(json_file)?; 48 | let json: serde_json::Value = serde_json::from_reader(&json_file).unwrap(); 49 | let weight_map = match json.get("weight_map") { 50 | None => return Err("no weight map in {json_file:?}".into()), 51 | Some(serde_json::Value::Object(map)) => map, 52 | Some(_) => return Err("weight map in {json_file:?} is not a map".into()), 53 | }; 54 | let mut safetensors_files = std::collections::HashSet::new(); 55 | for value in weight_map.values() { 56 | if let Some(file) = value.as_str() { 57 | safetensors_files.insert(file.to_string()); 58 | } 59 | } 60 | let safetensors_files = safetensors_files 61 | .iter() 62 | .map(|v| repo.get(v).unwrap()) 63 | .collect::>(); 64 | Ok(safetensors_files) 65 | } 66 | 67 | pub trait Sample { 68 | fn to_i16(&self) -> i16; 69 | } 70 | 71 | impl Sample for f32 { 72 | fn to_i16(&self) -> i16 { 73 | (self.clamp(-1.0, 1.0) * 32767.0) as i16 74 | } 75 | } 76 | 77 | impl Sample for f64 { 78 | fn to_i16(&self) -> i16 { 79 | (self.clamp(-1.0, 1.0) * 32767.0) as i16 80 | } 81 | } 82 | 83 | impl Sample for i16 { 84 | fn to_i16(&self) -> i16 { 85 | *self 86 | } 87 | } 88 | 89 | pub fn write_pcm_as_wav( 90 | w: &mut W, 91 | samples: &[S], 92 | sample_rate: u32, 93 | ) -> std::io::Result<()> { 94 | let len = 12u32; // header 95 | let len = len + 24u32; // fmt 96 | let len = len + samples.len() as u32 * 2 + 8; // data 97 | let n_channels = 1u16; 98 | let bytes_per_second = sample_rate * 2 * n_channels as u32; 99 | w.write_all(b"RIFF")?; 100 | w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes 101 | w.write_all(b"WAVE")?; 102 | 103 | // Format block 104 | w.write_all(b"fmt ")?; 105 | w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes 106 | w.write_all(&1u16.to_le_bytes())?; // PCM 107 | w.write_all(&n_channels.to_le_bytes())?; // one channel 108 | w.write_all(&sample_rate.to_le_bytes())?; 109 | w.write_all(&bytes_per_second.to_le_bytes())?; 110 | w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample 111 | w.write_all(&16u16.to_le_bytes())?; // bits per sample 112 | 113 | // Data block 114 | w.write_all(b"data")?; 115 | w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?; 116 | for sample in samples.iter() { 117 | w.write_all(&sample.to_i16().to_le_bytes())? 118 | } 119 | Ok(()) 120 | } 121 | 122 | pub fn normalize_loudness( 123 | wav: &Tensor, 124 | sample_rate: u32, 125 | loudness_compressor: bool, 126 | ) -> Result> { 127 | let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::()?; 128 | if energy < 2e-3 { 129 | return Ok(wav.clone()); 130 | } 131 | let wav_array = wav.to_vec1::()?; 132 | let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate); 133 | meter.push(wav_array.into_iter()); 134 | let power = meter.as_100ms_windows(); 135 | let loudness = match bs1770::gated_mean(power) { 136 | None => return Ok(wav.clone()), 137 | Some(gp) => gp.loudness_lkfs() as f64, 138 | }; 139 | let delta_loudness = -14. - loudness; 140 | let gain = 10f64.powf(delta_loudness / 20.); 141 | let wav = (wav * gain)?; 142 | if loudness_compressor { 143 | Ok(wav.tanh()?) 144 | } else { 145 | Ok(wav) 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /src/models/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "coqui")] 2 | pub mod coqui; 3 | #[cfg(feature = "gtts")] 4 | pub mod gtts; 5 | #[cfg(feature = "meta")] 6 | pub mod meta; 7 | #[cfg(feature = "msedge")] 8 | pub mod msedge; 9 | #[cfg(feature = "parler")] 10 | pub mod parler; 11 | #[cfg(feature = "tts-rs")] 12 | pub mod tts_rs; 13 | 14 | use crate::{ 15 | utils::{get_path, play_wav_file, read_wav_file}, 16 | TtsError, 17 | }; 18 | use hound::WavSpec; 19 | #[cfg(feature = "msedge")] 20 | use msedge_tts::tts::AudioMetadata; 21 | use rodio::Sample; 22 | use std::{error::Error, fs::File}; 23 | 24 | pub trait NaturalModelTrait { 25 | type SynthesizeType: Sample + Send; 26 | fn save(&mut self, message: String, path: String) -> Result<(), Box>; 27 | fn say(&mut self, message: String) -> Result<(), Box>; 28 | fn synthesize( 29 | &mut self, 30 | message: String, 31 | ) -> Result, Box>; 32 | } 33 | 34 | pub fn speak_model( 35 | model: &mut T, 36 | message: String, 37 | ) -> Result<(), Box> { 38 | let path = "output.wav"; 39 | let actual = get_path(path.to_string()); 40 | let _ = std::fs::remove_file(actual.clone()); 41 | let _ = model.save(message.clone(), actual.clone()); 42 | let _ = play_wav_file(&actual); 43 | let _ = std::fs::remove_file(actual); 44 | Ok(()) 45 | } 46 | 47 | pub fn synthesize_model( 48 | model: &mut T, 49 | message: String, 50 | ) -> Result, Box> { 51 | let path = "output.wav"; 52 | let actual = get_path(path.to_string()); 53 | let _ = std::fs::remove_file(actual.clone()); 54 | let _ = model.save(message.clone(), actual.clone()); 55 | let rwf = read_wav_file(&actual)?; 56 | let _ = std::fs::remove_file(actual); 57 | Ok(rwf) 58 | } 59 | 60 | pub enum Spec { 61 | Wav(WavSpec), 62 | #[cfg(feature = "msedge")] 63 | Synthesized(String, Vec), 64 | Unknown, 65 | } 66 | 67 | pub struct SynthesizedAudio { 68 | pub spec: Spec, 69 | pub data: Vec, 70 | pub duration: Option, 71 | } 72 | 73 | impl SynthesizedAudio { 74 | pub fn new(data: Vec, spec: Spec, duration: Option) -> Self { 75 | return Self { 76 | data, 77 | spec, 78 | duration, 79 | }; 80 | } 81 | } 82 | 83 | pub fn did_save(path: &str) -> Result<(), Box> { 84 | let file = File::open(path); 85 | match file { 86 | Ok(_) => Ok(()), 87 | Err(_) => Err(Box::new(TtsError::NotSaved)), 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/models/msedge.rs: -------------------------------------------------------------------------------- 1 | use super::{NaturalModelTrait, Spec, SynthesizedAudio}; 2 | use crate::utils::{play_audio, save_wav}; 3 | use msedge_tts::{ 4 | tts::{client::connect, SpeechConfig as OtherConfig}, 5 | voice::{get_voices_list, Voice}, 6 | }; 7 | 8 | #[derive(Clone, Debug)] 9 | pub struct MSEdgeModel { 10 | config: SpeechConfig, 11 | } 12 | 13 | impl MSEdgeModel { 14 | pub fn new_from_voice(voice: Voice) -> Self { 15 | return Self { 16 | config: SpeechConfig::from(&voice), 17 | }; 18 | } 19 | 20 | pub fn new(config: SpeechConfig) -> Self { 21 | return Self { config }; 22 | } 23 | } 24 | 25 | impl Default for MSEdgeModel { 26 | fn default() -> Self { 27 | let voice = get_voices_list().unwrap(); 28 | return Self::new(SpeechConfig::from(voice.first().unwrap())); 29 | } 30 | } 31 | 32 | impl NaturalModelTrait for MSEdgeModel { 33 | type SynthesizeType = f32; 34 | fn say(&mut self, message: String) -> Result<(), Box> { 35 | let synthesized = Self::synthesize(self, message)?; 36 | 37 | let rate = match self.config.rate { 38 | x if x <= 0 => 16000, 39 | x => x, 40 | }; 41 | 42 | match synthesized.spec { 43 | Spec::Wav(x) => play_audio(synthesized.data, x.sample_rate), 44 | _ => play_audio(synthesized.data, rate as u32), 45 | } 46 | Ok(()) 47 | } 48 | 49 | fn save(&mut self, message: String, path: String) -> Result<(), Box> { 50 | let synthesized = Self::synthesize(self, message)?; 51 | 52 | let rate = match self.config.rate { 53 | x if x <= 0 => 16000, 54 | x => x, 55 | }; 56 | 57 | let _ = save_wav(&synthesized.data, path.as_str(), rate as u32); 58 | Ok(()) 59 | } 60 | 61 | fn synthesize( 62 | &mut self, 63 | message: String, 64 | ) -> Result, Box> { 65 | let mut tts = connect().unwrap(); 66 | let audio = tts.synthesize(message.as_str(), &self.config.as_msedge())?; 67 | return Ok(SynthesizedAudio::new( 68 | audio.audio_bytes.iter().map(|x| x.clone() as f32).collect(), 69 | Spec::Synthesized(audio.audio_format, audio.audio_metadata), 70 | None, 71 | )); 72 | } 73 | } 74 | 75 | #[derive(Debug, Clone)] 76 | pub struct SpeechConfig { 77 | pub voice_name: String, 78 | pub audio_format: String, 79 | pub pitch: i32, 80 | pub rate: i32, 81 | pub volume: i32, 82 | } 83 | 84 | impl SpeechConfig { 85 | pub fn as_msedge(&self) -> OtherConfig { 86 | return OtherConfig { 87 | voice_name: self.voice_name.clone(), 88 | audio_format: self.audio_format.clone(), 89 | pitch: self.pitch, 90 | rate: self.rate, 91 | volume: self.volume, 92 | }; 93 | } 94 | } 95 | 96 | impl From<&msedge_tts::tts::SpeechConfig> for SpeechConfig { 97 | fn from(config: &msedge_tts::tts::SpeechConfig) -> Self { 98 | return Self { 99 | voice_name: config.voice_name.clone(), 100 | audio_format: config.audio_format.clone(), 101 | pitch: config.pitch, 102 | rate: config.rate, 103 | volume: config.volume, 104 | }; 105 | } 106 | } 107 | 108 | impl From<&msedge_tts::voice::Voice> for SpeechConfig { 109 | fn from(voice: &msedge_tts::voice::Voice) -> Self { 110 | let mscfg = OtherConfig::from(voice); 111 | return Self::from(&mscfg); 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /src/models/parler/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod model; 2 | use super::{did_save, speak_model, NaturalModelTrait, SynthesizedAudio}; 3 | use candle_core::{DType, Device, IndexOp, Tensor}; 4 | use candle_nn::VarBuilder; 5 | use derive_builder::Builder; 6 | use hf_hub::api::sync::Api; 7 | use hound::WavSpec; 8 | use model::*; 9 | use std::error::Error; 10 | use tokenizers::Tokenizer; 11 | 12 | use super::meta::utils::*; 13 | 14 | const MODEL_NAME: &str = "parler-tts/parler-tts-mini-v1"; 15 | 16 | #[derive(Builder, Clone, Default)] 17 | #[builder(setter(into))] 18 | pub struct ParlerModelOptions { 19 | #[builder(default = "false")] 20 | cpu: bool, 21 | description: String, 22 | #[builder(default = "false")] 23 | tracing: bool, 24 | #[builder(default = "1.0")] 25 | temperature: f64, 26 | #[builder(default = "None")] 27 | top_p: Option, 28 | #[builder(default = "299792458")] 29 | seed: u64, 30 | #[builder(default = "MODEL_NAME.to_string()")] 31 | model_name: String, 32 | } 33 | 34 | #[derive(Clone)] 35 | pub struct ParlerModel { 36 | device: Device, 37 | config: Config, 38 | model: Model, 39 | description: String, 40 | temperature: f64, 41 | top_p: Option, 42 | seed: u64, 43 | tokenizer: Tokenizer, 44 | } 45 | 46 | impl ParlerModel { 47 | pub fn new(options: ParlerModelOptions) -> Result> { 48 | use tracing_chrome::ChromeLayerBuilder; 49 | use tracing_subscriber::prelude::*; 50 | 51 | let _guard = if options.tracing { 52 | let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); 53 | tracing_subscriber::registry().with(chrome_layer).init(); 54 | Some(guard) 55 | } else { 56 | None 57 | }; 58 | 59 | let api = Api::new()?; 60 | let device = device(options.cpu)?; 61 | 62 | let revision = "main".to_string(); 63 | 64 | let repo = api.repo(hf_hub::Repo::with_revision( 65 | options.model_name, 66 | hf_hub::RepoType::Model, 67 | revision.clone(), 68 | )); 69 | 70 | let config = repo.get("config.json")?; 71 | 72 | let model_files = match repo.get("model.safetensors") { 73 | Ok(x) => vec![x], 74 | Err(_) => hub_load_safetensors(&repo, "model.safetensors.index.json")?, 75 | }; 76 | 77 | let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device)? }; 78 | 79 | let config: Config = serde_json::from_reader(std::fs::File::open(config)?)?; 80 | let model = Model::new(&config, vb)?; 81 | 82 | let tokenizer = repo.get("tokenizer.json")?; 83 | 84 | let tokenizer = Tokenizer::from_file(tokenizer).unwrap(); 85 | 86 | return Ok(Self { 87 | device, 88 | config, 89 | top_p: options.top_p, 90 | description: options.description, 91 | model, 92 | tokenizer, 93 | seed: options.seed, 94 | temperature: options.temperature, 95 | }); 96 | } 97 | 98 | pub fn generate(&mut self, message: String) -> Result, Box> { 99 | let description_tokens = self 100 | .tokenizer 101 | .encode(self.description.clone(), true) 102 | .unwrap() 103 | .get_ids() 104 | .to_vec(); 105 | let description_tokens = Tensor::new(description_tokens, &self.device)?.unsqueeze(0)?; 106 | let prompt_tokens = self 107 | .tokenizer 108 | .encode(message, true) 109 | .unwrap() 110 | .get_ids() 111 | .to_vec(); 112 | let prompt_tokens = Tensor::new(prompt_tokens, &self.device)?.unsqueeze(0)?; 113 | 114 | let lp = candle_transformers::generation::LogitsProcessor::new( 115 | self.seed, 116 | Some(self.temperature), 117 | self.top_p, 118 | ); 119 | 120 | println!("Starting."); 121 | 122 | let codes = self 123 | .model 124 | .generate(&prompt_tokens, &description_tokens, lp, 512)?; 125 | let codes = codes.to_dtype(DType::I64)?; 126 | codes.save_safetensors("codes", "out.safetensors")?; 127 | let codes = codes.unsqueeze(0)?; 128 | 129 | let pcm = self 130 | .model 131 | .audio_encoder 132 | .decode_codes(&codes.to_device(&self.device)?)?; 133 | let pcm = pcm.i((0, 0))?; 134 | let pcm = normalize_loudness(&pcm, 24_000, true)?; 135 | let pcm = pcm.to_vec1::()?; 136 | 137 | return Ok(SynthesizedAudio::new( 138 | pcm, 139 | super::Spec::Wav(WavSpec { 140 | sample_rate: self.config.audio_encoder.sampling_rate, 141 | channels: 1, 142 | sample_format: hound::SampleFormat::Float, 143 | bits_per_sample: self.config.audio_encoder.model_bitrate as u16, 144 | }), 145 | None, 146 | )); 147 | } 148 | } 149 | 150 | impl Default for ParlerModel { 151 | fn default() -> Self { 152 | let desc = "A female speaker in fast calming voice in a quiet environment".to_string(); 153 | let model = "parler-tts/parler-tts-mini-expresso".to_string(); 154 | return Self::new( 155 | ParlerModelOptionsBuilder::default() 156 | .model_name(model) 157 | .description(desc) 158 | .build() 159 | .unwrap(), 160 | ) 161 | .unwrap(); 162 | } 163 | } 164 | 165 | impl NaturalModelTrait for ParlerModel { 166 | type SynthesizeType = f32; 167 | fn save(&mut self, message: String, path: String) -> Result<(), Box> { 168 | let data = self.synthesize(message)?; 169 | let mut output = std::fs::File::create(&path)?; 170 | println!("{:?}", self.config.audio_encoder.sampling_rate); 171 | write_pcm_as_wav( 172 | &mut output, 173 | &data.data, 174 | self.config.audio_encoder.sampling_rate, 175 | )?; 176 | did_save(path.as_str()) 177 | } 178 | 179 | fn say(&mut self, message: String) -> Result<(), Box> { 180 | speak_model(self, message) 181 | } 182 | 183 | fn synthesize( 184 | &mut self, 185 | message: String, 186 | ) -> Result, Box> { 187 | self.generate(message) 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /src/models/parler/model.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{DType, Device, IndexOp, Result, Tensor}; 2 | use candle_nn::{layer_norm, linear_b as linear, Activation, LayerNorm, Linear, VarBuilder}; 3 | use candle_transformers::{ 4 | generation::LogitsProcessor, 5 | models::{dac, t5}, 6 | utils::repeat_kv, 7 | }; 8 | 9 | #[derive(serde::Deserialize, Debug, Clone)] 10 | pub struct DecoderConfig { 11 | pub vocab_size: usize, 12 | pub max_position_embeddings: usize, 13 | pub num_hidden_layers: usize, 14 | pub ffn_dim: usize, 15 | pub num_attention_heads: usize, 16 | pub num_key_value_heads: Option, 17 | pub num_cross_attention_key_value_heads: Option, 18 | pub activation_function: Activation, 19 | pub hidden_size: usize, 20 | pub scale_embedding: bool, 21 | pub num_codebooks: usize, 22 | pub pad_token_id: usize, 23 | pub bos_token_id: usize, 24 | pub eos_token_id: usize, 25 | pub tie_word_embeddings: bool, 26 | } 27 | 28 | #[derive(serde::Deserialize, Debug, Clone)] 29 | pub struct Config { 30 | pub decoder_start_token_id: u32, 31 | pub pad_token_id: u32, 32 | pub decoder: DecoderConfig, 33 | pub text_encoder: t5::Config, 34 | pub vocab_size: usize, 35 | pub audio_encoder: dac::Config, 36 | } 37 | 38 | #[derive(Debug, Clone)] 39 | pub struct Attention { 40 | k_proj: Linear, 41 | v_proj: Linear, 42 | q_proj: Linear, 43 | out_proj: Linear, 44 | is_causal: bool, 45 | kv_cache: Option<(Tensor, Tensor)>, 46 | scaling: f64, 47 | num_heads: usize, 48 | num_kv_heads: usize, 49 | num_kv_groups: usize, 50 | head_dim: usize, 51 | } 52 | 53 | impl Attention { 54 | fn new( 55 | num_kv_heads: usize, 56 | is_causal: bool, 57 | cfg: &DecoderConfig, 58 | vb: VarBuilder, 59 | ) -> Result { 60 | let embed_dim = cfg.hidden_size; 61 | let head_dim = embed_dim / cfg.num_attention_heads; 62 | let kv_out_dim = num_kv_heads * head_dim; 63 | let k_proj = linear(embed_dim, kv_out_dim, false, vb.pp("k_proj"))?; 64 | let v_proj = linear(embed_dim, kv_out_dim, false, vb.pp("v_proj"))?; 65 | let q_proj = linear(embed_dim, embed_dim, false, vb.pp("q_proj"))?; 66 | let out_proj = linear(embed_dim, embed_dim, false, vb.pp("out_proj"))?; 67 | Ok(Self { 68 | k_proj, 69 | v_proj, 70 | q_proj, 71 | out_proj, 72 | is_causal, 73 | kv_cache: None, 74 | scaling: (head_dim as f64).powf(-0.5), 75 | num_heads: cfg.num_attention_heads, 76 | num_kv_heads, 77 | num_kv_groups: cfg.num_attention_heads / num_kv_heads, 78 | head_dim, 79 | }) 80 | } 81 | 82 | fn forward( 83 | &mut self, 84 | xs: &Tensor, 85 | key_value_states: Option<&Tensor>, 86 | attention_mask: Option<&Tensor>, 87 | ) -> Result { 88 | let (b_sz, tgt_len, _) = xs.dims3()?; 89 | let query_states = (xs.apply(&self.q_proj)? * self.scaling)? 90 | .reshape((b_sz, tgt_len, self.num_heads, self.head_dim))? 91 | .transpose(1, 2)? 92 | .contiguous()?; 93 | let key_states = match key_value_states { 94 | Some(states) => states.apply(&self.k_proj)?, 95 | None => xs.apply(&self.k_proj)?, 96 | }; 97 | let key_states = key_states 98 | .reshape((b_sz, (), self.num_kv_heads, self.head_dim))? 99 | .transpose(1, 2)? 100 | .contiguous()?; 101 | let value_states = match key_value_states { 102 | Some(states) => states.apply(&self.v_proj)?, 103 | None => xs.apply(&self.v_proj)?, 104 | }; 105 | let value_states = value_states 106 | .reshape((b_sz, (), self.num_kv_heads, self.head_dim))? 107 | .transpose(1, 2)? 108 | .contiguous()?; 109 | 110 | let (key_states, value_states) = match &self.kv_cache { 111 | None => (key_states, value_states), 112 | Some((prev_k, prev_v)) => { 113 | let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; 114 | let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; 115 | (key_states, value_states) 116 | } 117 | }; 118 | if self.is_causal { 119 | self.kv_cache = Some((key_states.clone(), value_states.clone())); 120 | } 121 | 122 | let key_states = repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; 123 | let value_states = repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; 124 | 125 | let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?; 126 | let attn_weights = match attention_mask { 127 | None => attn_weights, 128 | Some(mask) => attn_weights.broadcast_add(mask)?, 129 | }; 130 | let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; 131 | let attn_output = attn_weights.matmul(&value_states)?; 132 | attn_output 133 | .transpose(1, 2)? 134 | .reshape((b_sz, tgt_len, ()))? 135 | .apply(&self.out_proj) 136 | } 137 | 138 | fn clear_kv_cache(&mut self) { 139 | self.kv_cache = None 140 | } 141 | } 142 | 143 | #[derive(Debug, Clone)] 144 | pub struct DecoderLayer { 145 | self_attn: Attention, 146 | self_attn_layer_norm: LayerNorm, 147 | encoder_attn: Attention, 148 | encoder_attn_layer_norm: LayerNorm, 149 | fc1: Linear, 150 | fc2: Linear, 151 | final_layer_norm: LayerNorm, 152 | activation: Activation, 153 | } 154 | 155 | impl DecoderLayer { 156 | fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result { 157 | let kv_heads = cfg.num_key_value_heads.unwrap_or(cfg.num_attention_heads); 158 | let kv_heads_cross = cfg.num_cross_attention_key_value_heads.unwrap_or(kv_heads); 159 | 160 | let self_attn = Attention::new(kv_heads, true, cfg, vb.pp("self_attn"))?; 161 | let encoder_attn = Attention::new(kv_heads_cross, false, cfg, vb.pp("encoder_attn"))?; 162 | let self_attn_layer_norm = 163 | layer_norm(cfg.hidden_size, 1e-5, vb.pp("self_attn_layer_norm"))?; 164 | let encoder_attn_layer_norm = 165 | layer_norm(cfg.hidden_size, 1e-5, vb.pp("encoder_attn_layer_norm"))?; 166 | let fc1 = linear(cfg.hidden_size, cfg.ffn_dim, false, vb.pp("fc1"))?; 167 | let fc2 = linear(cfg.ffn_dim, cfg.hidden_size, false, vb.pp("fc2"))?; 168 | let final_layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb.pp("final_layer_norm"))?; 169 | Ok(Self { 170 | self_attn, 171 | self_attn_layer_norm, 172 | encoder_attn, 173 | encoder_attn_layer_norm, 174 | fc1, 175 | fc2, 176 | final_layer_norm, 177 | activation: cfg.activation_function, 178 | }) 179 | } 180 | 181 | fn forward( 182 | &mut self, 183 | xs: &Tensor, 184 | attention_mask: Option<&Tensor>, 185 | encoder_xs: &Tensor, 186 | encoder_attention_mask: Option<&Tensor>, 187 | ) -> Result { 188 | // Self attention 189 | let residual = xs; 190 | let xs = xs.apply(&self.self_attn_layer_norm)?; 191 | let xs = self.self_attn.forward(&xs, None, attention_mask)?; 192 | let xs = (residual + xs)?; 193 | 194 | // Cross attention 195 | let residual = &xs; 196 | let xs = xs.apply(&self.encoder_attn_layer_norm)?; 197 | let xs = self 198 | .encoder_attn 199 | .forward(&xs, Some(encoder_xs), encoder_attention_mask)?; 200 | let xs = (residual + xs)?; 201 | 202 | // Fully connected 203 | let residual = &xs; 204 | let xs = xs 205 | .apply(&self.final_layer_norm)? 206 | .apply(&self.fc1)? 207 | .apply(&self.activation)? 208 | .apply(&self.fc2)?; 209 | residual + xs 210 | } 211 | 212 | fn clear_kv_cache(&mut self) { 213 | self.self_attn.clear_kv_cache(); 214 | self.encoder_attn.clear_kv_cache(); 215 | } 216 | } 217 | 218 | #[derive(Debug, Clone)] 219 | pub struct Decoder { 220 | embed_tokens: Vec, 221 | embed_positions: Tensor, 222 | layers: Vec, 223 | layer_norm: LayerNorm, 224 | num_codebooks: usize, 225 | hidden_size: usize, 226 | lm_heads: Vec, 227 | dtype: DType, 228 | } 229 | 230 | impl Decoder { 231 | pub fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result { 232 | let vb_d = vb.pp("model.decoder"); 233 | let mut embed_tokens = Vec::with_capacity(cfg.num_codebooks); 234 | let vb_e = vb_d.pp("embed_tokens"); 235 | for embed_idx in 0..cfg.num_codebooks { 236 | let e = candle_nn::embedding(cfg.vocab_size + 1, cfg.hidden_size, vb_e.pp(embed_idx))?; 237 | embed_tokens.push(e) 238 | } 239 | let embed_positions = vb_d.get( 240 | (cfg.max_position_embeddings, cfg.hidden_size), 241 | "embed_positions.weights", 242 | )?; 243 | let mut layers = Vec::with_capacity(cfg.num_hidden_layers); 244 | let vb_l = vb_d.pp("layers"); 245 | for layer_idx in 0..cfg.num_hidden_layers { 246 | let layer = DecoderLayer::new(cfg, vb_l.pp(layer_idx))?; 247 | layers.push(layer) 248 | } 249 | let layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb_d.pp("layer_norm"))?; 250 | 251 | let mut lm_heads = Vec::with_capacity(cfg.num_codebooks); 252 | let vb_l = vb.pp("lm_heads"); 253 | for lm_idx in 0..cfg.num_codebooks { 254 | let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb_l.pp(lm_idx))?; 255 | lm_heads.push(lm_head) 256 | } 257 | Ok(Self { 258 | embed_tokens, 259 | embed_positions, 260 | layers, 261 | layer_norm, 262 | num_codebooks: cfg.num_codebooks, 263 | lm_heads, 264 | hidden_size: cfg.hidden_size, 265 | dtype: vb.dtype(), 266 | }) 267 | } 268 | 269 | pub fn forward( 270 | &mut self, 271 | input_ids: &Tensor, 272 | prompt_hidden_states: Option<&Tensor>, 273 | attention_mask: Option<&Tensor>, 274 | encoder_xs: &Tensor, 275 | encoder_attention_mask: Option<&Tensor>, 276 | seqlen_offset: usize, 277 | ) -> Result> { 278 | let (b_sz, num_codebooks, seq_len) = input_ids.dims3()?; 279 | let mut inputs_embeds = Tensor::zeros( 280 | (b_sz, seq_len, self.hidden_size), 281 | self.dtype, 282 | input_ids.device(), 283 | )?; 284 | for (idx, embs) in self.embed_tokens.iter().enumerate() { 285 | let e = input_ids.i((.., idx))?.apply(embs)?; 286 | inputs_embeds = (inputs_embeds + e)? 287 | } 288 | let inputs_embeds = match prompt_hidden_states { 289 | None => inputs_embeds, 290 | Some(pis) => Tensor::cat(&[pis, &inputs_embeds], 1)?, 291 | }; 292 | let embed_positions = self 293 | .embed_positions 294 | .i(seqlen_offset..seqlen_offset + inputs_embeds.dim(1)?)?; 295 | let mut xs = (inputs_embeds + embed_positions.unsqueeze(0))?; 296 | for layer in self.layers.iter_mut() { 297 | xs = layer.forward(&xs, attention_mask, encoder_xs, encoder_attention_mask)?; 298 | } 299 | let xs = xs.apply(&self.layer_norm)?; 300 | let mut lm_logits = Vec::with_capacity(self.num_codebooks); 301 | for lm_head in self.lm_heads.iter() { 302 | let logits = xs.apply(lm_head)?; 303 | lm_logits.push(logits) 304 | } 305 | Ok(lm_logits) 306 | } 307 | 308 | pub fn clear_kv_cache(&mut self) { 309 | for layer in self.layers.iter_mut() { 310 | layer.clear_kv_cache() 311 | } 312 | } 313 | } 314 | 315 | #[derive(Debug, Clone)] 316 | pub struct Model { 317 | pub embed_prompts: candle_nn::Embedding, 318 | pub enc_to_dec_proj: Option, 319 | pub decoder: Decoder, 320 | pub text_encoder: t5::T5EncoderModel, 321 | pub decoder_start_token_id: u32, 322 | pub pad_token_id: u32, 323 | pub audio_encoder: dac::Model, 324 | } 325 | 326 | impl Model { 327 | pub fn new(cfg: &Config, vb: VarBuilder) -> Result { 328 | let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.text_encoder)?; 329 | let decoder = Decoder::new(&cfg.decoder, vb.pp("decoder"))?; 330 | let embed_prompts = candle_nn::embedding( 331 | cfg.vocab_size, 332 | cfg.decoder.hidden_size, 333 | vb.pp("embed_prompts"), 334 | )?; 335 | let enc_to_dec_proj = if cfg.text_encoder.d_model != cfg.decoder.hidden_size { 336 | let proj = linear( 337 | cfg.text_encoder.d_model, 338 | cfg.decoder.hidden_size, 339 | true, 340 | vb.pp("enc_to_dec_proj"), 341 | )?; 342 | Some(proj) 343 | } else { 344 | None 345 | }; 346 | let audio_encoder = dac::Model::new(&cfg.audio_encoder, vb.pp("audio_encoder"))?; 347 | Ok(Self { 348 | decoder, 349 | text_encoder, 350 | embed_prompts, 351 | enc_to_dec_proj, 352 | decoder_start_token_id: cfg.decoder_start_token_id, 353 | pad_token_id: cfg.pad_token_id, 354 | audio_encoder, 355 | }) 356 | } 357 | 358 | /// Note that the returned tensor uses the CPU device. 359 | pub fn generate( 360 | &mut self, 361 | prompt_tokens: &Tensor, 362 | description_tokens: &Tensor, 363 | mut lp: LogitsProcessor, 364 | max_steps: usize, 365 | ) -> Result { 366 | self.decoder.clear_kv_cache(); 367 | self.text_encoder.clear_kv_cache(); 368 | let encoded = self.text_encoder.forward(description_tokens)?; 369 | let encoded = match self.enc_to_dec_proj.as_ref() { 370 | None => encoded, 371 | Some(proj) => encoded.apply(proj)?, 372 | }; 373 | let prompt_hidden_states = prompt_tokens.apply(&self.embed_prompts)?; 374 | let num_codebooks = self.decoder.num_codebooks; 375 | let mut audio_tokens = vec![self.decoder_start_token_id; num_codebooks]; 376 | let mut all_audio_tokens = vec![vec![]; num_codebooks]; 377 | let prompt_len = prompt_hidden_states.dim(1)?; 378 | for step in 0..max_steps { 379 | let input_ids = Tensor::from_slice( 380 | audio_tokens.as_slice(), 381 | (1, num_codebooks, 1), 382 | prompt_tokens.device(), 383 | )?; 384 | let (prompt_hidden_states, pos) = if step == 0 { 385 | (Some(&prompt_hidden_states), 0) 386 | } else { 387 | (None, step + prompt_len) 388 | }; 389 | let causal_mask = if pos == 0 { 390 | self.prepare_causal_mask(prompt_len + 1, prompt_len + 1, input_ids.device())? 391 | } else { 392 | self.prepare_causal_mask(1, pos + 1, input_ids.device())? 393 | }; 394 | let logits = self.decoder.forward( 395 | &input_ids, 396 | prompt_hidden_states, 397 | Some(&causal_mask), 398 | &encoded, 399 | None, 400 | pos, 401 | )?; 402 | for (logit_idx, logit) in logits.iter().enumerate() { 403 | if logit_idx > step { 404 | break; 405 | } 406 | if audio_tokens[logit_idx] != self.pad_token_id { 407 | let logit = logit.i((0, logit.dim(1)? - 1))?; 408 | let token = lp.sample(&logit)?; 409 | audio_tokens[logit_idx] = token 410 | } 411 | } 412 | if audio_tokens.iter().all(|v| v == &self.pad_token_id) { 413 | break; 414 | } 415 | for (cb_idx, &token) in audio_tokens.iter().enumerate() { 416 | if token != self.decoder_start_token_id && token != self.pad_token_id { 417 | all_audio_tokens[cb_idx].push(token) 418 | } 419 | } 420 | } 421 | 422 | let min_len = all_audio_tokens.iter().map(|v| v.len()).min().unwrap_or(0); 423 | all_audio_tokens.iter_mut().for_each(|v| { 424 | v.resize(min_len, 0); 425 | }); 426 | let all_audio_tokens = Tensor::new(all_audio_tokens, &Device::Cpu)?; 427 | Ok(all_audio_tokens) 428 | } 429 | 430 | fn prepare_causal_mask(&self, q_len: usize, kv_len: usize, device: &Device) -> Result { 431 | let mask: Vec<_> = (0..q_len) 432 | .flat_map(|i| { 433 | (0..kv_len).map(move |j| { 434 | if i + kv_len < j + q_len { 435 | f32::NEG_INFINITY 436 | } else { 437 | 0. 438 | } 439 | }) 440 | }) 441 | .collect(); 442 | Tensor::from_slice(&mask, (q_len, kv_len), device) 443 | } 444 | } 445 | -------------------------------------------------------------------------------- /src/models/tts_rs.rs: -------------------------------------------------------------------------------- 1 | use super::{NaturalModelTrait, SynthesizedAudio}; 2 | use crate::TtsError; 3 | use std::error::Error; 4 | use tts::Tts; 5 | 6 | #[derive(Clone)] 7 | pub struct TtsModel(pub Tts); 8 | 9 | impl TtsModel { 10 | pub fn new() -> Result> { 11 | let def = Tts::default()?; 12 | return Ok(Self(def)); 13 | } 14 | } 15 | 16 | impl Default for TtsModel { 17 | fn default() -> Self { 18 | return Self::new().unwrap(); 19 | } 20 | } 21 | 22 | impl NaturalModelTrait for TtsModel { 23 | type SynthesizeType = f32; 24 | fn save(&mut self, message: String, path: String) -> Result<(), Box> { 25 | Err(TtsError::NotSupported.into()) 26 | } 27 | 28 | fn say(&mut self, message: String) -> Result<(), Box> { 29 | let is_speaking = self.0.is_speaking(); 30 | 31 | if let Ok(speaking) = is_speaking { 32 | if speaking { 33 | return Ok(()); 34 | } 35 | } 36 | 37 | let _ = self.0.speak(message, false); 38 | Ok(()) 39 | } 40 | 41 | fn synthesize( 42 | &mut self, 43 | message: String, 44 | ) -> Result, Box> { 45 | Err(TtsError::NotSupported.into()) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/test.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | use crate::{ 3 | models::{ 4 | gtts::GttsModel, meta::MetaModel, msedge::MSEdgeModel, parler::ParlerModel, 5 | tts_rs::TtsModel, 6 | }, 7 | *, 8 | }; 9 | 10 | #[cfg(feature = "gtts")] 11 | #[test] 12 | fn gtts_test() { 13 | let mut natural = NaturalTtsBuilder::default() 14 | .gtts_model(GttsModel::default()) 15 | .default_model(Model::Gtts) 16 | .build() 17 | .unwrap(); 18 | let _ = natural.say("Hello, World!".to_string()); 19 | } 20 | 21 | #[cfg(feature = "parler")] 22 | #[test] 23 | fn parler_test() { 24 | let mut natural = NaturalTtsBuilder::default() 25 | .parler_model(ParlerModel::default()) 26 | .default_model(Model::Parler) 27 | .build() 28 | .unwrap(); 29 | let _ = natural.say("Hello, World!".to_string()); 30 | } 31 | 32 | #[cfg(feature = "msedge")] 33 | #[test] 34 | fn msedge_test() { 35 | let mut natural = NaturalTtsBuilder::default() 36 | .msedge_model(MSEdgeModel::default()) 37 | .default_model(Model::MSEdge) 38 | .build() 39 | .unwrap(); 40 | let _ = natural.say("Hello, World!".to_string()); 41 | } 42 | 43 | #[cfg(feature = "tts-rs")] 44 | #[test] 45 | fn tts_test() { 46 | let mut natural = NaturalTtsBuilder::default() 47 | .tts_model(TtsModel::default()) 48 | .default_model(Model::TTS) 49 | .build() 50 | .unwrap(); 51 | let _ = natural.say("Hello, World!".to_string()); 52 | } 53 | 54 | #[cfg(feature = "meta")] 55 | #[test] 56 | fn meta_test() { 57 | let mut natural = NaturalTtsBuilder::default() 58 | .meta_model(MetaModel::default()) 59 | .default_model(Model::Meta) 60 | .build() 61 | .unwrap(); 62 | let _ = natural.say("Hello, World!".to_string()); 63 | } 64 | 65 | #[cfg(feature = "coqui")] 66 | #[test] 67 | fn coqui_test() { 68 | let mut natural = NaturalTtsBuilder::default() 69 | .coqui_model(CoquiModel::default()) 70 | .default_model(Model::Coqui) 71 | .build() 72 | .unwrap(); 73 | let _ = natural.say("Hello, World!".to_string()); 74 | } 75 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024-2025 natural-tts 2 | // 3 | // Permission is hereby granted, free of charge, to any person obtaining a copy 4 | // of this software and associated documentation files (the "Software"), to deal 5 | // in the Software without restriction, including without limitation the rights 6 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | // copies of the Software, and to permit persons to whom the Software is 8 | // furnished to do so, subject to the following conditions: 9 | // 10 | // The above copyright notice and this permission notice shall be included in all 11 | // copies or substantial portions of the Software. 12 | // 13 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | // SOFTWARE. 20 | use crate::models::{Spec::Wav, SynthesizedAudio}; 21 | use hound::WavReader; 22 | use rodio::{buffer::SamplesBuffer, cpal::FromSample, Decoder, OutputStream, Sink}; 23 | use std::{error::Error, io::Write}; 24 | 25 | pub fn read_wav_file(path: &str) -> Result, Box> { 26 | let mut reader = WavReader::open(path)?; 27 | let mut f32_samples: Vec = Vec::new(); 28 | 29 | reader.samples::().for_each(|s| { 30 | if let Ok(sample) = s { 31 | f32_samples.push(sample as f32); 32 | } 33 | }); 34 | 35 | Ok(SynthesizedAudio::new( 36 | f32_samples, 37 | Wav(reader.spec()), 38 | Some(reader.duration() as i32), 39 | )) 40 | } 41 | 42 | pub fn get_path(path: String) -> String { 43 | let mut new_path = env!("CARGO_MANIFEST_DIR").to_string(); 44 | new_path.push_str(&format!("/src/{}", path)); 45 | return new_path; 46 | } 47 | 48 | pub fn play_audio(data: Vec, rate: u32) 49 | where 50 | T: rodio::Sample + Send + 'static, 51 | f32: FromSample, 52 | { 53 | let (_stream, handle) = rodio::OutputStream::try_default().unwrap(); 54 | let source = SamplesBuffer::new(1, rate, data); 55 | let sink = rodio::Sink::try_new(&handle).unwrap(); 56 | 57 | sink.append(source); 58 | 59 | sink.sleep_until_end(); 60 | } 61 | 62 | pub fn play_wav_file(path: &str) -> Result<(), Box> { 63 | let file = std::fs::File::open(path)?; 64 | let decoder = Decoder::new(file)?; 65 | let (_stream, stream_handle) = OutputStream::try_default()?; 66 | let sink = Sink::try_new(&stream_handle)?; 67 | 68 | sink.append(decoder); 69 | sink.sleep_until_end(); 70 | 71 | Ok(()) 72 | } 73 | 74 | pub fn save_wav(data: &[f32], filename: &str, sample_rate: u32) -> Result<(), std::io::Error> { 75 | let mut file = std::fs::File::create(filename)?; 76 | 77 | // Write WAV header 78 | let (chunk_size, bits_per_sample) = (44 + data.len() * 4, 32); 79 | let pcm = 1; // PCM format 80 | 81 | let header = [ 82 | // RIFF chunk 83 | b'R', 84 | b'I', 85 | b'F', 86 | b'F', 87 | (chunk_size & 0xff) as u8, 88 | (chunk_size >> 8 & 0xff) as u8, 89 | (chunk_size >> 16 & 0xff) as u8, 90 | (chunk_size >> 24 & 0xff) as u8, 91 | // WAVE chunk 92 | b'W', 93 | b'A', 94 | b'V', 95 | b'E', 96 | // fmt subchunk 97 | b'f', 98 | b'm', 99 | b't', 100 | b' ', 101 | 16, 102 | 0, 103 | 0, 104 | 0, // Subchunk size (16 for PCM format) 105 | pcm as u8, 106 | 0, // PCM format 107 | 1, 108 | 0, // Mono channel 109 | (sample_rate & 0xff) as u8, 110 | (sample_rate >> 8 & 0xff) as u8, 111 | (sample_rate >> 16 & 0xff) as u8, 112 | (sample_rate >> 24 & 0xff) as u8, 113 | 4, 114 | 0, // Average bytes per second (4 for 32-bit samples) 115 | 4, 116 | 0, // Block align (4 for 32-bit samples, mono) 117 | bits_per_sample as u8, 118 | 0, // Bits per sample 119 | // data subchunk 120 | b'd', 121 | b'a', 122 | b't', 123 | b'a', 124 | (data.len() * 4) as u8, 125 | (data.len() * 4 >> 8 & 0xff) as u8, 126 | (data.len() * 4 >> 16 & 0xff) as u8, 127 | (data.len() * 4 >> 24 & 0xff) as u8, 128 | ]; 129 | file.write_all(&header)?; 130 | 131 | // Write audio data (assuming f32 samples between -1.0 and 1.0) 132 | for sample in data { 133 | let i32_sample = (sample * 2147483647.0) as i32; // Convert to 32-bit signed integer 134 | let bytes = i32_sample.to_le_bytes(); 135 | file.write_all(&bytes)?; 136 | } 137 | 138 | Ok(()) 139 | } 140 | --------------------------------------------------------------------------------