├── .gitignore ├── Dockerfile.cuda.cn ├── README.md ├── app.py ├── assets ├── app.js ├── audio_process.js ├── images │ ├── record.svg │ └── speaking.svg ├── index.html └── voice.png ├── examples └── sherpa_examples.py ├── requirements.cuda.txt ├── requirements.txt ├── screenshot.jpg └── voiceapi ├── asr.py └── tts.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | venv 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | .env* 17 | .vscode/ 18 | .idea/ 19 | # Mac OS file 20 | .DS_Store 21 | *.pyc 22 | __pycache__/ 23 | *~ 24 | .venv 25 | frpc.ini 26 | 27 | models/ 28 | examples/*.wav 29 | models 30 | -------------------------------------------------------------------------------- /Dockerfile.cuda.cn: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:11.8.0-devel-ubuntu22.04 2 | 3 | RUN sed -i 's/archive.ubuntu.com/mirrors.aliyun.com/g' /etc/apt/sources.list 4 | RUN apt-get update -y && apt-get install -y python3 python3-pip libasound2 libcublas-12-6 libcudnn8-dev 5 | RUN pip3 config set global.index-url https://mirrors.aliyun.com/pypi/web/simple 6 | RUN pip3 install sherpa-onnx==1.11.1+cuda -f https://k2-fsa.github.io/sherpa/onnx/cuda-cn.html 7 | 8 | WORKDIR /app 9 | ADD requirements.cuda.txt /app/ 10 | RUN pip3 install -r requirements.cuda.txt 11 | 12 | ADD . /app/ 13 | ENTRYPOINT ["python3", "app.py"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # voiceapi - A simple and clean voice transcription/synthesis API with sherpa-onnx 2 | 3 | Thanks to [k2-fsa/sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx), we can easily build a voice API with Python. 4 | 5 | 6 | ## Supported models 7 | | Model | Language | Type | Description | 8 | | -------------------------------------- | ----------------------------- | ----------- | ----------------------------------- | 9 | | zipformer-bilingual-zh-en-2023-02-20 | Chinese + English | Online ASR | Streaming Zipformer, Bilingual | 10 | | sense-voice-zh-en-ja-ko-yue-2024-07-17 | Chinese + English | Offline ASR | SenseVoice, Bilingual | 11 | | paraformer-trilingual-zh-cantonese-en | Chinese + Cantonese + English | Offline ASR | Paraformer, Trilingual | 12 | | paraformer-en-2024-03-09 | English | Offline ASR | Paraformer, English | 13 | | vits-zh-hf-theresa | Chinese | TTS | VITS, Chinese, 804 speakers | 14 | | melo-tts-zh_en | Chinese + English | TTS | Melo, Chinese + English, 1 speakers | 15 | | kokoro-multi-lang-v1_0 | Chinese + English | TTS | Chinese + English, 53 speakers | 16 | 17 | ## Run the app locally 18 | Python 3.10+ is required 19 | 20 | ```shell 21 | python3 -m venv venv 22 | . venv/bin/activate 23 | 24 | pip install -r requirements.txt 25 | python app.py 26 | ``` 27 | 28 | Visit `http://localhost:8000/` to see the demo page 29 | 30 | ## Build cuda image (for Chinese users) 31 | ```shell 32 | docker build -t voiceapi:cuda_dev -f Dockerfile.cuda.cn . 33 | ``` 34 | 35 | ## Streaming API (via WebSocket) 36 | ### /asr 37 | Send PCM 16bit audio data to the server, and the server will return the transcription result. 38 | - `samplerate` can be set in the query string, default is 16000. 39 | 40 | The server will return the transcription result in JSON format, with the following fields: 41 | - `text`: the transcription result 42 | - `finished`: whether the segment is finished 43 | - `idx`: the index of the segment 44 | 45 | ```javascript 46 | const ws = new WebSocket('ws://localhost:8000/asr?samplerate=16000'); 47 | ws.onopen = () => { 48 | console.log('connected'); 49 | ws.send('{"sid": 0}'); 50 | }; 51 | ws.onmessage = (e) => { 52 | const data = JSON.parse(e.data); 53 | const { text, finished, idx } = data; 54 | // do something with text 55 | // finished is true when the segment is finished 56 | }; 57 | // send audio data 58 | // PCM 16bit, with samplerate 59 | ws.send(int16Array.buffer); 60 | ``` 61 | ### /tts 62 | Send text to the server, and the server will return the synthesized audio data. 63 | - `samplerate` can be set in the query string, default is 16000. 64 | - `sid` is the Speaker ID, default is 0. 65 | - `speed` is the speed of the synthesized audio, default is 1.0. 66 | - `chunk_size` is the size of the audio chunk, default is 1024. 67 | 68 | The server will return the synthesized audio data in binary format. 69 | - The audio data is in PCM 16bit format, with the binary data in the response body. 70 | - The server will return the synthesized result with json format, with the following fields: 71 | - `elapsed`: the elapsed time 72 | - `progress`: the progress of the synthesis 73 | - `duration`: the duration of the synthesis 74 | - `size`: the size of the synthesized audio data 75 | 76 | ```javascript 77 | const ws = new WebSocket('ws://localhost:8000/tts?samplerate=16000'); 78 | ws.onopen = () => { 79 | console.log('connected'); 80 | ws.send('Your text here'); 81 | }; 82 | ws.onmessage = (e) => { 83 | if (e.data instanceof Blob) { 84 | // Chunked audio data 85 | e.data.arrayBuffer().then((arrayBuffer) => { 86 | const int16Array = new Int16Array(arrayBuffer); 87 | let float32Array = new Float32Array(int16Array.length); 88 | for (let i = 0; i < int16Array.length; i++) { 89 | float32Array[i] = int16Array[i] / 32768.; 90 | } 91 | playNode.port.postMessage({ message: 'audioData', audioData: float32Array }); 92 | }); 93 | } else { 94 | // The server will return the synthesized result 95 | const {elapsed, progress, duration, size } = JSON.parse(e.data); 96 | this.elapsedTime = elapsed; 97 | } 98 | }; 99 | ``` 100 | 101 | ### No Streaming API 102 | #### /tts 103 | Send text to the server, and the server will return the synthesized audio data. 104 | 105 | - `text` is the text to be synthesized. 106 | - `samplerate` can be set in the query string, default is 16000. 107 | - `sid` is the Speaker ID, default is 0. 108 | - `speed` is the speed of the synthesized audio, default is 1.0. 109 | - 110 | ```shell 111 | curl -X POST "http://localhost:8000/tts" \ 112 | -H "Content-Type: application/json" \ 113 | -d '{ 114 | "text": "Hello, world!", 115 | "sid": 0, 116 | "samplerate": 16000 117 | }' -o helloworkd.wav 118 | ``` 119 | 120 | ## Download models 121 | All models are stored in the `models` directory 122 | Only download the models you need. default models are: 123 | - asr models: `sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20`(Bilingual, Chinese + English). Streaming 124 | - tts models: `vits-zh-hf-theresa` (Chinese + English) 125 | 126 | ### silero_vad.onnx 127 | > silero is required for ASR 128 | ```bash 129 | mkdir -p silero_vad 130 | cd silero_vad 131 | curl -SL -o silero_vad/silero_vad.onnx https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx 132 | ``` 133 | 134 | ### FireRedASR-AED-L 135 | ```bash 136 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 137 | ``` 138 | ### kokoro-multi-lang-v1_0 139 | ```bash 140 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-multi-lang-v1_0.tar.bz2 141 | ``` 142 | 143 | ### vits-zh-hf-theresa 144 | ```bash 145 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-zh-hf-theresa.tar.bz2 146 | ``` 147 | 148 | ### vits-melo-tts-zh_en 149 | ```bash 150 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-melo-tts-zh_en.tar.bz2 151 | ``` 152 | ### sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 153 | ```bash 154 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 155 | ``` 156 | 157 | ### sherpa-onnx-paraformer-trilingual-zh-cantonese-en 158 | ```bash 159 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-trilingual-zh-cantonese-en.tar.bz2 160 | ``` 161 | ### whisper 162 | ```bash 163 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 164 | ``` 165 | ### sensevoice 166 | ```bash 167 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 168 | ``` 169 | 170 | ### sherpa-onnx-streaming-paraformer-bilingual-zh-en 171 | ```bash 172 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 173 | ``` 174 | 175 | ### sherpa-onnx-paraformer-trilingual-zh-cantonese-en 176 | ```bash 177 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-trilingual-zh-cantonese-en.tar.bz2 178 | ``` 179 | ### sherpa-onnx-paraformer-en 180 | ```bash 181 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-en-2024-03-09.tar.bz2 182 | ``` 183 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect, Query 3 | from fastapi.responses import HTMLResponse, StreamingResponse 4 | from fastapi.staticfiles import StaticFiles 5 | import asyncio 6 | import logging 7 | from pydantic import BaseModel, Field 8 | import uvicorn 9 | from voiceapi.tts import TTSResult, start_tts_stream, TTSStream 10 | from voiceapi.asr import start_asr_stream, ASRStream, ASRResult 11 | import logging 12 | import argparse 13 | import os 14 | 15 | app = FastAPI() 16 | logger = logging.getLogger(__file__) 17 | 18 | 19 | @app.websocket("/asr") 20 | async def websocket_asr(websocket: WebSocket, 21 | samplerate: int = Query(16000, title="Sample Rate", 22 | description="The sample rate of the audio."),): 23 | await websocket.accept() 24 | 25 | asr_stream: ASRStream = await start_asr_stream(samplerate, args) 26 | if not asr_stream: 27 | logger.error("failed to start ASR stream") 28 | await websocket.close() 29 | return 30 | 31 | async def task_recv_pcm(): 32 | while True: 33 | pcm_bytes = await websocket.receive_bytes() 34 | if not pcm_bytes: 35 | return 36 | await asr_stream.write(pcm_bytes) 37 | 38 | async def task_send_result(): 39 | while True: 40 | result: ASRResult = await asr_stream.read() 41 | if not result: 42 | return 43 | await websocket.send_json(result.to_dict()) 44 | try: 45 | await asyncio.gather(task_recv_pcm(), task_send_result()) 46 | except WebSocketDisconnect: 47 | logger.info("asr: disconnected") 48 | finally: 49 | await asr_stream.close() 50 | 51 | 52 | @app.websocket("/tts") 53 | async def websocket_tts(websocket: WebSocket, 54 | samplerate: int = Query(16000, 55 | title="Sample Rate", 56 | description="The sample rate of the generated audio."), 57 | interrupt: bool = Query(True, 58 | title="Interrupt", 59 | description="Interrupt the current TTS stream when a new text is received."), 60 | sid: int = Query(0, 61 | title="Speaker ID", 62 | description="The ID of the speaker to use for TTS."), 63 | chunk_size: int = Query(1024, 64 | title="Chunk Size", 65 | description="The size of the chunk to send to the client."), 66 | speed: float = Query(1.0, 67 | title="Speed", 68 | description="The speed of the generated audio."), 69 | split: bool = Query(True, 70 | title="Split", 71 | description="Split the text into sentences.")): 72 | 73 | await websocket.accept() 74 | tts_stream: TTSStream = None 75 | 76 | async def task_recv_text(): 77 | nonlocal tts_stream 78 | while True: 79 | text = await websocket.receive_text() 80 | if not text: 81 | return 82 | 83 | if interrupt or not tts_stream: 84 | if tts_stream: 85 | await tts_stream.close() 86 | logger.info("tts: stream interrupt") 87 | 88 | tts_stream = await start_tts_stream(sid, samplerate, speed, args) 89 | if not tts_stream: 90 | logger.error("tts: failed to allocate tts stream") 91 | await websocket.close() 92 | return 93 | logger.info(f"tts: received: {text} (split={split})") 94 | await tts_stream.write(text, split) 95 | 96 | async def task_send_pcm(): 97 | nonlocal tts_stream 98 | while not tts_stream: 99 | # wait for tts stream to be created 100 | await asyncio.sleep(0.1) 101 | 102 | while True: 103 | result: TTSResult = await tts_stream.read() 104 | if not result: 105 | return 106 | 107 | if result.finished: 108 | await websocket.send_json(result.to_dict()) 109 | else: 110 | for i in range(0, len(result.pcm_bytes), chunk_size): 111 | await websocket.send_bytes(result.pcm_bytes[i:i+chunk_size]) 112 | 113 | try: 114 | await asyncio.gather(task_recv_text(), task_send_pcm()) 115 | except WebSocketDisconnect: 116 | logger.info("tts: disconnected") 117 | finally: 118 | if tts_stream: 119 | await tts_stream.close() 120 | 121 | 122 | class TTSRequest(BaseModel): 123 | text: str = Field(..., title="Text", 124 | description="The text to be converted to speech.", 125 | examples=["Hello, world!"]) 126 | sid: int = Field(0, title="Speaker ID", 127 | description="The ID of the speaker to use for TTS.") 128 | samplerate: int = Field(16000, title="Sample Rate", 129 | description="The sample rate of the generated audio.") 130 | speed: float = Field(1.0, title="Speed", 131 | description="The speed of the generated audio.") 132 | 133 | 134 | @ app.post("/tts", 135 | description="Generate speech audio from text.", 136 | response_class=StreamingResponse, responses={200: {"content": {"audio/wav": {}}}}) 137 | async def tts_generate(req: TTSRequest): 138 | if not req.text: 139 | raise HTTPException(status_code=400, detail="text is required") 140 | 141 | tts_stream = await start_tts_stream(req.sid, req.samplerate, req.speed, args) 142 | if not tts_stream: 143 | raise HTTPException( 144 | status_code=500, detail="failed to start TTS stream") 145 | 146 | r = await tts_stream.generate(req.text) 147 | return StreamingResponse(r, media_type="audio/wav") 148 | 149 | 150 | if __name__ == "__main__": 151 | models_root = './models' 152 | 153 | for d in ['.', '..', '../..']: 154 | if os.path.isdir(f'{d}/models'): 155 | models_root = f'{d}/models' 156 | break 157 | 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument("--port", type=int, default=8000, help="port number") 160 | parser.add_argument("--addr", type=str, 161 | default="0.0.0.0", help="serve address") 162 | 163 | parser.add_argument("--asr-provider", type=str, 164 | default="cpu", help="asr provider, cpu or cuda") 165 | parser.add_argument("--tts-provider", type=str, 166 | default="cpu", help="tts provider, cpu or cuda") 167 | 168 | parser.add_argument("--threads", type=int, default=2, 169 | help="number of threads") 170 | 171 | parser.add_argument("--models-root", type=str, default=models_root, 172 | help="model root directory") 173 | 174 | parser.add_argument("--asr-model", type=str, default='sensevoice', 175 | help="ASR model name: zipformer-bilingual, sensevoice, paraformer-trilingual, paraformer-en, fireredasr") 176 | 177 | parser.add_argument("--asr-lang", type=str, default='zh', 178 | help="ASR language, zh, en, ja, ko, yue") 179 | 180 | parser.add_argument("--tts-model", type=str, default='vits-zh-hf-theresa', 181 | help="TTS model name: vits-zh-hf-theresa, vits-melo-tts-zh_en, kokoro-multi-lang-v1_0") 182 | 183 | args = parser.parse_args() 184 | 185 | if args.tts_model == 'vits-melo-tts-zh_en' and args.tts_provider == 'cuda': 186 | logger.warning( 187 | "vits-melo-tts-zh_en does not support CUDA fallback to CPU") 188 | args.tts_provider = 'cpu' 189 | 190 | app.mount("/", app=StaticFiles(directory="./assets", html=True), name="assets") 191 | 192 | logging.basicConfig(format='%(levelname)s: %(asctime)s %(name)s:%(lineno)s %(message)s', 193 | level=logging.INFO) 194 | uvicorn.run(app, host=args.addr, port=args.port) 195 | -------------------------------------------------------------------------------- /assets/app.js: -------------------------------------------------------------------------------- 1 | const demoapp = { 2 | text: '讲个冷笑话吧,要很好笑的那种。', 3 | recording: false, 4 | asrWS: null, 5 | currentText: null, 6 | disabled: false, 7 | elapsedTime: null, 8 | logs: [{ idx: 0, text: 'Happily here at ruzhila.cn.' }], 9 | async init() { 10 | }, 11 | async dotts() { 12 | let audioContext = new AudioContext({ sampleRate: 16000 }) 13 | await audioContext.audioWorklet.addModule('./audio_process.js') 14 | 15 | const ws = new WebSocket('/tts'); 16 | ws.onopen = () => { 17 | ws.send(this.text); 18 | }; 19 | const playNode = new AudioWorkletNode(audioContext, 'play-audio-processor'); 20 | playNode.connect(audioContext.destination); 21 | 22 | this.disabled = true; 23 | ws.onmessage = async (e) => { 24 | if (e.data instanceof Blob) { 25 | e.data.arrayBuffer().then((arrayBuffer) => { 26 | const int16Array = new Int16Array(arrayBuffer); 27 | let float32Array = new Float32Array(int16Array.length); 28 | for (let i = 0; i < int16Array.length; i++) { 29 | float32Array[i] = int16Array[i] / 32768.; 30 | } 31 | playNode.port.postMessage({ message: 'audioData', audioData: float32Array }); 32 | }); 33 | } else { 34 | this.elapsedTime = JSON.parse(e.data)?.elapsed; 35 | this.disabled = false; 36 | } 37 | } 38 | }, 39 | 40 | async stopasr() { 41 | if (!this.asrWS) { 42 | return; 43 | } 44 | this.asrWS.close(); 45 | this.asrWS = null; 46 | this.recording = false; 47 | if (this.currentText) { 48 | this.logs.push({ idx: this.logs.length + 1, text: this.currentText }); 49 | } 50 | this.currentText = null; 51 | 52 | }, 53 | 54 | async doasr() { 55 | const audioConstraints = { 56 | video: false, 57 | audio: true, 58 | }; 59 | 60 | const mediaStream = await navigator.mediaDevices.getUserMedia(audioConstraints); 61 | 62 | const ws = new WebSocket('/asr'); 63 | let currentMessage = ''; 64 | 65 | ws.onopen = () => { 66 | this.logs = []; 67 | }; 68 | 69 | ws.onmessage = (e) => { 70 | const data = JSON.parse(e.data); 71 | const { text, finished, idx } = data; 72 | 73 | currentMessage = text; 74 | this.currentText = text 75 | 76 | if (finished) { 77 | this.logs.push({ text: currentMessage, idx: idx }); 78 | currentMessage = ''; 79 | this.currentText = null 80 | } 81 | }; 82 | 83 | let audioContext = new AudioContext({ sampleRate: 16000 }) 84 | await audioContext.audioWorklet.addModule('./audio_process.js') 85 | 86 | const recordNode = new AudioWorkletNode(audioContext, 'record-audio-processor'); 87 | recordNode.connect(audioContext.destination); 88 | recordNode.port.onmessage = (event) => { 89 | if (ws && ws.readyState === WebSocket.OPEN) { 90 | const int16Array = event.data.data; 91 | ws.send(int16Array.buffer); 92 | } 93 | } 94 | const source = audioContext.createMediaStreamSource(mediaStream); 95 | source.connect(recordNode); 96 | this.asrWS = ws; 97 | this.recording = true; 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /assets/audio_process.js: -------------------------------------------------------------------------------- 1 | class PlayerAudioProcessor extends AudioWorkletProcessor { 2 | constructor() { 3 | super(); 4 | this.buffer = new Float32Array(); 5 | this.port.onmessage = (event) => { 6 | let newFetchedData = new Float32Array(this.buffer.length + event.data.audioData.length); 7 | newFetchedData.set(this.buffer, 0); 8 | newFetchedData.set(event.data.audioData, this.buffer.length); 9 | this.buffer = newFetchedData; 10 | }; 11 | } 12 | 13 | process(inputs, outputs, parameters) { 14 | const output = outputs[0]; 15 | const channel = output[0]; 16 | const bufferLength = this.buffer.length; 17 | for (let i = 0; i < channel.length; i++) { 18 | channel[i] = (i < bufferLength) ? this.buffer[i] : 0; 19 | } 20 | this.buffer = this.buffer.slice(channel.length); 21 | return true; 22 | } 23 | } 24 | 25 | class RecordAudioProcessor extends AudioWorkletProcessor { 26 | constructor() { 27 | super(); 28 | } 29 | 30 | process(inputs, outputs, parameters) { 31 | const channel = inputs[0][0]; 32 | if (!channel || channel.length === 0) { 33 | return true; 34 | } 35 | const int16Array = new Int16Array(channel.length); 36 | for (let i = 0; i < channel.length; i++) { 37 | int16Array[i] = channel[i] * 32767; 38 | } 39 | this.port.postMessage({ data: int16Array }); 40 | return true 41 | } 42 | } 43 | 44 | registerProcessor('play-audio-processor', PlayerAudioProcessor); 45 | registerProcessor('record-audio-processor', RecordAudioProcessor); -------------------------------------------------------------------------------- /assets/images/record.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/images/speaking.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | voiceapi demo 13 | 19 | 20 | 58 | 59 | 60 | 61 | 87 |
88 |
89 |
90 |

