├── .gitignore ├── .dockerignore ├── logo.png ├── Dockerfile ├── Dockerfile.arm64 ├── requirements.txt ├── requirements_mac.txt ├── utils ├── utils.py └── thread_manager.py ├── arguments_classes ├── socket_sender_arguments.py ├── paraformer_stt_arguments.py ├── chat_tts_arguments.py ├── melo_tts_arguments.py ├── socket_receiver_arguments.py ├── module_arguments.py ├── vad_arguments.py ├── bedrock_language_model_arguments.py ├── whisper_stt_arguments.py ├── parler_tts_arguments.py ├── mlx_language_model_arguments.py └── language_model_arguments.py ├── LLM ├── chat.py ├── mlx_language_model.py ├── bedrock_language_model.py └── language_model.py ├── docker-compose.yml ├── connections ├── socket_sender.py ├── local_audio_streamer.py └── socket_receiver.py ├── STT ├── paraformer_handler.py ├── lightning_whisper_mlx_handler.py └── whisper_stt_handler.py ├── baseHandler.py ├── TTS ├── chatTTS_handler.py ├── melo_handler.py └── parler_handler.py ├── VAD ├── vad_handler.py └── vad_iterator.py ├── listen_and_play.py ├── README.md ├── LICENSE ├── s2s_pipeline.py └── s2s_pipeline_bedrock.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | tmp 3 | cache -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | tmp 2 | cache 3 | Dockerfile 4 | docker-compose.yml 5 | .dockerignore 6 | .gitignore 7 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/machinelearnear/speech-to-speech-amazon-bedrock-sagemaker/main/logo.png -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-devel 2 | 3 | ENV PYTHONUNBUFFERED 1 4 | 5 | WORKDIR /usr/src/app 6 | 7 | # Install packages 8 | RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* 9 | 10 | COPY requirements.txt ./ 11 | RUN pip install --no-cache-dir -r requirements.txt 12 | 13 | COPY . . 14 | -------------------------------------------------------------------------------- /Dockerfile.arm64: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3 2 | 3 | ENV PYTHONUNBUFFERED 1 4 | 5 | WORKDIR /usr/src/app 6 | 7 | # Install packages 8 | RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* 9 | 10 | COPY requirements.txt ./ 11 | RUN pip install --no-cache-dir -r requirements.txt 12 | 13 | COPY . . -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk==3.9.1 2 | parler_tts @ git+https://github.com/huggingface/parler-tts.git 3 | melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers 4 | torch==2.4.0 5 | sounddevice==0.5.0 6 | ChatTTS>=0.1.1 7 | funasr>=1.1.6 8 | modelscope>=1.17.1 9 | deepfilternet>=0.5.6 10 | boto3==1.35.18 -------------------------------------------------------------------------------- /requirements_mac.txt: -------------------------------------------------------------------------------- 1 | nltk==3.9.1 2 | parler_tts @ git+https://github.com/huggingface/parler-tts.git 3 | melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers 4 | torch==2.4.0 5 | sounddevice==0.5.0 6 | lightning-whisper-mlx>=0.0.10 7 | mlx-lm>=0.14.0 8 | ChatTTS>=0.1.1 9 | funasr>=1.1.6 10 | modelscope>=1.17.1 11 | deepfilternet>=0.5.6 12 | boto3==1.35.18 -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def next_power_of_2(x): 5 | return 1 if x == 0 else 2 ** (x - 1).bit_length() 6 | 7 | 8 | def int2float(sound): 9 | """ 10 | Taken from https://github.com/snakers4/silero-vad 11 | """ 12 | 13 | abs_max = np.abs(sound).max() 14 | sound = sound.astype("float32") 15 | if abs_max > 0: 16 | sound *= 1 / 32768 17 | sound = sound.squeeze() # depends on the use case 18 | return sound 19 | -------------------------------------------------------------------------------- /arguments_classes/socket_sender_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class SocketSenderArguments: 6 | send_host: str = field( 7 | default="localhost", 8 | metadata={ 9 | "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all " 10 | "available interfaces on the host machine." 11 | }, 12 | ) 13 | send_port: int = field( 14 | default=12346, 15 | metadata={ 16 | "help": "The port number on which the socket server listens. Default is 12346." 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /arguments_classes/paraformer_stt_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class ParaformerSTTHandlerArguments: 6 | paraformer_stt_model_name: str = field( 7 | default="paraformer-zh", 8 | metadata={ 9 | "help": "The pretrained model to use. Default is 'paraformer-zh'. Can be choose from https://github.com/modelscope/FunASR" 10 | }, 11 | ) 12 | paraformer_stt_device: str = field( 13 | default="cuda", 14 | metadata={ 15 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 16 | }, 17 | ) 18 | -------------------------------------------------------------------------------- /utils/thread_manager.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | 4 | class ThreadManager: 5 | """ 6 | Manages multiple threads used to execute given handler tasks. 7 | """ 8 | 9 | def __init__(self, handlers): 10 | self.handlers = handlers 11 | self.threads = [] 12 | 13 | def start(self): 14 | for handler in self.handlers: 15 | thread = threading.Thread(target=handler.run) 16 | self.threads.append(thread) 17 | thread.start() 18 | 19 | def stop(self): 20 | for handler in self.handlers: 21 | handler.stop_event.set() 22 | for thread in self.threads: 23 | thread.join() 24 | -------------------------------------------------------------------------------- /arguments_classes/chat_tts_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class ChatTTSHandlerArguments: 6 | chat_tts_stream: bool = field( 7 | default=True, 8 | metadata={"help": "The tts mode is stream Default is 'stream'."}, 9 | ) 10 | chat_tts_device: str = field( 11 | default="cuda", 12 | metadata={ 13 | "help": "The device to be used for speech synthesis. Default is 'cuda'." 14 | }, 15 | ) 16 | chat_tts_chunk_size: int = field( 17 | default=512, 18 | metadata={ 19 | "help": "Sets the size of the audio data chunk processed per cycle, balancing playback latency and CPU load.. Default is 512。." 20 | }, 21 | ) 22 | -------------------------------------------------------------------------------- /arguments_classes/melo_tts_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class MeloTTSHandlerArguments: 6 | melo_language: str = field( 7 | default="en", 8 | metadata={ 9 | "help": "The language of the text to be synthesized. Default is 'EN_NEWEST'." 10 | }, 11 | ) 12 | melo_device: str = field( 13 | default="auto", 14 | metadata={ 15 | "help": "The device to be used for speech synthesis. Default is 'auto'." 16 | }, 17 | ) 18 | melo_speaker_to_id: str = field( 19 | default="en", 20 | metadata={ 21 | "help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']." 22 | }, 23 | ) 24 | -------------------------------------------------------------------------------- /LLM/chat.py: -------------------------------------------------------------------------------- 1 | class Chat: 2 | """ 3 | Handles the chat using to avoid OOM issues. 4 | """ 5 | 6 | def __init__(self, size): 7 | self.size = size 8 | self.init_chat_message = None 9 | # maxlen is necessary pair, since a each new step we add an prompt and assitant answer 10 | self.buffer = [] 11 | 12 | def append(self, item): 13 | self.buffer.append(item) 14 | if len(self.buffer) == 2 * (self.size + 1): 15 | self.buffer.pop(0) 16 | self.buffer.pop(0) 17 | 18 | def init_chat(self, init_chat_message): 19 | self.init_chat_message = init_chat_message 20 | 21 | def to_list(self): 22 | if self.init_chat_message: 23 | return [self.init_chat_message] + self.buffer 24 | else: 25 | return self.buffer 26 | -------------------------------------------------------------------------------- /arguments_classes/socket_receiver_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class SocketReceiverArguments: 6 | recv_host: str = field( 7 | default="localhost", 8 | metadata={ 9 | "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all " 10 | "available interfaces on the host machine." 11 | }, 12 | ) 13 | recv_port: int = field( 14 | default=12345, 15 | metadata={ 16 | "help": "The port number on which the socket server listens. Default is 12346." 17 | }, 18 | ) 19 | chunk_size: int = field( 20 | default=1024, 21 | metadata={ 22 | "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes." 23 | }, 24 | ) 25 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | --- 2 | services: 3 | 4 | pipeline: 5 | build: 6 | context: . 7 | dockerfile: ${DOCKERFILE:-Dockerfile} 8 | command: 9 | - python3 10 | - s2s_pipeline.py 11 | - --recv_host 12 | - 0.0.0.0 13 | - --send_host 14 | - 0.0.0.0 15 | - --lm_model_name 16 | - microsoft/Phi-3-mini-4k-instruct 17 | - --init_chat_role 18 | - system 19 | - --init_chat_prompt 20 | - "You are a helpful assistant" 21 | - --stt_compile_mode 22 | - reduce-overhead 23 | - --tts_compile_mode 24 | - default 25 | expose: 26 | - 12345/tcp 27 | - 12346/tcp 28 | ports: 29 | - 12345:12345/tcp 30 | - 12346:12346/tcp 31 | volumes: 32 | - ./cache/:/root/.cache/ 33 | - ./s2s_pipeline.py:/usr/src/app/s2s_pipeline.py 34 | deploy: 35 | resources: 36 | reservations: 37 | devices: 38 | - driver: nvidia 39 | device_ids: ['0'] 40 | capabilities: [gpu] 41 | -------------------------------------------------------------------------------- /connections/socket_sender.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from rich.console import Console 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | console = Console() 8 | 9 | 10 | class SocketSender: 11 | """ 12 | Handles sending generated audio packets to the clients. 13 | """ 14 | 15 | def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346): 16 | self.stop_event = stop_event 17 | self.queue_in = queue_in 18 | self.host = host 19 | self.port = port 20 | 21 | def run(self): 22 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 23 | self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 24 | self.socket.bind((self.host, self.port)) 25 | self.socket.listen(1) 26 | logger.info("Sender waiting to be connected...") 27 | self.conn, _ = self.socket.accept() 28 | logger.info("sender connected") 29 | 30 | while not self.stop_event.is_set(): 31 | audio_chunk = self.queue_in.get() 32 | self.conn.sendall(audio_chunk) 33 | if isinstance(audio_chunk, bytes) and audio_chunk == b"END": 34 | break 35 | self.conn.close() 36 | logger.info("Sender closed") 37 | -------------------------------------------------------------------------------- /connections/local_audio_streamer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import sounddevice as sd 3 | import numpy as np 4 | 5 | import time 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class LocalAudioStreamer: 12 | def __init__( 13 | self, 14 | input_queue, 15 | output_queue, 16 | list_play_chunk_size=512, 17 | ): 18 | self.list_play_chunk_size = list_play_chunk_size 19 | 20 | self.stop_event = threading.Event() 21 | self.input_queue = input_queue 22 | self.output_queue = output_queue 23 | 24 | def run(self): 25 | def callback(indata, outdata, frames, time, status): 26 | if self.output_queue.empty(): 27 | self.input_queue.put(indata.copy()) 28 | outdata[:] = 0 * outdata 29 | else: 30 | outdata[:] = self.output_queue.get()[:, np.newaxis] 31 | 32 | logger.debug("Available devices:") 33 | logger.debug(sd.query_devices()) 34 | with sd.Stream( 35 | samplerate=16000, 36 | dtype="int16", 37 | channels=1, 38 | callback=callback, 39 | blocksize=self.list_play_chunk_size, 40 | ): 41 | logger.info("Starting local audio stream") 42 | while not self.stop_event.is_set(): 43 | time.sleep(0.001) 44 | print("Stopping recording") 45 | -------------------------------------------------------------------------------- /arguments_classes/module_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModuleArguments: 7 | device: Optional[str] = field( 8 | default=None, 9 | metadata={"help": "If specified, overrides the device for all handlers."}, 10 | ) 11 | mode: Optional[str] = field( 12 | default="socket", 13 | metadata={ 14 | "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'socket'." 15 | }, 16 | ) 17 | local_mac_optimal_settings: bool = field( 18 | default=False, 19 | metadata={ 20 | "help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used." 21 | }, 22 | ) 23 | stt: Optional[str] = field( 24 | default="whisper", 25 | metadata={ 26 | "help": "The STT to use. Either 'whisper', 'whisper-mlx', and 'paraformer'. Default is 'whisper'." 27 | }, 28 | ) 29 | llm: Optional[str] = field( 30 | default="transformers", 31 | metadata={ 32 | "help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'" 33 | }, 34 | ) 35 | tts: Optional[str] = field( 36 | default="parler", 37 | metadata={ 38 | "help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'" 39 | }, 40 | ) 41 | log_level: str = field( 42 | default="info", 43 | metadata={ 44 | "help": "Provide logging level. Example --log_level debug, default=warning." 45 | }, 46 | ) 47 | -------------------------------------------------------------------------------- /arguments_classes/vad_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class VADHandlerArguments: 6 | thresh: float = field( 7 | default=0.3, 8 | metadata={ 9 | "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection." 10 | }, 11 | ) 12 | sample_rate: int = field( 13 | default=16000, 14 | metadata={ 15 | "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio." 16 | }, 17 | ) 18 | min_silence_ms: int = field( 19 | default=250, 20 | metadata={ 21 | "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms." 22 | }, 23 | ) 24 | min_speech_ms: int = field( 25 | default=500, 26 | metadata={ 27 | "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms." 28 | }, 29 | ) 30 | max_speech_ms: float = field( 31 | default=float("inf"), 32 | metadata={ 33 | "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments." 34 | }, 35 | ) 36 | speech_pad_ms: int = field( 37 | default=500, 38 | metadata={ 39 | "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms." 40 | }, 41 | ) 42 | audio_enhancement: bool = field( 43 | default=False, 44 | metadata={ 45 | "help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is False." 46 | }, 47 | ) 48 | -------------------------------------------------------------------------------- /STT/paraformer_handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import perf_counter 3 | 4 | from baseHandler import BaseHandler 5 | from funasr import AutoModel 6 | import numpy as np 7 | from rich.console import Console 8 | import torch 9 | 10 | logging.basicConfig( 11 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 12 | ) 13 | logger = logging.getLogger(__name__) 14 | 15 | console = Console() 16 | 17 | 18 | class ParaformerSTTHandler(BaseHandler): 19 | """ 20 | Handles the Speech To Text generation using a Paraformer model. 21 | The default for this model is set to Chinese. 22 | This model was contributed by @wuhongsheng. 23 | """ 24 | 25 | def setup( 26 | self, 27 | model_name="paraformer-zh", 28 | device="cuda", 29 | gen_kwargs={}, 30 | ): 31 | print(model_name) 32 | if len(model_name.split("/")) > 1: 33 | model_name = model_name.split("/")[-1] 34 | self.device = device 35 | self.model = AutoModel(model=model_name, device=device) 36 | self.warmup() 37 | 38 | def warmup(self): 39 | logger.info(f"Warming up {self.__class__.__name__}") 40 | 41 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture 42 | n_steps = 1 43 | dummy_input = np.array([0] * 512, dtype=np.float32) 44 | for _ in range(n_steps): 45 | _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "") 46 | 47 | def process(self, spoken_prompt): 48 | logger.debug("infering paraformer...") 49 | 50 | global pipeline_start 51 | pipeline_start = perf_counter() 52 | 53 | pred_text = ( 54 | self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "") 55 | ) 56 | torch.mps.empty_cache() 57 | 58 | logger.debug("finished paraformer inference") 59 | console.print(f"[yellow]USER: {pred_text}") 60 | 61 | yield pred_text 62 | -------------------------------------------------------------------------------- /connections/socket_receiver.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from rich.console import Console 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | console = Console() 8 | 9 | 10 | class SocketReceiver: 11 | """ 12 | Handles reception of the audio packets from the client. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | stop_event, 18 | queue_out, 19 | should_listen, 20 | host="0.0.0.0", 21 | port=12345, 22 | chunk_size=1024, 23 | ): 24 | self.stop_event = stop_event 25 | self.queue_out = queue_out 26 | self.should_listen = should_listen 27 | self.chunk_size = chunk_size 28 | self.host = host 29 | self.port = port 30 | 31 | def receive_full_chunk(self, conn, chunk_size): 32 | data = b"" 33 | while len(data) < chunk_size: 34 | packet = conn.recv(chunk_size - len(data)) 35 | if not packet: 36 | # connection closed 37 | return None 38 | data += packet 39 | return data 40 | 41 | def run(self): 42 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 43 | self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 44 | self.socket.bind((self.host, self.port)) 45 | self.socket.listen(1) 46 | logger.info("Receiver waiting to be connected...") 47 | self.conn, _ = self.socket.accept() 48 | logger.info("receiver connected") 49 | 50 | self.should_listen.set() 51 | while not self.stop_event.is_set(): 52 | audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size) 53 | if audio_chunk is None: 54 | # connection closed 55 | self.queue_out.put(b"END") 56 | break 57 | if self.should_listen.is_set(): 58 | self.queue_out.put(audio_chunk) 59 | self.conn.close() 60 | logger.info("Receiver closed") 61 | -------------------------------------------------------------------------------- /baseHandler.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class BaseHandler: 8 | """ 9 | Base class for pipeline parts. Each part of the pipeline has an input and an output queue. 10 | The `setup` method along with `setup_args` and `setup_kwargs` can be used to address the specific requirements of the implemented pipeline part. 11 | To stop a handler properly, set the stop_event and, to avoid queue deadlocks, place b"END" in the input queue. 12 | Objects placed in the input queue will be processed by the `process` method, and the yielded results will be placed in the output queue. 13 | The cleanup method handles stopping the handler, and b"END" is placed in the output queue. 14 | """ 15 | 16 | def __init__(self, stop_event, queue_in, queue_out, setup_args=(), setup_kwargs={}): 17 | self.stop_event = stop_event 18 | self.queue_in = queue_in 19 | self.queue_out = queue_out 20 | self.setup(*setup_args, **setup_kwargs) 21 | self._times = [] 22 | 23 | def setup(self): 24 | pass 25 | 26 | def process(self): 27 | raise NotImplementedError 28 | 29 | def run(self): 30 | while not self.stop_event.is_set(): 31 | input = self.queue_in.get() 32 | if isinstance(input, bytes) and input == b"END": 33 | # sentinelle signal to avoid queue deadlock 34 | logger.debug("Stopping thread") 35 | break 36 | start_time = perf_counter() 37 | for output in self.process(input): 38 | self._times.append(perf_counter() - start_time) 39 | if self.last_time > self.min_time_to_debug: 40 | logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s") 41 | self.queue_out.put(output) 42 | start_time = perf_counter() 43 | 44 | self.cleanup() 45 | self.queue_out.put(b"END") 46 | 47 | @property 48 | def last_time(self): 49 | return self._times[-1] 50 | 51 | @property 52 | def min_time_to_debug(self): 53 | return 0.001 54 | 55 | def cleanup(self): 56 | pass 57 | -------------------------------------------------------------------------------- /arguments_classes/bedrock_language_model_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | @dataclass 4 | class BedrockLanguageModelHandlerArguments: 5 | model_id: str = field( 6 | default="anthropic.claude-3-sonnet-20240229-v1:0", 7 | metadata={ 8 | "help": "The Amazon Bedrock model ID to use. Default is 'anthropic.claude-3-sonnet-20240229-v1:0'." 9 | }, 10 | ) 11 | temperature: float = field( 12 | default=0.5, 13 | metadata={ 14 | "help": "Controls randomness in the model's output. Lower values make the output more focused and deterministic. Default is 0.5." 15 | }, 16 | ) 17 | top_k: int = field( 18 | default=200, 19 | metadata={ 20 | "help": "Limits the number of top tokens considered for each step of text generation. Default is 200." 21 | }, 22 | ) 23 | user_role: str = field( 24 | default="user", 25 | metadata={ 26 | "help": "Role assigned to the user in the chat context. Default is 'user'." 27 | }, 28 | ) 29 | chat_size: int = field( 30 | default=10, 31 | metadata={ 32 | "help": "Number of interactions to keep in the chat history. Default is 10." 33 | }, 34 | ) 35 | init_chat_role: str = field( 36 | default=None, 37 | metadata={ 38 | "help": "Initial role for setting up the chat context. Default is None." 39 | }, 40 | ) 41 | init_chat_prompt: str = field( 42 | default="You are a helpful AI assistant.", 43 | metadata={ 44 | "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" 45 | }, 46 | ) 47 | aws_region: str = field( 48 | default="us-east-1", 49 | metadata={ 50 | "help": "The AWS region where the Bedrock service is located. Default is 'us-east-1'." 51 | }, 52 | ) 53 | aws_access_key_id: str = field( 54 | default=None, 55 | metadata={ 56 | "help": "AWS access key ID for authentication. If not provided, will use the default AWS configuration." 57 | }, 58 | ) 59 | aws_secret_access_key: str = field( 60 | default=None, 61 | metadata={ 62 | "help": "AWS secret access key for authentication. If not provided, will use the default AWS configuration." 63 | }, 64 | ) -------------------------------------------------------------------------------- /arguments_classes/whisper_stt_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class WhisperSTTHandlerArguments: 7 | stt_model_name: str = field( 8 | default="distil-whisper/distil-large-v3", 9 | metadata={ 10 | "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'." 11 | }, 12 | ) 13 | stt_device: str = field( 14 | default="cuda", 15 | metadata={ 16 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 17 | }, 18 | ) 19 | stt_torch_dtype: str = field( 20 | default="float16", 21 | metadata={ 22 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." 23 | }, 24 | ) 25 | stt_compile_mode: str = field( 26 | default=None, 27 | metadata={ 28 | "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" 29 | }, 30 | ) 31 | stt_gen_max_new_tokens: int = field( 32 | default=128, 33 | metadata={ 34 | "help": "The maximum number of new tokens to generate. Default is 128." 35 | }, 36 | ) 37 | stt_gen_num_beams: int = field( 38 | default=1, 39 | metadata={ 40 | "help": "The number of beams for beam search. Default is 1, implying greedy decoding." 41 | }, 42 | ) 43 | stt_gen_return_timestamps: bool = field( 44 | default=False, 45 | metadata={ 46 | "help": "Whether to return timestamps with transcriptions. Default is False." 47 | }, 48 | ) 49 | stt_gen_task: str = field( 50 | default="transcribe", 51 | metadata={ 52 | "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'." 53 | }, 54 | ) 55 | language: Optional[str] = field( 56 | default='en', 57 | metadata={ 58 | "help": """The language for the conversation. 59 | Choose between 'en' (english), 'fr' (french), 'es' (spanish), 60 | 'zh' (chinese), 'ko' (korean), 'ja' (japanese), or 'None'. 61 | If using 'auto', the language is automatically detected and can 62 | change during the conversation. Default is 'en'.""" 63 | }, 64 | ) -------------------------------------------------------------------------------- /arguments_classes/parler_tts_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class ParlerTTSHandlerArguments: 6 | tts_model_name: str = field( 7 | default="ylacombe/parler-tts-mini-jenny-30H", 8 | metadata={ 9 | "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'." 10 | }, 11 | ) 12 | tts_device: str = field( 13 | default="cuda", 14 | metadata={ 15 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 16 | }, 17 | ) 18 | tts_torch_dtype: str = field( 19 | default="float16", 20 | metadata={ 21 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." 22 | }, 23 | ) 24 | tts_compile_mode: str = field( 25 | default=None, 26 | metadata={ 27 | "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)" 28 | }, 29 | ) 30 | tts_gen_min_new_tokens: int = field( 31 | default=64, 32 | metadata={ 33 | "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs" 34 | }, 35 | ) 36 | tts_gen_max_new_tokens: int = field( 37 | default=512, 38 | metadata={ 39 | "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs" 40 | }, 41 | ) 42 | description: str = field( 43 | default=( 44 | "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " 45 | "She speaks very fast." 46 | ), 47 | metadata={ 48 | "help": "Description of the speaker's voice and speaking style to guide the TTS model." 49 | }, 50 | ) 51 | play_steps_s: float = field( 52 | default=1.0, 53 | metadata={ 54 | "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds." 55 | }, 56 | ) 57 | max_prompt_pad_length: int = field( 58 | default=8, 59 | metadata={ 60 | "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible." 61 | }, 62 | ) 63 | -------------------------------------------------------------------------------- /arguments_classes/mlx_language_model_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class MLXLanguageModelHandlerArguments: 6 | mlx_lm_model_name: str = field( 7 | default="mlx-community/SmolLM-360M-Instruct", 8 | metadata={ 9 | "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." 10 | }, 11 | ) 12 | mlx_lm_device: str = field( 13 | default="mps", 14 | metadata={ 15 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 16 | }, 17 | ) 18 | mlx_lm_torch_dtype: str = field( 19 | default="float16", 20 | metadata={ 21 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." 22 | }, 23 | ) 24 | mlx_lm_user_role: str = field( 25 | default="user", 26 | metadata={ 27 | "help": "Role assigned to the user in the chat context. Default is 'user'." 28 | }, 29 | ) 30 | mlx_lm_init_chat_role: str = field( 31 | default="system", 32 | metadata={ 33 | "help": "Initial role for setting up the chat context. Default is 'system'." 34 | }, 35 | ) 36 | mlx_lm_init_chat_prompt: str = field( 37 | default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", 38 | metadata={ 39 | "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" 40 | }, 41 | ) 42 | mlx_lm_gen_max_new_tokens: int = field( 43 | default=128, 44 | metadata={ 45 | "help": "Maximum number of new tokens to generate in a single completion. Default is 128." 46 | }, 47 | ) 48 | mlx_lm_gen_temperature: float = field( 49 | default=0.0, 50 | metadata={ 51 | "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." 52 | }, 53 | ) 54 | mlx_lm_gen_do_sample: bool = field( 55 | default=False, 56 | metadata={ 57 | "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." 58 | }, 59 | ) 60 | mlx_lm_chat_size: int = field( 61 | default=2, 62 | metadata={ 63 | "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." 64 | }, 65 | ) 66 | -------------------------------------------------------------------------------- /arguments_classes/language_model_arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class LanguageModelHandlerArguments: 6 | lm_model_name: str = field( 7 | default="HuggingFaceTB/SmolLM-360M-Instruct", 8 | metadata={ 9 | "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'." 10 | }, 11 | ) 12 | lm_device: str = field( 13 | default="cuda", 14 | metadata={ 15 | "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration." 16 | }, 17 | ) 18 | lm_torch_dtype: str = field( 19 | default="float16", 20 | metadata={ 21 | "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)." 22 | }, 23 | ) 24 | user_role: str = field( 25 | default="user", 26 | metadata={ 27 | "help": "Role assigned to the user in the chat context. Default is 'user'." 28 | }, 29 | ) 30 | init_chat_role: str = field( 31 | default="system", 32 | metadata={ 33 | "help": "Initial role for setting up the chat context. Default is 'system'." 34 | }, 35 | ) 36 | init_chat_prompt: str = field( 37 | default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", 38 | metadata={ 39 | "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" 40 | }, 41 | ) 42 | lm_gen_max_new_tokens: int = field( 43 | default=128, 44 | metadata={ 45 | "help": "Maximum number of new tokens to generate in a single completion. Default is 128." 46 | }, 47 | ) 48 | lm_gen_min_new_tokens: int = field( 49 | default=0, 50 | metadata={ 51 | "help": "Minimum number of new tokens to generate in a single completion. Default is 0." 52 | }, 53 | ) 54 | lm_gen_temperature: float = field( 55 | default=0.0, 56 | metadata={ 57 | "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." 58 | }, 59 | ) 60 | lm_gen_do_sample: bool = field( 61 | default=False, 62 | metadata={ 63 | "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False." 64 | }, 65 | ) 66 | chat_size: int = field( 67 | default=2, 68 | metadata={ 69 | "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." 70 | }, 71 | ) 72 | -------------------------------------------------------------------------------- /STT/lightning_whisper_mlx_handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import perf_counter 3 | from baseHandler import BaseHandler 4 | from lightning_whisper_mlx import LightningWhisperMLX 5 | import numpy as np 6 | from rich.console import Console 7 | from copy import copy 8 | import torch 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | console = Console() 13 | 14 | SUPPORTED_LANGUAGES = [ 15 | "en", 16 | "fr", 17 | "es", 18 | "zh", 19 | "ja", 20 | "ko", 21 | ] 22 | 23 | 24 | class LightningWhisperSTTHandler(BaseHandler): 25 | """ 26 | Handles the Speech To Text generation using a Whisper model. 27 | """ 28 | 29 | def setup( 30 | self, 31 | model_name="distil-large-v3", 32 | device="mps", 33 | torch_dtype="float16", 34 | compile_mode=None, 35 | language=None, 36 | gen_kwargs={}, 37 | ): 38 | if len(model_name.split("/")) > 1: 39 | model_name = model_name.split("/")[-1] 40 | self.device = device 41 | self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None) 42 | self.start_language = language 43 | self.last_language = language 44 | 45 | self.warmup() 46 | 47 | def warmup(self): 48 | logger.info(f"Warming up {self.__class__.__name__}") 49 | 50 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture 51 | n_steps = 1 52 | dummy_input = np.array([0] * 512) 53 | 54 | for _ in range(n_steps): 55 | _ = self.model.transcribe(dummy_input)["text"].strip() 56 | 57 | def process(self, spoken_prompt): 58 | logger.debug("infering whisper...") 59 | 60 | global pipeline_start 61 | pipeline_start = perf_counter() 62 | 63 | if self.start_language != 'auto': 64 | transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language) 65 | else: 66 | transcription_dict = self.model.transcribe(spoken_prompt) 67 | language_code = transcription_dict["language"] 68 | if language_code not in SUPPORTED_LANGUAGES: 69 | logger.warning(f"Whisper detected unsupported language: {language_code}") 70 | if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language 71 | transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language) 72 | else: 73 | transcription_dict = {"text": "", "language": "en"} 74 | else: 75 | self.last_language = language_code 76 | 77 | pred_text = transcription_dict["text"].strip() 78 | language_code = transcription_dict["language"] 79 | torch.mps.empty_cache() 80 | 81 | logger.debug("finished whisper inference") 82 | console.print(f"[yellow]USER: {pred_text}") 83 | logger.debug(f"Language Code Whisper: {language_code}") 84 | 85 | yield (pred_text, language_code) 86 | -------------------------------------------------------------------------------- /TTS/chatTTS_handler.py: -------------------------------------------------------------------------------- 1 | import ChatTTS 2 | import logging 3 | from baseHandler import BaseHandler 4 | import librosa 5 | import numpy as np 6 | from rich.console import Console 7 | import torch 8 | 9 | logging.basicConfig( 10 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 11 | ) 12 | logger = logging.getLogger(__name__) 13 | 14 | console = Console() 15 | 16 | 17 | class ChatTTSHandler(BaseHandler): 18 | def setup( 19 | self, 20 | should_listen, 21 | device="cuda", 22 | gen_kwargs={}, # Unused 23 | stream=True, 24 | chunk_size=512, 25 | ): 26 | self.should_listen = should_listen 27 | self.device = device 28 | self.model = ChatTTS.Chat() 29 | self.model.load(compile=False) # Doesn't work for me with True 30 | self.chunk_size = chunk_size 31 | self.stream = stream 32 | rnd_spk_emb = self.model.sample_random_speaker() 33 | self.params_infer_code = ChatTTS.Chat.InferCodeParams( 34 | spk_emb=rnd_spk_emb, 35 | ) 36 | self.warmup() 37 | 38 | def warmup(self): 39 | logger.info(f"Warming up {self.__class__.__name__}") 40 | _ = self.model.infer("text") 41 | 42 | def process(self, llm_sentence): 43 | console.print(f"[green]ASSISTANT: {llm_sentence}") 44 | if self.device == "mps": 45 | import time 46 | 47 | start = time.time() 48 | torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete. 49 | torch.mps.empty_cache() # Frees all memory allocated by the MPS device. 50 | _ = ( 51 | time.time() - start 52 | ) # Removing this line makes it fail more often. I'm looking into it. 53 | 54 | wavs_gen = self.model.infer( 55 | llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream 56 | ) 57 | 58 | if self.stream: 59 | wavs = [np.array([])] 60 | for gen in wavs_gen: 61 | if gen[0] is None or len(gen[0]) == 0: 62 | self.should_listen.set() 63 | return 64 | audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000) 65 | audio_chunk = (audio_chunk * 32768).astype(np.int16)[0] 66 | while len(audio_chunk) > self.chunk_size: 67 | yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据 68 | audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据 69 | yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk))) 70 | else: 71 | wavs = wavs_gen 72 | if len(wavs[0]) == 0: 73 | self.should_listen.set() 74 | return 75 | audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000) 76 | audio_chunk = (audio_chunk * 32768).astype(np.int16) 77 | for i in range(0, len(audio_chunk), self.chunk_size): 78 | yield np.pad( 79 | audio_chunk[i : i + self.chunk_size], 80 | (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])), 81 | ) 82 | self.should_listen.set() 83 | -------------------------------------------------------------------------------- /VAD/vad_handler.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | from VAD.vad_iterator import VADIterator 3 | from baseHandler import BaseHandler 4 | import numpy as np 5 | import torch 6 | from rich.console import Console 7 | 8 | from utils.utils import int2float 9 | from df.enhance import enhance, init_df 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | console = Console() 15 | 16 | 17 | class VADHandler(BaseHandler): 18 | """ 19 | Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed 20 | to the following part. 21 | """ 22 | 23 | def setup( 24 | self, 25 | should_listen, 26 | thresh=0.3, 27 | sample_rate=16000, 28 | min_silence_ms=1000, 29 | min_speech_ms=500, 30 | max_speech_ms=float("inf"), 31 | speech_pad_ms=30, 32 | audio_enhancement=False, 33 | ): 34 | self.should_listen = should_listen 35 | self.sample_rate = sample_rate 36 | self.min_silence_ms = min_silence_ms 37 | self.min_speech_ms = min_speech_ms 38 | self.max_speech_ms = max_speech_ms 39 | self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad") 40 | self.iterator = VADIterator( 41 | self.model, 42 | threshold=thresh, 43 | sampling_rate=sample_rate, 44 | min_silence_duration_ms=min_silence_ms, 45 | speech_pad_ms=speech_pad_ms, 46 | ) 47 | self.audio_enhancement = audio_enhancement 48 | if audio_enhancement: 49 | self.enhanced_model, self.df_state, _ = init_df() 50 | 51 | def process(self, audio_chunk): 52 | audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16) 53 | audio_float32 = int2float(audio_int16) 54 | vad_output = self.iterator(torch.from_numpy(audio_float32)) 55 | if vad_output is not None and len(vad_output) != 0: 56 | logger.debug("VAD: end of speech detected") 57 | array = torch.cat(vad_output).cpu().numpy() 58 | duration_ms = len(array) / self.sample_rate * 1000 59 | if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms: 60 | logger.debug( 61 | f"audio input of duration: {len(array) / self.sample_rate}s, skipping" 62 | ) 63 | else: 64 | self.should_listen.clear() 65 | logger.debug("Stop listening") 66 | if self.audio_enhancement: 67 | if self.sample_rate != self.df_state.sr(): 68 | audio_float32 = torchaudio.functional.resample( 69 | torch.from_numpy(array), 70 | orig_freq=self.sample_rate, 71 | new_freq=self.df_state.sr(), 72 | ) 73 | enhanced = enhance( 74 | self.enhanced_model, 75 | self.df_state, 76 | audio_float32.unsqueeze(0), 77 | ) 78 | enhanced = torchaudio.functional.resample( 79 | enhanced, 80 | orig_freq=self.df_state.sr(), 81 | new_freq=self.sample_rate, 82 | ) 83 | else: 84 | enhanced = enhance( 85 | self.enhanced_model, self.df_state, audio_float32 86 | ) 87 | array = enhanced.numpy().squeeze() 88 | yield array 89 | 90 | @property 91 | def min_time_to_debug(self): 92 | return 0.00001 93 | -------------------------------------------------------------------------------- /VAD/vad_iterator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class VADIterator: 5 | def __init__( 6 | self, 7 | model, 8 | threshold: float = 0.5, 9 | sampling_rate: int = 16000, 10 | min_silence_duration_ms: int = 100, 11 | speech_pad_ms: int = 30, 12 | ): 13 | """ 14 | Mainly taken from https://github.com/snakers4/silero-vad 15 | Class for stream imitation 16 | 17 | Parameters 18 | ---------- 19 | model: preloaded .jit/.onnx silero VAD model 20 | 21 | threshold: float (default - 0.5) 22 | Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. 23 | It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. 24 | 25 | sampling_rate: int (default - 16000) 26 | Currently silero VAD models support 8000 and 16000 sample rates 27 | 28 | min_silence_duration_ms: int (default - 100 milliseconds) 29 | In the end of each speech chunk wait for min_silence_duration_ms before separating it 30 | 31 | speech_pad_ms: int (default - 30 milliseconds) 32 | Final speech chunks are padded by speech_pad_ms each side 33 | """ 34 | 35 | self.model = model 36 | self.threshold = threshold 37 | self.sampling_rate = sampling_rate 38 | self.is_speaking = False 39 | self.buffer = [] 40 | 41 | if sampling_rate not in [8000, 16000]: 42 | raise ValueError( 43 | "VADIterator does not support sampling rates other than [8000, 16000]" 44 | ) 45 | 46 | self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 47 | self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 48 | self.reset_states() 49 | 50 | def reset_states(self): 51 | self.model.reset_states() 52 | self.triggered = False 53 | self.temp_end = 0 54 | self.current_sample = 0 55 | 56 | @torch.no_grad() 57 | def __call__(self, x): 58 | """ 59 | x: torch.Tensor 60 | audio chunk (see examples in repo) 61 | 62 | return_seconds: bool (default - False) 63 | whether return timestamps in seconds (default - samples) 64 | """ 65 | 66 | if not torch.is_tensor(x): 67 | try: 68 | x = torch.Tensor(x) 69 | except Exception: 70 | raise TypeError("Audio cannot be casted to tensor. Cast it manually") 71 | 72 | window_size_samples = len(x[0]) if x.dim() == 2 else len(x) 73 | self.current_sample += window_size_samples 74 | 75 | speech_prob = self.model(x, self.sampling_rate).item() 76 | 77 | if (speech_prob >= self.threshold) and self.temp_end: 78 | self.temp_end = 0 79 | 80 | if (speech_prob >= self.threshold) and not self.triggered: 81 | self.triggered = True 82 | return None 83 | 84 | if (speech_prob < self.threshold - 0.15) and self.triggered: 85 | if not self.temp_end: 86 | self.temp_end = self.current_sample 87 | if self.current_sample - self.temp_end < self.min_silence_samples: 88 | return None 89 | else: 90 | # end of speak 91 | self.temp_end = 0 92 | self.triggered = False 93 | spoken_utterance = self.buffer 94 | self.buffer = [] 95 | return spoken_utterance 96 | 97 | if self.triggered: 98 | self.buffer.append(x) 99 | 100 | return None 101 | -------------------------------------------------------------------------------- /LLM/mlx_language_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from LLM.chat import Chat 3 | from baseHandler import BaseHandler 4 | from mlx_lm import load, stream_generate, generate 5 | from rich.console import Console 6 | import torch 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | console = Console() 11 | 12 | WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { 13 | "en": "english", 14 | "fr": "french", 15 | "es": "spanish", 16 | "zh": "chinese", 17 | "ja": "japanese", 18 | "ko": "korean", 19 | } 20 | 21 | class MLXLanguageModelHandler(BaseHandler): 22 | """ 23 | Handles the language model part. 24 | """ 25 | 26 | def setup( 27 | self, 28 | model_name="microsoft/Phi-3-mini-4k-instruct", 29 | device="mps", 30 | torch_dtype="float16", 31 | gen_kwargs={}, 32 | user_role="user", 33 | chat_size=1, 34 | init_chat_role=None, 35 | init_chat_prompt="You are a helpful AI assistant.", 36 | ): 37 | self.model_name = model_name 38 | self.model, self.tokenizer = load(self.model_name) 39 | self.gen_kwargs = gen_kwargs 40 | 41 | self.chat = Chat(chat_size) 42 | if init_chat_role: 43 | if not init_chat_prompt: 44 | raise ValueError( 45 | "An initial promt needs to be specified when setting init_chat_role." 46 | ) 47 | self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) 48 | self.user_role = user_role 49 | 50 | self.warmup() 51 | 52 | def warmup(self): 53 | logger.info(f"Warming up {self.__class__.__name__}") 54 | 55 | dummy_input_text = "Write me a poem about Machine Learning." 56 | dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] 57 | 58 | n_steps = 2 59 | 60 | for _ in range(n_steps): 61 | prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False) 62 | generate( 63 | self.model, 64 | self.tokenizer, 65 | prompt=prompt, 66 | max_tokens=self.gen_kwargs["max_new_tokens"], 67 | verbose=False, 68 | ) 69 | 70 | def process(self, prompt): 71 | logger.debug("infering language model...") 72 | language_code = None 73 | 74 | if isinstance(prompt, tuple): 75 | prompt, language_code = prompt 76 | prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt 77 | 78 | self.chat.append({"role": self.user_role, "content": prompt}) 79 | 80 | # Remove system messages if using a Gemma model 81 | if "gemma" in self.model_name.lower(): 82 | chat_messages = [ 83 | msg for msg in self.chat.to_list() if msg["role"] != "system" 84 | ] 85 | else: 86 | chat_messages = self.chat.to_list() 87 | 88 | prompt = self.tokenizer.apply_chat_template( 89 | chat_messages, tokenize=False, add_generation_prompt=True 90 | ) 91 | output = "" 92 | curr_output = "" 93 | for t in stream_generate( 94 | self.model, 95 | self.tokenizer, 96 | prompt, 97 | max_tokens=self.gen_kwargs["max_new_tokens"], 98 | ): 99 | output += t 100 | curr_output += t 101 | if curr_output.endswith((".", "?", "!", "<|end|>")): 102 | yield (curr_output.replace("<|end|>", ""), language_code) 103 | curr_output = "" 104 | generated_text = output.replace("<|end|>", "") 105 | torch.mps.empty_cache() 106 | 107 | self.chat.append({"role": "assistant", "content": generated_text}) -------------------------------------------------------------------------------- /TTS/melo_handler.py: -------------------------------------------------------------------------------- 1 | from melo.api import TTS 2 | import logging 3 | from baseHandler import BaseHandler 4 | import librosa 5 | import numpy as np 6 | from rich.console import Console 7 | import torch 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | console = Console() 12 | 13 | WHISPER_LANGUAGE_TO_MELO_LANGUAGE = { 14 | "en": "EN_NEWEST", 15 | "fr": "FR", 16 | "es": "ES", 17 | "zh": "ZH", 18 | "ja": "JP", 19 | "ko": "KR", 20 | } 21 | 22 | WHISPER_LANGUAGE_TO_MELO_SPEAKER = { 23 | "en": "EN-Newest", 24 | "fr": "FR", 25 | "es": "ES", 26 | "zh": "ZH", 27 | "ja": "JP", 28 | "ko": "KR", 29 | } 30 | 31 | 32 | class MeloTTSHandler(BaseHandler): 33 | def setup( 34 | self, 35 | should_listen, 36 | device="mps", 37 | language="en", 38 | speaker_to_id="en", 39 | gen_kwargs={}, # Unused 40 | blocksize=512, 41 | ): 42 | self.should_listen = should_listen 43 | self.device = device 44 | self.language = language 45 | self.model = TTS( 46 | language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device 47 | ) 48 | self.speaker_id = self.model.hps.data.spk2id[ 49 | WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id] 50 | ] 51 | self.blocksize = blocksize 52 | self.warmup() 53 | 54 | def warmup(self): 55 | logger.info(f"Warming up {self.__class__.__name__}") 56 | _ = self.model.tts_to_file("text", self.speaker_id, quiet=True) 57 | 58 | def process(self, llm_sentence): 59 | language_code = None 60 | 61 | if isinstance(llm_sentence, tuple): 62 | llm_sentence, language_code = llm_sentence 63 | 64 | console.print(f"[green]ASSISTANT: {llm_sentence}") 65 | 66 | if language_code is not None and self.language != language_code: 67 | try: 68 | self.model = TTS( 69 | language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code], 70 | device=self.device, 71 | ) 72 | self.speaker_id = self.model.hps.data.spk2id[ 73 | WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code] 74 | ] 75 | self.language = language_code 76 | except KeyError: 77 | console.print( 78 | f"[red]Language {language_code} not supported by Melo. Using {self.language} instead." 79 | ) 80 | 81 | if self.device == "mps": 82 | import time 83 | 84 | start = time.time() 85 | torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete. 86 | torch.mps.empty_cache() # Frees all memory allocated by the MPS device. 87 | _ = ( 88 | time.time() - start 89 | ) # Removing this line makes it fail more often. I'm looking into it. 90 | 91 | try: 92 | audio_chunk = self.model.tts_to_file( 93 | llm_sentence, self.speaker_id, quiet=True 94 | ) 95 | except (AssertionError, RuntimeError) as e: 96 | logger.error(f"Error in MeloTTSHandler: {e}") 97 | audio_chunk = np.array([]) 98 | if len(audio_chunk) == 0: 99 | self.should_listen.set() 100 | return 101 | audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) 102 | audio_chunk = (audio_chunk * 32768).astype(np.int16) 103 | for i in range(0, len(audio_chunk), self.blocksize): 104 | yield np.pad( 105 | audio_chunk[i : i + self.blocksize], 106 | (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), 107 | ) 108 | 109 | self.should_listen.set() 110 | -------------------------------------------------------------------------------- /LLM/bedrock_language_model.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.exceptions import ClientError 3 | from baseHandler import BaseHandler 4 | from threading import Thread 5 | from queue import Queue 6 | import logging 7 | from nltk import sent_tokenize 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class BedrockModelHandler(BaseHandler): 12 | def setup( 13 | self, 14 | model_id="anthropic.claude-3-sonnet-20240229-v1:0", 15 | temperature=0.5, 16 | top_k=200, 17 | user_role="user", 18 | chat_size=10, 19 | init_chat_role=None, 20 | init_chat_prompt="You are a helpful AI assistant.", 21 | ): 22 | self.model_id = model_id 23 | self.temperature = temperature 24 | self.top_k = top_k 25 | self.user_role = user_role 26 | self.chat = [] 27 | self.chat_size = chat_size 28 | 29 | if init_chat_role: 30 | if not init_chat_prompt: 31 | raise ValueError("An initial prompt needs to be specified when setting init_chat_role.") 32 | self.chat.append({"role": init_chat_role, "content": [{"text": init_chat_prompt}]}) 33 | 34 | self.bedrock_client = boto3.client(service_name='bedrock-runtime') 35 | 36 | self.warmup() 37 | 38 | def warmup(self): 39 | logger.info(f"Warming up {self.__class__.__name__}") 40 | dummy_input_text = "Repeat the word 'home'." 41 | self.process(dummy_input_text) 42 | logger.info(f"{self.__class__.__name__}: warmed up!") 43 | 44 | def process(self, prompt): 45 | logger.debug("inferring with language model...") 46 | language_code = None 47 | if isinstance(prompt, tuple): 48 | prompt, language_code = prompt 49 | prompt = f"Please reply to my message in {self._get_language_name(language_code)}. " + prompt 50 | 51 | self.chat.append({"role": self.user_role, "content": [{"text": prompt}]}) 52 | 53 | # Trim chat history if it exceeds chat_size 54 | if len(self.chat) > self.chat_size: 55 | self.chat = self.chat[-self.chat_size:] 56 | 57 | system_prompts = [{"text": "You are a helpful AI assistant."}] 58 | inference_config = {"temperature": self.temperature} 59 | additional_model_fields = {"top_k": self.top_k} 60 | 61 | try: 62 | response = self.bedrock_client.converse_stream( 63 | modelId=self.model_id, 64 | messages=self.chat, 65 | system=system_prompts, 66 | inferenceConfig=inference_config, 67 | additionalModelRequestFields=additional_model_fields 68 | ) 69 | 70 | stream = response.get('stream') 71 | if stream: 72 | generated_text = "" 73 | for event in stream: 74 | if 'contentBlockDelta' in event: 75 | new_text = event['contentBlockDelta']['delta']['text'] 76 | generated_text += new_text 77 | sentences = sent_tokenize(generated_text) 78 | if len(sentences) > 1: 79 | yield (sentences[0], language_code) 80 | generated_text = new_text 81 | 82 | # Don't forget the last sentence 83 | if generated_text: 84 | yield (generated_text, language_code) 85 | 86 | self.chat.append({"role": "assistant", "content": [{"text": generated_text}]}) 87 | 88 | except ClientError as err: 89 | message = err.response['Error']['Message'] 90 | logger.error("A client error occurred: %s", message) 91 | yield (f"An error occurred: {message}", language_code) 92 | 93 | def _get_language_name(self, language_code): 94 | language_map = { 95 | "en": "English", 96 | "fr": "French", 97 | "es": "Spanish", 98 | "zh": "Chinese", 99 | "ja": "Japanese", 100 | "ko": "Korean", 101 | } 102 | return language_map.get(language_code, "the same language as the input") -------------------------------------------------------------------------------- /listen_and_play.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import threading 3 | from queue import Queue 4 | from dataclasses import dataclass, field 5 | import sounddevice as sd 6 | from transformers import HfArgumentParser 7 | 8 | 9 | @dataclass 10 | class ListenAndPlayArguments: 11 | send_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."}) 12 | recv_rate: int = field(default=16000, metadata={"help": "In Hz. Default is 16000."}) 13 | list_play_chunk_size: int = field( 14 | default=1024, 15 | metadata={"help": "The size of data chunks (in bytes). Default is 1024."}, 16 | ) 17 | host: str = field( 18 | default="localhost", 19 | metadata={ 20 | "help": "The hostname or IP address for listening and playing. Default is 'localhost'." 21 | }, 22 | ) 23 | send_port: int = field( 24 | default=12345, 25 | metadata={"help": "The network port for sending data. Default is 12345."}, 26 | ) 27 | recv_port: int = field( 28 | default=12346, 29 | metadata={"help": "The network port for receiving data. Default is 12346."}, 30 | ) 31 | 32 | 33 | def listen_and_play( 34 | send_rate=16000, 35 | recv_rate=44100, 36 | list_play_chunk_size=1024, 37 | host="localhost", 38 | send_port=12345, 39 | recv_port=12346, 40 | ): 41 | send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 42 | send_socket.connect((host, send_port)) 43 | 44 | recv_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 45 | recv_socket.connect((host, recv_port)) 46 | 47 | print("Recording and streaming...") 48 | 49 | stop_event = threading.Event() 50 | recv_queue = Queue() 51 | send_queue = Queue() 52 | 53 | def callback_recv(outdata, frames, time, status): 54 | if not recv_queue.empty(): 55 | data = recv_queue.get() 56 | outdata[: len(data)] = data 57 | outdata[len(data) :] = b"\x00" * (len(outdata) - len(data)) 58 | else: 59 | outdata[:] = b"\x00" * len(outdata) 60 | 61 | def callback_send(indata, frames, time, status): 62 | if recv_queue.empty(): 63 | data = bytes(indata) 64 | send_queue.put(data) 65 | 66 | def send(stop_event, send_queue): 67 | while not stop_event.is_set(): 68 | data = send_queue.get() 69 | send_socket.sendall(data) 70 | 71 | def recv(stop_event, recv_queue): 72 | def receive_full_chunk(conn, chunk_size): 73 | data = b"" 74 | while len(data) < chunk_size: 75 | packet = conn.recv(chunk_size - len(data)) 76 | if not packet: 77 | return None # Connection has been closed 78 | data += packet 79 | return data 80 | 81 | while not stop_event.is_set(): 82 | data = receive_full_chunk(recv_socket, list_play_chunk_size * 2) 83 | if data: 84 | recv_queue.put(data) 85 | 86 | try: 87 | send_stream = sd.RawInputStream( 88 | samplerate=send_rate, 89 | channels=1, 90 | dtype="int16", 91 | blocksize=list_play_chunk_size, 92 | callback=callback_send, 93 | ) 94 | recv_stream = sd.RawOutputStream( 95 | samplerate=recv_rate, 96 | channels=1, 97 | dtype="int16", 98 | blocksize=list_play_chunk_size, 99 | callback=callback_recv, 100 | ) 101 | threading.Thread(target=send_stream.start).start() 102 | threading.Thread(target=recv_stream.start).start() 103 | 104 | send_thread = threading.Thread(target=send, args=(stop_event, send_queue)) 105 | send_thread.start() 106 | recv_thread = threading.Thread(target=recv, args=(stop_event, recv_queue)) 107 | recv_thread.start() 108 | 109 | input("Press Enter to stop...") 110 | 111 | except KeyboardInterrupt: 112 | print("Finished streaming.") 113 | 114 | finally: 115 | stop_event.set() 116 | recv_thread.join() 117 | send_thread.join() 118 | send_socket.close() 119 | recv_socket.close() 120 | print("Connection closed.") 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = HfArgumentParser((ListenAndPlayArguments,)) 125 | (listen_and_play_kwargs,) = parser.parse_args_into_dataclasses() 126 | listen_and_play(**vars(listen_and_play_kwargs)) 127 | -------------------------------------------------------------------------------- /LLM/language_model.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | from transformers import ( 3 | AutoModelForCausalLM, 4 | AutoTokenizer, 5 | pipeline, 6 | TextIteratorStreamer, 7 | ) 8 | import torch 9 | 10 | from LLM.chat import Chat 11 | from baseHandler import BaseHandler 12 | from rich.console import Console 13 | import logging 14 | from nltk import sent_tokenize 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | console = Console() 19 | 20 | 21 | WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { 22 | "en": "english", 23 | "fr": "french", 24 | "es": "spanish", 25 | "zh": "chinese", 26 | "ja": "japanese", 27 | "ko": "korean", 28 | } 29 | 30 | class LanguageModelHandler(BaseHandler): 31 | """ 32 | Handles the language model part. 33 | """ 34 | 35 | def setup( 36 | self, 37 | model_name="microsoft/Phi-3-mini-4k-instruct", 38 | device="cuda", 39 | torch_dtype="float16", 40 | gen_kwargs={}, 41 | user_role="user", 42 | chat_size=1, 43 | init_chat_role=None, 44 | init_chat_prompt="You are a helpful AI assistant.", 45 | ): 46 | self.device = device 47 | self.torch_dtype = getattr(torch, torch_dtype) 48 | 49 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 50 | self.model = AutoModelForCausalLM.from_pretrained( 51 | model_name, torch_dtype=torch_dtype, trust_remote_code=True 52 | ).to(device) 53 | self.pipe = pipeline( 54 | "text-generation", model=self.model, tokenizer=self.tokenizer, device=device 55 | ) 56 | self.streamer = TextIteratorStreamer( 57 | self.tokenizer, 58 | skip_prompt=True, 59 | skip_special_tokens=True, 60 | ) 61 | self.gen_kwargs = { 62 | "streamer": self.streamer, 63 | "return_full_text": False, 64 | **gen_kwargs, 65 | } 66 | 67 | self.chat = Chat(chat_size) 68 | if init_chat_role: 69 | if not init_chat_prompt: 70 | raise ValueError( 71 | "An initial promt needs to be specified when setting init_chat_role." 72 | ) 73 | self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) 74 | self.user_role = user_role 75 | 76 | self.warmup() 77 | 78 | def warmup(self): 79 | logger.info(f"Warming up {self.__class__.__name__}") 80 | 81 | dummy_input_text = "Repeat the word 'home'." 82 | dummy_chat = [{"role": self.user_role, "content": dummy_input_text}] 83 | warmup_gen_kwargs = { 84 | "min_new_tokens": self.gen_kwargs["min_new_tokens"], 85 | "max_new_tokens": self.gen_kwargs["max_new_tokens"], 86 | **self.gen_kwargs, 87 | } 88 | 89 | n_steps = 2 90 | 91 | if self.device == "cuda": 92 | start_event = torch.cuda.Event(enable_timing=True) 93 | end_event = torch.cuda.Event(enable_timing=True) 94 | torch.cuda.synchronize() 95 | start_event.record() 96 | 97 | for _ in range(n_steps): 98 | thread = Thread( 99 | target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs 100 | ) 101 | thread.start() 102 | for _ in self.streamer: 103 | pass 104 | 105 | if self.device == "cuda": 106 | end_event.record() 107 | torch.cuda.synchronize() 108 | 109 | logger.info( 110 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" 111 | ) 112 | 113 | def process(self, prompt): 114 | logger.debug("infering language model...") 115 | language_code = None 116 | if isinstance(prompt, tuple): 117 | prompt, language_code = prompt 118 | prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt 119 | 120 | self.chat.append({"role": self.user_role, "content": prompt}) 121 | thread = Thread( 122 | target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs 123 | ) 124 | thread.start() 125 | if self.device == "mps": 126 | generated_text = "" 127 | for new_text in self.streamer: 128 | generated_text += new_text 129 | printable_text = generated_text 130 | torch.mps.empty_cache() 131 | else: 132 | generated_text, printable_text = "", "" 133 | for new_text in self.streamer: 134 | generated_text += new_text 135 | printable_text += new_text 136 | sentences = sent_tokenize(printable_text) 137 | if len(sentences) > 1: 138 | yield (sentences[0], language_code) 139 | printable_text = new_text 140 | 141 | self.chat.append({"role": "assistant", "content": generated_text}) 142 | 143 | # don't forget last sentence 144 | yield (printable_text, language_code) 145 | -------------------------------------------------------------------------------- /STT/whisper_stt_handler.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | from transformers import ( 3 | AutoProcessor, 4 | AutoModelForSpeechSeq2Seq 5 | ) 6 | import torch 7 | from copy import copy 8 | from baseHandler import BaseHandler 9 | from rich.console import Console 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | console = Console() 14 | 15 | SUPPORTED_LANGUAGES = [ 16 | "en", 17 | "fr", 18 | "es", 19 | "zh", 20 | "ja", 21 | "ko", 22 | ] 23 | 24 | 25 | class WhisperSTTHandler(BaseHandler): 26 | """ 27 | Handles the Speech To Text generation using a Whisper model. 28 | """ 29 | 30 | def setup( 31 | self, 32 | model_name="distil-whisper/distil-large-v3", 33 | device="cuda", 34 | torch_dtype="float16", 35 | compile_mode=None, 36 | language=None, 37 | gen_kwargs={}, 38 | ): 39 | self.device = device 40 | self.torch_dtype = getattr(torch, torch_dtype) 41 | self.compile_mode = compile_mode 42 | self.gen_kwargs = gen_kwargs 43 | if language == 'auto': 44 | language = None 45 | self.last_language = language 46 | if self.last_language is not None: 47 | self.gen_kwargs["language"] = self.last_language 48 | 49 | self.processor = AutoProcessor.from_pretrained(model_name) 50 | self.model = AutoModelForSpeechSeq2Seq.from_pretrained( 51 | model_name, 52 | torch_dtype=self.torch_dtype, 53 | ).to(device) 54 | 55 | # compile 56 | if self.compile_mode: 57 | self.model.generation_config.cache_implementation = "static" 58 | self.model.forward = torch.compile( 59 | self.model.forward, mode=self.compile_mode, fullgraph=True 60 | ) 61 | self.warmup() 62 | 63 | def prepare_model_inputs(self, spoken_prompt): 64 | input_features = self.processor( 65 | spoken_prompt, sampling_rate=16000, return_tensors="pt" 66 | ).input_features 67 | input_features = input_features.to(self.device, dtype=self.torch_dtype) 68 | 69 | return input_features 70 | 71 | def warmup(self): 72 | logger.info(f"Warming up {self.__class__.__name__}") 73 | 74 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture 75 | n_steps = 1 if self.compile_mode == "default" else 2 76 | dummy_input = torch.randn( 77 | (1, self.model.config.num_mel_bins, 3000), 78 | dtype=self.torch_dtype, 79 | device=self.device, 80 | ) 81 | if self.compile_mode not in (None, "default"): 82 | # generating more tokens than previously will trigger CUDA graphs capture 83 | # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation 84 | # hence, having min_new_tokens < max_new_tokens in the future doesn't make sense 85 | warmup_gen_kwargs = { 86 | "min_new_tokens": self.gen_kwargs[ 87 | "max_new_tokens" 88 | ], # Yes, assign max_new_tokens to min_new_tokens 89 | "max_new_tokens": self.gen_kwargs["max_new_tokens"], 90 | **self.gen_kwargs, 91 | } 92 | else: 93 | warmup_gen_kwargs = self.gen_kwargs 94 | 95 | if self.device == "cuda": 96 | start_event = torch.cuda.Event(enable_timing=True) 97 | end_event = torch.cuda.Event(enable_timing=True) 98 | torch.cuda.synchronize() 99 | start_event.record() 100 | 101 | for _ in range(n_steps): 102 | _ = self.model.generate(dummy_input, **warmup_gen_kwargs) 103 | 104 | if self.device == "cuda": 105 | end_event.record() 106 | torch.cuda.synchronize() 107 | 108 | logger.info( 109 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" 110 | ) 111 | 112 | def process(self, spoken_prompt): 113 | logger.debug("infering whisper...") 114 | 115 | global pipeline_start 116 | pipeline_start = perf_counter() 117 | 118 | input_features = self.prepare_model_inputs(spoken_prompt) 119 | pred_ids = self.model.generate(input_features, **self.gen_kwargs) 120 | language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" 121 | 122 | if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language 123 | logger.warning("Whisper detected unsupported language:", language_code) 124 | gen_kwargs = copy(self.gen_kwargs) 125 | gen_kwargs['language'] = self.last_language 126 | language_code = self.last_language 127 | pred_ids = self.model.generate(input_features, **gen_kwargs) 128 | else: 129 | self.last_language = language_code 130 | 131 | pred_text = self.processor.batch_decode( 132 | pred_ids, skip_special_tokens=True, decode_with_timestamps=False 133 | )[0] 134 | language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" 135 | 136 | logger.debug("finished whisper inference") 137 | console.print(f"[yellow]USER: {pred_text}") 138 | logger.debug(f"Language Code Whisper: {language_code}") 139 | 140 | yield (pred_text, language_code) 141 | -------------------------------------------------------------------------------- /TTS/parler_handler.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | from time import perf_counter 3 | from baseHandler import BaseHandler 4 | import numpy as np 5 | import torch 6 | from transformers import ( 7 | AutoTokenizer, 8 | ) 9 | from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer 10 | import librosa 11 | import logging 12 | from rich.console import Console 13 | from utils.utils import next_power_of_2 14 | from transformers.utils.import_utils import ( 15 | is_flash_attn_2_available, 16 | ) 17 | 18 | torch._inductor.config.fx_graph_cache = True 19 | # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS 20 | torch._dynamo.config.cache_size_limit = 15 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | console = Console() 25 | 26 | 27 | if not is_flash_attn_2_available() and torch.cuda.is_available(): 28 | logger.warn( 29 | """Parler TTS works best with flash attention 2, but is not installed 30 | Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`""" 31 | ) 32 | 33 | 34 | class ParlerTTSHandler(BaseHandler): 35 | def setup( 36 | self, 37 | should_listen, 38 | model_name="ylacombe/parler-tts-mini-jenny-30H", 39 | device="cuda", 40 | torch_dtype="float16", 41 | compile_mode=None, 42 | gen_kwargs={}, 43 | max_prompt_pad_length=8, 44 | description=( 45 | "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. " 46 | "She speaks very fast." 47 | ), 48 | play_steps_s=1, 49 | blocksize=512, 50 | ): 51 | self.should_listen = should_listen 52 | self.device = device 53 | self.torch_dtype = getattr(torch, torch_dtype) 54 | self.gen_kwargs = gen_kwargs 55 | self.compile_mode = compile_mode 56 | self.max_prompt_pad_length = max_prompt_pad_length 57 | self.description = description 58 | 59 | self.description_tokenizer = AutoTokenizer.from_pretrained(model_name) 60 | self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name) 61 | self.model = ParlerTTSForConditionalGeneration.from_pretrained( 62 | model_name, torch_dtype=self.torch_dtype 63 | ).to(device) 64 | 65 | framerate = self.model.audio_encoder.config.frame_rate 66 | self.play_steps = int(framerate * play_steps_s) 67 | self.blocksize = blocksize 68 | 69 | if self.compile_mode not in (None, "default"): 70 | logger.warning( 71 | "Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'" 72 | ) 73 | self.compile_mode = "default" 74 | 75 | if self.compile_mode: 76 | self.model.generation_config.cache_implementation = "static" 77 | self.model.forward = torch.compile( 78 | self.model.forward, mode=self.compile_mode, fullgraph=True 79 | ) 80 | 81 | self.warmup() 82 | 83 | def prepare_model_inputs( 84 | self, 85 | prompt, 86 | max_length_prompt=50, 87 | pad=False, 88 | ): 89 | pad_args_prompt = ( 90 | {"padding": "max_length", "max_length": max_length_prompt} if pad else {} 91 | ) 92 | 93 | tokenized_description = self.description_tokenizer( 94 | self.description, return_tensors="pt" 95 | ) 96 | input_ids = tokenized_description.input_ids.to(self.device) 97 | attention_mask = tokenized_description.attention_mask.to(self.device) 98 | 99 | tokenized_prompt = self.prompt_tokenizer( 100 | prompt, return_tensors="pt", **pad_args_prompt 101 | ) 102 | prompt_input_ids = tokenized_prompt.input_ids.to(self.device) 103 | prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device) 104 | 105 | gen_kwargs = { 106 | "input_ids": input_ids, 107 | "attention_mask": attention_mask, 108 | "prompt_input_ids": prompt_input_ids, 109 | "prompt_attention_mask": prompt_attention_mask, 110 | **self.gen_kwargs, 111 | } 112 | 113 | return gen_kwargs 114 | 115 | def warmup(self): 116 | logger.info(f"Warming up {self.__class__.__name__}") 117 | 118 | if self.device == "cuda": 119 | start_event = torch.cuda.Event(enable_timing=True) 120 | end_event = torch.cuda.Event(enable_timing=True) 121 | 122 | # 2 warmup steps for no compile or compile mode with CUDA graphs capture 123 | n_steps = 1 if self.compile_mode == "default" else 2 124 | 125 | if self.device == "cuda": 126 | torch.cuda.synchronize() 127 | start_event.record() 128 | if self.compile_mode: 129 | pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)] 130 | for pad_length in pad_lengths[::-1]: 131 | model_kwargs = self.prepare_model_inputs( 132 | "dummy prompt", max_length_prompt=pad_length, pad=True 133 | ) 134 | for _ in range(n_steps): 135 | _ = self.model.generate(**model_kwargs) 136 | logger.info(f"Warmed up length {pad_length} tokens!") 137 | else: 138 | model_kwargs = self.prepare_model_inputs("dummy prompt") 139 | for _ in range(n_steps): 140 | _ = self.model.generate(**model_kwargs) 141 | 142 | if self.device == "cuda": 143 | end_event.record() 144 | torch.cuda.synchronize() 145 | logger.info( 146 | f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" 147 | ) 148 | 149 | def process(self, llm_sentence): 150 | if isinstance(llm_sentence, tuple): 151 | llm_sentence, _ = llm_sentence 152 | 153 | console.print(f"[green]ASSISTANT: {llm_sentence}") 154 | nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids) 155 | 156 | pad_args = {} 157 | if self.compile_mode: 158 | # pad to closest upper power of two 159 | pad_length = next_power_of_2(nb_tokens) 160 | logger.debug(f"padding to {pad_length}") 161 | pad_args["pad"] = True 162 | pad_args["max_length_prompt"] = pad_length 163 | 164 | tts_gen_kwargs = self.prepare_model_inputs( 165 | llm_sentence, 166 | **pad_args, 167 | ) 168 | 169 | streamer = ParlerTTSStreamer( 170 | self.model, device=self.device, play_steps=self.play_steps 171 | ) 172 | tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs} 173 | torch.manual_seed(0) 174 | thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs) 175 | thread.start() 176 | 177 | for i, audio_chunk in enumerate(streamer): 178 | global pipeline_start 179 | if i == 0 and "pipeline_start" in globals(): 180 | logger.info( 181 | f"Time to first audio: {perf_counter() - pipeline_start:.3f}" 182 | ) 183 | audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000) 184 | audio_chunk = (audio_chunk * 32768).astype(np.int16) 185 | for i in range(0, len(audio_chunk), self.blocksize): 186 | yield np.pad( 187 | audio_chunk[i : i + self.blocksize], 188 | (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])), 189 | ) 190 | 191 | self.should_listen.set() 192 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
4 |