├── .gitignore ├── LLM ├── __pycache__ │ ├── chat.cpython-310.pyc │ └── mlx_lm.cpython-310.pyc ├── chat.py └── mlx_lm.py ├── README.md ├── STT ├── __pycache__ │ └── lightning_whisper_mlx_handler.cpython-310.pyc └── lightning_whisper_mlx_handler.py ├── TTS ├── __pycache__ │ └── melotts.cpython-310.pyc └── melotts.py ├── baseHandler.py ├── listen_and_play.py ├── local_audio_streamer.py ├── logo.png ├── requirements.txt ├── s2s_pipeline.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | tmp -------------------------------------------------------------------------------- /LLM/__pycache__/chat.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shi3z/speech-to-speech-japanese/d3b394287d07684a155905084df708e8d53fb7c6/LLM/__pycache__/chat.cpython-310.pyc -------------------------------------------------------------------------------- /LLM/__pycache__/mlx_lm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shi3z/speech-to-speech-japanese/d3b394287d07684a155905084df708e8d53fb7c6/LLM/__pycache__/mlx_lm.cpython-310.pyc -------------------------------------------------------------------------------- /LLM/chat.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | class Chat: 5 | """ 6 | Handles the chat using to avoid OOM issues. 7 | """ 8 | 9 | def __init__(self, size): 10 | self.size = size 11 | self.init_chat_message = None 12 | # maxlen is necessary pair, since a each new step we add an prompt and assitant answer 13 | self.buffer = [] 14 | 15 | def append(self, item): 16 | self.buffer.append(item) 17 | if len(self.buffer) == 2 * (self.size + 1): 18 | self.buffer.pop(0) 19 | self.buffer.pop(0) 20 | 21 | def init_chat(self, init_chat_message): 22 | self.init_chat_message = init_chat_message 23 | 24 | def to_list(self): 25 | if self.init_chat_message: 26 | return [self.init_chat_message] + self.buffer 27 | else: 28 | return self.buffer 29 | -------------------------------------------------------------------------------- /LLM/mlx_lm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from LLM.chat import Chat 3 | from baseHandler import BaseHandler 4 | from mlx_lm import load, stream_generate, generate 5 | from rich.console import Console 6 | import torch 7 | logging.basicConfig( 8 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 9 | ) 10 | logger = logging.getLogger(__name__) 11 | 12 | console = Console() 13 | 14 | class MLXLanguageModelHandler(BaseHandler): 15 | """ 16 | Handles the language model part. 17 | """ 18 | 19 | def setup( 20 | self, 21 | model_name, 22 | device="mps", 23 | torch_dtype="float16", 24 | gen_kwargs={}, 25 | user_role="user", 26 | chat_size=1, 27 | init_chat_role=None, 28 | #init_chat_prompt="You are a helpful AI assistant.", 29 | init_chat_prompt="あなたは日本語に堪能な友人です。英語を使ってはいけません。全て日本語で回答します. You must answer in Japanese", 30 | ): 31 | #model_name="mlx-community/Mistral-7B-Instruct-v0.3-4bit" 32 | #model_name="mlx-community/mlx-community/Phi-3-mini-4k-instruct-4bit" 33 | #model_name="mlx-community/Llama-3-Swallow-8B-Instruct-v0.1-4bit" 34 | #model_name="mlx-community/Llama-3-Swallow-8B-Instruct-v0.1-8bit" 35 | model_name="mlx-community/Llama-3-Swallow-8B-Instruct-v0.1-8bit" 36 | print(model_name) 37 | self.model_name = model_name 38 | model_id = model_name#'microsoft/Phi-3-mini-4k-instruct' 39 | self.model, self.tokenizer = load(model_id) 40 | self.gen_kwargs = gen_kwargs 41 | 42 | self.chat = Chat(chat_size) 43 | #init_chat_role=None 44 | init_chat_prompt="あなたは日本語に堪能な友人です。英語を使ってはいけません。全て日本語で回答します. You must answer in Japanese" 45 | if init_chat_role: 46 | if not init_chat_prompt: 47 | raise ValueError( 48 | "An initial promt needs to be specified when setting init_chat_role." 49 | ) 50 | self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) 51 | self.user_role = user_role 52 | 53 | self.warmup() 54 | 55 | def warmup(self): 56 | logger.info(f"Warming up {self.__class__.__name__}") 57 | return 58 | 59 | dummy_input_text = "Write me a poem about Machine Learning." 60 | dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] 61 | 62 | n_steps = 2 63 | 64 | for _ in range(n_steps): 65 | prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False) 66 | generate(self.model, self.tokenizer, prompt=prompt, max_tokens=self.gen_kwargs["max_new_tokens"], verbose=False) 67 | 68 | 69 | def process(self, prompt): 70 | logger.debug("infering language model...") 71 | 72 | 73 | self.chat.append({"role": self.user_role, "content": f"{prompt}, "}) 74 | prompt = self.tokenizer.apply_chat_template(self.chat.to_list(), tokenize=False, add_generation_prompt=True) 75 | output = "" 76 | curr_output = "" 77 | print(self.chat.to_list()) 78 | for t in stream_generate(self.model, self.tokenizer, prompt, max_tokens=self.gen_kwargs["max_new_tokens"]): 79 | output += t 80 | curr_output += t 81 | if curr_output.endswith(('.', '?','?',',','。','!','<|end|>','<|eot_id|>')): 82 | yield curr_output.replace('<|end|>', '') 83 | curr_output = "" 84 | print(f"cur:{curr_output}") 85 | generated_text = output.replace('<|end|>', '') 86 | print(f"generated_text:{generated_text}") 87 | torch.mps.empty_cache() 88 | 89 | self.chat.append({"role": "assistant", "content": generated_text}) 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
 
