├── kokoros ├── src │ ├── onn │ │ ├── ort_yolo.rs │ │ ├── mod.rs │ │ ├── ort_base.rs │ │ └── ort_koko.rs │ ├── lib.rs │ ├── utils │ │ ├── mod.rs │ │ ├── wav.rs │ │ ├── fileio.rs │ │ ├── debug.rs │ │ ├── mp3.rs │ │ └── opus.rs │ └── tts │ │ ├── mod.rs │ │ ├── vocab.rs │ │ ├── tokenize.rs │ │ ├── phonemizer.rs │ │ ├── normalize.rs │ │ └── koko.rs └── Cargo.toml ├── data └── .gitignore ├── checkpoints └── .gitignore ├── download_all.sh ├── Cargo.toml ├── scripts ├── requirements.txt ├── download_voices.sh ├── download_models.sh └── run_openai.py ├── .gitignore ├── koko ├── Cargo.toml └── src │ └── main.rs ├── kokoros-openai ├── Cargo.toml └── src │ └── lib.rs ├── Dockerfile ├── .github └── workflows │ └── docker-ghcr.yml ├── install.sh └── README.md /kokoros/src/onn/ort_yolo.rs: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | *.json 2 | *.bin -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | *.onnx 2 | !download_models.sh 3 | -------------------------------------------------------------------------------- /kokoros/src/onn/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod ort_base; 2 | pub mod ort_koko; 3 | -------------------------------------------------------------------------------- /kokoros/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod onn; 2 | pub mod tts; 3 | pub mod utils; 4 | -------------------------------------------------------------------------------- /download_all.sh: -------------------------------------------------------------------------------- 1 | bash scripts/download_models.sh 2 | bash scripts/download_voices.sh -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["koko", "kokoros", "kokoros-openai"] 3 | resolver = "2" 4 | -------------------------------------------------------------------------------- /scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.24.0 2 | openai>=1.61.1 3 | torch>=2.0.0 4 | requests>=2.31.0 5 | -------------------------------------------------------------------------------- /kokoros/src/utils/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod debug; 2 | pub mod fileio; 3 | pub mod mp3; 4 | pub mod opus; 5 | pub mod wav; 6 | -------------------------------------------------------------------------------- /kokoros/src/tts/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod koko; 2 | pub mod normalize; 3 | pub mod phonemizer; 4 | pub mod tokenize; 5 | pub mod vocab; 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | vendor/ 3 | *.wav 4 | *.DS_store 5 | .venv 6 | venv/ 7 | .vscode/ 8 | .worktrees/ 9 | .idea/ 10 | tmp/ -------------------------------------------------------------------------------- /scripts/download_voices.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | wget "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/voices-v1.0.bin" -O "data/voices-v1.0.bin" -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | mkdir checkpoints 2 | wget https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/kokoro-v1.0.onnx -O "checkpoints/kokoro-v1.0.onnx" -------------------------------------------------------------------------------- /koko/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "koko" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | kokoros = { path = "../kokoros" } 8 | kokoros-openai = { path = "../kokoros-openai" } 9 | 10 | clap = { version = "4.5.39", features = ["derive"] } 11 | tokio = { version = "1.45.1", features = ["io-util", "rt-multi-thread"] } 12 | tracing = "0.1" 13 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 14 | -------------------------------------------------------------------------------- /scripts/run_openai.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | base_url = "http://localhost:3000/v1" 4 | 5 | client = OpenAI(base_url=base_url, api_key="sfrhg453656") 6 | 7 | speech_file_path = "tmp/speech.wav" 8 | response = client.audio.speech.create( 9 | model="anything can go here", 10 | voice="am_michael", # or voice=NotGiven(), (`from openai import NotGiven`) to use the server default 11 | input="Today is a wonderful day to build something people love!", 12 | ) 13 | response.write_to_file(speech_file_path) 14 | -------------------------------------------------------------------------------- /kokoros-openai/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "kokoros-openai" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | kokoros = { path = "../kokoros" } 8 | 9 | axum = { version = "0.8.4", features = ["http2"] } 10 | futures = "0.3" 11 | serde = { version = "1.0.219", features = ["derive"] } 12 | serde_json = "1.0" 13 | tokio = { version = "1.0", features = ["full"] } 14 | tokio-stream = "0.1" 15 | tower-http = { version = "0.6.6", features = ["cors", "trace"] } 16 | tracing = "0.1" 17 | uuid = { version = "1.0", features = ["v4"] } 18 | regex = "1.0" 19 | -------------------------------------------------------------------------------- /kokoros/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "kokoros" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | espeak-rs = "0.1.9" 8 | hound = "3.5.1" 9 | indicatif = "0.17.11" 10 | ndarray = "0.16.1" 11 | lazy_static = "1.5.0" 12 | regex = "1.11.1" 13 | reqwest = { version = "0.12.19" } 14 | serde_json = "1.0.140" 15 | tokio = { version = "1.45.1", features = ["fs", "io-util"] } 16 | ndarray-npy = "0.9.1" 17 | mp3lame-encoder = "0.2.1" 18 | tracing = "0.1" 19 | uuid = { version = "1.0", features = ["v4"] } 20 | opus = "0.3" 21 | ogg = "0.9" 22 | 23 | # Base ONNX Runtime configuration 24 | ort = { version = "2.0.0-rc.10", default-features = true } 25 | 26 | [features] 27 | default = ["cpu"] 28 | cpu = [] 29 | cuda = ["ort/cuda"] 30 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | FROM rust:1.86.0-slim-bookworm AS builderrs 3 | 4 | RUN apt-get update -qq && apt-get install -qq -y wget pkg-config libssl-dev clang git cmake && rustup component add rustfmt 5 | 6 | WORKDIR /app 7 | 8 | COPY . . 9 | COPY Cargo.toml . 10 | COPY Cargo.lock . 11 | 12 | RUN chmod +x ./download_all.sh && ./download_all.sh 13 | 14 | RUN cargo build --release 15 | 16 | FROM debian:sid-slim AS runner 17 | 18 | WORKDIR /app 19 | 20 | COPY --from=builderrs /app/target/release/build ./target/release/build 21 | COPY --from=builderrs /app/target/release/koko ./target/release/koko 22 | COPY --from=builderrs /app/data ./data 23 | COPY --from=builderrs /app/checkpoints ./checkpoints 24 | 25 | RUN chmod +x ./target/release/koko && apt-get update -qq && apt-get install -qq -y pkg-config libssl-dev 26 | 27 | EXPOSE 3000 28 | 29 | ENTRYPOINT [ "./target/release/koko" ] 30 | -------------------------------------------------------------------------------- /.github/workflows/docker-ghcr.yml: -------------------------------------------------------------------------------- 1 | name: Build and Push Docker Image to GHCR 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | workflow_dispatch: 7 | 8 | env: 9 | IMAGE_NAME: kokoros 10 | 11 | jobs: 12 | build-and-push: 13 | runs-on: ubuntu-latest 14 | permissions: 15 | contents: read 16 | packages: write 17 | id-token: write 18 | 19 | steps: 20 | - name: Checkout repository 21 | uses: actions/checkout@v4 22 | 23 | - name: Log in to GitHub Container Registry 24 | uses: docker/login-action@v3 25 | with: 26 | registry: ghcr.io 27 | username: ${{ github.actor }} 28 | password: ${{ secrets.GITHUB_TOKEN }} 29 | 30 | - name: Extract metadata (tags, labels) 31 | id: meta 32 | uses: docker/metadata-action@v5 33 | with: 34 | images: ghcr.io/${{ github.repository_owner }}/kokoros 35 | 36 | - name: Build and push Docker image 37 | uses: docker/build-push-action@v5 38 | with: 39 | context: . 40 | file: Dockerfile 41 | push: true 42 | tags: ${{ steps.meta.outputs.tags }} 43 | labels: ${{ steps.meta.outputs.labels }} 44 | -------------------------------------------------------------------------------- /kokoros/src/tts/vocab.rs: -------------------------------------------------------------------------------- 1 | use lazy_static::lazy_static; 2 | use std::collections::HashMap; 3 | 4 | pub fn get_vocab() -> std::collections::HashMap { 5 | let pad = "$"; 6 | let punctuation = ";:,.!?¡¿—…\"«»“” "; 7 | let letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; 8 | let letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"; 9 | 10 | let symbols: String = [pad, punctuation, letters, letters_ipa].concat(); 11 | 12 | symbols 13 | .chars() 14 | .enumerate() 15 | .collect::>() 16 | .into_iter() 17 | .map(|(idx, c)| (c, idx)) 18 | .collect() 19 | } 20 | 21 | pub fn get_reverse_vocab() -> HashMap { 22 | VOCAB.iter().map(|(&c, &idx)| (idx, c)).collect() 23 | } 24 | 25 | #[allow(dead_code)] 26 | pub fn print_sorted_reverse_vocab() { 27 | let mut sorted_keys: Vec<_> = REVERSE_VOCAB.keys().collect(); 28 | sorted_keys.sort(); 29 | 30 | for key in sorted_keys { 31 | eprintln!("{}: {}", key, REVERSE_VOCAB[key]); 32 | } 33 | } 34 | 35 | lazy_static! { 36 | pub static ref VOCAB: HashMap = get_vocab(); 37 | pub static ref REVERSE_VOCAB: HashMap = get_reverse_vocab(); 38 | } 39 | -------------------------------------------------------------------------------- /kokoros/src/utils/wav.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Write}; 2 | 3 | pub struct WavHeader { 4 | pub channels: u16, 5 | pub sample_rate: u32, 6 | pub bits_per_sample: u16, 7 | } 8 | 9 | impl WavHeader { 10 | pub fn new(channels: u16, sample_rate: u32, bits_per_sample: u16) -> Self { 11 | Self { 12 | channels, 13 | sample_rate, 14 | bits_per_sample, 15 | } 16 | } 17 | 18 | pub fn write_header(&self, writer: &mut W) -> io::Result<()> { 19 | // RIFF header 20 | writer.write_all(b"RIFF")?; 21 | writer.write_all(&[0xFF, 0xFF, 0xFF, 0xFF])?; // File size - 8 (placeholder) 22 | writer.write_all(b"WAVE")?; 23 | 24 | // Format chunk 25 | writer.write_all(b"fmt ")?; 26 | writer.write_all(&(16u32).to_le_bytes())?; // Format chunk size 27 | writer.write_all(&(3u16).to_le_bytes())?; // Format = 3 (IEEE float) 28 | writer.write_all(&self.channels.to_le_bytes())?; 29 | writer.write_all(&self.sample_rate.to_le_bytes())?; 30 | let byte_rate = 31 | self.sample_rate * u32::from(self.channels) * u32::from(self.bits_per_sample) / 8; 32 | writer.write_all(&byte_rate.to_le_bytes())?; 33 | let block_align = self.channels * self.bits_per_sample / 8; 34 | writer.write_all(&block_align.to_le_bytes())?; 35 | writer.write_all(&self.bits_per_sample.to_le_bytes())?; 36 | 37 | // Data chunk header 38 | writer.write_all(b"data")?; 39 | writer.write_all(&[0xFF, 0xFF, 0xFF, 0xFF])?; // Data size (placeholder) 40 | 41 | Ok(()) 42 | } 43 | } 44 | 45 | pub fn write_audio_chunk(writer: &mut W, samples: &[f32]) -> io::Result<()> { 46 | for sample in samples { 47 | writer.write_all(&sample.to_le_bytes())?; 48 | } 49 | Ok(()) 50 | } 51 | -------------------------------------------------------------------------------- /kokoros/src/utils/fileio.rs: -------------------------------------------------------------------------------- 1 | use indicatif::{ProgressBar, ProgressStyle}; 2 | use serde_json::Value; 3 | use std::{io::Read, path::Path}; 4 | use tokio::{fs::File, io::AsyncWriteExt}; 5 | 6 | pub async fn download_file_from_url( 7 | url: &str, 8 | path: &str, 9 | ) -> Result<(), Box> { 10 | if let Some(parent) = Path::new(path).parent() { 11 | std::fs::create_dir_all(parent)?; 12 | } 13 | 14 | let mut resp = reqwest::get(url).await?; 15 | 16 | if resp.status().is_success() { 17 | let total_size = resp.content_length().unwrap_or(0); 18 | 19 | eprintln!("Downloading {} - total size: {}", path, total_size); 20 | 21 | let pb = ProgressBar::new(total_size); 22 | pb.set_style(ProgressStyle::default_bar() 23 | .template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})") 24 | .unwrap() 25 | .progress_chars("#>-")); 26 | 27 | let mut file = File::create(path).await?; 28 | let mut downloaded = 0; 29 | 30 | while let Some(chunk) = resp.chunk().await? { 31 | file.write_all(&chunk).await?; 32 | downloaded += chunk.len(); 33 | pb.set_position(downloaded.try_into()?); 34 | } 35 | 36 | pb.finish_with_message("Download completed"); 37 | Ok(()) 38 | } else { 39 | Err(format!("Failed to download file: {}", resp.status()).into()) 40 | } 41 | } 42 | 43 | pub fn load_json_file(path: &str) -> Result { 44 | let file = std::fs::File::open(path); 45 | if file.is_err() { 46 | return Err(format!("failed to open file: {}", file.err().unwrap())); 47 | } 48 | 49 | let mut data = String::new(); 50 | file.unwrap() 51 | .read_to_string(&mut data) 52 | .map_err(|e| e.to_string())?; 53 | let json_value: Value = serde_json::from_str(&data).map_err(|e| e.to_string())?; 54 | 55 | Ok(json_value) 56 | } 57 | -------------------------------------------------------------------------------- /kokoros/src/utils/debug.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | // ANSI color codes for request ID colorization 4 | const COLORS: &[&str] = &[ 5 | "\x1b[31m", "\x1b[32m", "\x1b[33m", "\x1b[34m", "\x1b[35m", "\x1b[36m", 6 | "\x1b[91m", "\x1b[92m", "\x1b[93m", "\x1b[94m", "\x1b[95m", "\x1b[96m", 7 | "\x1b[37m", "\x1b[90m" 8 | ]; 9 | const RESET: &str = "\x1b[0m"; 10 | 11 | /// Get consistent color for a request ID using hash-based assignment 12 | pub fn get_request_id_color(request_id: &str) -> &'static str { 13 | let mut hash = 0u32; 14 | for byte in request_id.bytes() { 15 | hash = hash.wrapping_mul(31).wrapping_add(byte as u32); 16 | } 17 | let color_index = (hash as usize) % COLORS.len(); 18 | COLORS[color_index] 19 | } 20 | 21 | /// Format a debug prefix with colored request ID and instance ID 22 | pub fn format_debug_prefix(request_id: Option<&str>, instance_id: Option<&str>) -> String { 23 | match (request_id, instance_id) { 24 | (Some(req_id), Some(inst_id)) => { 25 | let color = get_request_id_color(req_id); 26 | format!("{}[{}]{}[{}]", color, req_id, RESET, inst_id) 27 | }, 28 | (Some(req_id), None) => { 29 | let color = get_request_id_color(req_id); 30 | format!("{}[{}]{}", color, req_id, RESET) 31 | }, 32 | (None, Some(inst_id)) => format!("[{}]", inst_id), 33 | (None, None) => String::new(), 34 | } 35 | } 36 | 37 | /// Get colored request ID with relative timing (enhanced version) 38 | pub fn get_colored_request_id_with_relative(request_id: &str, start_time: Instant) -> String { 39 | let color = get_request_id_color(request_id); 40 | 41 | // Get relative time from request start 42 | let elapsed_ms = start_time.elapsed().as_millis(); 43 | let relative_time = if elapsed_ms < 1 { 44 | " 0".to_string() // Show "0" right-aligned for initial request 45 | } else { 46 | format!("{:5}", elapsed_ms) // Right-aligned 5 digits 47 | }; 48 | 49 | format!("{}[{}]{} \x1b[90m{}\x1b[0m", color, request_id, RESET, relative_time) 50 | } -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check for required dependencies 4 | if [[ "$(uname)" == "Darwin" ]]; then 5 | if ! command -v brew &> /dev/null; then 6 | echo "Error: Homebrew is required for automatic dependency installation." 7 | echo "Please install Homebrew or manually install: pkg-config opus" 8 | exit 1 9 | fi 10 | 11 | if ! command -v pkg-config &> /dev/null; then 12 | echo "pkg-config not found. Installing via Homebrew..." 13 | brew install pkg-config 14 | fi 15 | 16 | if ! brew list opus &> /dev/null; then 17 | echo "opus not found. Installing via Homebrew..." 18 | brew install opus 19 | fi 20 | fi 21 | 22 | # Set variables 23 | VOICES_JSON_SRC="data/voices-v1.0.bin" 24 | VOICES_JSON_DEST="$HOME/.cache/kokoros/data.voices-v1.0.bin" 25 | KOKO_BIN_SRC="target/release/koko" 26 | KOKO_BIN_DEST="/usr/local/bin/koko" 27 | 28 | # Create the destination directory if it doesn't exist 29 | if [ ! -d "$(dirname "$VOICES_JSON_DEST")" ]; then 30 | echo "Creating directory: $(dirname "$VOICES_JSON_DEST")" 31 | mkdir -p "$(dirname "$VOICES_JSON_DEST")" 32 | fi 33 | 34 | # Copy voices to the cache directory 35 | if [ -f "$VOICES_JSON_SRC" ]; then 36 | echo "Copying $VOICES_JSON_SRC to $VOICES_JSON_DEST" 37 | cp "$VOICES_JSON_SRC" "$VOICES_JSON_DEST" 38 | else 39 | echo "Error: $VOICES_JSON_SRC not found. Aborting." 40 | exit 1 41 | fi 42 | 43 | # Copy koko binary to /usr/local/bin 44 | if [ -f "$KOKO_BIN_SRC" ]; then 45 | echo "Copying $KOKO_BIN_SRC to $KOKO_BIN_DEST" 46 | sudo cp "$KOKO_BIN_SRC" "$KOKO_BIN_DEST" 47 | else 48 | echo "$KOKO_BIN_SRC not found. Build for you..." 49 | cargo build --release 50 | echo "Copying $KOKO_BIN_SRC to $KOKO_BIN_DEST" 51 | sudo cp "$KOKO_BIN_SRC" "$KOKO_BIN_DEST" 52 | fi 53 | 54 | # Provide user feedback 55 | if [ $? -eq 0 ]; then 56 | echo "Installation completed successfully!" 57 | echo "Voices configuration: $VOICES_JSON_DEST" 58 | echo "Executable installed at: $KOKO_BIN_DEST" 59 | echo '🎉 now try in terminal: koko ' 60 | else 61 | echo "Installation encountered an error. Please check the messages above." 62 | exit 1 63 | fi 64 | -------------------------------------------------------------------------------- /kokoros/src/onn/ort_base.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "cuda")] 2 | use ort::execution_providers::cuda::CUDAExecutionProvider; 3 | use ort::execution_providers::cpu::CPUExecutionProvider; 4 | use ort::session::builder::SessionBuilder; 5 | use ort::session::Session; 6 | use ort::logging::LogLevel; 7 | 8 | pub trait OrtBase { 9 | fn load_model(&mut self, model_path: String) -> Result<(), String> { 10 | #[cfg(feature = "cuda")] 11 | let providers = [CUDAExecutionProvider::default().build()]; 12 | 13 | #[cfg(not(feature = "cuda"))] 14 | let providers = [CPUExecutionProvider::default().build()]; 15 | 16 | match SessionBuilder::new() { 17 | Ok(builder) => { 18 | let session = builder 19 | .with_execution_providers(providers) 20 | .map_err(|e| format!("Failed to build session: {}", e))? 21 | .with_log_level(LogLevel::Warning) 22 | .map_err(|e| format!("Failed to set log level: {}", e))? 23 | .commit_from_file(model_path) 24 | .map_err(|e| format!("Failed to commit from file: {}", e))?; 25 | self.set_sess(session); 26 | Ok(()) 27 | } 28 | Err(e) => Err(format!("Failed to create session builder: {}", e)), 29 | } 30 | } 31 | 32 | fn print_info(&self) { 33 | if let Some(session) = self.sess() { 34 | eprintln!("Input names:"); 35 | for input in &session.inputs { 36 | eprintln!(" - {}", input.name); 37 | } 38 | eprintln!("Output names:"); 39 | for output in &session.outputs { 40 | eprintln!(" - {}", output.name); 41 | } 42 | 43 | #[cfg(feature = "cuda")] 44 | eprintln!("Configured with: CUDA execution provider"); 45 | 46 | #[cfg(not(feature = "cuda"))] 47 | eprintln!("Configured with: CPU execution provider"); 48 | } else { 49 | eprintln!("Session is not initialized."); 50 | } 51 | } 52 | 53 | fn set_sess(&mut self, sess: Session); 54 | fn sess(&self) -> Option<&Session>; 55 | } 56 | -------------------------------------------------------------------------------- /kokoros/src/tts/tokenize.rs: -------------------------------------------------------------------------------- 1 | use crate::tts::vocab::VOCAB; 2 | 3 | /// Tokenizes the given phonemes string into a vector of token indices. 4 | /// 5 | /// This function takes a text string as input and converts it into a vector of token indices 6 | /// by looking up each character in the global `VOCAB` map and mapping it to the corresponding 7 | /// token index. The resulting vector contains the token indices for the input text. 8 | /// 9 | /// # Arguments 10 | /// * `text` - The input text string to be tokenized. 11 | /// 12 | /// # Returns 13 | /// A vector of `i64` token indices representing the input text. 14 | pub fn tokenize(phonemes: &str) -> Vec { 15 | phonemes 16 | .chars() 17 | .filter_map(|c| VOCAB.get(&c)) 18 | .map(|&idx| idx as i64) 19 | .collect() 20 | } 21 | 22 | #[cfg(test)] 23 | mod tests { 24 | use super::*; 25 | 26 | #[test] 27 | fn test_tokenize() { 28 | let text = "heɪ ðɪs ɪz ˈlʌvliː!"; 29 | let tokens = tokenize(text); 30 | 31 | // Expected tokens based on the vocabulary mapping defined in get_vocab() 32 | let expected = vec![24, 47, 54, 54, 57, 5]; 33 | 34 | assert_eq!(tokens, expected); 35 | 36 | // Test empty string 37 | let empty = ""; 38 | let empty_tokens = tokenize(empty); 39 | assert!(empty_tokens.is_empty()); 40 | 41 | // Test punctuation 42 | let punct = "..."; 43 | let punct_tokens = tokenize(punct); 44 | assert_eq!(punct_tokens.len(), 3); 45 | } 46 | } 47 | 48 | use crate::tts::vocab::REVERSE_VOCAB; 49 | 50 | pub fn tokens_to_phonemes(tokens: &[i64]) -> String { 51 | tokens 52 | .iter() 53 | .filter_map(|&t| REVERSE_VOCAB.get(&(t as usize))) 54 | .collect() 55 | } 56 | 57 | #[cfg(test)] 58 | mod tests2 { 59 | use super::*; 60 | 61 | #[test] 62 | fn test_tokens_to_phonemes() { 63 | let tokens = vec![24, 47, 54, 54, 57, 5]; 64 | let text = tokens_to_phonemes(&tokens); 65 | assert_eq!(text, "Hello!"); 66 | 67 | let tokens = vec![ 68 | 0, 50, 83, 54, 156, 57, 135, 3, 16, 65, 156, 87, 158, 54, 46, 5, 0, 69 | ]; 70 | 71 | let text = tokens_to_phonemes(&tokens); 72 | assert_eq!(text, "$həlˈoʊ, wˈɜːld!$"); 73 | 74 | // Test empty vector 75 | let empty_tokens: Vec = vec![]; 76 | assert_eq!(tokens_to_phonemes(&empty_tokens), ""); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /kokoros/src/utils/mp3.rs: -------------------------------------------------------------------------------- 1 | use mp3lame_encoder::{Builder, FlushNoGap, Id3Tag, MonoPcm}; 2 | 3 | pub fn pcm_to_mp3(pcm_data: &[f32], sample_rate: u32) -> Result, std::io::Error> { 4 | let mut mp3_encoder = Builder::new().ok_or(std::io::Error::new( 5 | std::io::ErrorKind::Other, 6 | format!("Encoder init failed"), 7 | ))?; 8 | 9 | mp3_encoder.set_num_channels(1).map_err(|e| { 10 | std::io::Error::new( 11 | std::io::ErrorKind::Other, 12 | format!("Set channels failed: {:?}", e), 13 | ) 14 | })?; 15 | mp3_encoder.set_sample_rate(sample_rate).map_err(|e| { 16 | std::io::Error::new( 17 | std::io::ErrorKind::Other, 18 | format!("Set sample rate failed: {:?}", e), 19 | ) 20 | })?; 21 | mp3_encoder 22 | .set_brate(mp3lame_encoder::Bitrate::Kbps192) 23 | .map_err(|e| { 24 | std::io::Error::new( 25 | std::io::ErrorKind::Other, 26 | format!("Set bitrate failed: {:?}", e), 27 | ) 28 | })?; 29 | mp3_encoder 30 | .set_quality(mp3lame_encoder::Quality::Best) 31 | .map_err(|e| { 32 | std::io::Error::new( 33 | std::io::ErrorKind::Other, 34 | format!("Set quality failed: {:?}", e), 35 | ) 36 | })?; 37 | 38 | let _ = mp3_encoder.set_id3_tag(Id3Tag { 39 | title: b"Generated Audio", 40 | artist: b"TTS Model", 41 | album: b"Synthesized Speech", 42 | year: b"Current year", 43 | album_art: &[], 44 | comment: b"Generated by TTS", 45 | }); 46 | 47 | let mut mp3_encoder = mp3_encoder.build().map_err(|e| { 48 | std::io::Error::new( 49 | std::io::ErrorKind::Other, 50 | format!("Build encoder failed: {:?}", e), 51 | ) 52 | })?; 53 | 54 | let pcm_i16: Vec = pcm_data 55 | .iter() 56 | .map(|&x| (x * i16::MAX as f32) as i16) 57 | .collect(); 58 | let pcm = MonoPcm(&pcm_i16); 59 | 60 | let mut mp3_out_buffer = Vec::new(); 61 | mp3_out_buffer.reserve(mp3lame_encoder::max_required_buffer_size(pcm.0.len())); 62 | 63 | let encoded_size = mp3_encoder 64 | .encode(pcm, mp3_out_buffer.spare_capacity_mut()) 65 | .map_err(|e| { 66 | std::io::Error::new( 67 | std::io::ErrorKind::Other, 68 | format!("Encoding failed: {:?}", e), 69 | ) 70 | })?; 71 | 72 | unsafe { 73 | mp3_out_buffer.set_len(mp3_out_buffer.len().wrapping_add(encoded_size)); 74 | } 75 | 76 | let flush_size = mp3_encoder 77 | .flush::(mp3_out_buffer.spare_capacity_mut()) 78 | .map_err(|e| { 79 | std::io::Error::new(std::io::ErrorKind::Other, format!("Flush failed: {:?}", e)) 80 | })?; 81 | unsafe { 82 | mp3_out_buffer.set_len(mp3_out_buffer.len().wrapping_add(flush_size)); 83 | } 84 | 85 | Ok(mp3_out_buffer) 86 | } 87 | -------------------------------------------------------------------------------- /kokoros/src/tts/phonemizer.rs: -------------------------------------------------------------------------------- 1 | use crate::tts::normalize; 2 | use crate::tts::vocab::VOCAB; 3 | use lazy_static::lazy_static; 4 | use regex::Regex; 5 | 6 | lazy_static! { 7 | static ref PHONEME_PATTERNS: Regex = Regex::new(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)").unwrap(); 8 | static ref Z_PATTERN: Regex = Regex::new(r#" z(?=[;:,.!?¡¿—…"«»"" ]|$)"#).unwrap(); 9 | static ref NINETY_PATTERN: Regex = Regex::new(r"(?<=nˈaɪn)ti(?!ː)").unwrap(); 10 | } 11 | 12 | // Placeholder for the EspeakBackend struct 13 | struct EspeakBackend { 14 | language: String, 15 | preserve_punctuation: bool, 16 | with_stress: bool, 17 | } 18 | 19 | impl EspeakBackend { 20 | fn new(language: &str, preserve_punctuation: bool, with_stress: bool) -> Self { 21 | EspeakBackend { 22 | language: language.to_string(), 23 | preserve_punctuation, 24 | with_stress, 25 | } 26 | } 27 | 28 | fn phonemize(&self, _text: &[String]) -> Option> { 29 | // Implementation would go here 30 | // This is where you'd integrate with actual espeak bindings 31 | todo!("Implement actual phonemization") 32 | } 33 | } 34 | 35 | pub struct Phonemizer { 36 | lang: String, 37 | backend: EspeakBackend, 38 | } 39 | 40 | impl Phonemizer { 41 | pub fn new(lang: &str) -> Self { 42 | let backend = match lang { 43 | "a" => EspeakBackend::new("en-us", true, true), 44 | "b" => EspeakBackend::new("en-gb", true, true), 45 | _ => panic!("Unsupported language"), 46 | }; 47 | 48 | Phonemizer { 49 | lang: lang.to_string(), 50 | backend, 51 | } 52 | } 53 | 54 | pub fn phonemize(&self, text: &str, normalize: bool) -> String { 55 | let text = if normalize { 56 | normalize::normalize_text(text) 57 | } else { 58 | text.to_string() 59 | }; 60 | 61 | // Assume phonemize returns Option 62 | let mut ps = match self.backend.phonemize(&[text]) { 63 | Some(phonemes) => phonemes[0].clone(), 64 | None => String::new(), 65 | }; 66 | 67 | // Apply kokoro-specific replacements 68 | ps = ps 69 | .replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ") 70 | .replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ"); 71 | 72 | // Apply character replacements 73 | ps = ps 74 | .replace("ʲ", "j") 75 | .replace("r", "ɹ") 76 | .replace("x", "k") 77 | .replace("ɬ", "l"); 78 | 79 | // Apply regex patterns 80 | ps = PHONEME_PATTERNS.replace_all(&ps, " ").to_string(); 81 | ps = Z_PATTERN.replace_all(&ps, "z").to_string(); 82 | 83 | if self.lang == "a" { 84 | ps = NINETY_PATTERN.replace_all(&ps, "di").to_string(); 85 | } 86 | 87 | // Filter characters present in vocabulary 88 | ps = ps.chars().filter(|&c| VOCAB.contains_key(&c)).collect(); 89 | 90 | ps.trim().to_string() 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /kokoros/src/tts/normalize.rs: -------------------------------------------------------------------------------- 1 | use lazy_static::lazy_static; 2 | use regex::Regex; 3 | 4 | lazy_static! { 5 | static ref WHITESPACE_RE: Regex = Regex::new(r"[^\S \n]").unwrap(); 6 | static ref MULTI_SPACE_RE: Regex = Regex::new(r" +").unwrap(); 7 | static ref NEWLINE_SPACE_RE: Regex = Regex::new(r"(?<=\n) +(?=\n)").unwrap(); 8 | static ref DOCTOR_RE: Regex = Regex::new(r"\bD[Rr]\.(?= [A-Z])").unwrap(); 9 | static ref MISTER_RE: Regex = Regex::new(r"\b(?:Mr\.|MR\.(?= [A-Z]))").unwrap(); 10 | static ref MISS_RE: Regex = Regex::new(r"\b(?:Ms\.|MS\.(?= [A-Z]))").unwrap(); 11 | static ref MRS_RE: Regex = Regex::new(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))").unwrap(); 12 | static ref ETC_RE: Regex = Regex::new(r"\betc\.(?! [A-Z])").unwrap(); 13 | static ref YEAH_RE: Regex = Regex::new(r"(?i)\b(y)eah?\b").unwrap(); 14 | static ref NUMBERS_RE: Regex = 15 | Regex::new(r"\d*\.\d+|\b\d{4}s?\b|(? String { 31 | let mut text = text.to_string(); 32 | 33 | // Replace special quotes and brackets 34 | text = text.replace('\u{2018}', "'").replace('\u{2019}', "'"); 35 | text = text.replace('«', "\u{201C}").replace('»', "\u{201D}"); 36 | text = text.replace('\u{201C}', "\"").replace('\u{201D}', "\""); 37 | text = text.replace('(', "«").replace(')', "»"); 38 | 39 | // Replace Chinese/Japanese punctuation 40 | let from_chars = ['、', '。', '!', ',', ':', ';', '?']; 41 | let to_chars = [',', '.', '!', ',', ':', ';', '?']; 42 | 43 | for (from, to) in from_chars.iter().zip(to_chars.iter()) { 44 | text = text.replace(*from, &format!("{} ", to)); 45 | } 46 | 47 | // Apply regex replacements 48 | text = WHITESPACE_RE.replace_all(&text, " ").to_string(); 49 | text = MULTI_SPACE_RE.replace_all(&text, " ").to_string(); 50 | text = NEWLINE_SPACE_RE.replace_all(&text, "").to_string(); 51 | text = DOCTOR_RE.replace_all(&text, "Doctor").to_string(); 52 | text = MISTER_RE.replace_all(&text, "Mister").to_string(); 53 | text = MISS_RE.replace_all(&text, "Miss").to_string(); 54 | text = MRS_RE.replace_all(&text, "Mrs").to_string(); 55 | text = ETC_RE.replace_all(&text, "etc").to_string(); 56 | text = YEAH_RE.replace_all(&text, "${1}e'a").to_string(); 57 | // Note: split_num, flip_money, and point_num functions need to be implemented 58 | text = COMMA_NUM_RE.replace_all(&text, "").to_string(); 59 | text = RANGE_RE.replace_all(&text, " to ").to_string(); 60 | text = S_AFTER_NUM_RE.replace_all(&text, " S").to_string(); 61 | text = POSSESSIVE_RE.replace_all(&text, "'S").to_string(); 62 | text = X_POSSESSIVE_RE.replace_all(&text, "s").to_string(); 63 | 64 | // Handle initials and acronyms 65 | text = INITIALS_RE 66 | .replace_all(&text, |caps: ®ex::Captures| caps[0].replace('.', "-")) 67 | .to_string(); 68 | text = ACRONYM_RE.replace_all(&text, "-").to_string(); 69 | 70 | text.trim().to_string() 71 | } 72 | -------------------------------------------------------------------------------- /kokoros/src/utils/opus.rs: -------------------------------------------------------------------------------- 1 | use opus::{Encoder, Channels, Application, Bitrate}; 2 | use ogg::{PacketWriter, PacketWriteEndInfo}; 3 | use std::io::Cursor; 4 | use std::time::{SystemTime, UNIX_EPOCH}; 5 | 6 | pub fn pcm_to_opus_ogg(pcm_data: &[f32], sample_rate: u32) -> Result, std::io::Error> { 7 | // 1. Initialize Opus encoder with Audio application (better for high quality TTS) 8 | let mut encoder = Encoder::new(sample_rate, Channels::Mono, Application::Audio) 9 | .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("Encoder init failed: {:?}", e)))?; 10 | 11 | encoder.set_bitrate(Bitrate::Bits(64000)) 12 | .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("Set bitrate failed: {:?}", e)))?; 13 | 14 | // Get strict pre-skip value from the encoder 15 | let pre_skip = encoder.get_lookahead() 16 | .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("Get lookahead failed: {:?}", e)))? as u16; 17 | 18 | // output buffer 19 | let mut ogg_buffer = Cursor::new(Vec::new()); 20 | let mut packet_writer = PacketWriter::new(&mut ogg_buffer); 21 | 22 | let serial_no = SystemTime::now() 23 | .duration_since(UNIX_EPOCH) 24 | .map(|d| d.subsec_nanos()) 25 | .unwrap_or(1); 26 | 27 | // --- 2. Create header packet into OpusHead --- 28 | let mut id_header = Vec::new(); 29 | id_header.extend_from_slice(b"OpusHead"); 30 | id_header.push(1); // Version 31 | id_header.push(1); // Channels 32 | id_header.extend_from_slice(&pre_skip.to_le_bytes()); // Pre-skip (Corrected) 33 | id_header.extend_from_slice(&sample_rate.to_le_bytes()); // Input Sample Rate 34 | id_header.extend_from_slice(&0u16.to_le_bytes()); // Gain 35 | id_header.push(0); // Mapping Family 36 | 37 | packet_writer.write_packet(id_header, serial_no, PacketWriteEndInfo::EndPage, 0) 38 | .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; 39 | 40 | // --- 3. Create comment header into OpusTags --- 41 | let comments = vec![ 42 | ("TITLE", "Generated Audio"), 43 | ("ENCODER", "Kokoros TTS"), 44 | ]; 45 | 46 | let mut comment_header = Vec::new(); 47 | comment_header.extend_from_slice(b"OpusTags"); 48 | 49 | let vendor = b"Rust Opus Encoder"; 50 | comment_header.extend_from_slice(&(vendor.len() as u32).to_le_bytes()); 51 | comment_header.extend_from_slice(vendor); 52 | 53 | comment_header.extend_from_slice(&(comments.len() as u32).to_le_bytes()); 54 | 55 | for (key, value) in comments { 56 | let comment_str = format!("{}={}", key, value); 57 | let comment_bytes = comment_str.as_bytes(); 58 | comment_header.extend_from_slice(&(comment_bytes.len() as u32).to_le_bytes()); 59 | comment_header.extend_from_slice(comment_bytes); 60 | } 61 | 62 | packet_writer.write_packet(comment_header, serial_no, PacketWriteEndInfo::EndPage, 0) 63 | .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; 64 | 65 | // --- 4. Encode audio data --- 66 | let frame_size = (sample_rate as usize * 20) / 1000; // 20ms frames 67 | // Output buffer recommendation: 4000 bytes is generally enough for max Opus frame 68 | let mut output_buffer = vec![0u8; 4000]; 69 | 70 | let chunks: Vec<&[f32]> = pcm_data.chunks(frame_size).collect(); 71 | let total_chunks = chunks.len(); 72 | let mut samples_processed: u64 = 0; // Track total input samples to avoid drift 73 | 74 | for (i, chunk) in chunks.iter().enumerate() { 75 | let is_last_chunk = i == total_chunks - 1; 76 | 77 | // Padding for last chunk 78 | let input_frame = if chunk.len() < frame_size { 79 | let mut padded = chunk.to_vec(); 80 | padded.resize(frame_size, 0.0); 81 | std::borrow::Cow::Owned(padded) 82 | } else { 83 | std::borrow::Cow::Borrowed(*chunk) 84 | }; 85 | 86 | let encoded_len = encoder.encode_float(&input_frame, &mut output_buffer) 87 | .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("Encoding failed: {:?}", e)))?; 88 | 89 | // Calculate Granule Position based on TOTAL processed input samples 90 | // This avoids floating point accumulation errors. 91 | // Formula: GP = (Total Input Samples * 48000) / Input Sample Rate 92 | samples_processed += chunk.len() as u64; 93 | 94 | let granule_pos = (samples_processed * 48000) / sample_rate as u64; 95 | 96 | let end_info = if is_last_chunk { 97 | PacketWriteEndInfo::EndStream 98 | } else { 99 | PacketWriteEndInfo::NormalPacket 100 | }; 101 | 102 | let packet_data = output_buffer[..encoded_len].to_vec(); 103 | 104 | packet_writer.write_packet(packet_data, serial_no, end_info, granule_pos) 105 | .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; 106 | } 107 | 108 | drop(packet_writer); 109 | 110 | Ok(ogg_buffer.into_inner()) 111 | } -------------------------------------------------------------------------------- /kokoros/src/onn/ort_koko.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Cow; 2 | 3 | use ndarray::{ArrayBase, IxDyn, OwnedRepr}; 4 | use ort::{ 5 | session::{Session, SessionInputValue, SessionInputs, SessionOutputs}, 6 | value::{Tensor, Value}, 7 | }; 8 | use model_schema::v1_0_timestamped::DURATIONS; 9 | use super::ort_base; 10 | use ort_base::OrtBase; 11 | use crate::utils::debug::format_debug_prefix; 12 | 13 | mod model_schema { 14 | pub const STYLE: &str = "style"; 15 | pub const SPEED: &str = "speed"; 16 | 17 | pub mod v1_0 { 18 | pub const TOKENS: &str = "tokens"; 19 | pub const AUDIO: &str = "audio"; 20 | } 21 | 22 | pub mod v1_0_timestamped { 23 | pub const TOKENS: &str = "input_ids"; 24 | pub const AUDIO: &str = "waveform"; 25 | // We define primary and fallback keys as a const array 26 | pub const DURATIONS: &str = "durations"; 27 | } 28 | } 29 | 30 | pub enum ModelStrategy { 31 | Standard(Session), 32 | Timestamped(Session), 33 | } 34 | 35 | pub struct OrtKoko { 36 | inner: Option, 37 | } 38 | 39 | impl ModelStrategy { 40 | fn audio_key(&self) -> &'static str { 41 | match self { 42 | ModelStrategy::Standard(_) => model_schema::v1_0::AUDIO, 43 | ModelStrategy::Timestamped(_) => model_schema::v1_0_timestamped::AUDIO, 44 | } 45 | } 46 | 47 | fn tokens_key(&self) -> &'static str { 48 | match self { 49 | ModelStrategy::Standard(_) => model_schema::v1_0::TOKENS, 50 | ModelStrategy::Timestamped(_) => model_schema::v1_0_timestamped::TOKENS, 51 | } 52 | } 53 | } 54 | 55 | impl OrtBase for OrtKoko { 56 | fn set_sess(&mut self, sess: Session) { 57 | let output_count = sess.outputs.len(); 58 | 59 | let strategy = if output_count > 1 { 60 | tracing::info!("OrtKoko: Timestamped backend activated ({} outputs)", output_count); 61 | ModelStrategy::Timestamped(sess) 62 | } else { 63 | tracing::info!("OrtKoko: Standard backend activated ({} output)", output_count); 64 | ModelStrategy::Standard(sess) 65 | }; 66 | 67 | self.inner = Some(strategy); 68 | } 69 | 70 | fn sess(&self) -> Option<&Session> { 71 | self.inner.as_ref().map(|strategy| match strategy { 72 | ModelStrategy::Standard(sess) => sess, 73 | ModelStrategy::Timestamped(sess) => sess, 74 | }) 75 | } 76 | } 77 | impl OrtKoko { 78 | pub fn new(model_path: String) -> Result { 79 | let mut instance = OrtKoko { inner: None }; 80 | instance.load_model(model_path)?; 81 | Ok(instance) 82 | } 83 | 84 | pub fn strategy(&self) -> Option<&ModelStrategy> { 85 | self.inner.as_ref() 86 | } 87 | 88 | fn prepare_inputs( 89 | tokens_key: &'static str, 90 | tokens: Vec>, 91 | styles: Vec>, 92 | speed: f32, 93 | ) -> Result, SessionInputValue<'static>)>, Box> { 94 | let shape = [tokens.len(), tokens[0].len()]; 95 | let tokens_tensor = Tensor::from_array((shape, tokens.into_iter().flatten().collect::>()))?; 96 | 97 | let shape_style = [styles.len(), styles[0].len()]; 98 | let style_tensor = Tensor::from_array((shape_style, styles.into_iter().flatten().collect::>()))?; 99 | 100 | let speed_tensor = Tensor::from_array(([1], vec![speed]))?; 101 | 102 | Ok(vec![ 103 | (Cow::Borrowed(tokens_key), SessionInputValue::Owned(Value::from(tokens_tensor))), 104 | (Cow::Borrowed(model_schema::STYLE), SessionInputValue::Owned(Value::from(style_tensor))), 105 | (Cow::Borrowed(model_schema::SPEED), SessionInputValue::Owned(Value::from(speed_tensor))), 106 | ]) 107 | } 108 | 109 | pub fn infer( 110 | &mut self, 111 | tokens: Vec>, 112 | styles: Vec>, 113 | speed: f32, 114 | request_id: Option<&str>, 115 | instance_id: Option<&str>, 116 | chunk_number: Option, 117 | ) -> Result<(ArrayBase, IxDyn>, Option>), Box> { 118 | 119 | let debug_prefix = format_debug_prefix(request_id, instance_id); 120 | let chunk_info = chunk_number.map(|n| format!("Chunk: {}, ", n)).unwrap_or_default(); 121 | tracing::debug!("{} {}inference start. Tokens: {}", debug_prefix, chunk_info, tokens.len()); 122 | 123 | let strategy = self.inner.as_mut().ok_or("Session is not initialized.")?; 124 | let audio_key = strategy.audio_key(); 125 | let tokens_key = strategy.tokens_key(); 126 | let inputs = Self::prepare_inputs(tokens_key, tokens.clone(), styles, speed)?; 127 | match strategy { 128 | ModelStrategy::Standard(sess) => { 129 | let outputs = sess.run(SessionInputs::from(inputs))?; 130 | 131 | let (shape, data) = outputs[audio_key] 132 | .try_extract_tensor::() 133 | .or_else(|_| outputs["waveforms"].try_extract_tensor::()) 134 | .map_err(|_| "Standard Model: Could not find 'audio' output")?; 135 | 136 | let shape_vec: Vec = shape.into_iter().map(|&i| i as usize).collect(); 137 | let audio_array = ArrayBase::from_shape_vec(shape_vec, data.to_vec())?; 138 | 139 | Ok((audio_array, None)) 140 | } 141 | ModelStrategy::Timestamped(sess) => { 142 | let outputs = sess.run(SessionInputs::from(inputs))?; 143 | 144 | let (shape, data) = outputs[audio_key] 145 | .try_extract_tensor::() 146 | .or_else(|_| outputs["audio"].try_extract_tensor::()) 147 | .map_err(|_| "Timestamped Model: Could not find 'waveforms' or 'audio'")?; 148 | 149 | let shape_vec: Vec = shape.into_iter().map(|&i| i as usize).collect(); 150 | let audio_array = ArrayBase::from_shape_vec(shape_vec, data.to_vec())?; 151 | 152 | let durations_vec = outputs[DURATIONS] 153 | .try_extract_tensor::() 154 | .map(|(_, d)| d.to_vec()) 155 | .map_err(|_| format!( 156 | "Timestamped Model Error: Expected output tensor '{}' of type f32. \ 157 | If your model uses 'duration' (singular) or i64, please update the schema constants.", 158 | DURATIONS 159 | ))?; 160 | 161 | Ok((audio_array, Some(durations_vec))) 162 | } 163 | } 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | Banner 3 |
4 |
5 |

🔥🔥🔥 Kokoro Rust

6 | 7 | ## [Zonos Rust Is On The Way?](https://github.com/lucasjinreal/Kokoros/issues/60) 8 | ## [Spark-TTS On The Way?](https://github.com/lucasjinreal/Kokoros/issues/75) 9 | ## [Orpheus-TTS On The Way?](https://github.com/lucasjinreal/Kokoros/issues/75) 10 | 11 | 12 | **ASMR** 13 | 14 | https://github.com/user-attachments/assets/1043dfd3-969f-4e10-8b56-daf8285e7420 15 | 16 | (typo in video, ignore it) 17 | 18 | **Digital Human** 19 | 20 | https://github.com/user-attachments/assets/9f5e8fe9-d352-47a9-b4a1-418ec1769567 21 | 22 |

23 | Give a star ⭐ if you like it! 24 |

25 | 26 | [Kokoro](https://huggingface.co/hexgrad/Kokoro-82M) is a trending top 2 TTS model on huggingface. 27 | This repo provides **insanely fast Kokoro infer in Rust**, you can now have your built TTS engine powered by Kokoro and infer fast by only a command of `koko`. 28 | 29 | `kokoros` is a `rust` crate that provides easy to use TTS ability. 30 | One can directly call `koko` in terminal to synthesize audio. 31 | 32 | `kokoros` uses a relative small model 87M params, while results in extremly good quality voices results. 33 | 34 | Languge support: 35 | 36 | - [x] English; 37 | - [x] Chinese (partly); 38 | - [x] Japanese (partly); 39 | - [x] German (partly); 40 | 41 | > 🔥🔥🔥🔥🔥🔥🔥🔥🔥 Kokoros Rust version just got a lot attention now. If you also interested in insanely fast inference, embeded build, wasm support etc, please star this repo! We are keep updating it. 42 | 43 | New Discord community: https://discord.gg/E566zfDWqD, Please join us if you interested in Rust Kokoro. 44 | 45 | ## Updates 46 | 47 | - **_`2025.07.12`_**: 🔥🔥🔥 **HTTP API streaming and parallel processing infrastructure.** OpenAI-compatible server supports streaming audio generation with `"stream": true` achieving 1-2s time-to-first-audio, work-in-progress parallel TTS processing with `--instances` flag support, improved logging system with Unix timestamps, and natural-sounding voice generation through advanced chunking; 48 | - **_`2025.01.22`_**: 🔥🔥🔥 **CLI streaming mode supported.** You can now using `--stream` to have fun with stream mode, kudos to [mroigo](https://github.com/mrorigo); 49 | - **_`2025.01.17`_**: 🔥🔥🔥 Style mixing supported! Now, listen the output AMSR effect by simply specific style: `af_sky.4+af_nicole.5`; 50 | - **_`2025.01.15`_**: OpenAI compatible server supported, openai format still under polish! 51 | - **_`2025.01.15`_**: Phonemizer supported! Now `Kokoros` can inference E2E without anyother dependencies! Kudos to [@tstm](https://github.com/tstm); 52 | - **_`2025.01.13`_**: Espeak-ng tokenizer and phonemizer supported! Kudos to [@mindreframer](https://github.com/mindreframer) ; 53 | - **_`2025.01.12`_**: Released `Kokoros`; 54 | 55 | ## Prerequisites 56 | 57 | To build this project locally, you need the following system dependencies: 58 | 59 | ### macOS 60 | ```bash 61 | brew install pkg-config opus 62 | ``` 63 | 64 | ### Linux (Ubuntu/Debian) 65 | ```bash 66 | sudo apt-get install pkg-config libopus-dev 67 | ``` 68 | 69 | ## Installation 70 | 71 | 1. Download the required model and voice data files: 72 | 73 | ```bash 74 | bash download_all.sh 75 | ``` 76 | 77 | This will download: 78 | - The Kokoro ONNX model (`checkpoints/kokoro-v1.0.onnx`) 79 | - The voices data file (`data/voices-v1.0.bin`) 80 | 81 | Alternatively, you can download them separately: 82 | ```bash 83 | bash scripts/download_models.sh 84 | bash scripts/download_voices.sh 85 | ``` 86 | 87 | 2. Build the project: 88 | 89 | ```bash 90 | cargo build --release 91 | ``` 92 | 93 | 3. (Optional) Install Python dependencies for OpenAI client examples: 94 | 95 | ```bash 96 | pip install -r scripts/requirements.txt 97 | ``` 98 | 99 | 4. (Optional) Install the binary and voice data system-wide: 100 | 101 | ```bash 102 | bash install.sh 103 | ``` 104 | 105 | This will copy the `koko` binary to `/usr/local/bin` (making it available system-wide as `koko`) and copy the voice data to `$HOME/.cache/kokoros/`. 106 | 107 | ## Usage 108 | 109 | ### View available options 110 | 111 | ```bash 112 | ./target/release/koko -h 113 | ``` 114 | 115 | ### Generate speech for some text 116 | 117 | ``` 118 | mkdir -p tmp 119 | ./target/release/koko text "Hello, this is a TTS test" 120 | ``` 121 | 122 | The generated audio will be saved to `tmp/output.wav` by default. You can customize the save location with the `--output` or `-o` option: 123 | 124 | ``` 125 | ./target/release/koko text "I hope you're having a great day today!" --output greeting.wav 126 | ``` 127 | 128 | ### Generate speech for each line in a file 129 | 130 | ``` 131 | ./target/release/koko file poem.txt 132 | ``` 133 | 134 | For a file with 3 lines of text, by default, speech audio files `tmp/output_0.wav`, `tmp/output_1.wav`, `tmp/output_2.wav` will be outputted. You can customize the save location with the `--output` or `-o` option, using `{line}` as the line number: 135 | 136 | ``` 137 | ./target/release/koko file lyrics.txt -o "song/lyric_{line}.wav" 138 | ``` 139 | 140 | ### Word-level timestamps (TSV sidecar) 141 | 142 | Add `--timestamps` to produce a `.tsv` file with per-word timings alongside the WAV output. The TSV contains three columns: `word`, `start_sec`, `end_sec`. 143 | 144 | Text mode example: 145 | 146 | ``` 147 | ./target/release/koko text \ 148 | --output tmp/output.wav \ 149 | --timestamps \ 150 | "Hello from the timestamped model" 151 | ``` 152 | 153 | This creates: 154 | - `tmp/output.wav` 155 | - `tmp/output.tsv` 156 | 157 | File mode example (one pair per line): 158 | 159 | ``` 160 | ./target/release/koko file input.txt \ 161 | --output tmp/line_{line}.wav \ 162 | --timestamps 163 | ``` 164 | 165 | For each line N, this creates `tmp/line_N.wav` and `tmp/line_N.tsv`. 166 | 167 | Notes: 168 | - The sidecar path is derived automatically by replacing the `.wav` extension with `.tsv`. 169 | - Sample rate is 24 kHz by default; times are in seconds with 3 decimal places. 170 | 171 | #### Quick start with the Hugging Face timestamped model (copy-paste) 172 | 173 | Copy and paste the following to run an end-to-end example using the timestamped Kokoro ONNX model hosted on Hugging Face. This will download the model and voice data to the expected paths and generate both `output.wav` and `output.tsv`. 174 | 175 | ``` 176 | mkdir -p checkpoints data tmp 177 | 178 | # 1) Download the timestamped ONNX model from Hugging Face 179 | curl -L \ 180 | "https://huggingface.co/onnx-community/Kokoro-82M-v1.0-ONNX-timestamped/resolve/main/onnx/model.onnx" \ 181 | -o checkpoints/kokoro-v1.0.onnx 182 | 183 | # 2) Download voices data (single binary used by existing models) 184 | curl -L \ 185 | "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/voices-v1.0.bin" \ 186 | -o data/voices-v1.0.bin 187 | 188 | # 3) Build the binary 189 | cargo build --release 190 | 191 | # 4) Run: generates tmp/output.wav and tmp/output.tsv 192 | ./target/release/koko text \ 193 | --output tmp/output.wav \ 194 | --timestamps \ 195 | "Hello from the timestamped model" 196 | ``` 197 | 198 | Notes: 199 | - We keep using the unified `voices-v1.0.bin`, which is compatible with the timestamped model. 200 | - If the files already exist in `checkpoints/` and `data/`, the CLI will use them directly. 201 | 202 | ### Parallel Processing Configuration 203 | 204 | Configure parallel TTS instances for the OpenAI-compatible server based on your performance preference: 205 | 206 | ``` 207 | # Best 0.5-2 seconds time-to-first-audio (lowest latency) 208 | ./target/release/koko openai --instances 1 209 | 210 | # Balanced performance (default, 2 instances, usually best throughput for CPU processing) 211 | ./target/release/koko openai 212 | 213 | # Best total processing time (Diminishing returns on CPU processing observed on Mac M2) 214 | ./target/release/koko openai --instances 4 215 | ``` 216 | 217 | **How to determine the optimal number of instances for your system configuration?** 218 | Choose your configuration based on use case: 219 | - Single instance for real-time applications requiring immediate audio response irrespective of system configuration. 220 | - Multiple instances for batch processing where total completion time matters more than initial latency. 221 | - This was benchmarked on a Mac M2 with 8 cores and 24GB RAM. 222 | - Tested with the message: 223 | > Welcome to our comprehensive technology demonstration session. Today we will explore advanced parallel processing systems thoroughly. These systems utilize multiple computational instances simultaneously for efficiency. Each instance processes different segments concurrently without interference. The coordination between instances ensures seamless output delivery consistently. Modern algorithms optimize resource utilization effectively across all components. Performance improvements are measurable and significant in real scenarios. Quality assurance validates each processing stage thoroughly before deployment. Integration testing confirms system reliability consistently under various conditions. User experience remains smooth throughout operation regardless of complexity. Advanced monitoring tracks system performance metrics continuously during execution. 224 | - Benchmark results (avg of 5) 225 | | No. of instances | TTFA | Total time | 226 | |------------------|------|------------| 227 | | 1 | 1.44s | 19.0s | 228 | | 2 | 2.44s | 16.1s | 229 | | 4 | 4.98s | 16.6s | 230 | - If you have a CPU, memory bandwidth will be the usual bottleneck. You will have to experiment to find a sweet spot of number of instances giving you optimal throughput on your system configuration. 231 | - If you have a NVIDIA GPU, you can try increasing the number of instances. You are expected to further improve throughput. 232 | - Attempts to [make this work on CoreML](https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html), would likely start with converting the ONNX model to CoreML or ORT. 233 | 234 | *Note: The `--instances` flag is currently supported in API server mode. CLI text commands will support parallel processing in future releases.* 235 | 236 | ### OpenAI-Compatible Server 237 | 238 | 1. Start the server: 239 | 240 | ```bash 241 | ./target/release/koko openai 242 | ``` 243 | 244 | 2. Make API requests using either curl or Python: 245 | 246 | Using curl: 247 | 248 | ```bash 249 | # Standard audio generation 250 | curl -X POST http://localhost:3000/v1/audio/speech \ 251 | -H "Content-Type: application/json" \ 252 | -d '{ 253 | "model": "tts-1", 254 | "input": "Hello, this is a test of the Kokoro TTS system!", 255 | "voice": "af_sky" 256 | }' \ 257 | --output sky-says-hello.wav 258 | 259 | # Streaming audio generation (PCM format only) 260 | curl -X POST http://localhost:3000/v1/audio/speech \ 261 | -H "Content-Type: application/json" \ 262 | -d '{ 263 | "model": "tts-1", 264 | "input": "This is a streaming test with real-time audio generation.", 265 | "voice": "af_sky", 266 | "stream": true 267 | }' \ 268 | --output streaming-audio.pcm 269 | 270 | # Live streaming playback (requires ffplay) 271 | curl -s -X POST http://localhost:3000/v1/audio/speech \ 272 | -H "Content-Type: application/json" \ 273 | -d '{ 274 | "model": "tts-1", 275 | "input": "Hello streaming world!", 276 | "voice": "af_sky", 277 | "stream": true 278 | }' | \ 279 | ffplay -f s16le -ar 24000 -nodisp -autoexit -loglevel quiet - 280 | ``` 281 | 282 | Using Python: 283 | 284 | ```bash 285 | python scripts/run_openai.py 286 | ``` 287 | 288 | ### Streaming 289 | 290 | The `stream` option will start the program, reading for lines of input from stdin and outputting WAV audio to stdout. 291 | 292 | Use it in conjunction with piping. 293 | 294 | #### Typing manually 295 | 296 | ``` 297 | ./target/release/koko stream > live-audio.wav 298 | # Start typing some text to generate speech for and hit enter to submit 299 | # Speech will append to `live-audio.wav` as it is generated 300 | # Hit Ctrl D to exit 301 | ``` 302 | 303 | #### Input from another source 304 | 305 | ``` 306 | echo "Suppose some other program was outputting lines of text" | ./target/release/koko stream > programmatic-audio.wav 307 | ``` 308 | 309 | ### With docker 310 | 311 | 1. Build or Pull Docker Image 312 | 313 | You can either **build the Docker image locally** or **pull the pre-built image from GitHub Container Registry (GHCR)**. 314 | 315 | ```bash 316 | # Build locally 317 | docker build -t kokoros . 318 | 319 | # Or pull pre-built image from GHCR 320 | docker pull ghcr.io/lucasjinreal/kokoros:main 321 | ``` 322 | 323 | 2. Run the image, passing options as described above 324 | 325 | ```bash 326 | # Basic text to speech 327 | docker run -v ./tmp:/app/tmp kokoros text "Hello from docker!" -o tmp/hello.wav 328 | 329 | # An OpenAI server (with appropriately bound port) 330 | docker run -p 3000:3000 kokoros openai 331 | ``` 332 | 333 | ## Roadmap 334 | 335 | Due to Kokoro actually not finalizing it's ability, this repo will keep tracking the status of Kokoro, and helpfully we can have language support incuding: English, Mandarin, Japanese, German, French etc. 336 | 337 | ## Copyright 338 | 339 | Copyright reserved by Lucas Jin under Apache License. 340 | -------------------------------------------------------------------------------- /koko/src/main.rs: -------------------------------------------------------------------------------- 1 | use clap::{Parser, Subcommand}; 2 | use kokoros::{ 3 | tts::koko::{TTSKoko, TTSOpts}, 4 | utils::wav::{write_audio_chunk, WavHeader}, 5 | }; 6 | use std::net::{IpAddr, SocketAddr}; 7 | use std::{ 8 | fs::{self}, 9 | io::Write, 10 | path::Path, 11 | }; 12 | use tokio::io::{AsyncBufReadExt, BufReader}; 13 | use tracing_subscriber::fmt::time::FormatTime; 14 | 15 | /// Custom Unix timestamp formatter for tracing logs 16 | struct UnixTimestampFormatter; 17 | 18 | impl FormatTime for UnixTimestampFormatter { 19 | fn format_time(&self, w: &mut tracing_subscriber::fmt::format::Writer<'_>) -> std::fmt::Result { 20 | let now = std::time::SystemTime::now() 21 | .duration_since(std::time::UNIX_EPOCH) 22 | .unwrap(); 23 | let timestamp = format!("{}.{:06}", now.as_secs(), now.subsec_micros()); 24 | write!(w, "{}", timestamp) 25 | } 26 | } 27 | 28 | #[derive(Subcommand, Debug)] 29 | enum Mode { 30 | /// Generate speech for a string of text 31 | #[command(alias = "t", long_flag_alias = "text", short_flag_alias = 't')] 32 | Text { 33 | /// Text to generate speech for 34 | #[arg( 35 | default_value = "Hello, This is Kokoro, your remarkable AI TTS. It's a TTS model with merely 82 million parameters yet delivers incredible audio quality. 36 | This is one of the top notch Rust based inference models, and I'm sure you'll love it. If you do, please give us a star. Thank you very much. 37 | As the night falls, I wish you all a peaceful and restful sleep. May your dreams be filled with joy and happiness. Good night, and sweet dreams!" 38 | )] 39 | text: String, 40 | 41 | /// Path to output the WAV file to on the filesystem 42 | #[arg( 43 | short = 'o', 44 | long = "output", 45 | value_name = "OUTPUT_PATH", 46 | default_value = "tmp/output.wav" 47 | )] 48 | save_path: String, 49 | }, 50 | 51 | /// Read from a file path and generate a speech file for each line 52 | #[command(alias = "f", long_flag_alias = "file", short_flag_alias = 'f')] 53 | File { 54 | /// Filesystem path to read lines from 55 | input_path: String, 56 | 57 | /// Format for the output path of each WAV file, where {line} will be replaced with the line number 58 | #[arg( 59 | short = 'o', 60 | long = "output", 61 | value_name = "OUTPUT_PATH_FORMAT", 62 | default_value = "tmp/output_{line}.wav" 63 | )] 64 | save_path_format: String, 65 | }, 66 | 67 | /// Continuously read from stdin to generate speech, outputting to stdout, for each line 68 | #[command(aliases = ["stdio", "stdin", "-"], long_flag_aliases = ["stdio", "stdin"])] 69 | Stream, 70 | 71 | /// Start an OpenAI-compatible HTTP server 72 | #[command(name = "openai", alias = "oai", long_flag_aliases = ["oai", "openai"])] 73 | OpenAI { 74 | /// IP address to bind to (typically 127.0.0.1 or 0.0.0.0) 75 | #[arg(long, default_value_t = [0, 0, 0, 0].into())] 76 | ip: IpAddr, 77 | 78 | /// Port to expose the HTTP server on 79 | #[arg(long, default_value_t = 3000)] 80 | port: u16, 81 | }, 82 | } 83 | 84 | #[derive(Parser, Debug)] 85 | #[command(name = "kokoros")] 86 | #[command(version = "0.1")] 87 | #[command(author = "Lucas Jin")] 88 | struct Cli { 89 | /// A language identifier from 90 | /// https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md 91 | #[arg( 92 | short = 'l', 93 | long = "lan", 94 | value_name = "LANGUAGE", 95 | default_value = "en-us" 96 | )] 97 | lan: String, 98 | 99 | /// Path to the Kokoro v1.0 ONNX model on the filesystem 100 | #[arg( 101 | short = 'm', 102 | long = "model", 103 | value_name = "MODEL_PATH", 104 | default_value = "checkpoints/kokoro-v1.0.onnx" 105 | )] 106 | model_path: String, 107 | 108 | /// Path to the voices data file on the filesystem 109 | #[arg( 110 | short = 'd', 111 | long = "data", 112 | value_name = "DATA_PATH", 113 | default_value = "data/voices-v1.0.bin" 114 | )] 115 | data_path: String, 116 | 117 | /// Which single voice to use or voices to combine to serve as the style of speech 118 | #[arg( 119 | short = 's', 120 | long = "style", 121 | value_name = "STYLE", 122 | // if users use `af_sarah.4+af_nicole.6` as style name 123 | // then we blend it, with 0.4*af_sarah + 0.6*af_nicole 124 | default_value = "af_sarah.4+af_nicole.6" 125 | )] 126 | style: String, 127 | 128 | /// Rate of speech, as a coefficient of the default 129 | /// (i.e. 0.0 to 1.0 is slower than default, 130 | /// whereas 1.0 and beyond is faster than default) 131 | #[arg( 132 | short = 'p', 133 | long = "speed", 134 | value_name = "SPEED", 135 | default_value_t = 1.0 136 | )] 137 | speed: f32, 138 | 139 | /// Output audio in mono (as opposed to stereo) 140 | #[arg(long = "mono", default_value_t = false)] 141 | mono: bool, 142 | 143 | /// Initial silence duration in tokens 144 | #[arg(long = "initial-silence", value_name = "INITIAL_SILENCE")] 145 | initial_silence: Option, 146 | 147 | /// Also output a sidecar TSV file with word-level timestamps 148 | #[arg(long = "timestamps", default_value_t = false, global = true)] 149 | timestamps: bool, 150 | 151 | /// Number of TTS instances for parallel processing 152 | #[arg(long = "instances", value_name = "INSTANCES", default_value_t = 2)] 153 | instances: usize, 154 | 155 | #[command(subcommand)] 156 | mode: Mode, 157 | } 158 | 159 | fn derive_tsv_path_from_wav(path: &str) -> String { 160 | let p = Path::new(path); 161 | if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) { 162 | if let Some(parent) = p.parent() { 163 | return parent.join(format!("{stem}.tsv")).to_string_lossy().to_string(); 164 | } 165 | return format!("{stem}.tsv"); 166 | } 167 | // Fallback: just append .tsv 168 | format!("{path}.tsv") 169 | } 170 | 171 | fn write_wav_file(path: &str, samples: &[f32], sample_rate: u32, mono: bool) -> std::io::Result<()> { 172 | use std::fs::File; 173 | use std::io::Write; 174 | 175 | let channels: u16 = if mono { 1 } else { 2 }; 176 | let bits_per_sample: u16 = 32; // f32 177 | let bytes_per_sample: u32 = (bits_per_sample as u32) / 8; 178 | let block_align: u16 = channels * bits_per_sample / 8; 179 | let byte_rate: u32 = sample_rate * (block_align as u32); 180 | 181 | // Data size in bytes 182 | let num_frames: usize = samples.len(); 183 | let total_samples_to_write: usize = if mono { num_frames } else { num_frames * 2 }; 184 | let data_size: u32 = (total_samples_to_write as u32) * bytes_per_sample; 185 | let riff_chunk_size: u32 = 36 + data_size; // 4 + (8+16) + (8+data) 186 | 187 | let mut f = File::create(path)?; 188 | 189 | // RIFF header 190 | f.write_all(b"RIFF")?; 191 | f.write_all(&riff_chunk_size.to_le_bytes())?; 192 | f.write_all(b"WAVE")?; 193 | 194 | // fmt chunk 195 | f.write_all(b"fmt ")?; 196 | f.write_all(&(16u32).to_le_bytes())?; // PCM fmt chunk size 197 | f.write_all(&(3u16).to_le_bytes())?; // IEEE float = 3 198 | f.write_all(&channels.to_le_bytes())?; 199 | f.write_all(&sample_rate.to_le_bytes())?; 200 | f.write_all(&byte_rate.to_le_bytes())?; 201 | f.write_all(&block_align.to_le_bytes())?; 202 | f.write_all(&bits_per_sample.to_le_bytes())?; 203 | 204 | // data chunk 205 | f.write_all(b"data")?; 206 | f.write_all(&data_size.to_le_bytes())?; 207 | 208 | // write samples 209 | if mono { 210 | for &s in samples { 211 | f.write_all(&s.to_le_bytes())?; 212 | } 213 | } else { 214 | for &s in samples { 215 | f.write_all(&s.to_le_bytes())?; // left 216 | f.write_all(&s.to_le_bytes())?; // right (duplicate for simple stereo) 217 | } 218 | } 219 | 220 | Ok(()) 221 | } 222 | 223 | fn write_tsv(path: &str, alignments: &[(String, f32, f32)]) -> std::io::Result<()> { 224 | use std::fs::File; 225 | use std::io::Write; 226 | let mut f = File::create(path)?; 227 | f.write_all(b"word\tstart_sec\tend_sec\n")?; 228 | for (w, s, e) in alignments { 229 | // Use 3 decimal places by default 230 | let line = format!("{}\t{:.3}\t{:.3}\n", w, s, e); 231 | f.write_all(line.as_bytes())?; 232 | } 233 | Ok(()) 234 | } 235 | 236 | fn main() -> Result<(), Box> { 237 | // Initialize tracing with Unix timestamp format and environment-based log level 238 | tracing_subscriber::fmt() 239 | .with_timer(UnixTimestampFormatter) 240 | .with_env_filter( 241 | tracing_subscriber::EnvFilter::try_from_default_env() 242 | .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")) 243 | ) 244 | .init(); 245 | 246 | let rt = tokio::runtime::Runtime::new()?; 247 | rt.block_on(async { 248 | let Cli { 249 | lan, 250 | model_path, 251 | data_path, 252 | style, 253 | speed, 254 | initial_silence, 255 | mono, 256 | timestamps, 257 | instances, 258 | mode, 259 | } = Cli::parse(); 260 | 261 | let tts = TTSKoko::new(&model_path, &data_path).await; 262 | 263 | match mode { 264 | Mode::File { 265 | input_path, 266 | save_path_format, 267 | } => { 268 | let file_content = fs::read_to_string(input_path)?; 269 | for (i, line) in file_content.lines().enumerate() { 270 | let stripped_line = line.trim(); 271 | if stripped_line.is_empty() { 272 | continue; 273 | } 274 | 275 | let save_path = save_path_format.replace("{line}", &i.to_string()); 276 | if timestamps { 277 | match tts.tts_timestamped_raw_audio( 278 | stripped_line, 279 | &lan, 280 | &style, 281 | speed, 282 | initial_silence, 283 | None, 284 | None, 285 | None, 286 | ) { 287 | Ok(Some((audio, words))) => { 288 | // Write WAV 289 | // Note: current engine uses 24kHz 290 | write_wav_file(&save_path, &audio, 24_000, mono)?; 291 | 292 | // Write TSV sidecar 293 | let tsv_path = derive_tsv_path_from_wav(&save_path); 294 | let rows: Vec<(String, f32, f32)> = words 295 | .into_iter() 296 | .map(|w| (w.word, w.start_sec, w.end_sec)) 297 | .collect(); 298 | write_tsv(&tsv_path, &rows)?; 299 | eprintln!("Audio saved to {}", save_path); 300 | eprintln!("Timestamps saved to {}", tsv_path); 301 | } 302 | Ok(None) => { 303 | eprintln!("No audio produced for line {}", i + 1); 304 | } 305 | Err(e) => { 306 | eprintln!("Error processing line {}: {}", i + 1, e); 307 | } 308 | } 309 | } else { 310 | tts.tts(TTSOpts { 311 | txt: stripped_line, 312 | lan: &lan, 313 | style_name: &style, 314 | save_path: &save_path, 315 | mono, 316 | speed, 317 | initial_silence, 318 | })?; 319 | } 320 | } 321 | } 322 | 323 | Mode::Text { text, save_path } => { 324 | let s = std::time::Instant::now(); 325 | if timestamps { 326 | match tts.tts_timestamped_raw_audio( 327 | &text, 328 | &lan, 329 | &style, 330 | speed, 331 | initial_silence, 332 | None, 333 | None, 334 | None, 335 | ) { 336 | Ok(Some((audio, words))) => { 337 | write_wav_file(&save_path, &audio, 24_000, mono)?; 338 | let tsv_path = derive_tsv_path_from_wav(&save_path); 339 | let rows: Vec<(String, f32, f32)> = words 340 | .into_iter() 341 | .map(|w| (w.word, w.start_sec, w.end_sec)) 342 | .collect(); 343 | write_tsv(&tsv_path, &rows)?; 344 | eprintln!("Audio saved to {}", save_path); 345 | eprintln!("Timestamps saved to {}", tsv_path); 346 | } 347 | Ok(None) => { 348 | eprintln!("No audio produced for input text"); 349 | } 350 | Err(e) => { 351 | eprintln!("Error processing input text: {}", e); 352 | } 353 | } 354 | } else { 355 | tts.tts(TTSOpts { 356 | txt: &text, 357 | lan: &lan, 358 | style_name: &style, 359 | save_path: &save_path, 360 | mono, 361 | speed, 362 | initial_silence, 363 | })?; 364 | } 365 | println!("Time taken: {:?}", s.elapsed()); 366 | let words_per_second = 367 | text.split_whitespace().count() as f32 / s.elapsed().as_secs_f32(); 368 | println!("Words per second: {:.2}", words_per_second); 369 | } 370 | 371 | Mode::OpenAI { ip, port } => { 372 | // Create multiple independent TTS instances for parallel processing 373 | let mut tts_instances = Vec::new(); 374 | for i in 0..instances { 375 | tracing::info!("Initializing TTS instance [{}] ({}/{})", format!("{:02x}", i), i + 1, instances); 376 | let instance = TTSKoko::new(&model_path, &data_path).await; 377 | tts_instances.push(instance); 378 | } 379 | let app = kokoros_openai::create_server(tts_instances).await; 380 | let addr = SocketAddr::from((ip, port)); 381 | let binding = tokio::net::TcpListener::bind(&addr).await?; 382 | tracing::info!("Starting OpenAI-compatible HTTP server on {}", addr); 383 | kokoros_openai::serve(binding, app.into_make_service()).await?; 384 | } 385 | 386 | Mode::Stream => { 387 | let stdin = tokio::io::stdin(); 388 | let reader = BufReader::new(stdin); 389 | let mut lines = reader.lines(); 390 | 391 | // Use std::io::stdout() for sync writing 392 | let mut stdout = std::io::stdout(); 393 | 394 | eprintln!( 395 | "Entering streaming mode. Type text and press Enter. Use Ctrl+D to exit." 396 | ); 397 | 398 | // Write WAV header first 399 | let header = WavHeader::new(1, 24000, 32); 400 | header.write_header(&mut stdout)?; 401 | stdout.flush()?; 402 | 403 | while let Some(line) = lines.next_line().await? { 404 | let stripped_line = line.trim(); 405 | if stripped_line.is_empty() { 406 | continue; 407 | } 408 | 409 | // Process the line and get audio data 410 | match tts.tts_raw_audio(&stripped_line, &lan, &style, speed, initial_silence, None, None, None) { 411 | Ok(raw_audio) => { 412 | // Write the raw audio samples directly 413 | write_audio_chunk(&mut stdout, &raw_audio)?; 414 | stdout.flush()?; 415 | eprintln!("Audio written to stdout. Ready for another line of text."); 416 | } 417 | Err(e) => eprintln!("Error processing line: {}", e), 418 | } 419 | } 420 | } 421 | } 422 | 423 | Ok(()) 424 | }) 425 | } 426 | -------------------------------------------------------------------------------- /kokoros-openai/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! OpenAI-compatible TTS HTTP server for Kokoros 2 | //! 3 | //! This module provides an HTTP API that is compatible with OpenAI's text-to-speech endpoints. 4 | //! It implements streaming and non-streaming audio generation with multiple format support. 5 | //! 6 | //! ## Implemented Features 7 | //! - `/v1/audio/speech` - Text-to-speech generation with streaming support 8 | //! - `/v1/audio/voices` - List available voices 9 | //! - `/v1/models` - List available models (static dummy list) 10 | //! - Multiple audio formats: MP3, WAV, PCM, OPUS, AAC, FLAC 11 | //! - Streaming audio generation for low-latency responses 12 | //! 13 | //! ## OpenAI API Compatibility Limitations 14 | //! - `return_download_link`: Not implemented (files are streamed directly) 15 | //! - `lang_code`: Not implemented (language auto-detected from voice prefix) 16 | //! - `volume_multiplier`: Not implemented (audio returned at original levels) 17 | //! - `download_format`: Not implemented (only response_format used) 18 | //! - `normalization_options`: Not implemented (basic text processing only) 19 | //! - Streaming only supports PCM format (other formats fall back to PCM) 20 | 21 | use std::error::Error; 22 | use std::io; 23 | use std::sync::Arc; 24 | use std::time::Instant; 25 | 26 | use axum::{ 27 | Json, Router, 28 | body::Body, 29 | extract::{Path, State}, 30 | http::{StatusCode, header}, 31 | response::{IntoResponse, Response}, 32 | routing::{get, post}, 33 | }; 34 | use futures::stream::StreamExt; 35 | use kokoros::{ 36 | tts::koko::{InitConfig as TTSKokoInitConfig, TTSKoko}, 37 | utils::mp3::pcm_to_mp3, 38 | utils::opus::pcm_to_opus_ogg, 39 | utils::wav::{WavHeader, write_audio_chunk}, 40 | }; 41 | use regex::Regex; 42 | use serde::{Deserialize, Serialize}; 43 | use tokio::sync::mpsc; 44 | use tower_http::cors::CorsLayer; 45 | use tracing::{debug, error, info}; 46 | use uuid::Uuid; 47 | 48 | /// Break words used for chunk splitting 49 | const BREAK_WORDS: &[&str] = &[ 50 | "and", "or", "but", "&", "because", "if", "since", "though", "although", "however", "which", 51 | ]; 52 | 53 | /// Split text into speech chunks for streaming 54 | /// 55 | /// Prioritizes sentence boundaries over word count for natural speech breaks 56 | /// Then applies center-break word splitting for long chunks 57 | fn split_text_into_speech_chunks(text: &str, words_per_chunk: usize) -> Vec { 58 | let mut chunks = Vec::new(); 59 | let mut current_chunk = String::new(); 60 | let mut word_count = 0; 61 | 62 | // First pass: split by punctuation 63 | for word in text.split_whitespace() { 64 | if !current_chunk.is_empty() { 65 | current_chunk.push(' '); 66 | } 67 | // Check for numbered list patterns: 1. 2) 3: (4), 5(\s)[.\)\:] 68 | let is_numbered_break = is_numbered_list_item(word); 69 | 70 | if is_numbered_break && !current_chunk.is_empty() { 71 | chunks.push(current_chunk.trim().to_string()); 72 | current_chunk.clear(); 73 | word_count = 0; 74 | } 75 | current_chunk.push_str(word); 76 | word_count += 1; 77 | 78 | // Check for unconditional breaks (always break regardless of word count) 79 | let ends_with_unconditional = word.ends_with('.') 80 | || word.ends_with('!') 81 | || word.ends_with('?') 82 | || word.ends_with(':') 83 | || word.ends_with(';'); 84 | 85 | // Check for conditional breaks (commas - only break if enough words) 86 | let ends_with_conditional = word.ends_with(','); 87 | 88 | // Split conditions: 89 | // 1. Unconditional punctuation - always break 90 | // 2. Conditional punctuation + target word count reached 91 | if ends_with_unconditional 92 | || is_numbered_break 93 | || (ends_with_conditional && word_count >= words_per_chunk) 94 | { 95 | chunks.push(current_chunk.trim().to_string()); 96 | current_chunk.clear(); 97 | word_count = 0; 98 | } 99 | } 100 | 101 | if !current_chunk.trim().is_empty() { 102 | chunks.push(current_chunk.trim().to_string()); 103 | } 104 | 105 | // Second pass: apply center-break splitting for long chunks 106 | // All chunks: ≥12 words 107 | // First 2 chunks: punctuation priority, Others: break words only 108 | let mut final_chunks = Vec::new(); 109 | for (index, chunk) in chunks.iter().enumerate() { 110 | let threshold = 12; 111 | let use_punctuation = index < 2; // First 2 chunks can use punctuation 112 | let split_chunks = split_long_chunk_with_depth(chunk, threshold, use_punctuation, 0); 113 | final_chunks.extend(split_chunks); 114 | } 115 | 116 | // Final processing: Move break words from end of chunks to beginning of next chunk 117 | for i in 0..final_chunks.len() - 1 { 118 | let current_chunk = &final_chunks[i]; 119 | let words: Vec<&str> = current_chunk.trim().split_whitespace().collect(); 120 | 121 | if let Some(last_word) = words.last() { 122 | // Check if last word is a break word (case insensitive) 123 | if BREAK_WORDS.contains(&last_word.to_lowercase().as_str()) && words.len() > 1 { 124 | // Only move if it won't create an empty chunk (need more than 1 word) 125 | let new_current = words[..words.len() - 1].join(" "); 126 | 127 | // Add break word to beginning of next chunk 128 | let next_chunk = &final_chunks[i + 1]; 129 | let new_next = format!("{} {}", last_word, next_chunk); 130 | 131 | // Update the chunks 132 | final_chunks[i] = new_current; 133 | final_chunks[i + 1] = new_next; 134 | } 135 | } 136 | } 137 | 138 | // After all processing, there is no explicit filter to remove empty chunks. 139 | // If any empty string slipped through (e.g., from .trim().to_string() on 140 | // whitespace-only current_chunk, or from split_long_chunk), it would remain. 141 | // Dont consider filtering out empty chunks here, to enable catching potential bugs 142 | // in the chunking logic. 143 | final_chunks 144 | } 145 | 146 | /// Check if a word is a numbered list item: 1. 2) 3: (4), 5(\s)[.\)\:] 147 | fn is_numbered_list_item(word: &str) -> bool { 148 | // Pattern matches: number followed by . ) or : 149 | // Examples: "1.", "2)", "3:", "(4)", "(5)," 150 | let numbered_regex = Regex::new(r"^\(?[0-9]+[.\)\:],?$").unwrap(); 151 | numbered_regex.is_match(word) 152 | } 153 | 154 | fn split_long_chunk_with_depth( 155 | chunk: &str, 156 | threshold: usize, 157 | use_punctuation: bool, 158 | depth: usize, 159 | ) -> Vec { 160 | // Prevent infinite recursion 161 | if depth >= 3 { 162 | return vec![chunk.to_string()]; 163 | } 164 | let words: Vec<&str> = chunk.split_whitespace().collect(); 165 | let word_count = words.len(); 166 | 167 | // Only split if chunk meets the threshold 168 | if word_count < threshold { 169 | return vec![chunk.to_string()]; 170 | } 171 | 172 | let center = word_count / 2; 173 | 174 | if use_punctuation { 175 | // Priority 1: Search for commas closest to center 176 | if let Some(pos) = find_closest_punctuation(&words, center, &[","]) { 177 | if pos >= 3 && pos < words.len() { 178 | let first_chunk = words[..pos].join(" "); 179 | let second_chunk = words[pos..].join(" "); 180 | 181 | // Recursively split both chunks if they're still too long 182 | let mut result = Vec::new(); 183 | result.extend(split_long_chunk_with_depth( 184 | &first_chunk, 185 | threshold, 186 | use_punctuation, 187 | depth + 1, 188 | )); 189 | result.extend(split_long_chunk_with_depth( 190 | &second_chunk, 191 | threshold, 192 | use_punctuation, 193 | depth + 1, 194 | )); 195 | return result; 196 | } 197 | } 198 | } 199 | 200 | // Priority 2: Search for break words closest to center 201 | if let Some(pos) = find_closest_break_word(&words, center, BREAK_WORDS) { 202 | if pos >= 3 && pos < words.len() { 203 | let first_chunk = words[..pos].join(" "); 204 | let second_chunk = words[pos..].join(" "); 205 | 206 | // Recursively split both chunks if they're still too long 207 | let mut result = Vec::new(); 208 | result.extend(split_long_chunk_with_depth( 209 | &first_chunk, 210 | threshold, 211 | use_punctuation, 212 | depth + 1, 213 | )); 214 | result.extend(split_long_chunk_with_depth( 215 | &second_chunk, 216 | threshold, 217 | use_punctuation, 218 | depth + 1, 219 | )); 220 | return result; 221 | } 222 | } 223 | 224 | // No suitable break point found, return original chunk 225 | vec![chunk.to_string()] 226 | } 227 | 228 | /// Find closest punctuation to center 229 | fn find_closest_punctuation(words: &[&str], center: usize, punctuation: &[&str]) -> Option { 230 | let mut closest_pos = None; 231 | let mut min_distance = usize::MAX; 232 | 233 | for (i, word) in words.iter().enumerate() { 234 | if punctuation.iter().any(|p| word.ends_with(p)) { 235 | let distance = if i < center { center - i } else { i - center }; 236 | if distance < min_distance { 237 | min_distance = distance; 238 | closest_pos = Some(i + 1); // Split after the punctuation 239 | } 240 | } 241 | } 242 | 243 | closest_pos 244 | } 245 | 246 | /// Find closest break word to center 247 | fn find_closest_break_word(words: &[&str], center: usize, break_words: &[&str]) -> Option { 248 | let mut closest_pos = None; 249 | let mut min_distance = usize::MAX; 250 | 251 | for (i, word) in words.iter().enumerate() { 252 | if break_words.contains(&word.to_lowercase().as_str()) { 253 | let distance = if i < center { center - i } else { i - center }; 254 | if distance < min_distance { 255 | min_distance = distance; 256 | closest_pos = Some(i); // Break word becomes first word of second chunk 257 | } 258 | } 259 | } 260 | 261 | closest_pos 262 | } 263 | 264 | #[derive(Deserialize, Default, Debug)] 265 | #[serde(rename_all = "lowercase")] 266 | enum AudioFormat { 267 | #[default] 268 | Mp3, 269 | Wav, 270 | Opus, 271 | Aac, 272 | Flac, 273 | Pcm, 274 | } 275 | 276 | #[derive(Deserialize)] 277 | struct Voice(String); 278 | 279 | impl Default for Voice { 280 | fn default() -> Self { 281 | Self("af_sky".into()) 282 | } 283 | } 284 | 285 | #[derive(Deserialize)] 286 | struct Speed(f32); 287 | 288 | impl Default for Speed { 289 | fn default() -> Self { 290 | Self(1.) 291 | } 292 | } 293 | 294 | #[derive(Deserialize)] 295 | struct SpeechRequest { 296 | // Only one Kokoro model exists 297 | #[allow(dead_code)] 298 | model: String, 299 | 300 | input: String, 301 | 302 | #[serde(default)] 303 | voice: Voice, 304 | 305 | #[serde(default)] 306 | response_format: AudioFormat, 307 | 308 | #[serde(default)] 309 | speed: Speed, 310 | 311 | #[serde(default)] 312 | initial_silence: Option, 313 | 314 | /// Enable streaming audio generation (implemented) 315 | #[serde(default)] 316 | stream: Option, 317 | 318 | // OpenAI API compatibility parameters - accepted but not implemented 319 | // These fields ensure request parsing compatibility with OpenAI clients 320 | /// Return download link after generation (not implemented) 321 | #[serde(default)] 322 | #[allow(dead_code)] 323 | return_download_link: Option, 324 | 325 | /// Language code for text processing (not implemented) 326 | #[serde(default)] 327 | #[allow(dead_code)] 328 | lang_code: Option, 329 | 330 | /// Volume multiplier for output audio (not implemented) 331 | #[serde(default)] 332 | #[allow(dead_code)] 333 | volume_multiplier: Option, 334 | 335 | /// Format for download when different from response_format (not implemented) 336 | #[serde(default)] 337 | #[allow(dead_code)] 338 | download_format: Option, 339 | 340 | /// Text normalization options (not implemented) 341 | #[serde(default)] 342 | #[allow(dead_code)] 343 | normalization_options: Option, 344 | } 345 | 346 | /// Async TTS worker task 347 | #[derive(Debug)] 348 | struct TTSTask { 349 | id: usize, 350 | chunk: String, 351 | voice: String, 352 | speed: f32, 353 | initial_silence: Option, 354 | result_tx: mpsc::UnboundedSender<(usize, Vec)>, 355 | } 356 | 357 | /// Streaming session manager 358 | #[derive(Debug)] 359 | struct StreamingSession { 360 | session_id: Uuid, 361 | start_time: Instant, 362 | } 363 | 364 | /// TTS worker pool manager with multiple TTS instances 365 | #[derive(Clone)] 366 | struct TTSWorkerPool { 367 | tts_instances: Vec>, 368 | } 369 | 370 | impl TTSWorkerPool { 371 | fn new(tts_instances: Vec) -> Self { 372 | Self { 373 | tts_instances: tts_instances.into_iter().map(Arc::new).collect(), 374 | } 375 | } 376 | 377 | fn get_instance(&self, worker_id: usize) -> (Arc, String) { 378 | let index = worker_id % self.tts_instances.len(); 379 | let instance_id = format!("{:02x}", index); 380 | (Arc::clone(&self.tts_instances[index]), instance_id) 381 | } 382 | 383 | fn instance_count(&self) -> usize { 384 | self.tts_instances.len() 385 | } 386 | 387 | // process_chunk method removed - now handled inline in sequential queue processing 388 | } 389 | 390 | #[derive(Serialize)] 391 | struct VoicesResponse { 392 | voices: Vec, 393 | } 394 | 395 | #[derive(Serialize)] 396 | struct ModelObject { 397 | id: String, 398 | object: String, 399 | created: u64, 400 | owned_by: String, 401 | } 402 | 403 | #[derive(Serialize)] 404 | struct ModelsResponse { 405 | object: String, 406 | data: Vec, 407 | } 408 | 409 | pub async fn create_server(tts_instances: Vec) -> Router { 410 | info!("Starting TTS server with {} instances", tts_instances.len()); 411 | 412 | // Use first instance for compatibility with non-streaming endpoints 413 | let tts_single = tts_instances 414 | .first() 415 | .cloned() 416 | .expect("At least one TTS instance required"); 417 | 418 | Router::new() 419 | .route("/", get(handle_home)) 420 | .route("/v1/audio/speech", post(handle_tts)) 421 | .route("/v1/audio/voices", get(handle_voices)) 422 | .route("/v1/models", get(handle_models)) 423 | .route("/v1/models/{model}", get(handle_model)) 424 | .layer(axum::middleware::from_fn(request_id_middleware)) 425 | .layer(CorsLayer::permissive()) 426 | .with_state((tts_single, tts_instances)) 427 | } 428 | 429 | pub use axum::serve; 430 | 431 | #[derive(Debug)] 432 | enum SpeechError { 433 | // Deciding to modify this example in order to see errors 434 | // (e.g. with tracing) is up to the developer 435 | #[allow(dead_code)] 436 | Koko(Box), 437 | 438 | #[allow(dead_code)] 439 | Header(io::Error), 440 | 441 | #[allow(dead_code)] 442 | Chunk(io::Error), 443 | 444 | #[allow(dead_code)] 445 | Mp3Conversion(std::io::Error), 446 | 447 | #[allow(dead_code)] 448 | OpusConversion(std::io::Error), 449 | } 450 | 451 | impl std::fmt::Display for SpeechError { 452 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 453 | match self { 454 | SpeechError::Koko(e) => write!(f, "Koko TTS error: {}", e), 455 | SpeechError::Header(e) => write!(f, "Header error: {}", e), 456 | SpeechError::Chunk(e) => write!(f, "Chunk error: {}", e), 457 | SpeechError::Mp3Conversion(e) => write!(f, "MP3 conversion error: {}", e), 458 | SpeechError::OpusConversion(e) => write!(f, "Opus conversion error: {}", e), 459 | } 460 | } 461 | } 462 | 463 | impl IntoResponse for SpeechError { 464 | fn into_response(self) -> Response { 465 | // None of these errors make sense to expose to the user of the API 466 | StatusCode::INTERNAL_SERVER_ERROR.into_response() 467 | } 468 | } 469 | 470 | /// Returns a 200 OK response to make it easier to check if the server is 471 | /// running. 472 | async fn handle_home() -> &'static str { 473 | "OK" 474 | } 475 | 476 | async fn handle_tts( 477 | State((tts_single, tts_instances)): State<(TTSKoko, Vec)>, 478 | request: axum::extract::Request, 479 | ) -> Result { 480 | let (request_id, request_start) = request 481 | .extensions() 482 | .get::<(String, Instant)>() 483 | .cloned() 484 | .unwrap_or_else(|| ("unknown".to_string(), Instant::now())); 485 | 486 | // OpenAI TTS always streams by default - client decides how to consume 487 | // Only send complete file when explicitly requested via stream: false 488 | 489 | // Parse the JSON body 490 | let bytes = axum::body::to_bytes(request.into_body(), usize::MAX) 491 | .await 492 | .map_err(|e| { 493 | error!("Error reading request body: {:?}", e); 494 | SpeechError::Mp3Conversion(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)) 495 | })?; 496 | 497 | let speech_request: SpeechRequest = serde_json::from_slice(&bytes).map_err(|e| { 498 | error!("JSON parsing error: {:?}", e); 499 | SpeechError::Mp3Conversion(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)) 500 | })?; 501 | 502 | let SpeechRequest { 503 | input, 504 | voice: Voice(voice), 505 | response_format, 506 | speed: Speed(speed), 507 | initial_silence, 508 | stream, 509 | .. 510 | } = speech_request; 511 | 512 | // OpenAI-compliant behavior: Stream by default, only send complete file if stream: false 513 | let should_stream = stream.unwrap_or(true); // Default to streaming like OpenAI 514 | 515 | let colored_request_id = get_colored_request_id_with_relative(&request_id, request_start); 516 | debug!( 517 | "{} Streaming decision: stream_param={:?}, final_decision={}", 518 | colored_request_id, stream, should_stream 519 | ); 520 | 521 | if should_stream { 522 | return handle_tts_streaming( 523 | tts_instances, 524 | input, 525 | voice, 526 | response_format, 527 | speed, 528 | initial_silence, 529 | request_id, 530 | request_start, 531 | ) 532 | .await; 533 | } 534 | 535 | // Non-streaming mode (existing implementation) 536 | let raw_audio = tts_single 537 | .tts_raw_audio( 538 | &input, 539 | "en-us", 540 | &voice, 541 | speed, 542 | initial_silence, 543 | Some(&request_id), 544 | Some("00"), 545 | None, 546 | ) 547 | .map_err(SpeechError::Koko)?; 548 | 549 | let sample_rate = TTSKokoInitConfig::default().sample_rate; 550 | 551 | let (content_type, audio_data, format_name) = match response_format { 552 | AudioFormat::Wav => { 553 | let mut wav_data = Vec::default(); 554 | let header = WavHeader::new(1, sample_rate, 32); 555 | header 556 | .write_header(&mut wav_data) 557 | .map_err(SpeechError::Header)?; 558 | write_audio_chunk(&mut wav_data, &raw_audio).map_err(SpeechError::Chunk)?; 559 | 560 | ("audio/wav", wav_data, "WAV") 561 | } 562 | AudioFormat::Opus => { 563 | let opus_data = 564 | pcm_to_opus_ogg(&raw_audio, sample_rate).map_err(|e| SpeechError::OpusConversion(e))?; 565 | 566 | ("audio/opus", opus_data, "OPUS") 567 | } 568 | AudioFormat::Mp3 => { 569 | let mp3_data = 570 | pcm_to_mp3(&raw_audio, sample_rate).map_err(|e| SpeechError::Mp3Conversion(e))?; 571 | 572 | ("audio/mpeg", mp3_data, "MP3") 573 | } 574 | AudioFormat::Pcm => { 575 | // For PCM, we return the raw audio data directly 576 | // Convert f32 samples to 16-bit PCM 577 | let mut pcm_data = Vec::with_capacity(raw_audio.len() * 2); 578 | for sample in raw_audio { 579 | let pcm_sample = (sample * 32767.0).clamp(-32768.0, 32767.0) as i16; 580 | pcm_data.extend_from_slice(&pcm_sample.to_le_bytes()); 581 | } 582 | ("audio/pcm", pcm_data, "PCM") 583 | } 584 | // For now, unsupported formats fall back to MP3 585 | _ => { 586 | let mp3_data = 587 | pcm_to_mp3(&raw_audio, sample_rate).map_err(|e| SpeechError::Mp3Conversion(e))?; 588 | 589 | ("audio/mpeg", mp3_data, "MP3") 590 | } 591 | }; 592 | 593 | let colored_request_id = get_colored_request_id_with_relative(&request_id, request_start); 594 | info!( 595 | "{} TTS non-streaming completed - {} bytes, {} format", 596 | colored_request_id, 597 | audio_data.len(), 598 | format_name 599 | ); 600 | 601 | Ok(Response::builder() 602 | .header(header::CONTENT_TYPE, content_type) 603 | .body(audio_data.into()) 604 | .map_err(|e| { 605 | SpeechError::Mp3Conversion(std::io::Error::new(std::io::ErrorKind::Other, e)) 606 | })?) 607 | } 608 | 609 | /// Handle streaming TTS requests with true async processing 610 | /// 611 | /// Uses micro-chunking and parallel processing for low-latency streaming. 612 | /// Maintains speech order while allowing out-of-order chunk completion. 613 | async fn handle_tts_streaming( 614 | tts_instances: Vec, 615 | input: String, 616 | voice: String, 617 | response_format: AudioFormat, 618 | speed: f32, 619 | initial_silence: Option, 620 | request_id: String, 621 | request_start: Instant, 622 | ) -> Result { 623 | // Streaming implementation: PCM format for optimal performance 624 | let content_type = match response_format { 625 | AudioFormat::Pcm => "audio/pcm", 626 | _ => "audio/pcm", // Force PCM for optimal streaming performance 627 | }; 628 | 629 | // Create worker pool with vector of TTS instances for true parallelism 630 | let worker_pool = TTSWorkerPool::new(tts_instances); 631 | 632 | // Create speech chunks based on word count and punctuation 633 | let mut chunks = split_text_into_speech_chunks(&input, 10); 634 | 635 | // Add empty chunk at end as completion signal to client 636 | chunks.push(String::new()); 637 | let total_chunks = chunks.len(); 638 | 639 | let colored_request_id = get_colored_request_id_with_relative(&request_id, request_start); 640 | debug!( 641 | "{} Processing {} chunks for streaming with window size {}", 642 | colored_request_id, 643 | total_chunks, 644 | worker_pool.instance_count() 645 | ); 646 | 647 | if chunks.is_empty() { 648 | return Err(SpeechError::Mp3Conversion(std::io::Error::new( 649 | std::io::ErrorKind::InvalidInput, 650 | "No text to process", 651 | ))); 652 | } 653 | 654 | // Create channels for sequential chunk processing 655 | let (task_tx, mut task_rx) = mpsc::unbounded_channel::(); 656 | let (audio_tx, audio_rx) = mpsc::unbounded_channel::<(usize, Vec)>(); // Tag chunks with order ID 657 | 658 | // Track total bytes transferred 659 | let total_bytes = Arc::new(std::sync::atomic::AtomicUsize::new(0)); 660 | 661 | // Create session for tracking 662 | let session = StreamingSession { 663 | session_id: Uuid::new_v4(), 664 | start_time: Instant::now(), 665 | }; 666 | 667 | let colored_request_id = get_colored_request_id_with_relative(&request_id, request_start); 668 | info!( 669 | "{} TTS session started - {} chunks streaming", 670 | colored_request_id, total_chunks 671 | ); 672 | 673 | // Queue all tasks in order for sequential processing 674 | for (id, chunk) in chunks.into_iter().enumerate() { 675 | let task = TTSTask { 676 | id, 677 | chunk, 678 | voice: voice.clone(), 679 | speed, 680 | initial_silence: if id == 0 { initial_silence } else { None }, 681 | result_tx: audio_tx.clone(), 682 | }; 683 | 684 | task_tx.send(task).unwrap(); 685 | } 686 | 687 | // Drop the task sender to signal completion 688 | drop(task_tx); 689 | 690 | // Windowed parallel processing: allow chunks to process concurrently up to available TTS instances 691 | let worker_pool_clone = worker_pool.clone(); 692 | let total_bytes_clone = total_bytes.clone(); 693 | let audio_tx_clone = audio_tx.clone(); 694 | let total_chunks_expected = total_chunks; 695 | tokio::spawn(async move { 696 | use std::collections::BTreeMap; 697 | 698 | let mut chunk_counter = 0; 699 | let mut pending_chunks: BTreeMap< 700 | usize, 701 | tokio::task::JoinHandle), String>>, 702 | > = BTreeMap::new(); 703 | let mut next_to_send = 0; 704 | let mut chunks_processed = 0; 705 | let window_size = worker_pool_clone.instance_count(); // Allow chunks to process in parallel up to available TTS instances 706 | 707 | loop { 708 | // Receive new tasks while we have window space and tasks are available 709 | while pending_chunks.len() < window_size { 710 | // Use a non-blocking approach but with proper channel closure detection 711 | match task_rx.try_recv() { 712 | Ok(task) => { 713 | let task_id = task.id; 714 | let worker_pool_clone = worker_pool_clone.clone(); 715 | let total_bytes_clone = total_bytes_clone.clone(); 716 | let request_id_clone = request_id.clone(); 717 | 718 | // Process chunk with dedicated TTS instance (alternates between instances) 719 | let (tts_instance, actual_instance_id) = 720 | worker_pool_clone.get_instance(chunk_counter); 721 | let chunk_text = task.chunk.clone(); 722 | let voice = task.voice.clone(); 723 | let speed = task.speed; 724 | let initial_silence = task.initial_silence; 725 | let chunk_num = chunk_counter; 726 | 727 | // Spawn parallel processing 728 | let handle = tokio::spawn(async move { 729 | // Handle empty chunks (completion signals) without TTS processing 730 | if chunk_text.trim().is_empty() { 731 | // Empty chunk - send as completion signal 732 | return Ok((task_id, Vec::new())); 733 | } 734 | 735 | let result = tokio::task::spawn_blocking(move || { 736 | let audio_result = tts_instance.tts_raw_audio( 737 | &chunk_text, 738 | "en-us", 739 | &voice, 740 | speed, 741 | initial_silence, 742 | Some(&request_id_clone), 743 | Some(&actual_instance_id), 744 | Some(chunk_num), 745 | ); 746 | 747 | audio_result 748 | .map(|audio| audio) 749 | .map_err(|e| format!("TTS processing error: {:?}", e)) 750 | }) 751 | .await; 752 | 753 | // Convert audio to PCM 754 | match result { 755 | Ok(Ok(audio_samples)) => { 756 | let mut pcm_data = Vec::with_capacity(audio_samples.len() * 2); 757 | for sample in audio_samples { 758 | let pcm_sample = 759 | (sample * 32767.0).clamp(-32768.0, 32767.0) as i16; 760 | pcm_data.extend_from_slice(&pcm_sample.to_le_bytes()); 761 | } 762 | total_bytes_clone.fetch_add( 763 | pcm_data.len(), 764 | std::sync::atomic::Ordering::Relaxed, 765 | ); 766 | Ok((task_id, pcm_data)) 767 | } 768 | Ok(Err(e)) => Err(e), 769 | Err(e) => Err(format!("Task execution error: {:?}", e)), 770 | } 771 | }); 772 | 773 | pending_chunks.insert(chunk_counter, handle); 774 | chunk_counter += 1; 775 | } 776 | Err(tokio::sync::mpsc::error::TryRecvError::Empty) => { 777 | // No tasks available right now, break inner loop to check completed chunks 778 | break; 779 | } 780 | Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => { 781 | // Channel is closed, no more tasks will come 782 | break; 783 | } 784 | } 785 | } 786 | 787 | // Check if we can send the next chunk in order 788 | if let Some(handle) = pending_chunks.remove(&next_to_send) { 789 | if handle.is_finished() { 790 | match handle.await { 791 | Ok(Ok((task_id, pcm_data))) => { 792 | if let Err(_) = audio_tx_clone.send((task_id, pcm_data)) { 793 | break; 794 | } 795 | next_to_send += 1; 796 | chunks_processed += 1; 797 | } 798 | Ok(Err(_e)) => { 799 | // TTS processing error - skip this chunk 800 | next_to_send += 1; 801 | chunks_processed += 1; 802 | } 803 | Err(_e) => { 804 | // Task execution error - skip this chunk 805 | next_to_send += 1; 806 | chunks_processed += 1; 807 | } 808 | } 809 | } else { 810 | // Not finished yet, put it back 811 | pending_chunks.insert(next_to_send, handle); 812 | } 813 | } 814 | 815 | // Check if all chunks have been processed and sent 816 | // We're done when we've processed all expected chunks 817 | if chunks_processed >= total_chunks_expected { 818 | break; 819 | } 820 | 821 | // Also check if we have no more work to do (fallback safety check) 822 | if pending_chunks.is_empty() 823 | && task_rx.is_empty() 824 | && chunks_processed < total_chunks_expected 825 | { 826 | // This shouldn't happen, but log it for debugging 827 | eprintln!( 828 | "Warning: Early termination detected - processed {} of {} chunks", 829 | chunks_processed, total_chunks_expected 830 | ); 831 | break; 832 | } 833 | 834 | // Small delay to prevent busy waiting 835 | tokio::time::sleep(tokio::time::Duration::from_millis(1)).await; 836 | } 837 | 838 | // Wait for any remaining chunks to complete and collect them 839 | // This fixes the previous issue where only chunks matching next_to_send exactly were processed 840 | let mut remaining_chunks = Vec::new(); 841 | 842 | for (chunk_id, handle) in pending_chunks { 843 | match handle.await { 844 | Ok(Ok((task_id, pcm_data))) => { 845 | // Collect all successful chunks regardless of order 846 | remaining_chunks.push((chunk_id, task_id, pcm_data)); 847 | } 848 | Ok(Err(_e)) => { 849 | // TTS processing error - still count as processed 850 | chunks_processed += 1; 851 | } 852 | Err(_e) => { 853 | // Task execution error - still count as processed 854 | chunks_processed += 1; 855 | } 856 | } 857 | } 858 | 859 | // Sort remaining chunks by chunk_id to maintain proper order 860 | // This ensures audio continuity even for out-of-order completions 861 | remaining_chunks.sort_by_key(|(chunk_id, _, _)| *chunk_id); 862 | 863 | // Send all remaining chunks in order, preventing data loss 864 | for (chunk_id, task_id, pcm_data) in remaining_chunks { 865 | // Only send chunks that are in the expected sequence (>= next_to_send) 866 | // This prevents duplicate sends while ensuring no valid chunks are skipped 867 | if chunk_id >= next_to_send { 868 | let _ = audio_tx_clone.send((task_id, pcm_data)); 869 | chunks_processed += 1; 870 | } 871 | } 872 | 873 | let _session_time = session.start_time.elapsed(); 874 | 875 | // Log completion 876 | let bytes_transferred = total_bytes.load(std::sync::atomic::Ordering::Relaxed); 877 | // Calculate audio duration: 16-bit PCM (2 bytes per sample) at 24000 Hz 878 | let total_samples = bytes_transferred / 2; 879 | let duration_seconds = total_samples as f64 / 24000.0; 880 | let colored_request_id = get_colored_request_id_with_relative(&request_id, request_start); 881 | info!( 882 | "{} TTS session completed - {} chunks, {} bytes, {:.1}s audio, PCM format", 883 | colored_request_id, total_chunks, bytes_transferred, duration_seconds 884 | ); 885 | 886 | // Send termination signal 887 | let _ = audio_tx.send((total_chunks, vec![])); // Empty data as termination signal 888 | }); 889 | 890 | // No ordering needed - sequential processing guarantees order 891 | 892 | // Create immediate streaming - chunks are already sent in order from TTS processing 893 | let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(audio_rx) 894 | .map(|(_chunk_id, data)| -> Result, std::io::Error> { 895 | // Check for termination signal (empty data) 896 | if data.is_empty() { 897 | return Err(std::io::Error::new( 898 | std::io::ErrorKind::UnexpectedEof, 899 | "Stream complete", 900 | )); 901 | } 902 | Ok(data) 903 | }) 904 | .take_while(|result| { 905 | // Continue until we hit an error (termination signal) 906 | std::future::ready(result.is_ok()) 907 | }); 908 | 909 | // Convert to HTTP body with explicit ordering 910 | let body = Body::from_stream(stream); 911 | 912 | Ok(Response::builder() 913 | .header(header::CONTENT_TYPE, content_type) 914 | .header(header::CONNECTION, "keep-alive") 915 | .header(header::CACHE_CONTROL, "no-cache") 916 | .header("X-Accel-Buffering", "no") // Disable nginx buffering 917 | .header("Transfer-Encoding", "chunked") // Enable HTTP chunked transfer encoding 918 | .header("Access-Control-Allow-Origin", "*") // CORS for browser clients 919 | .body(body) 920 | .map_err(|e| { 921 | SpeechError::Mp3Conversion(std::io::Error::new(std::io::ErrorKind::Other, e)) 922 | })?) 923 | } 924 | 925 | async fn handle_voices( 926 | State((tts_single, _tts_instances)): State<(TTSKoko, Vec)>, 927 | ) -> Json { 928 | let voices = tts_single.get_available_voices(); 929 | Json(VoicesResponse { voices }) 930 | } 931 | 932 | /// Handle /v1/models endpoint 933 | /// 934 | /// Returns a static list of models for OpenAI API compatibility. 935 | /// Note: All models use the same underlying Kokoro TTS engine. 936 | async fn handle_models() -> Json { 937 | let models = vec![ 938 | ModelObject { 939 | id: "tts-1".to_string(), 940 | object: "model".to_string(), 941 | created: 1686935002, 942 | owned_by: "kokoro".to_string(), 943 | }, 944 | ModelObject { 945 | id: "tts-1-hd".to_string(), 946 | object: "model".to_string(), 947 | created: 1686935002, 948 | owned_by: "kokoro".to_string(), 949 | }, 950 | ModelObject { 951 | id: "kokoro".to_string(), 952 | object: "model".to_string(), 953 | created: 1686935002, 954 | owned_by: "kokoro".to_string(), 955 | }, 956 | ]; 957 | 958 | Json(ModelsResponse { 959 | object: "list".to_string(), 960 | data: models, 961 | }) 962 | } 963 | 964 | async fn handle_model(Path(model_id): Path) -> Result, StatusCode> { 965 | let model = match model_id.as_str() { 966 | "tts-1" => ModelObject { 967 | id: "tts-1".to_string(), 968 | object: "model".to_string(), 969 | created: 1686935002, 970 | owned_by: "kokoro".to_string(), 971 | }, 972 | "tts-1-hd" => ModelObject { 973 | id: "tts-1-hd".to_string(), 974 | object: "model".to_string(), 975 | created: 1686935002, 976 | owned_by: "kokoro".to_string(), 977 | }, 978 | "kokoro" => ModelObject { 979 | id: "kokoro".to_string(), 980 | object: "model".to_string(), 981 | created: 1686935002, 982 | owned_by: "kokoro".to_string(), 983 | }, 984 | _ => return Err(StatusCode::NOT_FOUND), 985 | }; 986 | 987 | Ok(Json(model)) 988 | } 989 | 990 | fn get_colored_request_id_with_relative(request_id: &str, start_time: Instant) -> String { 991 | kokoros::utils::debug::get_colored_request_id_with_relative(request_id, start_time) 992 | } 993 | 994 | async fn request_id_middleware( 995 | mut request: axum::extract::Request, 996 | next: axum::middleware::Next, 997 | ) -> axum::response::Response { 998 | let method = request.method().clone(); 999 | let uri = request.uri().path().to_string(); 1000 | let user_agent = request 1001 | .headers() 1002 | .get("user-agent") 1003 | .and_then(|h| h.to_str().ok()) 1004 | .unwrap_or("-") 1005 | .to_string(); 1006 | 1007 | let request_id = uuid::Uuid::new_v4().to_string()[..8].to_string(); 1008 | let start = std::time::Instant::now(); 1009 | let colored_request_id = get_colored_request_id_with_relative(&request_id, start); 1010 | request.extensions_mut().insert((request_id.clone(), start)); 1011 | 1012 | info!( 1013 | "{} {} {} \"{}\"", 1014 | colored_request_id, method, uri, user_agent 1015 | ); 1016 | 1017 | let response = next.run(request).await; 1018 | let _latency = start.elapsed(); 1019 | 1020 | let colored_request_id_response = get_colored_request_id_with_relative(&request_id, start); 1021 | info!("{} {}", colored_request_id_response, response.status()); 1022 | 1023 | response 1024 | } 1025 | -------------------------------------------------------------------------------- /kokoros/src/tts/koko.rs: -------------------------------------------------------------------------------- 1 | use crate::onn::ort_koko::{self, ModelStrategy}; 2 | use crate::tts::tokenize::tokenize; 3 | use crate::utils; 4 | use crate::utils::debug::format_debug_prefix; 5 | use lazy_static::lazy_static; 6 | use ndarray::Array3; 7 | use ndarray_npy::NpzReader; 8 | use std::collections::HashMap; 9 | use std::error::Error; 10 | use std::fs::File; 11 | use std::path::Path; 12 | use std::sync::atomic::{AtomicBool, Ordering}; 13 | use std::sync::{Arc, Mutex}; 14 | 15 | use espeak_rs::text_to_phonemes; 16 | 17 | // Global mutex to serialize espeak-rs calls to prevent phoneme randomization 18 | // espeak-rs uses global state internally and is not thread-safe 19 | lazy_static! { 20 | static ref ESPEAK_MUTEX: Mutex<()> = Mutex::new(()); 21 | } 22 | 23 | // Flag to ensure voice styles are only logged once 24 | static VOICES_LOGGED: AtomicBool = AtomicBool::new(false); 25 | 26 | #[derive(Debug, Clone)] 27 | pub struct WordAlignment { 28 | pub word: String, 29 | pub start_sec: f32, 30 | pub end_sec: f32, 31 | } 32 | 33 | #[derive(Debug, Clone)] 34 | pub enum TtsOutput { 35 | /// Standard audio, no timing data 36 | Audio(Vec), 37 | /// Audio with synchronized word timestamps 38 | Aligned(Vec, Vec), 39 | } 40 | 41 | impl TtsOutput { 42 | pub fn raw_output(self) -> (Vec, Option>) { 43 | match self { 44 | TtsOutput::Audio(a) => (a, None), 45 | TtsOutput::Aligned(a, b) => (a, Some(b)) 46 | } 47 | } 48 | } 49 | 50 | enum ExecutionMode<'a> { 51 | /// Collects all data, adjusts timestamps to be global, returns it at the end. 52 | Batch, 53 | /// Yields chunks immediately with relative timestamps. Returns None at end. 54 | Stream(&'a mut dyn FnMut(TtsOutput) -> Result<(), Box>), 55 | } 56 | 57 | #[derive(Debug, Clone)] 58 | pub struct TTSOpts<'a> { 59 | pub txt: &'a str, 60 | pub lan: &'a str, 61 | pub style_name: &'a str, 62 | pub save_path: &'a str, 63 | pub mono: bool, 64 | pub speed: f32, 65 | pub initial_silence: Option, 66 | } 67 | 68 | #[derive(Clone)] 69 | pub struct TTSKoko { 70 | #[allow(dead_code)] 71 | model_path: String, 72 | model: Arc>, 73 | styles: HashMap>, 74 | init_config: InitConfig, 75 | } 76 | 77 | /// Parallel TTS with multiple ONNX instances for true concurrency 78 | #[derive(Clone)] 79 | pub struct TTSKokoParallel { 80 | #[allow(dead_code)] 81 | model_path: String, 82 | models: Vec>>, 83 | styles: HashMap>, 84 | init_config: InitConfig, 85 | } 86 | 87 | #[derive(Clone)] 88 | pub struct InitConfig { 89 | pub model_url: String, 90 | pub voices_url: String, 91 | pub sample_rate: u32, 92 | } 93 | 94 | impl Default for InitConfig { 95 | fn default() -> Self { 96 | Self { 97 | model_url: "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/kokoro-v1.0.onnx".into(), 98 | voices_url: "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/voices-v1.0.bin".into(), 99 | sample_rate: 24000, 100 | } 101 | } 102 | } 103 | 104 | impl TTSKoko { 105 | pub async fn new(model_path: &str, voices_path: &str) -> Self { 106 | Self::from_config(model_path, voices_path, InitConfig::default()).await 107 | } 108 | 109 | pub async fn from_config(model_path: &str, voices_path: &str, cfg: InitConfig) -> Self { 110 | if !Path::new(model_path).exists() { 111 | utils::fileio::download_file_from_url(cfg.model_url.as_str(), model_path) 112 | .await 113 | .expect("download model failed."); 114 | } 115 | 116 | if !Path::new(voices_path).exists() { 117 | utils::fileio::download_file_from_url(cfg.voices_url.as_str(), voices_path) 118 | .await 119 | .expect("download voices data file failed."); 120 | } 121 | 122 | let model = Arc::new(Mutex::new( 123 | ort_koko::OrtKoko::new(model_path.to_string()) 124 | .expect("Failed to create Kokoro TTS model"), 125 | )); 126 | // TODO: if(not streaming) { model.print_info(); } 127 | // model.print_info(); 128 | 129 | let styles = Self::load_voices(voices_path); 130 | 131 | TTSKoko { 132 | model_path: model_path.to_string(), 133 | model, 134 | styles, 135 | init_config: cfg, 136 | } 137 | } 138 | 139 | fn process_internal( 140 | &self, 141 | txt: &str, 142 | lan: &str, 143 | style_name: &str, 144 | speed: f32, 145 | initial_silence: Option, 146 | request_id: Option<&str>, 147 | instance_id: Option<&str>, 148 | chunk_number_start: Option, 149 | mut mode: ExecutionMode, 150 | ) -> Result, Vec)>, Box> { 151 | 152 | let chunks = self.split_text_into_chunks(txt, 500); 153 | let start_chunk_num = chunk_number_start.unwrap_or(0); 154 | 155 | let debug_prefix = format_debug_prefix(request_id, instance_id); 156 | 157 | let process_one_chunk = |chunk: &str, chunk_num: usize| -> Result> { 158 | 159 | let chunk_info = format!("Chunk: {}, ", chunk_num); 160 | tracing::debug!("{} {}text: '{}'", debug_prefix, chunk_info, chunk); 161 | 162 | // A. Tokenize 163 | // Only build the expensive alignment map if the loaded model supports timestamps. 164 | let use_alignment = { 165 | let model = self.model.lock().unwrap(); 166 | matches!(model.strategy(), Some(ModelStrategy::Timestamped(_))) 167 | }; 168 | 169 | let (mut tokens, word_map) = if use_alignment { 170 | self.tokenize_with_alignment(chunk, lan) 171 | } else { 172 | // Fast path for audio-only models: single eSpeak pass, no per-item calls 173 | self.tokenize_full_no_alignment(chunk, lan) 174 | }; 175 | 176 | // Log token count (helpful for debugging context limits) 177 | tracing::debug!("{} {}tokens generated: {}", debug_prefix, chunk_info, tokens.len()); 178 | 179 | // B. Silence 180 | let silence_count = initial_silence.unwrap_or(0); 181 | for _ in 0..silence_count { 182 | tokens.insert(0, 30); 183 | } 184 | 185 | // C. Style 186 | let styles = self.mix_styles(style_name, tokens.len())?; 187 | 188 | // D. Padding 189 | let mut padded_tokens = vec![0]; 190 | padded_tokens.extend(tokens); 191 | padded_tokens.push(0); 192 | 193 | let index_offset = 1 + silence_count; 194 | let tokens_batch = vec![padded_tokens]; 195 | 196 | // E. Infer 197 | let (chunk_audio_array, chunk_durations_opt) = self.model.lock().unwrap().infer( 198 | tokens_batch, 199 | styles, 200 | speed, 201 | request_id, 202 | instance_id, 203 | Some(chunk_num), 204 | )?; 205 | 206 | let chunk_audio: Vec = chunk_audio_array.iter().cloned().collect(); 207 | 208 | // F. Calculate Alignments 209 | if let Some(durations) = chunk_durations_opt { 210 | let mut alignments = Vec::new(); 211 | 212 | // Model durations are in frames (hop=600 @ 24 kHz) ⇒ 40 frames/sec. 213 | let frames_per_sec: f32 = 40.0; 214 | 215 | // Guard speed to avoid division by zero; timestamps should reflect the final render timeline. 216 | let speed_safe = if speed > 1e-6 { speed } else { 1.0 }; 217 | 218 | // Include initial "silence tokens" time into the local time cursor. You already shift the 219 | // durations index by `index_offset = 1 + silence_count`; here we also advance the cursor by 220 | // the skipped frames so the first word starts at the actual audio time. 221 | let mut chunk_time_cursor_frames: f32 = 0.0; 222 | if silence_count > 0 { 223 | let start = 1; // skip BOS 224 | let end = (1 + silence_count).min(durations.len()); 225 | if end > start { 226 | let silence_frames: f32 = durations[start..end].iter().sum(); 227 | chunk_time_cursor_frames += silence_frames; 228 | } 229 | } 230 | 231 | // Punctuation pause table in seconds (tune as needed). We scale by 1/speed so faster speech shortens pauses. 232 | let punct_pause_s = |label: &str| -> f32 { 233 | match label { 234 | "." | "!" | "?" => 0.300, // 300 ms 235 | "," => 0.150, // 150 ms 236 | ";" | ":" => 0.200, 237 | _ => 0.0, 238 | } 239 | }; 240 | 241 | for (word, start, end) in word_map { 242 | let adj_start = start + index_offset; 243 | let adj_end = end + index_offset; 244 | 245 | // Punctuation items are separate in word_map with zero token span; account for pause. 246 | let is_punct = word.len() == 1 && ".,!?:;!?".contains(word.as_str()); 247 | if is_punct { 248 | // Scale pauses by 1/speed so timestamps match rendered audio when speech rate changes. 249 | let pause_s = punct_pause_s(&word) / speed_safe; 250 | let pause_frames = pause_s * frames_per_sec; 251 | let start_sec = chunk_time_cursor_frames / frames_per_sec; 252 | let end_sec = (chunk_time_cursor_frames + pause_frames) / frames_per_sec; 253 | alignments.push(WordAlignment { word: word.clone(), start_sec, end_sec }); 254 | chunk_time_cursor_frames += pause_frames; 255 | continue; 256 | } 257 | 258 | // Normal word span: sum its frame durations and advance the cursor. 259 | if adj_start < adj_end && adj_end <= durations.len() { 260 | let mut word_frames: f32 = durations[adj_start..adj_end].iter().sum(); 261 | 262 | // If your ONNX `durations` do NOT already include speed scaling, uncomment this line: 263 | // word_frames /= speed_safe; 264 | // (Leave it commented if the model already produces speed‑scaled durations.) 265 | 266 | let start_sec = chunk_time_cursor_frames / frames_per_sec; 267 | let end_sec = (chunk_time_cursor_frames + word_frames) / frames_per_sec; 268 | alignments.push(WordAlignment { word, start_sec, end_sec }); 269 | chunk_time_cursor_frames += word_frames; 270 | } 271 | } 272 | 273 | // Per‑chunk closure: linearly scale the local alignment times to match this chunk’s audio length. 274 | // This eliminates cumulative drift across chunks and prevents middle events from sliding late. 275 | let t_end_sec = chunk_time_cursor_frames / frames_per_sec; // alignment‑derived duration (sec) 276 | let chunk_audio_sec = chunk_audio.len() as f32 / 24_000.0; // audio duration (sec) 277 | 278 | if t_end_sec > 0.0 { 279 | let s = (chunk_audio_sec / t_end_sec); 280 | // Optionally clamp extreme corrections; typical values should be close to 1.0 281 | let s_clamped = s.clamp(0.8, 1.25); 282 | if (s_clamped - 1.0).abs() > 0.005 { // >0.5% correction 283 | tracing::debug!(scale = s_clamped, "Per-chunk alignment scaling applied (speed-aware)"); 284 | for al in &mut alignments { 285 | al.start_sec *= s_clamped; 286 | al.end_sec *= s_clamped; 287 | } 288 | } 289 | 290 | // Optional sanity log after scaling 291 | let diff_ms = (((t_end_sec * s_clamped) - chunk_audio_sec) * 1000.0).abs(); 292 | if diff_ms > 10.0 { 293 | tracing::warn!( 294 | chunk_t_end_sec = t_end_sec * s_clamped, 295 | chunk_audio_sec, 296 | diff_ms, 297 | "Alignment vs audio duration still off after scaling", 298 | ); 299 | } else { 300 | tracing::debug!( 301 | chunk_t_end_sec = t_end_sec * s_clamped, 302 | chunk_audio_sec, 303 | "Chunk alignment closure OK", 304 | ); 305 | } 306 | } 307 | 308 | Ok(TtsOutput::Aligned(chunk_audio, alignments)) 309 | } else { 310 | Ok(TtsOutput::Audio(chunk_audio)) 311 | } 312 | }; 313 | 314 | match &mut mode { 315 | ExecutionMode::Stream(callback) => { 316 | for (i, chunk) in chunks.iter().enumerate() { 317 | let output = process_one_chunk(chunk, start_chunk_num + i)?; 318 | callback(output)?; 319 | } 320 | Ok(None) 321 | } 322 | 323 | ExecutionMode::Batch => { 324 | let mut batch_audio = Vec::new(); 325 | let mut batch_alignments = Vec::new(); 326 | let mut global_time_offset = 0.0; 327 | let sample_rate = 24000.0; 328 | 329 | for (i, chunk) in chunks.iter().enumerate() { 330 | let output = process_one_chunk(chunk, start_chunk_num + i)?; 331 | 332 | match output { 333 | TtsOutput::Aligned(audio, alignments) => { 334 | let duration = audio.len() as f32 / sample_rate; 335 | batch_audio.extend_from_slice(&audio); 336 | 337 | for mut align in alignments { 338 | align.start_sec += global_time_offset; 339 | align.end_sec += global_time_offset; 340 | batch_alignments.push(align); 341 | } 342 | global_time_offset += duration; 343 | } 344 | TtsOutput::Audio(audio) => { 345 | let duration = audio.len() as f32 / sample_rate; 346 | batch_audio.extend_from_slice(&audio); 347 | global_time_offset += duration; 348 | } 349 | } 350 | } 351 | Ok(Some((batch_audio, batch_alignments))) 352 | } 353 | } 354 | } 355 | 356 | /// Prosody-Aware Tokenization --- 357 | fn tokenize_with_alignment(&self, text: &str, lan: &str) -> (Vec, Vec<(String, usize, usize)>) { 358 | // We will produce tokens from the full, context-aware phonemes (best prosody) 359 | // and build an alignment map by estimating per-word token spans using 360 | // per-word phoneme tokenization. This keeps audio natural while providing 361 | // robust timestamps even when eSpeak merges words (e.g., "the model"). 362 | 363 | // 1) Full-phrase phonemes and tokens (prosody source) 364 | let full_phonemes = { 365 | let _guard = ESPEAK_MUTEX.lock().unwrap(); 366 | text_to_phonemes(text, lan, None, true, false) 367 | .unwrap_or_default() 368 | .join("") 369 | }; 370 | let all_tokens = tokenize(&full_phonemes); 371 | 372 | // 2) Build a tokenization plan per original "word or punctuation" unit. 373 | // We want punctuation timestamps too, so we split words and punctuation as separate items. 374 | // Simple heuristic: split on whitespace, then further split trailing/leading punctuation 375 | // for .,!?;: characters. 376 | fn split_words_and_punct(s: &str) -> Vec { 377 | let mut out = Vec::new(); 378 | for raw in s.split_whitespace() { 379 | let mut start = 0usize; 380 | let mut end = raw.len(); 381 | let chars: Vec = raw.chars().collect(); 382 | 383 | // Leading punctuation 384 | while start < end { 385 | let c = chars[start]; 386 | if ".,!?:;".contains(c) { 387 | out.push(c.to_string()); 388 | start += 1; 389 | } else { 390 | break; 391 | } 392 | } 393 | // Trailing punctuation 394 | while end > start { 395 | let c = chars[end - 1]; 396 | if ".,!?:;".contains(c) { 397 | end -= 1; 398 | } else { 399 | break; 400 | } 401 | } 402 | if start < end { 403 | out.push(chars[start..end].iter().collect()); 404 | } 405 | // Push trailing punctuation in original order 406 | for i in end..chars.len() { 407 | out.push(chars[i].to_string()); 408 | } 409 | } 410 | out 411 | } 412 | 413 | let items = split_words_and_punct(text); 414 | 415 | // 3) For each item, get its standalone phonemes and token count. 416 | // Punctuation-only items get zero tokens but we still record them for timestamps. 417 | let mut per_item_token_counts: Vec = Vec::with_capacity(items.len()); 418 | let mut per_item_is_punct: Vec = Vec::with_capacity(items.len()); 419 | for it in &items { 420 | if it.len() == 1 && ".,!?:;".contains(it.chars().next().unwrap()) { 421 | per_item_token_counts.push(0); 422 | per_item_is_punct.push(true); 423 | } else { 424 | let ph = { 425 | let _guard = ESPEAK_MUTEX.lock().unwrap(); 426 | text_to_phonemes(it, lan, None, true, false) 427 | .unwrap_or_default() 428 | .join("") 429 | }; 430 | let cnt = tokenize(&ph).len(); 431 | per_item_token_counts.push(cnt); 432 | per_item_is_punct.push(false); 433 | } 434 | } 435 | 436 | // 4) Map per-item token counts onto the full token sequence length. 437 | // If sums differ (likely due to coarticulation/context differences), 438 | // rescale the counts to match the full length, keeping the distribution similar. 439 | let target_len = all_tokens.len(); 440 | let mut sum_counts: usize = per_item_token_counts.iter().sum(); 441 | 442 | let mut adjusted_counts: Vec = per_item_token_counts.clone(); 443 | if sum_counts != target_len && sum_counts > 0 { 444 | let scale = (target_len as f64) / (sum_counts as f64); 445 | let mut fractional: Vec<(usize, f64)> = Vec::with_capacity(adjusted_counts.len()); 446 | let mut new_sum = 0usize; 447 | for (i, &c) in per_item_token_counts.iter().enumerate() { 448 | let scaled = (c as f64) * scale; 449 | let floored = scaled.floor() as usize; 450 | adjusted_counts[i] = floored; 451 | new_sum += floored; 452 | fractional.push((i, scaled - floored as f64)); 453 | } 454 | // Distribute the remaining tokens to the largest fractional parts 455 | let mut remaining = target_len.saturating_sub(new_sum); 456 | fractional.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); 457 | for (i, _) in fractional { 458 | if remaining == 0 { break; } 459 | adjusted_counts[i] += 1; 460 | remaining -= 1; 461 | } 462 | tracing::debug!("Alignment: rescaled per-item token counts from {} to {} to match durations length {}.", sum_counts, adjusted_counts.iter().sum::(), target_len); 463 | sum_counts = adjusted_counts.iter().sum(); 464 | } 465 | 466 | // 5) Build the word_map by assigning contiguous spans across the token stream. 467 | // Punctuation items receive zero-length spans by design (timestamp markers). 468 | let mut word_map: Vec<(String, usize, usize)> = Vec::with_capacity(items.len()); 469 | let mut cursor = 0usize; 470 | for (idx, item) in items.iter().enumerate() { 471 | let cnt = adjusted_counts.get(idx).copied().unwrap_or(0); 472 | if per_item_is_punct[idx] { 473 | // Zero-length marker at current cursor 474 | word_map.push((item.clone(), cursor, cursor)); 475 | } else { 476 | let start_idx = cursor; 477 | let end_idx = cursor.saturating_add(cnt); 478 | word_map.push((item.clone(), start_idx, end_idx)); 479 | cursor = end_idx; 480 | } 481 | } 482 | 483 | // If our mapping under-ran due to rounding issues, extend the last non-punct item to cover all tokens 484 | if cursor < target_len { 485 | if let Some(last_non_punct_pos) = (0..word_map.len()).rev().find(|&i| !(per_item_is_punct[i])) { 486 | let (w, s, _e) = &word_map[last_non_punct_pos]; 487 | word_map[last_non_punct_pos] = (w.clone(), *s, target_len); 488 | } 489 | } 490 | 491 | // If there are absolutely no tokens (empty text), return empty mapping 492 | (all_tokens, word_map) 493 | } 494 | 495 | /// Fast tokenization path for audio-only models (no timestamps) 496 | /// Performs a single eSpeak phonemization for the full text and returns tokens with an empty word map. 497 | fn tokenize_full_no_alignment(&self, text: &str, lan: &str) -> (Vec, Vec<(String, usize, usize)>) { 498 | let full_phonemes = { 499 | let _guard = ESPEAK_MUTEX.lock().unwrap(); 500 | text_to_phonemes(text, lan, None, true, false) 501 | .unwrap_or_default() 502 | .join("") 503 | }; 504 | let all_tokens = tokenize(&full_phonemes); 505 | (all_tokens, Vec::new()) 506 | } 507 | 508 | fn split_text_into_chunks(&self, text: &str, max_tokens: usize) -> Vec { 509 | let mut chunks = Vec::new(); 510 | 511 | // First split by sentences - using common sentence ending punctuation 512 | let sentences: Vec<&str> = text 513 | .split(|c| c == '.' || c == '?' || c == '!' || c == ';') 514 | .filter(|s| !s.trim().is_empty()) 515 | .collect(); 516 | 517 | let mut current_chunk = String::new(); 518 | 519 | for sentence in sentences { 520 | // Clean up the sentence and add back punctuation 521 | let sentence = format!("{}.", sentence.trim()); 522 | 523 | // Convert to phonemes to check token count 524 | let sentence_phonemes = { 525 | let _guard = ESPEAK_MUTEX.lock().unwrap(); 526 | text_to_phonemes(&sentence, "en", None, true, false) 527 | .unwrap_or_default() 528 | .join("") 529 | }; 530 | let token_count = tokenize(&sentence_phonemes).len(); 531 | 532 | if token_count > max_tokens { 533 | // If single sentence is too long, split by words 534 | let words: Vec<&str> = sentence.split_whitespace().collect(); 535 | let mut word_chunk = String::new(); 536 | 537 | for word in words { 538 | let test_chunk = if word_chunk.is_empty() { 539 | word.to_string() 540 | } else { 541 | format!("{} {}", word_chunk, word) 542 | }; 543 | 544 | let test_phonemes = { 545 | let _guard = ESPEAK_MUTEX.lock().unwrap(); 546 | text_to_phonemes(&test_chunk, "en", None, true, false) 547 | .unwrap_or_default() 548 | .join("") 549 | }; 550 | let test_tokens = tokenize(&test_phonemes).len(); 551 | 552 | if test_tokens > max_tokens { 553 | if !word_chunk.is_empty() { 554 | chunks.push(word_chunk); 555 | } 556 | word_chunk = word.to_string(); 557 | } else { 558 | word_chunk = test_chunk; 559 | } 560 | } 561 | 562 | if !word_chunk.is_empty() { 563 | chunks.push(word_chunk); 564 | } 565 | } else if !current_chunk.is_empty() { 566 | // Try to append to current chunk 567 | let test_text = format!("{} {}", current_chunk, sentence); 568 | let test_phonemes = { 569 | let _guard = ESPEAK_MUTEX.lock().unwrap(); 570 | text_to_phonemes(&test_text, "en", None, true, false) 571 | .unwrap_or_default() 572 | .join("") 573 | }; 574 | let test_tokens = tokenize(&test_phonemes).len(); 575 | 576 | if test_tokens > max_tokens { 577 | // If combining would exceed limit, start new chunk 578 | chunks.push(current_chunk); 579 | current_chunk = sentence; 580 | } else { 581 | current_chunk = test_text; 582 | } 583 | } else { 584 | current_chunk = sentence; 585 | } 586 | } 587 | 588 | // Add the last chunk if not empty 589 | if !current_chunk.is_empty() { 590 | chunks.push(current_chunk); 591 | } 592 | 593 | chunks 594 | } 595 | 596 | /// Smart word-based chunking for async streaming 597 | /// Creates chunks based on natural speech boundaries using word count and punctuation 598 | pub fn split_text_into_speech_chunks(&self, text: &str, max_words: usize) -> Vec { 599 | let mut chunks = Vec::new(); 600 | 601 | // Split by sentence-ending punctuation first 602 | let sentences: Vec<&str> = text 603 | .split(|c| c == '.' || c == '!' || c == '?') 604 | .filter(|s| !s.trim().is_empty()) 605 | .collect(); 606 | 607 | for sentence in sentences { 608 | let sentence = sentence.trim(); 609 | if sentence.is_empty() { 610 | continue; 611 | } 612 | 613 | // Count words in this sentence 614 | let words: Vec<&str> = sentence.split_whitespace().collect(); 615 | let word_count = words.len(); 616 | 617 | if word_count <= max_words { 618 | // Small sentence - add as complete chunk (preserve original punctuation) 619 | chunks.push(format!("{}.", sentence)); 620 | } else { 621 | // Large sentence - split by punctuation marks while preserving them 622 | let mut sub_clauses = Vec::new(); 623 | let mut current_pos = 0; 624 | 625 | for (i, ch) in sentence.char_indices() { 626 | if ch == ',' || ch == ';' || ch == ':' { 627 | if i > current_pos { 628 | let clause_with_punct = format!("{}{}", &sentence[current_pos..i], ch); 629 | sub_clauses.push(clause_with_punct); 630 | } 631 | current_pos = i + 1; 632 | } 633 | } 634 | 635 | // Add remaining text 636 | if current_pos < sentence.len() { 637 | sub_clauses.push(sentence[current_pos..].to_string()); 638 | } 639 | 640 | let sub_clauses: Vec<&str> = sub_clauses 641 | .iter() 642 | .map(|s| s.trim()) 643 | .filter(|s| !s.is_empty()) 644 | .collect(); 645 | 646 | let mut current_chunk = String::new(); 647 | let mut current_word_count = 0; 648 | 649 | for clause in sub_clauses { 650 | let clause = clause.trim(); 651 | let clause_words: Vec<&str> = clause.split_whitespace().collect(); 652 | let clause_word_count = clause_words.len(); 653 | 654 | if current_word_count + clause_word_count <= max_words { 655 | // Add clause to current chunk (preserve original punctuation) 656 | if current_chunk.is_empty() { 657 | current_chunk = clause.to_string(); 658 | } else { 659 | current_chunk = format!("{} {}", current_chunk, clause); 660 | } 661 | current_word_count += clause_word_count; 662 | } else { 663 | // Start new chunk (preserve original punctuation) 664 | if !current_chunk.is_empty() { 665 | chunks.push(current_chunk); 666 | } 667 | current_chunk = clause.to_string(); 668 | current_word_count = clause_word_count; 669 | } 670 | } 671 | 672 | // Add final chunk (preserve original punctuation) 673 | if !current_chunk.is_empty() { 674 | chunks.push(current_chunk); 675 | } 676 | } 677 | } 678 | 679 | // If no sentences found, fall back to word-based chunking 680 | if chunks.is_empty() { 681 | let words: Vec<&str> = text.split_whitespace().collect(); 682 | let mut current_chunk = String::new(); 683 | let mut current_word_count = 0; 684 | 685 | for word in words { 686 | if current_word_count + 1 <= max_words { 687 | if current_chunk.is_empty() { 688 | current_chunk = word.to_string(); 689 | } else { 690 | current_chunk = format!("{} {}", current_chunk, word); 691 | } 692 | current_word_count += 1; 693 | } else { 694 | if !current_chunk.is_empty() { 695 | chunks.push(current_chunk); 696 | } 697 | current_chunk = word.to_string(); 698 | current_word_count = 1; 699 | } 700 | } 701 | 702 | if !current_chunk.is_empty() { 703 | chunks.push(current_chunk); 704 | } 705 | } 706 | 707 | chunks 708 | } 709 | 710 | pub fn tts_timestamped_raw_audio( 711 | &self, 712 | txt: &str, 713 | lan: &str, 714 | style_name: &str, 715 | speed: f32, 716 | initial_silence: Option, 717 | request_id: Option<&str>, 718 | instance_id: Option<&str>, 719 | chunk_number: Option, 720 | ) -> Result, Vec)>, Box> { 721 | self.process_internal( 722 | txt, 723 | lan, 724 | style_name, 725 | speed, 726 | initial_silence, 727 | request_id, 728 | instance_id, 729 | chunk_number, 730 | ExecutionMode::Batch 731 | ) 732 | } 733 | 734 | pub fn tts_raw_audio( 735 | &self, 736 | txt: &str, 737 | lan: &str, 738 | style_name: &str, 739 | speed: f32, 740 | initial_silence: Option, 741 | request_id: Option<&str>, 742 | instance_id: Option<&str>, 743 | chunk_number: Option, 744 | ) -> Result, Box> { 745 | let audio = self.process_internal( 746 | txt, 747 | lan, 748 | style_name, 749 | speed, 750 | initial_silence, 751 | request_id, 752 | instance_id, 753 | chunk_number, 754 | ExecutionMode::Batch, 755 | )?; 756 | 757 | Ok(audio.unwrap().0) 758 | } 759 | 760 | /// Streaming version that yields audio chunks as they're generated 761 | pub fn tts_raw_audio_streaming( 762 | &self, 763 | txt: &str, 764 | lan: &str, 765 | style_name: &str, 766 | speed: f32, 767 | initial_silence: Option, 768 | request_id: Option<&str>, 769 | instance_id: Option<&str>, 770 | chunk_number: Option, 771 | mut chunk_callback: F, 772 | ) -> Result<(), Box> 773 | where 774 | F: FnMut(Vec) -> Result<(), Box>, 775 | { 776 | let mut adapter = |output: TtsOutput| -> Result<(), Box> { 777 | chunk_callback(output.raw_output().0) 778 | }; 779 | 780 | self.process_internal( 781 | txt, 782 | lan, 783 | style_name, 784 | speed, 785 | initial_silence, 786 | request_id, 787 | instance_id, 788 | chunk_number, 789 | // Pass the ADAPTER, not the original callback 790 | ExecutionMode::Stream(&mut adapter), 791 | )?; 792 | 793 | Ok(()) 794 | } 795 | 796 | /// Streaming version that strictly requires a timestamped model. 797 | /// Yields audio chunks + alignment data via the callback as they are generated. 798 | pub fn tts_timestamped_raw_audio_streaming( 799 | &self, 800 | txt: &str, 801 | lan: &str, 802 | style_name: &str, 803 | speed: f32, 804 | initial_silence: Option, 805 | request_id: Option<&str>, 806 | instance_id: Option<&str>, 807 | chunk_number: Option, 808 | mut chunk_callback: F, 809 | ) -> Result<(), Box> 810 | where 811 | // CHANGE: Callback accepts TtsOutput instead of just Vec 812 | F: FnMut((Vec, Vec)) -> Result<(), Box>, 813 | { 814 | let mut adapter = |output: TtsOutput| -> Result<(), Box> { 815 | let audio = output.raw_output(); 816 | chunk_callback((audio.0, audio.1.unwrap())) 817 | }; 818 | 819 | self.process_internal( 820 | txt, 821 | lan, 822 | style_name, 823 | speed, 824 | initial_silence, 825 | request_id, 826 | instance_id, 827 | chunk_number, 828 | ExecutionMode::Stream(&mut adapter), 829 | )?; 830 | 831 | Ok(()) 832 | } 833 | 834 | pub fn tts( 835 | &self, 836 | TTSOpts { 837 | txt, 838 | lan, 839 | style_name, 840 | save_path, 841 | mono, 842 | speed, 843 | initial_silence, 844 | }: TTSOpts, 845 | ) -> Result<(), Box> { 846 | let audio = self.tts_raw_audio( 847 | &txt, 848 | lan, 849 | style_name, 850 | speed, 851 | initial_silence, 852 | None, 853 | None, 854 | None, 855 | )?; 856 | 857 | // Save to file 858 | if mono { 859 | let spec = hound::WavSpec { 860 | channels: 1, 861 | sample_rate: self.init_config.sample_rate, 862 | bits_per_sample: 32, 863 | sample_format: hound::SampleFormat::Float, 864 | }; 865 | 866 | let mut writer = hound::WavWriter::create(save_path, spec)?; 867 | for &sample in &audio { 868 | writer.write_sample(sample)?; 869 | } 870 | writer.finalize()?; 871 | } else { 872 | let spec = hound::WavSpec { 873 | channels: 2, 874 | sample_rate: self.init_config.sample_rate, 875 | bits_per_sample: 32, 876 | sample_format: hound::SampleFormat::Float, 877 | }; 878 | 879 | let mut writer = hound::WavWriter::create(save_path, spec)?; 880 | for &sample in &audio { 881 | writer.write_sample(sample)?; 882 | writer.write_sample(sample)?; 883 | } 884 | writer.finalize()?; 885 | } 886 | eprintln!("Audio saved to {}", save_path); 887 | Ok(()) 888 | } 889 | 890 | pub fn mix_styles( 891 | &self, 892 | style_name: &str, 893 | tokens_len: usize, 894 | ) -> Result>, Box> { 895 | if !style_name.contains("+") { 896 | if let Some(style) = self.styles.get(style_name) { 897 | let styles = vec![style[tokens_len][0].to_vec()]; 898 | Ok(styles) 899 | } else { 900 | Err(format!("can not found from styles_map: {}", style_name).into()) 901 | } 902 | } else { 903 | eprintln!("parsing style mix"); 904 | let styles: Vec<&str> = style_name.split('+').collect(); 905 | 906 | let mut style_names = Vec::new(); 907 | let mut style_portions = Vec::new(); 908 | 909 | for style in styles { 910 | if let Some((name, portion)) = style.split_once('.') { 911 | if let Ok(portion) = portion.parse::() { 912 | style_names.push(name); 913 | style_portions.push(portion * 0.1); 914 | } 915 | } 916 | } 917 | eprintln!("styles: {:?}, portions: {:?}", style_names, style_portions); 918 | 919 | let mut blended_style = vec![vec![0.0; 256]; 1]; 920 | 921 | for (name, portion) in style_names.iter().zip(style_portions.iter()) { 922 | if let Some(style) = self.styles.get(*name) { 923 | let style_slice = &style[tokens_len][0]; // This is a [256] array 924 | // Blend into the blended_style 925 | for j in 0..256 { 926 | blended_style[0][j] += style_slice[j] * portion; 927 | } 928 | } 929 | } 930 | eprintln!("blended_style: {:?}", blended_style); 931 | Ok(blended_style) 932 | } 933 | } 934 | 935 | fn load_voices(voices_path: &str) -> HashMap> { 936 | let mut npz = NpzReader::new(File::open(voices_path).unwrap()).unwrap(); 937 | let mut map = HashMap::new(); 938 | 939 | for voice in npz.names().unwrap() { 940 | let voice_data: Result, _> = npz.by_name(&voice); 941 | let voice_data = voice_data.unwrap(); 942 | let mut tensor = vec![[[0.0; 256]; 1]; 511]; 943 | for (i, inner_value) in voice_data.outer_iter().enumerate() { 944 | for (j, inner_inner_value) in inner_value.outer_iter().enumerate() { 945 | for (k, number) in inner_inner_value.iter().enumerate() { 946 | tensor[i][j][k] = *number; 947 | } 948 | } 949 | } 950 | map.insert(voice, tensor); 951 | } 952 | 953 | let _sorted_voices = { 954 | let mut voices = map.keys().collect::>(); 955 | voices.sort(); 956 | 957 | // Only log voices once across all TTS instances 958 | if !VOICES_LOGGED.swap(true, Ordering::Relaxed) { 959 | tracing::info!("=========================================="); 960 | tracing::info!("Voice styles loaded ({} total):", voices.len()); 961 | tracing::info!("=========================================="); 962 | 963 | // Group voices by prefix 964 | let mut grouped_voices: std::collections::BTreeMap<&str, Vec<&str>> = 965 | std::collections::BTreeMap::new(); 966 | for voice in &voices { 967 | if let Some(prefix) = voice.get(0..2) { 968 | grouped_voices 969 | .entry(prefix) 970 | .or_insert_with(Vec::new) 971 | .push(voice); 972 | } 973 | } 974 | 975 | for (prefix, voices_in_group) in grouped_voices { 976 | let category = match prefix { 977 | "af" => "American Female(af)", 978 | "am" => "American Male(am)", 979 | "bf" => "British Female(bf)", 980 | "bm" => "British Male(bm)", 981 | "ef" => "European Female(ef)", 982 | "em" => "European Male(em)", 983 | "ff" => "French Female(ff)", 984 | "hf" => "Hindi Female(hf)", 985 | "hm" => "Hindi Male(hm)", 986 | "if" => "Italian Female(if)", 987 | "im" => "Italian Male(im)", 988 | "jf" => "Japanese Female(jf)", 989 | "jm" => "Japanese Male(jm)", 990 | "pf" => "Portuguese Female(pf)", 991 | "pm" => "Portuguese Male(pm)", 992 | "zf" => "Chinese Female(zf)", 993 | "zm" => "Chinese Male(zm)", 994 | _ => prefix, 995 | }; 996 | 997 | let voices_str = voices_in_group.join(", "); 998 | // Gray out the voice information 999 | tracing::info!("\x1b[90m{}: {}\x1b[0m", category, voices_str); 1000 | } 1001 | 1002 | tracing::info!("=========================================="); 1003 | } 1004 | 1005 | voices 1006 | }; 1007 | 1008 | map 1009 | } 1010 | 1011 | // Returns a sorted list of available voice names 1012 | pub fn get_available_voices(&self) -> Vec { 1013 | let mut voices: Vec = self.styles.keys().cloned().collect(); 1014 | voices.sort(); 1015 | voices 1016 | } 1017 | } 1018 | 1019 | impl TTSKokoParallel { 1020 | pub async fn new_with_instances( 1021 | model_path: &str, 1022 | voices_path: &str, 1023 | num_instances: usize, 1024 | ) -> Self { 1025 | Self::from_config_with_instances( 1026 | model_path, 1027 | voices_path, 1028 | InitConfig::default(), 1029 | num_instances, 1030 | ) 1031 | .await 1032 | } 1033 | 1034 | pub async fn from_config_with_instances( 1035 | model_path: &str, 1036 | voices_path: &str, 1037 | cfg: InitConfig, 1038 | num_instances: usize, 1039 | ) -> Self { 1040 | if !Path::new(model_path).exists() { 1041 | utils::fileio::download_file_from_url(cfg.model_url.as_str(), model_path) 1042 | .await 1043 | .expect("download model failed."); 1044 | } 1045 | 1046 | if !Path::new(voices_path).exists() { 1047 | utils::fileio::download_file_from_url(cfg.voices_url.as_str(), voices_path) 1048 | .await 1049 | .expect("download voices data file failed."); 1050 | } 1051 | 1052 | // Create multiple ONNX model instances 1053 | let mut models = Vec::new(); 1054 | for i in 0..num_instances { 1055 | tracing::info!( 1056 | "Creating TTS instance [{}] ({}/{})", 1057 | format!("{:02x}", i), 1058 | i + 1, 1059 | num_instances 1060 | ); 1061 | let model = Arc::new(Mutex::new( 1062 | ort_koko::OrtKoko::new(model_path.to_string()) 1063 | .expect("Failed to create Kokoro TTS model"), 1064 | )); 1065 | models.push(model); 1066 | } 1067 | 1068 | let styles = TTSKoko::load_voices(voices_path); 1069 | 1070 | TTSKokoParallel { 1071 | model_path: model_path.to_string(), 1072 | models, 1073 | styles, 1074 | init_config: cfg, 1075 | } 1076 | } 1077 | 1078 | /// Get a specific model instance for a worker 1079 | pub fn get_model_instance(&self, worker_id: usize) -> Arc> { 1080 | let index = worker_id % self.models.len(); 1081 | Arc::clone(&self.models[index]) 1082 | } 1083 | 1084 | /// HELPER: Create a lightweight wrapper for a specific model --- 1085 | fn get_tts_wrapper(&self, model_instance: Arc>) -> TTSKoko { 1086 | TTSKoko { 1087 | model_path: self.model_path.clone(), 1088 | model: model_instance, 1089 | // TODO: This clones the HashMap. In a future PR, wrap styles in Arc<>! 1090 | styles: self.styles.clone(), 1091 | init_config: self.init_config.clone(), 1092 | } 1093 | } 1094 | 1095 | /// TTS with timestamps for model instance 1096 | pub fn tts_timestamped_raw_audio_with_instance( 1097 | &self, 1098 | text: &str, 1099 | language: &str, 1100 | style_name: &str, 1101 | speed: f32, 1102 | initial_silence: Option, 1103 | request_id: Option<&str>, 1104 | instance_id: Option<&str>, 1105 | chunk_number: Option, 1106 | model_instance: Arc>, 1107 | ) -> Result, Vec)>, Box> { 1108 | let wrapper = self.get_tts_wrapper(model_instance); 1109 | wrapper.tts_timestamped_raw_audio( 1110 | text, language, style_name, speed, initial_silence, request_id, instance_id, chunk_number 1111 | ) 1112 | } 1113 | 1114 | /// TTS processing with specific model instance (no global lock) 1115 | pub fn tts_raw_audio_with_instance( 1116 | &self, 1117 | text: &str, 1118 | language: &str, 1119 | style_name: &str, 1120 | speed: f32, 1121 | initial_silence: Option, 1122 | request_id: Option<&str>, 1123 | instance_id: Option<&str>, 1124 | chunk_number: Option, 1125 | model_instance: Arc>, 1126 | ) -> Result, Box> { 1127 | let wrapper = self.get_tts_wrapper(model_instance); 1128 | 1129 | wrapper.tts_raw_audio( 1130 | text, language, style_name, speed, 1131 | initial_silence, request_id, instance_id, chunk_number 1132 | ) 1133 | } 1134 | 1135 | /// Forward compatibility - split text method 1136 | pub fn split_text_into_speech_chunks(&self, text: &str, max_words: usize) -> Vec { 1137 | // Use TTSKoko's implementation for now - create temporary instance 1138 | let temp_tts = TTSKoko { 1139 | model_path: self.model_path.clone(), 1140 | model: Arc::clone(&self.models[0]), // Just for interface compatibility 1141 | styles: self.styles.clone(), 1142 | init_config: self.init_config.clone(), 1143 | }; 1144 | temp_tts.split_text_into_speech_chunks(text, max_words) 1145 | } 1146 | 1147 | /// Get available voices 1148 | pub fn get_available_voices(&self) -> Vec { 1149 | let mut voices: Vec = self.styles.keys().cloned().collect(); 1150 | voices.sort(); 1151 | voices 1152 | } 1153 | } 1154 | --------------------------------------------------------------------------------