├── .dockerignore ├── .gitignore ├── Dockerfile ├── Dockerfile.arm64 ├── LICENSE ├── LLM ├── chat.py ├── language_model.py ├── mlx_language_model.py └── openai_api_language_model.py ├── README.md ├── STT ├── lightning_whisper_mlx_handler.py ├── paraformer_handler.py └── whisper_stt_handler.py ├── TTS ├── chatTTS_handler.py ├── melo_handler.py └── parler_handler.py ├── VAD ├── vad_handler.py └── vad_iterator.py ├── arguments_classes ├── chat_tts_arguments.py ├── language_model_arguments.py ├── melo_tts_arguments.py ├── mlx_language_model_arguments.py ├── module_arguments.py ├── open_api_language_model_arguments.py ├── paraformer_stt_arguments.py ├── parler_tts_arguments.py ├── socket_receiver_arguments.py ├── socket_sender_arguments.py ├── vad_arguments.py └── whisper_stt_arguments.py ├── baseHandler.py ├── connections ├── local_audio_streamer.py ├── socket_receiver.py └── socket_sender.py ├── docker-compose.yml ├── listen_and_play.py ├── logo.png ├── requirements.txt ├── requirements_mac.txt ├── s2s_pipeline.py └── utils ├── thread_manager.py └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | tmp 2 | cache 3 | Dockerfile 4 | docker-compose.yml 5 | .dockerignore 6 | .gitignore 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | tmp 3 | cache -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-devel 2 | 3 | ENV PYTHONUNBUFFERED 1 4 | 5 | WORKDIR /usr/src/app 6 | 7 | # Install packages 8 | RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* 9 | 10 | COPY requirements.txt ./ 11 | RUN pip install --no-cache-dir -r requirements.txt 12 | 13 | COPY . . 14 | -------------------------------------------------------------------------------- /Dockerfile.arm64: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3 2 | 3 | ENV PYTHONUNBUFFERED 1 4 | 5 | WORKDIR /usr/src/app 6 | 7 | # Install packages 8 | RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* 9 | 10 | COPY requirements.txt ./ 11 | RUN pip install --no-cache-dir -r requirements.txt 12 | 13 | COPY . . -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2024] [The HuggingFace Inc. team] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /LLM/chat.py: -------------------------------------------------------------------------------- 1 | class Chat: 2 | """ 3 | Handles the chat using to avoid OOM issues. 4 | """ 5 | 6 | def __init__(self, size): 7 | self.size = size 8 | self.init_chat_message = None 9 | # maxlen is necessary pair, since a each new step we add an prompt and assitant answer 10 | self.buffer = [] 11 | 12 | def append(self, item): 13 | self.buffer.append(item) 14 | if len(self.buffer) == 2 * (self.size + 1): 15 | self.buffer.pop(0) 16 | self.buffer.pop(0) 17 | 18 | def init_chat(self, init_chat_message): 19 | self.init_chat_message = init_chat_message 20 | 21 | def to_list(self): 22 | if self.init_chat_message: 23 | return [self.init_chat_message] + self.buffer 24 | else: 25 | return self.buffer 26 | -------------------------------------------------------------------------------- /LLM/language_model.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | from transformers import ( 3 | AutoModelForCausalLM, 4 | AutoTokenizer, 5 | pipeline, 6 | TextIteratorStreamer, 7 | ) 8 | import torch 9 | 10 | from LLM.chat import Chat 11 | from baseHandler import BaseHandler 12 | from rich.console import Console 13 | import logging 14 | from nltk import sent_tokenize 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | console = Console() 19 | 20 | 21 | WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { 22 | "en": "english", 23 | "fr": "french", 24 | "es": "spanish", 25 | "zh": "chinese", 26 | "ja": "japanese", 27 | "ko": "korean", 28 | } 29 | 30 | class LanguageModelHandler(BaseHandler): 31 | """ 32 | Handles the language model part. 33 | """ 34 | 35 | def setup( 36 | self, 37 | model_name="microsoft/Phi-3-mini-4k-instruct", 38 | device="cuda", 39 | torch_dtype="float16", 40 | gen_kwargs={}, 41 | user_role="user", 42 | chat_size=1, 43 | init_chat_role=None, 44 | init_chat_prompt="You are a helpful AI assistant.", 45 | ): 46 | self.device = device 47 | self.torch_dtype = getattr(torch, torch_dtype) 48 | 49 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 50 | self.model = AutoModelForCausalLM.from_pretrained( 51 | model_name, torch_dtype=torch_dtype, trust_remote_code=True 52 | ).to(device) 53 | self.pipe = pipeline( 54 | "text-generation", model=self.model, tokenizer=self.tokenizer, device=device 55 | ) 56 | self.streamer = TextIteratorStreamer( 57 | self.tokenizer, 58 | skip_prompt=True, 59 | skip_special_tokens=True, 60 | ) 61 | self.gen_kwargs = { 62 | "streamer": self.streamer, 63 | "return_full_text": False, 64 | **gen_kwargs, 65 | } 66 | 67 | self.chat = Chat(chat_size) 68 | if init_chat_role: 69 | if not init_chat_prompt: 70 | raise ValueError( 71 | "An initial promt needs to be specified when setting init_chat_role." 72 | ) 73 | self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) 74 | self.user_role = user_role 75 | 76 | self.warmup() 77 | 78 | def warmup(self): 79 | logger.info(f"Warming up {self.__class__.__name__}") 80 | 81 | dummy_input_text = "Repeat the word 'home'." 82 | dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] 83 | warmup_gen_kwargs = { 84 | "min_new_tokens": self.gen_kwargs["min_new_tokens"], 85 | "max_new_tokens": self.gen_kwargs["max_new_tokens"], 86 | **self.gen_kwargs, 87 | } 88 | 89 | n_steps = 2 90 | 91 | if self.device == "cuda": 92 | start_event = torch.cuda.Event(enable_timing=True) 93 | end_event = torch.cuda.Event(enable_timing=True) 94 | torch.cuda.synchronize() 95 | start_event.record() 96 | 97 | for _ in range(n_steps): 98 | thread = Thread( 99 | target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs 100 | ) 101 | thread.start() 102 | for _ in self.streamer: 103 | pass 104 | 105 | if self.device == "cuda": 106 | end_event.record() 107 | torch.cuda.synchronize() 108 | 109 | logger.info( 110 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" 111 | ) 112 | 113 | def process(self, prompt): 114 | logger.debug("infering language model...") 115 | language_code = None 116 | if isinstance(prompt, tuple): 117 | prompt, language_code = prompt 118 | prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt 119 | 120 | self.chat.append({"role": self.user_role, "content": prompt}) 121 | thread = Thread( 122 | target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs 123 | ) 124 | thread.start() 125 | if self.device == "mps": 126 | generated_text = "" 127 | for new_text in self.streamer: 128 | generated_text += new_text 129 | printable_text = generated_text 130 | torch.mps.empty_cache() 131 | else: 132 | generated_text, printable_text = "", "" 133 | for new_text in self.streamer: 134 | generated_text += new_text 135 | printable_text += new_text 136 | sentences = sent_tokenize(printable_text) 137 | if len(sentences) > 1: 138 | yield (sentences[0], language_code) 139 | printable_text = new_text 140 | 141 | self.chat.append({"role": "assistant", "content": generated_text}) 142 | 143 | # don't forget last sentence 144 | yield (printable_text, language_code) 145 | -------------------------------------------------------------------------------- /LLM/mlx_language_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from LLM.chat import Chat 3 | from baseHandler import BaseHandler 4 | from mlx_lm import load, stream_generate, generate 5 | from rich.console import Console 6 | import torch 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | console = Console() 11 | 12 | WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { 13 | "en": "english", 14 | "fr": "french", 15 | "es": "spanish", 16 | "zh": "chinese", 17 | "ja": "japanese", 18 | "ko": "korean", 19 | } 20 | 21 | class MLXLanguageModelHandler(BaseHandler): 22 | """ 23 | Handles the language model part. 24 | """ 25 | 26 | def setup( 27 | self, 28 | model_name="microsoft/Phi-3-mini-4k-instruct", 29 | device="mps", 30 | torch_dtype="float16", 31 | gen_kwargs={}, 32 | user_role="user", 33 | chat_size=1, 34 | init_chat_role=None, 35 | init_chat_prompt="You are a helpful AI assistant.", 36 | ): 37 | self.model_name = model_name 38 | self.model, self.tokenizer = load(self.model_name) 39 | self.gen_kwargs = gen_kwargs 40 | 41 | self.chat = Chat(chat_size) 42 | if init_chat_role: 43 | if not init_chat_prompt: 44 | raise ValueError( 45 | "An initial promt needs to be specified when setting init_chat_role." 46 | ) 47 | self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) 48 | self.user_role = user_role 49 | 50 | self.warmup() 51 | 52 | def warmup(self): 53 | logger.info(f"Warming up {self.__class__.__name__}") 54 | 55 | dummy_input_text = "Repeat the word 'home'." 56 | dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] 57 | 58 | n_steps = 2 59 | 60 | for _ in range(n_steps): 61 | prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False) 62 | generate( 63 | self.model, 64 | self.tokenizer, 65 | prompt=prompt, 66 | max_tokens=self.gen_kwargs["max_new_tokens"], 67 | verbose=False, 68 | ) 69 | 70 | def process(self, prompt): 71 | logger.debug("infering language model...") 72 | language_code = None 73 | 74 | if isinstance(prompt, tuple): 75 | prompt, language_code = prompt 76 | prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt 77 | 78 | self.chat.append({"role": self.user_role, "content": prompt}) 79 | 80 | # Remove system messages if using a Gemma model 81 | if "gemma" in self.model_name.lower(): 82 | chat_messages = [ 83 | msg for msg in self.chat.to_list() if msg["role"] != "system" 84 | ] 85 | else: 86 | chat_messages = self.chat.to_list() 87 | 88 | prompt = self.tokenizer.apply_chat_template( 89 | chat_messages, tokenize=False, add_generation_prompt=True 90 | ) 91 | output = "" 92 | curr_output = "" 93 | for t in stream_generate( 94 | self.model, 95 | self.tokenizer, 96 | prompt, 97 | max_tokens=self.gen_kwargs["max_new_tokens"], 98 | ): 99 | output += t 100 | curr_output += t 101 | if curr_output.endswith((".", "?", "!", "<|end|>")): 102 | yield (curr_output.replace("<|end|>", ""), language_code) 103 | curr_output = "" 104 | generated_text = output.replace("<|end|>", "") 105 | torch.mps.empty_cache() 106 | 107 | self.chat.append({"role": "assistant", "content": generated_text}) -------------------------------------------------------------------------------- /LLM/openai_api_language_model.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | from LLM.chat import Chat 3 | from baseHandler import BaseHandler 4 | from rich.console import Console 5 | import logging 6 | import time 7 | logger = logging.getLogger(__name__) 8 | 9 | console = Console() 10 | from nltk import sent_tokenize 11 | 12 | class OpenApiModelHandler(BaseHandler): 13 | """ 14 | Handles the language model part. 15 | """ 16 | def setup( 17 | self, 18 | model_name="deepseek-chat", 19 | device="cuda", 20 | gen_kwargs={}, 21 | base_url =None, 22 | api_key=None, 23 | stream=False, 24 | user_role="user", 25 | chat_size=1, 26 | init_chat_role="system", 27 | init_chat_prompt="You are a helpful AI assistant.", 28 | ): 29 | self.model_name = model_name 30 | self.stream = stream 31 | self.chat = Chat(chat_size) 32 | if init_chat_role: 33 | if not init_chat_prompt: 34 | raise ValueError( 35 | "An initial promt needs to be specified when setting init_chat_role." 36 | ) 37 | self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) 38 | self.user_role = user_role 39 | self.client = OpenAI(api_key=api_key, base_url=base_url) 40 | self.warmup() 41 | 42 | def warmup(self): 43 | logger.info(f"Warming up {self.__class__.__name__}") 44 | start = time.time() 45 | response = self.client.chat.completions.create( 46 | model=self.model_name, 47 | messages=[ 48 | {"role": "system", "content": "You are a helpful assistant"}, 49 | {"role": "user", "content": "Hello"}, 50 | ], 51 | stream=self.stream 52 | ) 53 | end = time.time() 54 | logger.info( 55 | f"{self.__class__.__name__}: warmed up! time: {(end - start):.3f} s" 56 | ) 57 | def process(self, prompt): 58 | logger.debug("call api language model...") 59 | self.chat.append({"role": self.user_role, "content": prompt}) 60 | 61 | language_code = None 62 | if isinstance(prompt, tuple): 63 | prompt, language_code = prompt 64 | 65 | response = self.client.chat.completions.create( 66 | model=self.model_name, 67 | messages=[ 68 | {"role": self.user_role, "content": prompt}, 69 | ], 70 | stream=self.stream 71 | ) 72 | if self.stream: 73 | generated_text, printable_text = "", "" 74 | for chunk in response: 75 | new_text = chunk.choices[0].delta.content or "" 76 | generated_text += new_text 77 | printable_text += new_text 78 | sentences = sent_tokenize(printable_text) 79 | if len(sentences) > 1: 80 | yield sentences[0], language_code 81 | printable_text = new_text 82 | self.chat.append({"role": "assistant", "content": generated_text}) 83 | # don't forget last sentence 84 | yield printable_text, language_code 85 | else: 86 | generated_text = response.choices[0].message.content 87 | self.chat.append({"role": "assistant", "content": generated_text}) 88 | yield generated_text, language_code 89 | 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
 