3 | 4 |
5 | 6 | # Speech To Speech: an effort for an open-sourced and modular GPT4-o 7 | 8 | fork from: [https://github.com/eustlb/speech-to-speech](https://github.com/eustlb/speech-to-speech) 9 | 10 | # Japanese support 11 | 12 | Python 3.10で動作確認済み 13 | 14 | ```bash 15 | git clone https://github.com/shi3z/speech-to-speech-japanese.git 16 | cd speech-to-speech-japanese 17 | pip install git+https://github.com/nltk/nltk.git@3.8.2 18 | git clone https://github.com/reazon-research/ReazonSpeech 19 | pip install Cython 20 | pip install ReazonSpeech/pkg/nemo-asr 21 | git clone https://github.com/myshell-ai/MeloTTS 22 | cd MeloTTS 23 | pip install -e . 24 | python -m unidic download 25 | cd .. 26 | pip install -r requirements.txt 27 | pip install transformers==4.44.1 28 | pip install mlx-lm 29 | pip install protobuf --upgrade 30 | python s2s_pipeline.py --mode local --device mps 31 | ``` 32 | MacBookPro M2 Max(32GB)で動作確認済 33 | MacBook M1(16GB)でも動作確認済み( @necobit (https://github.com/necobit) ) 34 | 35 | ## 📖 Quick Index 36 | * [Approach](#approach) 37 | - [Structure](#structure) 38 | - [Modularity](#modularity) 39 | * [Setup](#setup) 40 | * [Usage](#usage) 41 | - [Server/Client approach](#serverclient-approach) 42 | - [Local approach](#local-approach) 43 | * [Command-line usage](#command-line-usage) 44 | - [Model parameters](#model-parameters) 45 | - [Generation parameters](#generation-parameters) 46 | - [Notable parameters](#notable-parameters) 47 | 48 | ## Approach 49 | 50 | ### Structure 51 | This repository implements a speech-to-speech cascaded pipeline with consecutive parts: 52 | 1. **Voice Activity Detection (VAD)**: [silero VAD v5](https://github.com/snakers4/silero-vad) 53 | 2. **Speech to Text (STT)**: Whisper checkpoints (including [distilled versions](https://huggingface.co/distil-whisper)) 54 | 3. **Language Model (LM)**: Any instruct model available on the [Hugging Face Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending)! 🤗 55 | 4. **Text to Speech (TTS)**: [Parler-TTS](https://github.com/huggingface/parler-tts)🤗 56 | 57 | ### Modularity 58 | The pipeline aims to provide a fully open and modular approach, leveraging models available on the Transformers library via the Hugging Face hub. The level of modularity intended for each part is as follows: 59 | - **VAD**: Uses the implementation from [Silero's repo](https://github.com/snakers4/silero-vad). 60 | - **STT**: Uses Whisper models exclusively; however, any Whisper checkpoint can be used, enabling options like [Distil-Whisper](https://huggingface.co/distil-whisper/distil-large-v3) and [French Distil-Whisper](https://huggingface.co/eustlb/distil-large-v3-fr). 61 | - **LM**: This part is fully modular and can be changed by simply modifying the Hugging Face hub model ID. Users need to select an instruct model since the usage here involves interacting with it. 62 | - **TTS**: The mini architecture of Parler-TTS is standard, but different checkpoints, including fine-tuned multilingual checkpoints, can be used. 63 | 64 | The code is designed to facilitate easy modification. Each component is implemented as a class and can be re-implemented to match specific needs. 65 | 66 | ## Setup 67 | 68 | Clone the repository: 69 | ```bash 70 | git clone https://github.com/eustlb/speech-to-speech.git 71 | cd speech-to-speech 72 | ``` 73 | 74 | Install the required dependencies: 75 | ```bash 76 | pip install -r requirements.txt 77 | ``` 78 | 79 | ## Usage 80 | 81 | The pipeline can be run in two ways: 82 | - **Server/Client approach**: Models run on a server, and audio input/output are streamed from a client. 83 | - **Local approach**: Uses the same client/server method but with the loopback address. 84 | 85 | ### Server/Client Approach 86 | 87 | To run the pipeline on the server: 88 | ```bash 89 | python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0 90 | ``` 91 | 92 | Then run the client locally to handle sending microphone input and receiving generated audio: 93 | ```bash 94 | python listen_and_play.py --host 95 | ``` 96 | 97 | ### Local Approach 98 | Simply use the loopback address: 99 | ```bash 100 | python s2s_pipeline.py --recv_host localhost --send_host localhost 101 | python listen_and_play.py --host localhost 102 | ``` 103 | 104 | You can pass `--device mps` to run it locally on a Mac. 105 | 106 | ### Recommended usage 107 | 108 | Leverage Torch Compile for Whisper and Parler-TTS: 109 | 110 | ```bash 111 | python s2s_pipeline.py \ 112 | --recv_host 0.0.0.0 \ 113 | --send_host 0.0.0.0 \ 114 | --lm_model_name microsoft/Phi-3-mini-4k-instruct \ 115 | --init_chat_role system \ 116 | --stt_compile_mode reduce-overhead \ 117 | --tts_compile_mode default 118 | ``` 119 | 120 | For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`). 121 | 122 | ## Command-line Usage 123 | 124 | ### Model Parameters 125 | 126 | `model_name`, `torch_dtype`, and `device` are exposed for each part leveraging the Transformers' implementations: Speech to Text, Language Model, and Text to Speech. Specify the targeted pipeline part with the corresponding prefix: 127 | - `stt` (Speech to Text) 128 | - `lm` (Language Model) 129 | - `tts` (Text to Speech) 130 | 131 | For example: 132 | ```bash 133 | --lm_model_name google/gemma-2b-it 134 | ``` 135 | 136 | ### Generation Parameters 137 | 138 | Other generation parameters of the model's generate method can be set using the part's prefix + `_gen_`, e.g., `--stt_gen_max_new_tokens 128`. These parameters can be added to the pipeline part's arguments class if not already exposed (see `LanguageModelHandlerArguments` for example). 139 | 140 | ### Notable Parameters 141 | 142 | #### VAD Parameters 143 | - `--thresh`: Threshold value to trigger voice activity detection. 144 | - `--min_speech_ms`: Minimum duration of detected voice activity to be considered speech. 145 | - `--min_silence_ms`: Minimum length of silence intervals for segmenting speech, balancing sentence cutting and latency reduction. 146 | 147 | #### Language Model 148 | - `--init_chat_role`: Defaults to `None`. Sets the initial role in the chat template, if applicable. Refer to the model's card to set this value (e.g. for [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) you have to set `--init_chat_role system`) 149 | - `--init_chat_prompt`: Defaults to `"You are a helpful AI assistant."` Required when setting `--init_chat_role`. 150 | 151 | #### Speech to Text 152 | - `--description`: Sets the description for Parler-TTS generated voice. Defaults to: `"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."` 153 | 154 | - `--play_steps_s`: Specifies the duration of the first chunk sent during streaming output from Parler-TTS, impacting readiness and decoding steps. 155 | 156 | ## Citations 157 | 158 | ### Silero VAD 159 | ```bibtex 160 | @misc{Silero VAD, 161 | author = {Silero Team}, 162 | title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier}, 163 | year = {2021}, 164 | publisher = {GitHub}, 165 | journal = {GitHub repository}, 166 | howpublished = {\url{https://github.com/snakers4/silero-vad}}, 167 | commit = {insert_some_commit_here}, 168 | email = {hello@silero.ai} 169 | } 170 | ``` 171 | 172 | ### Distil-Whisper 173 | ```bibtex 174 | @misc{gandhi2023distilwhisper, 175 | title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling}, 176 | author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush}, 177 | year={2023}, 178 | eprint={2311.00430}, 179 | archivePrefix={arXiv}, 180 | primaryClass={cs.CL} 181 | } 182 | ``` 183 | 184 | ### Parler-TTS 185 | ```bibtex 186 | @misc{lacombe-etal-2024-parler-tts, 187 | author = {Yoach Lacombe and Vaibhav Srivastav and Sanchit Gandhi}, 188 | title = {Parler-TTS}, 189 | year = {2024}, 190 | publisher = {GitHub}, 191 | journal = {GitHub repository}, 192 | howpublished = {\url{https://github.com/huggingface/parler-tts}} 193 | } 194 | ``` 195 | -------------------------------------------------------------------------------- /STT/__pycache__/lightning_whisper_mlx_handler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shi3z/speech-to-speech-japanese/d3b394287d07684a155905084df708e8d53fb7c6/STT/__pycache__/lightning_whisper_mlx_handler.cpython-310.pyc -------------------------------------------------------------------------------- /STT/lightning_whisper_mlx_handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import perf_counter 3 | from baseHandler import BaseHandler 4 | from lightning_whisper_mlx import LightningWhisperMLX 5 | import numpy as np 6 | from rich.console import Console 7 | import torch 8 | from reazonspeech.nemo.asr import load_model, transcribe, audio_from_path,audio_from_numpy 9 | 10 | 11 | logging.basicConfig( 12 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 13 | ) 14 | logger = logging.getLogger(__name__) 15 | 16 | console = Console() 17 | 18 | 19 | 20 | class LightningWhisperSTTHandler(BaseHandler): 21 | """ 22 | Handles the Speech To Text generation using a Whisper model. 23 | """ 24 | 25 | def setup( 26 | self, 27 | model_name="distil-whisper/distil-large-v3", 28 | device="cuda", 29 | torch_dtype="float16", 30 | compile_mode=None, 31 | gen_kwargs={}, 32 | ): 33 | self.device = device 34 | #self.model = LightningWhisperMLX( 35 | # model="distil-large-v3", batch_size=6, quant=None 36 | #) 37 | self.model=load_model() 38 | 39 | self.warmup() 40 | 41 | def warmup(self): 42 | logger.info(f"Warming up {self.__class__.__name__}") 43 | return 44 | 45 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture 46 | n_steps = 1 47 | dummy_input = np.array([0] * 512) 48 | 49 | for _ in range(n_steps): 50 | _ = self.model.transcribe(dummy_input)["text"].strip() 51 | 52 | def process(self, spoken_prompt): 53 | logger.debug("infering whisper...") 54 | 55 | global pipeline_start 56 | pipeline_start = perf_counter() 57 | RATE=16000 58 | audio = audio_from_numpy(spoken_prompt,RATE) 59 | 60 | #pred_text = self.model.transcribe(spoken_prompt,language="ja")["text"].strip() 61 | result = transcribe(self.model, audio) 62 | pred_text = result.text 63 | torch.mps.empty_cache() 64 | 65 | logger.debug("finished whisper inference") 66 | console.print(f"[yellow]USER: {pred_text}") 67 | 68 | yield pred_text 69 | -------------------------------------------------------------------------------- /TTS/__pycache__/melotts.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shi3z/speech-to-speech-japanese/d3b394287d07684a155905084df708e8d53fb7c6/TTS/__pycache__/melotts.cpython-310.pyc -------------------------------------------------------------------------------- /TTS/melotts.py: -------------------------------------------------------------------------------- 1 | from MeloTTS.melo.api import TTS 2 | import logging 3 | from baseHandler import BaseHandler 4 | import librosa 5 | import numpy as np 6 | from rich.console import Console 7 | import torch 8 | 9 | logging.basicConfig( 10 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 11 | ) 12 | logger = logging.getLogger(__name__) 13 | 14 | console = Console() 15 | 16 | 17 | class MeloTTSHandler(BaseHandler): 18 | def setup( 19 | self, 20 | should_listen, 21 | device="mps", 22 | #language="EN_NEWEST", 23 | language="JP", 24 | blocksize=512, 25 | ): 26 | self.should_listen = should_listen 27 | self.device = device 28 | self.model = TTS(language=language, device=device) 29 | self.speaker_id = self.model.hps.data.spk2id[language] 30 | self.blocksize = blocksize 31 | self.warmup() 32 | 33 | def warmup(self): 34 | logger.info(f"Warming up {self.__class__.__name__}") 35 | _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) 36 | 37 | def process(self, llm_sentence): 38 | llm_sentence = llm_sentence.replace("AI","エーアイ") 39 | llm_sentence = llm_sentence.replace("Python","パイソン") 40 | llm_sentence = llm_sentence.replace("A","エー") 41 | llm_sentence = llm_sentence.replace("B","ビー") 42 | llm_sentence = llm_sentence.replace("C","シー") 43 | llm_sentence = llm_sentence.replace("D","ディー") 44 | llm_sentence = llm_sentence.replace("E","イー") 45 | llm_sentence = llm_sentence.replace("F","エフ") 46 | llm_sentence = llm_sentence.replace("G","ジー") 47 | llm_sentence = llm_sentence.replace("H","エイチ") 48 | llm_sentence = llm_sentence.replace("I","アイ") 49 | llm_sentence = llm_sentence.replace("J","ジェイ") 50 | llm_sentence = llm_sentence.replace("K","ケイ") 51 | llm_sentence = llm_sentence.replace("L","エル") 52 | llm_sentence = llm_sentence.replace("M","エム") 53 | llm_sentence = llm_sentence.replace("N","エヌ") 54 | llm_sentence = llm_sentence.replace("O","オー") 55 | llm_sentence = llm_sentence.replace("P","ピー") 56 | llm_sentence = llm_sentence.replace("Q","キュー") 57 | llm_sentence = llm_sentence.replace("R","アール") 58 | llm_sentence = llm_sentence.replace("S","エス") 59 | llm_sentence = llm_sentence.replace("T","ティー") 60 | llm_sentence = llm_sentence.replace("U","ユー") 61 | llm_sentence = llm_sentence.replace("V","ブイ") 62 | llm_sentence = llm_sentence.replace("W","ダブリュー") 63 | llm_sentence = llm_sentence.replace("X","エックス") 64 | llm_sentence = llm_sentence.replace("Y","ワイ") 65 | llm_sentence = llm_sentence.replace("Z","ゼット") 66 | console.print(f"[green]ASSISTANT: {llm_sentence}") 67 | if self.device == "mps": 68 | import time 69 | start = time.time() 70 | torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete. 71 | torch.mps.empty_cache() # Frees all memory allocated by the MPS device. 72 | time_it_took = time.time()-start # Removing this line makes it fail more often. I'm looking into it. 73 | 74 | audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True) 75 | if len(audio_chunk) == 0: 76 | self.should_listen.set() 77 | return 78 | audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) 79 | audio_chunk = (audio_chunk * 32768).astype(np.int16) 80 | for i in range(0, len(audio_chunk), self.blocksize): 81 | yield np.pad( 82 | audio_chunk[i : i + self.blocksize], 83 | (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), 84 | ) 85 | 86 | self.should_listen.set() 87 | -------------------------------------------------------------------------------- /baseHandler.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class BaseHandler: 8 | """ 9 | Base class for pipeline parts. Each part of the pipeline has an input and an output queue. 10 | The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part. 11 | To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue. 12 | Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue. 13 | The cleanup method handles stopping the handler, and b"END" is placed in the output queue. 14 | """ 15 | 16 | def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}): 17 | self.stop_event = stop_event 18 | self.queue_in = queue_in 19 | self.queue_out = queue_out 20 | self.setup(*setup_args, **setup_kwargs) 21 | self._times = [] 22 | 23 | def setup(self): 24 | pass 25 | 26 | def process(self): 27 | raise NotImplementedError 28 | 29 | def run(self): 30 | while not self.stop_event.is_set(): 31 | input = self.queue_in.get() 32 | if isinstance(input, bytes) and input == b"END": 33 | # sentinelle signal to avoid queue deadlock 34 | logger.debug("Stopping thread") 35 | break 36 | start_time = perf_counter() 37 | for output in self.process(input): 38 | self._times.append(perf_counter() - start_time) 39 | logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s") 40 | self.queue_out.put(output) 41 | start_time = perf_counter() 42 | 43 | self.cleanup() 44 | self.queue_out.put(b"END") 45 | 46 | @property 47 | def last_time(self): 48 | return self._times[-1] 49 | 50 | def cleanup(self): 51 | pass 52 | -------------------------------------------------------------------------------- /listen_and_play.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import numpy as np 3 | import threading 4 | from queue import Queue 5 | from dataclasses import dataclass, field 6 | import sounddevice as sd 7 | from transformers import HfArgumentParser 8 | 9 | 10 | @dataclass 11 | class ListenAndPlayArguments: 12 | send_rate: int = field( 13 | default=16000, 14 | metadata={ 15 | "help": "In Hz. Default is 16000." 16 | } 17 | ) 18 | recv_rate: int = field( 19 | default=44100, 20 | metadata={ 21 | "help": "In Hz. Default is 44100." 22 | } 23 | ) 24 | list_play_chunk_size: int = field( 25 | default=1024, 26 | metadata={ 27 | "help": "The size of data chunks (in bytes). Default is 1024." 28 | } 29 | ) 30 | host: str = field( 31 | default="localhost", 32 | metadata={ 33 | "help": "The hostname or IP address for listening and playing. Default is 'localhost'." 34 | } 35 | ) 36 | send_port: int = field( 37 | default=12345, 38 | metadata={ 39 | "help": "The network port for sending data. Default is 12345." 40 | } 41 | ) 42 | recv_port: int = field( 43 | default=12346, 44 | metadata={ 45 | "help": "The network port for receiving data. Default is 12346." 46 | } 47 | ) 48 | 49 | 50 | def listen_and_play( 51 | send_rate=16000, 52 | recv_rate=44100, 53 | list_play_chunk_size=1024, 54 | host="localhost", 55 | send_port=12345, 56 | recv_port=12346, 57 | ): 58 | 59 | send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 60 | send_socket.connect((host, send_port)) 61 | 62 | recv_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 63 | recv_socket.connect((host, recv_port)) 64 | 65 | print("Recording and streaming...") 66 | 67 | stop_event = threading.Event() 68 | recv_queue = Queue() 69 | send_queue = Queue() 70 | 71 | def callback_recv(outdata, frames, time, status): 72 | if not recv_queue.empty(): 73 | data = recv_queue.get() 74 | outdata[:len(data)] = data 75 | outdata[len(data):] = b'\x00' * (len(outdata) - len(data)) 76 | else: 77 | outdata[:] = b'\x00' * len(outdata) 78 | 79 | def callback_send(indata, frames, time, status): 80 | if recv_queue.empty(): 81 | data = bytes(indata) 82 | send_queue.put(data) 83 | 84 | def send(stop_event, send_queue): 85 | while not stop_event.is_set(): 86 | data = send_queue.get() 87 | send_socket.sendall(data) 88 | 89 | def recv(stop_event, recv_queue): 90 | 91 | def receive_full_chunk(conn, chunk_size): 92 | data = b'' 93 | while len(data) < chunk_size: 94 | packet = conn.recv(chunk_size - len(data)) 95 | if not packet: 96 | return None # Connection has been closed 97 | data += packet 98 | return data 99 | 100 | while not stop_event.is_set(): 101 | data = receive_full_chunk(recv_socket, list_play_chunk_size * 2) 102 | if data: 103 | recv_queue.put(data) 104 | 105 | try: 106 | send_stream = sd.RawInputStream(samplerate=send_rate, channels=1, dtype='int16', blocksize=list_play_chunk_size, callback=callback_send) 107 | recv_stream = sd.RawOutputStream(samplerate=recv_rate, channels=1, dtype='int16', blocksize=list_play_chunk_size, callback=callback_recv) 108 | threading.Thread(target=send_stream.start).start() 109 | threading.Thread(target=recv_stream.start).start() 110 | 111 | send_thread = threading.Thread(target=send, args=(stop_event, send_queue)) 112 | send_thread.start() 113 | recv_thread = threading.Thread(target=recv, args=(stop_event, recv_queue)) 114 | recv_thread.start() 115 | 116 | input("Press Enter to stop...") 117 | 118 | except KeyboardInterrupt: 119 | print("Finished streaming.") 120 | 121 | finally: 122 | stop_event.set() 123 | recv_thread.join() 124 | send_thread.join() 125 | send_socket.close() 126 | recv_socket.close() 127 | print("Connection closed.") 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = HfArgumentParser((ListenAndPlayArguments,)) 132 | listen_and_play_kwargs, = parser.parse_args_into_dataclasses() 133 | listen_and_play(**vars(listen_and_play_kwargs)) 134 | 135 | -------------------------------------------------------------------------------- /local_audio_streamer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import sounddevice as sd 3 | import numpy as np 4 | 5 | import time 6 | 7 | 8 | class LocalAudioStreamer: 9 | def __init__( 10 | self, 11 | input_queue, 12 | output_queue, 13 | list_play_chunk_size=512, 14 | ): 15 | self.list_play_chunk_size = list_play_chunk_size 16 | 17 | self.stop_event = threading.Event() 18 | self.input_queue = input_queue 19 | self.output_queue = output_queue 20 | 21 | def run(self): 22 | def callback(indata, outdata, frames, time, status): 23 | if self.output_queue.empty(): 24 | self.input_queue.put(indata.copy()) 25 | outdata[:] = 0 * outdata 26 | else: 27 | outdata[:] = self.output_queue.get()[:, np.newaxis] 28 | 29 | with sd.Stream( 30 | samplerate=16000, 31 | dtype="int16", 32 | channels=1, 33 | callback=callback, 34 | blocksize=self.list_play_chunk_size, 35 | ): 36 | while not self.stop_event.is_set(): 37 | time.sleep(0.001) 38 | print("Stopping recording") 39 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shi3z/speech-to-speech-japanese/d3b394287d07684a155905084df708e8d53fb7c6/logo.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk==3.8.1 2 | parler_tts @ git+https://github.com/huggingface/parler-tts.git 3 | torch==2.4.0 4 | sounddevice==0.5.0 5 | -------------------------------------------------------------------------------- /s2s_pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import socket 4 | import sys 5 | import threading 6 | from copy import copy 7 | from dataclasses import dataclass, field 8 | from pathlib import Path 9 | from queue import Queue 10 | from threading import Event, Thread 11 | from time import perf_counter 12 | from typing import Optional 13 | 14 | from LLM.mlx_lm import MLXLanguageModelHandler 15 | from TTS.melotts import MeloTTSHandler 16 | from baseHandler import BaseHandler 17 | from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler 18 | import numpy as np 19 | import torch 20 | from nltk.tokenize import sent_tokenize 21 | from rich.console import Console 22 | from transformers import ( 23 | AutoModelForCausalLM, 24 | AutoModelForSpeechSeq2Seq, 25 | AutoProcessor, 26 | AutoTokenizer, 27 | HfArgumentParser, 28 | pipeline, 29 | TextIteratorStreamer, 30 | ) 31 | #from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer 32 | import librosa 33 | 34 | from local_audio_streamer import LocalAudioStreamer 35 | from utils import VADIterator, int2float, next_power_of_2 36 | 37 | # Ensure that the necessary NLTK resources are available 38 | # try: 39 | # nltk.data.find('tokenizers/punkt_tab') 40 | # except LookupError: 41 | # nltk.download('punkt_tab') 42 | 43 | # caching allows ~50% compilation time reduction 44 | # see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma 45 | CURRENT_DIR = Path(__file__).resolve().parent 46 | os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") 47 | # torch._inductor.config.fx_graph_cache = True 48 | # # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS 49 | # torch._dynamo.config.cache_size_limit = 15 50 | 51 | 52 | console = Console() 53 | 54 | 55 | @dataclass 56 | class ModuleArguments: 57 | device: Optional[str] = field( 58 | default=None, 59 | metadata={"help": "If specified, overrides the device for all handlers."}, 60 | ) 61 | mode: Optional[str] = field( 62 | default="local", 63 | metadata={ 64 | "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'." 65 | }, 66 | ) 67 | log_level: str = field( 68 | default="info", 69 | metadata={ 70 | "help": "Provide logging level. Example --log_level debug, default=warning." 71 | }, 72 | ) 73 | 74 | 75 | class ThreadManager: 76 | """ 77 | Manages multiple threads used to execute given handler tasks. 78 | """ 79 | 80 | def __init__(self, handlers): 81 | self.handlers = handlers 82 | self.threads = [] 83 | 84 | def start(self): 85 | for handler in self.handlers: 86 | thread = threading.Thread(target=handler.run) 87 | self.threads.append(thread) 88 | thread.start() 89 | 90 | def stop(self): 91 | for handler in self.handlers: 92 | handler.stop_event.set() 93 | for thread in self.threads: 94 | thread.join() 95 | 96 | 97 | @dataclass 98 | class SocketReceiverArguments: 99 | recv_host: str = field( 100 | default="localhost", 101 | metadata={ 102 | "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all " 103 | "available interfaces on the host machine." 104 | }, 105 | ) 106 | recv_port: int = field( 107 | default=12345, 108 | metadata={ 109 | "help": "The port number on which the socket server listens. Default is 12346." 110 | }, 111 | ) 112 | chunk_size: int = field( 113 | default=1024, 114 | metadata={ 115 | "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes." 116 | }, 117 | ) 118 | 119 | 120 | class SocketReceiver: 121 | """ 122 | Handles reception of the audio packets from the client. 123 | """ 124 | 125 | def __init__( 126 | self, 127 | stop_event, 128 | queue_out, 129 | should_listen, 130 | host="0.0.0.0", 131 | port=12345, 132 | chunk_size=1024, 133 | ): 134 | self.stop_event = stop_event 135 | self.queue_out = queue_out 136 | self.should_listen = should_listen 137 | self.chunk_size = chunk_size 138 | self.host = host 139 | self.port = port 140 | 141 | def receive_full_chunk(self, conn, chunk_size): 142 | data = b"" 143 | while len(data) < chunk_size: 144 | packet = conn.recv(chunk_size - len(data)) 145 | if not packet: 146 | # connection closed 147 | return None 148 | data += packet 149 | return data 150 | 151 | def run(self): 152 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 153 | self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 154 | self.socket.bind((self.host, self.port)) 155 | self.socket.listen(1) 156 | logger.info("Receiver waiting to be connected...") 157 | self.conn, _ = self.socket.accept() 158 | logger.info("receiver connected") 159 | 160 | self.should_listen.set() 161 | while not self.stop_event.is_set(): 162 | audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size) 163 | if audio_chunk is None: 164 | # connection closed 165 | self.queue_out.put(b"END") 166 | break 167 | if self.should_listen.is_set(): 168 | self.queue_out.put(audio_chunk) 169 | self.conn.close() 170 | logger.info("Receiver closed") 171 | 172 | 173 | @dataclass 174 | class SocketSenderArguments: 175 | send_host: str = field( 176 | default="localhost", 177 | metadata={ 178 | "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all " 179 | "available interfaces on the host machine." 180 | }, 181 | ) 182 | send_port: int = field( 183 | default=12346, 184 | metadata={ 185 | "help": "The port number on which the socket server listens. Default is 12346." 186 | }, 187 | ) 188 | 189 | 190 | class SocketSender: 191 | """ 192 | Handles sending generated audio packets to the clients. 193 | """ 194 | 195 | def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346): 196 | self.stop_event = stop_event 197 | self.queue_in = queue_in 198 | self.host = host 199 | self.port = port 200 | 201 | def run(self): 202 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 203 | self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 204 | self.socket.bind((self.host, self.port)) 205 | self.socket.listen(1) 206 | logger.info("Sender waiting to be connected...") 207 | self.conn, _ = self.socket.accept() 208 | logger.info("sender connected") 209 | 210 | while not self.stop_event.is_set(): 211 | audio_chunk = self.queue_in.get() 212 | self.conn.sendall(audio_chunk) 213 | if isinstance(audio_chunk, bytes) and audio_chunk == b"END": 214 | break 215 | self.conn.close() 216 | logger.info("Sender closed") 217 | 218 | 219 | @dataclass 220 | class VADHandlerArguments: 221 | thresh: float = field( 222 | default=0.3, 223 | metadata={ 224 | "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection." 225 | }, 226 | ) 227 | sample_rate: int = field( 228 | default=16000, 229 | metadata={ 230 | "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio." 231 | }, 232 | ) 233 | min_silence_ms: int = field( 234 | default=250, 235 | metadata={ 236 | "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms." 237 | }, 238 | ) 239 | min_speech_ms: int = field( 240 | default=500, 241 | metadata={ 242 | "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms." 243 | }, 244 | ) 245 | max_speech_ms: float = field( 246 | default=float("inf"), 247 | metadata={ 248 | "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments." 249 | }, 250 | ) 251 | speech_pad_ms: int = field( 252 | default=250, 253 | metadata={ 254 | "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms." 255 | }, 256 | ) 257 | 258 | 259 | class VADHandler(BaseHandler): 260 | """ 261 | Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed 262 | to the following part. 263 | """ 264 | 265 | def setup( 266 | self, 267 | should_listen, 268 | thresh=0.3, 269 | sample_rate=16000, 270 | min_silence_ms=1000, 271 | min_speech_ms=500, 272 | max_speech_ms=float("inf"), 273 | speech_pad_ms=30, 274 | ): 275 | self.should_listen = should_listen 276 | self.sample_rate = sample_rate 277 | self.min_silence_ms = min_silence_ms 278 | self.min_speech_ms = min_speech_ms 279 | self.max_speech_ms = max_speech_ms 280 | self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad") 281 | self.iterator = VADIterator( 282 | self.model, 283 | threshold=thresh, 284 | sampling_rate=sample_rate, 285 | min_silence_duration_ms=min_silence_ms, 286 | speech_pad_ms=speech_pad_ms, 287 | ) 288 | 289 | def process(self, audio_chunk): 290 | audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16) 291 | audio_float32 = int2float(audio_int16) 292 | vad_output = self.iterator(torch.from_numpy(audio_float32)) 293 | if vad_output is not None and len(vad_output) != 0: 294 | logger.debug("VAD: end of speech detected") 295 | array = torch.cat(vad_output).cpu().numpy() 296 | duration_ms = len(array) / self.sample_rate * 1000 297 | if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms: 298 | logger.debug( 299 | f"audio input of duration: {len(array) / self.sample_rate}s, skipping" 300 | ) 301 | else: 302 | self.should_listen.clear() 303 | logger.debug("Stop listening") 304 | yield array 305 | 306 | 307 | @dataclass 308 | class WhisperSTTHandlerArguments: 309 | stt_model_name: str = field( 310 | default="distil-whisper/distil-large-v3", 311 | metadata={ 312 | "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'." 313 | }, 314 | ) 315 | stt_device: str = field( 316 | default="cuda", 317 | metadata={ 318 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 319 | }, 320 | ) 321 | stt_torch_dtype: str = field( 322 | default="float16", 323 | metadata={ 324 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." 325 | }, 326 | ) 327 | stt_compile_mode: str = field( 328 | default=None, 329 | metadata={ 330 | "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" 331 | }, 332 | ) 333 | stt_gen_max_new_tokens: int = field( 334 | default=128, 335 | metadata={ 336 | "help": "The maximum number of new tokens to generate. Default is 128." 337 | }, 338 | ) 339 | stt_gen_num_beams: int = field( 340 | default=1, 341 | metadata={ 342 | "help": "The number of beams for beam search. Default is 1, implying greedy decoding." 343 | }, 344 | ) 345 | stt_gen_return_timestamps: bool = field( 346 | default=False, 347 | metadata={ 348 | "help": "Whether to return timestamps with transcriptions. Default is False." 349 | }, 350 | ) 351 | # stt_gen_task: str = field( 352 | # default="transcribe", 353 | # metadata={ 354 | # "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." 355 | # }, 356 | # ) 357 | stt_gen_language: str = field( 358 | default="jp", 359 | metadata={ 360 | "help": "The language of the speech to transcribe. Default is 'en' for English." 361 | }, 362 | ) 363 | 364 | 365 | class WhisperSTTHandler(BaseHandler): 366 | """ 367 | Handles the Speech To Text generation using a Whisper model. 368 | """ 369 | 370 | def setup( 371 | self, 372 | model_name="distil-whisper/distil-large-v3", 373 | device="cuda", 374 | torch_dtype="float16", 375 | compile_mode=None, 376 | gen_kwargs={}, 377 | ): 378 | self.device = device 379 | self.torch_dtype = getattr(torch, torch_dtype) 380 | self.compile_mode = compile_mode 381 | self.gen_kwargs = gen_kwargs 382 | 383 | self.processor = AutoProcessor.from_pretrained(model_name) 384 | self.model = AutoModelForSpeechSeq2Seq.from_pretrained( 385 | model_name, 386 | torch_dtype=self.torch_dtype, 387 | ).to(device) 388 | 389 | # compile 390 | if self.compile_mode: 391 | self.model.generation_config.cache_implementation = "static" 392 | self.model.forward = torch.compile( 393 | self.model.forward, mode=self.compile_mode, fullgraph=True 394 | ) 395 | self.warmup() 396 | 397 | def prepare_model_inputs(self, spoken_prompt): 398 | input_features = self.processor( 399 | spoken_prompt, sampling_rate=16000, return_tensors="pt" 400 | ).input_features 401 | input_features = input_features.to(self.device, dtype=self.torch_dtype) 402 | 403 | return input_features 404 | 405 | def warmup(self): 406 | logger.info(f"Warming up {self.__class__.__name__}") 407 | 408 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture 409 | n_steps = 1 if self.compile_mode == "default" else 2 410 | dummy_input = torch.randn( 411 | (1, self.model.config.num_mel_bins, 3000), 412 | dtype=self.torch_dtype, 413 | device=self.device, 414 | ) 415 | if self.compile_mode not in (None, "default"): 416 | # generating more tokens than previously will trigger CUDA graphs capture 417 | # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation 418 | warmup_gen_kwargs = { 419 | "min_new_tokens": self.gen_kwargs["max_new_tokens"], 420 | "max_new_tokens": self.gen_kwargs["max_new_tokens"], 421 | **self.gen_kwargs, 422 | } 423 | else: 424 | warmup_gen_kwargs = self.gen_kwargs 425 | 426 | if self.device == "cuda": 427 | start_event = torch.cuda.Event(enable_timing=True) 428 | end_event = torch.cuda.Event(enable_timing=True) 429 | torch.cuda.synchronize() 430 | start_event.record() 431 | 432 | for _ in range(n_steps): 433 | _ = self.model.generate(dummy_input, **warmup_gen_kwargs) 434 | 435 | if self.device == "cuda": 436 | end_event.record() 437 | torch.cuda.synchronize() 438 | 439 | logger.info( 440 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" 441 | ) 442 | 443 | def process(self, spoken_prompt): 444 | logger.debug("infering whisper...") 445 | 446 | global pipeline_start 447 | pipeline_start = perf_counter() 448 | 449 | input_features = self.prepare_model_inputs(spoken_prompt) 450 | pred_ids = self.model.generate(input_features, **self.gen_kwargs) 451 | pred_text = self.processor.batch_decode( 452 | pred_ids, skip_special_tokens=True, decode_with_timestamps=False 453 | )[0] 454 | 455 | logger.debug("finished whisper inference") 456 | console.print(f"[yellow]USER: {pred_text}") 457 | 458 | yield pred_text 459 | 460 | 461 | @dataclass 462 | class LanguageModelHandlerArguments: 463 | lm_model_name: str = field( 464 | default="HuggingFaceTB/SmolLM-360M-Instruct", 465 | metadata={ 466 | "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." 467 | }, 468 | ) 469 | lm_device: str = field( 470 | default="cuda", 471 | metadata={ 472 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 473 | }, 474 | ) 475 | lm_torch_dtype: str = field( 476 | default="float16", 477 | metadata={ 478 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." 479 | }, 480 | ) 481 | user_role: str = field( 482 | default="user", 483 | metadata={ 484 | "help": "Role assigned to the user in the chat context. Default is 'user'." 485 | }, 486 | ) 487 | init_chat_role: str = field( 488 | default='system', 489 | metadata={ 490 | "help": "Initial role for setting up the chat context. Default is 'system'." 491 | }, 492 | ) 493 | init_chat_prompt: str = field( 494 | default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", 495 | metadata={ 496 | "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" 497 | }, 498 | ) 499 | lm_gen_max_new_tokens: int = field( 500 | default=128, 501 | metadata={ 502 | "help": "Maximum number of new tokens to generate in a single completion. Default is 128." 503 | }, 504 | ) 505 | lm_gen_temperature: float = field( 506 | default=0.0, 507 | metadata={ 508 | "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." 509 | }, 510 | ) 511 | lm_gen_do_sample: bool = field( 512 | default=False, 513 | metadata={ 514 | "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." 515 | }, 516 | ) 517 | chat_size: int = field( 518 | default=2, 519 | metadata={ 520 | "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." 521 | }, 522 | ) 523 | 524 | 525 | class Chat: 526 | """ 527 | Handles the chat using to avoid OOM issues. 528 | """ 529 | 530 | def __init__(self, size): 531 | self.size = size 532 | self.init_chat_message = None 533 | # maxlen is necessary pair, since a each new step we add an prompt and assitant answer 534 | self.buffer = [] 535 | 536 | def append(self, item): 537 | self.buffer.append(item) 538 | if len(self.buffer) == 2 * (self.size + 1): 539 | self.buffer.pop(0) 540 | self.buffer.pop(0) 541 | 542 | def init_chat(self, init_chat_message): 543 | self.init_chat_message = init_chat_message 544 | 545 | def to_list(self): 546 | if self.init_chat_message: 547 | return [self.init_chat_message] + self.buffer 548 | else: 549 | return self.buffer 550 | 551 | 552 | class LanguageModelHandler(BaseHandler): 553 | """ 554 | Handles the language model part. 555 | """ 556 | 557 | def setup( 558 | self, 559 | model_name="microsoft/Phi-3-mini-4k-instruct", 560 | device="cuda", 561 | torch_dtype="float16", 562 | gen_kwargs={}, 563 | user_role="user", 564 | chat_size=1, 565 | init_chat_role=None, 566 | init_chat_prompt="You are a helpful AI assistant.", 567 | ): 568 | self.device = device 569 | self.torch_dtype = getattr(torch, torch_dtype) 570 | 571 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 572 | self.model = AutoModelForCausalLM.from_pretrained( 573 | model_name, torch_dtype=torch_dtype, trust_remote_code=True 574 | ).to(device) 575 | self.pipe = pipeline( 576 | "text-generation", model=self.model, tokenizer=self.tokenizer, device=device 577 | ) 578 | self.streamer = TextIteratorStreamer( 579 | self.tokenizer, 580 | skip_prompt=True, 581 | skip_special_tokens=True, 582 | ) 583 | self.gen_kwargs = { 584 | "streamer": self.streamer, 585 | "return_full_text": False, 586 | **gen_kwargs, 587 | } 588 | 589 | self.chat = Chat(chat_size) 590 | if init_chat_role: 591 | if not init_chat_prompt: 592 | raise ValueError( 593 | "An initial promt needs to be specified when setting init_chat_role." 594 | ) 595 | self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) 596 | self.user_role = user_role 597 | 598 | self.warmup() 599 | 600 | def warmup(self): 601 | logger.info(f"Warming up {self.__class__.__name__}") 602 | 603 | dummy_input_text = "Write me a poem about Machine Learning." 604 | dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] 605 | warmup_gen_kwargs = { 606 | "min_new_tokens": self.gen_kwargs["max_new_tokens"], 607 | "max_new_tokens": self.gen_kwargs["max_new_tokens"], 608 | **self.gen_kwargs, 609 | } 610 | 611 | n_steps = 2 612 | 613 | if self.device == "cuda": 614 | start_event = torch.cuda.Event(enable_timing=True) 615 | end_event = torch.cuda.Event(enable_timing=True) 616 | torch.cuda.synchronize() 617 | start_event.record() 618 | 619 | for _ in range(n_steps): 620 | thread = Thread( 621 | target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs 622 | ) 623 | thread.start() 624 | for _ in self.streamer: 625 | pass 626 | 627 | if self.device == "cuda": 628 | end_event.record() 629 | torch.cuda.synchronize() 630 | 631 | logger.info( 632 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" 633 | ) 634 | 635 | def process(self, prompt): 636 | logger.debug("infering language model...") 637 | 638 | self.chat.append({"role": self.user_role, "content": prompt}) 639 | thread = Thread( 640 | target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs 641 | ) 642 | thread.start() 643 | if self.device == "mps": 644 | generated_text = "" 645 | for new_text in self.streamer: 646 | generated_text += new_text 647 | printable_text = generated_text 648 | torch.mps.empty_cache() 649 | else: 650 | generated_text, printable_text = "", "" 651 | for new_text in self.streamer: 652 | generated_text += new_text 653 | printable_text += new_text 654 | sentences = sent_tokenize(printable_text) 655 | if len(sentences) > 1: 656 | yield (sentences[0]) 657 | printable_text = new_text 658 | 659 | self.chat.append({"role": "assistant", "content": generated_text}) 660 | 661 | # don't forget last sentence 662 | yield printable_text 663 | 664 | 665 | @dataclass 666 | class ParlerTTSHandlerArguments: 667 | tts_model_name: str = field( 668 | default="ylacombe/parler-tts-mini-jenny-30H", 669 | metadata={ 670 | "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'." 671 | }, 672 | ) 673 | tts_device: str = field( 674 | default="cuda", 675 | metadata={ 676 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 677 | }, 678 | ) 679 | tts_torch_dtype: str = field( 680 | default="float16", 681 | metadata={ 682 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." 683 | }, 684 | ) 685 | tts_compile_mode: str = field( 686 | default=None, 687 | metadata={ 688 | "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" 689 | }, 690 | ) 691 | tts_gen_min_new_tokens: int = field( 692 | default=64, 693 | metadata={ 694 | "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs" 695 | }, 696 | ) 697 | tts_gen_max_new_tokens: int = field( 698 | default=512, 699 | metadata={ 700 | "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs" 701 | }, 702 | ) 703 | description: str = field( 704 | default=( 705 | "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " 706 | "She speaks very fast." 707 | ), 708 | metadata={ 709 | "help": "Description of the speaker's voice and speaking style to guide the TTS model." 710 | }, 711 | ) 712 | play_steps_s: float = field( 713 | default=1.0, 714 | metadata={ 715 | "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." 716 | }, 717 | ) 718 | max_prompt_pad_length: int = field( 719 | default=8, 720 | metadata={ 721 | "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible." 722 | }, 723 | ) 724 | 725 | 726 | class ParlerTTSHandler(BaseHandler): 727 | def setup( 728 | self, 729 | should_listen, 730 | model_name="ylacombe/parler-tts-mini-jenny-30H", 731 | device="cuda", 732 | torch_dtype="float16", 733 | compile_mode=None, 734 | gen_kwargs={}, 735 | max_prompt_pad_length=8, 736 | description=( 737 | "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " 738 | "She speaks very fast." 739 | ), 740 | play_steps_s=1, 741 | blocksize=512, 742 | ): 743 | self.should_listen = should_listen 744 | self.device = device 745 | self.torch_dtype = getattr(torch, torch_dtype) 746 | self.gen_kwargs = gen_kwargs 747 | self.compile_mode = compile_mode 748 | self.max_prompt_pad_length = max_prompt_pad_length 749 | self.description = description 750 | 751 | self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) 752 | self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) 753 | self.model = ParlerTTSForConditionalGeneration.from_pretrained( 754 | model_name, torch_dtype=self.torch_dtype 755 | ).to(device) 756 | 757 | framerate = self.model.audio_encoder.config.frame_rate 758 | self.play_steps = int(framerate * play_steps_s) 759 | self.blocksize = blocksize 760 | 761 | if self.compile_mode not in (None, "default"): 762 | logger.warning( 763 | "Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'" 764 | ) 765 | self.compile_mode = "default" 766 | 767 | if self.compile_mode: 768 | self.model.generation_config.cache_implementation = "static" 769 | self.model.forward = torch.compile( 770 | self.model.forward, mode=self.compile_mode, fullgraph=True 771 | ) 772 | 773 | self.warmup() 774 | 775 | def prepare_model_inputs( 776 | self, 777 | prompt, 778 | max_length_prompt=50, 779 | pad=False, 780 | ): 781 | pad_args_prompt = ( 782 | {"padding": "max_length", "max_length": max_length_prompt} if pad else {} 783 | ) 784 | 785 | tokenized_description = self.description_tokenizer( 786 | self.description, return_tensors="pt" 787 | ) 788 | input_ids = tokenized_description.input_ids.to(self.device) 789 | attention_mask = tokenized_description.attention_mask.to(self.device) 790 | 791 | tokenized_prompt = self.prompt_tokenizer( 792 | prompt, return_tensors="pt", **pad_args_prompt 793 | ) 794 | prompt_input_ids = tokenized_prompt.input_ids.to(self.device) 795 | prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device) 796 | 797 | gen_kwargs = { 798 | "input_ids": input_ids, 799 | "attention_mask": attention_mask, 800 | "prompt_input_ids": prompt_input_ids, 801 | "prompt_attention_mask": prompt_attention_mask, 802 | **self.gen_kwargs, 803 | } 804 | 805 | return gen_kwargs 806 | 807 | def warmup(self): 808 | logger.info(f"Warming up {self.__class__.__name__}") 809 | 810 | if self.device == "cuda": 811 | start_event = torch.cuda.Event(enable_timing=True) 812 | end_event = torch.cuda.Event(enable_timing=True) 813 | 814 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture 815 | n_steps = 1 if self.compile_mode == "default" else 2 816 | 817 | if self.device == "cuda": 818 | torch.cuda.synchronize() 819 | start_event.record() 820 | if self.compile_mode: 821 | pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] 822 | for pad_length in pad_lengths[::-1]: 823 | model_kwargs = self.prepare_model_inputs( 824 | "dummy prompt", max_length_prompt=pad_length, pad=True 825 | ) 826 | for _ in range(n_steps): 827 | _ = self.model.generate(**model_kwargs) 828 | logger.info(f"Warmed up length {pad_length} tokens!") 829 | else: 830 | model_kwargs = self.prepare_model_inputs("dummy prompt") 831 | for _ in range(n_steps): 832 | _ = self.model.generate(**model_kwargs) 833 | 834 | if self.device == "cuda": 835 | end_event.record() 836 | torch.cuda.synchronize() 837 | logger.info( 838 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" 839 | ) 840 | 841 | def process(self, llm_sentence): 842 | console.print(f"[green]ASSISTANT: {llm_sentence}") 843 | nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids) 844 | 845 | pad_args = {} 846 | if self.compile_mode: 847 | # pad to closest upper power of two 848 | pad_length = next_power_of_2(nb_tokens) 849 | logger.debug(f"padding to {pad_length}") 850 | pad_args["pad"] = True 851 | pad_args["max_length_prompt"] = pad_length 852 | 853 | tts_gen_kwargs = self.prepare_model_inputs( 854 | llm_sentence, 855 | **pad_args, 856 | ) 857 | 858 | streamer = ParlerTTSStreamer( 859 | self.model, device=self.device, play_steps=self.play_steps 860 | ) 861 | tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs} 862 | torch.manual_seed(0) 863 | thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs) 864 | thread.start() 865 | 866 | for i, audio_chunk in enumerate(streamer): 867 | if i == 0: 868 | logger.info( 869 | f"Time to first audio: {perf_counter() - pipeline_start:.3f}" 870 | ) 871 | audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) 872 | audio_chunk = (audio_chunk * 32768).astype(np.int16) 873 | for i in range(0, len(audio_chunk), self.blocksize): 874 | yield np.pad( 875 | audio_chunk[i : i + self.blocksize], 876 | (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), 877 | ) 878 | 879 | self.should_listen.set() 880 | 881 | 882 | def prepare_args(args, prefix): 883 | """ 884 | Rename arguments by removing the prefix and prepares the gen_kwargs. 885 | """ 886 | 887 | gen_kwargs = {} 888 | for key in copy(args.__dict__): 889 | if key.startswith(prefix): 890 | value = args.__dict__.pop(key) 891 | new_key = key[len(prefix) + 1 :] # Remove prefix and underscore 892 | if new_key.startswith("gen_"): 893 | gen_kwargs[new_key[4:]] = value # Remove 'gen_' and add to dict 894 | else: 895 | args.__dict__[new_key] = value 896 | 897 | args.__dict__["gen_kwargs"] = gen_kwargs 898 | 899 | 900 | def main(): 901 | parser = HfArgumentParser( 902 | ( 903 | ModuleArguments, 904 | SocketReceiverArguments, 905 | SocketSenderArguments, 906 | VADHandlerArguments, 907 | WhisperSTTHandlerArguments, 908 | LanguageModelHandlerArguments, 909 | ParlerTTSHandlerArguments, 910 | ) 911 | ) 912 | 913 | # 0. Parse CLI arguments 914 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 915 | # Parse configurations from a JSON file if specified 916 | ( 917 | module_kwargs, 918 | socket_receiver_kwargs, 919 | socket_sender_kwargs, 920 | vad_handler_kwargs, 921 | whisper_stt_handler_kwargs, 922 | language_model_handler_kwargs, 923 | parler_tts_handler_kwargs, 924 | ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 925 | else: 926 | # Parse arguments from command line if no JSON file is provided 927 | ( 928 | module_kwargs, 929 | socket_receiver_kwargs, 930 | socket_sender_kwargs, 931 | vad_handler_kwargs, 932 | whisper_stt_handler_kwargs, 933 | language_model_handler_kwargs, 934 | parler_tts_handler_kwargs, 935 | ) = parser.parse_args_into_dataclasses() 936 | 937 | # 1. Handle logger 938 | global logger 939 | logging.basicConfig( 940 | level=module_kwargs.log_level.upper(), 941 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 942 | ) 943 | logger = logging.getLogger(__name__) 944 | 945 | # torch compile logs 946 | if module_kwargs.log_level == "debug": 947 | torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) 948 | 949 | # 2. Prepare each part's arguments 950 | def overwrite_device_argument(common_device: Optional[str], *handler_kwargs): 951 | if common_device: 952 | for kwargs in handler_kwargs: 953 | if hasattr(kwargs, "lm_device"): 954 | kwargs.lm_device = common_device 955 | if hasattr(kwargs, "tts_device"): 956 | kwargs.tts_device = common_device 957 | if hasattr(kwargs, "stt_device"): 958 | kwargs.stt_device = common_device 959 | 960 | # Call this function with the common device and all the handlers 961 | overwrite_device_argument( 962 | module_kwargs.device, 963 | language_model_handler_kwargs, 964 | parler_tts_handler_kwargs, 965 | whisper_stt_handler_kwargs, 966 | ) 967 | 968 | prepare_args(whisper_stt_handler_kwargs, "stt") 969 | prepare_args(language_model_handler_kwargs, "lm") 970 | prepare_args(parler_tts_handler_kwargs, "tts") 971 | 972 | # 3. Build the pipeline 973 | stop_event = Event() 974 | # used to stop putting received audio chunks in queue until all setences have been processed by the TTS 975 | should_listen = Event() 976 | recv_audio_chunks_queue = Queue() 977 | send_audio_chunks_queue = Queue() 978 | spoken_prompt_queue = Queue() 979 | text_prompt_queue = Queue() 980 | lm_response_queue = Queue() 981 | 982 | if module_kwargs.mode == "local": 983 | local_audio_streamer = LocalAudioStreamer( 984 | input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue 985 | ) 986 | comms_handlers = [local_audio_streamer] 987 | should_listen.set() 988 | else: 989 | comms_handlers = [ 990 | SocketReceiver( 991 | stop_event, 992 | recv_audio_chunks_queue, 993 | should_listen, 994 | host=socket_receiver_kwargs.recv_host, 995 | port=socket_receiver_kwargs.recv_port, 996 | chunk_size=socket_receiver_kwargs.chunk_size, 997 | ), 998 | SocketSender( 999 | stop_event, 1000 | send_audio_chunks_queue, 1001 | host=socket_sender_kwargs.send_host, 1002 | port=socket_sender_kwargs.send_port, 1003 | ), 1004 | ] 1005 | 1006 | vad = VADHandler( 1007 | stop_event, 1008 | queue_in=recv_audio_chunks_queue, 1009 | queue_out=spoken_prompt_queue, 1010 | setup_args=(should_listen,), 1011 | setup_kwargs=vars(vad_handler_kwargs), 1012 | ) 1013 | stt = LightningWhisperSTTHandler( 1014 | stop_event, 1015 | queue_in=spoken_prompt_queue, 1016 | queue_out=text_prompt_queue, 1017 | setup_kwargs=vars(whisper_stt_handler_kwargs), 1018 | ) 1019 | lm = MLXLanguageModelHandler( 1020 | stop_event, 1021 | queue_in=text_prompt_queue, 1022 | queue_out=lm_response_queue, 1023 | setup_kwargs=vars(language_model_handler_kwargs), 1024 | ) 1025 | # tts = ParlerTTSHandler( 1026 | # stop_event, 1027 | # queue_in=lm_response_queue, 1028 | # queue_out=send_audio_chunks_queue, 1029 | # setup_args=(should_listen,), 1030 | # setup_kwargs=vars(parler_tts_handler_kwargs), 1031 | # ) 1032 | tts = MeloTTSHandler( 1033 | stop_event, 1034 | queue_in=lm_response_queue, 1035 | queue_out=send_audio_chunks_queue, 1036 | setup_args=(should_listen,), 1037 | ) 1038 | 1039 | # 4. Run the pipeline 1040 | try: 1041 | pipeline_manager = ThreadManager([*comms_handlers, vad, stt, lm, tts]) 1042 | pipeline_manager.start() 1043 | 1044 | except KeyboardInterrupt: 1045 | pipeline_manager.stop() 1046 | 1047 | 1048 | if __name__ == "__main__": 1049 | main() 1050 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def next_power_of_2(x): 6 | return 1 if x == 0 else 2**(x - 1).bit_length() 7 | 8 | 9 | def int2float(sound): 10 | """ 11 | Taken from https://github.com/snakers4/silero-vad 12 | """ 13 | 14 | abs_max = np.abs(sound).max() 15 | sound = sound.astype('float32') 16 | if abs_max > 0: 17 | sound *= 1/32768 18 | sound = sound.squeeze() # depends on the use case 19 | return sound 20 | 21 | 22 | class VADIterator: 23 | def __init__(self, 24 | model, 25 | threshold: float = 0.5, 26 | sampling_rate: int = 16000, 27 | min_silence_duration_ms: int = 100, 28 | speech_pad_ms: int = 30 29 | ): 30 | 31 | """ 32 | Mainly taken from https://github.com/snakers4/silero-vad 33 | Class for stream imitation 34 | 35 | Parameters 36 | ---------- 37 | model: preloaded .jit/.onnx silero VAD model 38 | 39 | threshold: float (default - 0.5) 40 | Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. 41 | It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. 42 | 43 | sampling_rate: int (default - 16000) 44 | Currently silero VAD models support 8000 and 16000 sample rates 45 | 46 | min_silence_duration_ms: int (default - 100 milliseconds) 47 | In the end of each speech chunk wait for min_silence_duration_ms before separating it 48 | 49 | speech_pad_ms: int (default - 30 milliseconds) 50 | Final speech chunks are padded by speech_pad_ms each side 51 | """ 52 | 53 | self.model = model 54 | self.threshold = threshold 55 | self.sampling_rate = sampling_rate 56 | self.is_speaking = False 57 | self.buffer = [] 58 | 59 | if sampling_rate not in [8000, 16000]: 60 | raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]') 61 | 62 | self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 63 | self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 64 | self.reset_states() 65 | 66 | def reset_states(self): 67 | 68 | self.model.reset_states() 69 | self.triggered = False 70 | self.temp_end = 0 71 | self.current_sample = 0 72 | 73 | @torch.no_grad() 74 | def __call__(self, x): 75 | """ 76 | x: torch.Tensor 77 | audio chunk (see examples in repo) 78 | 79 | return_seconds: bool (default - False) 80 | whether return timestamps in seconds (default - samples) 81 | """ 82 | 83 | if not torch.is_tensor(x): 84 | try: 85 | x = torch.Tensor(x) 86 | except: 87 | raise TypeError("Audio cannot be casted to tensor. Cast it manually") 88 | 89 | window_size_samples = len(x[0]) if x.dim() == 2 else len(x) 90 | self.current_sample += window_size_samples 91 | 92 | speech_prob = self.model(x, self.sampling_rate).item() 93 | 94 | if (speech_prob >= self.threshold) and self.temp_end: 95 | self.temp_end = 0 96 | 97 | if (speech_prob >= self.threshold) and not self.triggered: 98 | self.triggered = True 99 | return None 100 | 101 | if (speech_prob < self.threshold - 0.15) and self.triggered: 102 | if not self.temp_end: 103 | self.temp_end = self.current_sample 104 | if self.current_sample - self.temp_end < self.min_silence_samples: 105 | return None 106 | else: 107 | # end of speak 108 | self.temp_end = 0 109 | self.triggered = False 110 | spoken_utterance = self.buffer 111 | self.buffer = [] 112 | return spoken_utterance 113 | 114 | if self.triggered: 115 | self.buffer.append(x) 116 | 117 | return None 118 | --------------------------------------------------------------------------------