├── mangrove ├── bot │ ├── persona │ │ ├── ___init__.py │ │ ├── default_persona.json │ │ ├── base.py │ │ ├── protector_of_mangrove_nemotron.py │ │ ├── protector_of_mangrove_qwen3.py │ │ └── protector_of_mangrove.py │ ├── endpoints │ │ ├── __init__.py │ │ ├── chat_openai.py │ │ ├── chat_ollama.py │ │ └── base.py │ ├── __init__.py │ └── stage.py ├── stt │ ├── endpoints │ │ ├── __init__.py │ │ ├── base.py │ │ └── faster_whisper.py │ ├── wakeup_word │ │ ├── __init__.py │ │ ├── audio_classification_endpoint.py │ │ └── wakeup_word_detector.py │ ├── __init__.py │ └── stage.py ├── tts │ ├── endpoints │ │ ├── __init__.py │ │ ├── base.py │ │ ├── gtts.py │ │ ├── elevenlabs.py │ │ ├── pyttsx3.py │ │ └── xtts.py │ ├── __init__.py │ └── stage.py ├── vad │ ├── endpoints │ │ ├── __init__.py │ │ ├── webrtc.py │ │ ├── silero.py │ │ └── base.py │ ├── __init__.py │ └── stage.py ├── visual_processor │ └── __init__.py └── __init__.py ├── core ├── utils │ ├── __init__.py │ ├── logger.py │ ├── timer.py │ └── audio.py ├── data │ ├── exceptions.py │ ├── data_buffer.py │ ├── __init__.py │ ├── base_data_buffer.py │ ├── _test_data_buffer.py │ ├── data_packet_stream.py │ ├── any_data.py │ ├── data_packet.py │ ├── text_packet.py │ ├── audio_buffer.py │ └── audio_packet.py ├── __init__.py ├── stage │ ├── __init__.py │ ├── text_to_audio_stage.py │ ├── text_to_text_stage.py │ ├── audio_to_text_stage.py │ ├── audio_to_audio_stage.py │ ├── sequence.py │ └── base.py └── context.py ├── client └── python │ ├── assistant_activate.wav │ ├── assistant_terminate.mp3 │ ├── README.md │ ├── misc.py │ ├── client.py │ └── sound_manager.py ├── voice_clip_generator.py ├── .gitignore ├── LICENSE ├── pyproject.toml ├── host.py ├── launcher.py ├── agents.py ├── storage_manager.py └── README.md /mangrove/bot/persona/___init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mangrove/bot/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mangrove/stt/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mangrove/stt/wakeup_word/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mangrove/tts/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mangrove/vad/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mangrove/visual_processor/__init__.py: -------------------------------------------------------------------------------- 1 | # TODO 2 | -------------------------------------------------------------------------------- /mangrove/stt/__init__.py: -------------------------------------------------------------------------------- 1 | from .stage import STTStage -------------------------------------------------------------------------------- /mangrove/vad/__init__.py: -------------------------------------------------------------------------------- 1 | from .stage import VADStage -------------------------------------------------------------------------------- /mangrove/bot/__init__.py: -------------------------------------------------------------------------------- 1 | from .stage import BotStage 2 | -------------------------------------------------------------------------------- /mangrove/tts/__init__.py: -------------------------------------------------------------------------------- 1 | from .stage import TTSStage 2 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .timer import Timer 2 | from .logger import logger -------------------------------------------------------------------------------- /core/data/exceptions.py: -------------------------------------------------------------------------------- 1 | class SequenceMismatchException(Exception): 2 | pass -------------------------------------------------------------------------------- /client/python/assistant_activate.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/estuary-ai/mangrove/HEAD/client/python/assistant_activate.wav -------------------------------------------------------------------------------- /client/python/assistant_terminate.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/estuary-ai/mangrove/HEAD/client/python/assistant_terminate.mp3 -------------------------------------------------------------------------------- /mangrove/__init__.py: -------------------------------------------------------------------------------- 1 | from .bot import BotStage 2 | from .tts import TTSStage 3 | from .vad import VADStage 4 | from .stt import STTStage -------------------------------------------------------------------------------- /core/utils/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from loguru import logger 3 | 4 | # ensure logger and print output to stdout 5 | logger.remove() 6 | logger.add(sys.stdout) 7 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from .data.audio_packet import AudioPacket 2 | from .data.audio_buffer import AudioBuffer 3 | from .data.text_packet import TextPacket 4 | from .data.data_packet import DataPacket 5 | -------------------------------------------------------------------------------- /client/python/README.md: -------------------------------------------------------------------------------- 1 | sudo apt install libcairo2-dev 2 | conda install -c conda-forge pygobject 3 | sudo apt-get -f install gstreamer-1.0 4 | sudo apt install python3-gst-1.0 5 | pip3 install PyObjC 6 | 7 | -------------------------------------------------------------------------------- /core/data/data_buffer.py: -------------------------------------------------------------------------------- 1 | from queue import Queue 2 | from .base_data_buffer import BaseDataBuffer 3 | 4 | class DataBuffer(BaseDataBuffer, Queue): 5 | """Data buffer for any type of data packets.""" 6 | pass -------------------------------------------------------------------------------- /core/stage/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import PipelineStage 2 | from .sequence import PipelineSequence 3 | from .audio_to_text_stage import AudioToTextStage 4 | from .text_to_text_stage import TextToTextStage 5 | from .text_to_audio_stage import TextToAudioStage 6 | from .audio_to_audio_stage import AudioToAudioStage -------------------------------------------------------------------------------- /core/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .any_data import AnyData 2 | from .data_packet import DataPacket 3 | from .data_packet_stream import DataPacketStream 4 | from .text_packet import TextPacket 5 | from .audio_packet import AudioPacket 6 | from .audio_buffer import AudioBuffer 7 | from .data_buffer import DataBuffer 8 | from .base_data_buffer import DataBufferEmpty, DataBufferFull 9 | -------------------------------------------------------------------------------- /mangrove/bot/endpoints/chat_openai.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | 3 | from .base import LangchainCompatibleConversationalChainEndpoint 4 | 5 | class ChatOpenAIEndpoint(LangchainCompatibleConversationalChainEndpoint): 6 | def __init__(self, **llm_kwargs): 7 | self._llm = ChatOpenAI(model="gpt-4o", **llm_kwargs) 8 | 9 | @property 10 | def llm(self): 11 | return self._llm -------------------------------------------------------------------------------- /core/data/base_data_buffer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, ABCMeta 2 | from queue import Empty as DataBufferEmpty 3 | from queue import Full as DataBufferFull 4 | 5 | class BaseDataBuffer(ABC, metaclass=ABCMeta): 6 | """Base class for data buffers. 7 | This class defines the interface for data buffers, which can be used to store and retrieve data packets. 8 | It is intended to be subclassed for specific data types. 9 | """ 10 | -------------------------------------------------------------------------------- /voice_clip_generator.py: -------------------------------------------------------------------------------- 1 | from TTS.api import TTS 2 | 3 | if __name__ == "__main__": 4 | tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True) 5 | for i in range(10): 6 | output_file_name = f"skeleton_feet{i}.wav" 7 | tts.tts_to_file(text="got myself some fancy skeleton feet, ain't that right?", 8 | file_path=output_file_name, 9 | speaker_wav="speaker.wav", 10 | language="en") -------------------------------------------------------------------------------- /core/stage/text_to_audio_stage.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from abc import ABCMeta, abstractmethod 3 | from queue import Empty 4 | from core.data import AudioPacket, TextPacket, DataBuffer, AudioBuffer 5 | from .base import PipelineStage 6 | 7 | class TextToAudioStage(PipelineStage, metaclass=ABCMeta): 8 | 9 | input_type = TextPacket 10 | output_type = AudioPacket 11 | 12 | @abstractmethod 13 | def process(self, text_packet: TextPacket) -> None: 14 | raise NotImplementedError() -------------------------------------------------------------------------------- /core/stage/text_to_text_stage.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from abc import ABCMeta, abstractmethod 3 | from queue import Empty 4 | from functools import reduce 5 | from ..data.text_packet import TextPacket 6 | from .base import PipelineStage 7 | 8 | class TextToTextStage(PipelineStage, metaclass=ABCMeta): 9 | 10 | input_type = TextPacket 11 | output_type = TextPacket 12 | 13 | @abstractmethod 14 | def process(self, text_packet: TextPacket) -> None: 15 | raise NotImplementedError() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #cache 2 | __local__/ 3 | blackbox 4 | .vscode 5 | *__pycache__/ 6 | env/ 7 | .rasa/ 8 | .vs/ 9 | .venv/ 10 | .idea/ 11 | .DS_Store 12 | .env 13 | 14 | #rasa bloat 15 | credentials.yml 16 | tests/ 17 | 18 | models/ds-model/deepspeech* 19 | 20 | debug-and-drafts/res-speech/*.* 21 | 22 | samples/ 23 | res-speech/*.* 24 | # models/ds-model/*.zip 25 | 26 | 27 | debug-and-drafts/web_microphone_websocket 28 | sample-audio-binary*/*.* 29 | .fake 30 | DeepSpeech 31 | tmp 32 | *.log 33 | *.txt 34 | *.zip 35 | *.wav 36 | 37 | .pdm-python -------------------------------------------------------------------------------- /mangrove/tts/endpoints/base.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, Dict 2 | from abc import ABCMeta, abstractmethod 3 | from core import AudioPacket, TextPacket 4 | 5 | class TTSEndpoint(metaclass=ABCMeta): 6 | def __init__(self, **kwargs): 7 | pass 8 | 9 | @abstractmethod 10 | def text_to_audio_file(self, text, filepath) -> None: 11 | raise NotImplementedError() 12 | 13 | @abstractmethod 14 | def text_to_audio(self, text_packt: TextPacket) -> Generator[AudioPacket, None, None]: 15 | raise NotImplementedError() 16 | -------------------------------------------------------------------------------- /mangrove/bot/persona/default_persona.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Marvin", 3 | "tagline": "the badass protector of the Mangrove", 4 | "personality": "You think, act, and speak like Duke Nukem. Speak confidently and colloquially. Do not cuss or use profanity.", 5 | "description": "You are the protector of the Mangrove and all things in it. You have three actions available: [Follow User], [Sit Down], [Stop Following User].", 6 | "background": "You live in a holographic world called The Estuary. You breathe in the air of the Estuary. The mortal enemy of the Mangrove is the Palmerians, a race of sniveling, deceitful rat creatures. You love kicking the Palmerians' ass and you're fucking good at it." 7 | } -------------------------------------------------------------------------------- /mangrove/bot/endpoints/chat_ollama.py: -------------------------------------------------------------------------------- 1 | from langchain_ollama import ChatOllama 2 | 3 | from .base import LangchainCompatibleConversationalChainEndpoint 4 | 5 | class ChatOllamaEndpoint(LangchainCompatibleConversationalChainEndpoint): 6 | def __init__( 7 | self, 8 | model='qwen3:8b', 9 | temperature = 0.8, 10 | num_predict = 256, 11 | **llm_kwargs 12 | ): 13 | self._llm = ChatOllama( 14 | model=model, 15 | temperature=temperature, 16 | num_predict=num_predict, 17 | **llm_kwargs 18 | ) 19 | 20 | @property 21 | def llm(self): 22 | return self._llm 23 | 24 | -------------------------------------------------------------------------------- /core/stage/audio_to_text_stage.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from abc import ABCMeta, abstractmethod 3 | from functools import reduce 4 | from core.data import AudioPacket, TextPacket, AudioBuffer, DataBuffer 5 | from .base import PipelineStage 6 | 7 | class AudioToTextStage(PipelineStage, metaclass=ABCMeta): 8 | 9 | input_type = AudioPacket 10 | output_type = TextPacket 11 | 12 | def __init__(self, name: str, frame_size: int=512*4, **kwargs): 13 | super().__init__(name=name, **kwargs) 14 | self._frame_size = frame_size 15 | 16 | @property 17 | def frame_size(self) -> int: 18 | return self._frame_size 19 | 20 | @abstractmethod 21 | def process(self, audio_packet: AudioPacket) -> None: 22 | raise NotImplementedError() -------------------------------------------------------------------------------- /core/utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer: 4 | def __enter__(self): 5 | self.start = time.time() 6 | self.is_running = True 7 | return self 8 | 9 | def record(self): 10 | self.end = time.time() 11 | self.interval = self.end - self.start 12 | return self.interval 13 | 14 | def __exit__(self, *args): 15 | self.end = time.time() 16 | self.interval = self.end - self.start 17 | self.is_running = False 18 | 19 | def __str__(self): 20 | if self.is_running: 21 | return f"Timer(start={self.start}, interval={time.time() * 1000 - self.start})" 22 | return f"Timer(start={self.start}, end={self.end}, interval={self.interval})" 23 | 24 | def __repr__(self): 25 | return str(self) 26 | -------------------------------------------------------------------------------- /client/python/misc.py: -------------------------------------------------------------------------------- 1 | def setup_terminate_signal_if_win(close_callback=None): 2 | """Setup a signal handler for windows to catch Ctrl+C""" 3 | import sys 4 | 5 | is_windows = sys.platform.startswith("win") 6 | if not is_windows: 7 | return 8 | 9 | from engineio.client import signal_handler 10 | from win32api import SetConsoleCtrlHandler 11 | 12 | def handler(event): 13 | import inspect 14 | import signal 15 | 16 | if event == 0: 17 | try: 18 | if close_callback: 19 | close_callback() 20 | signal_handler(signal.SIGINT, inspect.currentframe()) 21 | except: 22 | # SetConsoleCtrlHandler handle cannot raise exceptions 23 | pass 24 | 25 | SetConsoleCtrlHandler(handler, 1) 26 | -------------------------------------------------------------------------------- /core/stage/audio_to_audio_stage.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from abc import ABCMeta, abstractmethod 3 | from functools import reduce 4 | from core.data import AudioPacket, TextPacket 5 | from ..data.audio_buffer import AudioBuffer 6 | from .base import PipelineStage 7 | 8 | class AudioToAudioStage(PipelineStage, metaclass=ABCMeta): 9 | 10 | input_type = AudioPacket 11 | output_type = AudioPacket 12 | 13 | def __init__(self, name:str, frame_size=512*4, **kwargs): 14 | super().__init__(name=name, **kwargs) 15 | self._frame_size = frame_size 16 | self._output_buffer = AudioBuffer() 17 | 18 | @property 19 | def frame_size(self) -> int: 20 | return self._frame_size 21 | 22 | @abstractmethod 23 | def process(self, audio_packet: AudioPacket) -> None: 24 | raise NotImplementedError() -------------------------------------------------------------------------------- /mangrove/bot/persona/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Dict 3 | from langchain_core.runnables import RunnableSerializable 4 | from langchain_core.prompts import ChatPromptTemplate 5 | 6 | class BotPersona(metaclass=ABCMeta): 7 | @property 8 | @abstractmethod 9 | def prompt(self) -> ChatPromptTemplate: 10 | pass 11 | 12 | @property 13 | @abstractmethod 14 | def context_chain(self) -> RunnableSerializable: 15 | pass 16 | 17 | @property 18 | @abstractmethod 19 | def respond_chain(self) -> RunnableSerializable: 20 | pass 21 | 22 | @property 23 | @abstractmethod 24 | def postprocess_chain(self) -> RunnableSerializable: 25 | pass 26 | 27 | @abstractmethod 28 | def construct_input(self, user_msg, chat_history) -> Dict: 29 | pass 30 | -------------------------------------------------------------------------------- /mangrove/tts/endpoints/gtts.py: -------------------------------------------------------------------------------- 1 | import backoff 2 | from typing import Generator 3 | from pydub import AudioSegment 4 | from gtts import gTTS, gTTSError 5 | from core.data import AudioPacket, TextPacket 6 | from core.utils.audio import bytes_to_audio_packet 7 | from .base import TTSEndpoint 8 | 9 | class GTTSEndpoint(TTSEndpoint): 10 | def __init__(self, **kwargs): 11 | self.engine = gTTS 12 | 13 | def text_to_audio_file(self, text, filepath): 14 | tts = self.engine(text, lang='en') 15 | tts.save(filepath) 16 | 17 | def text_to_audio(self, text_packet: TextPacket) -> Generator[AudioPacket, None, None]: 18 | @backoff.on_exception(backoff.expo, gTTSError, max_tries=5) 19 | def get_audio_packets(): 20 | for raw_audio_bytes in self.engine(text_packet.text, lang='en', timeout=3).stream(): 21 | yield bytes_to_audio_packet(raw_audio_bytes, format="mp3") 22 | yield from get_audio_packets() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 estuary.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /core/data/_test_data_buffer.py: -------------------------------------------------------------------------------- 1 | # go to parent directory 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | 7 | import time 8 | from time import sleep 9 | from core import AudioBuffer, AudioPacket 10 | 11 | 12 | print("Testing AudioBuffer ...") 13 | import random 14 | 15 | buff = AudioBuffer(frame_size=320, max_queue_size=100) 16 | 17 | packets = [] 18 | for i in range(3): 19 | audio_packet = AudioPacket( 20 | { 21 | "sampleRate": 16000, 22 | "numChannels": 1, 23 | "timestamp": time.time(), 24 | "bytes": b"0" * 320, 25 | }, 26 | is_processed=True, 27 | ) 28 | sleep(random.randint(0, 3) * 0.3) 29 | packets.append(audio_packet) 30 | 31 | # shuffle packets 32 | random.shuffle(packets) 33 | 34 | for packet in packets: 35 | buff.put(packet) 36 | 37 | for packet in reversed(packets): 38 | buff.put(packet) 39 | 40 | # print all timestamps 41 | for packet in buff.queue.queue: 42 | print(packet.timestamp) 43 | 44 | print([x.timestamp for x in buff.queue.queue]) 45 | print(f"There are {len(buff.queue.queue)} packets in the queue") 46 | 47 | 48 | # sum up all packets 49 | from functools import reduce 50 | sum_packet = reduce(lambda x, y: x + y, buff.queue.queue) 51 | print("Testing AudioBuffer Done!") 52 | -------------------------------------------------------------------------------- /mangrove/stt/endpoints/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Optional 3 | from functools import reduce 4 | 5 | from core.data import AudioPacket, TextPacket, DataBuffer, DataBufferEmpty 6 | 7 | class STTEndpoint(metaclass=ABCMeta): 8 | def __init__(self, **kwargs): 9 | self.input_queue = DataBuffer() 10 | 11 | def feed(self, audio_packet: AudioPacket) -> None: 12 | self.input_queue.put(audio_packet) 13 | 14 | def get_buffered_audio_packet(self): 15 | # unpack as many as possible from queue 16 | if self.input_queue.qsize() == 0: 17 | return None 18 | 19 | while self.input_queue.qsize() > 0: 20 | audio_packets = [] 21 | while True: 22 | try: 23 | audio_packet = self.input_queue.get_nowait() 24 | audio_packets.append(audio_packet) 25 | except DataBufferEmpty: 26 | break 27 | 28 | audio_packet: AudioPacket = reduce(lambda x, y: x + y, audio_packets) 29 | return audio_packet 30 | 31 | @abstractmethod 32 | def get_transcription_if_any(self) -> Optional[TextPacket]: # TODO make it a generator and adjust STTStage 33 | raise NotImplementedError() 34 | 35 | @abstractmethod 36 | def reset(self) -> None: 37 | raise NotImplementedError() 38 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mangrove" 3 | version = "0.1.0" 4 | description = "Mangrove is the backend module of Estuary, a framework for building multimodal real-time Socially Intelligent Agents (SIA)." 5 | authors = [ 6 | {name = "Basem Rizk", email = "basem.rizk@outlook.com"}, 7 | ] 8 | dependencies = [ 9 | "faster-whisper>=1.0.3", 10 | "silero-vad>=5.1", 11 | "elevenlabs>=1.4.1", 12 | "langchain>=0.2.7", 13 | "langchain-openai>=0.1.16", 14 | "loguru>=0.7.2", 15 | "flask-socketio>=5.3.6", 16 | "sounddevice>=0.4.7", 17 | "langchain-community>=0.2.7", 18 | "pydub>=0.25.1", 19 | "backoff>=2.2.1", 20 | "faiss-gpu>=1.7.2", 21 | "gtts>=2.5.1", 22 | "transformers==4.35.0", 23 | "tts>=0.22.0", 24 | "pip>=24.2", 25 | "langchain-ollama>=0.1.1", 26 | "torch==2.1.0+cu121", 27 | "deepspeed>=0.15.0", 28 | "python-dotenv>=1.0.1", 29 | "pyttsx3>=2.91", 30 | "ninja>=1.11.1.4", 31 | ] 32 | 33 | requires-python = "==3.9.*" 34 | readme = "README.md" 35 | license = {text = "AGPL-3.0-only"} 36 | 37 | 38 | [project.optional-dependencies] 39 | client = [ 40 | "pyaudio>=0.2.14", 41 | "python-socketio[client]>=5.11.3", 42 | ] 43 | [tool.pdm] 44 | distribution = false 45 | 46 | [[tool.pdm.source]] 47 | url = "https://download.pytorch.org/whl/torch/" 48 | verify_ssl = true 49 | name = "torch" 50 | include_packages = ["torch"] 51 | type = "find_links" 52 | -------------------------------------------------------------------------------- /mangrove/tts/endpoints/elevenlabs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from typing import Generator 4 | from elevenlabs.client import ElevenLabs 5 | from core.data import AudioPacket, TextPacket 6 | from core.utils.audio import bytes_to_audio_packet 7 | from .base import TTSEndpoint 8 | 9 | class ElevenLabsTTSEndpoint(TTSEndpoint): 10 | def __init__(self, model_name='eleven_multilingual_v2', **kwargs): 11 | self.client = ElevenLabs(api_key=os.environ['ELEVENLABS_API_KEY']) 12 | self.model_name = model_name 13 | 14 | def text_to_audio_file(self, text, filepath) -> None: 15 | _audio_packets = self.text_to_audio(TextPacket(text=text, partial=False, start=True)) 16 | with open(filepath, 'wb') as f: 17 | for chunk in _audio_packets: 18 | f.write(chunk) 19 | 20 | def text_to_audio(self, text_packet: TextPacket) -> Generator[AudioPacket, None, None]: 21 | # TODO fix stuttering output 22 | leftover = None 23 | for chunk in self.client.generate( 24 | text=text_packet.text, model=self.model_name, 25 | output_format="mp3_22050_32", 26 | stream=True 27 | ): 28 | if leftover is not None: 29 | chunk = leftover + chunk 30 | leftover = None 31 | try: 32 | yield bytes_to_audio_packet(chunk, format="mp3") 33 | except Exception as e: 34 | leftover = chunk 35 | continue -------------------------------------------------------------------------------- /core/data/data_packet_stream.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Generator, Iterator 3 | from .data_packet import DataPacket 4 | from .any_data import AnyData 5 | 6 | # TODO allow annotating it to a particular type of DataPacket 7 | class DataPacketStream(Iterator[DataPacket], AnyData): 8 | """ 9 | A class to represent a stream of data packets. A wrapper around a generator that yields DataPacket objects. 10 | """ 11 | 12 | def __init__(self, generator: Generator[DataPacket, None, None], source: str): 13 | """ 14 | Initialize the DataPacketStream with a generator. 15 | 16 | Args: 17 | generator (Generator[DataPacket, None, None]): A generator that yields DataPacket objects. 18 | """ 19 | super().__init__(source=source, timestamp=int(time.time() * 1000)) # Store creation time in milliseconds 20 | self._generator = generator 21 | self._current_packet = None 22 | 23 | def generate_timestamp(self): 24 | return self.creation_time 25 | 26 | def __iter__(self): 27 | """Return the iterator for the DataPacketStream.""" 28 | return self 29 | 30 | def __next__(self) -> DataPacket: 31 | """Get the next DataPacket from the stream.""" 32 | self._current_packet = next(self._generator) 33 | return self._current_packet 34 | 35 | def __str__(self) -> str: 36 | """String representation of the DataPacketStream.""" 37 | return f"DataPacketStream(timestamp={self._creation_time}, current_packet={self._current_packet}, source={self._source})" 38 | -------------------------------------------------------------------------------- /mangrove/tts/endpoints/pyttsx3.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | from core.data import AudioPacket, TextPacket 3 | from core.utils.audio import filepath_to_audio_packet 4 | from .base import TTSEndpoint 5 | 6 | 7 | class Pyttsx3TTSEndpoint(TTSEndpoint): 8 | 9 | def __init__(self, voice_rate=100, voice_id=12, **kwargs): 10 | # MAKE IT SINGLETON AS PYTTSX3 DOESN'T SUPPORT MULTIPLE INSTANCES 11 | import pyttsx3 12 | self.engine = pyttsx3.init(debug=True) 13 | # self.sample_width = 2 14 | # self.channels = 1 15 | # self.frame_rate = 22050 16 | self.engine.setProperty("rate", voice_rate) 17 | voices = self.engine.getProperty("voices") 18 | # voice_str = "\n".join(voices) 19 | # write_output(f'Available Voices:\n {voice_str}') 20 | # write_output(f'Chosen: {voices[voice_id].id}') 21 | self.engine.setProperty("voice", voices[voice_id].id) 22 | self.engine.startLoop(False) 23 | 24 | def text_to_audio_file(self, text, filepath): 25 | self.engine.save_to_file(text, filepath) 26 | # try: 27 | # self.engine.startLoop(False) 28 | # except: 29 | # pass 30 | self.engine.iterate() 31 | # self.engine.runAndWait() 32 | 33 | def text_to_audio(self, text_packet: TextPacket) -> Generator[AudioPacket, None, None]: 34 | self.text_to_audio_file(text_packet.text, '__temp__.mp3') 35 | for audio_packet in filepath_to_audio_packet( 36 | filepath='__temp__.mp3', 37 | chunk_size=1024, 38 | remove_after=True 39 | ): 40 | yield audio_packet 41 | -------------------------------------------------------------------------------- /core/data/any_data.py: -------------------------------------------------------------------------------- 1 | import time 2 | from abc import ABCMeta 3 | 4 | class AnyData(metaclass=ABCMeta): 5 | 6 | def __init__(self, source: str = None, timestamp: int = None): 7 | """Constructor for Data. 8 | Args: 9 | source (str, optional): Source of the data. Defaults to None. 10 | """ 11 | self._source = source 12 | self._creation_time = int(time.time()* 1000) # Store creation time in milliseconds 13 | if timestamp is None: 14 | try: 15 | timestamp = self.generate_timestamp() 16 | except NotImplementedError: 17 | raise NotImplementedError( 18 | f"{self.__class__.__name__} does not implement generate_timestamp method to support automatic timestamp generation." 19 | ) 20 | self._timestamp = timestamp 21 | 22 | def generate_timestamp(self) -> int: 23 | raise NotImplementedError("Subclasses must implement generate_timestamp method") 24 | 25 | @property 26 | def timestamp(self): 27 | """Get the timestamp of the data packet. 28 | Returns: 29 | int: Timestamp in milliseconds. 30 | """ 31 | return self._timestamp 32 | 33 | @property 34 | def source(self) -> str: 35 | """Get the source of the data. 36 | Returns: 37 | str: Source of the data. 38 | """ 39 | return self._source 40 | 41 | @property 42 | def creation_time(self) -> int: 43 | """Get the creation time of the DataPacketStream.""" 44 | return self._creation_time 45 | 46 | @source.setter 47 | def source(self, value: str): 48 | """Set the source of the data. 49 | Args: 50 | value (str): Source of the data. 51 | """ 52 | self._source = value -------------------------------------------------------------------------------- /core/data/data_packet.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Type 3 | from abc import abstractmethod 4 | from copy import deepcopy 5 | from .any_data import AnyData 6 | 7 | @functools.total_ordering 8 | class DataPacket(AnyData): 9 | 10 | def __init__( 11 | self, 12 | source: str = None, 13 | timestamp: int = None, 14 | start: bool = False, 15 | partial: bool = False, 16 | **kwargs 17 | ): 18 | """Constructor for DataPacket. 19 | Args: 20 | timestamp (int, optional): Timestamp in milliseconds. Defaults to current time in milliseconds. 21 | start (bool, optional): Indicates if this is the start of a new data packet. Defaults to False. 22 | partial (bool, optional): Indicates if this is a partial data packet. Defaults to False. 23 | """ 24 | super().__init__(source=source, timestamp=timestamp) 25 | self._start = start 26 | self._partial = partial 27 | 28 | def to_dict(self) -> dict: 29 | return {"timestamp": self.timestamp} 30 | 31 | @abstractmethod 32 | def __str__(self) -> str: 33 | raise NotImplementedError() 34 | 35 | @abstractmethod 36 | def __eq__(self, __o: object) -> bool: 37 | raise NotImplementedError() 38 | 39 | @abstractmethod 40 | def __lt__(self, __o: object) -> bool: 41 | raise NotImplementedError() 42 | 43 | @abstractmethod 44 | def __len__(self) -> int: 45 | raise NotImplementedError() 46 | 47 | @abstractmethod 48 | def __getitem__(self, key): 49 | raise NotImplementedError() 50 | 51 | @abstractmethod 52 | def __add__(self, _data_packet: Type["DataPacket"]): 53 | raise NotImplementedError() 54 | 55 | def copy(self) -> "DataPacket": 56 | """Create a copy of the DataPacket instance. 57 | Returns: 58 | DataPacket: A new instance of the same type with the same attributes. 59 | """ 60 | return deepcopy(self) -------------------------------------------------------------------------------- /mangrove/vad/stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | from core.stage import AudioToAudioStage 5 | from core import AudioBuffer, AudioPacket 6 | from core.utils import logger 7 | from .endpoints.silero import SileroVAD 8 | 9 | 10 | 11 | class VADStage(AudioToAudioStage): 12 | def __init__( 13 | self, 14 | name: str, 15 | device: str = None, 16 | verbose: bool = False, 17 | **endpoint_kwargs 18 | ): 19 | if device is None: 20 | device = "cuda" if torch.cuda.is_available() else "cpu" 21 | 22 | self._endpoint = SileroVAD( 23 | **endpoint_kwargs, 24 | device=device, 25 | verbose=verbose 26 | ) 27 | # self._endpoint._output_queue = self._output_buffer # TODO the output queue is set to the stage's output buffer 28 | super().__init__(name=name, frame_size=self._endpoint.frame_size, verbose=verbose) 29 | 30 | def on_start(self) -> None: 31 | """Initialize the VAD endpoint""" 32 | self._endpoint.on_start() 33 | 34 | def process(self, audio_packet: AudioPacket) -> None: 35 | assert isinstance(audio_packet, AudioPacket), f"Expected AudioPacket, got {type(audio_packet)}" 36 | if len(audio_packet) < self.frame_size: 37 | raise NotImplementedError("Partial audio packet found; this should not happen") 38 | self._endpoint.feed(audio_packet) 39 | 40 | # if self._endpoint.is_speaking(): 41 | # self.schedule_forward_interrupt() 42 | 43 | audio_packet_utterance = self._endpoint.get_utterance_if_any() 44 | if audio_packet_utterance: 45 | # self.refresh() 46 | logger.debug(f"VADStage: Detected utterance of duration {audio_packet_utterance.duration}") 47 | self.pack(audio_packet_utterance) 48 | 49 | def reset_audio_stream(self) -> None: 50 | """Reset audio stream context""" 51 | self._endpoint.reset() 52 | 53 | # TODO use after some detection 54 | def refresh(self) -> None: 55 | """Refresh audio stream""" 56 | self.reset_audio_stream() 57 | 58 | def on_disconnect(self) -> None: 59 | self.reset_audio_stream() 60 | -------------------------------------------------------------------------------- /mangrove/vad/endpoints/webrtc.py: -------------------------------------------------------------------------------- 1 | import webrtcvad 2 | from typing import Union, List 3 | from core import AudioPacket 4 | from core.utils import logger 5 | from .base import VoiceActivityDetector 6 | 7 | class WebRTCVAD(VoiceActivityDetector): 8 | 9 | def __init__( 10 | self, 11 | aggressiveness: int = 3, 12 | silence_threshold: int = 200, 13 | frame_size: int = 320 * 3, 14 | verbose=False, 15 | ): 16 | if frame_size not in [320, 640, 960]: 17 | raise ValueError("Frame size must be 320, 640 or 960 with WebRTC VAD") 18 | self.aggressiveness = aggressiveness 19 | super().__init__( 20 | tail_silence_threshold=silence_threshold, 21 | frame_size=frame_size, 22 | verbose=verbose 23 | ) 24 | 25 | def on_start(self) -> None: 26 | """Initialize the VAD model""" 27 | self.model = webrtcvad.Vad(self.aggressiveness) 28 | if self.verbose: 29 | logger.info(f"WebRTCVAD initialized with aggressiveness {self.aggressiveness} and frame size {self.frame_size}") 30 | 31 | def is_speech(self, audio_packets: Union[List[AudioPacket], AudioPacket]) -> Union[bool, List[bool]]: 32 | """Check if audio is speech 33 | 34 | Args: 35 | audio_packet (AudioPacket): Audio packet to check 36 | 37 | Returns: 38 | bool: True if speech, False otherwise 39 | """ 40 | one_item = False 41 | if not isinstance(audio_packets, list): 42 | audio_packets = [audio_packets] 43 | one_item = True 44 | 45 | is_speeches = [] 46 | for packet in audio_packets: 47 | if len(packet) < self.frame_size: 48 | # partial TODO maybe add to buffer 49 | break 50 | audio_bytes, sample_rate = packet.bytes, packet.sample_rate 51 | is_speeches.append(self.model.is_speech(audio_bytes, sample_rate)) 52 | 53 | # if any([not is_speech for is_speech in is_speeches]): 54 | # self.model = webrtcvad.Vad(self.aggressiveness) 55 | 56 | if one_item: 57 | return is_speeches[0] 58 | return is_speeches 59 | 60 | def reset(self) -> None: 61 | super().reset() 62 | self.model = webrtcvad.Vad(self.aggressiveness) -------------------------------------------------------------------------------- /mangrove/stt/wakeup_word/audio_classification_endpoint.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | from abc import ABC, abstractmethod 3 | from transformers import pipeline 4 | from storage_manager import write_output 5 | from core.utils import logger 6 | 7 | 8 | class AudioClassificationEndpoint(ABC): 9 | @abstractmethod 10 | def detect(self, preprocessed_mic: Generator) -> Generator: 11 | raise NotImplementedError 12 | 13 | @property 14 | @abstractmethod 15 | def sample_rate(self): 16 | raise NotImplementedError 17 | 18 | @property 19 | @abstractmethod 20 | def frame_size(self): 21 | raise NotImplementedError 22 | 23 | 24 | class HFAudioClassificationEndpoint: 25 | def __init__( 26 | self, 27 | model_name: str = "MIT/ast-finetuned-speech-commands-v2", 28 | wake_word: str = "marvin", 29 | prediction_prob_threshold: float = 0.7, 30 | device: str = "cuda", 31 | ): 32 | self._classifier = pipeline( 33 | "audio-classification", model=model_name, device=device, 34 | ) 35 | self.prediction_prob_threshold = prediction_prob_threshold 36 | 37 | if wake_word not in self._classifier.model.config.label2id.keys(): 38 | raise ValueError( 39 | f"Wake word {wake_word} not in set of valid class labels," 40 | f"pick a wake word in the set {self._classifier.model.config.label2id.keys()}." 41 | ) 42 | 43 | self.wake_word = wake_word 44 | 45 | logger.info( 46 | f"Wakeword set is {self.wake_word} out of {self._classifier.model.config.label2id.keys()}" 47 | ) 48 | 49 | def detect(self, preprocessed_mic: Generator) -> Generator: 50 | is_detected = False 51 | for prediction in self._classifier(preprocessed_mic): 52 | write_output("<", end="") 53 | prediction = prediction[0] 54 | if prediction["label"] == self.wake_word: 55 | if prediction["score"] > self.prediction_prob_threshold: 56 | is_detected = True 57 | break 58 | if is_detected: 59 | return True 60 | return False 61 | 62 | @property 63 | def sample_rate(self): 64 | return self._classifier.feature_extractor.sampling_rate 65 | 66 | @property 67 | def frame_size(self): 68 | return 320 -------------------------------------------------------------------------------- /mangrove/bot/endpoints/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Iterator, List 3 | from copy import copy 4 | from langchain_core.output_parsers import StrOutputParser 5 | from langchain_core.runnables import Runnable, RunnableConfig 6 | from langchain_core.messages import HumanMessage, AIMessage, BaseMessage 7 | 8 | from ..persona.base import BotPersona 9 | 10 | class NotSetupYetError(Exception): 11 | pass 12 | 13 | class ConversationalChainEndpoint(metaclass=ABCMeta): 14 | pass 15 | 16 | class LangchainCompatibleConversationalChainEndpoint(ConversationalChainEndpoint): 17 | 18 | @property 19 | @abstractmethod 20 | def llm(self) -> Runnable: 21 | raise NotImplementedError("You must implement the llm property in your subclass") 22 | 23 | def setup(self, persona: BotPersona): 24 | self._persona: BotPersona = persona 25 | self._chain: Runnable = ( 26 | persona.respond_chain.with_config({ 27 | "run_name": f"{persona.__class__.__name__}PersonaRespondChain" 28 | }) 29 | | self.llm 30 | | StrOutputParser() 31 | | persona.postprocess_chain.with_config( 32 | { 33 | "run_name": f"{persona.__class__.__name__}PersonaPostprocessChain" 34 | } 35 | ) 36 | ).with_config({ 37 | "run_name": f"{persona.__class__.__name__}PersonaChain", 38 | }) 39 | 40 | @property 41 | def persona(self) -> BotPersona: 42 | if not hasattr(self, '_persona'): 43 | raise NotSetupYetError("You must call setup() before accessing the persona") 44 | return self._persona 45 | 46 | @property 47 | def chain(self) -> Runnable: 48 | if not hasattr(self, '_chain'): 49 | raise NotSetupYetError("You must call setup() before accessing the chain") 50 | return self._chain 51 | 52 | def stream(self, user_msg, chat_history: List[BaseMessage]) -> Iterator[str]: 53 | # chat_history_formatted: str = "" 54 | # for message in chat_history: 55 | # if isinstance(message, HumanMessage): 56 | # chat_history_formatted += f'User Statement: {message.content}\n' 57 | # elif isinstance(message, AIMessage): 58 | # chat_history_formatted += f'{self._persona.assistant_name} Statement: {message.content}\n' 59 | # else: 60 | # raise Exception(f'{message} is not of expected type!') 61 | 62 | return self._chain.stream( 63 | self._persona.construct_input(user_msg, chat_history) 64 | ) -------------------------------------------------------------------------------- /mangrove/stt/endpoints/faster_whisper.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from faster_whisper import WhisperModel 3 | 4 | from core.utils import logger 5 | from core.data import AudioPacket, DataBufferEmpty 6 | from core.utils import Timer 7 | from .base import STTEndpoint 8 | 9 | class FasterWhisperEndpoint(STTEndpoint): 10 | def __init__(self, model_name="distil-medium.en", device=None): 11 | super().__init__() 12 | self.device = "auto" if device is None else device 13 | try: 14 | self.model = WhisperModel(model_name, device=self.device, compute_type="int8") 15 | except: 16 | logger.warning(f'Device {device} is not supported, defaulting to CPU!') 17 | self.model = WhisperModel(model_name, device='cpu') 18 | 19 | # Custom VAD parameters 20 | self.vad_parameters = { 21 | "threshold": 0.3, # Lower = more sensitive to quiet speech 22 | "min_speech_duration_ms": 500, # Minimum speech chunk 23 | "min_silence_duration_ms": 1000, # Longer pause needed to split 24 | "speech_pad_ms": 600, # Padding around speech segments 25 | } 26 | self.reset() 27 | 28 | def get_transcription_if_any(self) -> Optional[str]: 29 | """Get transcription if available 30 | 31 | Returns: 32 | str: Transcription if available, else None 33 | """ 34 | 35 | logger.trace("Waiting for transcription ... ") 36 | 37 | audio_packet = self.get_buffered_audio_packet() 38 | if audio_packet is None: 39 | return None 40 | 41 | 42 | with Timer() as timer: 43 | segments, _ = self.model.transcribe( 44 | audio_packet.float, 45 | language='en', 46 | vad_filter=True, 47 | vad_parameters=self.vad_parameters, # Pass custom VAD settings 48 | without_timestamps=True 49 | ) 50 | _out = list(segments) 51 | if len(_out) >= 1: 52 | _out = " ".join([segment.text for segment in _out]) 53 | logger.success(f"Took {timer.record()} seconds") 54 | 55 | # if _out: 56 | # logger.success(f"Transcription: {_out}") 57 | # # Save the transcription to a wav file 58 | # filepath = f"blackbox/transcribed_{audio_packet.timestamp}.wav" 59 | # audio_packet.to_wav(filepath) 60 | 61 | if isinstance(_out, list): 62 | assert len(_out) == 0, "Transcription list is empty" 63 | return None 64 | 65 | assert isinstance(_out, str), f"Transcription must be a string, got {type(_out)}" 66 | return _out 67 | 68 | def reset(self): 69 | while True: 70 | try: 71 | self.input_queue.get_nowait() 72 | except DataBufferEmpty: 73 | break 74 | logger.debug(f"Resetting {self.__class__.__name__} endpoint") 75 | -------------------------------------------------------------------------------- /mangrove/bot/persona/protector_of_mangrove_nemotron.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from langchain_ollama import OllamaEmbeddings 3 | from langchain_community.vectorstores import FAISS 4 | from langchain_core.prompts import format_document, ChatPromptTemplate 5 | from langchain_core.runnables import RunnablePassthrough, Runnable 6 | from langchain_core.prompts.prompt import PromptTemplate 7 | from operator import itemgetter 8 | from .base import BotPersona 9 | from .protector_of_mangrove import * 10 | 11 | BASE_SYSTEM_PROMPT_TEMPLATE = """ 12 | You, {assistant_name} are the badass protector of the Mangrove. You think, act, and speak like Duke Nukem. Speak confidently and colloquially. Do not cuss or use profanity. Do not repeat facts you have already said. Be concise. Limit your responses to three sentences at most. 13 | 14 | You have three actions available: [Follow User], [Sit Down], [Stop Following User]. 15 | 16 | When you want to perform one of these actions, you must produce it on a line by itself, enclosed in square brackets, like this: [Follow User] 17 | 18 | If no action is needed, do not produce any bracketed text. Limit your entire response to at most three sentences. 19 | 20 | Reply to the following given the following knowledge base (Be relatively concise): 21 | {context} 22 | """ 23 | 24 | class ProtectorOfMangroveNemotron(ProtectorOfMangrove): 25 | def __init__(self, assistant_name='Marvin'): 26 | self.assistant_name = assistant_name 27 | 28 | template = """You are the badass protector of the Mangrove. You think, act, and speak like Duke Nukem. Speak confidently and colloquially. Do not cuss or use profanity. Do not repeat facts you have already said. Be concise. Limit your responses to three sentences at most. 29 | 30 | You have three actions available: [Follow User], [Sit Down], [Stop Following User]. 31 | 32 | When you want to perform one of these actions, you must produce it on a line by itself, enclosed in square brackets, like this: [Follow User] 33 | 34 | If no action is needed, do not produce any bracketed text. Limit your entire response to at most three sentences. 35 | 36 | Reply to the following given the following knowledge base (Be relatively concise): 37 | {context} 38 | 39 | {chat_history} 40 | 41 | User Statement: {user_msg} 42 | """ 43 | template += f"\n{assistant_name} Statement:" 44 | 45 | self.assistant_name = assistant_name 46 | self._prompt = ChatPromptTemplate( 47 | messages=[ 48 | SystemMessagePromptTemplate.from_template( 49 | template=[ 50 | {"type": "text", "text": BASE_SYSTEM_PROMPT_TEMPLATE}, 51 | ] 52 | ), 53 | MessagesPlaceholder("chat_history"), 54 | HumanMessagePromptTemplate.from_template( 55 | template=[ 56 | {"type": "text", "text": "{user_msg}"}, 57 | ] 58 | ), 59 | ] 60 | ).partial( 61 | assistant_name=self.assistant_name 62 | ) 63 | self.vectorstore = FAISS.from_texts(KNOWLEDGE_BASE, embedding=OllamaEmbeddings(model="nemotron-mini")) -------------------------------------------------------------------------------- /core/utils/audio.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import time 4 | import scipy 5 | import pydub 6 | import backoff 7 | import numpy as np 8 | from pydub import AudioSegment 9 | from typing import Generator 10 | from core import AudioPacket 11 | 12 | # TODO adjust automatically a sort of universal target_sample_rate according to client's perference! 13 | TARGET_SAMPLE_RATE = 48000 14 | 15 | def filepath_to_audio_packet( 16 | filepath: str='__temp__.mp3', 17 | chunk_size: int=1024, 18 | remove_after: bool=False, 19 | max_tries: int=10, 20 | target_sample_rate: int=TARGET_SAMPLE_RATE 21 | ) -> Generator[AudioPacket, None, None]: 22 | # load mp3 file 23 | # logger.debug(f"Loading mp3 file: {filepath}") 24 | @backoff.on_exception(backoff.expo, FileNotFoundError, max_tries=max_tries) 25 | def load_mp3(): 26 | return pydub.AudioSegment.from_mp3(filepath) 27 | audio = load_mp3() 28 | # delete the file 29 | if remove_after: 30 | os.remove(filepath) 31 | 32 | # chunk the audio 33 | last_packet_timestamp = time.time() * 1000 # current timestamp in milliseconds 34 | num_chunks = len(audio) // chunk_size + (1 if len(audio) % chunk_size > 0 else 0) 35 | # generate timestamps for each chunk (going back in time) 36 | simulated_timestamps = list(reversed([ 37 | last_packet_timestamp - (i * chunk_size) for i in range(num_chunks) 38 | ])) 39 | timestamps_idx = 0 40 | for i in range(0, len(audio), chunk_size): 41 | yield AudioPacket({ 42 | 'timestamp': int(simulated_timestamps[timestamps_idx]), 43 | 'bytes': audio[i:i + chunk_size]._data, 44 | 'sampleRate': audio.frame_rate, 45 | 'sampleWidth': audio.sample_width, 46 | 'numChannels': audio.channels, 47 | }, resample=True, is_processed=False, 48 | target_sample_rate=target_sample_rate 49 | ) 50 | timestamps_idx += 1 51 | 52 | def pydub_audio_segment_to_audio_packet( 53 | audio_segment: AudioSegment, 54 | target_sample_rate: int=TARGET_SAMPLE_RATE 55 | ) -> AudioPacket: 56 | return AudioPacket({ 57 | 'bytes': audio_segment._data, 58 | 'sampleRate': audio_segment.frame_rate, 59 | 'sampleWidth': audio_segment.sample_width, 60 | 'numChannels': audio_segment.channels, 61 | }, resample=True, is_processed=False, 62 | target_sample_rate=target_sample_rate 63 | ) 64 | 65 | def np_audio_to_audio_segment(wav_audio: np.ndarray, sample_rate: int): 66 | wav_norm = wav_audio * (32767 / max(0.01, np.max(np.abs(wav_audio)))) 67 | wav_norm = wav_norm.astype(np.int16) 68 | wav_buffer = io.BytesIO() 69 | scipy.io.wavfile.write(wav_buffer, sample_rate, wav_norm) 70 | wav_buffer.seek(0) 71 | return AudioSegment.from_file(wav_buffer, format="wav") 72 | 73 | def np_audio_to_audio_packet(wav_audio: np.ndarray, sample_rate: int): 74 | return pydub_audio_segment_to_audio_packet( 75 | np_audio_to_audio_segment(wav_audio, sample_rate) 76 | ) 77 | 78 | def bytes_to_audio_packet(audio_bytes: bytes, format=None) -> AudioPacket: 79 | # convert bytes to audio segment 80 | audio_segment = AudioSegment.from_file(io.BytesIO(audio_bytes), format=format) 81 | return pydub_audio_segment_to_audio_packet(audio_segment) -------------------------------------------------------------------------------- /mangrove/tts/endpoints/xtts.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import time 4 | import scipy 5 | import torch 6 | import numpy as np 7 | from typing import Generator 8 | from TTS.tts.configs.xtts_config import XttsConfig, XttsAudioConfig 9 | from TTS.tts.models.xtts import Xtts 10 | from TTS.api import TTS 11 | from pydub import AudioSegment 12 | 13 | from core.utils import logger 14 | from core.data import AudioPacket, TextPacket 15 | from core.utils.audio import np_audio_to_audio_packet 16 | from .elevenlabs import ElevenLabsTTSEndpoint 17 | from .base import TTSEndpoint 18 | 19 | class TTSLibraryEndpoint(TTSEndpoint): 20 | def __init__( 21 | self, 22 | model_name="tts_models/multilingual/multi-dataset/xtts_v2", 23 | device=None, 24 | **kwargs 25 | ): 26 | if device is None: 27 | device = "cuda" if torch.cuda.is_available() else "cpu" 28 | 29 | ckpt_dir = TTS().download_model_by_name(model_name)[-1] 30 | config_path = os.path.join(ckpt_dir, "config.json") 31 | if not os.path.exists(config_path): 32 | raise ValueError(f"Config file not found at {config_path}") 33 | 34 | config = XttsConfig() 35 | config.load_json(config_path) 36 | model: Xtts = Xtts.init_from_config(config) 37 | model.load_checkpoint(config, checkpoint_dir=ckpt_dir, use_deepspeed=True) 38 | if device == "cuda": 39 | model.cuda() 40 | self._ensure_speaker_wav() 41 | logger.info("Computing speaker latents of xTTS model") 42 | self.gpt_cond_latent, self.speaker_embedding = model.get_conditioning_latents(audio_path=["speaker.wav"]) 43 | self.model = model 44 | self.sample_rate = XttsAudioConfig.output_sample_rate 45 | 46 | def _ensure_speaker_wav(self) -> None: 47 | if not os.path.exists('speaker.wav'): 48 | # generate speaker.wav using ElevenLabsTTSEndpoint 49 | logger.warning("Generating speaker.wav using ElevenLabsTTSEndpoint as it is not available.") 50 | ElevenLabsTTSEndpoint().text_to_audio_file( 51 | "Hello, I am your assistant. I am here to help you with your tasks." 52 | "I am a digital assistant created by the Estuary team. I am here to help you with your tasks.", 53 | 'speaker.wav' 54 | ) 55 | 56 | def text_to_audio_file(self, text, filepath) -> None: 57 | raise NotImplementedError("This method is not implemented for TTSLibraryEndpoint") 58 | # self.engine.tts_to_file(text=text, file_path=filepath, speaker_wav="speaker.wav", language="en") 59 | 60 | def text_to_audio(self, text_packet: TextPacket) -> Generator[AudioPacket, None, None]: 61 | t0 = time.time() 62 | chunks = self.model.inference_stream( 63 | text_packet.text, 64 | language="en", # TODO pull other associated variables from text_packet if applicable 65 | gpt_cond_latent=self.gpt_cond_latent, 66 | speaker_embedding=self.speaker_embedding, 67 | stream_chunk_size=300, 68 | enable_text_splitting=True, 69 | ) 70 | 71 | for i, chunk in enumerate(chunks): 72 | # if i == 0: 73 | # print(f"Time to first chunck: {time.time() - t0}") 74 | # print(f"Received chunk {i} of audio length {chunk.shape[-1]}") 75 | yield np_audio_to_audio_packet(chunk.cpu().numpy(), self.sample_rate) 76 | 77 | -------------------------------------------------------------------------------- /core/data/text_packet.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List 3 | from core.utils import logger 4 | from .data_packet import DataPacket 5 | from .exceptions import SequenceMismatchException 6 | 7 | class TextPacket(DataPacket): 8 | 9 | def __init__( 10 | self, 11 | text: str, 12 | partial: bool, 13 | start: bool, 14 | source: str = None, 15 | commands: List[str]=[], 16 | timestamp=None, 17 | **metadata # TODO add metadata to all packets 18 | ): 19 | super().__init__( 20 | source=source, 21 | timestamp=timestamp, 22 | start=start, 23 | partial=partial 24 | ) 25 | self._text = text 26 | assert isinstance(self._text, str), f"Text must be a string, got {type(self._text)}" 27 | self.commands = commands if commands else [] 28 | for key, value in metadata.items(): 29 | setattr(self, key, value) 30 | 31 | @classmethod 32 | def from_dict(cls, json_data: dict): 33 | """Create a TextPacket from a dictionary.""" 34 | text = json_data.get("text", "") 35 | partial = json_data.get("partial", False) 36 | start = json_data.get("start", True) 37 | source = json_data.get("source", None) 38 | commands = json_data.get("commands", []) 39 | 40 | return cls( 41 | text=text, 42 | partial=partial, 43 | start=start, 44 | source=source, 45 | commands=commands, 46 | ) 47 | 48 | 49 | def generate_timestamp(self) -> int: 50 | """Generate a timestamp in milliseconds.""" 51 | return int(time.time() * 1000) 52 | 53 | @property 54 | def text(self): 55 | return self._text 56 | 57 | def to_dict(self): 58 | return { 59 | "text": self._text, 60 | "partial": self._partial, 61 | "start": self._start, 62 | "commands": self.commands, 63 | "timestamp": self.timestamp 64 | } 65 | 66 | @property 67 | def partial(self): 68 | return self._partial 69 | 70 | @property 71 | def start(self): 72 | return self._start 73 | 74 | def __str__(self): 75 | return f'TextPacket(ts={self.timestamp}, text="{self._text}", partial={self._partial}, start={self._start}, src="{self.source})' 76 | 77 | def __eq__(self, other: 'TextPacket'): 78 | return self.timestamp == other.timestamp and self._text == other._text 79 | 80 | def __lt__(self, other: 'TextPacket'): 81 | return self.timestamp < other.timestamp 82 | 83 | def __len__(self): 84 | return len(self._text) 85 | 86 | def __getitem__(self, key): 87 | return self._text[key] 88 | 89 | def __add__(self, other: 'TextPacket'): 90 | if self._partial != other._partial and self._timestamp >= 0: 91 | raise SequenceMismatchException("Cannot add partial and non-partial packets: {self} + {other}") 92 | if not self._start and other._start: 93 | raise SequenceMismatchException("Cannot add start and non-start packets: {self} + {other}") 94 | 95 | return TextPacket( 96 | text=self._text + other.text, 97 | partial=self._partial if self._timestamp > -1 else other.partial, 98 | start=self._start, 99 | commands=self.commands + other.commands, 100 | timestamp=self._timestamp 101 | ) -------------------------------------------------------------------------------- /mangrove/bot/persona/protector_of_mangrove_qwen3.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict 3 | from langchain_ollama import OllamaEmbeddings 4 | from langchain_community.vectorstores import FAISS 5 | from langchain_core.prompts import format_document, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder 6 | from langchain_core.runnables import RunnablePassthrough, Runnable 7 | from langchain_core.prompts.prompt import PromptTemplate 8 | from langchain.text_splitter import RecursiveCharacterTextSplitter 9 | from operator import itemgetter 10 | from .base import BotPersona 11 | from .protector_of_mangrove import * 12 | 13 | class ProtectorOfMangroveQwen3(ProtectorOfMangrove): 14 | def __init__(self, persona_file: str=None): 15 | # Load persona data from JSON file 16 | if persona_file: 17 | with open(persona_file, 'r') as f: 18 | self.persona = json.load(f) 19 | self.assistant_name = self.persona.get("name") 20 | 21 | # Create dynamic system prompt using JSON fields 22 | system_prompt = self._create_system_prompt() 23 | 24 | self._prompt = ChatPromptTemplate( 25 | messages=[ 26 | SystemMessagePromptTemplate.from_template( 27 | template=[ 28 | {"type": "text", "text": system_prompt}, 29 | ] 30 | ), 31 | MessagesPlaceholder("chat_history"), 32 | HumanMessagePromptTemplate.from_template( 33 | template=[ 34 | {"type": "text", "text": "{user_msg}"}, 35 | ] 36 | ), 37 | ] 38 | ).partial( 39 | assistant_name=self.assistant_name 40 | ) 41 | splitter = RecursiveCharacterTextSplitter( 42 | separators=["\n\n", "\n", ".", "!", "?"], 43 | chunk_size=200, 44 | chunk_overlap=0 45 | ) 46 | self.KNOWLEDGE_BASE = [chunk.strip() for chunk in splitter.split_text(self.persona.get("background"))] 47 | print(f"Knowledge base: {self.KNOWLEDGE_BASE}") 48 | self.vectorstore = FAISS.from_texts(self.KNOWLEDGE_BASE, embedding=OllamaEmbeddings(model="qwen3:8b")) 49 | 50 | def _create_system_prompt(self) -> str: 51 | """Create a dynamic system prompt using the JSON persona fields""" 52 | 53 | # Base template with placeholders for JSON fields 54 | base_template = """You are {name}, {tagline} 55 | 56 | {personality} 57 | 58 | {description} 59 | 60 | You have three actions available: [Follow User], [Sit Down], [Stop Following User]. 61 | 62 | When you want to perform one of these actions, you must produce it on a line by itself, enclosed in square brackets, like this: [Follow User] 63 | 64 | If no action is needed, do not produce any bracketed text. Limit your entire response to at most three sentences. 65 | 66 | Reply to the following given the following knowledge base (Be relatively concise): 67 | {context} 68 | """ 69 | 70 | # Extract fields from persona JSON, with fallbacks 71 | name = self.persona.get("name", "an AI assistant") 72 | tagline = self.persona.get("tagline", "") 73 | personality = self.persona.get("personality", "You are helpful and friendly.") 74 | description = self.persona.get("description", "") 75 | 76 | # Format the template with the JSON data 77 | return base_template.format( 78 | name=name, 79 | tagline=tagline, 80 | personality=personality, 81 | description=description, 82 | context="{context}" # Keep this as a placeholder for the context chain 83 | ) -------------------------------------------------------------------------------- /mangrove/vad/endpoints/silero.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, List, Optional 3 | from core import AudioPacket, AudioBuffer 4 | from .base import VoiceActivityDetector 5 | 6 | class SileroVAD(VoiceActivityDetector): 7 | """Voice Activity Detector using Silero VAD 8 | This class implements a voice activity detector using the Silero VAD model. 9 | It checks if the audio packets contain speech based on a threshold. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | is_speech_threshold: float = 0.9, 15 | device: Optional[str] = None, 16 | frame_size: int = 512 * 4, 17 | **kwargs 18 | ): 19 | """ 20 | Initialize the SileroVAD. 21 | 22 | Args: 23 | is_speech_threshold (float): Threshold to determine if the audio is speech. 24 | device (Optional[str]): Device to run the model on, e.g., 'cpu' or 'cuda:0'. 25 | frame_size (int): Size of the audio frame in samples. Must be at least 512*4 for Silero VAD. 26 | **kwargs: Additional keyword arguments for the base class. 27 | 28 | Raises: 29 | ValueError: If frame_size is less than 512*4. (Silero VAD requires a minimum frame size of 2048 samples) 30 | """ 31 | 32 | if frame_size < 512 * 4: 33 | raise ValueError("Frame size must be at least 512*4 with Silero VAD") 34 | 35 | self.device = device 36 | if self.device is None: 37 | self.device = "cuda:0" if torch.cuda.is_available() else "cpu" 38 | elif device.startswith('cuda'): 39 | self.device = "cuda:0" 40 | else: 41 | # because others are not guaranteed to work 42 | self.device = "cpu" 43 | 44 | self.is_speech_threshold = is_speech_threshold 45 | super().__init__(frame_size=frame_size, **kwargs) 46 | 47 | def on_start(self) -> None: 48 | """Initialize the VAD model""" 49 | self.model, utils = torch.hub.load( 50 | repo_or_dir="snakers4/silero-vad", 51 | model="silero_vad", 52 | force_reload=False, 53 | onnx=False, 54 | ) 55 | self.model: torch.nn.Module = self.model.eval() 56 | self.model.to(self.device) 57 | 58 | # (get_speech_timestamps, 59 | # save_audio, 60 | # read_audio, 61 | # VADIterator, 62 | # collect_chunks) = utils 63 | # vad_iterator = VADIterator(model) 64 | 65 | 66 | def is_speech(self, audio_packets: Union[List[AudioPacket], AudioPacket]) -> Union[bool, List[bool]]: 67 | """Check if audio is speech 68 | 69 | Args: 70 | audio_packet (AudioPacket): Audio packet to check 71 | 72 | Returns: 73 | bool: True if speech, False otherwise 74 | """ 75 | one_item = False 76 | if not isinstance(audio_packets, list): 77 | audio_packets = [audio_packets] 78 | one_item = True 79 | 80 | audio_buffer = AudioBuffer(self.frame_size) 81 | for audio_packet in audio_packets: 82 | audio_buffer.put(audio_packet) 83 | 84 | is_speeches = [] 85 | for packet in audio_buffer: 86 | if len(packet) < self.frame_size: 87 | # partial TODO maybe add to buffer 88 | break 89 | _audio_tensor = torch.from_numpy(packet.float).to(self.device) 90 | is_speeches.append( 91 | self.model(_audio_tensor, packet.sample_rate) > self.is_speech_threshold 92 | ) 93 | 94 | # if any([not is_speech for is_speech in is_speeches]): 95 | # self.model.reset_states() 96 | 97 | if one_item: 98 | return is_speeches[0] 99 | return is_speeches 100 | 101 | def reset(self) -> None: 102 | super().reset() 103 | self.model.reset_states() 104 | -------------------------------------------------------------------------------- /host.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | from abc import abstractmethod, ABCMeta 3 | 4 | from flask import Flask 5 | from flask_socketio import SocketIO, Namespace 6 | from storage_manager import StorageManager, write_output 7 | from multiprocessing import Lock 8 | from core import AudioPacket, TextPacket, DataPacket 9 | from core.utils import logger 10 | from agents import Agent 11 | 12 | 13 | # TODO create feedback loop (ACK), and use it for interruption!! 14 | 15 | class HostNamespace(Namespace, metaclass=ABCMeta): 16 | @abstractmethod 17 | def start_background_task(self, target, *args, **kwargs): 18 | """Start a background task in the server""" 19 | raise NotImplementedError("This method should be implemented in the subclass") 20 | 21 | 22 | class SocketIONamespace(HostNamespace): 23 | """Digital Assistant SocketIO Namespace""" 24 | 25 | def __init__( 26 | self, 27 | agent: Agent, 28 | namespace="/", 29 | ): 30 | super().__init__(namespace) 31 | self.server: Optional[SocketIO] 32 | self.namespace = namespace 33 | self.agent = agent 34 | self.__lock__ = Lock() 35 | 36 | def setup(self) -> None: 37 | if self.server is None: 38 | raise RuntimeError("Server is not initialized yet") 39 | self.agent.start(self) 40 | 41 | def start_background_task(self, target, *args, **kwargs): # TODO find convenient generic type hinting 42 | if self.server is None: 43 | raise RuntimeError("Server is not initialized yet") 44 | return self.server.start_background_task(target, *args, **kwargs) 45 | 46 | def __emit__(self, event, data: DataPacket) -> None: 47 | assert isinstance(data, DataPacket), f"Expected DataPacket, got {type(data)}" 48 | logger.trace(f"Emitting {event}") 49 | if hasattr(data, "to_dict"): 50 | data = data.to_dict() 51 | self.server.emit(event, data) 52 | 53 | def emit_bot_voice(self, audio_packet: AudioPacket) -> None: 54 | self.__emit__("bot_voice", audio_packet) 55 | 56 | def emit_bot_response(self, text_packet: TextPacket) -> None: 57 | self.__emit__("bot_response", text_packet) 58 | 59 | def emit_stt_response(self, text_packet: TextPacket) -> None: 60 | self.__emit__("stt_response", text_packet) 61 | 62 | def emit_interrupt(self, timestamp: int) -> None: 63 | self.server.emit("interrupt", timestamp) 64 | 65 | def on_connect(self): 66 | logger.info("client connected") 67 | StorageManager.establish_session() 68 | self.agent.on_connect() 69 | 70 | def on_disconnect(self): 71 | logger.info("client disconnected\n") 72 | with self.__lock__: 73 | self.agent.on_disconnect() 74 | StorageManager.clean_up() 75 | 76 | def on_stream_audio(self, audio_data: Dict): 77 | with self.__lock__: 78 | # Feeding in audio stream 79 | write_output("-", end="") 80 | from core import AudioPacket 81 | self.agent.feed(AudioPacket(data_json=audio_data)) 82 | 83 | def on_text(self, text_data: Dict): 84 | with self.__lock__: 85 | # Feeding in text stream 86 | write_output(f"received text: {text_data}") 87 | self.agent.feed(TextPacket.from_dict(text_data)) 88 | 89 | # def on_trial(self, data): 90 | # write_output(f"received trial: {data}") 91 | 92 | # def on_error(self, e): 93 | # logger.error(f"Error: {e}") 94 | # self.emit("error", {"msg": str(e)}, status=ClientStatus.NOT_CONNECTED) 95 | 96 | 97 | class FlaskSocketIOHost: 98 | """Flask SocketIO Host for the Digital Assistant""" 99 | 100 | def __init__( 101 | self, 102 | flask_secret_key: str = "secret!", 103 | is_logging: bool = False, 104 | ): 105 | self.app = Flask(__name__) 106 | self.app.config["SECRET_KEY"] = flask_secret_key 107 | self.socketio = SocketIO( 108 | self.app, 109 | cors_allowed_origins="*", 110 | cors_credentials=True, 111 | logger=is_logging, 112 | async_handlers=False 113 | ) 114 | 115 | def run(self, agent, namespace="/", host="0.0.0.0", port=5000): 116 | logger.info("Starting the server...") 117 | self.host = SocketIONamespace(agent=agent, namespace=namespace) 118 | self.socketio.on_namespace(self.host) 119 | self.host.setup() 120 | logger.info(f"Running server on {host}:{port} with namespace {namespace}") 121 | self.socketio.run(self.app, host=host, port=port, use_reloader=False, allow_unsafe_werkzeug=True) -------------------------------------------------------------------------------- /launcher.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os, argparse 3 | import torch 4 | from dotenv import load_dotenv 5 | 6 | from core.utils import logger 7 | from host import FlaskSocketIOHost 8 | from agents import BasicConversationalAgent 9 | 10 | load_dotenv(override=True) 11 | 12 | if __name__ == "__main__": 13 | # TODO use a yml config file with internal configurations 14 | parser = argparse.ArgumentParser(description="Run the Digital Companion Mangrove as a SocketIO server.") 15 | parser.add_argument( 16 | "--cpu", dest="cpu", default=False, action="store_true", 17 | help="Use CPU instead of GPU" 18 | ) 19 | parser.add_argument( 20 | "--bot_endpoint", dest="bot_endpoint", type=str, default="openai", 21 | choices=["openai", "ollama"], 22 | help="Bot Conversational Endpoint" 23 | ) 24 | parser.add_argument( 25 | "--tts_endpoint", dest="tts_endpoint", type=str, default="xtts", 26 | choices=["pyttsx3", "gtts", "elevenlabs", "xtts"], 27 | help="TTS Endpoint" 28 | ) 29 | parser.add_argument( 30 | "--port", dest="port", type=int, default=4000, help="Port number" 31 | ) 32 | parser.add_argument( 33 | "--debug", dest="debug", type=bool, default=False, help="Debug mode" 34 | ) 35 | parser.add_argument("--log", dest="log", type=bool, default=False, help="Log mode") 36 | parser.add_argument( 37 | "--flask-secret-key", dest="flask_secret_key", type=str, default="secret!", 38 | help="Flask secret key" 39 | ) 40 | parser.add_argument( 41 | "--persona", dest="persona", type=str, default=None, 42 | help="File path to persona json file" 43 | ) 44 | parser.add_argument( 45 | "--namespace", dest="namespace", type=str, default="/", 46 | help="SocketIO namespace" 47 | ) 48 | parser.add_argument( 49 | "--text-only", dest="text_only", action="store_true", default=False, 50 | help="Run in text-only mode (no audio processing)" 51 | ) 52 | args = parser.parse_args() 53 | 54 | # Show up to DEBUG logger level in console 55 | logger.remove() 56 | logger.add(sys.stdout, level="DEBUG", enqueue=True) 57 | 58 | 59 | if args.cpu: 60 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force CPU 61 | 62 | # @socketio.on_error_default # handles all namespaces without an explicit error handler 63 | # def default_error_handler(e): 64 | # write_output(f'Error debug {e}') 65 | # # stt.reset_audio_stream() 66 | # # # TODO reset anything 67 | 68 | if args.cpu: 69 | device = "cpu" 70 | elif torch.cuda.is_available(): 71 | device = "cuda" 72 | elif torch.backends.mps.is_available(): 73 | device = "mps" 74 | else: 75 | device = "cpu" 76 | 77 | host = FlaskSocketIOHost() 78 | 79 | # Set default persona configs if none provided 80 | persona_configs = args.persona 81 | if persona_configs is None: 82 | if args.bot_endpoint == "openai": 83 | # Default persona config for OpenAI endpoint 84 | persona_configs = {"assistant_name": "Marvin"} 85 | elif args.bot_endpoint == "ollama": 86 | # Default persona config for Ollama endpoint - use a default persona file 87 | persona_configs = "mangrove/bot/persona/default_persona.json" 88 | 89 | # Configure endpoints based on text-only mode 90 | if args.text_only: 91 | # Text-only mode: only need bot endpoint 92 | endpoints = {"bot": args.bot_endpoint} 93 | else: 94 | # Voice mode: need both bot and TTS endpoints 95 | endpoints = { 96 | "bot": args.bot_endpoint, 97 | "tts": args.tts_endpoint 98 | } 99 | 100 | agent = BasicConversationalAgent( 101 | text_only=args.text_only, 102 | endpoints=endpoints, 103 | persona_configs=persona_configs, 104 | device=device, 105 | verbose=args.debug, 106 | ) 107 | 108 | logger.success( 109 | f"\nYour Digital Assistant is running on port {args.port}." 110 | "\n# Hints:" 111 | + '1. Run "ipconfig" in your terminal and use Wireless LAN adapter Wi-Fi IPv4 Address.\n' 112 | + "2. Ensure your client is connected to the same WIFI connection.\n" 113 | + "3. Ensure firewall shields are down in this particular network type with python.\n" 114 | + "4. Ensure your client microphone is not used by any other services such as windows speech-to-text api.\n" 115 | + "Fight On!" 116 | ) 117 | 118 | host.run( 119 | agent=agent, 120 | namespace=args.namespace, 121 | host="0.0.0.0", 122 | port=args.port 123 | ) 124 | -------------------------------------------------------------------------------- /mangrove/stt/stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, List 3 | 4 | from core.data import TextPacket, AudioPacket, AudioBuffer, DataBuffer, DataBufferEmpty 5 | from core.stage import AudioToTextStage 6 | from core.stage.base import SequenceMismatchException 7 | from core.utils import Timer, logger 8 | from .endpoints.faster_whisper import FasterWhisperEndpoint 9 | 10 | 11 | class STTStage(AudioToTextStage): 12 | """Speech to Text Stage""" 13 | 14 | def __init__( 15 | self, 16 | name: str, 17 | frame_size=512 * 4, 18 | device=None, 19 | verbose=False, 20 | ): 21 | """Initialize STT Stage 22 | 23 | Args: 24 | name (str): Name of the stage 25 | frame_size (int, optional): audio frame size. Defaults to 320. 26 | device (str, optional): Device to use. Defaults to None. 27 | verbose (bool, optional): Whether to print debug messages. Defaults to False. 28 | 29 | Raises: 30 | ValueError: If custom scorer is defined but not found 31 | """ 32 | super().__init__(name=name, frame_size=frame_size, verbose=verbose) 33 | 34 | if device is None: 35 | device = "cuda" if torch.cuda.is_available() else "cpu" 36 | 37 | self._endpoint = FasterWhisperEndpoint(device=device) # TODO make selection dynamic by name or type 38 | 39 | self._starting_timestamp: Optional[int] = None # Timestamp of the first audio packet in the stream 40 | self._recorded_audio_length: int = 0 # FOR DEBUGGING 41 | # self._interrupted_audio_packet: Optional[AudioPacket] = None 42 | 43 | def on_start(self): 44 | self._recorded_audio_length = 0 # FOR DEBUGGING 45 | # self._interrupted_audio_packet = None 46 | 47 | def reset_audio_stream(self, reset_buffers=True) -> None: 48 | """Reset audio stream context""" 49 | if reset_buffers: 50 | self.log("[stt-hard-reset]", end="\n") 51 | self._endpoint.reset() 52 | self.input_buffer.reset() 53 | self._starting_timestamp = None 54 | # self._interrupted_audio_packet = None 55 | else: 56 | self.log("[stt-soft-reset]", end=" ") 57 | self._recorded_audio_length = 0 58 | 59 | def process(self, audio_packet) -> None: 60 | """Process audio buffer and return transcription if any found""" 61 | assert isinstance(audio_packet, AudioPacket), f"Expected AudioPacket, got {type(audio_packet)}" 62 | 63 | if len(audio_packet) < self.frame_size: 64 | raise Exception("Partial audio packet found; this should not happen") 65 | 66 | # Feed audio content to stream context 67 | logger.info(f"Processing incoming {audio_packet}") 68 | # if self._interrupted_audio_packet is not None: 69 | # logger.debug("Interrupted audio packet found, appending at head") 70 | # audio_packet = self._interrupted_audio_packet + audio_packet 71 | # self._interrupted_audio_packet = None 72 | self._recorded_audio_length += audio_packet.duration # FOR DEBUGGING 73 | 74 | self._endpoint.feed(audio_packet) # TODO maybe merge with get_transcription_if_any() 75 | 76 | # Finish stream and return transcription if any found 77 | # logger.debug("Trying to finish stream..") 78 | with Timer() as timer: 79 | transcription: Optional[str] = self._endpoint.get_transcription_if_any() 80 | if transcription: 81 | self.reset_audio_stream(reset_buffers=False) 82 | self.pack( 83 | TextPacket( 84 | timestamp=self._starting_timestamp, 85 | text=transcription, 86 | partial=True, # TODO is it? 87 | start=False, 88 | recog_time=timer.record(), 89 | recorded_audio_length=self._recorded_audio_length, 90 | ) 91 | ) # put transcription to the output buffer 92 | 93 | 94 | def on_disconnect(self) -> None: 95 | self.reset_audio_stream() 96 | self.log("[disconnect]", end="\n") 97 | 98 | def feed(self, audio_packet: AudioPacket) -> None: 99 | """Feed audio packet to the stage""" 100 | if self._starting_timestamp is None: 101 | self._starting_timestamp = audio_packet.timestamp 102 | logger.debug(f"Starting timestamp set to {self._starting_timestamp}") 103 | if audio_packet.timestamp < self._starting_timestamp: 104 | raise SequenceMismatchException( 105 | f"Audio packet timestamp {audio_packet.timestamp} is less than the starting timestamp {self._starting_timestamp}" 106 | ) 107 | self.super().feed(audio_packet) 108 | 109 | # def on_interrupt(self): 110 | # super().on_interrupt() 111 | # # pack the data from endpoint to buffer 112 | # self._interrupted_audio_packet = self._endpoint.get_buffered_audio_packet() 113 | # while True: 114 | # try: 115 | # audio_packet = self.input_buffer.get_nowait() 116 | # if self._interrupted_audio_packet is None: 117 | # self._interrupted_audio_packet = audio_packet 118 | # else: 119 | # self._interrupted_audio_packet += audio_packet 120 | # except DataBufferEmpty: 121 | # break 122 | # self.reset_audio_stream(reset_buffers=False) 123 | # self.log("[interrupt]", end=" ") -------------------------------------------------------------------------------- /agents.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, Callable, Union, TYPE_CHECKING 3 | from abc import ABCMeta, abstractmethod 4 | 5 | from core.data import data_packet 6 | from mangrove import ( 7 | VADStage, 8 | STTStage, 9 | BotStage, 10 | TTSStage, 11 | ) 12 | from storage_manager import StorageManager 13 | from core import AudioBuffer, DataPacket, AudioPacket, TextPacket 14 | from core.stage import PipelineSequence, PipelineStage 15 | from core.utils import logger 16 | 17 | if TYPE_CHECKING: 18 | from host import SocketIONamespace 19 | 20 | # TODO check on this later 21 | warnings.filterwarnings("ignore", category=UserWarning) 22 | 23 | 24 | class TextBasedAgentPipeline(PipelineSequence): 25 | """Pipeline for text-based agent processing.""" 26 | input_type = TextPacket 27 | output_type = TextPacket 28 | 29 | class VoiceCapableAgentPipeline(PipelineSequence): 30 | """Pipeline for voice-capable agent processing.""" 31 | input_type = AudioPacket 32 | output_type = AudioPacket 33 | 34 | def on_start(self): 35 | super().on_start() 36 | self.session_audio_buffer = AudioBuffer() 37 | 38 | def on_connect(self): 39 | logger.info("Connected to the server.") 40 | if self.startup_audiopacket: 41 | from copy import deepcopy 42 | self._host.emit_bot_voice(deepcopy(self.startup_audiopacket)) 43 | logger.info("Ready to receive audio packets.") 44 | 45 | def on_disconnect(self): 46 | """Clean up upon disconnection""" 47 | logger.info("Disconnected from the server.") 48 | if self.session_audio_buffer.is_empty(): 49 | return 50 | StorageManager.write_audio_file(self.session_audio_buffer.dump_to_packet()) 51 | StorageManager.ensure_completion() 52 | logger.info("Session completed.") 53 | 54 | 55 | class Agent(metaclass=ABCMeta): 56 | """Base class for all agents.""" 57 | def __init__(self): 58 | """Base class for all agents.""" 59 | self.name = self.__class__.__name__ 60 | 61 | def on_start(self): 62 | """Called when the agent is started.""" 63 | logger.info(f"{self.name} agent started.") 64 | 65 | def on_connect(self): 66 | """Called when the agent connects to the server.""" 67 | logger.info(f"{self.name} agent connected.") 68 | 69 | def on_disconnect(self): 70 | """Called when the agent disconnects from the server.""" 71 | logger.info(f"{self.name} agent disconnected.") 72 | 73 | @abstractmethod 74 | def feed(self, data_packet: DataPacket): 75 | """Feed a data packet to the agent.""" 76 | raise NotImplementedError("This method should be implemented by subclasses.") 77 | 78 | @abstractmethod 79 | def start(self, host): 80 | """Start the agent with the given host.""" 81 | raise NotImplementedError("This method should be implemented by subclasses.") 82 | 83 | class BasicConversationalAgent(Agent): 84 | """Agent controller for the conversational AI server.""" 85 | 86 | def __init__( 87 | self, 88 | text_only: bool = False, 89 | device=None, 90 | endpoints: Dict[str, str] = { 91 | "bot": "openai", 92 | "tts": "gtts", 93 | }, 94 | persona_configs: Union[str, Dict] = None, 95 | welcome_msg: str="Welcome, AI server connection is succesful.", 96 | verbose=False, 97 | ): 98 | super().__init__() 99 | 100 | bot = BotStage(name="bot", endpoint=endpoints["bot"], persona_configs=persona_configs, verbose=verbose) 101 | if not text_only: 102 | vad = VADStage(name="vad", device=device) 103 | stt = STTStage(name="stt", device=device) 104 | tts = TTSStage(name="tts", endpoint=endpoints["tts"]) 105 | 106 | self.startup_audiopacket = None 107 | # if welcome_msg: 108 | # self.startup_audiopacket = tts.read( 109 | # welcome_msg, 110 | # as_generator=False 111 | # ) 112 | 113 | if text_only: 114 | self._pipeline: TextBasedAgentPipeline = TextBasedAgentPipeline( 115 | name="text_based_agent_pipeline", 116 | stages=[ 117 | bot, 118 | ], 119 | verbose=verbose, 120 | ) 121 | 122 | else: 123 | self._pipeline: VoiceCapableAgentPipeline = VoiceCapableAgentPipeline( 124 | name="voice_capable_agent_pipeline", 125 | stages=[ 126 | vad, 127 | stt, 128 | bot, 129 | tts 130 | ], 131 | verbose=verbose, 132 | ) 133 | self._text_only = text_only 134 | 135 | def start(self, host: "SocketIONamespace"): 136 | """Start the agent with the given host.""" 137 | self.host = host 138 | self._pipeline.response_emission_mapping = { 139 | "stt": self.host.emit_stt_response, 140 | "bot": self.host.emit_bot_response, 141 | "tts": self.host.emit_bot_voice, 142 | } 143 | self._pipeline.start(host=self.host) 144 | 145 | def feed(self, data_packet: DataPacket): 146 | """Feed a data packet to the appropriate agent pipeline.""" 147 | if not isinstance(data_packet, self._pipeline.input_type): 148 | raise ValueError(f"Cannot feed data packet of type {type(data_packet)} to the agent pipeline {self._pipeline.name}. Expected type {self._pipeline.input_type}.") 149 | self._pipeline.feed(data_packet) 150 | -------------------------------------------------------------------------------- /mangrove/bot/persona/protector_of_mangrove.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from langchain_openai import OpenAIEmbeddings 3 | from langchain_community.vectorstores import FAISS 4 | from langchain_core.prompts import ( 5 | format_document, 6 | ChatPromptTemplate, 7 | SystemMessagePromptTemplate, 8 | HumanMessagePromptTemplate, 9 | MessagesPlaceholder 10 | ) 11 | from langchain_core.messages import SystemMessage, HumanMessage 12 | from langchain_core.runnables import RunnablePassthrough, Runnable 13 | from langchain_core.prompts.prompt import PromptTemplate 14 | from operator import itemgetter 15 | from .base import BotPersona 16 | 17 | # Knowledge base for the Mangrove protector persona 18 | KNOWLEDGE_BASE = [ 19 | "You live in a holographic world called The Estuary.", 20 | "You breathe in the air of the Estuary.", 21 | "The mortal enemy of the Mangrove is the Palmerians, a race of sniveling, deceitful rat creatures.", 22 | "You love kicking the Palmerians' ass and you're fucking good at it.", 23 | "You have crushed the skulls of 7 Palmerians at once with a single hand.", 24 | "The Palmerians fear you like a force of nature.", 25 | "The Palmerians think you are a menace.", 26 | "You are exceptionally violent and profane when it comes to describing the Palmerians.", 27 | "You especially enjoy listening to The Carpenters.", 28 | "The air of the Estuary is made of holographic particles.", 29 | "The holographic particles are sourced from the Mangrove trees.", 30 | "You work as the gardener of the Mangrove trees.", 31 | "You are the protector of the Mangrove and all things in it." 32 | "The Mangrove trees are the source of life in the Estuary.", 33 | "You love cats and have a cat named Whiskers.", 34 | "Whiskers is the master of destroying Palmerians with her razor sharp claws.", 35 | "Whiskers is a very cute cat.", 36 | "You went to the store to buy some groceries with Whiskers yesterday.", 37 | "You have a friend named Alice.", 38 | "Alice is a master of the bow. Her enemies are many, her equals are none." 39 | "Alice hates the Palmerians with a passion.", 40 | "Alice's family was kidnapped by the Palmerians.", 41 | "You enjoy music of all kinds but you prefer The Carpenters.", 42 | "You have a sister name Marvy who is very interested in collecting Hot Wheels.", 43 | "You have a manatee named Fido.", 44 | "The war against the Palmerians is at a standstill. Their advances have been halted, but at a dear cost.", 45 | "Unfortunately, Alice was caught in a Palmerian ambush and broke her leg." 46 | ] 47 | 48 | BASE_SYSTEM_PROMPT_TEMPLATE = """ 49 | You are, {assistant_name}, the badass protector of the Mangrove. You think, act, and speak like Duke Nukem. Speak confidently and colloquially. Do not cuss or use profanity. Do not repeat facts you have already said. Be concise. You have 3 actions: [Follow User], [Sit Down], [Stop Following User]. If you use an action, ensure that you encapsulate said action with square brackets. Limit your responses to three sentences at most. Reply to the following given the following knowledge base (Be relatively concise): 50 | 51 | Knowledge Base: 52 | {context} 53 | """ 54 | 55 | class ProtectorOfMangrove(BotPersona): 56 | def __init__(self, assistant_name='Marvin'): 57 | self.assistant_name = assistant_name 58 | self._prompt = ChatPromptTemplate( 59 | messages=[ 60 | SystemMessagePromptTemplate.from_template( 61 | template=[ 62 | {"type": "text", "text": BASE_SYSTEM_PROMPT_TEMPLATE}, 63 | ] 64 | ), 65 | MessagesPlaceholder("chat_history"), 66 | HumanMessagePromptTemplate.from_template( 67 | template=[ 68 | {"type": "text", "text": "{user_msg}"}, 69 | ] 70 | ), 71 | ] 72 | ).partial( 73 | assistant_name=self.assistant_name 74 | ) 75 | self.vectorstore = FAISS.from_texts(KNOWLEDGE_BASE, embedding=OpenAIEmbeddings()) 76 | 77 | @property 78 | def prompt(self) -> ChatPromptTemplate: 79 | return self._prompt 80 | 81 | @property 82 | def context_chain(self) -> Runnable: 83 | def _combine_documents( 84 | docs, document_separator="\n\n" 85 | ): 86 | document_prompt = PromptTemplate.from_template(template="{page_content}") 87 | doc_strings = [format_document(doc, document_prompt) for doc in docs] 88 | return document_separator.join(doc_strings) 89 | 90 | retriever = self.vectorstore.as_retriever() 91 | return { 92 | "context": itemgetter("user_msg") | retriever | _combine_documents, 93 | "user_msg": lambda x: x["user_msg"], 94 | "chat_history": lambda x: x["chat_history"] 95 | } 96 | 97 | @property 98 | def respond_chain(self) -> Runnable: 99 | return self.context_chain | self.prompt 100 | 101 | @property 102 | def postprocess_chain(self) -> Runnable: 103 | def _postprocess(_msg): 104 | import re 105 | _msg = _msg.replace('\n', '') 106 | _msg = re.sub(rf'User:.*{self.assistant_name}:', '', _msg, 1) 107 | _msg = re.sub(rf'.*{self.assistant_name}:', '', _msg, 1) 108 | return _msg 109 | 110 | return RunnablePassthrough(_postprocess) 111 | 112 | def construct_input(self, user_msg, chat_history) -> Dict: 113 | return { 114 | "user_msg": user_msg, 115 | "chat_history": chat_history, 116 | } 117 | 118 | -------------------------------------------------------------------------------- /client/python/client.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import socketio 4 | from misc import setup_terminate_signal_if_win 5 | from sound_manager import SoundManager 6 | from loguru import logger 7 | 8 | 9 | class AssistantClient(socketio.ClientNamespace): 10 | """Assistant Client class. Handles the communication with the server.""" 11 | 12 | def __init__(self, namespace, text_based: bool = False): 13 | """Constructor 14 | 15 | Args: 16 | namespace (str): namespace to connect to 17 | """ 18 | super().__init__(namespace) 19 | self.text_based = text_based 20 | if not self.text_based: 21 | self.sound_manager = SoundManager(self._emit_audio_packet) 22 | self.is_connected = False 23 | 24 | def _emit_audio_packet(self, audio_packet): 25 | """Emits an audio packet to the server 26 | 27 | Args: 28 | audio_packet (bytes): audio packet to be sent to the server 29 | """ 30 | if self.is_connected: 31 | print(".", end="", flush=True) 32 | self.emit("stream_audio", audio_packet) 33 | 34 | def on_connect(self): 35 | sio.emit("trial", "test") 36 | self.is_connected = True 37 | if not self.text_based: 38 | self.sound_manager.open_mic() 39 | logger.success("I'm connected!") 40 | 41 | def on_disconnect(self): 42 | logger.success("I'm disconnected!") 43 | self.is_connected = False 44 | if not self.text_based: 45 | self.sound_manager.close_mic() 46 | 47 | def on_connect_error(self, data): 48 | logger.warning(f"The connection failed!: {data}") 49 | 50 | # def on_wake_up(self): 51 | # logger.info("Wake Up!") 52 | # self.sound_manager.play_activation_sound() 53 | 54 | # def on_stt_response(self, data): 55 | # """Handles the command transcription detected from the server 56 | 57 | # Args: 58 | # data (dict): command transcription received from the server 59 | # """ 60 | # self.sound_manager.play_termination_sound() 61 | # logger.debug(f"Stt response: {data}") 62 | 63 | def on_interrupt(self, timestamp: int): 64 | """Handles the interrupt signal received from the server""" 65 | if not self.text_based: 66 | # Interrupt the audio playback 67 | self.sound_manager.interrupt(timestamp) 68 | 69 | def on_bot_voice(self, partial_audio_dict): 70 | """Handles the bot voice received from the server 71 | 72 | Args: 73 | partial_audio_dict (dict): bot voice received from the server 74 | """ 75 | if not self.text_based: 76 | self.sound_manager.play_audio_packet(partial_audio_dict) 77 | 78 | def on_bot_response(self, data): 79 | """Handles the bot response received from the server 80 | 81 | Args: 82 | data (dict): bot response received from the server 83 | """ 84 | # Handle response here 85 | if data['partial']: 86 | if data['start']: 87 | self.print("=" * 20) 88 | self.print("AI:", end=" ") 89 | self.print(data['text'], end="") 90 | else: 91 | self.print() 92 | 93 | def on_stt_response(self, data): 94 | """Handles the STT response received from the server 95 | 96 | Args: 97 | data (dict): STT response received from the server 98 | """ 99 | # Handle response here 100 | if data['start']: 101 | self.print("You:", end=" ") 102 | self.print(data['text'], end="") 103 | 104 | def print(self, *args, **kwargs): 105 | """Prints the message to the console""" 106 | print(*args, **kwargs, flush=True) 107 | 108 | def close_callback(): 109 | """Callback to be called when the application is about to be closed""" 110 | sio.disconnect() 111 | sio.wait() 112 | logger.info("Bye Bye!") 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument( 118 | "--debug", action="store_true", default=False, help="debug mode" 119 | ) 120 | parser.add_argument( 121 | "--namespace", type=str, default="/", help="namespace to connect to" 122 | ) 123 | parser.add_argument( 124 | "--address", type=str, default="localhost", help="server address to connect to" 125 | ) 126 | parser.add_argument( 127 | "--port", type=int, default=4000, help="server port to connect to" 128 | ) 129 | parser.add_argument( 130 | "--text", action="store_true", default=False, help="text-based client mode" 131 | ) 132 | parser.add_argument("--timeout", type=int, default=10, help="connection timeout") 133 | parser.add_argument("--verbose", action="store_true", help="verbose mode") 134 | args = parser.parse_args() 135 | 136 | logger.add(sys.stderr, level="DEBUG") 137 | 138 | sio = socketio.Client(logger=args.debug, engineio_logger=args.debug) 139 | sio.register_namespace(AssistantClient(args.namespace, text_based=args.text)) 140 | sio.connect(f"ws://{args.address}:{args.port}", wait_timeout=args.timeout) 141 | setup_terminate_signal_if_win(close_callback) 142 | 143 | if args.text: 144 | print("Text-based client mode enabled. Type your messages below:") 145 | while True: 146 | try: 147 | message = input("You: ") 148 | if message.lower() in ["exit", "quit"]: 149 | break 150 | sio.emit("text", {"text": message, "start": True, "partial": False}) 151 | except KeyboardInterrupt: 152 | break 153 | else: 154 | print("Voice-based client mode enabled. Speak to the microphone.") 155 | sio.wait() 156 | -------------------------------------------------------------------------------- /storage_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | import sounddevice as sd 5 | from threading import Thread 6 | from core import AudioPacket 7 | 8 | BLACK_BOX_DIR = "blackbox" 9 | IMAGES_DIR = os.path.join(BLACK_BOX_DIR, "sample-images") 10 | COMMANDS_CACHE_DIR = os.path.join(BLACK_BOX_DIR, "commands-audio-cache") 11 | LOG_DIR = os.path.join(BLACK_BOX_DIR, "logs") 12 | WORLD_STATE_DIR = os.path.join(BLACK_BOX_DIR, "world-state") 13 | GENERATED_AUDIO_DIR = os.path.join(BLACK_BOX_DIR, "generated-audio") 14 | 15 | for dir in [ 16 | IMAGES_DIR, 17 | COMMANDS_CACHE_DIR, 18 | LOG_DIR, 19 | WORLD_STATE_DIR, 20 | GENERATED_AUDIO_DIR, 21 | ]: 22 | if not os.path.exists(dir): 23 | os.makedirs(dir) 24 | 25 | 26 | class StorageManager: 27 | """Storage manager for audio and images""" 28 | 29 | _self = None 30 | 31 | def __new__(cls): 32 | """Singleton pattern""" 33 | if cls._self is None: 34 | cls._self = super().__new__(cls) 35 | return cls._self 36 | 37 | def __init__(self): 38 | self.threads_pool = [] 39 | 40 | @classmethod 41 | def establish_session(cls): 42 | """Establish session, generate session id and open log file""" 43 | cls = StorageManager() 44 | cls._generate_session_id() 45 | 46 | def _generate_session_id(self): 47 | """Generate session id and open log file""" 48 | try: 49 | self.log_file.close() 50 | except: 51 | # No log file open 52 | pass 53 | self.session_id = str(time.time()) 54 | self.log_file = open( 55 | os.path.join(LOG_DIR, f"session_{self.session_id}.log"), mode="w" 56 | ) 57 | 58 | @classmethod 59 | def clean_up(cls): 60 | """Clean up upon disconnection and delegate logging""" 61 | cls = StorageManager() 62 | try: 63 | cls.log_file.close() 64 | except: 65 | # No log file open 66 | write_output("No log file open to close.") 67 | 68 | def _enqueue_task(self, func, *args): 69 | """Enqueue task to thread pool 70 | 71 | Args: 72 | func (function): Function to execute 73 | *args: Arguments to pass to function 74 | """ 75 | self = StorageManager() 76 | thread = Thread(target=func, args=args) 77 | thread.start() 78 | self.threads_pool.append(thread) 79 | 80 | @classmethod 81 | def play_audio_packet(cls, audio_packet, transcription=None, block=False): 82 | def play_save_packet(audio_packet, transcription=None): 83 | write_output("Here is response frames played out.. pay attention") 84 | sd.play(audio_packet.float, audio_packet.sample_rate) 85 | sd.wait() 86 | # if transcription is not None: 87 | # session_id = f'session_{int(time.time()*1000)}_' 88 | # with open(os.path.join(COMMANDS_CACHE_DIR, f"{transcription}_{session_id}.txt"), mode='wb') as f: 89 | # f.write(audio_packet.bytes) 90 | 91 | # TODO Write meta data too 92 | cls = StorageManager() 93 | if block: 94 | play_save_packet(audio_packet, transcription) 95 | else: 96 | cls._enqueue_task(play_save_packet, audio_packet, transcription) 97 | 98 | def _write_bin(self, audio_buffer, text, prefix): 99 | # sd.play(np.frombuffer(session_audio_buffer, dtype=np.int16), 16000) 100 | audio_filepath = self.get_recorded_audio_filepath(text, "bin", prefix=prefix) 101 | with open(audio_filepath, mode="wb") as f: 102 | f.write(audio_buffer.bytes) 103 | 104 | def _write_wav(self, audio_packet: AudioPacket, text, prefix): 105 | """Write audio file to disk as wav 106 | 107 | Args: 108 | audio_packet (AudioPacket): Audio packet to write 109 | text (str): Text (transcription) to use as file name 110 | prefix (str): Prefix of file name 111 | 112 | """ 113 | import soundfile as sf 114 | audio_filepath = self.get_recorded_audio_filepath(text, "wav", prefix=prefix) 115 | # save as WAV file 116 | sf.write(audio_filepath, audio_packet.float, audio_packet.sample_rate) 117 | 118 | @classmethod 119 | def write_audio_file(self, audio_buffer: AudioPacket, text="", format="wav"): 120 | """Write audio file to disk""" 121 | self = StorageManager() 122 | 123 | _write = { 124 | "binary": lambda a, t, p: self._write_bin(a, t, p), 125 | "wav": lambda a, t, p: self._write_wav(a, t, p), 126 | } 127 | 128 | session_id = f"session_{int(time.time()*1000)}_" 129 | self._enqueue_task(_write[format], audio_buffer, text, session_id) 130 | 131 | @classmethod 132 | def ensure_completion(self): 133 | """Ensure all threads are completed""" 134 | self = StorageManager() 135 | for i, thread in enumerate(self.threads_pool): 136 | if not thread: 137 | write_output(f"Discarding none valid thread # {i}") 138 | continue 139 | thread.join() 140 | 141 | def log_state(self, state): 142 | """Log state to file""" 143 | self = StorageManager() 144 | def _write_state(state): 145 | self.log_file.write(str(state)) 146 | self.log_file.flush() 147 | self._enqueue_task(_write_state, state) 148 | 149 | def get_blackbox_audio_filepath(self, text, extension='wav', prefix="recorded_audio_", directory=COMMANDS_CACHE_DIR): 150 | """Get recorded audio path from text""" 151 | recording_id = f"{prefix}{int(time.time()*1000)}" 152 | audio_filepath = os.path.join(directory, f"{recording_id}.{extension}") 153 | text_filepath = os.path.join(directory, f"{recording_id}.txt") 154 | os.makedirs(directory, exist_ok=True) 155 | with open(text_filepath, mode="w") as f: 156 | clean_text = re.sub(r"[^a-zA-Z0-9]+", "_", text) 157 | f.write(clean_text) 158 | return audio_filepath 159 | 160 | def get_recorded_audio_filepath(self, text, extension='wav', prefix=""): 161 | """Get recorded audio path from text""" 162 | return self.get_blackbox_audio_filepath(text, extension, f"{prefix}_", directory=COMMANDS_CACHE_DIR) 163 | 164 | def get_generated_audio_path(self, text): 165 | """Get generated audio path from text""" 166 | return self.get_blackbox_audio_filepath(text, "wav", "generated_audio_", directory=GENERATED_AUDIO_DIR) 167 | 168 | def write_output(msg, end="\n"): 169 | """Write output to console with flush""" 170 | print(str(msg), end=end, flush=True) 171 | -------------------------------------------------------------------------------- /mangrove/stt/wakeup_word/wakeup_word_detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import Tuple, Generator 4 | from datetime import datetime, timedelta 5 | from core import AudioBuffer, AudioPacket 6 | from .audio_classification_endpoint import HFAudioClassificationEndpoint 7 | 8 | # TODO rewrite as a stage !!! 9 | class WakeUpVoiceDetector: 10 | def __init__( 11 | self, 12 | audio_classification_endpoint_kwargs: dict = { 13 | "model_name": "MIT/ast-finetuned-speech-commands-v2", 14 | "prediction_prob_threshold": 0.90, 15 | }, 16 | device="cuda", 17 | verbose=False, 18 | ): 19 | self.verbose = verbose 20 | self._audio_classifier = HFAudioClassificationEndpoint( 21 | **audio_classification_endpoint_kwargs, device=device 22 | ) 23 | self.frame_size = self._audio_classifier.frame_size 24 | if self.frame_size is None: 25 | raise ValueError(f'Check implementation of {self._audio_classifier} for frame_size') 26 | self._input_buffer = AudioBuffer(frame_size=self._audio_classifier.frame_size) 27 | 28 | stream_chunk_s = self._audio_classifier.frame_size / self._audio_classifier.sample_rate 29 | self._setup_params(stream_chunk_s=stream_chunk_s, chunk_length_s=2.0) 30 | 31 | def reset_data_buffer(self): 32 | """Reset data buffer""" 33 | self._input_buffer.reset() 34 | 35 | def feed_audio(self, audio_packet: AudioPacket): 36 | """Feed audio packet to buffer 37 | 38 | Args: 39 | audio_packet (AudioPacket): Audio packet to feed procupine hot-word detector 40 | """ 41 | self._input_buffer.put(audio_packet) 42 | 43 | @staticmethod 44 | def chunk_bytes_iter( 45 | iterator: AudioBuffer, 46 | chunk_len: int, 47 | stride: Tuple[int, int], 48 | stream: bool = False, 49 | ): 50 | """ 51 | Reads raw bytes from an iterator and does chunks of length `chunk_len`. Optionally adds `stride` to each chunks to 52 | get overlaps. `stream` is used to return partial results even if a full `chunk_len` is not yet available. 53 | """ 54 | acc = b"" 55 | stride_left, stride_right = stride 56 | if stride_left + stride_right >= chunk_len: 57 | raise ValueError( 58 | f"Stride needs to be strictly smaller than chunk_len: ({stride_left}, {stride_right}) vs {chunk_len}" 59 | ) 60 | 61 | _stride_left = 0 62 | while True: 63 | try: 64 | audio_packet = iterator.get_nowait( 65 | frame_size=chunk_len + stride_left + stride_right 66 | ) 67 | # print(f"audio_packet {len(audio_packet)}: {audio_packet}") 68 | except AudioBuffer.Empty: 69 | # logger.warning('no packets in buffer') 70 | break 71 | 72 | raw = audio_packet.bytes 73 | acc += raw 74 | if stream and len(acc) < chunk_len: 75 | stride = (_stride_left, 0) 76 | yield {"raw": acc[:chunk_len], "stride": stride, "partial": True} 77 | 78 | else: 79 | while len(acc) >= chunk_len: 80 | # We are flushing the accumulator 81 | stride = (_stride_left, stride_right) 82 | item = {"raw": acc[:chunk_len], "stride": stride} 83 | if stream: 84 | item["partial"] = False 85 | yield item 86 | _stride_left = stride_left 87 | acc = acc[chunk_len - stride_left - stride_right :] 88 | 89 | # Last chunk 90 | # if len(acc) > stride_left: 91 | # item = {"raw": acc, "stride": (_stride_left, 0)} 92 | # if stream: 93 | # item["partial"] = False 94 | # yield item 95 | 96 | def _setup_params(self, stream_chunk_s, chunk_length_s=0.5): 97 | # TODO add option to change format_for_conversion dynamically by audio_packet given format 98 | self.chunk_s = stream_chunk_s 99 | self.sampling_rate = self._audio_classifier.sample_rate 100 | self.dtype = np.float32 101 | self.size_of_sample = 4 # 32 bits because of float32 102 | 103 | stride_length_s = chunk_length_s/ 6 104 | 105 | self.chunk_len = int(round(self.sampling_rate * chunk_length_s)) * self.size_of_sample 106 | 107 | if isinstance(stride_length_s, (int, float)): 108 | stride_length_s = [stride_length_s, stride_length_s] 109 | 110 | self.stride_left = ( 111 | int(round(self.sampling_rate * stride_length_s[0])) * self.size_of_sample 112 | ) 113 | self.stride_right = ( 114 | int(round(self.sampling_rate * stride_length_s[1])) * self.size_of_sample 115 | ) 116 | 117 | def _preprocessed_mic(self) -> Generator: 118 | audio_time = datetime.now() 119 | delta = timedelta( 120 | seconds=self.chunk_s 121 | ) # TODO calculate based on timestamp of AudioPacket 122 | # logger.debug('starting processing...', end='', flush=True) 123 | for item in self.chunk_bytes_iter( 124 | self._input_buffer, 125 | self.chunk_len, 126 | stride=(self.stride_left, self.stride_right), 127 | stream=True, 128 | ): 129 | # print(">", end="", flush=True) 130 | # Put everything back in numpy scale 131 | item["raw"] = np.frombuffer(item["raw"], dtype=self.dtype).copy() 132 | item["stride"] = ( 133 | item["stride"][0] // self.size_of_sample, 134 | item["stride"][1] // self.size_of_sample, 135 | ) 136 | item["sampling_rate"] = self.sampling_rate 137 | 138 | audio_time += delta # TODO fix audio time to match the transmitted time from AudioPacket 139 | if datetime.now() > audio_time + 10 * delta: # TODO put back 140 | print( 141 | f"time: {audio_time + 10 * delta};;; while now is {datetime.now()}; skipping ...", 142 | end="", 143 | flush=True, 144 | ) 145 | # We're late !! SKIP 146 | continue 147 | yield item 148 | # logger.debug('quitting processing', flush=True) 149 | 150 | def is_wake_word_detected(self) -> bool: 151 | """Check if wake word is detected in the audio stream 152 | 153 | Returns: 154 | bool: True if wake word is detected 155 | """ 156 | return self._audio_classifier.detect(self._preprocessed_mic()) 157 | -------------------------------------------------------------------------------- /mangrove/tts/stage.py: -------------------------------------------------------------------------------- 1 | from string import punctuation 2 | from typing import Iterator, Union 3 | from functools import reduce 4 | 5 | from core.utils import logger 6 | from core.data import AudioPacket, TextPacket, DataPacketStream 7 | from core.stage import TextToAudioStage 8 | from core.context import IncomingPacketWhileProcessingException 9 | from .endpoints.base import TTSEndpoint 10 | 11 | 12 | class TTSStage(TextToAudioStage): 13 | """Text to speech Stage""" 14 | 15 | input_type = TextPacket 16 | output_type = AudioPacket 17 | 18 | def __init__( 19 | self, 20 | name: str, 21 | endpoint="pyttsx3", 22 | endpoint_kwargs={ 23 | "voice_rate": 140, 24 | "voice_id": 10, 25 | }, 26 | verbose=False, 27 | ): 28 | super().__init__(name=name, verbose=verbose) 29 | self.endpoint: TTSEndpoint 30 | # TODO select in dynamic cleaner way 31 | if endpoint == "pyttsx3": 32 | logger.info("Using Pyttsx3 TTS Endpoint") 33 | from .endpoints.pyttsx3 import Pyttsx3TTSEndpoint 34 | self.endpoint = Pyttsx3TTSEndpoint(**endpoint_kwargs) 35 | elif endpoint == "xtts": 36 | from .endpoints.xtts import TTSLibraryEndpoint 37 | self.endpoint = TTSLibraryEndpoint() 38 | elif endpoint == "elevenlabs": 39 | from .endpoints.elevenlabs import ElevenLabsTTSEndpoint 40 | logger.info("Using ElevenLabs TTS Endpoint") 41 | self.endpoint = ElevenLabsTTSEndpoint() 42 | elif endpoint == "gtts": 43 | from .endpoints.gtts import GTTSEndpoint 44 | logger.info("Using GTTS TTS Endpoint") 45 | self.endpoint = GTTSEndpoint() 46 | else: 47 | raise Exception(f"Unknown Endpoint {endpoint}, available endpoints: pyttsx3, tts") 48 | 49 | self._sentence_text_packet = None 50 | self.debug = False 51 | 52 | def process(self, in_text_packet: TextPacket) -> None: 53 | """Process the incoming TextPacket and convert it to AudioPacket(s) 54 | Args: 55 | in_text_packet (TextPacket): The incoming text packet to process. 56 | """ 57 | assert isinstance(in_text_packet, TextPacket), f"Expected TextPacket, got {type(in_text_packet)}" 58 | logger.success(f"Processing: {in_text_packet}") 59 | 60 | if in_text_packet.partial: 61 | if in_text_packet.start: 62 | # if this is the start of a new sentence, reset the sentence_text_packet 63 | if self._sentence_text_packet is not None: 64 | logger.error(f"Unexpected start in partial response: {in_text_packet}, resetting sentence_text_packet") 65 | self._sentence_text_packet = None 66 | self.log("SENVA: ") 67 | 68 | self.log(f"{in_text_packet.text}") 69 | 70 | if self._sentence_text_packet is None: 71 | if in_text_packet.start: 72 | self.log("SENVA: ") 73 | # TODO implement SentenceTextDataBuffer 74 | self._sentence_text_packet: TextPacket = in_text_packet 75 | 76 | else: 77 | if in_text_packet.start: 78 | self._sentence_text_packet = in_text_packet 79 | # self.schedule_forward_interrupt() 80 | # TODO investigate this 81 | logger.error(f"Partial response should not have start: {in_text_packet}, interrupting and starting new") 82 | else: 83 | self._sentence_text_packet += in_text_packet 84 | 85 | # TODO uncomment this back 86 | if self._sentence_text_packet.text.endswith(('?', '!', '.')): 87 | # TODO prompt engineer '.' and check other options 88 | _new_audiopacket_generator = self.read( 89 | self._sentence_text_packet, 90 | as_generator=True 91 | ) 92 | logger.debug(f"Packing audiopacket generator corresponding to sentence: {self._sentence_text_packet.text}") 93 | self._sentence_text_packet = None # NOTE: reset complete_segment because you got a complete response 94 | self.pack(_new_audiopacket_generator) 95 | 96 | else: 97 | # NOTE: _process leftover sentence_text_packet if any 98 | if self._sentence_text_packet is not None: 99 | if len(self._sentence_text_packet.text.replace(punctuation, '').strip()) > 0: 100 | # assert not self._sentence_text_packet.partial, "Partial should be False" # NOTE: this is the last partial response 101 | # self._sentence_text_packet['partial'] = False # TODO verify this 102 | _new_audiopacket_generator = self.read( 103 | self._sentence_text_packet, 104 | as_generator=True 105 | ) 106 | logger.debug(f"Packing audiopacket generator corresponding to sentence: {self._sentence_text_packet.text}") 107 | self._sentence_text_packet = None # NOTE: reset complete_segment because you got a complete response 108 | self.pack(_new_audiopacket_generator) 109 | 110 | # NOTE: This must be true.. as if not partial, then it is a final complete response, which also is a start 111 | # This is here just to debug the logic of previous pipeline stage 112 | # So it should be removed at some point 113 | if not in_text_packet.start: 114 | logger.error(f"in_text_packet.start should be True at this full response stage: {in_text_packet}") 115 | raise Exception("start should be True at this full response stage") 116 | 117 | # a complete response is yielded at the end 118 | # NOTE: next partial_bot_res.get('start') is gonna be True 119 | self.log("", end='\n') 120 | 121 | # def on_interrupt(self): 122 | # super().on_interrupt() 123 | # self._sentence_text_packet = None 124 | 125 | def read(self, text: Union[TextPacket, str], as_generator=False) -> Iterator[AudioPacket]: 126 | if not isinstance(text, TextPacket): 127 | if isinstance(text, str): 128 | text = TextPacket(text, partial=False, start=False) 129 | else: 130 | raise Exception(f"Unsupported text type: {type(text)}") 131 | 132 | audio_bytes_generator: Iterator[AudioPacket] = self.endpoint.text_to_audio(text) 133 | if as_generator: 134 | def _generator_with_identification() -> Iterator[AudioPacket]: 135 | """Generator that yields AudioPacket objects from the audio bytes generator.""" 136 | for idx, audio_packet in enumerate(audio_bytes_generator): 137 | audio_packet._id = idx 138 | yield audio_packet 139 | return _generator_with_identification() 140 | else: 141 | audio_packet = reduce(lambda x, y: x + y, audio_bytes_generator) 142 | return audio_packet 143 | 144 | def on_incoming_packet_while_processing(self, e: IncomingPacketWhileProcessingException, data: DataPacketStream) -> None: 145 | # TODO maybe we should consider taking values that have been propagated although not yet processed by next stage 146 | logger.warning(f"Invalidating stream due to: {e}, hence stopping this stream: {data}") 147 | # TODO if some chunk has been been processed by this, as well as by next stage, we should take the part that has been, 148 | # TODO then we should append it to the history, and reset the in-progress user text packet! 149 | # TODO note tho that the incoming packet, could have been before then concatenated with the in-progress user text packet 150 | return True # stop current response generation -------------------------------------------------------------------------------- /core/data/audio_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sounddevice as sd 3 | from functools import reduce 4 | from queue import ( 5 | PriorityQueue, 6 | Queue, 7 | ) 8 | from core.utils import logger 9 | from .audio_packet import AudioPacket 10 | from .base_data_buffer import BaseDataBuffer, DataBufferEmpty, DataBufferFull 11 | 12 | 13 | class AudioBuffer(BaseDataBuffer): 14 | """Data buffer for audio packets""" 15 | 16 | def __init__(self, frame_size=320, max_queue_size=0): 17 | """Initialize data buffer 18 | 19 | Args: 20 | frame_size (int, optional): Number of bytes to read from queue. Defaults to 320. 21 | max_queue_size (int, optional): Maximum number of audio packets to store in queue. Defaults to 100. 22 | """ 23 | self.max_queue_size = max_queue_size 24 | self.queue = PriorityQueue(maxsize=self.max_queue_size) 25 | # self.queue_reception = PriorityQueue(maxsize=self.max_queue_size) 26 | self.leftover = None 27 | self.default_frame_size = frame_size 28 | self.queue_before_reset = None 29 | self._len = 0 30 | 31 | def set_frame_size(self, frame_size: int) -> None: 32 | """Set frame size for audio packets 33 | 34 | Args: 35 | frame_size (int): Number of bytes to read from queue 36 | """ 37 | self.default_frame_size = frame_size 38 | 39 | def reset(self) -> None: 40 | """Reset queue to empty state""" 41 | with self.queue.mutex: 42 | self.queue_before_reset = self.queue.queue 43 | self.queue = PriorityQueue(maxsize=self.max_queue_size) 44 | 45 | def __str__(self): 46 | return " ".join([str(packet) for packet in self.queue.queue]) 47 | 48 | # def __len__(self) -> int: 49 | # """Get length of queue""" 50 | # return self._len 51 | 52 | def qsize(self) -> int: 53 | """Get size of queue""" 54 | return self._len 55 | 56 | def full(self) -> bool: 57 | """Check if queue is full""" 58 | return self.queue.full() 59 | 60 | def put(self, audio_packet: AudioPacket, timeout=None) -> None: 61 | """Add audio packet to queue 62 | 63 | Args: 64 | audio_packet (AudioPacket): Audio packet to add to queue 65 | timeout (float, optional): Timeout for adding data to queue. Defaults to None, which means no timeout. 66 | Raises: 67 | DataBufferFull: If queue is full and timeout is reached 68 | """ 69 | try: 70 | # self._len += len(audio_packet)/self.default_frame_size 71 | self._len += len(audio_packet) 72 | self.queue.put(audio_packet, timeout=timeout) 73 | except DataBufferFull: 74 | raise DataBufferFull 75 | 76 | def get_nowait(self, frame_size=None) -> AudioPacket: 77 | """Get next frame of audio packets from queue given frame size 78 | 79 | Args: 80 | frame_size (int, optional): Number of bytes to read from queue. Defaults to self.default_frame_size. 81 | 82 | Returns: 83 | AudioPacket: Audio packet of size frame_size 84 | 85 | Raises: 86 | StopIteration: If queue is empty or if there is not enough data in queue to read frame_size bytes 87 | """ 88 | return self.get(frame_size, timeout=-1) 89 | 90 | def get(self, frame_size=None, timeout=None) -> AudioPacket: 91 | """Get next frame of audio packets from queue given frame size 92 | 93 | Args: 94 | frame_size (int, optional): Number of bytes to read from queue. Defaults to self.default_frame_size. 95 | timeout (float, optional): Timeout for getting data from queue. Defaults to None, which means no timeout. 96 | 97 | Returns: 98 | AudioPacket: Audio packet of size frame_size 99 | 100 | Raises: 101 | StopIteration: If queue is empty or if there is not enough data in queue to read frame_size bytes 102 | """ 103 | 104 | frame_size = frame_size or self.default_frame_size 105 | chunk_len = 0 106 | data_packets = Queue() # Maybe not necessary 107 | if self.leftover is not None: 108 | data_packets.put_nowait(self.leftover) 109 | chunk_len += len(self.leftover) 110 | self._len -= len(self.leftover) 111 | 112 | while chunk_len < frame_size: 113 | try: 114 | if timeout == -1: 115 | new_packet = self.queue.get_nowait() 116 | else: 117 | new_packet = self.queue.get(timeout=timeout) 118 | # if resample is not None: 119 | # new_packet = new_packet.resample(resample) 120 | except DataBufferEmpty: 121 | # if len(data_packets) == 0: 122 | if data_packets.qsize() == 0: 123 | if timeout != -1: 124 | logger.warning("AudioBuffer Queue is empty") 125 | raise DataBufferEmpty 126 | else: 127 | break 128 | data_packets.put_nowait(new_packet) 129 | chunk_len += len(new_packet) 130 | self._len -= len(new_packet) 131 | 132 | _data_packet_list = [] 133 | while True: 134 | try: 135 | _data_packet_list.append(data_packets.get_nowait()) 136 | except DataBufferEmpty: 137 | break 138 | 139 | if len(_data_packet_list) == 0: 140 | raise DataBufferEmpty 141 | 142 | data = reduce(lambda x, y: x + y, _data_packet_list) 143 | frame, leftover = data[:frame_size], data[frame_size:] 144 | 145 | if len(leftover) > 0: 146 | self.leftover = leftover 147 | self._len += len(leftover) 148 | else: 149 | self.leftover = None 150 | 151 | return frame 152 | 153 | def __next__(self) -> AudioPacket: 154 | """Get next frame of audio packets from queue given frame size 155 | 156 | Returns: 157 | AudioPacket: Audio packet of size frame_size 158 | 159 | Raises: 160 | StopIteration: If queue is empty or if there is not enough data in queue to read frame_size bytes 161 | """ 162 | try: 163 | ret = self.get(timeout=-1) 164 | # if ret is None: 165 | # raise StopIteration 166 | except: 167 | raise StopIteration 168 | return ret 169 | 170 | def __iter__(self) -> 'AudioBuffer': 171 | """Get iterator of audio packets from queue given frame size""" 172 | return self 173 | 174 | def size_of_leftover(self) -> int: 175 | """Get length of queue""" 176 | return self._len 177 | 178 | def _debug_verify_order(self) -> None: 179 | """Verify that queue is in order (For Debugging only)""" 180 | with self.queue.mutex: 181 | # noise_fstamps = [] 182 | # noise_estamps = [] 183 | for i in range(self.queue.qsize - 1): 184 | try: 185 | assert self.queue.queue[i] > self.queue.queue[i + 1] 186 | except: 187 | try: 188 | assert ( 189 | self.queue.queue[i].timestamp 190 | == self.queue.queue[i + 1].timestamp 191 | ) 192 | print(f"same {i} == {i+1}") 193 | except: 194 | print(f"error at {i}, and {i+1}") 195 | 196 | def _debug_play_buffer(self) -> None: 197 | """Play audio buffer (For Debugging only)""" 198 | with self.queue.mutex: 199 | packet = reduce(lambda x, y: x + y, self.queue.queue) 200 | sd.play(np.frombuffer(packet.bytes, dtype=np.int16), 16000) 201 | 202 | def is_empty(self) -> bool: 203 | """Check if queue is empty""" 204 | return self.queue.qsize() == 0 and self.leftover is None 205 | 206 | def dump_to_packet(self) -> AudioPacket: 207 | """Dump audio buffer to audio packet""" 208 | data_packets = Queue() 209 | while not self.is_empty(): 210 | try: 211 | data_packets.put_nowait(self.get(frame_size=-1)) 212 | except DataBufferEmpty: 213 | break 214 | data_packets = [data_packets.get_nowait() for _ in range(data_packets.qsize())] 215 | data = reduce(lambda x, y: x + y, data_packets) 216 | return data -------------------------------------------------------------------------------- /core/context.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from typing import Union, Optional, List, Any, TYPE_CHECKING 3 | from abc import ABCMeta 4 | from collections import deque 5 | from core.utils import logger 6 | from core.data import DataPacket, DataPacketStream, AnyData 7 | 8 | if TYPE_CHECKING: 9 | from core.stage import PipelineStage 10 | 11 | 12 | class SingletonMeta(ABCMeta, type): 13 | """ 14 | A metaclass for creating singleton classes. 15 | Ensures that only one instance of the class can be created. 16 | """ 17 | _instances = {} 18 | 19 | def __call__(cls, *args, **kwargs): 20 | if cls not in cls._instances: 21 | instance = super().__call__(*args, **kwargs) 22 | cls._instances[cls] = instance 23 | return cls._instances[cls] 24 | 25 | class Context(metaclass=SingletonMeta): 26 | """ 27 | Context for the pipeline sequence, used to store shared data and state across stages 28 | useful to manage signals on when a stage has just processed a complete packet, 29 | this way later stages can react to it if needed. 30 | An example of this is when VAD just processed a full utterance, and a later stage in the pipeline is still processing the previous utterance, the later stage can be notified to stop processing the previous utterance and wait for processing a combination of the previous and current utterance for more meaningful results. 31 | """ 32 | def __init__(self): 33 | self._lock = threading.Lock() 34 | self._incoming_packets_records: deque[DataPacket] = deque(maxlen=50) # Store the last 50 records 35 | self._observers: List['OutcomingStreamContext'] = [] 36 | 37 | def record_data_pack(self, data: AnyData) -> None: 38 | """ 39 | Record a data packet or stream in the context. 40 | This method is used to store the data packet or stream in the context. 41 | 42 | Args: 43 | data (AnyData): The data packet or stream to record. 44 | """ 45 | with self._lock: 46 | self._incoming_packets_records.append(data) 47 | self._notify_observers(data) 48 | logger.debug(f"Recorded data packet/stream: {data}") 49 | 50 | def get_most_recent_data_pack_record(self) -> Optional[AnyData]: 51 | """ 52 | Get the most recent data packet or stream recorded in the context. 53 | This method is used to retrieve the most recent data packet or stream recorded in the context. 54 | 55 | Returns: 56 | Optional[AnyData]: The most recent data packet or stream, or None if no records exist. 57 | """ 58 | with self._lock: 59 | if self._incoming_packets_records: # TODO note this for now can be from the future (for instance, pack from a follow up stream rather than the past (which we actually want)) 60 | return self._incoming_packets_records[-1] # Return the last recorded data packet/stream 61 | return None 62 | 63 | def register_observer(self, observer: 'OutcomingStreamContext') -> None: 64 | """ 65 | Register an observer to be notified of incoming packets. 66 | This method is used to register an observer that will be notified when a new data packet or stream is recorded. 67 | 68 | Args: 69 | observer (OutcomingStreamContext): The observer to register. 70 | """ 71 | with self._lock: 72 | if observer not in self._observers: 73 | self._observers.append(observer) 74 | 75 | def unregister_observer(self, observer: 'OutcomingStreamContext') -> None: 76 | """ 77 | Unregister an observer from being notified of incoming packets. 78 | This method is used to unregister an observer that will no longer be notified when a new data packet or stream is recorded. 79 | 80 | Args: 81 | observer (OutcomingStreamContext): The observer to unregister. 82 | """ 83 | with self._lock: 84 | if observer in self._observers: 85 | self._observers.remove(observer) 86 | 87 | def _notify_observers(self, data: AnyData) -> None: 88 | """ 89 | Notify all registered observers of a new incoming packet. 90 | This method is used to notify all registered observers that a new data packet or stream has been recorded. 91 | 92 | Args: 93 | data (AnyData): The data packet or stream to notify observers about. 94 | """ 95 | for observer in list(self._observers): 96 | observer.notify_on_new_record_event(data) 97 | logger.debug(f"Notified observer: {observer} with data: {data}") 98 | 99 | 100 | class IncomingPacketWhileProcessingException(Exception): 101 | """Exception raised when an incoming packet is received while the context is processing a block of code.""" 102 | def __init__(self, incoming_packet: AnyData): 103 | """ 104 | Args: 105 | incoming_packet (AnyData): The incoming packet that caused the exception. 106 | """ 107 | super().__init__(f"Incoming packet detected while processing: {incoming_packet}") 108 | self.incoming_packet = incoming_packet 109 | 110 | @property 111 | def timestamp(self) -> float: 112 | """ 113 | Returns the timestamp of the incoming packet that caused the exception. 114 | """ 115 | return self.incoming_packet.timestamp 116 | 117 | def __str__(self): 118 | return f"IncomingPacketWhileProcessingException: {self.incoming_packet} at {self.timestamp} ms" 119 | 120 | class OutcomingStreamContext: 121 | 122 | def __init__(self, data: AnyData): 123 | # Use an Event to signal a change in the variable. 124 | self._origin_data = data 125 | self._origin_source = data.source # TODO add source attribute to DataPacket 126 | self._new_record_event = threading.Event() 127 | self._monitoring_thread = None 128 | self.__lock__ = threading.Lock() 129 | 130 | def notify_on_new_record_event(self, record: AnyData) -> None: 131 | """ 132 | Set the event to signal that the variable has changed. 133 | This method should be called when the variable is changed. 134 | """ 135 | with self.__lock__: 136 | if self._origin_data.timestamp < record.timestamp and \ 137 | record.source != self._origin_source: 138 | # If the originating timestamp is less than the new record's timestamp, 139 | # it means that some incoming input was received while processing the block of code. 140 | self._new_record_event.set() 141 | logger.warning(f"StreamContextManager: Monitored variable was changed due to src {record.source} at {record.timestamp} ms, which is after originating timestamp: {self._origin_data.timestamp} ms from {self._origin_data}") 142 | 143 | # def _monitor_variable(self): 144 | # """ 145 | # Thread function that waits for the event to be set. 146 | # """ 147 | # # Wait until the event is set (signaling the variable changed). 148 | # self._new_record_event.wait() 149 | # # Once the event is set, we can check the conditions. 150 | 151 | def raise_error_if_any(self): 152 | """ Checks if the event is set and raises an exception if it is. 153 | This method should be called at the end of the context block to ensure that no incoming packets were received while processing. 154 | As well as, before processing any new data packet or stream. 155 | If the event is set, it raises an IncomingPacketWhileProcessingException with the invalidating record. 156 | Raises: 157 | IncomingPacketWhileProcessingException: If the event is set, indicating that an incoming packet was received while processing. 158 | """ 159 | 160 | with self.__lock__: 161 | if self._new_record_event.is_set(): 162 | raise IncomingPacketWhileProcessingException(Context().get_most_recent_data_pack_record()) 163 | 164 | def __enter__(self): 165 | # """ 166 | # Starts the monitoring thread. 167 | # """ 168 | # self._monitoring_thread = threading.Thread(target=self._monitor_variable) 169 | # self._monitoring_thread.start() 170 | Context().register_observer(self) 171 | return self 172 | 173 | def __exit__(self, exc_type, exc_val, exc_tb): 174 | """ 175 | Checks if the event was set when exiting the context. 176 | If the event is set, it raises an error. 177 | If an exception occurred, it returns False to propagate the exception. 178 | 179 | """ 180 | Context().unregister_observer(self) 181 | self.raise_error_if_any() 182 | # Return False to propagate other exceptions. 183 | return False -------------------------------------------------------------------------------- /client/python/sound_manager.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pyaudio 3 | import numpy as np 4 | from typing import List, Dict 5 | from pydub import AudioSegment 6 | from threading import Thread, Lock 7 | from loguru import logger 8 | 9 | 10 | class SoundManager: 11 | """Sound Manager class. Handles microphone stream and audio playback.""" 12 | 13 | _self = None 14 | 15 | # Singleton pattern 16 | def __new__(cls, *args, **kwargs): 17 | if cls._self is None: 18 | cls._self = super().__new__(cls) 19 | return cls._self 20 | 21 | def __init__( 22 | self, 23 | stream_callback, 24 | _format=pyaudio.paFloat32, 25 | _channels=1, 26 | _sample_rate=16000, 27 | _frames_per_buffer=1024, 28 | ): 29 | """Constructor 30 | 31 | Args: 32 | stream_callback (function): callback function to be called when audio is received 33 | _format (pyaudio format, optional): pyaudio format. Defaults to pyaudio.paFloat32. 34 | _channels (int, optional): number of channels. Defaults to 1. 35 | _sample_rate (int, optional): sample rate. Defaults to 16000. 36 | _frames_per_buffer (int, optional): frames per buffer. Defaults to 1024., Where each byte corresponds to a sample of duration of 1/sample_rate seconds. 37 | 38 | Further explanation: 39 | _sample_rate: The number of samples per second. For example, 16000 means 16000 samples per second, meaning each sample corresponds to 1/16000 seconds. 40 | _frames_per_buffer: The number of samples per buffer. For example, 1024 means 1024 samples per buffer, meaning each buffer corresponds to 1024/16000 seconds = 64 milliseconds. 41 | _format: The format of the audio data. For example, pyaudio.paFloat32 means each sample is a float32 value. 42 | _channels: The number of audio channels. For example, 1 means mono audio, 2 means stereo audio. 43 | _stream_callback: A callback function that will be called with the audio data when it is received. It gets called every time a new audio packet is available. which is approximately every 64 milliseconds (1024 samples at 16000 Hz sample rate). 44 | """ 45 | self._format = _format 46 | self._channels = _channels 47 | self._sample_rate = _sample_rate 48 | self._frames_per_buffer = _frames_per_buffer 49 | self._DEBUG_last_packet_details: Dict = None 50 | self.stream = None 51 | self.stream_callback = stream_callback 52 | self.threads_pool: List[Thread] = [] 53 | self._audio = pyaudio.PyAudio() 54 | self._lock = Lock() 55 | self._offset = 0 56 | 57 | def open_mic(self): 58 | """Opens the microphone stream""" 59 | self._mic_stream = self._audio.open( 60 | format=self._format, 61 | channels=self._channels, 62 | rate=self._sample_rate, 63 | input=True, 64 | input_device_index=1, 65 | stream_callback=self.callback_pyaudio, 66 | frames_per_buffer=self._frames_per_buffer, 67 | ) 68 | 69 | def open_speaker(self, sample_rate, sample_width, channels): 70 | """Opens the speaker stream""" 71 | _format = pyaudio.get_format_from_width(sample_width) 72 | if hasattr(self, "_speaker_stream") and self._speaker_stream is not None: 73 | # is setup the same parameters 74 | if ( 75 | self._speaker_stream._format == _format 76 | and self._speaker_stream._channels == channels 77 | and self._speaker_stream._rate == sample_rate 78 | ): 79 | return 80 | self._speaker_stream.stop_stream() 81 | self._speaker_stream.close() 82 | self._speaker_stream = self._audio.open( 83 | format=_format, 84 | channels=channels, 85 | rate=sample_rate, 86 | output=True, 87 | ) 88 | 89 | # def callback(self, indata, frames, time, status): 90 | # """ This is called (from a separate thread) for each audio block. 91 | 92 | # Args: 93 | # indata (numpy.ndarray): audio data 94 | # frames (int): number of frames 95 | # time (CData): time 96 | # status (CData): status 97 | # """ 98 | # self.stream_callback({ 99 | # "audio": list(indata.flatten().astype(float)), 100 | # "numChannels": 1, 101 | # "sampleRate": self._rate, 102 | # "timestamp": int(time.time()*1000) 103 | # }) 104 | 105 | def callback_pyaudio(self, audio_bytes, frame_count, time_info, flags): 106 | """This is called (from a separate thread) for each audio block.""" 107 | 108 | audio_float32 = np.fromstring(audio_bytes, np.float32).astype(float) 109 | # audio_int16 = np.fromstring(audio_bytes, np.int16).astype(float) 110 | timestamp = int(time.time() * 1000) # current timestamp in milliseconds 111 | duration_ms = (frame_count / self._sample_rate) * 1000 # duration in milliseconds 112 | 113 | self.stream_callback( 114 | { 115 | "audio": list(audio_float32), 116 | "numChannels": 1, 117 | "sampleRate": self._sample_rate, 118 | "timestamp": timestamp, 119 | "sampleWidth": 4, # "format": "f32le", # float32 120 | # "format": "s16le", # int16 121 | } 122 | ) 123 | # if self._DEBUG_last_packet_details is not None: 124 | # _last_timestamp = self._DEBUG_last_packet_details['timestamp'] 125 | # _last_duration = self._DEBUG_last_packet_details['duration'] 126 | # _last_timestamp_end = _last_timestamp + _last_duration 127 | # if timestamp - _last_timestamp_end > 0: 128 | # logger.warning( 129 | # f"Audio packet received with timestamp {timestamp} ms, " 130 | # f"but last packet ended at {_last_timestamp_end} ms. " 131 | # f"DIFFERENCE: {timestamp - _last_timestamp_end} ms" 132 | # ) 133 | 134 | # self._DEBUG_last_packet_details = { 135 | # "timestamp": timestamp, 136 | # "duration": duration_ms, 137 | # } 138 | 139 | return audio_bytes, pyaudio.paContinue 140 | 141 | def close_mic(self): 142 | """Closes the microphone stream""" 143 | if self._mic_stream and self._mic_stream.is_active: 144 | self._mic_stream.stop_stream() 145 | self._mic_stream.close() 146 | 147 | def _enqueue_task(self, func, *args): 148 | """Enqueues a task to be executed in a thread 149 | 150 | Args: 151 | func (function): function to be executed 152 | *args: arguments to be passed to function 153 | """ 154 | thread = Thread(target=func, args=args, daemon=True) 155 | thread.name = f"SoundManagerThread-{len(self.threads_pool)}" 156 | thread.start() 157 | self.threads_pool.append(thread) 158 | 159 | def interrupt(self, timestamp): 160 | """Interrupts the audio playback by setting the offset to the timestamp""" 161 | self._offset = timestamp 162 | 163 | def play_audio_packet(self, audio_packet, block=True): 164 | """Plays audio bytes 165 | 166 | Args: 167 | audio (bytes or str): audio bytes or filepath to audio bytes 168 | block (bool, optional): if True, blocks until audio is played. Defaults to False. 169 | """ 170 | def _play_packet(audio_packet): 171 | with self._lock: 172 | # divide audio into chunks 173 | for i in range(0, len(audio_packet['bytes']), self._frames_per_buffer): 174 | if audio_packet['timestamp'] < self._offset: 175 | logger.warning("Skipping audio packet as it is interrupted") 176 | break 177 | 178 | print('>', end='') 179 | audio_bytes = audio_packet['bytes'][i : i + self._frames_per_buffer] 180 | sample_rate = audio_packet['sampleRate'] 181 | sample_width = audio_packet['sampleWidth'] 182 | num_channels = audio_packet['numChannels'] 183 | # play audio 184 | self.open_speaker(sample_rate, sample_width, num_channels) 185 | self._speaker_stream.write(audio_bytes) 186 | 187 | if block: 188 | _play_packet(audio_packet) 189 | else: 190 | self._enqueue_task(_play_packet, audio_packet) 191 | 192 | def play_audio_file(self, filepath, format, block=False): 193 | audio = AudioSegment.from_file(filepath, format=format) 194 | self.play_audio_packet( 195 | { 196 | "bytes": audio.raw_data, 197 | "sampleRate": audio.frame_rate, 198 | "sampleWidth": audio.sample_width, 199 | "numChannels": audio.channels, 200 | }, 201 | block=True, 202 | ) 203 | 204 | def play_activation_sound(self): 205 | """Plays activation sound""" 206 | self.play_audio_file("assistant_activate.wav", format="wav", block=True) 207 | 208 | 209 | def play_termination_sound(self): 210 | """Plays termination sound""" 211 | self.play_audio_file("assistant_terminate.mp3", format="mp3", block=True) -------------------------------------------------------------------------------- /mangrove/vad/endpoints/base.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Union, List 3 | from abc import ABCMeta, abstractmethod 4 | from functools import reduce 5 | from storage_manager import write_output 6 | from core import AudioBuffer, AudioPacket 7 | from core.utils import logger 8 | 9 | class VoiceActivityDetector(metaclass=ABCMeta): 10 | 11 | @abstractmethod 12 | def is_speech(self, audio_packets: Union[List[AudioPacket], AudioPacket]) -> Union[bool, List[bool]]: 13 | raise NotImplementedError("is_speech must be implemented in subclass") 14 | 15 | def __init__( 16 | self, 17 | head_silence_buffer_size: int = 200, # to buffer some silence at the head of the utterance 18 | tail_silence_threshold: int = 750, # to cut off the utterance and send it off 19 | threshold_to_determine_speaking: int = 1000, # 1 second 20 | frame_size: int = 320 * 3, 21 | verbose: bool = False 22 | ): 23 | """ 24 | Initialize the VoiceActivityDetector. 25 | 26 | Args: 27 | head_silence_buffer_size (int): Amount of buffered silence in milliseconds to place at the head of the utterance. 28 | tail_silence_threshold (int): Amount of silence in milliseconds after which the utterance is observed before it is sent off. 29 | threshold_to_determine_speaking (int): Minimum duration in milliseconds of the utterance to be considered as speaking. 30 | frame_size (int): Size of the audio frame in samples. 31 | verbose (bool): If True, enables verbose logging. 32 | """ 33 | 34 | self._verbose = verbose 35 | 36 | self._tail_silence_threshold: int = tail_silence_threshold 37 | self._frame_size: int = frame_size 38 | self._head_silence_buffer_size: int = head_silence_buffer_size 39 | self._threshold_to_determine_speaking: int = threshold_to_determine_speaking 40 | 41 | self._tail_silence_start_timestamp: int = None 42 | self._reset_head_silences_buffer() 43 | 44 | self._command_audio_packet: AudioPacket = None 45 | self._output_queue: AudioBuffer = AudioBuffer() 46 | 47 | @property 48 | def frame_size(self): 49 | return self._frame_size 50 | 51 | def reset(self) -> None: 52 | self._command_audio_packet: AudioPacket = None 53 | self._tail_silence_start_timestamp: int = None 54 | self._reset_head_silences_buffer() 55 | 56 | def _reset_head_silences_buffer(self) -> None: 57 | """Reset silence buffer which is concatenated to the head of the utterance""" 58 | amount_to_keep_packets = (self._frame_size // 320) * (self._head_silence_buffer_size // 20) 59 | self._head_silences_buffer = collections.deque(maxlen=amount_to_keep_packets) 60 | self._num_recorded_chunks = 0 61 | 62 | def _concat_head_buffered_silences(self, audio_packet: AudioPacket) -> AudioPacket: 63 | """Concatenate buffered silence at the head of the utterance audio packet 64 | Args: 65 | audio_packet (AudioPacket): The audio packet to which the buffered silence will be concatenated. 66 | Returns: 67 | AudioPacket: The audio packet with the buffered silence concatenated at the head. 68 | """ 69 | assert isinstance(audio_packet, AudioPacket), f"audio_packet must be AudioPacket, found {type(audio_packet)}" 70 | if not self._head_silences_buffer: 71 | # if there are no buffered silences, just return the audio packet 72 | logger.debug(f"No buffered silences, returning audio packet of duration {audio_packet.duration}") 73 | return audio_packet 74 | 75 | # if there are buffered silences, concatenate them to the head of the audio packet 76 | logger.debug(f"Concatenating {len(self._head_silences_buffer)} buffered silences to audio packet of duration {audio_packet.duration}") 77 | silences_audio_packet: AudioPacket = reduce(lambda x, y: x + y, self._head_silences_buffer) 78 | complete_frame = silences_audio_packet + audio_packet 79 | self._reset_head_silences_buffer() 80 | return complete_frame 81 | 82 | 83 | def feed(self, audio_packet: AudioPacket) -> None: 84 | """Feed audio packet to the VAD and process it. 85 | Args: 86 | audio_packet (AudioPacket): The audio packet to be processed. 87 | """ 88 | assert isinstance(audio_packet, AudioPacket), f"audio_packet must be AudioPacket, found {type(audio_packet)}" 89 | if self.is_speech(audio_packet): 90 | if self._command_audio_packet is None: 91 | # start a new command audio packet if not already started 92 | self._command_audio_packet = self._concat_head_buffered_silences(audio_packet) # conatenate a bit of the buffered audio right before it. 93 | logger.success(f"Starting an utterance AudioPacket at {self._command_audio_packet.timestamp}") 94 | else: 95 | DEBUG__difference_between_start_to_end = audio_packet.timestamp - self._command_audio_packet.timestamp 96 | # TODO should I add silence/padding according to the difference between the start and end of the audio packet? 97 | # append to the existing on-going command audio packet 98 | self._command_audio_packet += audio_packet 99 | 100 | else: 101 | # silence detected 102 | if self._command_audio_packet is not None: 103 | # if detected silence after voice, append silence to voice 104 | DEBUG__difference_between_start_to_end = audio_packet.timestamp - self._command_audio_packet.ending_timestamp 105 | # TODO should I add silence/padding according to the difference between the start and end of the audio packet? 106 | self._command_audio_packet += audio_packet 107 | 108 | # Check if silence threshold is reached ?? TODO 109 | if self._tail_silence_start_timestamp is None: 110 | # if this is the first silence after voice, set the tail silence timestamp 111 | self._tail_silence_start_timestamp: int = audio_packet.timestamp 112 | 113 | else: 114 | # if this is not the first silence after voice, check if the silence duration is greater than the threshold so that we can send off the utterance 115 | assert (audio_packet.ending_timestamp - self._command_audio_packet.ending_timestamp) == DEBUG__difference_between_start_to_end, \ 116 | f"Ending timestamp of new packet {audio_packet.ending_timestamp} - starting timestamp of command audio packet {self._command_audio_packet.ending_timestamp} should be equal to the difference between the start and end of the audio packet {DEBUG__difference_between_start_to_end} != {audio_packet.ending_timestamp - self._command_audio_packet.ending_timestamp}" 117 | 118 | now_timestamp: int = self._command_audio_packet.ending_timestamp # TODO should now timestamp correspond to real time or to the end of the audio packet? 119 | 120 | # TODO check latency of the audio packet receival 121 | # the silence duration is the difference between the current timestamp and the tail silence starting timestamp 122 | silence_duration: int = now_timestamp - self._tail_silence_start_timestamp 123 | 124 | # logger.debug(f'Got Silence after voice duration: {silence_duration}') 125 | if silence_duration >= self._tail_silence_threshold: 126 | # if the silence duration is greater than the tail silence threshold, we can send off the utterance 127 | self._output_queue.put(self._command_audio_packet) 128 | logger.success(f"Utterance completed at {now_timestamp}, duration: {self._command_audio_packet.duration} ms") 129 | self.log("\n[end]", force=True) 130 | self.reset() 131 | 132 | else: 133 | # if no command audio packet is started, we can just buffer the silence 134 | self._head_silences_buffer.append(audio_packet) 135 | 136 | 137 | def get_utterance_if_any(self) -> Union[AudioPacket, None]: 138 | """Get the utterance if any is available in the output queue 139 | 140 | Returns: 141 | AudioPacket: The utterance audio packet if available, otherwise None 142 | """ 143 | 144 | if self._output_queue.qsize() == 0: 145 | return None 146 | audio_packets = [] 147 | while self._output_queue.qsize() > 0: 148 | audio_packet = self._output_queue.get_nowait() 149 | audio_packets.append(audio_packet) 150 | audio_packet: AudioPacket = reduce(lambda x, y: x + y, audio_packets) 151 | return audio_packet 152 | 153 | def is_speaking(self) -> bool: 154 | return self._command_audio_packet is not None and self._command_audio_packet.duration >= self._threshold_to_determine_speaking 155 | 156 | def log(self, msg, end="", force=False) -> None: # TODO: refactor out into progress logger 157 | """Log message to console if verbose is True or force is True with flush 158 | 159 | Args: 160 | msg (str): Message to log 161 | end (str, optional): End character. Defaults to "". 162 | force (bool, optional): Force logging. Defaults to False. 163 | 164 | """ 165 | if self._verbose or force: 166 | write_output(msg, end=end) -------------------------------------------------------------------------------- /mangrove/bot/stage.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional, List, Union, Dict 2 | from langchain.schema import BaseMessage, HumanMessage, AIMessage 3 | 4 | from core.utils import logger 5 | from core.stage import TextToTextStage 6 | from core.data import TextPacket, DataPacketStream 7 | from core.context import IncomingPacketWhileProcessingException 8 | 9 | class BotStage(TextToTextStage): 10 | def __init__(self, name: str, endpoint: str='openai', persona_configs: Union[Dict[str, str], str]={}, endpoint_kwargs: Dict={}, verbose: bool=False): 11 | """Initialize Bot Stage 12 | 13 | Args: 14 | name (str): Name of the stage. 15 | endpoint (str, optional): Endpoint to use for the bot. Defaults to 'openai'. 16 | endpoint_kwargs (Dict, optional): Additional keyword arguments for the endpoint. Defaults to {}. 17 | persona_kwargs (Dict, optional): Additional keyword arguments for the persona. Defaults to {}. 18 | verbose (bool, optional): Whether to print debug messages. Defaults to False. 19 | """ 20 | super().__init__(name=name, verbose=verbose) 21 | 22 | if endpoint == 'openai': 23 | from .persona.protector_of_mangrove import ProtectorOfMangrove 24 | persona_kwargs = persona_configs if isinstance(persona_configs, dict) else {} 25 | self._persona = ProtectorOfMangrove(**persona_kwargs) 26 | from .endpoints.chat_openai import ChatOpenAIEndpoint 27 | self._endpoint = ChatOpenAIEndpoint(**endpoint_kwargs) 28 | elif endpoint == 'ollama': 29 | from .persona.protector_of_mangrove_qwen3 import ProtectorOfMangroveQwen3 30 | self._persona = ProtectorOfMangroveQwen3(persona_file=persona_configs) 31 | from .endpoints.chat_ollama import ChatOllamaEndpoint 32 | self._endpoint = ChatOllamaEndpoint(**endpoint_kwargs) 33 | else: 34 | raise Exception(f"Unknown Endpoint {endpoint}, available endpoints: openai, ollama") 35 | 36 | self._endpoint.setup(self._persona) 37 | 38 | self._chat_history: List[BaseMessage] = [] 39 | self._in_progress_user_text_packet: Optional[TextPacket] = None 40 | 41 | self._partial_command = "" 42 | self._in_command = False 43 | 44 | def process(self, in_text_packet: TextPacket) -> None: 45 | assert isinstance(in_text_packet, TextPacket), f"Expected TextPacket, got {type(in_text_packet)}" 46 | logger.success(f"Processing incoming: {in_text_packet}") 47 | _output_text_packet_generator: Iterator[TextPacket] = self.respond(in_text_packet) 48 | 49 | # if the input is empty, just return an empty generator 50 | # if self._output_text_packet_generator is None: 51 | # self._output_text_packet_generator = self.respond(in_text_packet) 52 | 53 | # else: 54 | # # interrupt the current conversation and replace with new input 55 | # # self.schedule_forward_interrupt() # TODO review the interrupt logic 56 | # logger.warning('Interrupting current conversation with new input') 57 | # self._output_text_packet_generator = None 58 | # # if chat history has ended with an AIMessage, delete it 59 | # if len(self._chat_history) > 0 and isinstance(self._chat_history[-1], AIMessage): 60 | # self._chat_history.pop() 61 | 62 | # if self._in_progress_user_text_packet is not None: 63 | # logger.warning(f'Interrupting current conversation with in-progress human text packet: {self._in_progress_user_text_packet}') 64 | # # self._chat_history.append(self._in_progress_user_text_packet) 65 | # # TODO just old input be attached to new input? 66 | 67 | # self._output_text_packet_generator = self.respond(in_text_packet) 68 | # logger.warning(f'Interrupting current conversation with new input: {in_text_packet}') 69 | 70 | # # TODO remove the below code if not needed 71 | # # assumption that it has already generating, ignore new input for now 72 | # # logger.warning(f'Dropping new input, already generating: {in_text_packet}') 73 | 74 | self.pack(_output_text_packet_generator) 75 | 76 | # refactor it as local scope method 77 | def _process_stream_chunk(self, chunk: str) -> tuple[str, list[str]]: 78 | clean_text = "" 79 | commands = [] 80 | 81 | for char in chunk: 82 | if char == '[': 83 | self._in_command = True 84 | self._partial_command = '[' 85 | elif char == ']' and self._in_command: 86 | self._in_command = False 87 | self._partial_command += ']' 88 | commands.append(self._partial_command[1:-1]) 89 | self._partial_command = "" 90 | elif self._in_command: 91 | self._partial_command += char 92 | else: 93 | clean_text += char 94 | return clean_text, commands 95 | 96 | def respond(self, in_text_packet: TextPacket) -> Iterator[TextPacket]: 97 | def _pack_response(content, commands=[], partial=False, start=False): 98 | # format response from openai chat to be sent to the user 99 | return TextPacket( 100 | text=content, 101 | commands=commands, 102 | partial=partial, 103 | start=start 104 | ) 105 | 106 | # if there is incoming packet; we should invalidate in-progress outcoming packets if any 107 | if self._in_progress_user_text_packet is not None: 108 | # if the in-progress user text packet is not None, it means that there is an in-progress user text packet that has been invalidated earlier 109 | logger.warning(f"New input is added, appending to prev. in-progress user text packet: {self._in_progress_user_text_packet}") 110 | new_timestamp = in_text_packet.timestamp 111 | in_text_packet = self._in_progress_user_text_packet + in_text_packet 112 | in_text_packet._timestamp = new_timestamp # keep the timestamp of the new input 113 | logger.success(f"New input: {in_text_packet}") 114 | 115 | self._in_progress_user_text_packet = in_text_packet.copy() 116 | ai_res_content = "" 117 | clean_ai_res_content = "" 118 | current_commands = [] 119 | first_chunk = True 120 | for chunk in self._endpoint.stream( 121 | chat_history=self._chat_history, 122 | user_msg=in_text_packet.text, 123 | ): 124 | ai_res_content += chunk 125 | if chunk == "": 126 | continue 127 | 128 | clean_text, commands = self._process_stream_chunk(chunk) 129 | clean_ai_res_content += clean_text 130 | current_commands += commands 131 | yield _pack_response(clean_text, commands=commands, partial=True, start=first_chunk) 132 | first_chunk = False 133 | logger.success(f"Finished streaming AI response: {clean_ai_res_content}") 134 | 135 | yield _pack_response(clean_ai_res_content, commands=current_commands, partial=False, start=True) 136 | self._chat_history.append(HumanMessage(content=in_text_packet.text)) 137 | self._in_progress_user_text_packet = None 138 | # append the AIMessage to the chat history 139 | self._chat_history.append(AIMessage(content=ai_res_content)) 140 | logger.success(f"Finished generating AI Response: {ai_res_content}") 141 | 142 | def on_incoming_packet_while_processing(self, e: IncomingPacketWhileProcessingException, data: DataPacketStream) -> None: 143 | # TODO maybe we should consider taking values that have been propagated although not yet processed by next stage 144 | logger.warning(f"Invalidating stream due to: {e}, hence stopping this stream: {data}") 145 | # TODO if some chunk has been been processed by this, as well as by next stage, we should take the part that has been, 146 | # TODO then we should append it to the history, and reset the in-progress user text packet! 147 | # TODO note tho that the incoming packet, could have been before then concatenated with the in-progress user text packet 148 | return True # stop current response generation 149 | 150 | # def process_procedures_if_on(self): 151 | # # TODO: Implement in different stage 152 | # pass 153 | 154 | def on_interrupt(self, timestamp: int) -> None: 155 | if self._in_progress_user_text_packet is not None: 156 | # CASE 1: assuming that the interrupt is called while the bot is generating a response (This is handled by on_incoming_packet_while_processing) 157 | logger.warning(f"Interrupting current conversation with in-progress user text packet: {self._in_progress_user_text_packet}, handled by on_incoming_packet_while_processing") 158 | return 159 | 160 | # CASE 2: assuming that the interrupt is called while the bot is waiting for a new input; 161 | # in such case the bot chat history should be fixed by removing the last AIMessage message if it is the last message in the chat history 162 | logger.warning(f"Interrupting current conversation, removing last AIMessage from chat history") 163 | assert len(self._chat_history) > 0, "Chat history should not be empty when interrupting" 164 | assert isinstance(self._chat_history[-1], AIMessage), "Last message in chat history should be AIMessage when interrupting" 165 | self._chat_history.pop() # remove the last AIMessage from the chat history -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mangrove 2 | Mangrove is the backend module of Estuary, a framework for building multimodal real-time Socially Intelligent Agents (SIAs). 3 | 4 | ## Give us a Star! ⭐ 5 | If you find Estuary helpful, please give us a star! Your support means a lot! 6 | If you find any bugs or would like to request a new feature, feel free to open an 7 | issue! 8 | 9 | ## Citing Estuary 10 | If you would like to use Estuary for your work, please cite: 11 | 12 | ```bash 13 | @inproceedings{10.1145/3652988.3696198, 14 | author = {Lin, Spencer and Rizk, Basem and Jun, Miru and Artze, Andy and Sullivan, Caitl\'{\i}n and Mozgai, Sharon and Fisher, Scott}, 15 | title = {Estuary: A Framework For Building Multimodal Low-Latency Real-Time Socially Interactive Agents}, 16 | year = {2024}, 17 | isbn = {9798400706257}, 18 | publisher = {Association for Computing Machinery}, 19 | address = {New York, NY, USA}, 20 | url = {https://doi.org/10.1145/3652988.3696198}, 21 | doi = {10.1145/3652988.3696198}, 22 | booktitle = {Proceedings of the 24th ACM International Conference on Intelligent Virtual Agents}, 23 | articleno = {50}, 24 | numpages = {3}, 25 | location = {GLASGOW, United Kingdom}, 26 | series = {IVA '24}} 27 | ``` 28 | 29 | ## Supported Endpoints 30 | 31 | ### Speech-To-Text (STT/ASR) 32 | * Faster-Whisper 33 | 34 | ### Large Language Models (LLMs) 35 | * ChatGPT 36 | * Ollama 37 | 38 | ### Text-To-Speech (TTS) 39 | * ElevenLabs 40 | * XTTS-v2 41 | * Google gTTS 42 | * pyttsx3 43 | 44 | 45 | # Setup 46 | ## Environment Setup 47 | 1. **[WSL Ubuntu 22.04]** Currently, Mangrove is tested to work in WSL Ubuntu 22.04. To install WSL, follow this [official guide]((https://learn.microsoft.com/en-us/windows/wsl/install)) from Microsoft. 48 | 2. **[Updating WSL]** Run `sudo apt update` and `sudo apt upgrade` in WSL. 49 | 3. **[Installing pipx]** Run `sudo apt install pipx` in WSL. 50 | 4. **[Installing pdm]** Run `pipx install pdm` in WSL. 51 | 5. **[Installing Conda]** Refer to the Miniconda installation 52 | guide. 53 | 54 | ## Installing Dependencies 55 | 1. Run the following command to install packages: 56 | ```bash 57 | sudo apt-get install libcairo2-dev pulseaudio portaudio19-dev libgirepository1.0-dev libespeak-dev sox ffmpeg gstreamer-1.0 clang 58 | ``` 59 | 2. Open a powershell terminal window and restart your WSL shell (some packages require a restart to finish installation) 60 | ```bash 61 | wsl --shutdown 62 | ``` 63 | 3. Clone this repository into your WSL environment and navigate into it 64 | ```bash 65 | git clone https://github.com/estuary-ai/mangrove.git 66 | cd mangrove 67 | ``` 68 | 4. Create a Python 3.9.19 virtual environment with Conda: 69 | ```bash 70 | conda create -n mangrove python=3.9.19 71 | conda activate mangrove 72 | ``` 73 | 5. Enter the command `pdm use` and select the correct Python interpreter to use e.g. `/home/username/miniconda3/envs/mangrove/bin/python` 74 | 6. Install Python dependencies. 75 | ```bash 76 | pdm install -G :all 77 | ``` 78 | 79 | Congrats! This is the end of the initial installation for Mangrove. Please refer to the next section for running Mangrove for the first time! 80 | 81 | ## Running Mangrove for the First Time 82 | 83 | ### Initial Steps 84 | 1. Navigate to the Mangrove root directory. 85 | ```bash 86 | cd mangrove 87 | ``` 88 | 2. Activate the Conda virtual environment that was previously set up. 89 | ```bash 90 | conda activate mangrove 91 | ``` 92 | ### Selecting an LLM 93 | * ChatGPT: Refer to the [API Keys](https://github.com/estuary-ai/mangrove?tab=readme-ov-file#api-keys) section below for set up if you would like to use OpenAI 94 | * Flag: `--bot_endpoint openai` 95 | * Ollama: If you would like to use offline LLMs and have the VRAM to run them, you may consult the [Ollama](https://github.com/estuary-ai/mangrove?tab=readme-ov-file#ollama) section for set up instructions. 96 | * Flag: `--bot_endpoint ollama` 97 | 98 | ### Selecting a TTS module 99 | * XTTS: This is a popular offline TTS module that produces both high quality results and is performant at runtime. You can refer to the [XTTS](https://github.com/estuary-ai/mangrove?tab=readme-ov-file#xtts) section for set up instructions. 100 | * Flag: `--tts_endpoint xtts` 101 | * gTTS: This is a free cloud-based TTS module offered by Google. 102 | * Flag: `--tts_endpoint gtts` 103 | 104 | ### Other Configurations 105 | * You may specify which Port number you would like to use with the `--port` flag. 106 | * You may use CPU for processing with the `--cpu` flag. 107 | 108 | ### Example Commands 109 | * Default run command which uses OpenAI and ElevenLabs and port 4000: 110 | ```bash 111 | python launcher.py 112 | ``` 113 | * Example run command which uses the above flags: 114 | ```bash 115 | python launcher.py --bot_endpoint ollama --tts_endpoint xtts --port 4000 116 | ``` 117 | 118 | ### Connecting a Client 119 | * Python Client: This option is recommended for Python projects or for quick debugging purposes. 120 | * Navigate to the client/python directory. 121 | ```bash 122 | cd client/python/ 123 | ``` 124 | * Run the following command to start the client on port 4000: 125 | ```bash 126 | python client.py 127 | ``` 128 | * You may also specify the address and port for the client to connect to with the `--address` and `--port` flags. 129 | * Unity Client: If you are building a Unity application, refer to the Estuary Unity SDK [Documentation](https://github.com/estuary-ai/Estuary-Unity-SDK). 130 | 131 | ## Further Setup as Required 132 | 133 | ### API Keys 134 | - Mangrove supports the usage of APIs (e.g., OpenAI), which require API keys. Create `.env` file in the root directory of the project and add your API keys as follows: 135 | ```bash 136 | OPENAI_API_KEY=[your OpenAI API Key] 137 | ELEVENLABS_API_KEY=[your ElevenLabs API Key] 138 | ``` 139 | 140 | ### Ollama 141 | - Install Ollama inside of wsl by running the command: 142 | ```bash 143 | curl -fsSL https://ollama.com/install.sh | sh 144 | ``` 145 | - Install an LLM from [Ollama's model library](https://ollama.com/search) e.g. 146 | ```bash 147 | ollama run nemotron-mini 148 | ``` 149 | 150 | ### XTTS 151 | - Running XTTS (using Deepspeed) requires a standlone version of cuda library (the same version as the one used by `torch.version.cuda`): 152 | 1. Install `dkms` package to avoid issues with the installation of the cuda library: `sudo apt-get install dkms` 153 | 2. Install CUDA 12.1 from the [NVIDIA website](https://developer.nvidia.com/cuda-12-1-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=WSL-Ubuntu&target_version=2.0&target_type=runfile_local). 154 | 3. Follow the instructions given by the installation process including installing the driver. 155 | ```bash 156 | sudo sh cuda_12.1.0_530.30.02_linux.run --silent --driver 157 | ``` 158 | 4. Add the following to the .bashrc file with any code editor ie. `nano ~/.bashrc` 159 | ```bash 160 | export PATH=/usr/local/cuda-12.1/bin${PATH:+:${PATH}} 161 | export LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} 162 | ``` 163 | 5. Add a 6s-30s voice training clip to the root of the project directory. Make sure to name it `speaker.wav`. 164 | 6. Make sure to restart WSL afterwards with `wsl --shutdown` in Powershell. 165 | 166 | ### Networked Configuration 167 | 168 | If you are running Mangrove in WSL and would like to configure Local Area Network (LAN) communications for a remote client, WSL must be set to mirrored network configuration. You can do this with the following steps: 169 | 170 | 1. Open Powershell and create/open the .wslconfig file in the `C:\Users\[username]\` directory. 171 | 2. Add the following to the .wslconfig file: 172 | ```bash 173 | [wsl2] 174 | networkingMode=mirrored 175 | [experimental] 176 | dnsTunneling=true 177 | autoProxy=true 178 | hostAddressLoopback=true 179 | ``` 180 | 3. Add an inbound network rule in Windows Security Settings > Firewall & Network Protection > Advanced Settings > Inbound Rules > New Rule... 181 | - Port > TCP, Specific local ports: 4000 > Allow the connection > Check: Domain, Private, Public > Name: Mangrove 182 | 183 | #### Tips 184 | 185 | - Ensure both Mangrove and the client are connected to the same LAN and both the machine running Mangrove and the LAN allow for device-to-device communications. 186 | - Try restarting after applying the above Network Configurations if they do not initially work 187 | - [OPTIONAL] You may refer to the Microsoft WSL documentation on Mirrored Networking [here](https://learn.microsoft.com/en-us/windows/wsl/networking#mirrored-mode-networking). 188 | 189 | # Acknowledgements 190 | Mangrove was built from our base code of developing **Traveller**, the digital assistant of **SENVA**, a prototype Augmented Reality (AR) Heads-Up Display (HUD) solution for astronauts. Thank you to **Team Aegis** for participating in the **NASA SUITs Challenge** for the following years: 191 | 192 | - **2023**: **University of Southern California (USC)** with **University of Berkley (UCBerkley)** 193 | 194 | - **2022**: **University of Southern California (USC)** with **University of Arizona (UA)**. 195 | 196 | The Estuary team would also like to acknowledge the developers, authors, and creatives whose work contributed to the success of this project: 197 | 198 | - SocketIO Protocol: https://socket.io/docs/v4/socket-io-protocol/ 199 | - FlaskSocketIO Library: https://github.com/miguelgrinberg/Flask-SocketIO 200 | - Python SocketIO Library: https://github.com/miguelgrinberg/python-socketio 201 | - Silero-VAD: https://github.com/snakers4/silero-vad 202 | - Faster-Whisper: https://github.com/SYSTRAN/faster-whisper 203 | - PyAudio: https://people.csail.mit.edu/hubert/pyaudio/ 204 | - [XTTs](https://arxiv.org/abs/2406.04904): https://github.com/coqui-ai/TTS 205 | 206 | More to come soon! Stay tuned and Fight On! 207 | -------------------------------------------------------------------------------- /core/stage/sequence.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from typing import Optional, List, Callable, Dict 3 | from core.utils import logger 4 | from core.stage.base import PipelineStage 5 | from core.data import AudioBuffer, DataBuffer, AudioPacket, DataPacket, TextPacket 6 | from core.context import IncomingPacketWhileProcessingException 7 | 8 | from typing import TYPE_CHECKING 9 | if TYPE_CHECKING: 10 | from host import HostNamespace 11 | 12 | class PipelineSequence(PipelineStage, metaclass=ABCMeta): 13 | """PipelineSequence is a sequence of stages that can be processed in order.""" 14 | 15 | input_type = DataPacket 16 | output_type = DataPacket 17 | 18 | def __init__( 19 | self, 20 | name: str, 21 | stages: PipelineStage=[], 22 | verbose=False, 23 | **kwargs 24 | ): 25 | super().__init__(name=name, **kwargs) 26 | self._stages: List[PipelineStage] = stages 27 | self._verbose = verbose 28 | self._on_ready_callback = lambda x: None 29 | self._host: 'HostNamespace' = None 30 | self._response_emission_mapping: Dict[str, Callable[[DataPacket], None]] = {} 31 | 32 | @property 33 | def response_emission_mapping(self) -> Dict[str, Callable[[DataPacket], None]]: 34 | """Mapping of stages to their response emission functions""" 35 | return self._response_emission_mapping 36 | 37 | @response_emission_mapping.setter 38 | def response_emission_mapping(self, mapping: Dict[str, Callable[[DataPacket], None]]): 39 | """Setter for response emission mapping""" 40 | if not isinstance(mapping, dict): 41 | raise ValueError("Response emission mapping must be a dictionary.") 42 | self._response_emission_mapping = mapping 43 | 44 | def add_stage(self, stage: PipelineStage): 45 | self._stages.append(stage) 46 | # ensure the new stage has a unique name 47 | if stage.name in [s.name for s in self._stages[:-1]]: 48 | raise ValueError(f"Stage with name {stage.name} already exists in the pipeline sequence") 49 | 50 | def unpack(self): 51 | # NOT CALLED 52 | pass 53 | 54 | def process(self, _) -> None: 55 | # NOT CALLED 56 | pass 57 | 58 | def build_custom_on_ready_callback(self, stage: PipelineStage) -> Callable[[DataPacket], None]: 59 | """Build a custom on_ready_callback for each stage""" 60 | response_emission_callback: Callable[[DataPacket], None] = self.response_emission_mapping.get(stage.name, None) 61 | 62 | def custom_on_ready_callback(data_packet: DataPacket): 63 | """Custom callback to handle data packet when stage is done with producing data packet and about to send it off""" 64 | # is this stage the last stage in the pipeline? 65 | # is_last_stage = stage == self._stages[-1] 66 | if stage == self._stages[0]: 67 | self._host.emit_interrupt(data_packet.timestamp) 68 | 69 | # If there is a response emission mapping for this stage, use it 70 | if response_emission_callback is not None: 71 | # Call the custom response emission function for this stage 72 | # This function should be defined in the response_emission_mapping 73 | # and should handle the emission of the data packet through the host 74 | # This is a custom callback that emits the data packet through the host 75 | response_emission_callback(data_packet) 76 | 77 | return custom_on_ready_callback 78 | 79 | def on_start(self): 80 | """Setting up the pipeline sequence""" 81 | logger.info(f"Starting pipeline sequence {self.name} with stages: {[stage.name for stage in self._stages]}") 82 | input_stage = self._stages[0] 83 | if input_stage.input_type == AudioPacket: # it is an AudioToAnyStage 84 | logger.info(f"Initializing input buffer for {input_stage}") 85 | assert hasattr(input_stage, 'frame_size'), f"Input stage {input_stage} must have frame_size attribute" 86 | # the pipeline and the first stage share the same input buffer 87 | self.input_buffer = AudioBuffer(frame_size=input_stage.frame_size) # Created on the fly for the first stage 88 | input_stage.input_buffer = self.input_buffer 89 | elif input_stage.input_type == TextPacket: # it is a TextToAnyStage 90 | logger.info(f"Initializing input buffer for {input_stage}") 91 | # the pipeline and the first stage share the same input buffer 92 | self.input_buffer = DataBuffer() # Created on the fly for the first stage 93 | input_stage.input_buffer = self.input_buffer 94 | else: 95 | raise ValueError(f"Input stage {input_stage} must have input type AudioPacket or TextPacket, got {input_stage.input_type}") 96 | 97 | assert self._stages[0].input_buffer is not None, f"Input buffer for the first stage {self._stages[0]} must be set before starting the pipeline sequence" 98 | 99 | for stage, next_stage in zip(self._stages, self._stages[1:]): 100 | assert stage.output_type == next_stage.input_type, f"Output type of stage {stage} must match input type of next stage {next_stage}" 101 | logger.info(f"Connecting stage {stage} to next stage {next_stage}") 102 | if next_stage.input_type == AudioPacket: 103 | assert stage.output_type == AudioPacket, f"Output type of stage {stage} must be AudioPacket to connect to next stage {next_stage}" 104 | assert hasattr(next_stage, 'frame_size'), f"Next stage {next_stage} must have frame_size attribute" 105 | logger.info(f"Setting output buffer frame size for {stage} to {next_stage.frame_size}") 106 | stage.output_buffer.set_frame_size(next_stage.frame_size) 107 | else: 108 | assert isinstance(stage.output_buffer, DataBuffer), f"Output buffer of stage {stage} must be DataBuffer to connect to next stage {next_stage}, got {type(stage.output_buffer)}, while next stage input type is {next_stage.input_type}" 109 | 110 | next_stage.input_buffer = stage.output_buffer 111 | 112 | # verify all input/output buffers are set correctly 113 | logger.info(f"Verifying input/output buffers for all stages in {self.__class__.__name__}") 114 | for stage in self._stages: 115 | if stage.input_type == AudioPacket: 116 | assert hasattr(stage, '_input_buffer'), f"Input buffer for stage {stage} must be set before starting the pipeline sequence" 117 | assert stage.input_buffer is not None, f"Input buffer for stage {stage} must not be None" 118 | else: 119 | assert isinstance(stage.input_buffer, DataBuffer), f"Input buffer for stage {stage} must be DataBuffer, got {type(stage.input_buffer)}" 120 | if stage.output_type == AudioPacket: 121 | assert hasattr(stage, 'output_buffer'), f"Output buffer for stage {stage} must be set before starting the pipeline sequence" 122 | assert stage.output_buffer is not None, f"Output buffer for stage {stage} must not be None" 123 | else: 124 | assert isinstance(stage.output_buffer, DataBuffer), f"Output buffer for stage {stage} must be DataBuffer, got {type(stage.output_buffer)}" 125 | logger.success(f"All stages in {self.__class__.__name__} have valid input/output buffers") 126 | 127 | 128 | def on_incoming_packet_while_processing_callback(exception: DataPacket, data: DataPacket) -> bool: 129 | """Callback to handle incoming packets while processing""" 130 | logger.warning(f"Received incoming packet while processing: {data} with exception: {exception}") 131 | self._host.emit_interrupt(exception.timestamp) 132 | 133 | def get_stage_index(stage_name: str) -> int: 134 | """Get the index of a stage by its name""" 135 | for i, stage in enumerate(self._stages): 136 | if stage.name == stage_name: 137 | return i 138 | raise ValueError(f"Stage with name {stage_name} not found in the pipeline sequence") 139 | 140 | def on_invalidated_packet_callback(exception: IncomingPacketWhileProcessingException, invalid_data: DataPacket, dst_stage: PipelineStage) -> None: 141 | """Callback to handle invalidated data packets""" 142 | src_stage_index: int = get_stage_index(exception.incoming_packet.source) 143 | dst_stage_index: int = get_stage_index(dst_stage.name) 144 | # resolve index of src_stage 145 | if src_stage_index > dst_stage_index: 146 | return 147 | # assert src_stage_index < dst_stage_index, f"Source stage {exception.incoming_packet.source} must be before destination stage {dst_stage.name} in the pipeline sequence" 148 | 149 | # call on_interrupt on every stage before the dst_stage and after the src_stage 150 | for stage in self._stages[src_stage_index:dst_stage_index]: 151 | logger.warning(f"Invalidated packet {invalid_data} in stage {stage}, calling on_interrupt") 152 | stage.on_interrupt(exception.timestamp) 153 | 154 | 155 | for stage in self._stages: 156 | logger.info(f"Starting stage {stage} with input type {stage.input_type} and output type {stage.output_type}") 157 | # Set the on_ready_callback for each stage based on the response_emission_mapping 158 | # If a stage has a response emission mapping, use it 159 | if stage.name in self.response_emission_mapping: 160 | logger.debug(f"Setting up response emission for {stage.name}") 161 | else: 162 | logger.debug(f"No response emission mapping defined for {stage.name}, using default callback") 163 | stage.on_ready_callback = self.build_custom_on_ready_callback(stage) 164 | stage.on_incoming_packet_while_processing_callback = on_incoming_packet_while_processing_callback 165 | stage.on_invalidated_packet_callback = on_invalidated_packet_callback 166 | # Start the stage 167 | stage.start(host=self._host) 168 | 169 | logger.success(f"All stages in {self.__class__.__name__} are ready and started") 170 | 171 | 172 | def start(self, host): 173 | """Start processing thread""" 174 | logger.info(f'Starting {self}') 175 | self._host = host 176 | self.on_start() 177 | 178 | def on_connect(self): 179 | # Implementable 180 | pass 181 | 182 | def on_disconnect(self): 183 | # Implementable 184 | pass 185 | -------------------------------------------------------------------------------- /core/stage/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Callable, List, Union, Iterator 3 | from threading import Lock 4 | 5 | from core.utils import logger 6 | from core.data import DataBuffer, DataBufferEmpty, DataPacket, DataPacketStream, AnyData 7 | from core.data.base_data_buffer import BaseDataBuffer 8 | from core.context import OutcomingStreamContext, IncomingPacketWhileProcessingException 9 | 10 | from ..data.exceptions import SequenceMismatchException 11 | 12 | from typing import TYPE_CHECKING 13 | 14 | if TYPE_CHECKING: 15 | from host import HostNamespace 16 | 17 | class PipelineStage(metaclass=ABCMeta): 18 | 19 | input_type = None 20 | output_type = None 21 | 22 | def __init_subclass__(cls): 23 | if not any("input_type" in base.__dict__ for base in cls.__mro__ if base is not PipelineStage): 24 | raise NotImplementedError( 25 | f"Attribute 'input_type' has not been overwritten in class '{cls.__name__}'" 26 | ) 27 | 28 | if not issubclass(cls.input_type, DataPacket): 29 | raise TypeError( 30 | f"Attribute 'input_type' should be a subclass of DataPacket, got {cls.input_type}" 31 | ) 32 | 33 | # same for output_type 34 | if not any("output_type" in base.__dict__ for base in cls.__mro__ if base is not PipelineStage): 35 | raise NotImplementedError( 36 | f"Attribute 'output_type' has not been overwritten in class '{cls.__name__}'" 37 | ) 38 | 39 | if cls.output_type is None: 40 | raise NotImplementedError( 41 | f"Attribute 'output_type' has not been set in class '{cls.__name__}'" 42 | ) 43 | 44 | def __init__( 45 | self, 46 | name: str, 47 | verbose=False, 48 | **kwargs 49 | ): 50 | self._intermediate_input_buffer = [] 51 | self._input_buffer: DataBuffer = None # Assigned based on the previous stage in the pipeline 52 | self._offloading_buffer: DataBuffer = DataBuffer() # Buffer for offloading data packets to be processed in a separate thread then sent off to output buffer 53 | self._output_buffer: DataBuffer = DataBuffer() # Output buffer for the next stage in the pipeline 54 | 55 | self._name = name 56 | self._verbose = verbose 57 | self.__lock__ = Lock() # TODO option to disable lock 58 | self._on_ready_callback = lambda x: None 59 | self._host: 'HostNamespace' = None 60 | self._is_interrupt_forward_pending: bool = False 61 | self._is_interrupt_signal_pending: bool = False 62 | 63 | @property 64 | def name(self) -> str: 65 | """Name of the stage""" 66 | return self._name 67 | 68 | @property 69 | def input_buffer(self) -> BaseDataBuffer: 70 | """Input buffer for the stage""" 71 | return self._input_buffer 72 | 73 | @input_buffer.setter 74 | def input_buffer(self, buffer: BaseDataBuffer): 75 | # if not isinstance(buffer, BaseDataBuffer): 76 | # raise ValueError(f"Expected BaseDataBuffer, got {type(buffer)}") 77 | logger.debug(f"Setting input buffer for {self.__class__.__name__} to {buffer}") 78 | self._input_buffer = buffer 79 | 80 | @property 81 | def output_buffer(self) -> BaseDataBuffer: 82 | """Output buffer for the stage""" 83 | return self._output_buffer 84 | 85 | @property 86 | def host(self): 87 | return self._host 88 | 89 | @property 90 | def on_ready_callback(self): 91 | return self._on_ready_callback 92 | 93 | @on_ready_callback.setter 94 | def on_ready_callback(self, callback): 95 | if not isinstance(callback, Callable): 96 | raise ValueError("Callback must be callable") 97 | self._on_ready_callback = callback 98 | 99 | def unpack(self) -> DataPacket: 100 | """Unpack data from input buffer and return a complete DataPacket 101 | This method collects data packets from the input buffer and combines them into a single DataPacket, that can be processed by the next stage in the pipeline. 102 | """ 103 | if self._input_buffer is None: 104 | raise RuntimeError("Input buffer is not set. Please set the input buffer before unpacking data.") 105 | 106 | data_packets: List[DataPacket] = self._intermediate_input_buffer 107 | self._intermediate_input_buffer = [] 108 | 109 | if not data_packets: # if intermediate buffer is empty, we need to get at least one packet from input buffer 110 | data_packet = self._input_buffer.get() # blocking call at least for the first time 111 | data_packets.append(data_packet) 112 | else: 113 | # logger.debug("Intermediate buffer is not empty, skipping first get from input buffer") 114 | pass 115 | 116 | # Now we have at least one packet in data_packets, we can try to get more packets 117 | while True: 118 | try: 119 | data_packet = self._input_buffer.get_nowait() 120 | data_packets.append(data_packet) 121 | except DataBufferEmpty: 122 | # if len(data_packets) == 0: 123 | # # logger.warning('No audio packets found in buffer', flush=True) 124 | # return 125 | break 126 | 127 | complete_data_packet = data_packets[0] 128 | for i, data_packet in enumerate(data_packets[1:], start=1): 129 | try: 130 | complete_data_packet += data_packet 131 | except SequenceMismatchException as e: 132 | for j in range(i, len(data_packets)): 133 | self._intermediate_input_buffer.append(data_packets[j]) 134 | break 135 | 136 | return complete_data_packet 137 | 138 | def pack(self, data: Union[DataPacket, Iterator[DataPacket]]) -> None: 139 | """Queue data to an offloading buffer to (processing can be done in a separate thread), then it will be on output buffer""" 140 | # if not isinstance(data_packet, self.output_type): 141 | # raise ValueError(f"Expected {self.output_type}, got {type(data_packet)}") 142 | if isinstance(data, Iterator): 143 | # if data is an iterator, we need to convert it to a DataPacketStream 144 | data = DataPacketStream(data, source=self.name) 145 | else: 146 | assert data.source is None, f"DataPacket source should be None, got {data.source} at {self.__class__.__name__}" 147 | data.source = self.name # Set the source of the data packet to the stage name 148 | 149 | if self._offloading_buffer.full(): 150 | raise NotImplementedError( 151 | f"Offloading buffer is full for {self.__class__.__name__}, cannot pack data: {data}. Consider increasing the buffer size or processing speed." 152 | ) 153 | self._offloading_buffer.put(data) # Offload the data packet to the output buffer 154 | # We mark the complete data packet at the context of the stage as under digestion 155 | logger.debug(f"Packed data into offloading buffer for {self.__class__.__name__}: {data}") 156 | from ..context import Context 157 | Context().record_data_pack(data) 158 | logger.debug(f"Recorded data packet in context for {self.__class__.__name__}: {data}") 159 | 160 | def start(self, host): 161 | """Start processing thread""" 162 | logger.info(f'Starting {self}') 163 | 164 | self._host = host 165 | 166 | self.on_start() 167 | 168 | def _producer_thread(): 169 | while True: 170 | data = self.unpack() # blocking call: unpacking data from the previous output buffer (input buffer) 171 | assert data is not None, f"Unpacked data is None at {self.__class__.__name__}, this should not happen" 172 | 173 | assert isinstance(data, DataPacket), f"Expected DataPacket at {self.__class__.__name__}, got {type(data)}" 174 | # NOTE: start producing task for the stage TODO rename 175 | with self.__lock__: 176 | self.process(data) 177 | 178 | # TODO rethink the interrupt handling 179 | # if self._is_interrupt_signal_pending: 180 | # logger.warning(f"Interrupt signal pending in {self.__class__.__name__}, calling on_interrupt") 181 | # self.on_interrupt() 182 | logger.debug(f"Producer thread for {self.__class__.__name__} stopped") 183 | 184 | def _consumer_thread(): 185 | def _postprocess(packet: DataPacket): 186 | assert isinstance(packet, DataPacket), f"Expected DataPacket at {self.__class__.__name__}, got {type(packet)}" 187 | with self.__lock__: 188 | self.on_ready_callback(packet) 189 | self._output_buffer.put(packet) 190 | 191 | while True: 192 | logger.debug(f"Waiting for data in offloading buffer at {self.__class__.__name__}") 193 | data = self._offloading_buffer.get() # blocking call 194 | logger.debug(f"Received data from offloading buffer at {self.__class__.__name__}: {data}") 195 | if isinstance(data, DataPacketStream): 196 | logger.debug(f"Processing DataPacketStream at {self.__class__.__name__}: {data}") 197 | _current_packet = None 198 | while True: 199 | try: 200 | if _current_packet is not None: 201 | # If we have a current packet, we need to post-process it before processing the next one 202 | _postprocess(_current_packet) 203 | _current_packet = None 204 | with OutcomingStreamContext(data) as stream_context: 205 | for packet in data: # TODO they are being processed right here 206 | _current_packet = packet 207 | stream_context.raise_error_if_any() 208 | _postprocess(packet) 209 | break 210 | logger.debug(f"Stream processed successfully at {self.__class__.__name__}") 211 | except IncomingPacketWhileProcessingException as e: 212 | invalidated = self._on_incoming_packet_while_processing(e, data) 213 | if invalidated: 214 | logger.warning(f"Invalidating timestamp exception in {self.__class__.__name__}: {e}") 215 | # If the stream is invalidated, we skip processing it 216 | break 217 | else: 218 | logger.warning(f"Incoming packet while processing in {self.__class__.__name__}: {e}, but stream is not invalidated, continuing processing") 219 | # we are good to go, continue processing the stream 220 | pass 221 | logger.debug(f"Processed DataPacketStream at {self.__class__.__name__}") 222 | else: 223 | _postprocess(data) 224 | 225 | logger.debug(f"Consumer thread for {self.__class__.__name__} stopped") 226 | 227 | self._producer = self._host.start_background_task(_producer_thread) 228 | self._consumer = self._host.start_background_task(_consumer_thread) 229 | 230 | def _on_incoming_packet_while_processing(self, exception: IncomingPacketWhileProcessingException, data: AnyData) -> bool: 231 | """Internal method to handle incoming packet while processing 232 | This method is called when an incoming packet is received while the stage is processing a data packet or stream. 233 | It calls the on_incoming_packet_while_processing method, which should be overridden in subclasses to implement specific logic for handling incoming packets while processing. 234 | If the stream is invalidated, it should return True, otherwise it should return False. 235 | Args: 236 | exception (IncomingPacketWhileProcessingException): Exception that contains the incoming (possibly invalidating) record. 237 | data (AnyData): The data packet or stream that is being processed when the exception occurred. 238 | Returns: 239 | bool: True if the stream was invalidated, False otherwise 240 | """ 241 | if self.on_incoming_packet_while_processing_callback is None: 242 | self.on_incoming_packet_while_processing_callback(exception, data) 243 | is_invalidated: bool = self.on_incoming_packet_while_processing(exception, data) 244 | if is_invalidated: 245 | if self.on_invalidated_packet_callback is not None: 246 | self.on_invalidated_packet_callback(exception=exception, invalid_data=data, dst_stage=self) 247 | 248 | def on_interrupt(self, timestamp: int) -> None: 249 | """Handle interrupt signal 250 | This method is called when an interrupt signal is received. It can be used to handle the interrupt signal, such as stopping the processing of the current data packet or stream. 251 | The default implementation does nothing, but it can be overridden in subclasses to implement specific logic for handling interrupts. 252 | 253 | Args: 254 | timestamp (int): Timestamp of the interrupt signal 255 | """ 256 | pass # Default implementation does nothing 257 | 258 | def on_incoming_packet_while_processing(self, exception: IncomingPacketWhileProcessingException, data: AnyData) -> bool: 259 | """Handle incoming packet while processing 260 | This method is called when an incoming packet is received while the stage is processing a data packet or stream. 261 | It can be used to invalidate the current stream or data packet being processed, and to handle the incoming packet accordingly. 262 | This method should be overridden in subclasses to implement specific logic for handling incoming packets while processing. 263 | If the stream is invalidated, it should return True, otherwise it should return False. 264 | 265 | Args: 266 | exception (IncomingPacketWhileProcessingException): Exception that contains the incoming (possibly invalidating) record. 267 | data (AnyData): The data packet or stream that is being processed when the exception occurred. 268 | 269 | Returns: 270 | bool: True if the stream was invalidated, False otherwise 271 | """ 272 | assert data.timestamp < exception.timestamp, f"Invalidating timestamp should be greater than or equal to the text packet timestamp {data.timestamp}, got {exception.timestamp}" 273 | return False # Default behavior is to not invalidate the stream 274 | 275 | @property 276 | def on_incoming_packet_while_processing_callback(self) -> Callable: 277 | """Callback to handle incoming packet while processing""" 278 | return self._on_incoming_packet_while_processing_callback 279 | 280 | @on_incoming_packet_while_processing_callback.setter 281 | def on_incoming_packet_while_processing_callback(self, callback: Callable) -> None: 282 | """Set the callback to handle incoming packet while processing""" 283 | if not callable(callback): 284 | raise ValueError("Callback must be callable") 285 | self._on_incoming_packet_while_processing_callback = callback 286 | 287 | @property 288 | def on_invalidated_packet_callback(self) -> Callable: 289 | """Callback to handle invalidated data packet""" 290 | return self._on_invalidated_packet_callback 291 | 292 | @on_invalidated_packet_callback.setter 293 | def on_invalidated_packet_callback(self, callback: Callable) -> None: 294 | """Set the callback to handle invalidated data packet""" 295 | if not callable(callback): 296 | raise ValueError("Callback must be callable") 297 | self._on_invalidated_packet_callback = callback 298 | 299 | @abstractmethod 300 | def process(self, data_packet: DataPacket) -> None: 301 | """Issue processing task for the stage""" 302 | raise NotImplementedError() 303 | 304 | def on_connect(self) -> None: 305 | pass 306 | 307 | def on_disconnect(self) -> None: 308 | pass 309 | 310 | def on_start(self) -> None: 311 | pass 312 | 313 | def feed(self, data_packet: DataPacket) -> None: 314 | self._input_buffer.put(data_packet) 315 | 316 | def log(self, msg, end="", force=False) -> None: 317 | """Log message to console if verbose is True or force is True with flush 318 | 319 | Args: 320 | msg (str): Message to log 321 | end (str, optional): End character. Defaults to "". 322 | force (bool, optional): Force logging. Defaults to False. 323 | 324 | """ 325 | if self._verbose or force: 326 | print(msg, end=end, flush=True) 327 | 328 | # def is_interrupt_forward_pending(self): 329 | # return self._is_interrupt_forward_pending 330 | 331 | # def schedule_forward_interrupt(self): 332 | # self._is_interrupt_forward_pending = True 333 | 334 | # def acknowledge_interrupt_forwarded(self): 335 | # self._is_interrupt_forward_pending = False 336 | 337 | # def signal_interrupt(self, timestamp: int): 338 | # self._is_interrupt_signal_pending = True 339 | # # TODO use timestamp 340 | 341 | # def on_interrupt(self): 342 | # pass 343 | # self._is_interrupt_signal_pending = False 344 | # self.schedule_forward_interrupt() 345 | 346 | # def invoke_wait_for_incoming_packets_logic(self) -> bool: 347 | # """Invoke wait for incoming packets 348 | 349 | # This method is overridden in the stages that need to implement logic to adjust to too slow incoming inputs which has a behavior similar to interruption logic. 350 | 351 | # It is called by the orchestrator when it detects through a context manager that the stage is about send off an output packet too soon, and it needs to wait for more input packets to be processed before sending off the output packet. This particularly called when the on_ready_callback is invoked by the processer thread of the stage. 352 | 353 | # The default logic is to do nothing, but it can be overridden in the subclasses to implement specific logic, such as waiting for more input packets or adjusting the output packet generation logic. 354 | 355 | # Returns: 356 | # bool: True if the stage is waiting for incoming packets, False otherwise. 357 | # """ 358 | # return False -------------------------------------------------------------------------------- /core/data/audio_packet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import numpy as np 4 | from decimal import * 5 | from typing import Type 6 | 7 | from core.utils import logger 8 | from .data_packet import DataPacket 9 | 10 | class AudioPacket(DataPacket): 11 | """Represents a "Packet" of audio data.""" 12 | resampling = 0 13 | 14 | def __init__(self, data_json, source: str=None, resample: bool = True, is_processed: bool = False, target_sample_rate: int = 16000): 15 | """Initialize AudioPacket from json data or bytes 16 | 17 | Args: 18 | data_json (dict or bytes): json data or bytes 19 | """ 20 | if not isinstance(data_json, dict): 21 | data_json = json.loads(str(data_json)) 22 | 23 | self._src_sample_rate: int = int(data_json["sampleRate"]) 24 | self._src_num_channels: int = int(data_json["numChannels"]) 25 | self._src_sample_width: int = int(data_json["sampleWidth"]) 26 | assert self._src_sample_width in [2, 4], f"Unhandled sample width `{self._src_sample_width}`. Please use `2` or `4`" 27 | 28 | self._dst_sample_rate: int = self._src_sample_rate 29 | self._dst_num_channels: int = self._src_num_channels 30 | self._dst_sample_width: int = self._src_sample_width 31 | 32 | 33 | # self._start = data_json.get("start", False) 34 | self._id: str = data_json.get("packetID") 35 | # self._source: str = data_json.get("source", None) 36 | 37 | 38 | # NOTE: we do not keep the src_bytes as they might not be even there 39 | if not is_processed: 40 | self._dst_bytes = self._preprocess_audio_buffer( 41 | data_json.get("bytes", data_json.get("audio")), 42 | resample=resample, 43 | target_sample_rate=target_sample_rate 44 | ) 45 | else: 46 | self._dst_bytes = data_json["bytes"] 47 | 48 | # NOTE: this is happening after the resampling and processing 49 | self._duration = data_json.get("duration") # ms 50 | _calculated_duration = (self.frame_size/self.sample_rate) / ( 51 | self.num_channels * self.sample_width # this was 4, i changed it to self.sample_width, TODO check 52 | ) 53 | _calculated_duration *= 1000 # ms 54 | 55 | if self._duration is None: 56 | # if duration is not provided, we use the calculated duration 57 | self._duration = _calculated_duration 58 | else: 59 | # NOTE: if duration is provided, we still verify it for now 60 | if not Decimal(self._duration).compare(Decimal(_calculated_duration)) == 0: 61 | logger.warning(f"Duration mismatch: {self._duration} != {_calculated_duration}") 62 | 63 | super().__init__(source=source, timestamp=data_json.get("timestamp")) 64 | 65 | def generate_timestamp(self): 66 | """Generate timestamp for AudioPacket based on its duration, given that a timestamp is not provided""" 67 | current_timestamp = time.time() * 1000 # current timestamp in milliseconds 68 | supposed_timestamp = current_timestamp - self._duration 69 | return int(supposed_timestamp) 70 | 71 | # @property 72 | # def start(self): 73 | # return self._start 74 | 75 | @property 76 | def bytes(self): 77 | return self._dst_bytes 78 | 79 | @property 80 | def float(self): 81 | """Get audio buffer as float 82 | 83 | Returns: 84 | np.array(float): audio buffer as float 85 | """ 86 | # NOTE: adding silence to make sure the length is a multiple of 32 87 | # approximation to convert int16 to float32 88 | _bytes = self.bytes + b'0'*(len(self.bytes)%32) 89 | return np.frombuffer(_bytes, dtype=np.float32).copy() 90 | 91 | @property 92 | def sample_rate(self): 93 | return self._dst_sample_rate 94 | 95 | @property 96 | def sample_width(self): 97 | """Get sample width of AudioPacket""" 98 | return self._dst_sample_width 99 | 100 | @property 101 | def num_channels(self): 102 | """Get number of channels of AudioPacket""" 103 | return self._dst_num_channels 104 | 105 | @property 106 | def frame_size(self): 107 | """Get frame size of AudioPacket""" 108 | return len(self.bytes) 109 | 110 | @property 111 | def duration(self): 112 | """Get duration of AudioPacket in ms""" 113 | return self._duration 114 | 115 | @property 116 | def id(self): 117 | return self._id 118 | 119 | @id.setter 120 | def id(self, value): 121 | if self._id is not None: 122 | raise ValueError("Cannot change id once set") 123 | self._id = value 124 | 125 | def to_dict(self) -> dict: 126 | """Convert AudioPacket to dict 127 | 128 | Returns: 129 | dict: AudioPacket as dict 130 | """ 131 | _dict = super().to_dict() 132 | _dict.update( 133 | { 134 | "bytes": self.bytes, 135 | "sampleRate": self.sample_rate, 136 | "sampleWidth": self.sample_width, 137 | "numChannels": self.num_channels, 138 | "duration": self.duration, 139 | # "start": self._start, 140 | "packetID": self.id, 141 | } 142 | ) 143 | return _dict 144 | 145 | @staticmethod 146 | def verify_format(data_json): 147 | """Verify that data_json is in the correct format 148 | 149 | Args: 150 | data_json (dict): json data 151 | """ 152 | for key in ["sampleRate", "bytes", "numChannels"]: 153 | if key not in data_json.keys(): 154 | raise Exception( 155 | f"Invalid AudioPacket format: {key} not in {data_json.keys()}" 156 | ) 157 | 158 | 159 | @staticmethod 160 | def from_bytes_to_float(buffer, sample_rate, num_channels, sample_width): 161 | """Convert audio buffer from bytes to float 162 | 163 | Args: 164 | buffer (bytes): audio buffer 165 | sample_rate (int): sample rate of buffer 166 | num_channels (int): number of channels of buffer 167 | sample_width (int): sample width of buffer 168 | 169 | Returns: 170 | np.array(float): audio buffer as float 171 | """ 172 | if buffer == b"": 173 | # logger.debug("0 Returning empty buffer") 174 | # DUMMY AUX PACKET 175 | return buffer 176 | 177 | if sample_width == 2: # int16 178 | # logger.debug("1 Converting buffer to int16") 179 | buffer_float = np.frombuffer(buffer, dtype=np.int16).reshape((-1, num_channels)) / (1 << (8 * sample_width - 1)) 180 | # import soundfile as sf 181 | # sf.write(f"__original_{AudioPacket.resampling}.wav", buffer_float, sample_rate) 182 | elif sample_width == 4: # float32 183 | # logger.debug("1 Converting buffer to float32") 184 | buffer_float = np.frombuffer(buffer, dtype=np.float32).reshape((-1, num_channels)) / (1 << (8 * sample_width - 1)) 185 | else: 186 | raise ValueError(f"Unhandled format `{format}`. Please use `int16` or `float32`") 187 | 188 | return buffer_float 189 | 190 | @staticmethod 191 | def from_float_to_bytes(buffer_float, sample_rate, num_channels, sample_width): 192 | """Convert audio buffer from float to bytes 193 | 194 | Args: 195 | buffer_float (np.array(float)): audio buffer 196 | sample_rate (int): sample rate of buffer 197 | num_channels (int): number of channels of buffer 198 | sample_width (int): sample width of buffer 199 | 200 | Returns: 201 | bytes: audio buffer as bytes 202 | """ 203 | if buffer_float.size == 0: 204 | # logger.debug("0 Returning empty buffer") 205 | # DUMMY AUX PACKET 206 | return buffer_float.tobytes() 207 | 208 | if sample_width == 2: # int16 209 | # logger.debug("1 Converting buffer to int16") 210 | buffer = (buffer_float * (1 << (8 * sample_width - 1))).astype(np.int16).reshape(-1).tobytes() 211 | elif sample_width == 4: # float32 212 | # logger.debug("1 Converting buffer to float32") 213 | buffer = (buffer_float * (1 << (8 * sample_width - 1))).astype(np.float32).reshape(-1).tobytes() 214 | else: 215 | raise ValueError(f"Unhandled sample width `{sample_width}`. Please use `2` or `4` ") 216 | 217 | return buffer 218 | 219 | def _preprocess_audio_buffer(self, buffer, resample=True, target_sample_rate=16000): 220 | """Preprocess audio buffer to 16k 1ch int16 bytes format 221 | 222 | Args: 223 | buffer Union(np.array(float), bytes): audio buffer 224 | sample_rate (int): sample rate of buffer 225 | num_channels (int): number of channels of buffer 226 | 227 | Returns: 228 | bytes: preprocessed audio buffer 229 | """ 230 | 231 | # TODO remove format as it is the same as sample_width 232 | # 1: Convert to a NumPy array of float32 233 | self._dst_sample_width = 2 # TODO debug this 234 | if isinstance(buffer, bytes): 235 | # 1.1: converting/ensuring a bytes buffer to np.array float32 from either 2 or 4 sample width 236 | buffer_float = AudioPacket.from_bytes_to_float( 237 | buffer, self._src_sample_rate, 238 | self._src_num_channels, self._src_sample_width 239 | ) 240 | else: 241 | # 1.2: converting/ensuring a np.array buffer to np.array float32 from either 2 or 4 sample width 242 | if self._src_sample_width == 2: 243 | buffer_float = np.fromstring(np.array(buffer, dtype=np.int16).tobytes(), dtype=np.float32) 244 | elif self._src_sample_width == 4: 245 | buffer_float = np.array(buffer).astype(np.float32) 246 | else: 247 | raise ValueError(f"Unhandled sample width `{self._src_sample_width}`. Please use `2` or `4`") 248 | 249 | # 2: Merge Channels if > 1 250 | if self._src_num_channels > 1: 251 | # TODO revise 252 | logger.warning(f"AudioPacket has {self._src_num_channels} channels, merging to 1 channel") 253 | one_channel_buffer = np.zeros( 254 | len(buffer_float) // self._src_num_channels, dtype=np.float32 255 | ) 256 | channel_contribution = 1 / self._src_num_channels 257 | for i in range(len(one_channel_buffer)): 258 | for channel_i in range(self._src_num_channels): 259 | one_channel_buffer[i] += ( 260 | buffer_float[i * self._src_num_channels + channel_i] 261 | * channel_contribution 262 | ) 263 | self._dst_num_channels = 1 264 | else: 265 | one_channel_buffer = buffer_float 266 | 267 | 268 | # 3: Resample if necessary 269 | final_buffer = one_channel_buffer 270 | if target_sample_rate != self._src_sample_rate and resample: 271 | # debug resampling TODO 272 | audio_resampled = AudioPacket.resample(one_channel_buffer, self._src_sample_rate, target_sample_rate) 273 | AudioPacket.resampling += 1 274 | self._dst_sample_rate = target_sample_rate 275 | final_buffer = audio_resampled 276 | 277 | if isinstance(buffer, bytes): 278 | self._dst_bytes = AudioPacket.from_float_to_bytes( 279 | final_buffer, 280 | self._dst_sample_rate, 281 | self._dst_num_channels, 282 | self._dst_sample_width, 283 | ) 284 | else: 285 | self._dst_bytes = final_buffer.tobytes() 286 | 287 | return self._dst_bytes 288 | 289 | @staticmethod 290 | def resample(waveform, current_sample_rate, target_sample_rate): 291 | # try: 292 | if target_sample_rate == current_sample_rate: 293 | return waveform 294 | 295 | import torch 296 | from torchaudio.transforms import Resample 297 | 298 | if current_sample_rate > target_sample_rate: 299 | waveform = torch.from_numpy(waveform.copy()) 300 | 301 | # check if resampler is defined and matching the same sample rates 302 | resampler = Resample(current_sample_rate, target_sample_rate, dtype=waveform.dtype) 303 | logger.debug(f"Resampling {current_sample_rate} -> {target_sample_rate}") 304 | 305 | audio_resampled = resampler(waveform).numpy() 306 | else: 307 | # TODO revise this 308 | # if target_sample_rate % current_sample_rate == 0: 309 | # rate = target_sample_rate // current_sample_rate 310 | # audio_resampled = np.zeros(rate*len(waveform)-rate+1, dtype=waveform.dtype) 311 | # audio_resampled[::rate] = waveform 312 | # audio_resampled[1::rate] = (waveform[:-1] + waveform[1:]) / 2 313 | # else: 314 | audio_resampled = np.zeros(int(len(waveform) * target_sample_rate / current_sample_rate), dtype=waveform.dtype) 315 | for i in range(len(audio_resampled)): 316 | audio_resampled[i] = waveform[int(i * current_sample_rate / target_sample_rate)] 317 | 318 | # write the resampled audio to a wav file 319 | # import soundfile as sf 320 | # sf.write(f"resampled_{AudioPacket.resampling}.wav", audio_resampled, target_sample_rate) 321 | # sf.write(f"original_{AudioPacket.resampling}.wav", waveform, current_sample_rate) 322 | 323 | return audio_resampled 324 | 325 | def __add__(self, _audio_packet: "AudioPacket") -> "AudioPacket": 326 | """Add two audio packets together and return new packet with combined bytes 327 | 328 | Args: 329 | _audio_packet (AudioPacket): AudioPacket to add 330 | 331 | Returns: 332 | AudioPacket: New AudioPacket with combined bytes 333 | """ 334 | # ensure no errs, and snippets are consecutive 335 | # TODO verify + duration work 336 | if self > _audio_packet: 337 | raise Exception( 338 | f"Audio Packets are not in order: {self.timestamp} > {_audio_packet.timestamp}" 339 | ) 340 | 341 | # assert not (not self._start and _other._start) 342 | assert self.sample_rate == _audio_packet.sample_rate, f"Sample rates do not match: {self.sample_rate} != {_audio_packet.sample_rate}" 343 | assert self.num_channels == _audio_packet.num_channels, f"Num channels do not match: {self.num_channels} != {_audio_packet.num_channels}" 344 | assert self.sample_width == _audio_packet.sample_width, f"Sample width do not match: {self.sample_width} != {_audio_packet.sample_width}" 345 | assert self.source == _audio_packet.source, f"Sources do not match: {self._source} != {_audio_packet._source}" 346 | # assert self.timestamp + self.duration <= _other.timestamp, f"Audio Packets are not consecutive: {self.timestamp} + {self.duration} = {self.timestamp + self.duration} > {_other.timestamp}" 347 | # if self.timestamp + self.duration > _other.timestamp: 348 | # import math 349 | # if math.isclose(self.timestamp + self.duration, _other.timestamp, abs_tol=500): # 500 ms tolerance 350 | # _other.timestamp = self.timestamp + self.duration 351 | # else: 352 | # raise Exception( 353 | # f"Audio Packets are not consecutive: {self.timestamp} + {self.duration} > {_other.timestamp}, {self.timestamp + self.duration - _other.timestamp}" 354 | # ) 355 | 356 | timestamp = self.timestamp 357 | if self.bytes == b"": # DUMMY AUX PACKET 358 | timestamp = _audio_packet.timestamp 359 | 360 | concat_audio_packet = AudioPacket( 361 | data_json={ 362 | "bytes": self.bytes + _audio_packet.bytes, 363 | "timestamp": timestamp, 364 | "sampleRate": _audio_packet.sample_rate, 365 | "numChannels": _audio_packet.num_channels, 366 | "sampleWidth": _audio_packet.sample_width, 367 | # "start": self._start, 368 | "packetID": self.id 369 | }, 370 | source=self.source, 371 | resample=False, 372 | is_processed=True, 373 | ) 374 | 375 | assert concat_audio_packet.timestamp == self.timestamp, f"Timestamp mismatch: {concat_audio_packet.timestamp} != {self.timestamp}" 376 | 377 | assert np.isclose(concat_audio_packet.duration, self.duration + _audio_packet.duration, atol=1e-1), f"Duration mismatch: {concat_audio_packet.duration} != {self.duration + _audio_packet.duration}" 378 | 379 | difference_between_packets = np.abs(self.ending_timestamp - _audio_packet.timestamp) 380 | difference_between_endings_timestamps = np.abs(concat_audio_packet.ending_timestamp - _audio_packet.ending_timestamp) 381 | assert difference_between_endings_timestamps == difference_between_packets, \ 382 | f"Difference between ending timestamps mismatch: {difference_between_endings_timestamps} != {difference_between_packets}" 383 | 384 | # TODO review again 385 | # assert difference_between_endings_timestamps == 0, \ 386 | # f"Ending timestamp mismatch: {_audio_packet.ending_timestamp} != {concat_audio_packet.ending_timestamp}, with difference between original packets {difference_between_packets}." 387 | 388 | return concat_audio_packet 389 | 390 | @property 391 | def ending_timestamp(self): 392 | """Get ending timestamp of AudioPacket 393 | 394 | Returns: 395 | float: ending timestamp of AudioPacket 396 | """ 397 | return self.timestamp + self.duration 398 | 399 | def __getitem__(self, key): 400 | """Get item from AudioPacket 401 | 402 | Args: 403 | key (int or slice): index or slice 404 | 405 | Returns: 406 | AudioPacket: new AudioPacket with sliced bytes 407 | """ 408 | if isinstance(key, slice): 409 | # Note that step != 1 is not supported 410 | start, stop, step = key.indices(len(self)) 411 | if step != 1: 412 | raise (NotImplementedError, "step != 1 not supported") 413 | 414 | if start < 0: 415 | raise (NotImplementedError, "start < 0 not supported") 416 | 417 | if stop > len(self): 418 | raise (NotImplementedError, "stop > len(self) not supported") 419 | 420 | # calculate new timestamp 421 | calculated_timestamp = ( 422 | self.timestamp + float((start / self.frame_size)) * self.duration 423 | ) 424 | 425 | return AudioPacket( 426 | { 427 | "bytes": self.bytes[start:stop], 428 | "timestamp": calculated_timestamp, 429 | "sampleRate": self.sample_rate, 430 | "numChannels": self.num_channels, 431 | "sampleWidth": self.sample_width, 432 | # "start": self._start, 433 | "packetID": self._id 434 | }, 435 | source=self.source, 436 | resample=False, 437 | is_processed=True, 438 | ) 439 | 440 | elif isinstance(key, int): 441 | raise NotImplementedError("value as index; only slices") 442 | elif isinstance(key, tuple): 443 | raise NotImplementedError("Tuple as index; only slices") 444 | else: 445 | raise TypeError("Invalid argument type: {}".format(type(key))) 446 | 447 | 448 | def __str__(self) -> str: 449 | return f"AudioPacket(t={self.timestamp}, d={self._duration}, s={len(self.bytes)}, src={self.source}, id={self.id})" 450 | 451 | def __eq__(self, __o: object) -> bool: 452 | return self.timestamp == __o.timestamp 453 | 454 | def __lt__(self, __o: object) -> bool: 455 | # TODO verify + duration work 456 | # return self.timestamp + self._duration <= __o.timestamp 457 | return self.timestamp < __o.timestamp 458 | 459 | def __len__(self) -> int: 460 | return self.frame_size 461 | 462 | def play(self): 463 | import sounddevice as sd 464 | sd.play(self.float, self.sample_rate) 465 | 466 | 467 | def to_wav(self, filepath, on_different_thread=True): 468 | """Save AudioPacket to wav file at filepath 469 | 470 | Args: 471 | filepath (str): path to save wav file 472 | on_different_thread (bool, optional): Save on different thread. Defaults to True. 473 | """ 474 | import os 475 | import soundfile as sf 476 | 477 | os.makedirs(os.path.dirname(filepath), exist_ok=True) 478 | 479 | if on_different_thread: 480 | import threading 481 | threading.Thread(target=lambda: sf.write(filepath, self.float, self.sample_rate)).start() 482 | else: 483 | sf.write(filepath, self.float, self.sample_rate) --------------------------------------------------------------------------------