3 | 4 |
5 | 6 | # Speech To Speech: an effort for an open-sourced and modular GPT4-o 7 | 8 | 9 | ## 📖 Quick Index 10 | * [Approach](#approach) 11 | - [Structure](#structure) 12 | - [Modularity](#modularity) 13 | * [Setup](#setup) 14 | * [Usage](#usage) 15 | - [Docker Server approach](#docker-server) 16 | - [Server/Client approach](#serverclient-approach) 17 | - [Local approach](#local-approach-running-on-mac) 18 | * [Command-line usage](#command-line-usage) 19 | - [Model parameters](#model-parameters) 20 | - [Generation parameters](#generation-parameters) 21 | - [Notable parameters](#notable-parameters) 22 | 23 | ## Approach 24 | 25 | ### Structure 26 | This repository implements a speech-to-speech cascaded pipeline with consecutive parts: 27 | 1. **Voice Activity Detection (VAD)**: [silero VAD v5](https://github.com/snakers4/silero-vad) 28 | 2. **Speech to Text (STT)**: Whisper checkpoints (including [distilled versions](https://huggingface.co/distil-whisper)) 29 | 3. **Language Model (LM)**: Any instruct model available on the [Hugging Face Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending)! 🤗 30 | 4. **Text to Speech (TTS)**: [Parler-TTS](https://github.com/huggingface/parler-tts)🤗 31 | 32 | ### Modularity 33 | The pipeline aims to provide a fully open and modular approach, leveraging models available on the Transformers library via the Hugging Face hub. The level of modularity intended for each part is as follows: 34 | - **VAD**: Uses the implementation from [Silero's repo](https://github.com/snakers4/silero-vad). 35 | - **STT**: Uses Whisper models exclusively; however, any Whisper checkpoint can be used, enabling options like [Distil-Whisper](https://huggingface.co/distil-whisper/distil-large-v3) and [French Distil-Whisper](https://huggingface.co/eustlb/distil-large-v3-fr). 36 | - **LM**: This part is fully modular and can be changed by simply modifying the Hugging Face hub model ID. Users need to select an instruct model since the usage here involves interacting with it. 37 | - **TTS**: The mini architecture of Parler-TTS is standard, but different checkpoints, including fine-tuned multilingual checkpoints, can be used. 38 | 39 | The code is designed to facilitate easy modification. Each component is implemented as a class and can be re-implemented to match specific needs. 40 | 41 | ## Setup 42 | 43 | Clone the repository: 44 | ```bash 45 | git clone https://github.com/huggingface/speech-to-speech.git 46 | cd speech-to-speech 47 | ``` 48 | 49 | Install the required dependencies using [uv](https://github.com/astral-sh/uv): 50 | ```bash 51 | uv pip install -r requirements.txt 52 | ``` 53 | 54 | For Mac users, use the `requirements_mac.txt` file instead: 55 | ```bash 56 | uv pip install -r requirements_mac.txt 57 | ``` 58 | 59 | If you want to use Melo TTS, you also need to run: 60 | ```bash 61 | python -m unidic download 62 | ``` 63 | 64 | 65 | ## Usage 66 | 67 | The pipeline can be run in two ways: 68 | - **Server/Client approach**: Models run on a server, and audio input/output are streamed from a client. 69 | - **Local approach**: Runs locally. 70 | 71 | ### Docker Server 72 | 73 | #### Install the NVIDIA Container Toolkit 74 | 75 | https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html 76 | 77 | #### Start the docker container 78 | ```docker compose up``` 79 | 80 | ### Server/Client Approach 81 | 82 | 1. Run the pipeline on the server: 83 | ```bash 84 | python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0 85 | ``` 86 | 87 | 2. Run the client locally to handle microphone input and receive generated audio: 88 | ```bash 89 | python listen_and_play.py --host 90 | ``` 91 | 92 | ### Local Approach (Mac) 93 | 94 | 1. For optimal settings on Mac: 95 | ```bash 96 | python s2s_pipeline.py --local_mac_optimal_settings 97 | ``` 98 | 99 | This setting: 100 | - Adds `--device mps` to use MPS for all models. 101 | - Sets LightningWhisperMLX for STT 102 | - Sets MLX LM for language model 103 | - Sets MeloTTS for TTS 104 | 105 | ### Recommended usage with Cuda 106 | 107 | Leverage Torch Compile for Whisper and Parler-TTS: 108 | 109 | ```bash 110 | python s2s_pipeline.py \ 111 | --recv_host 0.0.0.0 \ 112 | --send_host 0.0.0.0 \ 113 | --lm_model_name microsoft/Phi-3-mini-4k-instruct \ 114 | --init_chat_role system \ 115 | --stt_compile_mode reduce-overhead \ 116 | --tts_compile_mode default 117 | ``` 118 | 119 | For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`). 120 | 121 | 122 | ### Multi-language Support 123 | 124 | The pipeline supports multiple languages, allowing for automatic language detection or specific language settings. Here are examples for both local (Mac) and server setups: 125 | 126 | #### With the server version: 127 | 128 | 129 | For automatic language detection: 130 | 131 | ```bash 132 | python s2s_pipeline.py \ 133 | --stt_model_name large-v3 \ 134 | --language zh \ 135 | --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \ 136 | ``` 137 | 138 | Or for one language in particular, chinese in this example 139 | 140 | ```bash 141 | python s2s_pipeline.py \ 142 | --stt_model_name large-v3 \ 143 | --language zh \ 144 | --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \ 145 | ``` 146 | 147 | #### Local Mac Setup 148 | 149 | For automatic language detection: 150 | 151 | ```bash 152 | python s2s_pipeline.py \ 153 | --local_mac_optimal_settings \ 154 | --device mps \ 155 | --stt_model_name large-v3 \ 156 | --language zh \ 157 | --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \ 158 | ``` 159 | 160 | Or for one language in particular, chinese in this example 161 | 162 | ```bash 163 | python s2s_pipeline.py \ 164 | --local_mac_optimal_settings \ 165 | --device mps \ 166 | --stt_model_name large-v3 \ 167 | --language zh \ 168 | --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \ 169 | ``` 170 | 171 | 172 | ## Command-line Usage 173 | 174 | ### Model Parameters 175 | 176 | `model_name`, `torch_dtype`, and `device` are exposed for each part leveraging the Transformers' implementations: Speech to Text, Language Model, and Text to Speech. Specify the targeted pipeline part with the corresponding prefix: 177 | - `stt` (Speech to Text) 178 | - `lm` (Language Model) 179 | - `tts` (Text to Speech) 180 | 181 | For example: 182 | ```bash 183 | --lm_model_name google/gemma-2b-it 184 | ``` 185 | 186 | ### Generation Parameters 187 | 188 | Other generation parameters of the model's generate method can be set using the part's prefix + `_gen_`, e.g., `--stt_gen_max_new_tokens 128`. These parameters can be added to the pipeline part's arguments class if not already exposed (see `LanguageModelHandlerArguments` for example). 189 | 190 | ### Notable Parameters 191 | 192 | #### VAD Parameters 193 | - `--thresh`: Threshold value to trigger voice activity detection. 194 | - `--min_speech_ms`: Minimum duration of detected voice activity to be considered speech. 195 | - `--min_silence_ms`: Minimum length of silence intervals for segmenting speech, balancing sentence cutting and latency reduction. 196 | 197 | #### Language Model 198 | - `--init_chat_role`: Defaults to `None`. Sets the initial role in the chat template, if applicable. Refer to the model's card to set this value (e.g. for [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) you have to set `--init_chat_role system`) 199 | - `--init_chat_prompt`: Defaults to `"You are a helpful AI assistant."` Required when setting `--init_chat_role`. 200 | 201 | #### Speech to Text 202 | - `--description`: Sets the description for Parler-TTS generated voice. Defaults to: `"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."` 203 | 204 | - `--play_steps_s`: Specifies the duration of the first chunk sent during streaming output from Parler-TTS, impacting readiness and decoding steps. 205 | 206 | ## Citations 207 | 208 | ### Silero VAD 209 | ```bibtex 210 | @misc{Silero VAD, 211 | author = {Silero Team}, 212 | title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier}, 213 | year = {2021}, 214 | publisher = {GitHub}, 215 | journal = {GitHub repository}, 216 | howpublished = {\url{https://github.com/snakers4/silero-vad}}, 217 | commit = {insert_some_commit_here}, 218 | email = {hello@silero.ai} 219 | } 220 | ``` 221 | 222 | ### Distil-Whisper 223 | ```bibtex 224 | @misc{gandhi2023distilwhisper, 225 | title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling}, 226 | author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush}, 227 | year={2023}, 228 | eprint={2311.00430}, 229 | archivePrefix={arXiv}, 230 | primaryClass={cs.CL} 231 | } 232 | ``` 233 | 234 | ### Parler-TTS 235 | ```bibtex 236 | @misc{lacombe-etal-2024-parler-tts, 237 | author = {Yoach Lacombe and Vaibhav Srivastav and Sanchit Gandhi}, 238 | title = {Parler-TTS}, 239 | year = {2024}, 240 | publisher = {GitHub}, 241 | journal = {GitHub repository}, 242 | howpublished = {\url{https://github.com/huggingface/parler-tts}} 243 | } 244 | ``` 245 | -------------------------------------------------------------------------------- /STT/lightning_whisper_mlx_handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import perf_counter 3 | from baseHandler import BaseHandler 4 | from lightning_whisper_mlx import LightningWhisperMLX 5 | import numpy as np 6 | from rich.console import Console 7 | from copy import copy 8 | import torch 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | console = Console() 13 | 14 | SUPPORTED_LANGUAGES = [ 15 | "en", 16 | "fr", 17 | "es", 18 | "zh", 19 | "ja", 20 | "ko", 21 | ] 22 | 23 | 24 | class LightningWhisperSTTHandler(BaseHandler): 25 | """ 26 | Handles the Speech To Text generation using a Whisper model. 27 | """ 28 | 29 | def setup( 30 | self, 31 | model_name="distil-large-v3", 32 | device="mps", 33 | torch_dtype="float16", 34 | compile_mode=None, 35 | language=None, 36 | gen_kwargs={}, 37 | ): 38 | if len(model_name.split("/")) > 1: 39 | model_name = model_name.split("/")[-1] 40 | self.device = device 41 | self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None) 42 | self.start_language = language 43 | self.last_language = language 44 | 45 | self.warmup() 46 | 47 | def warmup(self): 48 | logger.info(f"Warming up {self.__class__.__name__}") 49 | 50 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture 51 | n_steps = 1 52 | dummy_input = np.array([0] * 512) 53 | 54 | for _ in range(n_steps): 55 | _ = self.model.transcribe(dummy_input)["text"].strip() 56 | 57 | def process(self, spoken_prompt): 58 | logger.debug("infering whisper...") 59 | 60 | global pipeline_start 61 | pipeline_start = perf_counter() 62 | 63 | if self.start_language != 'auto': 64 | transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language) 65 | else: 66 | transcription_dict = self.model.transcribe(spoken_prompt) 67 | language_code = transcription_dict["language"] 68 | if language_code not in SUPPORTED_LANGUAGES: 69 | logger.warning(f"Whisper detected unsupported language: {language_code}") 70 | if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language 71 | transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language) 72 | else: 73 | transcription_dict = {"text": "", "language": "en"} 74 | else: 75 | self.last_language = language_code 76 | 77 | pred_text = transcription_dict["text"].strip() 78 | language_code = transcription_dict["language"] 79 | torch.mps.empty_cache() 80 | 81 | logger.debug("finished whisper inference") 82 | console.print(f"[yellow]USER: {pred_text}") 83 | logger.debug(f"Language Code Whisper: {language_code}") 84 | 85 | yield (pred_text, language_code) 86 | -------------------------------------------------------------------------------- /STT/paraformer_handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import perf_counter 3 | 4 | from baseHandler import BaseHandler 5 | from funasr import AutoModel 6 | import numpy as np 7 | from rich.console import Console 8 | import torch 9 | 10 | logging.basicConfig( 11 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 12 | ) 13 | logger = logging.getLogger(__name__) 14 | 15 | console = Console() 16 | 17 | 18 | class ParaformerSTTHandler(BaseHandler): 19 | """ 20 | Handles the Speech To Text generation using a Paraformer model. 21 | The default for this model is set to Chinese. 22 | This model was contributed by @wuhongsheng. 23 | """ 24 | 25 | def setup( 26 | self, 27 | model_name="paraformer-zh", 28 | device="cuda", 29 | gen_kwargs={}, 30 | ): 31 | print(model_name) 32 | if len(model_name.split("/")) > 1: 33 | model_name = model_name.split("/")[-1] 34 | self.device = device 35 | self.model = AutoModel(model=model_name, device=device) 36 | self.warmup() 37 | 38 | def warmup(self): 39 | logger.info(f"Warming up {self.__class__.__name__}") 40 | 41 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture 42 | n_steps = 1 43 | dummy_input = np.array([0] * 512, dtype=np.float32) 44 | for _ in range(n_steps): 45 | _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "") 46 | 47 | def process(self, spoken_prompt): 48 | logger.debug("infering paraformer...") 49 | 50 | global pipeline_start 51 | pipeline_start = perf_counter() 52 | 53 | pred_text = ( 54 | self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "") 55 | ) 56 | torch.mps.empty_cache() 57 | 58 | logger.debug("finished paraformer inference") 59 | console.print(f"[yellow]USER: {pred_text}") 60 | 61 | yield pred_text 62 | -------------------------------------------------------------------------------- /STT/whisper_stt_handler.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | from transformers import ( 3 | AutoProcessor, 4 | AutoModelForSpeechSeq2Seq 5 | ) 6 | import torch 7 | from copy import copy 8 | from baseHandler import BaseHandler 9 | from rich.console import Console 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | console = Console() 14 | 15 | SUPPORTED_LANGUAGES = [ 16 | "en", 17 | "fr", 18 | "es", 19 | "zh", 20 | "ja", 21 | "ko", 22 | ] 23 | 24 | 25 | class WhisperSTTHandler(BaseHandler): 26 | """ 27 | Handles the Speech To Text generation using a Whisper model. 28 | """ 29 | 30 | def setup( 31 | self, 32 | model_name="distil-whisper/distil-large-v3", 33 | device="cuda", 34 | torch_dtype="float16", 35 | compile_mode=None, 36 | language=None, 37 | gen_kwargs={}, 38 | ): 39 | self.device = device 40 | self.torch_dtype = getattr(torch, torch_dtype) 41 | self.compile_mode = compile_mode 42 | self.gen_kwargs = gen_kwargs 43 | if language == 'auto': 44 | language = None 45 | self.last_language = language 46 | if self.last_language is not None: 47 | self.gen_kwargs["language"] = self.last_language 48 | 49 | self.processor = AutoProcessor.from_pretrained(model_name) 50 | self.model = AutoModelForSpeechSeq2Seq.from_pretrained( 51 | model_name, 52 | torch_dtype=self.torch_dtype, 53 | ).to(device) 54 | 55 | # compile 56 | if self.compile_mode: 57 | self.model.generation_config.cache_implementation = "static" 58 | self.model.forward = torch.compile( 59 | self.model.forward, mode=self.compile_mode, fullgraph=True 60 | ) 61 | self.warmup() 62 | 63 | def prepare_model_inputs(self, spoken_prompt): 64 | input_features = self.processor( 65 | spoken_prompt, sampling_rate=16000, return_tensors="pt" 66 | ).input_features 67 | input_features = input_features.to(self.device, dtype=self.torch_dtype) 68 | 69 | return input_features 70 | 71 | def warmup(self): 72 | logger.info(f"Warming up {self.__class__.__name__}") 73 | 74 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture 75 | n_steps = 1 if self.compile_mode == "default" else 2 76 | dummy_input = torch.randn( 77 | (1, self.model.config.num_mel_bins, 3000), 78 | dtype=self.torch_dtype, 79 | device=self.device, 80 | ) 81 | if self.compile_mode not in (None, "default"): 82 | # generating more tokens than previously will trigger CUDA graphs capture 83 | # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation 84 | # hence, having min_new_tokens < max_new_tokens in the future doesn't make sense 85 | warmup_gen_kwargs = { 86 | "min_new_tokens": self.gen_kwargs[ 87 | "max_new_tokens" 88 | ], # Yes, assign max_new_tokens to min_new_tokens 89 | "max_new_tokens": self.gen_kwargs["max_new_tokens"], 90 | **self.gen_kwargs, 91 | } 92 | else: 93 | warmup_gen_kwargs = self.gen_kwargs 94 | 95 | if self.device == "cuda": 96 | start_event = torch.cuda.Event(enable_timing=True) 97 | end_event = torch.cuda.Event(enable_timing=True) 98 | torch.cuda.synchronize() 99 | start_event.record() 100 | 101 | for _ in range(n_steps): 102 | _ = self.model.generate(dummy_input, **warmup_gen_kwargs) 103 | 104 | if self.device == "cuda": 105 | end_event.record() 106 | torch.cuda.synchronize() 107 | 108 | logger.info( 109 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" 110 | ) 111 | 112 | def process(self, spoken_prompt): 113 | logger.debug("infering whisper...") 114 | 115 | global pipeline_start 116 | pipeline_start = perf_counter() 117 | 118 | input_features = self.prepare_model_inputs(spoken_prompt) 119 | pred_ids = self.model.generate(input_features, **self.gen_kwargs) 120 | language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" 121 | 122 | if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language 123 | logger.warning("Whisper detected unsupported language:", language_code) 124 | gen_kwargs = copy(self.gen_kwargs) 125 | gen_kwargs['language'] = self.last_language 126 | language_code = self.last_language 127 | pred_ids = self.model.generate(input_features, **gen_kwargs) 128 | else: 129 | self.last_language = language_code 130 | 131 | pred_text = self.processor.batch_decode( 132 | pred_ids, skip_special_tokens=True, decode_with_timestamps=False 133 | )[0] 134 | language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" 135 | 136 | logger.debug("finished whisper inference") 137 | console.print(f"[yellow]USER: {pred_text}") 138 | logger.debug(f"Language Code Whisper: {language_code}") 139 | 140 | yield (pred_text, language_code) 141 | -------------------------------------------------------------------------------- /TTS/chatTTS_handler.py: -------------------------------------------------------------------------------- 1 | import ChatTTS 2 | import logging 3 | from baseHandler import BaseHandler 4 | import librosa 5 | import numpy as np 6 | from rich.console import Console 7 | import torch 8 | 9 | logging.basicConfig( 10 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 11 | ) 12 | logger = logging.getLogger(__name__) 13 | 14 | console = Console() 15 | 16 | 17 | class ChatTTSHandler(BaseHandler): 18 | def setup( 19 | self, 20 | should_listen, 21 | device="cuda", 22 | gen_kwargs={}, # Unused 23 | stream=True, 24 | chunk_size=512, 25 | ): 26 | self.should_listen = should_listen 27 | self.device = device 28 | self.model = ChatTTS.Chat() 29 | self.model.load(compile=False) # Doesn't work for me with True 30 | self.chunk_size = chunk_size 31 | self.stream = stream 32 | rnd_spk_emb = self.model.sample_random_speaker() 33 | self.params_infer_code = ChatTTS.Chat.InferCodeParams( 34 | spk_emb=rnd_spk_emb, 35 | ) 36 | self.warmup() 37 | 38 | def warmup(self): 39 | logger.info(f"Warming up {self.__class__.__name__}") 40 | _ = self.model.infer("text") 41 | 42 | def process(self, llm_sentence): 43 | console.print(f"[green]ASSISTANT: {llm_sentence}") 44 | if self.device == "mps": 45 | import time 46 | 47 | start = time.time() 48 | torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete. 49 | torch.mps.empty_cache() # Frees all memory allocated by the MPS device. 50 | _ = ( 51 | time.time() - start 52 | ) # Removing this line makes it fail more often. I'm looking into it. 53 | 54 | wavs_gen = self.model.infer( 55 | llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream 56 | ) 57 | 58 | if self.stream: 59 | wavs = [np.array([])] 60 | for gen in wavs_gen: 61 | if gen[0] is None or len(gen[0]) == 0: 62 | self.should_listen.set() 63 | return 64 | audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000) 65 | audio_chunk = (audio_chunk * 32768).astype(np.int16)[0] 66 | while len(audio_chunk) > self.chunk_size: 67 | yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据 68 | audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据 69 | yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk))) 70 | else: 71 | wavs = wavs_gen 72 | if len(wavs[0]) == 0: 73 | self.should_listen.set() 74 | return 75 | audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000) 76 | audio_chunk = (audio_chunk * 32768).astype(np.int16) 77 | for i in range(0, len(audio_chunk), self.chunk_size): 78 | yield np.pad( 79 | audio_chunk[i : i + self.chunk_size], 80 | (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])), 81 | ) 82 | self.should_listen.set() 83 | -------------------------------------------------------------------------------- /TTS/melo_handler.py: -------------------------------------------------------------------------------- 1 | from melo.api import TTS 2 | import logging 3 | from baseHandler import BaseHandler 4 | import librosa 5 | import numpy as np 6 | from rich.console import Console 7 | import torch 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | console = Console() 12 | 13 | WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { 14 | "en": "EN", 15 | "fr": "FR", 16 | "es": "ES", 17 | "zh": "ZH", 18 | "ja": "JP", 19 | "ko": "KR", 20 | } 21 | 22 | WHISPER_LANGUAGE_TO_MELO_SPEAKER = { 23 | "en": "EN-BR", 24 | "fr": "FR", 25 | "es": "ES", 26 | "zh": "ZH", 27 | "ja": "JP", 28 | "ko": "KR", 29 | } 30 | 31 | 32 | class MeloTTSHandler(BaseHandler): 33 | def setup( 34 | self, 35 | should_listen, 36 | device="mps", 37 | language="en", 38 | speaker_to_id="en", 39 | gen_kwargs={}, # Unused 40 | blocksize=512, 41 | ): 42 | self.should_listen = should_listen 43 | self.device = device 44 | self.language = language 45 | self.model = TTS( 46 | language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device 47 | ) 48 | self.speaker_id = self.model.hps.data.spk2id[ 49 | WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id] 50 | ] 51 | self.blocksize = blocksize 52 | self.warmup() 53 | 54 | def warmup(self): 55 | logger.info(f"Warming up {self.__class__.__name__}") 56 | _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) 57 | 58 | def process(self, llm_sentence): 59 | language_code = None 60 | 61 | if isinstance(llm_sentence, tuple): 62 | llm_sentence, language_code = llm_sentence 63 | 64 | console.print(f"[green]ASSISTANT: {llm_sentence}") 65 | 66 | if language_code is not None and self.language != language_code: 67 | try: 68 | self.model = TTS( 69 | language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code], 70 | device=self.device, 71 | ) 72 | self.speaker_id = self.model.hps.data.spk2id[ 73 | WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code] 74 | ] 75 | self.language = language_code 76 | except KeyError: 77 | console.print( 78 | f"[red]Language {language_code} not supported by Melo. Using {self.language} instead." 79 | ) 80 | 81 | if self.device == "mps": 82 | import time 83 | 84 | start = time.time() 85 | torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete. 86 | torch.mps.empty_cache() # Frees all memory allocated by the MPS device. 87 | _ = ( 88 | time.time() - start 89 | ) # Removing this line makes it fail more often. I'm looking into it. 90 | 91 | try: 92 | audio_chunk = self.model.tts_to_file( 93 | llm_sentence, self.speaker_id, quiet=True 94 | ) 95 | except (AssertionError, RuntimeError) as e: 96 | logger.error(f"Error in MeloTTSHandler: {e}") 97 | audio_chunk = np.array([]) 98 | if len(audio_chunk) == 0: 99 | self.should_listen.set() 100 | return 101 | audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) 102 | audio_chunk = (audio_chunk * 32768).astype(np.int16) 103 | for i in range(0, len(audio_chunk), self.blocksize): 104 | yield np.pad( 105 | audio_chunk[i : i + self.blocksize], 106 | (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), 107 | ) 108 | 109 | self.should_listen.set() 110 | -------------------------------------------------------------------------------- /TTS/parler_handler.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | from time import perf_counter 3 | from baseHandler import BaseHandler 4 | import numpy as np 5 | import torch 6 | from transformers import ( 7 | AutoTokenizer, 8 | ) 9 | from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer 10 | import librosa 11 | import logging 12 | from rich.console import Console 13 | from utils.utils import next_power_of_2 14 | from transformers.utils.import_utils import ( 15 | is_flash_attn_2_available, 16 | ) 17 | 18 | torch._inductor.config.fx_graph_cache = True 19 | # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS 20 | torch._dynamo.config.cache_size_limit = 15 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | console = Console() 25 | 26 | 27 | if not is_flash_attn_2_available() and torch.cuda.is_available(): 28 | logger.warn( 29 | """Parler TTS works best with flash attention 2, but is not installed 30 | Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`""" 31 | ) 32 | 33 | 34 | class ParlerTTSHandler(BaseHandler): 35 | def setup( 36 | self, 37 | should_listen, 38 | model_name="ylacombe/parler-tts-mini-jenny-30H", 39 | device="cuda", 40 | torch_dtype="float16", 41 | compile_mode=None, 42 | gen_kwargs={}, 43 | max_prompt_pad_length=8, 44 | description=( 45 | "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " 46 | "She speaks very fast." 47 | ), 48 | play_steps_s=1, 49 | blocksize=512, 50 | ): 51 | self.should_listen = should_listen 52 | self.device = device 53 | self.torch_dtype = getattr(torch, torch_dtype) 54 | self.gen_kwargs = gen_kwargs 55 | self.compile_mode = compile_mode 56 | self.max_prompt_pad_length = max_prompt_pad_length 57 | self.description = description 58 | 59 | self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) 60 | self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) 61 | self.model = ParlerTTSForConditionalGeneration.from_pretrained( 62 | model_name, torch_dtype=self.torch_dtype 63 | ).to(device) 64 | 65 | framerate = self.model.audio_encoder.config.frame_rate 66 | self.play_steps = int(framerate * play_steps_s) 67 | self.blocksize = blocksize 68 | 69 | if self.compile_mode not in (None, "default"): 70 | logger.warning( 71 | "Torch compilation modes that captures CUDA graphs are not yet compatible with the TTS part. Reverting to 'default'" 72 | ) 73 | self.compile_mode = "default" 74 | 75 | if self.compile_mode: 76 | self.model.generation_config.cache_implementation = "static" 77 | self.model.forward = torch.compile( 78 | self.model.forward, mode=self.compile_mode, fullgraph=True 79 | ) 80 | 81 | self.warmup() 82 | 83 | def prepare_model_inputs( 84 | self, 85 | prompt, 86 | max_length_prompt=50, 87 | pad=False, 88 | ): 89 | pad_args_prompt = ( 90 | {"padding": "max_length", "max_length": max_length_prompt} if pad else {} 91 | ) 92 | 93 | tokenized_description = self.description_tokenizer( 94 | self.description, return_tensors="pt" 95 | ) 96 | input_ids = tokenized_description.input_ids.to(self.device) 97 | attention_mask = tokenized_description.attention_mask.to(self.device) 98 | 99 | tokenized_prompt = self.prompt_tokenizer( 100 | prompt, return_tensors="pt", **pad_args_prompt 101 | ) 102 | prompt_input_ids = tokenized_prompt.input_ids.to(self.device) 103 | prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device) 104 | 105 | gen_kwargs = { 106 | "input_ids": input_ids, 107 | "attention_mask": attention_mask, 108 | "prompt_input_ids": prompt_input_ids, 109 | "prompt_attention_mask": prompt_attention_mask, 110 | **self.gen_kwargs, 111 | } 112 | 113 | return gen_kwargs 114 | 115 | def warmup(self): 116 | logger.info(f"Warming up {self.__class__.__name__}") 117 | 118 | if self.device == "cuda": 119 | start_event = torch.cuda.Event(enable_timing=True) 120 | end_event = torch.cuda.Event(enable_timing=True) 121 | 122 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture 123 | n_steps = 1 if self.compile_mode == "default" else 2 124 | 125 | if self.device == "cuda": 126 | torch.cuda.synchronize() 127 | start_event.record() 128 | if self.compile_mode: 129 | pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] 130 | for pad_length in pad_lengths[::-1]: 131 | model_kwargs = self.prepare_model_inputs( 132 | "dummy prompt", max_length_prompt=pad_length, pad=True 133 | ) 134 | for _ in range(n_steps): 135 | _ = self.model.generate(**model_kwargs) 136 | logger.info(f"Warmed up length {pad_length} tokens!") 137 | else: 138 | model_kwargs = self.prepare_model_inputs("dummy prompt") 139 | for _ in range(n_steps): 140 | _ = self.model.generate(**model_kwargs) 141 | 142 | if self.device == "cuda": 143 | end_event.record() 144 | torch.cuda.synchronize() 145 | logger.info( 146 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" 147 | ) 148 | 149 | def process(self, llm_sentence): 150 | if isinstance(llm_sentence, tuple): 151 | llm_sentence, _ = llm_sentence 152 | 153 | console.print(f"[green]ASSISTANT: {llm_sentence}") 154 | nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids) 155 | 156 | pad_args = {} 157 | if self.compile_mode: 158 | # pad to closest upper power of two 159 | pad_length = next_power_of_2(nb_tokens) 160 | logger.debug(f"padding to {pad_length}") 161 | pad_args["pad"] = True 162 | pad_args["max_length_prompt"] = pad_length 163 | 164 | tts_gen_kwargs = self.prepare_model_inputs( 165 | llm_sentence, 166 | **pad_args, 167 | ) 168 | 169 | streamer = ParlerTTSStreamer( 170 | self.model, device=self.device, play_steps=self.play_steps 171 | ) 172 | tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs} 173 | torch.manual_seed(0) 174 | thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs) 175 | thread.start() 176 | 177 | for i, audio_chunk in enumerate(streamer): 178 | global pipeline_start 179 | if i == 0 and "pipeline_start" in globals(): 180 | logger.info( 181 | f"Time to first audio: {perf_counter() - pipeline_start:.3f}" 182 | ) 183 | audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) 184 | audio_chunk = (audio_chunk * 32768).astype(np.int16) 185 | for i in range(0, len(audio_chunk), self.blocksize): 186 | yield np.pad( 187 | audio_chunk[i : i + self.blocksize], 188 | (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), 189 | ) 190 | 191 | self.should_listen.set() 192 | -------------------------------------------------------------------------------- /VAD/vad_handler.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | from VAD.vad_iterator import VADIterator 3 | from baseHandler import BaseHandler 4 | import numpy as np 5 | import torch 6 | from rich.console import Console 7 | 8 | from utils.utils import int2float 9 | from df.enhance import enhance, init_df 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | console = Console() 15 | 16 | 17 | class VADHandler(BaseHandler): 18 | """ 19 | Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed 20 | to the following part. 21 | """ 22 | 23 | def setup( 24 | self, 25 | should_listen, 26 | thresh=0.3, 27 | sample_rate=16000, 28 | min_silence_ms=1000, 29 | min_speech_ms=500, 30 | max_speech_ms=float("inf"), 31 | speech_pad_ms=30, 32 | audio_enhancement=False, 33 | ): 34 | self.should_listen = should_listen 35 | self.sample_rate = sample_rate 36 | self.min_silence_ms = min_silence_ms 37 | self.min_speech_ms = min_speech_ms 38 | self.max_speech_ms = max_speech_ms 39 | self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad") 40 | self.iterator = VADIterator( 41 | self.model, 42 | threshold=thresh, 43 | sampling_rate=sample_rate, 44 | min_silence_duration_ms=min_silence_ms, 45 | speech_pad_ms=speech_pad_ms, 46 | ) 47 | self.audio_enhancement = audio_enhancement 48 | if audio_enhancement: 49 | self.enhanced_model, self.df_state, _ = init_df() 50 | 51 | def process(self, audio_chunk): 52 | audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16) 53 | audio_float32 = int2float(audio_int16) 54 | vad_output = self.iterator(torch.from_numpy(audio_float32)) 55 | if vad_output is not None and len(vad_output) != 0: 56 | logger.debug("VAD: end of speech detected") 57 | array = torch.cat(vad_output).cpu().numpy() 58 | duration_ms = len(array) / self.sample_rate * 1000 59 | if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms: 60 | logger.debug( 61 | f"audio input of duration: {len(array) / self.sample_rate}s, skipping" 62 | ) 63 | else: 64 | self.should_listen.clear() 65 | logger.debug("Stop listening") 66 | if self.audio_enhancement: 67 | if self.sample_rate != self.df_state.sr(): 68 | audio_float32 = torchaudio.functional.resample( 69 | torch.from_numpy(array), 70 | orig_freq=self.sample_rate, 71 | new_freq=self.df_state.sr(), 72 | ) 73 | enhanced = enhance( 74 | self.enhanced_model, 75 | self.df_state, 76 | audio_float32.unsqueeze(0), 77 | ) 78 | enhanced = torchaudio.functional.resample( 79 | enhanced, 80 | orig_freq=self.df_state.sr(), 81 | new_freq=self.sample_rate, 82 | ) 83 | else: 84 | enhanced = enhance( 85 | self.enhanced_model, self.df_state, audio_float32 86 | ) 87 | array = enhanced.numpy().squeeze() 88 | yield array 89 | 90 | @property 91 | def min_time_to_debug(self): 92 | return 0.00001 93 | -------------------------------------------------------------------------------- /VAD/vad_iterator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class VADIterator: 5 | def __init__( 6 | self, 7 | model, 8 | threshold: float = 0.5, 9 | sampling_rate: int = 16000, 10 | min_silence_duration_ms: int = 100, 11 | speech_pad_ms: int = 30, 12 | ): 13 | """ 14 | Mainly taken from https://github.com/snakers4/silero-vad 15 | Class for stream imitation 16 | 17 | Parameters 18 | ---------- 19 | model: preloaded .jit/.onnx silero VAD model 20 | 21 | threshold: float (default - 0.5) 22 | Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. 23 | It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. 24 | 25 | sampling_rate: int (default - 16000) 26 | Currently silero VAD models support 8000 and 16000 sample rates 27 | 28 | min_silence_duration_ms: int (default - 100 milliseconds) 29 | In the end of each speech chunk wait for min_silence_duration_ms before separating it 30 | 31 | speech_pad_ms: int (default - 30 milliseconds) 32 | Final speech chunks are padded by speech_pad_ms each side 33 | """ 34 | 35 | self.model = model 36 | self.threshold = threshold 37 | self.sampling_rate = sampling_rate 38 | self.is_speaking = False 39 | self.buffer = [] 40 | 41 | if sampling_rate not in [8000, 16000]: 42 | raise ValueError( 43 | "VADIterator does not support sampling rates other than [8000, 16000]" 44 | ) 45 | 46 | self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 47 | self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 48 | self.reset_states() 49 | 50 | def reset_states(self): 51 | self.model.reset_states() 52 | self.triggered = False 53 | self.temp_end = 0 54 | self.current_sample = 0 55 | 56 | @torch.no_grad() 57 | def __call__(self, x): 58 | """ 59 | x: torch.Tensor 60 | audio chunk (see examples in repo) 61 | 62 | return_seconds: bool (default - False) 63 | whether return timestamps in seconds (default - samples) 64 | """ 65 | 66 | if not torch.is_tensor(x): 67 | try: 68 | x = torch.Tensor(x) 69 | except Exception: 70 | raise TypeError("Audio cannot be casted to tensor. Cast it manually") 71 | 72 | window_size_samples = len(x[0]) if x.dim() == 2 else len(x) 73 | self.current_sample += window_size_samples 74 | 75 | speech_prob = self.model(x, self.sampling_rate).item() 76 | 77 | if (speech_prob >= self.threshold) and self.temp_end: 78 | self.temp_end = 0 79 | 80 | if (speech_prob >= self.threshold) and not self.triggered: 81 | self.triggered = True 82 | return None 83 | 84 | if (speech_prob < self.threshold - 0.15) and self.triggered: 85 | if not self.temp_end: 86 | self.temp_end = self.current_sample 87 | if self.current_sample - self.temp_end < self.min_silence_samples: 88 | return None 89 | else: 90 | # end of speak 91 | self.temp_end = 0 92 | self.triggered = False 93 | spoken_utterance = self.buffer 94 | self.buffer = [] 95 | return spoken_utterance 96 | 97 | if self.triggered: 98 | self.buffer.append(x) 99 | 100 | return None 101 | -------------------------------------------------------------------------------- /arguments_classes/chat_tts_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class ChatTTSHandlerArguments: 6 | chat_tts_stream: bool = field( 7 | default=True, 8 | metadata={"help": "The tts mode is stream Default is 'stream'."}, 9 | ) 10 | chat_tts_device: str = field( 11 | default="cuda", 12 | metadata={ 13 | "help": "The device to be used for speech synthesis. Default is 'cuda'." 14 | }, 15 | ) 16 | chat_tts_chunk_size: int = field( 17 | default=512, 18 | metadata={ 19 | "help": "Sets the size of the audio data chunk processed per cycle, balancing playback latency and CPU load.. Default is 512。." 20 | }, 21 | ) 22 | -------------------------------------------------------------------------------- /arguments_classes/language_model_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class LanguageModelHandlerArguments: 6 | lm_model_name: str = field( 7 | default="HuggingFaceTB/SmolLM-360M-Instruct", 8 | metadata={ 9 | "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." 10 | }, 11 | ) 12 | lm_device: str = field( 13 | default="cuda", 14 | metadata={ 15 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 16 | }, 17 | ) 18 | lm_torch_dtype: str = field( 19 | default="float16", 20 | metadata={ 21 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." 22 | }, 23 | ) 24 | user_role: str = field( 25 | default="user", 26 | metadata={ 27 | "help": "Role assigned to the user in the chat context. Default is 'user'." 28 | }, 29 | ) 30 | init_chat_role: str = field( 31 | default="system", 32 | metadata={ 33 | "help": "Initial role for setting up the chat context. Default is 'system'." 34 | }, 35 | ) 36 | init_chat_prompt: str = field( 37 | default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", 38 | metadata={ 39 | "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" 40 | }, 41 | ) 42 | lm_gen_max_new_tokens: int = field( 43 | default=128, 44 | metadata={ 45 | "help": "Maximum number of new tokens to generate in a single completion. Default is 128." 46 | }, 47 | ) 48 | lm_gen_min_new_tokens: int = field( 49 | default=0, 50 | metadata={ 51 | "help": "Minimum number of new tokens to generate in a single completion. Default is 0." 52 | }, 53 | ) 54 | lm_gen_temperature: float = field( 55 | default=0.0, 56 | metadata={ 57 | "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." 58 | }, 59 | ) 60 | lm_gen_do_sample: bool = field( 61 | default=False, 62 | metadata={ 63 | "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." 64 | }, 65 | ) 66 | chat_size: int = field( 67 | default=2, 68 | metadata={ 69 | "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." 70 | }, 71 | ) 72 | -------------------------------------------------------------------------------- /arguments_classes/melo_tts_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class MeloTTSHandlerArguments: 6 | melo_language: str = field( 7 | default="en", 8 | metadata={ 9 | "help": "The language of the text to be synthesized. Default is 'EN_NEWEST'." 10 | }, 11 | ) 12 | melo_device: str = field( 13 | default="auto", 14 | metadata={ 15 | "help": "The device to be used for speech synthesis. Default is 'auto'." 16 | }, 17 | ) 18 | melo_speaker_to_id: str = field( 19 | default="en", 20 | metadata={ 21 | "help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']." 22 | }, 23 | ) 24 | -------------------------------------------------------------------------------- /arguments_classes/mlx_language_model_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class MLXLanguageModelHandlerArguments: 6 | mlx_lm_model_name: str = field( 7 | default="mlx-community/SmolLM-360M-Instruct", 8 | metadata={ 9 | "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." 10 | }, 11 | ) 12 | mlx_lm_device: str = field( 13 | default="mps", 14 | metadata={ 15 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 16 | }, 17 | ) 18 | mlx_lm_torch_dtype: str = field( 19 | default="float16", 20 | metadata={ 21 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." 22 | }, 23 | ) 24 | mlx_lm_user_role: str = field( 25 | default="user", 26 | metadata={ 27 | "help": "Role assigned to the user in the chat context. Default is 'user'." 28 | }, 29 | ) 30 | mlx_lm_init_chat_role: str = field( 31 | default="system", 32 | metadata={ 33 | "help": "Initial role for setting up the chat context. Default is 'system'." 34 | }, 35 | ) 36 | mlx_lm_init_chat_prompt: str = field( 37 | default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", 38 | metadata={ 39 | "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" 40 | }, 41 | ) 42 | mlx_lm_gen_max_new_tokens: int = field( 43 | default=128, 44 | metadata={ 45 | "help": "Maximum number of new tokens to generate in a single completion. Default is 128." 46 | }, 47 | ) 48 | mlx_lm_gen_temperature: float = field( 49 | default=0.0, 50 | metadata={ 51 | "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." 52 | }, 53 | ) 54 | mlx_lm_gen_do_sample: bool = field( 55 | default=False, 56 | metadata={ 57 | "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." 58 | }, 59 | ) 60 | mlx_lm_chat_size: int = field( 61 | default=2, 62 | metadata={ 63 | "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." 64 | }, 65 | ) 66 | -------------------------------------------------------------------------------- /arguments_classes/module_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModuleArguments: 7 | device: Optional[str] = field( 8 | default=None, 9 | metadata={"help": "If specified, overrides the device for all handlers."}, 10 | ) 11 | mode: Optional[str] = field( 12 | default="socket", 13 | metadata={ 14 | "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'socket'." 15 | }, 16 | ) 17 | local_mac_optimal_settings: bool = field( 18 | default=False, 19 | metadata={ 20 | "help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used." 21 | }, 22 | ) 23 | stt: Optional[str] = field( 24 | default="whisper", 25 | metadata={ 26 | "help": "The STT to use. Either 'whisper', 'whisper-mlx', and 'paraformer'. Default is 'whisper'." 27 | }, 28 | ) 29 | llm: Optional[str] = field( 30 | default="transformers", 31 | metadata={ 32 | "help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'" 33 | }, 34 | ) 35 | tts: Optional[str] = field( 36 | default="parler", 37 | metadata={ 38 | "help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'" 39 | }, 40 | ) 41 | log_level: str = field( 42 | default="info", 43 | metadata={ 44 | "help": "Provide logging level. Example --log_level debug, default=warning." 45 | }, 46 | ) 47 | -------------------------------------------------------------------------------- /arguments_classes/open_api_language_model_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class OpenApiLanguageModelHandlerArguments: 6 | open_api_model_name: str = field( 7 | # default="HuggingFaceTB/SmolLM-360M-Instruct", 8 | default="deepseek-chat", 9 | metadata={ 10 | "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." 11 | }, 12 | ) 13 | open_api_user_role: str = field( 14 | default="user", 15 | metadata={ 16 | "help": "Role assigned to the user in the chat context. Default is 'user'." 17 | }, 18 | ) 19 | open_api_init_chat_role: str = field( 20 | default="system", 21 | metadata={ 22 | "help": "Initial role for setting up the chat context. Default is 'system'." 23 | }, 24 | ) 25 | open_api_init_chat_prompt: str = field( 26 | # default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", 27 | default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", 28 | metadata={ 29 | "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" 30 | }, 31 | ) 32 | 33 | open_api_chat_size: int = field( 34 | default=2, 35 | metadata={ 36 | "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." 37 | }, 38 | ) 39 | open_api_api_key: str = field( 40 | default=None, 41 | metadata={ 42 | "help": "Is a unique code used to authenticate and authorize access to an API.Default is None" 43 | }, 44 | ) 45 | open_api_base_url: str = field( 46 | default=None, 47 | metadata={ 48 | "help": "Is the root URL for all endpoints of an API, serving as the starting point for constructing API request.Default is Non" 49 | }, 50 | ) 51 | open_api_stream: bool = field( 52 | default=False, 53 | metadata={ 54 | "help": "The stream parameter typically indicates whether data should be transmitted in a continuous flow rather" 55 | " than in a single, complete response, often used for handling large or real-time data.Default is False" 56 | }, 57 | ) -------------------------------------------------------------------------------- /arguments_classes/paraformer_stt_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class ParaformerSTTHandlerArguments: 6 | paraformer_stt_model_name: str = field( 7 | default="paraformer-zh", 8 | metadata={ 9 | "help": "The pretrained model to use. Default is 'paraformer-zh'. Can be choose from https://github.com/modelscope/FunASR" 10 | }, 11 | ) 12 | paraformer_stt_device: str = field( 13 | default="cuda", 14 | metadata={ 15 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 16 | }, 17 | ) 18 | -------------------------------------------------------------------------------- /arguments_classes/parler_tts_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class ParlerTTSHandlerArguments: 6 | tts_model_name: str = field( 7 | default="ylacombe/parler-tts-mini-jenny-30H", 8 | metadata={ 9 | "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'." 10 | }, 11 | ) 12 | tts_device: str = field( 13 | default="cuda", 14 | metadata={ 15 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 16 | }, 17 | ) 18 | tts_torch_dtype: str = field( 19 | default="float16", 20 | metadata={ 21 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." 22 | }, 23 | ) 24 | tts_compile_mode: str = field( 25 | default=None, 26 | metadata={ 27 | "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" 28 | }, 29 | ) 30 | tts_gen_min_new_tokens: int = field( 31 | default=64, 32 | metadata={ 33 | "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs" 34 | }, 35 | ) 36 | tts_gen_max_new_tokens: int = field( 37 | default=512, 38 | metadata={ 39 | "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs" 40 | }, 41 | ) 42 | description: str = field( 43 | default=( 44 | "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " 45 | "She speaks very fast." 46 | ), 47 | metadata={ 48 | "help": "Description of the speaker's voice and speaking style to guide the TTS model." 49 | }, 50 | ) 51 | play_steps_s: float = field( 52 | default=1.0, 53 | metadata={ 54 | "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." 55 | }, 56 | ) 57 | max_prompt_pad_length: int = field( 58 | default=8, 59 | metadata={ 60 | "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible." 61 | }, 62 | ) 63 | -------------------------------------------------------------------------------- /arguments_classes/socket_receiver_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class SocketReceiverArguments: 6 | recv_host: str = field( 7 | default="localhost", 8 | metadata={ 9 | "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all " 10 | "available interfaces on the host machine." 11 | }, 12 | ) 13 | recv_port: int = field( 14 | default=12345, 15 | metadata={ 16 | "help": "The port number on which the socket server listens. Default is 12346." 17 | }, 18 | ) 19 | chunk_size: int = field( 20 | default=1024, 21 | metadata={ 22 | "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes." 23 | }, 24 | ) 25 | -------------------------------------------------------------------------------- /arguments_classes/socket_sender_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class SocketSenderArguments: 6 | send_host: str = field( 7 | default="localhost", 8 | metadata={ 9 | "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all " 10 | "available interfaces on the host machine." 11 | }, 12 | ) 13 | send_port: int = field( 14 | default=12346, 15 | metadata={ 16 | "help": "The port number on which the socket server listens. Default is 12346." 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /arguments_classes/vad_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class VADHandlerArguments: 6 | thresh: float = field( 7 | default=0.3, 8 | metadata={ 9 | "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection." 10 | }, 11 | ) 12 | sample_rate: int = field( 13 | default=16000, 14 | metadata={ 15 | "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio." 16 | }, 17 | ) 18 | min_silence_ms: int = field( 19 | default=250, 20 | metadata={ 21 | "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms." 22 | }, 23 | ) 24 | min_speech_ms: int = field( 25 | default=500, 26 | metadata={ 27 | "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms." 28 | }, 29 | ) 30 | max_speech_ms: float = field( 31 | default=float("inf"), 32 | metadata={ 33 | "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments." 34 | }, 35 | ) 36 | speech_pad_ms: int = field( 37 | default=500, 38 | metadata={ 39 | "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms." 40 | }, 41 | ) 42 | audio_enhancement: bool = field( 43 | default=False, 44 | metadata={ 45 | "help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is False." 46 | }, 47 | ) 48 | -------------------------------------------------------------------------------- /arguments_classes/whisper_stt_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class WhisperSTTHandlerArguments: 7 | stt_model_name: str = field( 8 | default="distil-whisper/distil-large-v3", 9 | metadata={ 10 | "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'." 11 | }, 12 | ) 13 | stt_device: str = field( 14 | default="cuda", 15 | metadata={ 16 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 17 | }, 18 | ) 19 | stt_torch_dtype: str = field( 20 | default="float16", 21 | metadata={ 22 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." 23 | }, 24 | ) 25 | stt_compile_mode: str = field( 26 | default=None, 27 | metadata={ 28 | "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" 29 | }, 30 | ) 31 | stt_gen_max_new_tokens: int = field( 32 | default=128, 33 | metadata={ 34 | "help": "The maximum number of new tokens to generate. Default is 128." 35 | }, 36 | ) 37 | stt_gen_num_beams: int = field( 38 | default=1, 39 | metadata={ 40 | "help": "The number of beams for beam search. Default is 1, implying greedy decoding." 41 | }, 42 | ) 43 | stt_gen_return_timestamps: bool = field( 44 | default=False, 45 | metadata={ 46 | "help": "Whether to return timestamps with transcriptions. Default is False." 47 | }, 48 | ) 49 | stt_gen_task: str = field( 50 | default="transcribe", 51 | metadata={ 52 | "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." 53 | }, 54 | ) 55 | language: Optional[str] = field( 56 | default='en', 57 | metadata={ 58 | "help": """The language for the conversation. 59 | Choose between 'en' (english), 'fr' (french), 'es' (spanish), 60 | 'zh' (chinese), 'ko' (korean), 'ja' (japanese), or 'None'. 61 | If using 'auto', the language is automatically detected and can 62 | change during the conversation. Default is 'en'.""" 63 | }, 64 | ) -------------------------------------------------------------------------------- /baseHandler.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class BaseHandler: 8 | """ 9 | Base class for pipeline parts. Each part of the pipeline has an input and an output queue. 10 | The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part. 11 | To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue. 12 | Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue. 13 | The cleanup method handles stopping the handler, and b"END" is placed in the output queue. 14 | """ 15 | 16 | def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}): 17 | self.stop_event = stop_event 18 | self.queue_in = queue_in 19 | self.queue_out = queue_out 20 | self.setup(*setup_args, **setup_kwargs) 21 | self._times = [] 22 | 23 | def setup(self): 24 | pass 25 | 26 | def process(self): 27 | raise NotImplementedError 28 | 29 | def run(self): 30 | while not self.stop_event.is_set(): 31 | input = self.queue_in.get() 32 | if isinstance(input, bytes) and input == b"END": 33 | # sentinelle signal to avoid queue deadlock 34 | logger.debug("Stopping thread") 35 | break 36 | start_time = perf_counter() 37 | for output in self.process(input): 38 | self._times.append(perf_counter() - start_time) 39 | if self.last_time > self.min_time_to_debug: 40 | logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s") 41 | self.queue_out.put(output) 42 | start_time = perf_counter() 43 | 44 | self.cleanup() 45 | self.queue_out.put(b"END") 46 | 47 | @property 48 | def last_time(self): 49 | return self._times[-1] 50 | 51 | @property 52 | def min_time_to_debug(self): 53 | return 0.001 54 | 55 | def cleanup(self): 56 | pass 57 | -------------------------------------------------------------------------------- /connections/local_audio_streamer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import sounddevice as sd 3 | import numpy as np 4 | 5 | import time 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class LocalAudioStreamer: 12 | def __init__( 13 | self, 14 | input_queue, 15 | output_queue, 16 | list_play_chunk_size=512, 17 | ): 18 | self.list_play_chunk_size = list_play_chunk_size 19 | 20 | self.stop_event = threading.Event() 21 | self.input_queue = input_queue 22 | self.output_queue = output_queue 23 | 24 | def run(self): 25 | def callback(indata, outdata, frames, time, status): 26 | if self.output_queue.empty(): 27 | self.input_queue.put(indata.copy()) 28 | outdata[:] = 0 * outdata 29 | else: 30 | outdata[:] = self.output_queue.get()[:, np.newaxis] 31 | 32 | logger.debug("Available devices:") 33 | logger.debug(sd.query_devices()) 34 | with sd.Stream( 35 | samplerate=16000, 36 | dtype="int16", 37 | channels=1, 38 | callback=callback, 39 | blocksize=self.list_play_chunk_size, 40 | ): 41 | logger.info("Starting local audio stream") 42 | while not self.stop_event.is_set(): 43 | time.sleep(0.001) 44 | print("Stopping recording") 45 | -------------------------------------------------------------------------------- /connections/socket_receiver.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from rich.console import Console 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | console = Console() 8 | 9 | 10 | class SocketReceiver: 11 | """ 12 | Handles reception of the audio packets from the client. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | stop_event, 18 | queue_out, 19 | should_listen, 20 | host="0.0.0.0", 21 | port=12345, 22 | chunk_size=1024, 23 | ): 24 | self.stop_event = stop_event 25 | self.queue_out = queue_out 26 | self.should_listen = should_listen 27 | self.chunk_size = chunk_size 28 | self.host = host 29 | self.port = port 30 | 31 | def receive_full_chunk(self, conn, chunk_size): 32 | data = b"" 33 | while len(data) < chunk_size: 34 | packet = conn.recv(chunk_size - len(data)) 35 | if not packet: 36 | # connection closed 37 | return None 38 | data += packet 39 | return data 40 | 41 | def run(self): 42 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 43 | self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 44 | self.socket.bind((self.host, self.port)) 45 | self.socket.listen(1) 46 | logger.info("Receiver waiting to be connected...") 47 | self.conn, _ = self.socket.accept() 48 | logger.info("receiver connected") 49 | 50 | self.should_listen.set() 51 | while not self.stop_event.is_set(): 52 | audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size) 53 | if audio_chunk is None: 54 | # connection closed 55 | self.queue_out.put(b"END") 56 | break 57 | if self.should_listen.is_set(): 58 | self.queue_out.put(audio_chunk) 59 | self.conn.close() 60 | logger.info("Receiver closed") 61 | -------------------------------------------------------------------------------- /connections/socket_sender.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from rich.console import Console 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | console = Console() 8 | 9 | 10 | class SocketSender: 11 | """ 12 | Handles sending generated audio packets to the clients. 13 | """ 14 | 15 | def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346): 16 | self.stop_event = stop_event 17 | self.queue_in = queue_in 18 | self.host = host 19 | self.port = port 20 | 21 | def run(self): 22 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 23 | self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 24 | self.socket.bind((self.host, self.port)) 25 | self.socket.listen(1) 26 | logger.info("Sender waiting to be connected...") 27 | self.conn, _ = self.socket.accept() 28 | logger.info("sender connected") 29 | 30 | while not self.stop_event.is_set(): 31 | audio_chunk = self.queue_in.get() 32 | self.conn.sendall(audio_chunk) 33 | if isinstance(audio_chunk, bytes) and audio_chunk == b"END": 34 | break 35 | self.conn.close() 36 | logger.info("Sender closed") 37 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | --- 2 | services: 3 | 4 | pipeline: 5 | build: 6 | context: . 7 | dockerfile: ${DOCKERFILE:-Dockerfile} 8 | command: 9 | - python3 10 | - s2s_pipeline.py 11 | - --recv_host 12 | - 0.0.0.0 13 | - --send_host 14 | - 0.0.0.0 15 | - --lm_model_name 16 | - microsoft/Phi-3-mini-4k-instruct 17 | - --init_chat_role 18 | - system 19 | - --init_chat_prompt 20 | - "You are a helpful assistant" 21 | - --stt_compile_mode 22 | - reduce-overhead 23 | - --tts_compile_mode 24 | - default 25 | expose: 26 | - 12345/tcp 27 | - 12346/tcp 28 | ports: 29 | - 12345:12345/tcp 30 | - 12346:12346/tcp 31 | volumes: 32 | - ./cache/:/root/.cache/ 33 | - ./s2s_pipeline.py:/usr/src/app/s2s_pipeline.py 34 | deploy: 35 | resources: 36 | reservations: 37 | devices: 38 | - driver: nvidia 39 | device_ids: ['0'] 40 | capabilities: [gpu] 41 | -------------------------------------------------------------------------------- /listen_and_play.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import threading 3 | from queue import Queue 4 | from dataclasses import dataclass, field 5 | import sounddevice as sd 6 | from transformers import HfArgumentParser 7 | 8 | 9 | @dataclass 10 | class ListenAndPlayArguments: 11 | send_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."}) 12 | recv_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."}) 13 | list_play_chunk_size: int = field( 14 | default=1024, 15 | metadata={"help": "The size of data chunks (in bytes). Default is 1024."}, 16 | ) 17 | host: str = field( 18 | default="localhost", 19 | metadata={ 20 | "help": "The hostname or IP address for listening and playing. Default is 'localhost'." 21 | }, 22 | ) 23 | send_port: int = field( 24 | default=12345, 25 | metadata={"help": "The network port for sending data. Default is 12345."}, 26 | ) 27 | recv_port: int = field( 28 | default=12346, 29 | metadata={"help": "The network port for receiving data. Default is 12346."}, 30 | ) 31 | 32 | 33 | def listen_and_play( 34 | send_rate=16000, 35 | recv_rate=44100, 36 | list_play_chunk_size=1024, 37 | host="localhost", 38 | send_port=12345, 39 | recv_port=12346, 40 | ): 41 | send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 42 | send_socket.connect((host, send_port)) 43 | 44 | recv_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 45 | recv_socket.connect((host, recv_port)) 46 | 47 | print("Recording and streaming...") 48 | 49 | stop_event = threading.Event() 50 | recv_queue = Queue() 51 | send_queue = Queue() 52 | 53 | def callback_recv(outdata, frames, time, status): 54 | if not recv_queue.empty(): 55 | data = recv_queue.get() 56 | outdata[: len(data)] = data 57 | outdata[len(data) :] = b"\x00" * (len(outdata) - len(data)) 58 | else: 59 | outdata[:] = b"\x00" * len(outdata) 60 | 61 | def callback_send(indata, frames, time, status): 62 | if recv_queue.empty(): 63 | data = bytes(indata) 64 | send_queue.put(data) 65 | 66 | def send(stop_event, send_queue): 67 | while not stop_event.is_set(): 68 | data = send_queue.get() 69 | send_socket.sendall(data) 70 | 71 | def recv(stop_event, recv_queue): 72 | def receive_full_chunk(conn, chunk_size): 73 | data = b"" 74 | while len(data) < chunk_size: 75 | packet = conn.recv(chunk_size - len(data)) 76 | if not packet: 77 | return None # Connection has been closed 78 | data += packet 79 | return data 80 | 81 | while not stop_event.is_set(): 82 | data = receive_full_chunk(recv_socket, list_play_chunk_size * 2) 83 | if data: 84 | recv_queue.put(data) 85 | 86 | try: 87 | send_stream = sd.RawInputStream( 88 | samplerate=send_rate, 89 | channels=1, 90 | dtype="int16", 91 | blocksize=list_play_chunk_size, 92 | callback=callback_send, 93 | ) 94 | recv_stream = sd.RawOutputStream( 95 | samplerate=recv_rate, 96 | channels=1, 97 | dtype="int16", 98 | blocksize=list_play_chunk_size, 99 | callback=callback_recv, 100 | ) 101 | threading.Thread(target=send_stream.start).start() 102 | threading.Thread(target=recv_stream.start).start() 103 | 104 | send_thread = threading.Thread(target=send, args=(stop_event, send_queue)) 105 | send_thread.start() 106 | recv_thread = threading.Thread(target=recv, args=(stop_event, recv_queue)) 107 | recv_thread.start() 108 | 109 | input("Press Enter to stop...") 110 | 111 | except KeyboardInterrupt: 112 | print("Finished streaming.") 113 | 114 | finally: 115 | stop_event.set() 116 | recv_thread.join() 117 | send_thread.join() 118 | send_socket.close() 119 | recv_socket.close() 120 | print("Connection closed.") 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = HfArgumentParser((ListenAndPlayArguments,)) 125 | (listen_and_play_kwargs,) = parser.parse_args_into_dataclasses() 126 | listen_and_play(**vars(listen_and_play_kwargs)) 127 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eustlb/speech-to-speech/d5e460721e578fef286c7b64e68ad6a57a25cf1b/logo.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk==3.9.1 2 | parler_tts @ git+https://github.com/huggingface/parler-tts.git 3 | melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers 4 | torch==2.4.0 5 | sounddevice==0.5.0 6 | ChatTTS>=0.1.1 7 | funasr>=1.1.6 8 | modelscope>=1.17.1 9 | deepfilternet>=0.5.6 10 | openai>=1.40.1 -------------------------------------------------------------------------------- /requirements_mac.txt: -------------------------------------------------------------------------------- 1 | nltk==3.9.1 2 | parler_tts @ git+https://github.com/huggingface/parler-tts.git 3 | melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers 4 | torch==2.4.0 5 | sounddevice==0.5.0 6 | lightning-whisper-mlx>=0.0.10 7 | mlx-lm>=0.14.0 8 | ChatTTS>=0.1.1 9 | funasr>=1.1.6 10 | modelscope>=1.17.1 11 | deepfilternet>=0.5.6 12 | openai>=1.40.1 -------------------------------------------------------------------------------- /s2s_pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from copy import copy 5 | from pathlib import Path 6 | from queue import Queue 7 | from threading import Event 8 | from typing import Optional 9 | from sys import platform 10 | from VAD.vad_handler import VADHandler 11 | from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments 12 | from arguments_classes.language_model_arguments import LanguageModelHandlerArguments 13 | from arguments_classes.mlx_language_model_arguments import ( 14 | MLXLanguageModelHandlerArguments, 15 | ) 16 | from arguments_classes.module_arguments import ModuleArguments 17 | from arguments_classes.paraformer_stt_arguments import ParaformerSTTHandlerArguments 18 | from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments 19 | from arguments_classes.socket_receiver_arguments import SocketReceiverArguments 20 | from arguments_classes.socket_sender_arguments import SocketSenderArguments 21 | from arguments_classes.vad_arguments import VADHandlerArguments 22 | from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments 23 | from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments 24 | from arguments_classes.open_api_language_model_arguments import OpenApiLanguageModelHandlerArguments 25 | import torch 26 | import nltk 27 | from rich.console import Console 28 | from transformers import ( 29 | HfArgumentParser, 30 | ) 31 | 32 | from utils.thread_manager import ThreadManager 33 | 34 | # Ensure that the necessary NLTK resources are available 35 | try: 36 | nltk.data.find("tokenizers/punkt_tab") 37 | except (LookupError, OSError): 38 | nltk.download("punkt_tab") 39 | try: 40 | nltk.data.find("tokenizers/averaged_perceptron_tagger_eng") 41 | except (LookupError, OSError): 42 | nltk.download("averaged_perceptron_tagger_eng") 43 | 44 | # caching allows ~50% compilation time reduction 45 | # see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma 46 | CURRENT_DIR = Path(__file__).resolve().parent 47 | os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") 48 | 49 | console = Console() 50 | logging.getLogger("numba").setLevel(logging.WARNING) # quiet down numba logs 51 | 52 | 53 | def rename_args(args, prefix): 54 | """ 55 | Rename arguments by removing the prefix and prepares the gen_kwargs. 56 | """ 57 | gen_kwargs = {} 58 | for key in copy(args.__dict__): 59 | if key.startswith(prefix): 60 | value = args.__dict__.pop(key) 61 | new_key = key[len(prefix) + 1 :] # Remove prefix and underscore 62 | if new_key.startswith("gen_"): 63 | gen_kwargs[new_key[4:]] = value # Remove 'gen_' and add to dict 64 | else: 65 | args.__dict__[new_key] = value 66 | 67 | args.__dict__["gen_kwargs"] = gen_kwargs 68 | 69 | 70 | def parse_arguments(): 71 | parser = HfArgumentParser( 72 | ( 73 | ModuleArguments, 74 | SocketReceiverArguments, 75 | SocketSenderArguments, 76 | VADHandlerArguments, 77 | WhisperSTTHandlerArguments, 78 | ParaformerSTTHandlerArguments, 79 | LanguageModelHandlerArguments, 80 | OpenApiLanguageModelHandlerArguments, 81 | MLXLanguageModelHandlerArguments, 82 | ParlerTTSHandlerArguments, 83 | MeloTTSHandlerArguments, 84 | ChatTTSHandlerArguments, 85 | ) 86 | ) 87 | 88 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 89 | # Parse configurations from a JSON file if specified 90 | return parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 91 | else: 92 | # Parse arguments from command line if no JSON file is provided 93 | return parser.parse_args_into_dataclasses() 94 | 95 | 96 | def setup_logger(log_level): 97 | global logger 98 | logging.basicConfig( 99 | level=log_level.upper(), 100 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 101 | ) 102 | logger = logging.getLogger(__name__) 103 | 104 | # torch compile logs 105 | if log_level == "debug": 106 | torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) 107 | 108 | 109 | def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs): 110 | if mac_optimal_settings: 111 | for kwargs in handler_kwargs: 112 | if hasattr(kwargs, "device"): 113 | kwargs.device = "mps" 114 | if hasattr(kwargs, "mode"): 115 | kwargs.mode = "local" 116 | if hasattr(kwargs, "stt"): 117 | kwargs.stt = "whisper-mlx" 118 | if hasattr(kwargs, "llm"): 119 | kwargs.llm = "mlx-lm" 120 | if hasattr(kwargs, "tts"): 121 | kwargs.tts = "melo" 122 | 123 | 124 | def check_mac_settings(module_kwargs): 125 | if platform == "darwin": 126 | if module_kwargs.device == "cuda": 127 | raise ValueError( 128 | "Cannot use CUDA on macOS. Please set the device to 'cpu' or 'mps'." 129 | ) 130 | if module_kwargs.llm != "mlx-lm": 131 | logger.warning( 132 | "For macOS users, it is recommended to use mlx-lm. You can activate it by passing --llm mlx-lm." 133 | ) 134 | if module_kwargs.tts != "melo": 135 | logger.warning( 136 | "If you experiences issues generating the voice, considering setting the tts to melo." 137 | ) 138 | 139 | 140 | def overwrite_device_argument(common_device: Optional[str], *handler_kwargs): 141 | if common_device: 142 | for kwargs in handler_kwargs: 143 | if hasattr(kwargs, "lm_device"): 144 | kwargs.lm_device = common_device 145 | if hasattr(kwargs, "tts_device"): 146 | kwargs.tts_device = common_device 147 | if hasattr(kwargs, "stt_device"): 148 | kwargs.stt_device = common_device 149 | if hasattr(kwargs, "paraformer_stt_device"): 150 | kwargs.paraformer_stt_device = common_device 151 | 152 | 153 | def prepare_module_args(module_kwargs, *handler_kwargs): 154 | optimal_mac_settings(module_kwargs.local_mac_optimal_settings, module_kwargs) 155 | if platform == "darwin": 156 | check_mac_settings(module_kwargs) 157 | overwrite_device_argument(module_kwargs.device, *handler_kwargs) 158 | 159 | 160 | def prepare_all_args( 161 | module_kwargs, 162 | whisper_stt_handler_kwargs, 163 | paraformer_stt_handler_kwargs, 164 | language_model_handler_kwargs, 165 | open_api_language_model_handler_kwargs, 166 | mlx_language_model_handler_kwargs, 167 | parler_tts_handler_kwargs, 168 | melo_tts_handler_kwargs, 169 | chat_tts_handler_kwargs, 170 | ): 171 | prepare_module_args( 172 | module_kwargs, 173 | whisper_stt_handler_kwargs, 174 | paraformer_stt_handler_kwargs, 175 | language_model_handler_kwargs, 176 | open_api_language_model_handler_kwargs, 177 | mlx_language_model_handler_kwargs, 178 | parler_tts_handler_kwargs, 179 | melo_tts_handler_kwargs, 180 | chat_tts_handler_kwargs, 181 | ) 182 | 183 | 184 | rename_args(whisper_stt_handler_kwargs, "stt") 185 | rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") 186 | rename_args(language_model_handler_kwargs, "lm") 187 | rename_args(mlx_language_model_handler_kwargs, "mlx_lm") 188 | rename_args(open_api_language_model_handler_kwargs, "open_api") 189 | rename_args(parler_tts_handler_kwargs, "tts") 190 | rename_args(melo_tts_handler_kwargs, "melo") 191 | rename_args(chat_tts_handler_kwargs, "chat_tts") 192 | 193 | 194 | def initialize_queues_and_events(): 195 | return { 196 | "stop_event": Event(), 197 | "should_listen": Event(), 198 | "recv_audio_chunks_queue": Queue(), 199 | "send_audio_chunks_queue": Queue(), 200 | "spoken_prompt_queue": Queue(), 201 | "text_prompt_queue": Queue(), 202 | "lm_response_queue": Queue(), 203 | } 204 | 205 | 206 | def build_pipeline( 207 | module_kwargs, 208 | socket_receiver_kwargs, 209 | socket_sender_kwargs, 210 | vad_handler_kwargs, 211 | whisper_stt_handler_kwargs, 212 | paraformer_stt_handler_kwargs, 213 | language_model_handler_kwargs, 214 | open_api_language_model_handler_kwargs, 215 | mlx_language_model_handler_kwargs, 216 | parler_tts_handler_kwargs, 217 | melo_tts_handler_kwargs, 218 | chat_tts_handler_kwargs, 219 | queues_and_events, 220 | ): 221 | stop_event = queues_and_events["stop_event"] 222 | should_listen = queues_and_events["should_listen"] 223 | recv_audio_chunks_queue = queues_and_events["recv_audio_chunks_queue"] 224 | send_audio_chunks_queue = queues_and_events["send_audio_chunks_queue"] 225 | spoken_prompt_queue = queues_and_events["spoken_prompt_queue"] 226 | text_prompt_queue = queues_and_events["text_prompt_queue"] 227 | lm_response_queue = queues_and_events["lm_response_queue"] 228 | if module_kwargs.mode == "local": 229 | from connections.local_audio_streamer import LocalAudioStreamer 230 | 231 | local_audio_streamer = LocalAudioStreamer( 232 | input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue 233 | ) 234 | comms_handlers = [local_audio_streamer] 235 | should_listen.set() 236 | else: 237 | from connections.socket_receiver import SocketReceiver 238 | from connections.socket_sender import SocketSender 239 | 240 | comms_handlers = [ 241 | SocketReceiver( 242 | stop_event, 243 | recv_audio_chunks_queue, 244 | should_listen, 245 | host=socket_receiver_kwargs.recv_host, 246 | port=socket_receiver_kwargs.recv_port, 247 | chunk_size=socket_receiver_kwargs.chunk_size, 248 | ), 249 | SocketSender( 250 | stop_event, 251 | send_audio_chunks_queue, 252 | host=socket_sender_kwargs.send_host, 253 | port=socket_sender_kwargs.send_port, 254 | ), 255 | ] 256 | 257 | vad = VADHandler( 258 | stop_event, 259 | queue_in=recv_audio_chunks_queue, 260 | queue_out=spoken_prompt_queue, 261 | setup_args=(should_listen,), 262 | setup_kwargs=vars(vad_handler_kwargs), 263 | ) 264 | 265 | stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) 266 | lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs) 267 | tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs) 268 | 269 | return ThreadManager([*comms_handlers, vad, stt, lm, tts]) 270 | 271 | 272 | def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs): 273 | if module_kwargs.stt == "whisper": 274 | from STT.whisper_stt_handler import WhisperSTTHandler 275 | return WhisperSTTHandler( 276 | stop_event, 277 | queue_in=spoken_prompt_queue, 278 | queue_out=text_prompt_queue, 279 | setup_kwargs=vars(whisper_stt_handler_kwargs), 280 | ) 281 | elif module_kwargs.stt == "whisper-mlx": 282 | from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler 283 | return LightningWhisperSTTHandler( 284 | stop_event, 285 | queue_in=spoken_prompt_queue, 286 | queue_out=text_prompt_queue, 287 | setup_kwargs=vars(whisper_stt_handler_kwargs), 288 | ) 289 | elif module_kwargs.stt == "paraformer": 290 | from STT.paraformer_handler import ParaformerSTTHandler 291 | return ParaformerSTTHandler( 292 | stop_event, 293 | queue_in=spoken_prompt_queue, 294 | queue_out=text_prompt_queue, 295 | setup_kwargs=vars(paraformer_stt_handler_kwargs), 296 | ) 297 | else: 298 | raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.") 299 | 300 | 301 | def get_llm_handler( 302 | module_kwargs, 303 | stop_event, 304 | text_prompt_queue, 305 | lm_response_queue, 306 | language_model_handler_kwargs, 307 | open_api_language_model_handler_kwargs, 308 | mlx_language_model_handler_kwargs 309 | ): 310 | if module_kwargs.llm == "transformers": 311 | from LLM.language_model import LanguageModelHandler 312 | return LanguageModelHandler( 313 | stop_event, 314 | queue_in=text_prompt_queue, 315 | queue_out=lm_response_queue, 316 | setup_kwargs=vars(language_model_handler_kwargs), 317 | ) 318 | elif module_kwargs.llm == "open_api": 319 | from LLM.openai_api_language_model import OpenApiModelHandler 320 | return OpenApiModelHandler( 321 | stop_event, 322 | queue_in=text_prompt_queue, 323 | queue_out=lm_response_queue, 324 | setup_kwargs=vars(open_api_language_model_handler_kwargs), 325 | ) 326 | 327 | elif module_kwargs.llm == "mlx-lm": 328 | from LLM.mlx_language_model import MLXLanguageModelHandler 329 | return MLXLanguageModelHandler( 330 | stop_event, 331 | queue_in=text_prompt_queue, 332 | queue_out=lm_response_queue, 333 | setup_kwargs=vars(mlx_language_model_handler_kwargs), 334 | ) 335 | 336 | else: 337 | raise ValueError("The LLM should be either transformers or mlx-lm") 338 | 339 | 340 | def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs): 341 | if module_kwargs.tts == "parler": 342 | from TTS.parler_handler import ParlerTTSHandler 343 | return ParlerTTSHandler( 344 | stop_event, 345 | queue_in=lm_response_queue, 346 | queue_out=send_audio_chunks_queue, 347 | setup_args=(should_listen,), 348 | setup_kwargs=vars(parler_tts_handler_kwargs), 349 | ) 350 | elif module_kwargs.tts == "melo": 351 | try: 352 | from TTS.melo_handler import MeloTTSHandler 353 | except RuntimeError as e: 354 | logger.error( 355 | "Error importing MeloTTSHandler. You might need to run: python -m unidic download" 356 | ) 357 | raise e 358 | return MeloTTSHandler( 359 | stop_event, 360 | queue_in=lm_response_queue, 361 | queue_out=send_audio_chunks_queue, 362 | setup_args=(should_listen,), 363 | setup_kwargs=vars(melo_tts_handler_kwargs), 364 | ) 365 | elif module_kwargs.tts == "chatTTS": 366 | try: 367 | from TTS.chatTTS_handler import ChatTTSHandler 368 | except RuntimeError as e: 369 | logger.error("Error importing ChatTTSHandler") 370 | raise e 371 | return ChatTTSHandler( 372 | stop_event, 373 | queue_in=lm_response_queue, 374 | queue_out=send_audio_chunks_queue, 375 | setup_args=(should_listen,), 376 | setup_kwargs=vars(chat_tts_handler_kwargs), 377 | ) 378 | else: 379 | raise ValueError("The TTS should be either parler, melo or chatTTS") 380 | 381 | 382 | def main(): 383 | ( 384 | module_kwargs, 385 | socket_receiver_kwargs, 386 | socket_sender_kwargs, 387 | vad_handler_kwargs, 388 | whisper_stt_handler_kwargs, 389 | paraformer_stt_handler_kwargs, 390 | language_model_handler_kwargs, 391 | open_api_language_model_handler_kwargs, 392 | mlx_language_model_handler_kwargs, 393 | parler_tts_handler_kwargs, 394 | melo_tts_handler_kwargs, 395 | chat_tts_handler_kwargs, 396 | ) = parse_arguments() 397 | 398 | setup_logger(module_kwargs.log_level) 399 | 400 | prepare_all_args( 401 | module_kwargs, 402 | whisper_stt_handler_kwargs, 403 | paraformer_stt_handler_kwargs, 404 | language_model_handler_kwargs, 405 | open_api_language_model_handler_kwargs, 406 | mlx_language_model_handler_kwargs, 407 | parler_tts_handler_kwargs, 408 | melo_tts_handler_kwargs, 409 | chat_tts_handler_kwargs, 410 | ) 411 | 412 | queues_and_events = initialize_queues_and_events() 413 | 414 | pipeline_manager = build_pipeline( 415 | module_kwargs, 416 | socket_receiver_kwargs, 417 | socket_sender_kwargs, 418 | vad_handler_kwargs, 419 | whisper_stt_handler_kwargs, 420 | paraformer_stt_handler_kwargs, 421 | language_model_handler_kwargs, 422 | open_api_language_model_handler_kwargs, 423 | mlx_language_model_handler_kwargs, 424 | parler_tts_handler_kwargs, 425 | melo_tts_handler_kwargs, 426 | chat_tts_handler_kwargs, 427 | queues_and_events, 428 | ) 429 | 430 | try: 431 | pipeline_manager.start() 432 | except KeyboardInterrupt: 433 | pipeline_manager.stop() 434 | 435 | 436 | if __name__ == "__main__": 437 | main() -------------------------------------------------------------------------------- /utils/thread_manager.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | 4 | class ThreadManager: 5 | """ 6 | Manages multiple threads used to execute given handler tasks. 7 | """ 8 | 9 | def __init__(self, handlers): 10 | self.handlers = handlers 11 | self.threads = [] 12 | 13 | def start(self): 14 | for handler in self.handlers: 15 | thread = threading.Thread(target=handler.run) 16 | self.threads.append(thread) 17 | thread.start() 18 | 19 | def stop(self): 20 | for handler in self.handlers: 21 | handler.stop_event.set() 22 | for thread in self.threads: 23 | thread.join() 24 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def next_power_of_2(x): 5 | return 1 if x == 0 else 2 ** (x - 1).bit_length() 6 | 7 | 8 | def int2float(sound): 9 | """ 10 | Taken from https://github.com/snakers4/silero-vad 11 | """ 12 | 13 | abs_max = np.abs(sound).max() 14 | sound = sound.astype("float32") 15 | if abs_max > 0: 16 | sound *= 1 / 32768 17 | sound = sound.squeeze() # depends on the use case 18 | return sound 19 | --------------------------------------------------------------------------------