91 | VoiceAPI Demo / 92 | ruzhila.cn 93 |

94 | 95 | 97 | 100 | 101 | 102 |
103 |
104 | 105 |
106 |
107 |
108 | 110 |
111 | 112 |
113 | 124 |
125 | 128 |
129 | 130 | 131 |
132 | 133 |
135 | 143 |
144 | 145 | 146 | 149 | 150 | 159 | 160 | 173 |
174 |
175 |
176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /assets/voice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruzhila/voiceapi/8612463d354213debdf62dbee2fa09fcc76eecec/assets/voice.png -------------------------------------------------------------------------------- /examples/sherpa_examples.py: -------------------------------------------------------------------------------- 1 | #!/bin/env python3 2 | """ 3 | Real-time ASR using microphone 4 | """ 5 | 6 | import argparse 7 | import logging 8 | import sherpa_onnx 9 | import os 10 | import time 11 | import struct 12 | import asyncio 13 | import soundfile 14 | 15 | try: 16 | import pyaudio 17 | except ImportError: 18 | raise ImportError('Please install pyaudio with `pip install pyaudio`') 19 | 20 | logger = logging.getLogger(__name__) 21 | SAMPLE_RATE = 16000 22 | 23 | pactx = pyaudio.PyAudio() 24 | models_root: str = None 25 | num_threads: int = 1 26 | 27 | 28 | def create_zipformer(args) -> sherpa_onnx.OnlineRecognizer: 29 | d = os.path.join( 30 | models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20') 31 | encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx") 32 | decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx") 33 | joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx") 34 | tokens = os.path.join(d, "tokens.txt") 35 | 36 | recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( 37 | tokens=tokens, 38 | encoder=encoder, 39 | decoder=decoder, 40 | joiner=joiner, 41 | provider=args.provider, 42 | num_threads=num_threads, 43 | sample_rate=SAMPLE_RATE, 44 | feature_dim=80, 45 | enable_endpoint_detection=True, 46 | rule1_min_trailing_silence=2.4, 47 | rule2_min_trailing_silence=1.2, 48 | rule3_min_utterance_length=20, # it essentially disables this rule 49 | ) 50 | return recognizer 51 | 52 | 53 | def create_sensevoice(args) -> sherpa_onnx.OfflineRecognizer: 54 | model = os.path.join( 55 | models_root, 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17', 'model.onnx') 56 | tokens = os.path.join( 57 | models_root, 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17', 'tokens.txt') 58 | recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( 59 | model=model, 60 | tokens=tokens, 61 | num_threads=num_threads, 62 | use_itn=True, 63 | debug=0, 64 | language=args.lang, 65 | ) 66 | return recognizer 67 | 68 | 69 | async def run_online(buf, recognizer): 70 | stream = recognizer.create_stream() 71 | last_result = "" 72 | segment_id = 0 73 | logger.info('Start real-time recognizer') 74 | while True: 75 | samples = await buf.get() 76 | stream.accept_waveform(SAMPLE_RATE, samples) 77 | while recognizer.is_ready(stream): 78 | recognizer.decode_stream(stream) 79 | 80 | is_endpoint = recognizer.is_endpoint(stream) 81 | result = recognizer.get_result(stream) 82 | 83 | if result and (last_result != result): 84 | last_result = result 85 | logger.info(f' > {segment_id}:{result}') 86 | 87 | if is_endpoint: 88 | if result: 89 | logger.info(f'{segment_id}: {result}') 90 | segment_id += 1 91 | recognizer.reset(stream) 92 | 93 | 94 | async def run_offline(buf, recognizer): 95 | config = sherpa_onnx.VadModelConfig() 96 | config.silero_vad.model = os.path.join( 97 | models_root, 'silero_vad', 'silero_vad.onnx') 98 | config.silero_vad.min_silence_duration = 0.25 99 | config.sample_rate = SAMPLE_RATE 100 | vad = sherpa_onnx.VoiceActivityDetector( 101 | config, buffer_size_in_seconds=100) 102 | 103 | logger.info('Start offline recognizer with VAD') 104 | texts = [] 105 | while True: 106 | samples = await buf.get() 107 | vad.accept_waveform(samples) 108 | while not vad.empty(): 109 | stream = recognizer.create_stream() 110 | stream.accept_waveform(SAMPLE_RATE, vad.front.samples) 111 | 112 | vad.pop() 113 | recognizer.decode_stream(stream) 114 | 115 | text = stream.result.text.strip().lower() 116 | if len(text): 117 | idx = len(texts) 118 | texts.append(text) 119 | logger.info(f"{idx}: {text}") 120 | 121 | 122 | async def handle_asr(args): 123 | action_func = None 124 | if args.model == 'zipformer': 125 | recognizer = create_zipformer(args) 126 | action_func = run_online 127 | elif args.model == 'sensevoice': 128 | recognizer = create_sensevoice(args) 129 | action_func = run_offline 130 | else: 131 | raise ValueError(f'Unknown model: {args.model}') 132 | buf = asyncio.Queue() 133 | recorder_task = asyncio.create_task(run_record(buf)) 134 | asr_task = asyncio.create_task(action_func(buf, recognizer)) 135 | await asyncio.gather(asr_task, recorder_task) 136 | 137 | 138 | async def handle_tts(args): 139 | model = os.path.join( 140 | models_root, 'vits-melo-tts-zh_en', 'model.onnx') 141 | lexicon = os.path.join( 142 | models_root, 'vits-melo-tts-zh_en', 'lexicon.txt') 143 | dict_dir = os.path.join( 144 | models_root, 'vits-melo-tts-zh_en', 'dict') 145 | tokens = os.path.join( 146 | models_root, 'vits-melo-tts-zh_en', 'tokens.txt') 147 | tts_config = sherpa_onnx.OfflineTtsConfig( 148 | model=sherpa_onnx.OfflineTtsModelConfig( 149 | vits=sherpa_onnx.OfflineTtsVitsModelConfig( 150 | model=model, 151 | lexicon=lexicon, 152 | dict_dir=dict_dir, 153 | tokens=tokens, 154 | ), 155 | provider=args.provider, 156 | debug=0, 157 | num_threads=num_threads, 158 | ), 159 | max_num_sentences=args.max_num_sentences, 160 | ) 161 | if not tts_config.validate(): 162 | raise ValueError("Please check your config") 163 | 164 | tts = sherpa_onnx.OfflineTts(tts_config) 165 | 166 | start = time.time() 167 | audio = tts.generate(args.text, sid=args.sid, 168 | speed=args.speed) 169 | elapsed_seconds = time.time() - start 170 | audio_duration = len(audio.samples) / audio.sample_rate 171 | real_time_factor = elapsed_seconds / audio_duration 172 | 173 | if args.output: 174 | logger.info(f"Saved to {args.output}") 175 | soundfile.write( 176 | args.output, 177 | audio.samples, 178 | samplerate=audio.sample_rate, 179 | subtype="PCM_16", 180 | ) 181 | 182 | logger.info(f"The text is '{args.text}'") 183 | logger.info(f"Elapsed seconds: {elapsed_seconds:.3f}") 184 | logger.info(f"Audio duration in seconds: {audio_duration:.3f}") 185 | logger.info( 186 | f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") 187 | 188 | 189 | async def run_record(buf: asyncio.Queue[list[float]]): 190 | loop = asyncio.get_event_loop() 191 | 192 | def on_input(in_data, frame_count, time_info, status): 193 | samples = [ 194 | v/32768.0 for v in list(struct.unpack('<' + 'h' * frame_count, in_data))] 195 | loop.create_task(buf.put(samples)) 196 | return (None, pyaudio.paContinue) 197 | 198 | frame_size = 320 199 | recorder = pactx.open(format=pyaudio.paInt16, channels=1, 200 | rate=SAMPLE_RATE, input=True, 201 | frames_per_buffer=frame_size, 202 | stream_callback=on_input) 203 | recorder.start_stream() 204 | logger.info('Start recording') 205 | 206 | while recorder.is_active(): 207 | await asyncio.sleep(0.1) 208 | 209 | 210 | async def main(): 211 | parser = argparse.ArgumentParser() 212 | parser.add_argument('--provider', default='cpu', 213 | help='onnxruntime provider, default is cpu, use cuda for GPU') 214 | 215 | subparsers = parser.add_subparsers(help='commands help') 216 | 217 | asr_parser = subparsers.add_parser('asr', help='run asr mode') 218 | asr_parser.add_argument('--model', default='zipformer', 219 | help='model name, default is zipformer') 220 | asr_parser.add_argument('--lang', default='zh', 221 | help='language, default is zh') 222 | asr_parser.set_defaults(func=handle_asr) 223 | 224 | tts_parser = subparsers.add_parser('tts', help='run tts mode') 225 | tts_parser.add_argument('--sid', type=int, default=0, help="""Speaker ID. Used only for multi-speaker models, e.g. 226 | models trained using the VCTK dataset. Not used for single-speaker 227 | models, e.g., models trained using the LJ speech dataset. 228 | """) 229 | tts_parser.add_argument('--output', type=str, default='output.wav', 230 | help='output file name, default is output.wav') 231 | tts_parser.add_argument( 232 | "--speed", 233 | type=float, 234 | default=1.0, 235 | help="Speech speed. Larger->faster; smaller->slower", 236 | ) 237 | tts_parser.add_argument( 238 | "--max-num-sentences", 239 | type=int, 240 | default=2, 241 | help="""Max number of sentences in a batch to avoid OOM if the input 242 | text is very long. Set it to -1 to process all the sentences in a 243 | single batch. A smaller value does not mean it is slower compared 244 | to a larger one on CPU. 245 | """, 246 | ) 247 | tts_parser.add_argument( 248 | "text", 249 | type=str, 250 | help="The input text to generate audio for", 251 | ) 252 | tts_parser.set_defaults(func=handle_tts) 253 | 254 | args = parser.parse_args() 255 | 256 | if hasattr(args, 'func'): 257 | await args.func(args) 258 | else: 259 | parser.print_help() 260 | 261 | if __name__ == '__main__': 262 | logging.basicConfig( 263 | format='%(levelname)s: %(asctime)s %(name)s:%(lineno)s %(message)s') 264 | logging.getLogger().setLevel(logging.INFO) 265 | painfo = pactx.get_default_input_device_info() 266 | assert painfo['maxInputChannels'] >= 1, 'No input device' 267 | logger.info('Default input device: %s', painfo['name']) 268 | 269 | for d in ['.', '..', '../..']: 270 | if os.path.isdir(f'{d}/models'): 271 | models_root = f'{d}/models' 272 | break 273 | assert models_root, 'Could not find models directory' 274 | asyncio.run(main()) 275 | -------------------------------------------------------------------------------- /requirements.cuda.txt: -------------------------------------------------------------------------------- 1 | onnxruntime-gpu == 1.19.2 2 | soundfile == 0.12.1 3 | fastapi == 0.114.1 4 | uvicorn == 0.30.6 5 | scipy == 1.13.1 6 | numpy == 1.26.4 7 | websockets == 13.0.1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sherpa-onnx == 1.11.3 2 | soundfile == 0.12.1 3 | fastapi == 0.114.1 4 | uvicorn == 0.30.6 5 | scipy == 1.13.1 6 | numpy == 1.26.4 7 | websockets == 13.0.1 -------------------------------------------------------------------------------- /screenshot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruzhila/voiceapi/8612463d354213debdf62dbee2fa09fcc76eecec/screenshot.jpg -------------------------------------------------------------------------------- /voiceapi/asr.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import logging 3 | import time 4 | import logging 5 | import sherpa_onnx 6 | import os 7 | import asyncio 8 | import numpy as np 9 | 10 | logger = logging.getLogger(__file__) 11 | _asr_engines = {} 12 | 13 | 14 | class ASRResult: 15 | def __init__(self, text: str, finished: bool, idx: int): 16 | self.text = text 17 | self.finished = finished 18 | self.idx = idx 19 | 20 | def to_dict(self): 21 | return {"text": self.text, "finished": self.finished, "idx": self.idx} 22 | 23 | 24 | class ASRStream: 25 | def __init__(self, recognizer: Union[sherpa_onnx.OnlineRecognizer | sherpa_onnx.OfflineRecognizer], sample_rate: int) -> None: 26 | self.recognizer = recognizer 27 | self.inbuf = asyncio.Queue() 28 | self.outbuf = asyncio.Queue() 29 | self.sample_rate = sample_rate 30 | self.is_closed = False 31 | self.online = isinstance(recognizer, sherpa_onnx.OnlineRecognizer) 32 | 33 | async def start(self): 34 | if self.online: 35 | asyncio.create_task(self.run_online()) 36 | else: 37 | asyncio.create_task(self.run_offline()) 38 | 39 | async def run_online(self): 40 | stream = self.recognizer.create_stream() 41 | last_result = "" 42 | segment_id = 0 43 | logger.info('asr: start real-time recognizer') 44 | while not self.is_closed: 45 | samples = await self.inbuf.get() 46 | stream.accept_waveform(self.sample_rate, samples) 47 | while self.recognizer.is_ready(stream): 48 | self.recognizer.decode_stream(stream) 49 | 50 | is_endpoint = self.recognizer.is_endpoint(stream) 51 | result = self.recognizer.get_result(stream) 52 | 53 | if result and (last_result != result): 54 | last_result = result 55 | logger.info(f' > {segment_id}:{result}') 56 | self.outbuf.put_nowait( 57 | ASRResult(result, False, segment_id)) 58 | 59 | if is_endpoint: 60 | if result: 61 | logger.info(f'{segment_id}: {result}') 62 | self.outbuf.put_nowait( 63 | ASRResult(result, True, segment_id)) 64 | segment_id += 1 65 | self.recognizer.reset(stream) 66 | 67 | async def run_offline(self): 68 | vad = _asr_engines['vad'] 69 | segment_id = 0 70 | st = None 71 | while not self.is_closed: 72 | samples = await self.inbuf.get() 73 | vad.accept_waveform(samples) 74 | while not vad.empty(): 75 | if not st: 76 | st = time.time() 77 | stream = self.recognizer.create_stream() 78 | stream.accept_waveform(self.sample_rate, vad.front.samples) 79 | 80 | vad.pop() 81 | self.recognizer.decode_stream(stream) 82 | 83 | result = stream.result.text.strip() 84 | if result: 85 | duration = time.time() - st 86 | logger.info(f'{segment_id}:{result} ({duration:.2f}s)') 87 | self.outbuf.put_nowait(ASRResult(result, True, segment_id)) 88 | segment_id += 1 89 | st = None 90 | 91 | async def close(self): 92 | self.is_closed = True 93 | self.outbuf.put_nowait(None) 94 | 95 | async def write(self, pcm_bytes: bytes): 96 | pcm_data = np.frombuffer(pcm_bytes, dtype=np.int16) 97 | samples = pcm_data.astype(np.float32) / 32768.0 98 | self.inbuf.put_nowait(samples) 99 | 100 | async def read(self) -> ASRResult: 101 | return await self.outbuf.get() 102 | 103 | 104 | def create_zipformer(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: 105 | d = os.path.join( 106 | args.models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20') 107 | if not os.path.exists(d): 108 | raise ValueError(f"asr: model not found {d}") 109 | 110 | encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx") 111 | decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx") 112 | joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx") 113 | tokens = os.path.join(d, "tokens.txt") 114 | 115 | recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( 116 | tokens=tokens, 117 | encoder=encoder, 118 | decoder=decoder, 119 | joiner=joiner, 120 | provider=args.asr_provider, 121 | num_threads=args.threads, 122 | sample_rate=samplerate, 123 | feature_dim=80, 124 | enable_endpoint_detection=True, 125 | rule1_min_trailing_silence=2.4, 126 | rule2_min_trailing_silence=1.2, 127 | rule3_min_utterance_length=20, # it essentially disables this rule 128 | ) 129 | return recognizer 130 | 131 | 132 | def create_sensevoice(samplerate: int, args) -> sherpa_onnx.OfflineRecognizer: 133 | d = os.path.join(args.models_root, 134 | 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17') 135 | 136 | if not os.path.exists(d): 137 | raise ValueError(f"asr: model not found {d}") 138 | 139 | recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( 140 | model=os.path.join(d, 'model.onnx'), 141 | tokens=os.path.join(d, 'tokens.txt'), 142 | num_threads=args.threads, 143 | sample_rate=samplerate, 144 | use_itn=True, 145 | debug=0, 146 | language=args.asr_lang, 147 | ) 148 | return recognizer 149 | 150 | 151 | def create_paraformer_trilingual(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: 152 | d = os.path.join( 153 | args.models_root, 'sherpa-onnx-paraformer-trilingual-zh-cantonese-en') 154 | if not os.path.exists(d): 155 | raise ValueError(f"asr: model not found {d}") 156 | 157 | recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( 158 | paraformer=os.path.join(d, 'model.onnx'), 159 | tokens=os.path.join(d, 'tokens.txt'), 160 | num_threads=args.threads, 161 | sample_rate=samplerate, 162 | debug=0, 163 | provider=args.asr_provider, 164 | ) 165 | return recognizer 166 | 167 | 168 | def create_paraformer_en(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: 169 | d = os.path.join( 170 | args.models_root, 'sherpa-onnx-paraformer-en') 171 | if not os.path.exists(d): 172 | raise ValueError(f"asr: model not found {d}") 173 | 174 | recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( 175 | paraformer=os.path.join(d, 'model.onnx'), 176 | tokens=os.path.join(d, 'tokens.txt'), 177 | num_threads=args.threads, 178 | sample_rate=samplerate, 179 | use_itn=True, 180 | debug=0, 181 | provider=args.asr_provider, 182 | ) 183 | return recognizer 184 | 185 | def create_fireredasr(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: 186 | d = os.path.join( 187 | args.models_root, 'sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16') 188 | if not os.path.exists(d): 189 | raise ValueError(f"asr: model not found {d}") 190 | 191 | encoder = os.path.join(d, "encoder.int8.onnx") 192 | decoder = os.path.join(d, "decoder.int8.onnx") 193 | tokens = os.path.join(d, "tokens.txt") 194 | 195 | recognizer = sherpa_onnx.OfflineRecognizer.from_fire_red_asr( 196 | encoder=encoder, 197 | decoder=decoder, 198 | tokens=tokens, 199 | debug=0, 200 | provider=args.asr_provider, 201 | ) 202 | return recognizer 203 | 204 | 205 | 206 | def load_asr_engine(samplerate: int, args) -> sherpa_onnx.OnlineRecognizer: 207 | cache_engine = _asr_engines.get(args.asr_model) 208 | if cache_engine: 209 | return cache_engine 210 | st = time.time() 211 | if args.asr_model == 'zipformer-bilingual': 212 | cache_engine = create_zipformer(samplerate, args) 213 | elif args.asr_model == 'sensevoice': 214 | cache_engine = create_sensevoice(samplerate, args) 215 | _asr_engines['vad'] = load_vad_engine(samplerate, args) 216 | elif args.asr_model == 'paraformer-trilingual': 217 | cache_engine = create_paraformer_trilingual(samplerate, args) 218 | _asr_engines['vad'] = load_vad_engine(samplerate, args) 219 | elif args.asr_model == 'paraformer-en': 220 | cache_engine = create_paraformer_en(samplerate, args) 221 | _asr_engines['vad'] = load_vad_engine(samplerate, args) 222 | elif args.asr_model == 'fireredasr': 223 | cache_engine = create_fireredasr(samplerate, args) 224 | _asr_engines['vad'] = load_vad_engine(samplerate, args) 225 | else: 226 | raise ValueError(f"asr: unknown model {args.asr_model}") 227 | _asr_engines[args.asr_model] = cache_engine 228 | logger.info(f"asr: engine loaded in {time.time() - st:.2f}s") 229 | return cache_engine 230 | 231 | 232 | def load_vad_engine(samplerate: int, args, min_silence_duration: float = 0.25, buffer_size_in_seconds: int = 100) -> sherpa_onnx.VoiceActivityDetector: 233 | config = sherpa_onnx.VadModelConfig() 234 | d = os.path.join(args.models_root, 'silero_vad') 235 | if not os.path.exists(d): 236 | raise ValueError(f"vad: model not found {d}") 237 | 238 | config.silero_vad.model = os.path.join(d, 'silero_vad.onnx') 239 | config.silero_vad.min_silence_duration = min_silence_duration 240 | config.sample_rate = samplerate 241 | config.provider = args.asr_provider 242 | config.num_threads = args.threads 243 | 244 | vad = sherpa_onnx.VoiceActivityDetector( 245 | config, 246 | buffer_size_in_seconds=buffer_size_in_seconds) 247 | return vad 248 | 249 | 250 | async def start_asr_stream(samplerate: int, args) -> ASRStream: 251 | """ 252 | Start a ASR stream 253 | """ 254 | stream = ASRStream(load_asr_engine(samplerate, args), samplerate) 255 | await stream.start() 256 | return stream 257 | -------------------------------------------------------------------------------- /voiceapi/tts.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import os 3 | import time 4 | import sherpa_onnx 5 | import logging 6 | import numpy as np 7 | import asyncio 8 | import time 9 | import soundfile 10 | from scipy.signal import resample 11 | import io 12 | import re 13 | 14 | logger = logging.getLogger(__file__) 15 | 16 | splitter = re.compile(r'[,,。.!?!?;;、\n]') 17 | _tts_engines = {} 18 | 19 | tts_configs = { 20 | 'vits-zh-hf-theresa': { 21 | 'model': 'theresa.onnx', 22 | 'lexicon': 'lexicon.txt', 23 | 'dict_dir': 'dict', 24 | 'tokens': 'tokens.txt', 25 | 'sample_rate': 22050, 26 | 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst', 'new_heteronym.fst'], 27 | }, 28 | 'vits-melo-tts-zh_en': { 29 | 'model': 'model.onnx', 30 | 'lexicon': 'lexicon.txt', 31 | 'dict_dir': 'dict', 32 | 'tokens': 'tokens.txt', 33 | 'sample_rate': 44100, 34 | 'rule_fsts': ['phone.fst', 'date.fst', 'number.fst', 'new_heteronym.fst'], 35 | }, 36 | 'kokoro-multi-lang-v1_0': { 37 | 'model': 'model.onnx', 38 | #'lexicon': ['lexicon-zh.txt','lexicon-us-en.txt','lexicon-gb-en.txt'], 39 | 'lexicon': 'lexicon-zh.txt', 40 | 'dict_dir': 'dict', 41 | 'tokens': 'tokens.txt', 42 | 'sample_rate': 24000, 43 | 'rule_fsts': ['date-zh.fst', 'number-zh.fst'], 44 | }, 45 | } 46 | 47 | 48 | def load_tts_model(name: str, model_root: str, provider: str, num_threads: int = 1, max_num_sentences: int = 20) -> sherpa_onnx.OfflineTtsConfig: 49 | cfg = tts_configs[name] 50 | fsts = [] 51 | model_dir = os.path.join(model_root, name) 52 | for f in cfg.get('rule_fsts', ''): 53 | fsts.append(os.path.join(model_dir, f)) 54 | tts_rule_fsts = ','.join(fsts) if fsts else '' 55 | 56 | if 'kokoro' in name: 57 | kokoro_model_config = sherpa_onnx.OfflineTtsKokoroModelConfig( 58 | model=os.path.join(model_dir, cfg['model']), 59 | voices=os.path.join(model_dir, 'voices.bin'), 60 | lexicon=os.path.join(model_dir, cfg['lexicon']), 61 | data_dir=os.path.join(model_dir, 'espeak-ng-data'), 62 | dict_dir=os.path.join(model_dir, cfg['dict_dir']), 63 | tokens=os.path.join(model_dir, cfg['tokens']), 64 | ) 65 | model_config = sherpa_onnx.OfflineTtsModelConfig( 66 | kokoro=kokoro_model_config, 67 | provider=provider, 68 | debug=0, 69 | num_threads=num_threads, 70 | ) 71 | elif 'vits' in name: 72 | vits_model_config = sherpa_onnx.OfflineTtsVitsModelConfig( 73 | model=os.path.join(model_dir, cfg['model']), 74 | lexicon=os.path.join(model_dir, cfg['lexicon']), 75 | dict_dir=os.path.join(model_dir, cfg['dict_dir']), 76 | tokens=os.path.join(model_dir, cfg['tokens']), 77 | ) 78 | model_config = sherpa_onnx.OfflineTtsModelConfig( 79 | vits=vits_model_config, 80 | provider=provider, 81 | debug=0, 82 | num_threads=num_threads, 83 | ) 84 | 85 | tts_config = sherpa_onnx.OfflineTtsConfig( 86 | model=model_config, 87 | rule_fsts=tts_rule_fsts, 88 | max_num_sentences=max_num_sentences) 89 | 90 | if not tts_config.validate(): 91 | raise ValueError("tts: invalid config") 92 | 93 | return tts_config 94 | 95 | 96 | def get_tts_engine(args) -> Tuple[sherpa_onnx.OfflineTts, int]: 97 | sample_rate = tts_configs[args.tts_model]['sample_rate'] 98 | cache_engine = _tts_engines.get(args.tts_model) 99 | if cache_engine: 100 | return cache_engine, sample_rate 101 | st = time.time() 102 | tts_config = load_tts_model( 103 | args.tts_model, args.models_root, args.tts_provider) 104 | 105 | cache_engine = sherpa_onnx.OfflineTts(tts_config) 106 | elapsed = time.time() - st 107 | logger.info(f"tts: loaded {args.tts_model} in {elapsed:.2f}s") 108 | _tts_engines[args.tts_model] = cache_engine 109 | 110 | return cache_engine, sample_rate 111 | 112 | 113 | class TTSResult: 114 | def __init__(self, pcm_bytes: bytes, finished: bool): 115 | self.pcm_bytes = pcm_bytes 116 | self.finished = finished 117 | self.progress: float = 0.0 118 | self.elapsed: float = 0.0 119 | self.audio_duration: float = 0.0 120 | self.audio_size: int = 0 121 | 122 | def to_dict(self): 123 | return { 124 | "progress": self.progress, 125 | "elapsed": f'{int(self.elapsed * 1000)}ms', 126 | "duration": f'{self.audio_duration:.2f}s', 127 | "size": self.audio_size 128 | } 129 | 130 | 131 | class TTSStream: 132 | def __init__(self, engine, sid: int, speed: float = 1.0, sample_rate: int = 16000, original_sample_rate: int = 16000): 133 | self.engine = engine 134 | self.sid = sid 135 | self.speed = speed 136 | self.outbuf: asyncio.Queue[TTSResult | None] = asyncio.Queue() 137 | self.is_closed = False 138 | self.target_sample_rate = sample_rate 139 | self.original_sample_rate = original_sample_rate 140 | 141 | def on_process(self, chunk: np.ndarray, progress: float): 142 | if self.is_closed: 143 | return 0 144 | 145 | # resample to target sample rate 146 | if self.target_sample_rate != self.original_sample_rate: 147 | num_samples = int( 148 | len(chunk) * self.target_sample_rate / self.original_sample_rate) 149 | resampled_chunk = resample(chunk, num_samples) 150 | chunk = resampled_chunk.astype(np.float32) 151 | 152 | scaled_chunk = chunk * 32768.0 153 | clipped_chunk = np.clip(scaled_chunk, -32768, 32767) 154 | int16_chunk = clipped_chunk.astype(np.int16) 155 | samples = int16_chunk.tobytes() 156 | self.outbuf.put_nowait(TTSResult(samples, False)) 157 | return self.is_closed and 0 or 1 158 | 159 | async def write(self, text: str, split: bool, pause: float = 0.2): 160 | start = time.time() 161 | if split: 162 | texts = re.split(splitter, text) 163 | else: 164 | texts = [text] 165 | 166 | audio_duration = 0.0 167 | audio_size = 0 168 | 169 | for idx, text in enumerate(texts): 170 | text = text.strip() 171 | if not text: 172 | continue 173 | sub_start = time.time() 174 | 175 | audio = await asyncio.to_thread(self.engine.generate, 176 | text, self.sid, self.speed, 177 | self.on_process) 178 | 179 | if not audio or not audio.sample_rate or not audio.samples: 180 | logger.error(f"tts: failed to generate audio for " 181 | f"'{text}' (audio={audio})") 182 | continue 183 | 184 | if split and idx < len(texts) - 1: # add a pause between sentences 185 | noise = np.zeros(int(audio.sample_rate * pause)) 186 | self.on_process(noise, 1.0) 187 | audio.samples = np.concatenate([audio.samples, noise]) 188 | 189 | audio_duration += len(audio.samples) / audio.sample_rate 190 | audio_size += len(audio.samples) 191 | elapsed_seconds = time.time() - sub_start 192 | logger.info(f"tts: generated audio for '{text}', " 193 | f"audio duration: {audio_duration:.2f}s, " 194 | f"elapsed: {elapsed_seconds:.2f}s") 195 | 196 | elapsed_seconds = time.time() - start 197 | logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, " 198 | f"audio duration: {audio_duration:.2f}s") 199 | 200 | r = TTSResult(None, True) 201 | r.elapsed = elapsed_seconds 202 | r.audio_duration = audio_duration 203 | r.progress = 1.0 204 | r.finished = True 205 | await self.outbuf.put(r) 206 | 207 | async def close(self): 208 | self.is_closed = True 209 | self.outbuf.put_nowait(None) 210 | logger.info("tts: stream closed") 211 | 212 | async def read(self) -> TTSResult: 213 | return await self.outbuf.get() 214 | 215 | async def generate(self, text: str) -> io.BytesIO: 216 | start = time.time() 217 | audio = await asyncio.to_thread(self.engine.generate, 218 | text, self.sid, self.speed) 219 | elapsed_seconds = time.time() - start 220 | audio_duration = len(audio.samples) / audio.sample_rate 221 | 222 | logger.info(f"tts: generated audio in {elapsed_seconds:.2f}s, " 223 | f"audio duration: {audio_duration:.2f}s, " 224 | f"sample rate: {audio.sample_rate}") 225 | 226 | if self.target_sample_rate != audio.sample_rate: 227 | audio.samples = resample(audio.samples, 228 | int(len(audio.samples) * self.target_sample_rate / audio.sample_rate)) 229 | audio.sample_rate = self.target_sample_rate 230 | 231 | output = io.BytesIO() 232 | soundfile.write(output, 233 | audio.samples, 234 | samplerate=audio.sample_rate, 235 | subtype="PCM_16", 236 | format="WAV") 237 | output.seek(0) 238 | return output 239 | 240 | 241 | async def start_tts_stream(sid: int, sample_rate: int, speed: float, args) -> TTSStream: 242 | engine, original_sample_rate = get_tts_engine(args) 243 | return TTSStream(engine, sid, speed, sample_rate, original_sample_rate) 244 | --------------------------------------------------------------------------------