├── .gitignore
├── LLM
├── __pycache__
│ ├── chat.cpython-310.pyc
│ └── mlx_lm.cpython-310.pyc
├── chat.py
└── mlx_lm.py
├── README.md
├── STT
├── __pycache__
│ └── lightning_whisper_mlx_handler.cpython-310.pyc
└── lightning_whisper_mlx_handler.py
├── TTS
├── __pycache__
│ └── melotts.cpython-310.pyc
└── melotts.py
├── baseHandler.py
├── listen_and_play.py
├── local_audio_streamer.py
├── logo.png
├── requirements.txt
├── s2s_pipeline.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | tmp
--------------------------------------------------------------------------------
/LLM/__pycache__/chat.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shi3z/speech-to-speech-japanese/d3b394287d07684a155905084df708e8d53fb7c6/LLM/__pycache__/chat.cpython-310.pyc
--------------------------------------------------------------------------------
/LLM/__pycache__/mlx_lm.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shi3z/speech-to-speech-japanese/d3b394287d07684a155905084df708e8d53fb7c6/LLM/__pycache__/mlx_lm.cpython-310.pyc
--------------------------------------------------------------------------------
/LLM/chat.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | class Chat:
5 | """
6 | Handles the chat using to avoid OOM issues.
7 | """
8 |
9 | def __init__(self, size):
10 | self.size = size
11 | self.init_chat_message = None
12 | # maxlen is necessary pair, since a each new step we add an prompt and assitant answer
13 | self.buffer = []
14 |
15 | def append(self, item):
16 | self.buffer.append(item)
17 | if len(self.buffer) == 2 * (self.size + 1):
18 | self.buffer.pop(0)
19 | self.buffer.pop(0)
20 |
21 | def init_chat(self, init_chat_message):
22 | self.init_chat_message = init_chat_message
23 |
24 | def to_list(self):
25 | if self.init_chat_message:
26 | return [self.init_chat_message] + self.buffer
27 | else:
28 | return self.buffer
29 |
--------------------------------------------------------------------------------
/LLM/mlx_lm.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from LLM.chat import Chat
3 | from baseHandler import BaseHandler
4 | from mlx_lm import load, stream_generate, generate
5 | from rich.console import Console
6 | import torch
7 | logging.basicConfig(
8 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
9 | )
10 | logger = logging.getLogger(__name__)
11 |
12 | console = Console()
13 |
14 | class MLXLanguageModelHandler(BaseHandler):
15 | """
16 | Handles the language model part.
17 | """
18 |
19 | def setup(
20 | self,
21 | model_name,
22 | device="mps",
23 | torch_dtype="float16",
24 | gen_kwargs={},
25 | user_role="user",
26 | chat_size=1,
27 | init_chat_role=None,
28 | #init_chat_prompt="You are a helpful AI assistant.",
29 | init_chat_prompt="あなたは日本語に堪能な友人です。英語を使ってはいけません。全て日本語で回答します. You must answer in Japanese",
30 | ):
31 | #model_name="mlx-community/Mistral-7B-Instruct-v0.3-4bit"
32 | #model_name="mlx-community/mlx-community/Phi-3-mini-4k-instruct-4bit"
33 | #model_name="mlx-community/Llama-3-Swallow-8B-Instruct-v0.1-4bit"
34 | #model_name="mlx-community/Llama-3-Swallow-8B-Instruct-v0.1-8bit"
35 | model_name="mlx-community/Llama-3-Swallow-8B-Instruct-v0.1-8bit"
36 | print(model_name)
37 | self.model_name = model_name
38 | model_id = model_name#'microsoft/Phi-3-mini-4k-instruct'
39 | self.model, self.tokenizer = load(model_id)
40 | self.gen_kwargs = gen_kwargs
41 |
42 | self.chat = Chat(chat_size)
43 | #init_chat_role=None
44 | init_chat_prompt="あなたは日本語に堪能な友人です。英語を使ってはいけません。全て日本語で回答します. You must answer in Japanese"
45 | if init_chat_role:
46 | if not init_chat_prompt:
47 | raise ValueError(
48 | "An initial promt needs to be specified when setting init_chat_role."
49 | )
50 | self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
51 | self.user_role = user_role
52 |
53 | self.warmup()
54 |
55 | def warmup(self):
56 | logger.info(f"Warming up {self.__class__.__name__}")
57 | return
58 |
59 | dummy_input_text = "Write me a poem about Machine Learning."
60 | dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
61 |
62 | n_steps = 2
63 |
64 | for _ in range(n_steps):
65 | prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False)
66 | generate(self.model, self.tokenizer, prompt=prompt, max_tokens=self.gen_kwargs["max_new_tokens"], verbose=False)
67 |
68 |
69 | def process(self, prompt):
70 | logger.debug("infering language model...")
71 |
72 |
73 | self.chat.append({"role": self.user_role, "content": f"{prompt}, "})
74 | prompt = self.tokenizer.apply_chat_template(self.chat.to_list(), tokenize=False, add_generation_prompt=True)
75 | output = ""
76 | curr_output = ""
77 | print(self.chat.to_list())
78 | for t in stream_generate(self.model, self.tokenizer, prompt, max_tokens=self.gen_kwargs["max_new_tokens"]):
79 | output += t
80 | curr_output += t
81 | if curr_output.endswith(('.', '?','?',',','。','!','<|end|>','<|eot_id|>')):
82 | yield curr_output.replace('<|end|>', '')
83 | curr_output = ""
84 | print(f"cur:{curr_output}")
85 | generated_text = output.replace('<|end|>', '')
86 | print(f"generated_text:{generated_text}")
87 | torch.mps.empty_cache()
88 |
89 | self.chat.append({"role": "assistant", "content": generated_text})
90 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 |
6 | # Speech To Speech: an effort for an open-sourced and modular GPT4-o
7 |
8 | fork from: [https://github.com/eustlb/speech-to-speech](https://github.com/eustlb/speech-to-speech)
9 |
10 | # Japanese support
11 |
12 | Python 3.10で動作確認済み
13 |
14 | ```bash
15 | git clone https://github.com/shi3z/speech-to-speech-japanese.git
16 | cd speech-to-speech-japanese
17 | pip install git+https://github.com/nltk/nltk.git@3.8.2
18 | git clone https://github.com/reazon-research/ReazonSpeech
19 | pip install Cython
20 | pip install ReazonSpeech/pkg/nemo-asr
21 | git clone https://github.com/myshell-ai/MeloTTS
22 | cd MeloTTS
23 | pip install -e .
24 | python -m unidic download
25 | cd ..
26 | pip install -r requirements.txt
27 | pip install transformers==4.44.1
28 | pip install mlx-lm
29 | pip install protobuf --upgrade
30 | python s2s_pipeline.py --mode local --device mps
31 | ```
32 | MacBookPro M2 Max(32GB)で動作確認済
33 | MacBook M1(16GB)でも動作確認済み( @necobit (https://github.com/necobit) )
34 |
35 | ## 📖 Quick Index
36 | * [Approach](#approach)
37 | - [Structure](#structure)
38 | - [Modularity](#modularity)
39 | * [Setup](#setup)
40 | * [Usage](#usage)
41 | - [Server/Client approach](#serverclient-approach)
42 | - [Local approach](#local-approach)
43 | * [Command-line usage](#command-line-usage)
44 | - [Model parameters](#model-parameters)
45 | - [Generation parameters](#generation-parameters)
46 | - [Notable parameters](#notable-parameters)
47 |
48 | ## Approach
49 |
50 | ### Structure
51 | This repository implements a speech-to-speech cascaded pipeline with consecutive parts:
52 | 1. **Voice Activity Detection (VAD)**: [silero VAD v5](https://github.com/snakers4/silero-vad)
53 | 2. **Speech to Text (STT)**: Whisper checkpoints (including [distilled versions](https://huggingface.co/distil-whisper))
54 | 3. **Language Model (LM)**: Any instruct model available on the [Hugging Face Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending)! 🤗
55 | 4. **Text to Speech (TTS)**: [Parler-TTS](https://github.com/huggingface/parler-tts)🤗
56 |
57 | ### Modularity
58 | The pipeline aims to provide a fully open and modular approach, leveraging models available on the Transformers library via the Hugging Face hub. The level of modularity intended for each part is as follows:
59 | - **VAD**: Uses the implementation from [Silero's repo](https://github.com/snakers4/silero-vad).
60 | - **STT**: Uses Whisper models exclusively; however, any Whisper checkpoint can be used, enabling options like [Distil-Whisper](https://huggingface.co/distil-whisper/distil-large-v3) and [French Distil-Whisper](https://huggingface.co/eustlb/distil-large-v3-fr).
61 | - **LM**: This part is fully modular and can be changed by simply modifying the Hugging Face hub model ID. Users need to select an instruct model since the usage here involves interacting with it.
62 | - **TTS**: The mini architecture of Parler-TTS is standard, but different checkpoints, including fine-tuned multilingual checkpoints, can be used.
63 |
64 | The code is designed to facilitate easy modification. Each component is implemented as a class and can be re-implemented to match specific needs.
65 |
66 | ## Setup
67 |
68 | Clone the repository:
69 | ```bash
70 | git clone https://github.com/eustlb/speech-to-speech.git
71 | cd speech-to-speech
72 | ```
73 |
74 | Install the required dependencies:
75 | ```bash
76 | pip install -r requirements.txt
77 | ```
78 |
79 | ## Usage
80 |
81 | The pipeline can be run in two ways:
82 | - **Server/Client approach**: Models run on a server, and audio input/output are streamed from a client.
83 | - **Local approach**: Uses the same client/server method but with the loopback address.
84 |
85 | ### Server/Client Approach
86 |
87 | To run the pipeline on the server:
88 | ```bash
89 | python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
90 | ```
91 |
92 | Then run the client locally to handle sending microphone input and receiving generated audio:
93 | ```bash
94 | python listen_and_play.py --host
95 | ```
96 |
97 | ### Local Approach
98 | Simply use the loopback address:
99 | ```bash
100 | python s2s_pipeline.py --recv_host localhost --send_host localhost
101 | python listen_and_play.py --host localhost
102 | ```
103 |
104 | You can pass `--device mps` to run it locally on a Mac.
105 |
106 | ### Recommended usage
107 |
108 | Leverage Torch Compile for Whisper and Parler-TTS:
109 |
110 | ```bash
111 | python s2s_pipeline.py \
112 | --recv_host 0.0.0.0 \
113 | --send_host 0.0.0.0 \
114 | --lm_model_name microsoft/Phi-3-mini-4k-instruct \
115 | --init_chat_role system \
116 | --stt_compile_mode reduce-overhead \
117 | --tts_compile_mode default
118 | ```
119 |
120 | For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`).
121 |
122 | ## Command-line Usage
123 |
124 | ### Model Parameters
125 |
126 | `model_name`, `torch_dtype`, and `device` are exposed for each part leveraging the Transformers' implementations: Speech to Text, Language Model, and Text to Speech. Specify the targeted pipeline part with the corresponding prefix:
127 | - `stt` (Speech to Text)
128 | - `lm` (Language Model)
129 | - `tts` (Text to Speech)
130 |
131 | For example:
132 | ```bash
133 | --lm_model_name google/gemma-2b-it
134 | ```
135 |
136 | ### Generation Parameters
137 |
138 | Other generation parameters of the model's generate method can be set using the part's prefix + `_gen_`, e.g., `--stt_gen_max_new_tokens 128`. These parameters can be added to the pipeline part's arguments class if not already exposed (see `LanguageModelHandlerArguments` for example).
139 |
140 | ### Notable Parameters
141 |
142 | #### VAD Parameters
143 | - `--thresh`: Threshold value to trigger voice activity detection.
144 | - `--min_speech_ms`: Minimum duration of detected voice activity to be considered speech.
145 | - `--min_silence_ms`: Minimum length of silence intervals for segmenting speech, balancing sentence cutting and latency reduction.
146 |
147 | #### Language Model
148 | - `--init_chat_role`: Defaults to `None`. Sets the initial role in the chat template, if applicable. Refer to the model's card to set this value (e.g. for [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) you have to set `--init_chat_role system`)
149 | - `--init_chat_prompt`: Defaults to `"You are a helpful AI assistant."` Required when setting `--init_chat_role`.
150 |
151 | #### Speech to Text
152 | - `--description`: Sets the description for Parler-TTS generated voice. Defaults to: `"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."`
153 |
154 | - `--play_steps_s`: Specifies the duration of the first chunk sent during streaming output from Parler-TTS, impacting readiness and decoding steps.
155 |
156 | ## Citations
157 |
158 | ### Silero VAD
159 | ```bibtex
160 | @misc{Silero VAD,
161 | author = {Silero Team},
162 | title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
163 | year = {2021},
164 | publisher = {GitHub},
165 | journal = {GitHub repository},
166 | howpublished = {\url{https://github.com/snakers4/silero-vad}},
167 | commit = {insert_some_commit_here},
168 | email = {hello@silero.ai}
169 | }
170 | ```
171 |
172 | ### Distil-Whisper
173 | ```bibtex
174 | @misc{gandhi2023distilwhisper,
175 | title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
176 | author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
177 | year={2023},
178 | eprint={2311.00430},
179 | archivePrefix={arXiv},
180 | primaryClass={cs.CL}
181 | }
182 | ```
183 |
184 | ### Parler-TTS
185 | ```bibtex
186 | @misc{lacombe-etal-2024-parler-tts,
187 | author = {Yoach Lacombe and Vaibhav Srivastav and Sanchit Gandhi},
188 | title = {Parler-TTS},
189 | year = {2024},
190 | publisher = {GitHub},
191 | journal = {GitHub repository},
192 | howpublished = {\url{https://github.com/huggingface/parler-tts}}
193 | }
194 | ```
195 |
--------------------------------------------------------------------------------
/STT/__pycache__/lightning_whisper_mlx_handler.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shi3z/speech-to-speech-japanese/d3b394287d07684a155905084df708e8d53fb7c6/STT/__pycache__/lightning_whisper_mlx_handler.cpython-310.pyc
--------------------------------------------------------------------------------
/STT/lightning_whisper_mlx_handler.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from time import perf_counter
3 | from baseHandler import BaseHandler
4 | from lightning_whisper_mlx import LightningWhisperMLX
5 | import numpy as np
6 | from rich.console import Console
7 | import torch
8 | from reazonspeech.nemo.asr import load_model, transcribe, audio_from_path,audio_from_numpy
9 |
10 |
11 | logging.basicConfig(
12 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
13 | )
14 | logger = logging.getLogger(__name__)
15 |
16 | console = Console()
17 |
18 |
19 |
20 | class LightningWhisperSTTHandler(BaseHandler):
21 | """
22 | Handles the Speech To Text generation using a Whisper model.
23 | """
24 |
25 | def setup(
26 | self,
27 | model_name="distil-whisper/distil-large-v3",
28 | device="cuda",
29 | torch_dtype="float16",
30 | compile_mode=None,
31 | gen_kwargs={},
32 | ):
33 | self.device = device
34 | #self.model = LightningWhisperMLX(
35 | # model="distil-large-v3", batch_size=6, quant=None
36 | #)
37 | self.model=load_model()
38 |
39 | self.warmup()
40 |
41 | def warmup(self):
42 | logger.info(f"Warming up {self.__class__.__name__}")
43 | return
44 |
45 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture
46 | n_steps = 1
47 | dummy_input = np.array([0] * 512)
48 |
49 | for _ in range(n_steps):
50 | _ = self.model.transcribe(dummy_input)["text"].strip()
51 |
52 | def process(self, spoken_prompt):
53 | logger.debug("infering whisper...")
54 |
55 | global pipeline_start
56 | pipeline_start = perf_counter()
57 | RATE=16000
58 | audio = audio_from_numpy(spoken_prompt,RATE)
59 |
60 | #pred_text = self.model.transcribe(spoken_prompt,language="ja")["text"].strip()
61 | result = transcribe(self.model, audio)
62 | pred_text = result.text
63 | torch.mps.empty_cache()
64 |
65 | logger.debug("finished whisper inference")
66 | console.print(f"[yellow]USER: {pred_text}")
67 |
68 | yield pred_text
69 |
--------------------------------------------------------------------------------
/TTS/__pycache__/melotts.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shi3z/speech-to-speech-japanese/d3b394287d07684a155905084df708e8d53fb7c6/TTS/__pycache__/melotts.cpython-310.pyc
--------------------------------------------------------------------------------
/TTS/melotts.py:
--------------------------------------------------------------------------------
1 | from MeloTTS.melo.api import TTS
2 | import logging
3 | from baseHandler import BaseHandler
4 | import librosa
5 | import numpy as np
6 | from rich.console import Console
7 | import torch
8 |
9 | logging.basicConfig(
10 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
11 | )
12 | logger = logging.getLogger(__name__)
13 |
14 | console = Console()
15 |
16 |
17 | class MeloTTSHandler(BaseHandler):
18 | def setup(
19 | self,
20 | should_listen,
21 | device="mps",
22 | #language="EN_NEWEST",
23 | language="JP",
24 | blocksize=512,
25 | ):
26 | self.should_listen = should_listen
27 | self.device = device
28 | self.model = TTS(language=language, device=device)
29 | self.speaker_id = self.model.hps.data.spk2id[language]
30 | self.blocksize = blocksize
31 | self.warmup()
32 |
33 | def warmup(self):
34 | logger.info(f"Warming up {self.__class__.__name__}")
35 | _ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
36 |
37 | def process(self, llm_sentence):
38 | llm_sentence = llm_sentence.replace("AI","エーアイ")
39 | llm_sentence = llm_sentence.replace("Python","パイソン")
40 | llm_sentence = llm_sentence.replace("A","エー")
41 | llm_sentence = llm_sentence.replace("B","ビー")
42 | llm_sentence = llm_sentence.replace("C","シー")
43 | llm_sentence = llm_sentence.replace("D","ディー")
44 | llm_sentence = llm_sentence.replace("E","イー")
45 | llm_sentence = llm_sentence.replace("F","エフ")
46 | llm_sentence = llm_sentence.replace("G","ジー")
47 | llm_sentence = llm_sentence.replace("H","エイチ")
48 | llm_sentence = llm_sentence.replace("I","アイ")
49 | llm_sentence = llm_sentence.replace("J","ジェイ")
50 | llm_sentence = llm_sentence.replace("K","ケイ")
51 | llm_sentence = llm_sentence.replace("L","エル")
52 | llm_sentence = llm_sentence.replace("M","エム")
53 | llm_sentence = llm_sentence.replace("N","エヌ")
54 | llm_sentence = llm_sentence.replace("O","オー")
55 | llm_sentence = llm_sentence.replace("P","ピー")
56 | llm_sentence = llm_sentence.replace("Q","キュー")
57 | llm_sentence = llm_sentence.replace("R","アール")
58 | llm_sentence = llm_sentence.replace("S","エス")
59 | llm_sentence = llm_sentence.replace("T","ティー")
60 | llm_sentence = llm_sentence.replace("U","ユー")
61 | llm_sentence = llm_sentence.replace("V","ブイ")
62 | llm_sentence = llm_sentence.replace("W","ダブリュー")
63 | llm_sentence = llm_sentence.replace("X","エックス")
64 | llm_sentence = llm_sentence.replace("Y","ワイ")
65 | llm_sentence = llm_sentence.replace("Z","ゼット")
66 | console.print(f"[green]ASSISTANT: {llm_sentence}")
67 | if self.device == "mps":
68 | import time
69 | start = time.time()
70 | torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
71 | torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
72 | time_it_took = time.time()-start # Removing this line makes it fail more often. I'm looking into it.
73 |
74 | audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True)
75 | if len(audio_chunk) == 0:
76 | self.should_listen.set()
77 | return
78 | audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
79 | audio_chunk = (audio_chunk * 32768).astype(np.int16)
80 | for i in range(0, len(audio_chunk), self.blocksize):
81 | yield np.pad(
82 | audio_chunk[i : i + self.blocksize],
83 | (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
84 | )
85 |
86 | self.should_listen.set()
87 |
--------------------------------------------------------------------------------
/baseHandler.py:
--------------------------------------------------------------------------------
1 | from time import perf_counter
2 | import logging
3 |
4 | logger = logging.getLogger(__name__)
5 |
6 |
7 | class BaseHandler:
8 | """
9 | Base class for pipeline parts. Each part of the pipeline has an input and an output queue.
10 | The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part.
11 | To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue.
12 | Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue.
13 | The cleanup method handles stopping the handler, and b"END" is placed in the output queue.
14 | """
15 |
16 | def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}):
17 | self.stop_event = stop_event
18 | self.queue_in = queue_in
19 | self.queue_out = queue_out
20 | self.setup(*setup_args, **setup_kwargs)
21 | self._times = []
22 |
23 | def setup(self):
24 | pass
25 |
26 | def process(self):
27 | raise NotImplementedError
28 |
29 | def run(self):
30 | while not self.stop_event.is_set():
31 | input = self.queue_in.get()
32 | if isinstance(input, bytes) and input == b"END":
33 | # sentinelle signal to avoid queue deadlock
34 | logger.debug("Stopping thread")
35 | break
36 | start_time = perf_counter()
37 | for output in self.process(input):
38 | self._times.append(perf_counter() - start_time)
39 | logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
40 | self.queue_out.put(output)
41 | start_time = perf_counter()
42 |
43 | self.cleanup()
44 | self.queue_out.put(b"END")
45 |
46 | @property
47 | def last_time(self):
48 | return self._times[-1]
49 |
50 | def cleanup(self):
51 | pass
52 |
--------------------------------------------------------------------------------
/listen_and_play.py:
--------------------------------------------------------------------------------
1 | import socket
2 | import numpy as np
3 | import threading
4 | from queue import Queue
5 | from dataclasses import dataclass, field
6 | import sounddevice as sd
7 | from transformers import HfArgumentParser
8 |
9 |
10 | @dataclass
11 | class ListenAndPlayArguments:
12 | send_rate: int = field(
13 | default=16000,
14 | metadata={
15 | "help": "In Hz. Default is 16000."
16 | }
17 | )
18 | recv_rate: int = field(
19 | default=44100,
20 | metadata={
21 | "help": "In Hz. Default is 44100."
22 | }
23 | )
24 | list_play_chunk_size: int = field(
25 | default=1024,
26 | metadata={
27 | "help": "The size of data chunks (in bytes). Default is 1024."
28 | }
29 | )
30 | host: str = field(
31 | default="localhost",
32 | metadata={
33 | "help": "The hostname or IP address for listening and playing. Default is 'localhost'."
34 | }
35 | )
36 | send_port: int = field(
37 | default=12345,
38 | metadata={
39 | "help": "The network port for sending data. Default is 12345."
40 | }
41 | )
42 | recv_port: int = field(
43 | default=12346,
44 | metadata={
45 | "help": "The network port for receiving data. Default is 12346."
46 | }
47 | )
48 |
49 |
50 | def listen_and_play(
51 | send_rate=16000,
52 | recv_rate=44100,
53 | list_play_chunk_size=1024,
54 | host="localhost",
55 | send_port=12345,
56 | recv_port=12346,
57 | ):
58 |
59 | send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
60 | send_socket.connect((host, send_port))
61 |
62 | recv_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
63 | recv_socket.connect((host, recv_port))
64 |
65 | print("Recording and streaming...")
66 |
67 | stop_event = threading.Event()
68 | recv_queue = Queue()
69 | send_queue = Queue()
70 |
71 | def callback_recv(outdata, frames, time, status):
72 | if not recv_queue.empty():
73 | data = recv_queue.get()
74 | outdata[:len(data)] = data
75 | outdata[len(data):] = b'\x00' * (len(outdata) - len(data))
76 | else:
77 | outdata[:] = b'\x00' * len(outdata)
78 |
79 | def callback_send(indata, frames, time, status):
80 | if recv_queue.empty():
81 | data = bytes(indata)
82 | send_queue.put(data)
83 |
84 | def send(stop_event, send_queue):
85 | while not stop_event.is_set():
86 | data = send_queue.get()
87 | send_socket.sendall(data)
88 |
89 | def recv(stop_event, recv_queue):
90 |
91 | def receive_full_chunk(conn, chunk_size):
92 | data = b''
93 | while len(data) < chunk_size:
94 | packet = conn.recv(chunk_size - len(data))
95 | if not packet:
96 | return None # Connection has been closed
97 | data += packet
98 | return data
99 |
100 | while not stop_event.is_set():
101 | data = receive_full_chunk(recv_socket, list_play_chunk_size * 2)
102 | if data:
103 | recv_queue.put(data)
104 |
105 | try:
106 | send_stream = sd.RawInputStream(samplerate=send_rate, channels=1, dtype='int16', blocksize=list_play_chunk_size, callback=callback_send)
107 | recv_stream = sd.RawOutputStream(samplerate=recv_rate, channels=1, dtype='int16', blocksize=list_play_chunk_size, callback=callback_recv)
108 | threading.Thread(target=send_stream.start).start()
109 | threading.Thread(target=recv_stream.start).start()
110 |
111 | send_thread = threading.Thread(target=send, args=(stop_event, send_queue))
112 | send_thread.start()
113 | recv_thread = threading.Thread(target=recv, args=(stop_event, recv_queue))
114 | recv_thread.start()
115 |
116 | input("Press Enter to stop...")
117 |
118 | except KeyboardInterrupt:
119 | print("Finished streaming.")
120 |
121 | finally:
122 | stop_event.set()
123 | recv_thread.join()
124 | send_thread.join()
125 | send_socket.close()
126 | recv_socket.close()
127 | print("Connection closed.")
128 |
129 |
130 | if __name__ == "__main__":
131 | parser = HfArgumentParser((ListenAndPlayArguments,))
132 | listen_and_play_kwargs, = parser.parse_args_into_dataclasses()
133 | listen_and_play(**vars(listen_and_play_kwargs))
134 |
135 |
--------------------------------------------------------------------------------
/local_audio_streamer.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import sounddevice as sd
3 | import numpy as np
4 |
5 | import time
6 |
7 |
8 | class LocalAudioStreamer:
9 | def __init__(
10 | self,
11 | input_queue,
12 | output_queue,
13 | list_play_chunk_size=512,
14 | ):
15 | self.list_play_chunk_size = list_play_chunk_size
16 |
17 | self.stop_event = threading.Event()
18 | self.input_queue = input_queue
19 | self.output_queue = output_queue
20 |
21 | def run(self):
22 | def callback(indata, outdata, frames, time, status):
23 | if self.output_queue.empty():
24 | self.input_queue.put(indata.copy())
25 | outdata[:] = 0 * outdata
26 | else:
27 | outdata[:] = self.output_queue.get()[:, np.newaxis]
28 |
29 | with sd.Stream(
30 | samplerate=16000,
31 | dtype="int16",
32 | channels=1,
33 | callback=callback,
34 | blocksize=self.list_play_chunk_size,
35 | ):
36 | while not self.stop_event.is_set():
37 | time.sleep(0.001)
38 | print("Stopping recording")
39 |
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shi3z/speech-to-speech-japanese/d3b394287d07684a155905084df708e8d53fb7c6/logo.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | nltk==3.8.1
2 | parler_tts @ git+https://github.com/huggingface/parler-tts.git
3 | torch==2.4.0
4 | sounddevice==0.5.0
5 |
--------------------------------------------------------------------------------
/s2s_pipeline.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import socket
4 | import sys
5 | import threading
6 | from copy import copy
7 | from dataclasses import dataclass, field
8 | from pathlib import Path
9 | from queue import Queue
10 | from threading import Event, Thread
11 | from time import perf_counter
12 | from typing import Optional
13 |
14 | from LLM.mlx_lm import MLXLanguageModelHandler
15 | from TTS.melotts import MeloTTSHandler
16 | from baseHandler import BaseHandler
17 | from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
18 | import numpy as np
19 | import torch
20 | from nltk.tokenize import sent_tokenize
21 | from rich.console import Console
22 | from transformers import (
23 | AutoModelForCausalLM,
24 | AutoModelForSpeechSeq2Seq,
25 | AutoProcessor,
26 | AutoTokenizer,
27 | HfArgumentParser,
28 | pipeline,
29 | TextIteratorStreamer,
30 | )
31 | #from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
32 | import librosa
33 |
34 | from local_audio_streamer import LocalAudioStreamer
35 | from utils import VADIterator, int2float, next_power_of_2
36 |
37 | # Ensure that the necessary NLTK resources are available
38 | # try:
39 | # nltk.data.find('tokenizers/punkt_tab')
40 | # except LookupError:
41 | # nltk.download('punkt_tab')
42 |
43 | # caching allows ~50% compilation time reduction
44 | # see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma
45 | CURRENT_DIR = Path(__file__).resolve().parent
46 | os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp")
47 | # torch._inductor.config.fx_graph_cache = True
48 | # # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
49 | # torch._dynamo.config.cache_size_limit = 15
50 |
51 |
52 | console = Console()
53 |
54 |
55 | @dataclass
56 | class ModuleArguments:
57 | device: Optional[str] = field(
58 | default=None,
59 | metadata={"help": "If specified, overrides the device for all handlers."},
60 | )
61 | mode: Optional[str] = field(
62 | default="local",
63 | metadata={
64 | "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
65 | },
66 | )
67 | log_level: str = field(
68 | default="info",
69 | metadata={
70 | "help": "Provide logging level. Example --log_level debug, default=warning."
71 | },
72 | )
73 |
74 |
75 | class ThreadManager:
76 | """
77 | Manages multiple threads used to execute given handler tasks.
78 | """
79 |
80 | def __init__(self, handlers):
81 | self.handlers = handlers
82 | self.threads = []
83 |
84 | def start(self):
85 | for handler in self.handlers:
86 | thread = threading.Thread(target=handler.run)
87 | self.threads.append(thread)
88 | thread.start()
89 |
90 | def stop(self):
91 | for handler in self.handlers:
92 | handler.stop_event.set()
93 | for thread in self.threads:
94 | thread.join()
95 |
96 |
97 | @dataclass
98 | class SocketReceiverArguments:
99 | recv_host: str = field(
100 | default="localhost",
101 | metadata={
102 | "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all "
103 | "available interfaces on the host machine."
104 | },
105 | )
106 | recv_port: int = field(
107 | default=12345,
108 | metadata={
109 | "help": "The port number on which the socket server listens. Default is 12346."
110 | },
111 | )
112 | chunk_size: int = field(
113 | default=1024,
114 | metadata={
115 | "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes."
116 | },
117 | )
118 |
119 |
120 | class SocketReceiver:
121 | """
122 | Handles reception of the audio packets from the client.
123 | """
124 |
125 | def __init__(
126 | self,
127 | stop_event,
128 | queue_out,
129 | should_listen,
130 | host="0.0.0.0",
131 | port=12345,
132 | chunk_size=1024,
133 | ):
134 | self.stop_event = stop_event
135 | self.queue_out = queue_out
136 | self.should_listen = should_listen
137 | self.chunk_size = chunk_size
138 | self.host = host
139 | self.port = port
140 |
141 | def receive_full_chunk(self, conn, chunk_size):
142 | data = b""
143 | while len(data) < chunk_size:
144 | packet = conn.recv(chunk_size - len(data))
145 | if not packet:
146 | # connection closed
147 | return None
148 | data += packet
149 | return data
150 |
151 | def run(self):
152 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
153 | self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
154 | self.socket.bind((self.host, self.port))
155 | self.socket.listen(1)
156 | logger.info("Receiver waiting to be connected...")
157 | self.conn, _ = self.socket.accept()
158 | logger.info("receiver connected")
159 |
160 | self.should_listen.set()
161 | while not self.stop_event.is_set():
162 | audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size)
163 | if audio_chunk is None:
164 | # connection closed
165 | self.queue_out.put(b"END")
166 | break
167 | if self.should_listen.is_set():
168 | self.queue_out.put(audio_chunk)
169 | self.conn.close()
170 | logger.info("Receiver closed")
171 |
172 |
173 | @dataclass
174 | class SocketSenderArguments:
175 | send_host: str = field(
176 | default="localhost",
177 | metadata={
178 | "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all "
179 | "available interfaces on the host machine."
180 | },
181 | )
182 | send_port: int = field(
183 | default=12346,
184 | metadata={
185 | "help": "The port number on which the socket server listens. Default is 12346."
186 | },
187 | )
188 |
189 |
190 | class SocketSender:
191 | """
192 | Handles sending generated audio packets to the clients.
193 | """
194 |
195 | def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346):
196 | self.stop_event = stop_event
197 | self.queue_in = queue_in
198 | self.host = host
199 | self.port = port
200 |
201 | def run(self):
202 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
203 | self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
204 | self.socket.bind((self.host, self.port))
205 | self.socket.listen(1)
206 | logger.info("Sender waiting to be connected...")
207 | self.conn, _ = self.socket.accept()
208 | logger.info("sender connected")
209 |
210 | while not self.stop_event.is_set():
211 | audio_chunk = self.queue_in.get()
212 | self.conn.sendall(audio_chunk)
213 | if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
214 | break
215 | self.conn.close()
216 | logger.info("Sender closed")
217 |
218 |
219 | @dataclass
220 | class VADHandlerArguments:
221 | thresh: float = field(
222 | default=0.3,
223 | metadata={
224 | "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection."
225 | },
226 | )
227 | sample_rate: int = field(
228 | default=16000,
229 | metadata={
230 | "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio."
231 | },
232 | )
233 | min_silence_ms: int = field(
234 | default=250,
235 | metadata={
236 | "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms."
237 | },
238 | )
239 | min_speech_ms: int = field(
240 | default=500,
241 | metadata={
242 | "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms."
243 | },
244 | )
245 | max_speech_ms: float = field(
246 | default=float("inf"),
247 | metadata={
248 | "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments."
249 | },
250 | )
251 | speech_pad_ms: int = field(
252 | default=250,
253 | metadata={
254 | "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
255 | },
256 | )
257 |
258 |
259 | class VADHandler(BaseHandler):
260 | """
261 | Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
262 | to the following part.
263 | """
264 |
265 | def setup(
266 | self,
267 | should_listen,
268 | thresh=0.3,
269 | sample_rate=16000,
270 | min_silence_ms=1000,
271 | min_speech_ms=500,
272 | max_speech_ms=float("inf"),
273 | speech_pad_ms=30,
274 | ):
275 | self.should_listen = should_listen
276 | self.sample_rate = sample_rate
277 | self.min_silence_ms = min_silence_ms
278 | self.min_speech_ms = min_speech_ms
279 | self.max_speech_ms = max_speech_ms
280 | self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
281 | self.iterator = VADIterator(
282 | self.model,
283 | threshold=thresh,
284 | sampling_rate=sample_rate,
285 | min_silence_duration_ms=min_silence_ms,
286 | speech_pad_ms=speech_pad_ms,
287 | )
288 |
289 | def process(self, audio_chunk):
290 | audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
291 | audio_float32 = int2float(audio_int16)
292 | vad_output = self.iterator(torch.from_numpy(audio_float32))
293 | if vad_output is not None and len(vad_output) != 0:
294 | logger.debug("VAD: end of speech detected")
295 | array = torch.cat(vad_output).cpu().numpy()
296 | duration_ms = len(array) / self.sample_rate * 1000
297 | if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
298 | logger.debug(
299 | f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
300 | )
301 | else:
302 | self.should_listen.clear()
303 | logger.debug("Stop listening")
304 | yield array
305 |
306 |
307 | @dataclass
308 | class WhisperSTTHandlerArguments:
309 | stt_model_name: str = field(
310 | default="distil-whisper/distil-large-v3",
311 | metadata={
312 | "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
313 | },
314 | )
315 | stt_device: str = field(
316 | default="cuda",
317 | metadata={
318 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
319 | },
320 | )
321 | stt_torch_dtype: str = field(
322 | default="float16",
323 | metadata={
324 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
325 | },
326 | )
327 | stt_compile_mode: str = field(
328 | default=None,
329 | metadata={
330 | "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
331 | },
332 | )
333 | stt_gen_max_new_tokens: int = field(
334 | default=128,
335 | metadata={
336 | "help": "The maximum number of new tokens to generate. Default is 128."
337 | },
338 | )
339 | stt_gen_num_beams: int = field(
340 | default=1,
341 | metadata={
342 | "help": "The number of beams for beam search. Default is 1, implying greedy decoding."
343 | },
344 | )
345 | stt_gen_return_timestamps: bool = field(
346 | default=False,
347 | metadata={
348 | "help": "Whether to return timestamps with transcriptions. Default is False."
349 | },
350 | )
351 | # stt_gen_task: str = field(
352 | # default="transcribe",
353 | # metadata={
354 | # "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
355 | # },
356 | # )
357 | stt_gen_language: str = field(
358 | default="jp",
359 | metadata={
360 | "help": "The language of the speech to transcribe. Default is 'en' for English."
361 | },
362 | )
363 |
364 |
365 | class WhisperSTTHandler(BaseHandler):
366 | """
367 | Handles the Speech To Text generation using a Whisper model.
368 | """
369 |
370 | def setup(
371 | self,
372 | model_name="distil-whisper/distil-large-v3",
373 | device="cuda",
374 | torch_dtype="float16",
375 | compile_mode=None,
376 | gen_kwargs={},
377 | ):
378 | self.device = device
379 | self.torch_dtype = getattr(torch, torch_dtype)
380 | self.compile_mode = compile_mode
381 | self.gen_kwargs = gen_kwargs
382 |
383 | self.processor = AutoProcessor.from_pretrained(model_name)
384 | self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
385 | model_name,
386 | torch_dtype=self.torch_dtype,
387 | ).to(device)
388 |
389 | # compile
390 | if self.compile_mode:
391 | self.model.generation_config.cache_implementation = "static"
392 | self.model.forward = torch.compile(
393 | self.model.forward, mode=self.compile_mode, fullgraph=True
394 | )
395 | self.warmup()
396 |
397 | def prepare_model_inputs(self, spoken_prompt):
398 | input_features = self.processor(
399 | spoken_prompt, sampling_rate=16000, return_tensors="pt"
400 | ).input_features
401 | input_features = input_features.to(self.device, dtype=self.torch_dtype)
402 |
403 | return input_features
404 |
405 | def warmup(self):
406 | logger.info(f"Warming up {self.__class__.__name__}")
407 |
408 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture
409 | n_steps = 1 if self.compile_mode == "default" else 2
410 | dummy_input = torch.randn(
411 | (1, self.model.config.num_mel_bins, 3000),
412 | dtype=self.torch_dtype,
413 | device=self.device,
414 | )
415 | if self.compile_mode not in (None, "default"):
416 | # generating more tokens than previously will trigger CUDA graphs capture
417 | # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
418 | warmup_gen_kwargs = {
419 | "min_new_tokens": self.gen_kwargs["max_new_tokens"],
420 | "max_new_tokens": self.gen_kwargs["max_new_tokens"],
421 | **self.gen_kwargs,
422 | }
423 | else:
424 | warmup_gen_kwargs = self.gen_kwargs
425 |
426 | if self.device == "cuda":
427 | start_event = torch.cuda.Event(enable_timing=True)
428 | end_event = torch.cuda.Event(enable_timing=True)
429 | torch.cuda.synchronize()
430 | start_event.record()
431 |
432 | for _ in range(n_steps):
433 | _ = self.model.generate(dummy_input, **warmup_gen_kwargs)
434 |
435 | if self.device == "cuda":
436 | end_event.record()
437 | torch.cuda.synchronize()
438 |
439 | logger.info(
440 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
441 | )
442 |
443 | def process(self, spoken_prompt):
444 | logger.debug("infering whisper...")
445 |
446 | global pipeline_start
447 | pipeline_start = perf_counter()
448 |
449 | input_features = self.prepare_model_inputs(spoken_prompt)
450 | pred_ids = self.model.generate(input_features, **self.gen_kwargs)
451 | pred_text = self.processor.batch_decode(
452 | pred_ids, skip_special_tokens=True, decode_with_timestamps=False
453 | )[0]
454 |
455 | logger.debug("finished whisper inference")
456 | console.print(f"[yellow]USER: {pred_text}")
457 |
458 | yield pred_text
459 |
460 |
461 | @dataclass
462 | class LanguageModelHandlerArguments:
463 | lm_model_name: str = field(
464 | default="HuggingFaceTB/SmolLM-360M-Instruct",
465 | metadata={
466 | "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
467 | },
468 | )
469 | lm_device: str = field(
470 | default="cuda",
471 | metadata={
472 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
473 | },
474 | )
475 | lm_torch_dtype: str = field(
476 | default="float16",
477 | metadata={
478 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
479 | },
480 | )
481 | user_role: str = field(
482 | default="user",
483 | metadata={
484 | "help": "Role assigned to the user in the chat context. Default is 'user'."
485 | },
486 | )
487 | init_chat_role: str = field(
488 | default='system',
489 | metadata={
490 | "help": "Initial role for setting up the chat context. Default is 'system'."
491 | },
492 | )
493 | init_chat_prompt: str = field(
494 | default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
495 | metadata={
496 | "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
497 | },
498 | )
499 | lm_gen_max_new_tokens: int = field(
500 | default=128,
501 | metadata={
502 | "help": "Maximum number of new tokens to generate in a single completion. Default is 128."
503 | },
504 | )
505 | lm_gen_temperature: float = field(
506 | default=0.0,
507 | metadata={
508 | "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
509 | },
510 | )
511 | lm_gen_do_sample: bool = field(
512 | default=False,
513 | metadata={
514 | "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
515 | },
516 | )
517 | chat_size: int = field(
518 | default=2,
519 | metadata={
520 | "help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
521 | },
522 | )
523 |
524 |
525 | class Chat:
526 | """
527 | Handles the chat using to avoid OOM issues.
528 | """
529 |
530 | def __init__(self, size):
531 | self.size = size
532 | self.init_chat_message = None
533 | # maxlen is necessary pair, since a each new step we add an prompt and assitant answer
534 | self.buffer = []
535 |
536 | def append(self, item):
537 | self.buffer.append(item)
538 | if len(self.buffer) == 2 * (self.size + 1):
539 | self.buffer.pop(0)
540 | self.buffer.pop(0)
541 |
542 | def init_chat(self, init_chat_message):
543 | self.init_chat_message = init_chat_message
544 |
545 | def to_list(self):
546 | if self.init_chat_message:
547 | return [self.init_chat_message] + self.buffer
548 | else:
549 | return self.buffer
550 |
551 |
552 | class LanguageModelHandler(BaseHandler):
553 | """
554 | Handles the language model part.
555 | """
556 |
557 | def setup(
558 | self,
559 | model_name="microsoft/Phi-3-mini-4k-instruct",
560 | device="cuda",
561 | torch_dtype="float16",
562 | gen_kwargs={},
563 | user_role="user",
564 | chat_size=1,
565 | init_chat_role=None,
566 | init_chat_prompt="You are a helpful AI assistant.",
567 | ):
568 | self.device = device
569 | self.torch_dtype = getattr(torch, torch_dtype)
570 |
571 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
572 | self.model = AutoModelForCausalLM.from_pretrained(
573 | model_name, torch_dtype=torch_dtype, trust_remote_code=True
574 | ).to(device)
575 | self.pipe = pipeline(
576 | "text-generation", model=self.model, tokenizer=self.tokenizer, device=device
577 | )
578 | self.streamer = TextIteratorStreamer(
579 | self.tokenizer,
580 | skip_prompt=True,
581 | skip_special_tokens=True,
582 | )
583 | self.gen_kwargs = {
584 | "streamer": self.streamer,
585 | "return_full_text": False,
586 | **gen_kwargs,
587 | }
588 |
589 | self.chat = Chat(chat_size)
590 | if init_chat_role:
591 | if not init_chat_prompt:
592 | raise ValueError(
593 | "An initial promt needs to be specified when setting init_chat_role."
594 | )
595 | self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
596 | self.user_role = user_role
597 |
598 | self.warmup()
599 |
600 | def warmup(self):
601 | logger.info(f"Warming up {self.__class__.__name__}")
602 |
603 | dummy_input_text = "Write me a poem about Machine Learning."
604 | dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
605 | warmup_gen_kwargs = {
606 | "min_new_tokens": self.gen_kwargs["max_new_tokens"],
607 | "max_new_tokens": self.gen_kwargs["max_new_tokens"],
608 | **self.gen_kwargs,
609 | }
610 |
611 | n_steps = 2
612 |
613 | if self.device == "cuda":
614 | start_event = torch.cuda.Event(enable_timing=True)
615 | end_event = torch.cuda.Event(enable_timing=True)
616 | torch.cuda.synchronize()
617 | start_event.record()
618 |
619 | for _ in range(n_steps):
620 | thread = Thread(
621 | target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
622 | )
623 | thread.start()
624 | for _ in self.streamer:
625 | pass
626 |
627 | if self.device == "cuda":
628 | end_event.record()
629 | torch.cuda.synchronize()
630 |
631 | logger.info(
632 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
633 | )
634 |
635 | def process(self, prompt):
636 | logger.debug("infering language model...")
637 |
638 | self.chat.append({"role": self.user_role, "content": prompt})
639 | thread = Thread(
640 | target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
641 | )
642 | thread.start()
643 | if self.device == "mps":
644 | generated_text = ""
645 | for new_text in self.streamer:
646 | generated_text += new_text
647 | printable_text = generated_text
648 | torch.mps.empty_cache()
649 | else:
650 | generated_text, printable_text = "", ""
651 | for new_text in self.streamer:
652 | generated_text += new_text
653 | printable_text += new_text
654 | sentences = sent_tokenize(printable_text)
655 | if len(sentences) > 1:
656 | yield (sentences[0])
657 | printable_text = new_text
658 |
659 | self.chat.append({"role": "assistant", "content": generated_text})
660 |
661 | # don't forget last sentence
662 | yield printable_text
663 |
664 |
665 | @dataclass
666 | class ParlerTTSHandlerArguments:
667 | tts_model_name: str = field(
668 | default="ylacombe/parler-tts-mini-jenny-30H",
669 | metadata={
670 | "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
671 | },
672 | )
673 | tts_device: str = field(
674 | default="cuda",
675 | metadata={
676 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
677 | },
678 | )
679 | tts_torch_dtype: str = field(
680 | default="float16",
681 | metadata={
682 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
683 | },
684 | )
685 | tts_compile_mode: str = field(
686 | default=None,
687 | metadata={
688 | "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
689 | },
690 | )
691 | tts_gen_min_new_tokens: int = field(
692 | default=64,
693 | metadata={
694 | "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"
695 | },
696 | )
697 | tts_gen_max_new_tokens: int = field(
698 | default=512,
699 | metadata={
700 | "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
701 | },
702 | )
703 | description: str = field(
704 | default=(
705 | "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
706 | "She speaks very fast."
707 | ),
708 | metadata={
709 | "help": "Description of the speaker's voice and speaking style to guide the TTS model."
710 | },
711 | )
712 | play_steps_s: float = field(
713 | default=1.0,
714 | metadata={
715 | "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
716 | },
717 | )
718 | max_prompt_pad_length: int = field(
719 | default=8,
720 | metadata={
721 | "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
722 | },
723 | )
724 |
725 |
726 | class ParlerTTSHandler(BaseHandler):
727 | def setup(
728 | self,
729 | should_listen,
730 | model_name="ylacombe/parler-tts-mini-jenny-30H",
731 | device="cuda",
732 | torch_dtype="float16",
733 | compile_mode=None,
734 | gen_kwargs={},
735 | max_prompt_pad_length=8,
736 | description=(
737 | "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
738 | "She speaks very fast."
739 | ),
740 | play_steps_s=1,
741 | blocksize=512,
742 | ):
743 | self.should_listen = should_listen
744 | self.device = device
745 | self.torch_dtype = getattr(torch, torch_dtype)
746 | self.gen_kwargs = gen_kwargs
747 | self.compile_mode = compile_mode
748 | self.max_prompt_pad_length = max_prompt_pad_length
749 | self.description = description
750 |
751 | self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
752 | self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
753 | self.model = ParlerTTSForConditionalGeneration.from_pretrained(
754 | model_name, torch_dtype=self.torch_dtype
755 | ).to(device)
756 |
757 | framerate = self.model.audio_encoder.config.frame_rate
758 | self.play_steps = int(framerate * play_steps_s)
759 | self.blocksize = blocksize
760 |
761 | if self.compile_mode not in (None, "default"):
762 | logger.warning(
763 | "Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'"
764 | )
765 | self.compile_mode = "default"
766 |
767 | if self.compile_mode:
768 | self.model.generation_config.cache_implementation = "static"
769 | self.model.forward = torch.compile(
770 | self.model.forward, mode=self.compile_mode, fullgraph=True
771 | )
772 |
773 | self.warmup()
774 |
775 | def prepare_model_inputs(
776 | self,
777 | prompt,
778 | max_length_prompt=50,
779 | pad=False,
780 | ):
781 | pad_args_prompt = (
782 | {"padding": "max_length", "max_length": max_length_prompt} if pad else {}
783 | )
784 |
785 | tokenized_description = self.description_tokenizer(
786 | self.description, return_tensors="pt"
787 | )
788 | input_ids = tokenized_description.input_ids.to(self.device)
789 | attention_mask = tokenized_description.attention_mask.to(self.device)
790 |
791 | tokenized_prompt = self.prompt_tokenizer(
792 | prompt, return_tensors="pt", **pad_args_prompt
793 | )
794 | prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
795 | prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
796 |
797 | gen_kwargs = {
798 | "input_ids": input_ids,
799 | "attention_mask": attention_mask,
800 | "prompt_input_ids": prompt_input_ids,
801 | "prompt_attention_mask": prompt_attention_mask,
802 | **self.gen_kwargs,
803 | }
804 |
805 | return gen_kwargs
806 |
807 | def warmup(self):
808 | logger.info(f"Warming up {self.__class__.__name__}")
809 |
810 | if self.device == "cuda":
811 | start_event = torch.cuda.Event(enable_timing=True)
812 | end_event = torch.cuda.Event(enable_timing=True)
813 |
814 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture
815 | n_steps = 1 if self.compile_mode == "default" else 2
816 |
817 | if self.device == "cuda":
818 | torch.cuda.synchronize()
819 | start_event.record()
820 | if self.compile_mode:
821 | pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
822 | for pad_length in pad_lengths[::-1]:
823 | model_kwargs = self.prepare_model_inputs(
824 | "dummy prompt", max_length_prompt=pad_length, pad=True
825 | )
826 | for _ in range(n_steps):
827 | _ = self.model.generate(**model_kwargs)
828 | logger.info(f"Warmed up length {pad_length} tokens!")
829 | else:
830 | model_kwargs = self.prepare_model_inputs("dummy prompt")
831 | for _ in range(n_steps):
832 | _ = self.model.generate(**model_kwargs)
833 |
834 | if self.device == "cuda":
835 | end_event.record()
836 | torch.cuda.synchronize()
837 | logger.info(
838 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
839 | )
840 |
841 | def process(self, llm_sentence):
842 | console.print(f"[green]ASSISTANT: {llm_sentence}")
843 | nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
844 |
845 | pad_args = {}
846 | if self.compile_mode:
847 | # pad to closest upper power of two
848 | pad_length = next_power_of_2(nb_tokens)
849 | logger.debug(f"padding to {pad_length}")
850 | pad_args["pad"] = True
851 | pad_args["max_length_prompt"] = pad_length
852 |
853 | tts_gen_kwargs = self.prepare_model_inputs(
854 | llm_sentence,
855 | **pad_args,
856 | )
857 |
858 | streamer = ParlerTTSStreamer(
859 | self.model, device=self.device, play_steps=self.play_steps
860 | )
861 | tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
862 | torch.manual_seed(0)
863 | thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
864 | thread.start()
865 |
866 | for i, audio_chunk in enumerate(streamer):
867 | if i == 0:
868 | logger.info(
869 | f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
870 | )
871 | audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
872 | audio_chunk = (audio_chunk * 32768).astype(np.int16)
873 | for i in range(0, len(audio_chunk), self.blocksize):
874 | yield np.pad(
875 | audio_chunk[i : i + self.blocksize],
876 | (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
877 | )
878 |
879 | self.should_listen.set()
880 |
881 |
882 | def prepare_args(args, prefix):
883 | """
884 | Rename arguments by removing the prefix and prepares the gen_kwargs.
885 | """
886 |
887 | gen_kwargs = {}
888 | for key in copy(args.__dict__):
889 | if key.startswith(prefix):
890 | value = args.__dict__.pop(key)
891 | new_key = key[len(prefix) + 1 :] # Remove prefix and underscore
892 | if new_key.startswith("gen_"):
893 | gen_kwargs[new_key[4:]] = value # Remove 'gen_' and add to dict
894 | else:
895 | args.__dict__[new_key] = value
896 |
897 | args.__dict__["gen_kwargs"] = gen_kwargs
898 |
899 |
900 | def main():
901 | parser = HfArgumentParser(
902 | (
903 | ModuleArguments,
904 | SocketReceiverArguments,
905 | SocketSenderArguments,
906 | VADHandlerArguments,
907 | WhisperSTTHandlerArguments,
908 | LanguageModelHandlerArguments,
909 | ParlerTTSHandlerArguments,
910 | )
911 | )
912 |
913 | # 0. Parse CLI arguments
914 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
915 | # Parse configurations from a JSON file if specified
916 | (
917 | module_kwargs,
918 | socket_receiver_kwargs,
919 | socket_sender_kwargs,
920 | vad_handler_kwargs,
921 | whisper_stt_handler_kwargs,
922 | language_model_handler_kwargs,
923 | parler_tts_handler_kwargs,
924 | ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
925 | else:
926 | # Parse arguments from command line if no JSON file is provided
927 | (
928 | module_kwargs,
929 | socket_receiver_kwargs,
930 | socket_sender_kwargs,
931 | vad_handler_kwargs,
932 | whisper_stt_handler_kwargs,
933 | language_model_handler_kwargs,
934 | parler_tts_handler_kwargs,
935 | ) = parser.parse_args_into_dataclasses()
936 |
937 | # 1. Handle logger
938 | global logger
939 | logging.basicConfig(
940 | level=module_kwargs.log_level.upper(),
941 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
942 | )
943 | logger = logging.getLogger(__name__)
944 |
945 | # torch compile logs
946 | if module_kwargs.log_level == "debug":
947 | torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
948 |
949 | # 2. Prepare each part's arguments
950 | def overwrite_device_argument(common_device: Optional[str], *handler_kwargs):
951 | if common_device:
952 | for kwargs in handler_kwargs:
953 | if hasattr(kwargs, "lm_device"):
954 | kwargs.lm_device = common_device
955 | if hasattr(kwargs, "tts_device"):
956 | kwargs.tts_device = common_device
957 | if hasattr(kwargs, "stt_device"):
958 | kwargs.stt_device = common_device
959 |
960 | # Call this function with the common device and all the handlers
961 | overwrite_device_argument(
962 | module_kwargs.device,
963 | language_model_handler_kwargs,
964 | parler_tts_handler_kwargs,
965 | whisper_stt_handler_kwargs,
966 | )
967 |
968 | prepare_args(whisper_stt_handler_kwargs, "stt")
969 | prepare_args(language_model_handler_kwargs, "lm")
970 | prepare_args(parler_tts_handler_kwargs, "tts")
971 |
972 | # 3. Build the pipeline
973 | stop_event = Event()
974 | # used to stop putting received audio chunks in queue until all setences have been processed by the TTS
975 | should_listen = Event()
976 | recv_audio_chunks_queue = Queue()
977 | send_audio_chunks_queue = Queue()
978 | spoken_prompt_queue = Queue()
979 | text_prompt_queue = Queue()
980 | lm_response_queue = Queue()
981 |
982 | if module_kwargs.mode == "local":
983 | local_audio_streamer = LocalAudioStreamer(
984 | input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue
985 | )
986 | comms_handlers = [local_audio_streamer]
987 | should_listen.set()
988 | else:
989 | comms_handlers = [
990 | SocketReceiver(
991 | stop_event,
992 | recv_audio_chunks_queue,
993 | should_listen,
994 | host=socket_receiver_kwargs.recv_host,
995 | port=socket_receiver_kwargs.recv_port,
996 | chunk_size=socket_receiver_kwargs.chunk_size,
997 | ),
998 | SocketSender(
999 | stop_event,
1000 | send_audio_chunks_queue,
1001 | host=socket_sender_kwargs.send_host,
1002 | port=socket_sender_kwargs.send_port,
1003 | ),
1004 | ]
1005 |
1006 | vad = VADHandler(
1007 | stop_event,
1008 | queue_in=recv_audio_chunks_queue,
1009 | queue_out=spoken_prompt_queue,
1010 | setup_args=(should_listen,),
1011 | setup_kwargs=vars(vad_handler_kwargs),
1012 | )
1013 | stt = LightningWhisperSTTHandler(
1014 | stop_event,
1015 | queue_in=spoken_prompt_queue,
1016 | queue_out=text_prompt_queue,
1017 | setup_kwargs=vars(whisper_stt_handler_kwargs),
1018 | )
1019 | lm = MLXLanguageModelHandler(
1020 | stop_event,
1021 | queue_in=text_prompt_queue,
1022 | queue_out=lm_response_queue,
1023 | setup_kwargs=vars(language_model_handler_kwargs),
1024 | )
1025 | # tts = ParlerTTSHandler(
1026 | # stop_event,
1027 | # queue_in=lm_response_queue,
1028 | # queue_out=send_audio_chunks_queue,
1029 | # setup_args=(should_listen,),
1030 | # setup_kwargs=vars(parler_tts_handler_kwargs),
1031 | # )
1032 | tts = MeloTTSHandler(
1033 | stop_event,
1034 | queue_in=lm_response_queue,
1035 | queue_out=send_audio_chunks_queue,
1036 | setup_args=(should_listen,),
1037 | )
1038 |
1039 | # 4. Run the pipeline
1040 | try:
1041 | pipeline_manager = ThreadManager([*comms_handlers, vad, stt, lm, tts])
1042 | pipeline_manager.start()
1043 |
1044 | except KeyboardInterrupt:
1045 | pipeline_manager.stop()
1046 |
1047 |
1048 | if __name__ == "__main__":
1049 | main()
1050 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def next_power_of_2(x):
6 | return 1 if x == 0 else 2**(x - 1).bit_length()
7 |
8 |
9 | def int2float(sound):
10 | """
11 | Taken from https://github.com/snakers4/silero-vad
12 | """
13 |
14 | abs_max = np.abs(sound).max()
15 | sound = sound.astype('float32')
16 | if abs_max > 0:
17 | sound *= 1/32768
18 | sound = sound.squeeze() # depends on the use case
19 | return sound
20 |
21 |
22 | class VADIterator:
23 | def __init__(self,
24 | model,
25 | threshold: float = 0.5,
26 | sampling_rate: int = 16000,
27 | min_silence_duration_ms: int = 100,
28 | speech_pad_ms: int = 30
29 | ):
30 |
31 | """
32 | Mainly taken from https://github.com/snakers4/silero-vad
33 | Class for stream imitation
34 |
35 | Parameters
36 | ----------
37 | model: preloaded .jit/.onnx silero VAD model
38 |
39 | threshold: float (default - 0.5)
40 | Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
41 | It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
42 |
43 | sampling_rate: int (default - 16000)
44 | Currently silero VAD models support 8000 and 16000 sample rates
45 |
46 | min_silence_duration_ms: int (default - 100 milliseconds)
47 | In the end of each speech chunk wait for min_silence_duration_ms before separating it
48 |
49 | speech_pad_ms: int (default - 30 milliseconds)
50 | Final speech chunks are padded by speech_pad_ms each side
51 | """
52 |
53 | self.model = model
54 | self.threshold = threshold
55 | self.sampling_rate = sampling_rate
56 | self.is_speaking = False
57 | self.buffer = []
58 |
59 | if sampling_rate not in [8000, 16000]:
60 | raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
61 |
62 | self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
63 | self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
64 | self.reset_states()
65 |
66 | def reset_states(self):
67 |
68 | self.model.reset_states()
69 | self.triggered = False
70 | self.temp_end = 0
71 | self.current_sample = 0
72 |
73 | @torch.no_grad()
74 | def __call__(self, x):
75 | """
76 | x: torch.Tensor
77 | audio chunk (see examples in repo)
78 |
79 | return_seconds: bool (default - False)
80 | whether return timestamps in seconds (default - samples)
81 | """
82 |
83 | if not torch.is_tensor(x):
84 | try:
85 | x = torch.Tensor(x)
86 | except:
87 | raise TypeError("Audio cannot be casted to tensor. Cast it manually")
88 |
89 | window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
90 | self.current_sample += window_size_samples
91 |
92 | speech_prob = self.model(x, self.sampling_rate).item()
93 |
94 | if (speech_prob >= self.threshold) and self.temp_end:
95 | self.temp_end = 0
96 |
97 | if (speech_prob >= self.threshold) and not self.triggered:
98 | self.triggered = True
99 | return None
100 |
101 | if (speech_prob < self.threshold - 0.15) and self.triggered:
102 | if not self.temp_end:
103 | self.temp_end = self.current_sample
104 | if self.current_sample - self.temp_end < self.min_silence_samples:
105 | return None
106 | else:
107 | # end of speak
108 | self.temp_end = 0
109 | self.triggered = False
110 | spoken_utterance = self.buffer
111 | self.buffer = []
112 | return spoken_utterance
113 |
114 | if self.triggered:
115 | self.buffer.append(x)
116 |
117 | return None
118 |
--------------------------------------------------------------------------------