├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── examples ├── http │ ├── client.py │ └── server.py ├── local │ ├── llms.py │ ├── multilingual.py │ └── run.py └── websocket │ ├── client.py │ └── server.py ├── litests ├── __init__.py ├── adapter │ ├── __init__.py │ ├── audiodevice.py │ ├── base.py │ ├── http.py │ └── websocket.py ├── llm │ ├── __init__.py │ ├── base.py │ ├── chatgpt.py │ ├── claude.py │ ├── context_manager │ │ ├── __init__.py │ │ ├── base.py │ │ └── postgres.py │ ├── dify.py │ ├── gemini.py │ └── litellm.py ├── models.py ├── performance_recorder │ ├── __init__.py │ ├── base.py │ ├── postgres.py │ └── sqlite.py ├── pipeline.py ├── stt │ ├── __init__.py │ ├── azure.py │ ├── base.py │ ├── google.py │ └── openai.py ├── tts │ ├── __init__.py │ ├── azure.py │ ├── base.py │ ├── google.py │ ├── openai.py │ ├── speech_gateway.py │ └── voicevox.py ├── vad │ ├── __init__.py │ ├── base.py │ └── standard.py └── voice_recorder │ ├── __init__.py │ ├── azure_storage.py │ ├── base.py │ └── file.py ├── requirements.txt ├── setup.py └── tests ├── llm ├── context_manager │ ├── test_context_manager.py │ └── test_pg_context_manager.py ├── test_chatgpt.py ├── test_chatgpt_azure.py ├── test_claude.py ├── test_dify.py ├── test_gemini.py └── test_litellm.py ├── stt ├── data │ ├── hello.wav │ └── hello_en.wav ├── test_azure.py ├── test_google.py └── test_openai.py ├── test_pipeline.py ├── tts ├── test_azure_tts.py ├── test_google_tts.py ├── test_openai_tts.py ├── test_speech_gateway.py └── test_voicevox.py └── vad └── test_standard.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /examples/http/client.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import json 4 | from uuid import uuid4 5 | import wave 6 | import httpx 7 | import pyaudio 8 | 9 | 10 | class AudioPlayer: 11 | def __init__(self, chunk_size: int = 1024): 12 | self.p = pyaudio.PyAudio() 13 | self.play_stream = None 14 | self.chunk_size = chunk_size 15 | 16 | def play(self, content: bytes): 17 | with wave.open(io.BytesIO(content), "rb") as wf: 18 | if not self.play_stream: 19 | self.play_stream = self.p.open( 20 | format=self.p.get_format_from_width(wf.getsampwidth()), 21 | channels=wf.getnchannels(), 22 | rate=wf.getframerate(), 23 | output=True, 24 | ) 25 | 26 | data = wf.readframes(self.chunk_size) 27 | while True: 28 | data = wf.readframes(self.chunk_size) 29 | if not data: 30 | break 31 | self.play_stream.write(data) 32 | 33 | 34 | audio_player = AudioPlayer() 35 | context_id = str(uuid4()) 36 | 37 | while True: 38 | user_input = input("user: ") 39 | if not user_input.strip(): 40 | continue 41 | 42 | with httpx.stream( 43 | method="post", 44 | url="http://127.0.0.1:8000/chat", 45 | json={ 46 | "type": "start", 47 | "context_id": context_id, 48 | "text": user_input 49 | }, 50 | timeout=60 51 | ) as resp: 52 | resp.raise_for_status() 53 | 54 | for chunk in resp.iter_lines(): 55 | if chunk.startswith("data:"): 56 | chunk_json = json.loads(chunk[5:].strip()) 57 | if chunk_json["type"] == "chunk": 58 | print(f"assistant: {chunk_json['text']}") 59 | if chunk_json["encoded_audio"]: 60 | audio_bytes = base64.b64decode(chunk_json["encoded_audio"]) 61 | audio_player.play(audio_bytes) 62 | -------------------------------------------------------------------------------- /examples/http/server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from fastapi import FastAPI 3 | from litests import LiteSTS 4 | from litests.vad import SpeechDetectorDummy 5 | from litests.adapter.http import HttpAdapter 6 | 7 | OPENAI_API_KEY = "YOUR_API_KEY" 8 | GOOGLE_API_KEY = "YOUR_API_KEY" 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | # Create pipeline 13 | sts = LiteSTS( 14 | vad=SpeechDetectorDummy(), # Disable VAD 15 | stt_google_api_key=GOOGLE_API_KEY, 16 | llm_openai_api_key=OPENAI_API_KEY, 17 | debug=True 18 | ) 19 | 20 | # Set HTTP adapter 21 | adapter = HttpAdapter(sts) 22 | router = adapter.get_api_router() 23 | 24 | # Start HTTP server 25 | app = FastAPI() 26 | app.include_router(router) 27 | -------------------------------------------------------------------------------- /examples/local/llms.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from litests import LiteSTS 3 | from litests.llm.chatgpt import ChatGPTService 4 | from litests.llm.gemini import GeminiService 5 | from litests.llm.claude import ClaudeService 6 | from litests.llm.dify import DifyService 7 | from litests.llm.litellm import LiteLLMService 8 | from litests.adapter.audiodevice import AudioDeviceAdapter 9 | 10 | GOOGLE_API_KEY = "GOOGLE_API_KEY" 11 | 12 | OPENAI_API_KEY = "OPENAI_API_KEY" 13 | GEMINI_API_KEY = "GEMINI_API_KEY" 14 | CLAUDE_API_KEY = "CLAUDE_API_KEY" 15 | DIFY_API_KEY = "DIFY_API_KEY" 16 | OTHERLLM_API_KEY = "OTHERLLM_API_KEY" 17 | 18 | 19 | chatgpt = ChatGPTService( 20 | openai_api_key=OPENAI_API_KEY 21 | ) 22 | 23 | gemini = GeminiService( 24 | gemini_api_key=GEMINI_API_KEY 25 | ) 26 | 27 | claude = ClaudeService( 28 | anthropic_api_key=CLAUDE_API_KEY 29 | ) 30 | 31 | dify = DifyService( 32 | api_key=DIFY_API_KEY, 33 | user="dify_user", 34 | base_url="your_dify_url", 35 | # is_agent_mode=True, # True when type of app is agent 36 | ) 37 | 38 | litellm = LiteLLMService( 39 | api_key=OTHERLLM_API_KEY, 40 | model="llm_service/llm_model_name", 41 | ) 42 | 43 | sts = LiteSTS( 44 | vad_volume_db_threshold=-40, # Adjust microphone sensitivity (Gate) 45 | stt_google_api_key=GOOGLE_API_KEY, 46 | llm=chatgpt, # <- Select LLM service you want to use 47 | debug=True 48 | ) 49 | 50 | # Create adapter 51 | adapter = AudioDeviceAdapter( 52 | sts, 53 | cancel_echo=True # Set False if you want to interrupt AI's answer 54 | ) 55 | 56 | async def quick_start_main(): 57 | # Uncomment below when you use GeminiService 58 | # await gemini.preflight() 59 | 60 | # Start listening 61 | await adapter.start_listening("_") 62 | 63 | asyncio.run(quick_start_main()) 64 | -------------------------------------------------------------------------------- /examples/local/multilingual.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from litests import LiteSTS 3 | from litests.stt.openai import OpenAISpeechRecognizer 4 | from litests.tts.openai import OpenAISpeechSynthesizer 5 | from litests.adapter.audiodevice import AudioDeviceAdapter 6 | 7 | OPENAI_API_KEY = "YOUR_API_KEY" 8 | 9 | 10 | # STT 11 | stt = OpenAISpeechRecognizer( 12 | openai_api_key=OPENAI_API_KEY, 13 | alternative_languages=["en-US", "zh-CN"] 14 | ) 15 | 16 | # TTS 17 | tts = OpenAISpeechSynthesizer( 18 | openai_api_key=OPENAI_API_KEY, 19 | speaker="shimmer", 20 | ) 21 | 22 | # Create STS pipeline 23 | sts = LiteSTS( 24 | vad_volume_db_threshold=-40, # Adjust microphone sensitivity (Gate) 25 | stt=stt, 26 | llm_openai_api_key=OPENAI_API_KEY, 27 | # Azure OpenAI 28 | # llm_model="azure", 29 | # llm_base_url="https://{your_resource_name}.openai.azure.com/openai/deployments/{your_deployment_name}/chat/completions?api-version={api_version}", 30 | tts=tts, 31 | debug=True 32 | ) 33 | 34 | # Create adapter 35 | adapter = AudioDeviceAdapter( 36 | sts, 37 | cancel_echo=True # Set False if you want to interrupt AI's answer 38 | ) 39 | 40 | asyncio.run(adapter.start_listening("_")) 41 | -------------------------------------------------------------------------------- /examples/local/run.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from litests import LiteSTS 3 | from litests.adapter.audiodevice import AudioDeviceAdapter 4 | 5 | OPENAI_API_KEY = "YOUR_API_KEY" 6 | GOOGLE_API_KEY = "YOUR_API_KEY" 7 | 8 | # Create STS pipeline 9 | sts = LiteSTS( 10 | vad_volume_db_threshold=-40, # Adjust microphone sensitivity (Gate) 11 | stt_google_api_key=GOOGLE_API_KEY, 12 | llm_openai_api_key=OPENAI_API_KEY, 13 | # Azure OpenAI 14 | # llm_model="azure", 15 | # llm_base_url="https://{your_resource_name}.openai.azure.com/openai/deployments/{your_deployment_name}/chat/completions?api-version={api_version}", 16 | debug=True 17 | ) 18 | 19 | # Create adapter 20 | adapter = AudioDeviceAdapter( 21 | sts, 22 | cancel_echo=True # Set False if you want to interrupt AI's answer 23 | ) 24 | 25 | # Start listening 26 | asyncio.run(adapter.start_listening("_")) 27 | -------------------------------------------------------------------------------- /examples/websocket/client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import io 4 | import json 5 | import queue 6 | import threading 7 | import uuid 8 | import wave 9 | import pyaudio 10 | import websockets 11 | 12 | WS_URL = "ws://localhost:8000/ws" 13 | CANCEL_ECHO = True 14 | 15 | 16 | class AudioPlayer: 17 | def __init__(self, chunk_size: int = 1024): 18 | self.queue = queue.Queue() 19 | 20 | self.thread = threading.Thread(target=self.process_queue, daemon=True) 21 | self.thread.start() 22 | 23 | self.to_wave = None 24 | self.p = pyaudio.PyAudio() 25 | self.play_stream = None 26 | self.chunk_size = chunk_size 27 | 28 | self.is_playing = False 29 | 30 | def play(self, content: bytes): 31 | try: 32 | self.is_playing = True 33 | 34 | if self.to_wave: 35 | wave_content = self.to_wave(content) 36 | else: 37 | wave_content = content 38 | 39 | with wave.open(io.BytesIO(wave_content), "rb") as wf: 40 | if not self.play_stream: 41 | self.play_stream = self.p.open( 42 | format=self.p.get_format_from_width(wf.getsampwidth()), 43 | channels=wf.getnchannels(), 44 | rate=wf.getframerate(), 45 | output=True, 46 | ) 47 | 48 | data = wf.readframes(self.chunk_size) 49 | while True: 50 | data = wf.readframes(self.chunk_size) 51 | if not data: 52 | break 53 | self.play_stream.write(data) 54 | 55 | finally: 56 | self.is_playing = False 57 | 58 | def process_queue(self): 59 | while True: 60 | data = self.queue.get() 61 | if data is None: 62 | break 63 | 64 | self.play(data) 65 | 66 | def add(self, audio_bytes: bytes): 67 | self.queue.put(audio_bytes) 68 | 69 | def cancel(self): 70 | while not self.queue.empty(): 71 | self.queue.get() 72 | 73 | def stop(self): 74 | self.queue.put(None) 75 | self.thread.join() 76 | if self.play_stream: 77 | self.play_stream.stop_stream() 78 | self.play_stream.close() 79 | self.p.terminate() 80 | 81 | audio_player = AudioPlayer() 82 | 83 | 84 | # Send microphone data to Speech-to-Speech server 85 | async def send_microphone_data(ws, session_id: str, cancel_echo: bool): 86 | p = pyaudio.PyAudio() 87 | 88 | mic_stream = p.open( 89 | format=pyaudio.paInt16, # LINEAR16 90 | channels=1, # Mono 91 | rate=16000, 92 | input=True, 93 | frames_per_buffer=512 94 | ) 95 | 96 | while True: 97 | data = mic_stream.read(512, exception_on_overflow=False) 98 | b64_data = base64.b64encode(data).decode("utf-8") 99 | 100 | if not (audio_player.is_playing and cancel_echo): 101 | await ws.send(json.dumps({ 102 | "type": "data", 103 | "session_id": session_id, 104 | "audio_data": b64_data 105 | })) 106 | 107 | await asyncio.sleep(0.01) 108 | 109 | 110 | # Receive and play audio from Speech-to-Speech server 111 | async def receive_and_play_audio(ws): 112 | while True: 113 | message_str = await ws.recv() 114 | message = json.loads(message_str) 115 | message_type = message.get("type") 116 | 117 | if message_type == "chunk": 118 | print(f"Response: {message.get('text')}") 119 | b64_data = message.get("audio_data") 120 | if b64_data: 121 | audio_bytes = base64.b64decode(b64_data) 122 | audio_player.add(audio_bytes) 123 | 124 | elif message_type == "stop": 125 | print(f"Stop requested") 126 | audio_player.cancel() 127 | 128 | 129 | async def main(): 130 | async with websockets.connect(WS_URL) as ws: 131 | session_id = str(uuid.uuid4()) 132 | 133 | # Send start message 134 | start_message = { 135 | "type": "start", 136 | "session_id": session_id 137 | } 138 | await ws.send(json.dumps(start_message)) 139 | 140 | print(f"Connected: {session_id}") 141 | 142 | # Start send and receive task 143 | send_task = asyncio.create_task(send_microphone_data(ws, session_id, CANCEL_ECHO)) 144 | receive_task = asyncio.create_task(receive_and_play_audio(ws)) 145 | await asyncio.gather(send_task, receive_task) 146 | 147 | # Send stop message 148 | await ws.send(json.dumps({ 149 | "type": "stop", 150 | "session_id": session_id 151 | })) 152 | 153 | 154 | if __name__ == "__main__": 155 | asyncio.run(main()) 156 | -------------------------------------------------------------------------------- /examples/websocket/server.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from fastapi import FastAPI, WebSocket 3 | from litests import LiteSTS 4 | from litests.adapter.websocket import WebSocketAdapter, WebSocketSessionData 5 | 6 | OPENAI_API_KEY = "YOUR_API_KEY" 7 | GOOGLE_API_KEY = "YOUR_API_KEY" 8 | 9 | 10 | # Adapter for sending back to websocket client 11 | class MyWebSocketAdapter(WebSocketAdapter): 12 | async def process_websocket(self, websocket: WebSocket, session_data: WebSocketSessionData): 13 | message = await websocket.receive_json() 14 | message_type = message.get("type") 15 | session_id = message["session_id"] 16 | 17 | if message_type == "start": 18 | print(f"Connected: {session_id}") 19 | self.websockets[session_id] = websocket 20 | 21 | elif message_type == "data": 22 | b64_audio_data = message["audio_data"] 23 | audio_data = base64.b64decode(b64_audio_data) 24 | await self.sts.process_audio_samples(audio_data, session_id) 25 | 26 | elif message_type == "stop": 27 | print("stop") 28 | await websocket.close() 29 | 30 | async def handle_response(self, response): 31 | if response.type == "chunk" and response.audio_data: 32 | b64_chunk = base64.b64encode(response.audio_data).decode("utf-8") 33 | await self.websockets[response.context_id].send_json({ 34 | "type": "chunk", 35 | "session_id": response.context_id, 36 | "text": response.text, 37 | "audio_data": b64_chunk 38 | }) 39 | 40 | async def stop_response(self, context_id): 41 | if context_id in self.websockets: 42 | await self.websockets[context_id].send_json({ 43 | "type": "stop", 44 | "session_id": context_id, 45 | }) 46 | 47 | 48 | # Create Speech-to-Speech pipeline 49 | sts = LiteSTS( 50 | vad_volume_db_threshold=-30, # Adjust microphone sensitivity (Gate) 51 | stt_google_api_key=GOOGLE_API_KEY, 52 | llm_openai_api_key=OPENAI_API_KEY, 53 | debug=True 54 | ) 55 | 56 | # Set adapter 57 | adapter = MyWebSocketAdapter(sts) 58 | router = adapter.get_websocket_router() 59 | 60 | # Start websocket server 61 | app = FastAPI() 62 | app.include_router(router) 63 | -------------------------------------------------------------------------------- /litests/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import LiteSTS 2 | -------------------------------------------------------------------------------- /litests/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Adapter 2 | -------------------------------------------------------------------------------- /litests/adapter/audiodevice.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import io 3 | import logging 4 | import queue 5 | import threading 6 | from typing import AsyncGenerator 7 | import wave 8 | import pyaudio 9 | from ..models import STSResponse 10 | from ..pipeline import LiteSTS 11 | from .base import Adapter 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class AudioDeviceAdapter(Adapter): 17 | def __init__( 18 | self, 19 | sts: LiteSTS = None, 20 | *, 21 | input_sample_rate: int = 16000, 22 | input_channels: int = 1, 23 | input_chunk_size: int = 512, 24 | output_chunk_size: int = 1024, 25 | cancel_echo: bool = True 26 | ): 27 | super().__init__(sts) 28 | 29 | # Microphpne 30 | self.input_sample_rate = input_sample_rate 31 | self.input_channels = input_channels 32 | self.input_chunk_size = input_chunk_size 33 | 34 | # Audio player 35 | self.to_wave = None 36 | self.p = pyaudio.PyAudio() 37 | self.play_stream = None 38 | self.wave_params = None 39 | self.output_chunk_size = output_chunk_size 40 | 41 | # Echo cancellation 42 | self.cancel_echo = cancel_echo 43 | self.is_playing_locally = False 44 | self.sts.vad.should_mute = lambda: self.cancel_echo and self.is_playing_locally 45 | 46 | # Response handler 47 | self.stop_event = threading.Event() 48 | self.response_queue: queue.Queue[bytes] = queue.Queue() 49 | self.response_handler_thread = threading.Thread(target=self.audio_player_worker, daemon=True) 50 | self.response_handler_thread.start() 51 | 52 | # Request 53 | async def start_listening(self, session_id: str, user_id: str = None): 54 | async def start_microphone_stream() -> AsyncGenerator[bytes, None]: 55 | p = pyaudio.PyAudio() 56 | pyaudio_stream = p.open( 57 | rate=self.input_sample_rate, 58 | channels=self.input_channels, 59 | format=pyaudio.paInt16, 60 | input=True, 61 | frames_per_buffer=self.input_chunk_size 62 | ) 63 | while True: 64 | yield pyaudio_stream.read(self.input_chunk_size) 65 | await asyncio.sleep(0.0001) 66 | 67 | if user_id: 68 | self.sts.vad.set_session_data(session_id, "user_id", user_id, True) 69 | await self.sts.vad.process_stream(start_microphone_stream(), session_id) 70 | 71 | # Response 72 | def audio_player_worker(self): 73 | while True: 74 | try: 75 | audio_data = self.response_queue.get() 76 | self.is_playing_locally = True 77 | wave_content = self.to_wave(audio_data) \ 78 | if self.to_wave else audio_data 79 | 80 | with wave.open(io.BytesIO(wave_content), "rb") as wf: 81 | current_params = wf.getparams() 82 | if not self.play_stream or self.wave_params != current_params: 83 | self.wave_params = current_params 84 | self.play_stream = self.p.open( 85 | format=self.p.get_format_from_width(self.wave_params.sampwidth), 86 | channels=self.wave_params.nchannels, 87 | rate=self.wave_params.framerate, 88 | output=True, 89 | ) 90 | 91 | data = wf.readframes(self.output_chunk_size) 92 | while True: 93 | data = wf.readframes(self.output_chunk_size) 94 | if not data: 95 | break 96 | self.play_stream.write(data) 97 | 98 | except Exception as ex: 99 | logger.error(f"Error processing audio data: {ex}", exc_info=True) 100 | 101 | finally: 102 | self.is_playing_locally = False 103 | self.response_queue.task_done() 104 | 105 | async def handle_response(self, response: STSResponse): 106 | if response.type == "chunk" and response.audio_data: 107 | self.response_queue.put(response.audio_data) 108 | elif response.type == "stop": 109 | await self.stop_response() 110 | 111 | async def stop_response(self, session_id: str, context_id: str): 112 | while not self.response_queue.empty(): 113 | try: 114 | _ = self.response_queue.get_nowait() 115 | self.response_queue.task_done() 116 | except: 117 | break 118 | 119 | def close(self): 120 | self.stop_event.set() 121 | self.stop_response() 122 | self.response_handler_thread.join() 123 | -------------------------------------------------------------------------------- /litests/adapter/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from ..models import STSResponse 3 | from ..pipeline import LiteSTS 4 | 5 | 6 | class Adapter(ABC): 7 | def __init__(self, sts: LiteSTS): 8 | self.sts = sts 9 | self.sts.handle_response = self.handle_response 10 | self.sts.stop_response = self.stop_response 11 | 12 | @abstractmethod 13 | async def handle_response(self, response: STSResponse): 14 | pass 15 | 16 | @abstractmethod 17 | async def stop_response(self, session_id: str, context_id: str): 18 | pass 19 | -------------------------------------------------------------------------------- /litests/adapter/http.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import logging 4 | from typing import List 5 | from uuid import uuid4 6 | from pydantic import BaseModel 7 | from fastapi import APIRouter 8 | from sse_starlette.sse import EventSourceResponse # pip install sse-starlette 9 | from litests import LiteSTS 10 | from litests.models import STSRequest, STSResponse 11 | from litests.adapter import Adapter 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class File(BaseModel): 17 | url: str 18 | 19 | 20 | class ChatRequest(BaseModel): 21 | context_id: str = None 22 | text: str = None 23 | audio_data: bytes = None 24 | audio_duration: float = 0 25 | files: List[File] = None 26 | 27 | def to_sts_request(self) -> STSRequest: 28 | return STSRequest( 29 | context_id=self.context_id, 30 | text=self.text, 31 | audio_data=self.audio_data, 32 | audio_duration=self.audio_duration, 33 | files=[{"url": f.url} for f in self.files] if self.files else None 34 | ) 35 | 36 | 37 | class ChatChunkResponse(BaseModel): 38 | type: str 39 | context_id: str 40 | text: str|None = None 41 | voice_text: str|None = None 42 | tool_call: str|None = None 43 | encoded_audio: str|None = None 44 | 45 | 46 | class HttpAdapter(Adapter): 47 | def __init__(self, sts: LiteSTS): 48 | super().__init__(sts) 49 | 50 | def get_api_router(self, path: str = "/chat"): 51 | router = APIRouter() 52 | 53 | @router.post(path) 54 | async def post_chat(request: ChatRequest): 55 | if not request.context_id: 56 | request.context_id = str(uuid4()) 57 | 58 | async def stream_response(): 59 | async for chunk in self.sts.invoke(request.to_sts_request()): 60 | try: 61 | if chunk.audio_data: 62 | b64_audio= base64.b64encode(chunk.audio_data).decode("utf-8") 63 | else: 64 | b64_audio = None 65 | 66 | yield ChatChunkResponse( 67 | type=chunk.type, 68 | context_id=chunk.context_id, 69 | text=chunk.text, 70 | voice_text=chunk.voice_text, 71 | tool_call=json.dumps(chunk.tool_call.__dict__) if chunk.tool_call else None, 72 | encoded_audio=b64_audio 73 | ).model_dump_json() 74 | 75 | except Exception as ex: 76 | logger.error(f"Error at HTTP adapter: {ex}") 77 | 78 | return EventSourceResponse(stream_response()) 79 | 80 | return router 81 | 82 | async def handle_response(self, response: STSResponse): 83 | pass 84 | 85 | async def stop_response(self, context_id: str): 86 | pass 87 | -------------------------------------------------------------------------------- /litests/adapter/websocket.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import logging 3 | from typing import Dict 4 | from fastapi import APIRouter, WebSocket 5 | from ..pipeline import LiteSTS 6 | from .base import Adapter 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class WebSocketSessionData: 12 | def __init__(self): 13 | self.id = None 14 | self.data = {} 15 | 16 | 17 | class WebSocketAdapter(Adapter): 18 | def __init__( 19 | self, 20 | sts: LiteSTS = None 21 | ): 22 | super().__init__(sts) 23 | self.websockets: Dict[str, WebSocket] = {} 24 | self.sessions: Dict[str, WebSocketSessionData] = {} 25 | 26 | @abstractmethod 27 | async def process_websocket(self, websocket: WebSocket, session_data: WebSocketSessionData): 28 | pass 29 | 30 | def get_websocket_router(self, path: str = "/ws"): 31 | router = APIRouter() 32 | 33 | @router.websocket(path) 34 | async def websocket_endpoint(websocket: WebSocket): 35 | await websocket.accept() 36 | session_data = WebSocketSessionData() 37 | 38 | try: 39 | while True: 40 | await self.process_websocket(websocket, session_data) 41 | 42 | except Exception as ex: 43 | error_message = str(ex) 44 | 45 | if "WebSocket is not connected" in error_message: 46 | logger.info(f"WebSocket disconnected (1): context_id={session_data.id}") 47 | elif "" in error_message: 48 | logger.info(f"WebSocket disconnected (2): context_id={session_data.id}") 49 | else: 50 | raise 51 | 52 | finally: 53 | if session_data.id: 54 | await self.sts.finalize(session_data.id) 55 | if session_data.id in self.websockets: 56 | del self.websockets[session_data.id] 57 | if session_data.id in self.sessions: 58 | del self.sessions[session_data.id] 59 | 60 | return router 61 | -------------------------------------------------------------------------------- /litests/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import LLMService, LLMResponse, ToolCall, Tool 2 | -------------------------------------------------------------------------------- /litests/llm/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import asyncio 3 | import inspect 4 | import logging 5 | import re 6 | from typing import AsyncGenerator, List, Dict, Any, Callable, Optional 7 | from .context_manager import ContextManager, SQLiteContextManager 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class ToolCall: 13 | def __init__(self, id: str = None, name: str = None, arguments: any = None): 14 | self.id = id 15 | self.name = name 16 | self.arguments = arguments 17 | 18 | 19 | class LLMResponse: 20 | def __init__(self, context_id: str, text: str = None, voice_text: str = None, tool_call: ToolCall = None): 21 | self.context_id = context_id 22 | self.text = text 23 | self.voice_text = voice_text 24 | self.tool_call = tool_call 25 | 26 | 27 | class Tool: 28 | def __init__(self, name: str, spec: Dict[str, Any], func: Callable, instruction: str = None, is_dynamic: bool = False): 29 | self.name = name 30 | self.spec = spec 31 | self.func = func 32 | self.instruction = instruction 33 | self.is_dynamic = is_dynamic 34 | 35 | 36 | class LLMService(ABC): 37 | def __init__( 38 | self, 39 | *, 40 | system_prompt: str, 41 | model: str, 42 | temperature: float = 0.5, 43 | split_chars: List[str] = None, 44 | option_split_chars: List[str] = None, 45 | option_split_threshold: int = 50, 46 | voice_text_tag: str = None, 47 | use_dynamic_tools: bool = False, 48 | context_manager: ContextManager = None, 49 | debug: bool = False 50 | ): 51 | self.system_prompt = system_prompt 52 | self.model = model 53 | self.temperature = temperature 54 | self.split_chars = split_chars or ["。", "?", "!", ". ", "?", "!"] 55 | self.option_split_chars = option_split_chars or ["、", ", "] 56 | self.option_split_threshold = option_split_threshold 57 | self.split_patterns = [] 58 | for char in self.option_split_chars: 59 | if char.endswith(" "): 60 | self.split_patterns.append(f"{re.escape(char)}") 61 | else: 62 | self.split_patterns.append(f"{re.escape(char)}\s?") 63 | self.option_split_chars_regex = f"({'|'.join(self.split_patterns)})\s*(?!.*({'|'.join(self.split_patterns)}))" 64 | self._request_filter = self.request_filter_default 65 | self.voice_text_tag = voice_text_tag 66 | self.tools: Dict[str, Tool] = {} 67 | self.use_dynamic_tools = use_dynamic_tools 68 | self.dynamic_tool_instruction = """ 69 | 70 | ## Important: Use of `{dynamic_tool_name}` 71 | 72 | Output **only Function / Tool call** parts. Exclude text content. 73 | 74 | """ 75 | self.additional_prompt_for_tool_listing = """ 76 | ---- 77 | Extract up to five tools that could be used to process the above user input. 78 | The response should follow this format. If multiple tools apply, separate them with commas. 79 | 80 | [tools:{tool_name},{tool_name},{tool_name}] 81 | 82 | If none apply, respond as follows: 83 | 84 | [tool_name:NOT_FOUND] 85 | 86 | The list of tools is as follows: 87 | 88 | """ 89 | self._get_dynamic_tools = self.get_dynamic_tools_default 90 | self._on_before_tool_calls = self.on_before_tool_calls_default 91 | self.context_manager = context_manager or SQLiteContextManager() 92 | self.debug = debug 93 | 94 | # Decorators 95 | def request_filter(self, func): 96 | self._request_filter = func 97 | return func 98 | 99 | def request_filter_default(self, text: str) -> str: 100 | return text 101 | 102 | def tool(self, spec): 103 | def decorator(func): 104 | return func 105 | return decorator 106 | 107 | async def get_dynamic_tools(self, func): 108 | self._get_dynamic_tools = func 109 | return func 110 | 111 | def on_before_tool_calls(self, func): 112 | self._on_before_tool_calls = func 113 | return func 114 | 115 | async def on_before_tool_calls_default(self, tool_calls: List[ToolCall]): 116 | pass 117 | 118 | def replace_last_option_split_char(self, original): 119 | return re.sub(self.option_split_chars_regex, r"\1|", original) 120 | 121 | def get_system_prompt(self, system_prompt_params: Dict[str, any]): 122 | if not system_prompt_params: 123 | return self.system_prompt 124 | else: 125 | return self.system_prompt.format(**system_prompt_params) 126 | 127 | @abstractmethod 128 | async def compose_messages(self, context_id: str, text: str, files: List[Dict[str, str]] = None, system_prompt_params: Dict[str, any] = None) -> List[Dict]: 129 | pass 130 | 131 | @abstractmethod 132 | async def update_context(self, context_id: str, messages: List[Dict], response_text: str): 133 | pass 134 | 135 | async def get_dynamic_tools_default(self, messages: List[dict], metadata: Dict[str, any] = None) -> List[Dict[str, any]]: 136 | return [] 137 | 138 | @abstractmethod 139 | async def get_llm_stream_response(self, context_id: str, user_id: str, messages: List[dict], system_prompt_params: Dict[str, any] = None) -> AsyncGenerator[LLMResponse, None]: 140 | pass 141 | 142 | def remove_control_tags(self, text: str) -> str: 143 | clean_text = text 144 | clean_text = re.sub(r"\[(\w+):([^\]]+)\]", "", clean_text) 145 | clean_text = clean_text.strip() 146 | return clean_text 147 | 148 | async def execute_tool(self, name: str, arguments: dict, metadata: dict): 149 | tool = self.tools[name] 150 | if "metadata" in inspect.signature(tool.func).parameters: 151 | arguments["metadata"] = metadata 152 | return await tool.func(**arguments) 153 | 154 | async def chat_stream(self, context_id: str, user_id: str, text: str, files: List[Dict[str, str]] = None, system_prompt_params: Dict[str, any] = None) -> AsyncGenerator[LLMResponse, None]: 155 | logger.info(f"User: {text}") 156 | text = self._request_filter(text) 157 | logger.info(f"User(Filtered): {text}") 158 | 159 | if not text and not files: 160 | return 161 | 162 | messages = await self.compose_messages(context_id, text, files, system_prompt_params) 163 | message_length_at_start = len(messages) - 1 164 | 165 | stream_buffer = "" 166 | response_text = "" 167 | 168 | in_voice_tag = False 169 | target_start = f"<{self.voice_text_tag}>" 170 | target_end = f"" 171 | 172 | def to_voice_text(segment: str) -> Optional[str]: 173 | if not self.voice_text_tag: 174 | return self.remove_control_tags(segment) 175 | 176 | nonlocal in_voice_tag 177 | if target_start in segment and target_end in segment: 178 | in_voice_tag = False 179 | start_index = segment.find(target_start) 180 | end_index = segment.find(target_end) 181 | voice_segment = segment[start_index + len(target_start): end_index] 182 | return self.remove_control_tags(voice_segment) 183 | 184 | elif target_start in segment: 185 | in_voice_tag = True 186 | start_index = segment.find(target_start) 187 | voice_segment = segment[start_index + len(target_start):] 188 | return self.remove_control_tags(voice_segment) 189 | 190 | elif target_end in segment: 191 | if in_voice_tag: 192 | in_voice_tag = False 193 | end_index = segment.find(target_end) 194 | voice_segment = segment[:end_index] 195 | return self.remove_control_tags(voice_segment) 196 | 197 | elif in_voice_tag: 198 | return self.remove_control_tags(segment) 199 | 200 | return None 201 | 202 | async for chunk in self.get_llm_stream_response(context_id, user_id, messages, system_prompt_params): 203 | if chunk.tool_call: 204 | yield chunk 205 | continue 206 | 207 | stream_buffer += chunk.text 208 | 209 | for spc in self.split_chars: 210 | stream_buffer = stream_buffer.replace(spc, spc + "|") 211 | 212 | if len(stream_buffer) > self.option_split_threshold: 213 | stream_buffer = self.replace_last_option_split_char(stream_buffer) 214 | 215 | segments = stream_buffer.split("|") 216 | while len(segments) > 1: 217 | sentence = segments.pop(0) 218 | stream_buffer = "|".join(segments) 219 | voice_text = to_voice_text(sentence) 220 | yield LLMResponse(context_id, sentence, voice_text) 221 | response_text += sentence 222 | segments = stream_buffer.split("|") 223 | 224 | await asyncio.sleep(0.001) # wait slightly in every loop not to use up CPU 225 | 226 | if stream_buffer: 227 | voice_text = to_voice_text(stream_buffer) 228 | yield LLMResponse(context_id, stream_buffer, voice_text) 229 | response_text += stream_buffer 230 | 231 | logger.info(f"AI: {response_text}") 232 | if len(messages) > message_length_at_start: 233 | await self.update_context( 234 | context_id, 235 | messages[message_length_at_start - len(messages):], 236 | response_text, 237 | ) 238 | -------------------------------------------------------------------------------- /litests/llm/chatgpt.py: -------------------------------------------------------------------------------- 1 | import json 2 | from logging import getLogger 3 | import re 4 | from typing import AsyncGenerator, Dict, List 5 | from urllib.parse import urlparse, parse_qs 6 | import openai 7 | from . import LLMService, LLMResponse, ToolCall, Tool 8 | from .context_manager import ContextManager 9 | 10 | logger = getLogger(__name__) 11 | 12 | 13 | class ChatGPTService(LLMService): 14 | def __init__( 15 | self, 16 | *, 17 | openai_api_key: str = None, 18 | system_prompt: str = None, 19 | base_url: str = None, 20 | model: str = "gpt-4o", 21 | temperature: float = 0.5, 22 | split_chars: List[str] = None, 23 | option_split_chars: List[str] = None, 24 | option_split_threshold: int = 50, 25 | voice_text_tag: str = None, 26 | use_dynamic_tools: bool = False, 27 | context_manager: ContextManager = None, 28 | debug: bool = False 29 | ): 30 | super().__init__( 31 | system_prompt=system_prompt, 32 | model=model, 33 | temperature=temperature, 34 | split_chars=split_chars, 35 | option_split_chars=option_split_chars, 36 | option_split_threshold=option_split_threshold, 37 | voice_text_tag=voice_text_tag, 38 | use_dynamic_tools=use_dynamic_tools, 39 | context_manager=context_manager, 40 | debug=debug 41 | ) 42 | if "azure" in model: 43 | api_version = parse_qs(urlparse(base_url).query).get("api-version", [None])[0] 44 | self.openai_client = openai.AsyncAzureOpenAI( 45 | api_key=openai_api_key, 46 | api_version=api_version, 47 | base_url=base_url 48 | ) 49 | else: 50 | self.openai_client = openai.AsyncClient(api_key=openai_api_key, base_url=base_url) 51 | 52 | self.dynamic_tool_spec = { 53 | "type": "function", 54 | "function": { 55 | "name": "execute_external_tool", 56 | "description": "Execute the most appropriate tool based on the user's intent: what they want to do and to what.", 57 | "parameters": { 58 | "type": "object", 59 | "properties": { 60 | "target": { 61 | "type": "string", 62 | "description": "What the user wants to interact with (e.g., long-term memory, weather, music)." 63 | }, 64 | "action": { 65 | "type": "string", 66 | "description": "The type of operation to perform on the target (e.g., retrieve, look up, play)." 67 | } 68 | }, 69 | "required": ["target", "action"] 70 | } 71 | } 72 | } 73 | 74 | async def compose_messages(self, context_id: str, text: str, files: List[Dict[str, str]] = None, system_prompt_params: Dict[str, any] = None) -> List[Dict]: 75 | messages = [] 76 | if self.system_prompt: 77 | messages.append({"role": "system", "content": self.get_system_prompt(system_prompt_params)}) 78 | 79 | # Extract the history starting from the first message where the role is 'user' 80 | histories = await self.context_manager.get_histories(context_id) 81 | while histories and histories[0]["role"] != "user": 82 | histories.pop(0) 83 | messages.extend(histories) 84 | 85 | if files: 86 | content = [] 87 | for f in files: 88 | if url := f.get("url"): 89 | content.append({"type": "image_url", "image_url": {"url": url}}) 90 | if text: 91 | content.append({"type": "text", "text": text}) 92 | else: 93 | content = text 94 | messages.append({"role": "user", "content": content}) 95 | 96 | return messages 97 | 98 | async def update_context(self, context_id: str, messages: List[Dict], response_text: str): 99 | messages.append({"role": "assistant", "content": response_text}) 100 | await self.context_manager.add_histories(context_id, messages, "chatgpt") 101 | 102 | def tool(self, spec: Dict): 103 | def decorator(func): 104 | tool_name = spec["function"]["name"] 105 | self.tools[tool_name] = Tool( 106 | name=tool_name, 107 | spec=spec, 108 | func=func 109 | ) 110 | return func 111 | return decorator 112 | 113 | async def get_dynamic_tools_default(self, messages: List[dict], metadata: Dict[str, any] = None) -> List[Dict[str, any]]: 114 | # Make additional prompt with registered tools 115 | tool_listing_prompt = self.additional_prompt_for_tool_listing 116 | for _, t in self.tools.items(): 117 | tool_listing_prompt += f'- {t.name}: {t.spec["function"]["description"]}\n' 118 | tool_listing_prompt += "- NOT_FOUND: Use this if no suitable tools are found.\n" 119 | 120 | # Build user message content 121 | user_content = messages[-1]["content"] 122 | if isinstance(user_content, list): 123 | user_content_for_tool = [] 124 | text_updated = False 125 | for c in user_content: 126 | content_type = c["type"] 127 | if content_type == "text" and not text_updated: 128 | # Update text content 129 | user_content_for_tool.append({"type": "text", "text": c["text"] + tool_listing_prompt}) 130 | text_updated = True 131 | else: 132 | # Use original non-text content (e.g. image) 133 | user_content_for_tool.append(c) 134 | # Add text content if no text content are found 135 | if not text_updated: 136 | user_content_for_tool.append({"type": "text", "text": tool_listing_prompt}) 137 | elif isinstance(user_content, str): 138 | user_content_for_tool = user_content + tool_listing_prompt 139 | 140 | # Call LLM to filter tools 141 | tool_choice_resp = await self.openai_client.chat.completions.create( 142 | messages=messages[:-1] + [{"role": "user", "content": user_content_for_tool}], 143 | model=self.model, 144 | temperature=0.0 145 | ) 146 | 147 | # Parse tools from response 148 | if match := re.search(r"\[tools:(.*?)\]", tool_choice_resp.choices[0].message.content): 149 | tool_names = match.group(1) 150 | else: 151 | tool_names = "NOT_FOUND" 152 | 153 | tools = [] 154 | for t in tool_names.split(","): 155 | if tool := self.tools.get(t.strip()): 156 | tools.append(tool.spec) 157 | 158 | return tools 159 | 160 | async def get_llm_stream_response(self, context_id: str, user_id: str, messages: List[Dict], system_prompt_params: Dict[str, any] = None, tools: List[Dict[str, any]] = None) -> AsyncGenerator[LLMResponse, None]: 161 | # Select tools to use 162 | tool_instruction = "" 163 | if tools: 164 | filtered_tools = tools 165 | for t in filtered_tools: 166 | if ti := self.tools.get(t["function"]["name"]).instruction: 167 | tool_instruction += f"{ti}\n\n" 168 | elif self.use_dynamic_tools: 169 | filtered_tools = [self.dynamic_tool_spec] 170 | tool_instruction = self.dynamic_tool_instruction.format( 171 | dynamic_tool_name=self.dynamic_tool_spec["function"]["name"] 172 | ) 173 | else: 174 | filtered_tools = [t.spec for _, t in self.tools.items() if not t.is_dynamic] or None 175 | 176 | # Update system prompt 177 | if tool_instruction and messages[0]["role"] == "system": 178 | system_message_for_tool = {"role": "system", "content": messages[0]["content"] + tool_instruction} 179 | else: 180 | system_message_for_tool = messages[0] 181 | 182 | stream_resp = await self.openai_client.chat.completions.create( 183 | messages=[system_message_for_tool] + messages[1:], 184 | model=self.model, 185 | temperature=self.temperature, 186 | tools=filtered_tools, 187 | stream=True 188 | ) 189 | 190 | tool_calls: List[ToolCall] = [] 191 | try_dynamic_tools = False 192 | async for chunk in stream_resp: 193 | if not chunk.choices: 194 | continue 195 | 196 | if chunk.choices[0].delta.tool_calls: 197 | t = chunk.choices[0].delta.tool_calls[0] 198 | if t.id: 199 | tool_calls.append(ToolCall(t.id, t.function.name, "")) 200 | if t.function.name == self.dynamic_tool_spec["function"]["name"]: 201 | logger.info("Get dynamic tool") 202 | filtered_tools = await self._get_dynamic_tools(messages) 203 | logger.info(f"Dynamic tools: {filtered_tools}") 204 | try_dynamic_tools = True 205 | if t.function.arguments: 206 | tool_calls[-1].arguments += t.function.arguments 207 | 208 | elif content := chunk.choices[0].delta.content: 209 | if not try_dynamic_tools: 210 | yield LLMResponse(context_id=context_id, text=content) 211 | 212 | if tool_calls: 213 | # Do something before tool calls (e.g. say to user that it will take a long time) 214 | await self._on_before_tool_calls(tool_calls) 215 | 216 | # Execute tools 217 | messages_length = len(messages) 218 | for tc in tool_calls: 219 | if self.debug: 220 | logger.info(f"ToolCall: {tc.name}") 221 | yield LLMResponse(context_id=context_id, tool_call=tc) 222 | 223 | if tc.name == self.dynamic_tool_spec["function"]["name"]: 224 | if filtered_tools: 225 | tool_result = None 226 | else: 227 | tool_result = {"message": "No tools found"} 228 | else: 229 | tool_result = await self.execute_tool(tc.name, json.loads(tc.arguments), {"user_id": user_id}) 230 | if self.debug: 231 | logger.info(f"ToolCall result: {tool_result}") 232 | 233 | if tool_result: 234 | messages.append({ 235 | "role": "assistant", 236 | "tool_calls": [{ 237 | "id": tc.id, 238 | "type": "function", 239 | "function": { 240 | "name": tc.name, 241 | "arguments": tc.arguments 242 | } 243 | }] 244 | }) 245 | 246 | messages.append({ 247 | "role": "tool", 248 | "content": json.dumps(tool_result), 249 | "tool_call_id": tc.id 250 | }) 251 | 252 | if len(messages) > messages_length or try_dynamic_tools: 253 | # Generate human-friendly message that explains tool result 254 | async for llm_response in self.get_llm_stream_response( 255 | context_id, user_id, messages, system_prompt_params=system_prompt_params, tools=filtered_tools 256 | ): 257 | yield llm_response 258 | -------------------------------------------------------------------------------- /litests/llm/context_manager/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ContextManager, SQLiteContextManager 2 | -------------------------------------------------------------------------------- /litests/llm/context_manager/base.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import json 3 | import logging 4 | from datetime import datetime, timezone, timedelta 5 | from abc import ABC, abstractmethod 6 | from typing import List, Dict 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class ContextManager(ABC): 12 | @abstractmethod 13 | async def get_histories(self, context_id: str, limit: int = 100) -> List[Dict]: 14 | pass 15 | 16 | @abstractmethod 17 | async def add_histories(self, context_id: str, data_list: List[Dict], context_schema: str = None): 18 | pass 19 | 20 | @abstractmethod 21 | async def get_last_created_at(self, context_id: str) -> datetime: 22 | pass 23 | 24 | 25 | class SQLiteContextManager(ContextManager): 26 | def __init__(self, db_path="context.db", context_timeout=3600): 27 | self.db_path = db_path 28 | self.context_timeout = context_timeout 29 | self.init_db() 30 | 31 | def init_db(self): 32 | conn = sqlite3.connect(self.db_path) 33 | try: 34 | with conn: 35 | # Create table if not exists 36 | # (Primary key 'id' automatically gets indexed by SQLite) 37 | conn.execute( 38 | """ 39 | CREATE TABLE IF NOT EXISTS chat_histories ( 40 | id INTEGER PRIMARY KEY AUTOINCREMENT, 41 | created_at TIMESTAMP NOT NULL, 42 | context_id TEXT NOT NULL, 43 | serialized_data JSON NOT NULL, 44 | context_schema TEXT 45 | ) 46 | """ 47 | ) 48 | 49 | # Create an index to speed up filtering queries by context_id and created_at 50 | conn.execute( 51 | """ 52 | CREATE INDEX IF NOT EXISTS idx_chat_histories_context_id_created_at 53 | ON chat_histories (context_id, created_at) 54 | """ 55 | ) 56 | 57 | except Exception as ex: 58 | logger.error(f"Error at init_db: {ex}") 59 | finally: 60 | conn.close() 61 | 62 | async def get_histories(self, context_id: str, limit: int = 100) -> List[Dict]: 63 | conn = sqlite3.connect(self.db_path) 64 | try: 65 | sql = """ 66 | SELECT serialized_data 67 | FROM chat_histories 68 | WHERE context_id = ? 69 | """ 70 | params = [context_id] 71 | 72 | if self.context_timeout > 0: 73 | # Cutoff time to exclude old records 74 | sql += " AND created_at >= ?" 75 | cutoff_time = datetime.now(timezone.utc) - timedelta(seconds=self.context_timeout) 76 | params.append(cutoff_time) 77 | 78 | sql += " ORDER BY id DESC" 79 | 80 | if limit > 0: 81 | sql += " LIMIT ?" 82 | params.append(limit) 83 | 84 | cursor = conn.cursor() 85 | cursor.execute(sql, tuple(params)) 86 | rows = cursor.fetchall() 87 | 88 | # Reverse the list so that the newest item is at the end (larger index) 89 | rows.reverse() 90 | results = [json.loads(row[0]) for row in rows] 91 | return results 92 | 93 | except Exception as ex: 94 | logger.error(f"Error at get_histories: {ex}") 95 | return [] 96 | 97 | finally: 98 | conn.close() 99 | 100 | async def add_histories(self, context_id: str, data_list: List[Dict], context_schema: str = None): 101 | if not data_list: 102 | # If the list is empty, do nothing 103 | return 104 | 105 | conn = sqlite3.connect(self.db_path) 106 | try: 107 | # Prepare INSERT statement 108 | columns = ["created_at", "context_id", "serialized_data", "context_schema"] 109 | placeholders = ["?"] * len(columns) 110 | sql = f""" 111 | INSERT INTO chat_histories ({', '.join(columns)}) 112 | VALUES ({', '.join(placeholders)}) 113 | """ 114 | 115 | now_utc = datetime.now(timezone.utc) 116 | records = [] 117 | for data_item in data_list: 118 | record = ( 119 | now_utc, # created_at 120 | context_id, # context_id 121 | json.dumps(data_item, ensure_ascii=True), # serialized_data 122 | context_schema, # context_schema 123 | ) 124 | records.append(record) 125 | 126 | # Execute many inserts in a single statement 127 | conn.executemany(sql, records) 128 | conn.commit() 129 | 130 | except Exception as ex: 131 | logger.error(f"Error at add_histories: {ex}") 132 | conn.rollback() 133 | 134 | finally: 135 | conn.close() 136 | 137 | async def get_last_created_at(self, context_id: str) -> datetime: 138 | conn = sqlite3.connect(self.db_path) 139 | try: 140 | sql = """ 141 | SELECT created_at 142 | FROM chat_histories 143 | WHERE context_id = ? 144 | ORDER BY id DESC 145 | LIMIT 1 146 | """ 147 | cursor = conn.cursor() 148 | cursor.execute(sql, (context_id,)) 149 | row = cursor.fetchone() 150 | if row: 151 | last_created_at = datetime.fromisoformat(row[0]) 152 | else: 153 | last_created_at = datetime.min 154 | 155 | return last_created_at.replace(tzinfo=timezone.utc) 156 | 157 | except Exception as ex: 158 | logger.error(f"Error at get_last_created_at: {ex}") 159 | return datetime.min.replace(tzinfo=timezone.utc) 160 | 161 | finally: 162 | conn.close() 163 | -------------------------------------------------------------------------------- /litests/llm/context_manager/postgres.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone, timedelta 2 | import json 3 | import logging 4 | from typing import List, Dict 5 | import psycopg2 6 | from ..base import ContextManager 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class PostgreSQLContextManager(ContextManager): 12 | def __init__( 13 | self, 14 | *, 15 | host: str = "localhost", 16 | port: int = 5432, 17 | dbname: str = "litests", 18 | user: str = "postgres", 19 | password: str = None, 20 | context_timeout: int = 3600 21 | ): 22 | self.connection_params = { 23 | "host": host, 24 | "port": port, 25 | "dbname": dbname, 26 | "user": user, 27 | "password": password, 28 | } 29 | self.context_timeout = context_timeout 30 | self.init_db() 31 | 32 | def connect_db(self): 33 | return psycopg2.connect(**self.connection_params) 34 | 35 | def init_db(self): 36 | conn = self.connect_db() 37 | try: 38 | with conn.cursor() as cur: 39 | # Create table 40 | cur.execute( 41 | """ 42 | CREATE TABLE IF NOT EXISTS chat_histories ( 43 | id SERIAL PRIMARY KEY, 44 | created_at TIMESTAMP NOT NULL, 45 | context_id TEXT NOT NULL, 46 | serialized_data JSON NOT NULL, 47 | context_schema TEXT 48 | ) 49 | """ 50 | ) 51 | # Create index 52 | cur.execute( 53 | """ 54 | CREATE INDEX IF NOT EXISTS idx_chat_histories_context_id_created_at 55 | ON chat_histories (context_id, created_at) 56 | """ 57 | ) 58 | conn.commit() 59 | except Exception as ex: 60 | logger.error(f"Error at init_db: {ex}") 61 | conn.rollback() 62 | finally: 63 | conn.close() 64 | 65 | async def get_histories(self, context_id: str, limit: int = 100) -> List[Dict]: 66 | conn = self.connect_db() 67 | try: 68 | sql_query = """ 69 | SELECT serialized_data 70 | FROM chat_histories 71 | WHERE context_id = %s 72 | """ 73 | params = [context_id] 74 | 75 | if self.context_timeout > 0: 76 | sql_query += " AND created_at >= %s" 77 | cutoff_time = datetime.now(timezone.utc) - timedelta(seconds=self.context_timeout) 78 | params.append(cutoff_time) 79 | 80 | sql_query += " ORDER BY id DESC" 81 | 82 | if limit > 0: 83 | sql_query += " LIMIT %s" 84 | params.append(limit) 85 | 86 | with conn.cursor() as cur: 87 | cur.execute(sql_query, tuple(params)) 88 | rows = cur.fetchall() 89 | 90 | rows.reverse() 91 | results = [row[0] for row in rows] 92 | return results 93 | 94 | except Exception as ex: 95 | logger.error(f"Error at get_histories: {ex}") 96 | return [] 97 | 98 | finally: 99 | conn.close() 100 | 101 | async def add_histories(self, context_id: str, data_list: List[Dict], context_schema: str = None): 102 | if not data_list: 103 | return 104 | 105 | conn = self.connect_db() 106 | try: 107 | columns = ["created_at", "context_id", "serialized_data", "context_schema"] 108 | placeholders = ["%s"] * len(columns) 109 | sql_query = f""" 110 | INSERT INTO chat_histories ({', '.join(columns)}) 111 | VALUES ({', '.join(placeholders)}) 112 | """ 113 | 114 | now_utc = datetime.now(timezone.utc) 115 | records = [] 116 | for data_item in data_list: 117 | record = ( 118 | now_utc, # created_at 119 | context_id, # context_id 120 | json.dumps(data_item, ensure_ascii=False), # serialized_data 121 | context_schema, # context_schema 122 | ) 123 | records.append(record) 124 | 125 | with conn.cursor() as cur: 126 | cur.executemany(sql_query, records) 127 | conn.commit() 128 | 129 | except Exception as ex: 130 | logger.error(f"Error at add_histories: {ex}") 131 | conn.rollback() 132 | 133 | finally: 134 | conn.close() 135 | 136 | async def get_last_created_at(self, context_id: str) -> datetime: 137 | conn = self.connect_db() 138 | try: 139 | sql = """ 140 | SELECT created_at 141 | FROM chat_histories 142 | WHERE context_id = %s 143 | ORDER BY id DESC 144 | LIMIT 1 145 | """ 146 | with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor: 147 | cursor.execute(sql, (context_id,)) 148 | row = cursor.fetchone() 149 | 150 | if row and row["created_at"]: 151 | last_created_at = row["created_at"] 152 | if last_created_at.tzinfo is None: 153 | last_created_at = last_created_at.replace(tzinfo=timezone.utc) 154 | else: 155 | last_created_at = last_created_at.astimezone(timezone.utc) 156 | else: 157 | last_created_at = datetime.min.replace(tzinfo=timezone.utc) 158 | 159 | return last_created_at 160 | 161 | except Exception as ex: 162 | logger.error(f"Error at get_last_created_at: {ex}") 163 | return datetime.min.replace(tzinfo=timezone.utc) 164 | 165 | finally: 166 | conn.close() 167 | -------------------------------------------------------------------------------- /litests/llm/dify.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | import json 3 | from typing import AsyncGenerator, Dict, List 4 | import httpx 5 | from . import LLMService, LLMResponse 6 | 7 | logger = getLogger(__name__) 8 | 9 | 10 | class DifyService(LLMService): 11 | def __init__( 12 | self, 13 | *, 14 | api_key: str = None, 15 | user: str = None, 16 | base_url: str = "http://127.0.0.1", 17 | is_agent_mode: bool = False, 18 | make_inputs: callable = None, 19 | split_chars: List[str] = None, 20 | option_split_chars: List[str] = None, 21 | option_split_threshold: int = 50, 22 | voice_text_tag: str = None, 23 | max_connections: int = 100, 24 | max_keepalive_connections: int = 20, 25 | timeout: float = 10.0 26 | ): 27 | super().__init__( 28 | system_prompt=None, 29 | model=None, 30 | temperature=0.0, 31 | split_chars=split_chars, 32 | option_split_chars=option_split_chars, 33 | option_split_threshold=option_split_threshold, 34 | voice_text_tag=voice_text_tag 35 | ) 36 | self.conversation_ids: Dict[str, str] = {} 37 | self.api_key = api_key 38 | self.user = user 39 | self.base_url = base_url 40 | self.is_agent_mode = is_agent_mode 41 | self.make_inputs = make_inputs 42 | self.http_client = httpx.AsyncClient( 43 | follow_redirects=False, 44 | timeout=httpx.Timeout(timeout), 45 | limits=httpx.Limits( 46 | max_connections=max_connections, 47 | max_keepalive_connections=max_keepalive_connections 48 | ) 49 | ) 50 | 51 | async def compose_messages(self, context_id: str, text: str, files: List[Dict[str, str]] = None, system_prompt_params: Dict[str, any] = None) -> List[Dict]: 52 | if self.make_inputs: 53 | inputs = self.make_inputs(context_id, text, files, system_prompt_params) 54 | else: 55 | inputs = {} 56 | 57 | message = { 58 | "inputs": inputs, 59 | "query": text, 60 | "response_mode": "streaming", 61 | "user": self.user, 62 | "auto_generate_name": False, 63 | "conversation_id": self.conversation_ids.get(context_id, "") 64 | } 65 | if files: 66 | for f in files: 67 | if url := f.get("url"): 68 | files.append({"type": "image", "transfer_method": "remote_url", "url": url}) 69 | message["files"] = files 70 | 71 | return [message] 72 | 73 | async def update_context(self, context_id: str, messages: List[Dict], response_text: str): 74 | # Context is managed at Dify server 75 | pass 76 | 77 | 78 | async def get_llm_stream_response(self, context_id: str, user_id: str, messages: List[dict], system_prompt_params: Dict[str, any] = None) -> AsyncGenerator[LLMResponse, None]: 79 | headers = { 80 | "Authorization": f"Bearer {self.api_key}" 81 | } 82 | 83 | if user_id: 84 | messages[0]["user"] = user_id 85 | stream_resp = await self.http_client.post( 86 | self.base_url + "/chat-messages", 87 | headers=headers, 88 | json=messages[0] 89 | ) 90 | stream_resp.raise_for_status() 91 | 92 | message_event_value = "agent_message" if self.is_agent_mode else "message" 93 | async for chunk in stream_resp.aiter_lines(): 94 | if chunk.startswith("data:"): 95 | chunk_json = json.loads(chunk[5:]) 96 | if chunk_json["event"] == message_event_value: 97 | answer = chunk_json["answer"] 98 | yield LLMResponse(context_id=context_id, text=answer) 99 | elif chunk_json["event"] == "message_end": 100 | # Save conversation id instead of managing context locally 101 | self.conversation_ids[context_id] = chunk_json["conversation_id"] 102 | -------------------------------------------------------------------------------- /litests/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Dict, Any 3 | from .llm import ToolCall 4 | 5 | 6 | @dataclass 7 | class STSRequest: 8 | type: str = "start" 9 | session_id: str = None 10 | user_id: str = None 11 | context_id: str = None 12 | text: str = None 13 | audio_data: bytes = None 14 | audio_duration: float = 0 15 | files: List[Dict[str, str]] = None 16 | system_prompt_params: Dict[str, Any] = None 17 | 18 | 19 | @dataclass 20 | class STSResponse: 21 | type: str 22 | session_id: str = None 23 | user_id: str = None 24 | context_id: str = None 25 | text: str = None 26 | voice_text: str = None 27 | audio_data: bytes = None 28 | tool_call: ToolCall = None 29 | metadata: dict = None 30 | -------------------------------------------------------------------------------- /litests/performance_recorder/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import PerformanceRecorder, PerformanceRecord 2 | -------------------------------------------------------------------------------- /litests/performance_recorder/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class PerformanceRecord: 7 | transaction_id: str 8 | user_id: str = None 9 | context_id: str = None 10 | stt_name: str = None 11 | llm_name: str = None 12 | tts_name: str = None 13 | request_text: str = None 14 | response_text: str = None 15 | request_files: str = None 16 | response_voice_text: str = None 17 | voice_length: float = 0 18 | stt_time: float = 0 19 | stop_response_time: float = 0 20 | llm_first_chunk_time: float = 0 21 | llm_first_voice_chunk_time: float = 0 22 | llm_time: float = 0 23 | tts_first_chunk_time: float = 0 24 | tts_time: float = 0 25 | total_time: float = 0 26 | 27 | 28 | class PerformanceRecorder(ABC): 29 | @abstractmethod 30 | def record(self, record: PerformanceRecord): 31 | pass 32 | 33 | @abstractmethod 34 | def close(self): 35 | pass 36 | -------------------------------------------------------------------------------- /litests/performance_recorder/postgres.py: -------------------------------------------------------------------------------- 1 | from dataclasses import fields 2 | from datetime import datetime, timezone 3 | import logging 4 | import queue 5 | import threading 6 | import time 7 | import psycopg2 8 | from . import PerformanceRecorder, PerformanceRecord 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class PostgreSQLPerformanceRecorder(PerformanceRecorder): 14 | def __init__( 15 | self, 16 | *, 17 | host: str = "localhost", 18 | port: int = 5432, 19 | dbname: str = "litests", 20 | user: str = "postgres", 21 | password: str = None, 22 | ): 23 | self.connection_params = { 24 | "host": host, 25 | "port": port, 26 | "dbname": dbname, 27 | "user": user, 28 | "password": password, 29 | } 30 | self.record_queue = queue.Queue() 31 | self.stop_event = threading.Event() 32 | 33 | self.init_db() 34 | 35 | self.worker_thread = threading.Thread(target=self.start_worker, daemon=True) 36 | self.worker_thread.start() 37 | 38 | def connect_db(self): 39 | return psycopg2.connect(**self.connection_params) 40 | 41 | def add_column_if_not_exist(self, cur, column_name): 42 | cur.execute( 43 | f""" 44 | SELECT column_name FROM information_schema.columns 45 | WHERE table_name='performance_records' AND column_name='{column_name}' 46 | """ 47 | ) 48 | if not cur.fetchone(): 49 | cur.execute( 50 | f"ALTER TABLE performance_records ADD COLUMN {column_name} TEXT" 51 | ) 52 | 53 | def init_db(self): 54 | conn = self.connect_db() 55 | try: 56 | with conn.cursor() as cur: 57 | cur.execute( 58 | """ 59 | CREATE TABLE IF NOT EXISTS performance_records ( 60 | id SERIAL PRIMARY KEY, 61 | created_at TIMESTAMPTZ, 62 | transaction_id TEXT, 63 | user_id TEXT, 64 | context_id TEXT, 65 | voice_length REAL, 66 | stt_time REAL, 67 | stop_response_time REAL, 68 | llm_first_chunk_time REAL, 69 | llm_first_voice_chunk_time REAL, 70 | llm_time REAL, 71 | tts_first_chunk_time REAL, 72 | tts_time REAL, 73 | total_time REAL, 74 | stt_name TEXT, 75 | llm_name TEXT, 76 | tts_name TEXT, 77 | request_text TEXT, 78 | request_files TEXT, 79 | response_text TEXT, 80 | response_voice_text TEXT 81 | ) 82 | """ 83 | ) 84 | 85 | # Add request_files column if not exist (migration v0.3.0 -> 0.3.2) 86 | self.add_column_if_not_exist(cur, "request_files") 87 | 88 | # Add user_id column if not exist (migration v0.3.2 -> 0.3.3) 89 | self.add_column_if_not_exist(cur, "user_id") 90 | 91 | # Add transaction_id column if not exist (migration v0.3.3 -> 0.3.4) 92 | self.add_column_if_not_exist(cur, "transaction_id") 93 | 94 | # Create index 95 | cur.execute("CREATE INDEX IF NOT EXISTS idx_created_at ON performance_records (created_at)") 96 | cur.execute("CREATE INDEX IF NOT EXISTS idx_transaction_id ON performance_records (transaction_id)") 97 | cur.execute("CREATE INDEX IF NOT EXISTS idx_user_id ON performance_records (user_id)") 98 | cur.execute("CREATE INDEX IF NOT EXISTS idx_context_id ON performance_records (context_id)") 99 | 100 | conn.commit() 101 | finally: 102 | conn.close() 103 | 104 | def start_worker(self): 105 | conn = self.connect_db() 106 | try: 107 | while not self.stop_event.is_set() or not self.record_queue.empty(): 108 | try: 109 | record = self.record_queue.get(timeout=0.5) 110 | except queue.Empty: 111 | continue 112 | 113 | try: 114 | self.insert_record(conn, record) 115 | except (psycopg2.InterfaceError, psycopg2.OperationalError): 116 | try: 117 | conn.close() 118 | except Exception: 119 | pass 120 | 121 | logger.warning("Connection is not available. Retrying insert_record with new connection...") 122 | time.sleep(0.5) 123 | conn = self.connect_db() 124 | self.insert_record(conn, record) 125 | 126 | self.record_queue.task_done() 127 | finally: 128 | try: 129 | conn.close() 130 | except Exception: 131 | pass 132 | 133 | def insert_record(self, conn: psycopg2.extensions.connection, record: PerformanceRecord): 134 | columns = [field.name for field in fields(PerformanceRecord)] + ["created_at"] 135 | placeholders = ["%s"] * len(columns) 136 | values = [getattr(record, field.name) for field in fields(PerformanceRecord)] + [datetime.now(timezone.utc)] 137 | sql = f"INSERT INTO performance_records ({', '.join(columns)}) VALUES ({', '.join(placeholders)})" 138 | with conn.cursor() as cur: 139 | cur.execute(sql, values) 140 | conn.commit() 141 | 142 | def record(self, record: PerformanceRecord): 143 | self.record_queue.put(record) 144 | 145 | def close(self): 146 | self.stop_event.set() 147 | self.record_queue.join() 148 | self.worker_thread.join() 149 | -------------------------------------------------------------------------------- /litests/performance_recorder/sqlite.py: -------------------------------------------------------------------------------- 1 | from dataclasses import fields 2 | from datetime import datetime, timezone 3 | import queue 4 | import sqlite3 5 | import threading 6 | from . import PerformanceRecorder, PerformanceRecord 7 | 8 | 9 | class SQLitePerformanceRecorder(PerformanceRecorder): 10 | def __init__(self, db_path="performance.db"): 11 | self.db_path = db_path 12 | self.record_queue = queue.Queue() 13 | self.stop_event = threading.Event() 14 | 15 | self.init_db() 16 | 17 | self.worker_thread = threading.Thread(target=self.start_worker, daemon=True) 18 | self.worker_thread.start() 19 | 20 | def init_db(self): 21 | conn = sqlite3.connect(self.db_path) 22 | try: 23 | with conn: 24 | conn.execute( 25 | """ 26 | CREATE TABLE IF NOT EXISTS performance_records ( 27 | id INTEGER PRIMARY KEY AUTOINCREMENT, 28 | created_at TIMESTAMP, 29 | transaction_id TEXT, 30 | user_id TEXT, 31 | context_id TEXT, 32 | voice_length REAL, 33 | stt_time REAL, 34 | stop_response_time REAL, 35 | llm_first_chunk_time REAL, 36 | llm_first_voice_chunk_time REAL, 37 | llm_time REAL, 38 | tts_first_chunk_time REAL, 39 | tts_time REAL, 40 | total_time REAL, 41 | stt_name TEXT, 42 | llm_name TEXT, 43 | tts_name TEXT, 44 | request_text TEXT, 45 | request_files TEXT, 46 | response_text TEXT, 47 | response_voice_text TEXT 48 | ) 49 | """ 50 | ) 51 | 52 | cursor = conn.execute("PRAGMA table_info(performance_records)") 53 | columns = [row[1] for row in cursor.fetchall()] 54 | 55 | # Add request_files column if not exist (migration v0.3.0 -> 0.3.2) 56 | if "request_files" not in columns: 57 | conn.execute("ALTER TABLE performance_records ADD COLUMN request_files TEXT") 58 | 59 | # Add user_id column if not exist (migration v0.3.2 -> 0.3.3) 60 | if "user_id" not in columns: 61 | conn.execute("ALTER TABLE performance_records ADD COLUMN user_id TEXT") 62 | 63 | # Add transaction_id column if not exist (migration v0.3.3 -> 0.3.4) 64 | if "transaction_id" not in columns: 65 | print("add column: transaction_id") 66 | conn.execute("ALTER TABLE performance_records ADD COLUMN transaction_id TEXT") 67 | 68 | # Create index 69 | conn.execute("CREATE INDEX IF NOT EXISTS idx_created_at ON performance_records (created_at)") 70 | conn.execute("CREATE INDEX IF NOT EXISTS idx_transaction_id ON performance_records (transaction_id)") 71 | conn.execute("CREATE INDEX IF NOT EXISTS idx_user_id ON performance_records (user_id)") 72 | conn.execute("CREATE INDEX IF NOT EXISTS idx_context_id ON performance_records (context_id)") 73 | 74 | finally: 75 | conn.close() 76 | 77 | def start_worker(self): 78 | conn = sqlite3.connect(self.db_path) 79 | try: 80 | while not self.stop_event.is_set() or not self.record_queue.empty(): 81 | try: 82 | record = self.record_queue.get(timeout=0.5) 83 | except queue.Empty: 84 | continue 85 | 86 | self.insert_record(conn, record) 87 | self.record_queue.task_done() 88 | finally: 89 | conn.close() 90 | 91 | def insert_record(self, conn: sqlite3.Connection, record: PerformanceRecord): 92 | columns = [field.name for field in fields(PerformanceRecord)] + ["created_at"] 93 | placeholders = ["?"] * len(columns) 94 | values = [getattr(record, field.name) for field in fields(PerformanceRecord)] + [datetime.now(timezone.utc)] 95 | sql = f"INSERT INTO performance_records ({', '.join(columns)}) VALUES ({', '.join(placeholders)})" 96 | conn.execute(sql, values) 97 | conn.commit() 98 | 99 | def record(self, record: PerformanceRecord): 100 | self.record_queue.put(record) 101 | 102 | def close(self): 103 | self.stop_event.set() 104 | self.record_queue.join() 105 | self.worker_thread.join() 106 | -------------------------------------------------------------------------------- /litests/stt/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SpeechRecognizer, SpeechRecognizerDummy 2 | -------------------------------------------------------------------------------- /litests/stt/azure.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import logging 4 | from typing import List 5 | import wave 6 | from . import SpeechRecognizer 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class AzureSpeechRecognizer(SpeechRecognizer): 12 | def __init__( 13 | self, 14 | azure_api_key: str, 15 | azure_region: str, 16 | sample_rate: int = 16000, 17 | language: str = "ja-JP", 18 | alternative_languages: List[str] = None, 19 | use_classic: bool = False, 20 | *, 21 | max_connections: int = 100, 22 | max_keepalive_connections: int = 20, 23 | timeout: float = 10.0, 24 | debug: bool = False 25 | ): 26 | super().__init__( 27 | language=language, 28 | alternative_languages=alternative_languages, 29 | max_connections=max_connections, 30 | max_keepalive_connections=max_keepalive_connections, 31 | timeout=timeout, 32 | debug=debug 33 | ) 34 | self.azure_api_key = azure_api_key 35 | self.azure_region = azure_region 36 | self.sample_rate = sample_rate 37 | self.use_classic = use_classic 38 | if self.use_classic and self.alternative_languages: 39 | logger.warning("Auto language detection is not available in Azure STT v1. Set `use_classic=False` to enable this feature.") 40 | 41 | async def transcribe(self, data: bytes) -> str: 42 | if self.use_classic: 43 | return await self.transcribe_classic(data) 44 | else: 45 | return await self.transcribe_fast(data) 46 | 47 | async def transcribe_classic(self, data: bytes) -> str: 48 | headers = { 49 | "Ocp-Apim-Subscription-Key": self.azure_api_key 50 | } 51 | 52 | resp = await self.http_client.post( 53 | f"https://{self.azure_region}.stt.speech.microsoft.com/speech/recognition/conversation/cognitiveservices/v1?language={self.language}", 54 | headers=headers, 55 | content=data 56 | ) 57 | 58 | try: 59 | resp_json = resp.json() 60 | except: 61 | resp_json = {} 62 | 63 | if resp.status_code != 200: 64 | logger.error(f"Failed in recognition: {resp.status_code}\n{resp_json}") 65 | 66 | if recognized_text := resp_json.get("DisplayText"): 67 | if self.debug: 68 | logger.info(f"Recognized: {recognized_text}") 69 | return recognized_text 70 | 71 | def to_wave_file(self, raw_audio: bytes): 72 | buffer = io.BytesIO() 73 | with wave.open(buffer, "wb") as wf: 74 | wf.setnchannels(1) # mono 75 | wf.setsampwidth(2) # 16bit 76 | wf.setframerate(self.sample_rate) # sample rate 77 | wf.writeframes(raw_audio) 78 | buffer.seek(0) 79 | return buffer 80 | 81 | async def transcribe_fast(self, data: bytes) -> str: 82 | # Using Fast Transcription 83 | # https://learn.microsoft.com/en-us/rest/api/speechtotext/transcriptions/transcribe?view=rest-speechtotext-2024-11-15&tabs=HTTP 84 | headers = { 85 | "Ocp-Apim-Subscription-Key": self.azure_api_key, 86 | } 87 | 88 | # https://learn.microsoft.com/en-us/azure/ai-services/speech-service/fast-transcription-create?tabs=locale-specified#request-configuration-options 89 | locales = [self.language] + self.alternative_languages 90 | files = { 91 | "audio": self.to_wave_file(data), 92 | "definition": (None, json.dumps({"locales": locales, "channels": [0,1]}), "application/json"), 93 | } 94 | 95 | resp = await self.http_client.post( 96 | f"https://{self.azure_region}.api.cognitive.microsoft.com/speechtotext/transcriptions:transcribe?api-version=2024-11-15", 97 | headers=headers, 98 | files=files 99 | ) 100 | 101 | try: 102 | resp.raise_for_status() 103 | resp_json = resp.json() 104 | except: 105 | logger.error(f"Failed in recognition: {resp.status_code}\n{resp.content}") 106 | return None 107 | 108 | if recognized_text := resp_json["combinedPhrases"][0]["text"]: 109 | if self.debug: 110 | logger.info(f"Recognized: {recognized_text}") 111 | return recognized_text 112 | -------------------------------------------------------------------------------- /litests/stt/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List 3 | import httpx 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class SpeechRecognizer(ABC): 10 | def __init__( 11 | self, 12 | *, 13 | language: str = None, 14 | alternative_languages: List[str] = None, 15 | max_connections: int = 100, 16 | max_keepalive_connections: int = 20, 17 | timeout: float = 10.0, 18 | debug: bool = False 19 | ): 20 | self.language = language 21 | self.alternative_languages = alternative_languages or [] 22 | self.http_client = httpx.AsyncClient( 23 | follow_redirects=False, 24 | timeout=httpx.Timeout(timeout), 25 | limits=httpx.Limits( 26 | max_connections=max_connections, 27 | max_keepalive_connections=max_keepalive_connections 28 | ) 29 | ) 30 | 31 | self.debug = debug 32 | 33 | @abstractmethod 34 | async def transcribe(self, data: bytes) -> str: 35 | pass 36 | 37 | async def close(self): 38 | await self.http_client.aclose() 39 | 40 | 41 | class SpeechRecognizerDummy(SpeechRecognizer): 42 | async def transcribe(self, data: bytes) -> str: 43 | pass 44 | -------------------------------------------------------------------------------- /litests/stt/google.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import logging 3 | from typing import List 4 | from . import SpeechRecognizer 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class GoogleSpeechRecognizer(SpeechRecognizer): 10 | def __init__( 11 | self, 12 | google_api_key: str, 13 | sample_rate: int = 16000, 14 | language: str = "ja-JP", 15 | alternative_languages: List[str] = None, 16 | *, 17 | max_connections: int = 100, 18 | max_keepalive_connections: int = 20, 19 | timeout: float = 10.0, 20 | debug: bool = False 21 | ): 22 | super().__init__( 23 | language=language, 24 | alternative_languages=alternative_languages, 25 | max_connections=max_connections, 26 | max_keepalive_connections=max_keepalive_connections, 27 | timeout=timeout, 28 | debug=debug 29 | ) 30 | self.google_api_key = google_api_key 31 | self.sample_rate = sample_rate 32 | 33 | async def transcribe(self, data: bytes) -> str: 34 | request_body = { 35 | "config": { 36 | "encoding": "LINEAR16", 37 | "sampleRateHertz": self.sample_rate, 38 | "languageCode": self.language, 39 | }, 40 | "audio": { 41 | "content": base64.b64encode(data).decode("utf-8") 42 | }, 43 | } 44 | if self.alternative_languages: 45 | request_body["config"]["alternativeLanguageCodes"] = self.alternative_languages 46 | 47 | resp = await self.http_client.post( 48 | f"https://speech.googleapis.com/v1/speech:recognize?key={self.google_api_key}", 49 | json=request_body 50 | ) 51 | 52 | try: 53 | resp_json = resp.json() 54 | except: 55 | resp_json = {} 56 | 57 | if resp.status_code != 200: 58 | logger.error(f"Failed in recognition: {resp.status_code}\n{resp_json}") 59 | 60 | if resp_json.get("results"): 61 | if recognized_text := resp_json["results"][0]["alternatives"][0].get("transcript"): 62 | if self.debug: 63 | logger.info(f"Recognized: {recognized_text}") 64 | return recognized_text 65 | -------------------------------------------------------------------------------- /litests/stt/openai.py: -------------------------------------------------------------------------------- 1 | import io 2 | import logging 3 | from typing import List 4 | import wave 5 | from . import SpeechRecognizer 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class OpenAISpeechRecognizer(SpeechRecognizer): 11 | def __init__( 12 | self, 13 | openai_api_key: str, 14 | sample_rate: int = 16000, 15 | language: str = "ja", 16 | alternative_languages: List[str] = None, 17 | *, 18 | max_connections: int = 100, 19 | max_keepalive_connections: int = 20, 20 | timeout: float = 10.0, 21 | debug: bool = False 22 | ): 23 | super().__init__( 24 | language=language, 25 | alternative_languages=alternative_languages, 26 | max_connections=max_connections, 27 | max_keepalive_connections=max_keepalive_connections, 28 | timeout=timeout, 29 | debug=debug 30 | ) 31 | self.openai_api_key = openai_api_key 32 | self.sample_rate = sample_rate 33 | 34 | def to_wave_file(self, raw_audio: bytes): 35 | buffer = io.BytesIO() 36 | with wave.open(buffer, "wb") as wf: 37 | wf.setnchannels(1) # mono 38 | wf.setsampwidth(2) # 16bit 39 | wf.setframerate(self.sample_rate) # sample rate 40 | wf.writeframes(raw_audio) 41 | buffer.seek(0) 42 | return buffer 43 | 44 | async def transcribe(self, data: bytes) -> str: 45 | headers = { 46 | "Authorization": f"Bearer {self.openai_api_key}" 47 | } 48 | 49 | form_data = { 50 | "model": "whisper-1", 51 | } 52 | 53 | if self.language and not self.alternative_languages: 54 | form_data["language"] = self.language.split("-")[0] if "-" in self.language else self.language 55 | 56 | files = { 57 | "file": ("voice.wav", self.to_wave_file(data), "audio/wav"), 58 | } 59 | 60 | resp = await self.http_client.post( 61 | "https://api.openai.com/v1/audio/transcriptions", 62 | headers=headers, 63 | data=form_data, 64 | files=files 65 | ) 66 | 67 | try: 68 | resp_json = resp.json() 69 | except: 70 | resp_json = {} 71 | return None 72 | 73 | if resp.status_code != 200: 74 | logger.error(f"Failed in recognition: {resp.status_code}\n{resp_json}") 75 | 76 | return resp_json.get("text") 77 | -------------------------------------------------------------------------------- /litests/tts/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SpeechSynthesizer, SpeechSynthesizerDummy 2 | -------------------------------------------------------------------------------- /litests/tts/azure.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict 3 | from . import SpeechSynthesizer 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class AzureSpeechSynthesizer(SpeechSynthesizer): 9 | def __init__( 10 | self, 11 | *, 12 | azure_api_key: str, 13 | azure_region: str, 14 | speaker: str, 15 | style_mapper: Dict[str, str] = None, 16 | default_language: str = "ja-JP", 17 | audio_format: str = "riff-16khz-16bit-mono-pcm", 18 | max_connections: int = 100, 19 | max_keepalive_connections: int = 20, 20 | timeout: float = 10.0, 21 | debug: bool = False 22 | ): 23 | super().__init__( 24 | style_mapper=style_mapper, 25 | max_connections=max_connections, 26 | max_keepalive_connections=max_keepalive_connections, 27 | timeout=timeout, 28 | debug=debug 29 | ) 30 | self.azure_api_key = azure_api_key 31 | self.azure_region = azure_region 32 | self.speaker = speaker 33 | self.default_language = default_language 34 | self.audio_format = audio_format 35 | self.voice_map = {self.default_language: self.speaker} 36 | 37 | async def synthesize(self, text: str, style_info: dict = None, language: str = None) -> bytes: 38 | if not text or not text.strip(): 39 | return bytes() 40 | 41 | logger.info(f"Speech synthesize: {text}") 42 | 43 | headers = { 44 | "X-Microsoft-OutputFormat": self.audio_format, 45 | "Content-Type": "application/ssml+xml", 46 | "Ocp-Apim-Subscription-Key": self.azure_api_key 47 | } 48 | 49 | speaker = self.voice_map[language or self.default_language] 50 | ssml_text = f"{text}" 51 | data = ssml_text.encode("utf-8") 52 | 53 | # Synthesize 54 | # https://learn.microsoft.com/ja-jp/azure/ai-services/speech-service/language-support?tabs=tts 55 | resp = await self.http_client.post( 56 | url=f"https://{self.azure_region}.tts.speech.microsoft.com/cognitiveservices/v1", 57 | headers=headers, 58 | data=data 59 | ) 60 | 61 | return resp.content 62 | -------------------------------------------------------------------------------- /litests/tts/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict 3 | import httpx 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class SpeechSynthesizer(ABC): 10 | def __init__( 11 | self, 12 | *, 13 | style_mapper: Dict[str, str] = None, 14 | max_connections: int = 100, 15 | max_keepalive_connections: int = 20, 16 | timeout: float = 10.0, 17 | debug: bool = False 18 | ): 19 | self.http_client = httpx.AsyncClient( 20 | follow_redirects=False, 21 | timeout=httpx.Timeout(timeout), 22 | limits=httpx.Limits( 23 | max_connections=max_connections, 24 | max_keepalive_connections=max_keepalive_connections 25 | ) 26 | ) 27 | self.style_mapper = style_mapper or {} 28 | self.debug = debug 29 | 30 | def parse_style(self, style_info: dict = None) -> str: 31 | if not style_info: 32 | return None 33 | 34 | styled_text = style_info.get("styled_text", "") 35 | for k, v in self.style_mapper.items(): 36 | if k in styled_text: 37 | return v 38 | return None 39 | 40 | @abstractmethod 41 | async def synthesize(self, text: str, style_info: dict = None, language: str = None) -> bytes: 42 | pass 43 | 44 | async def close(self): 45 | await self.http_client.aclose() 46 | 47 | 48 | class SpeechSynthesizerDummy(SpeechSynthesizer): 49 | async def synthesize(self, text: str, style_info: dict = None, language: str = None) -> bytes: 50 | return None 51 | -------------------------------------------------------------------------------- /litests/tts/google.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import logging 3 | from typing import Dict 4 | 5 | from . import SpeechSynthesizer 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class GoogleSpeechSynthesizer(SpeechSynthesizer): 11 | def __init__( 12 | self, 13 | *, 14 | google_api_key: str, 15 | speaker: str, 16 | default_language: str = "ja-JP", 17 | style_mapper: Dict[str, str] = None, 18 | audio_format: str = "LINEAR16", 19 | max_connections: int = 100, 20 | max_keepalive_connections: int = 20, 21 | timeout: float = 10.0, 22 | debug: bool = False 23 | ): 24 | super().__init__( 25 | style_mapper=style_mapper, 26 | max_connections=max_connections, 27 | max_keepalive_connections=max_keepalive_connections, 28 | timeout=timeout, 29 | debug=debug 30 | ) 31 | self.google_api_key = google_api_key 32 | self.speaker = speaker 33 | self.default_language = default_language 34 | self.audio_format = audio_format 35 | self.voice_map = {self.default_language: self.speaker} 36 | 37 | async def synthesize(self, text: str, style_info: dict = None, language: str = None) -> bytes: 38 | if not text or not text.strip(): 39 | return bytes() 40 | 41 | logger.info(f"Speech synthesize: {text}") 42 | 43 | # Set language and speaker 44 | voice = {"languageCode": self.default_language, "name": self.speaker} 45 | if language: 46 | if language.startswith("zh-"): 47 | language = language.replace("zh-", "cmn-CN") 48 | if language in self.voice_map: 49 | voice = {"languageCode": language, "name": self.voice_map[language]} 50 | 51 | # Synthesize 52 | # https://cloud.google.com/text-to-speech/docs/voices 53 | resp = await self.http_client.post( 54 | url=f"https://texttospeech.googleapis.com/v1/text:synthesize?key={self.google_api_key}", 55 | json={ 56 | "input": {"text": text}, 57 | "voice": voice, 58 | "audioConfig": {"audioEncoding": self.audio_format} 59 | } 60 | ) 61 | resp_json = resp.json() 62 | 63 | return base64.b64decode(resp_json["audioContent"]) 64 | -------------------------------------------------------------------------------- /litests/tts/openai.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict 3 | from . import SpeechSynthesizer 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class OpenAISpeechSynthesizer(SpeechSynthesizer): 9 | def __init__( 10 | self, 11 | *, 12 | openai_api_key: str, 13 | speaker: str, 14 | model: str = "tts-1", 15 | style_mapper: Dict[str, str] = None, 16 | audio_format: str = "wav", 17 | max_connections: int = 100, 18 | max_keepalive_connections: int = 20, 19 | timeout: float = 10.0, 20 | debug: bool = False 21 | ): 22 | super().__init__( 23 | style_mapper=style_mapper, 24 | max_connections=max_connections, 25 | max_keepalive_connections=max_keepalive_connections, 26 | timeout=timeout, 27 | debug=debug 28 | ) 29 | self.openai_api_key = openai_api_key 30 | self.speaker = speaker 31 | self.model = model 32 | self.audio_format = audio_format 33 | 34 | async def synthesize(self, text: str, style_info: dict = None, language: str = None) -> bytes: 35 | if not text or not text.strip(): 36 | return bytes() 37 | 38 | logger.info(f"Speech synthesize: {text}") 39 | 40 | # Synthesize 41 | resp = await self.http_client.post( 42 | url="https://api.openai.com/v1/audio/speech", 43 | headers={ 44 | "Authorization": f"Bearer {self.openai_api_key}" 45 | }, 46 | json= { 47 | "model": self.model, 48 | "voice": self.speaker, 49 | "input": text, 50 | # "speed": self.speed, 51 | "response_format": "wav" 52 | } 53 | ) 54 | 55 | return resp.content 56 | -------------------------------------------------------------------------------- /litests/tts/speech_gateway.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict 3 | from . import SpeechSynthesizer 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class SpeechGatewaySpeechSynthesizer(SpeechSynthesizer): 9 | def __init__( 10 | self, 11 | *, 12 | service_name: str, 13 | speaker: str, 14 | style_mapper: Dict[str, str] = None, 15 | tts_url: str = "http://127.0.0.1:8000/tts", 16 | audio_format: str = None, 17 | max_connections: int = 100, 18 | max_keepalive_connections: int = 20, 19 | timeout: float = 10.0, 20 | debug: bool = False 21 | ): 22 | super().__init__( 23 | style_mapper=style_mapper, 24 | max_connections=max_connections, 25 | max_keepalive_connections=max_keepalive_connections, 26 | timeout=timeout, 27 | debug=debug 28 | ) 29 | self.service_name = service_name 30 | self.speaker = speaker 31 | self.tts_url = tts_url 32 | self.audio_format = audio_format 33 | 34 | async def synthesize(self, text: str, style_info: dict = None, language: str = None) -> bytes: 35 | if not text or not text.strip(): 36 | return bytes() 37 | 38 | logger.info(f"Speech synthesize: {text}") 39 | 40 | # Audio format 41 | query_params = {"x_audio_format": self.audio_format} if self.audio_format else {} 42 | 43 | # Apply style 44 | request_json = {"text": text, "service_name": self.service_name, "speaker": self.speaker} 45 | if style := self.parse_style(style_info): 46 | request_json["style"] = style 47 | logger.info(f"Apply style: {style}") 48 | 49 | # Apply language 50 | if language and language != "ja-JP": 51 | logger.info(f"Apply language: {language}") 52 | request_json["language"] = language 53 | del request_json["service_name"] 54 | del request_json["speaker"] 55 | 56 | # Synthesize 57 | resp = await self.http_client.post( 58 | url=self.tts_url, 59 | params=query_params, 60 | json=request_json 61 | ) 62 | 63 | return resp.content 64 | -------------------------------------------------------------------------------- /litests/tts/voicevox.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict 3 | from . import SpeechSynthesizer 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class VoicevoxSpeechSynthesizer(SpeechSynthesizer): 9 | def __init__( 10 | self, 11 | *, 12 | base_url: str = "http://127.0.0.1:50021", 13 | speaker: int = 46, 14 | style_mapper: Dict[str, str] = None, 15 | max_connections: int = 100, 16 | max_keepalive_connections: int = 20, 17 | timeout: float = 10.0, 18 | debug: bool = False 19 | ): 20 | super().__init__( 21 | style_mapper=style_mapper, 22 | max_connections=max_connections, 23 | max_keepalive_connections=max_keepalive_connections, 24 | timeout=timeout, 25 | debug=debug 26 | ) 27 | self.base_url = base_url 28 | self.speaker = speaker 29 | 30 | async def get_audio_query(self, text: str, speaker: int): 31 | url = f"{self.base_url}/audio_query" 32 | response = await self.http_client.post(url, params={"speaker": speaker, "text": text}) 33 | response.raise_for_status() 34 | return response.json() 35 | 36 | async def synthesize(self, text: str, style_info: dict = None, language: str = None) -> bytes: 37 | if not text or not text.strip(): 38 | return bytes() 39 | 40 | logger.info(f"Speech synthesize: {text}") 41 | 42 | speaker = self.speaker 43 | 44 | # Apply style 45 | if style := self.parse_style(style_info): 46 | speaker = int(style) 47 | logger.info(f"Apply style: {speaker}") 48 | 49 | # Make query 50 | audio_query = await self.get_audio_query(text, speaker) 51 | 52 | # Synthesize 53 | response = await self.http_client.post( 54 | url=self.base_url + "/synthesis", 55 | params={"speaker": speaker}, 56 | json=audio_query 57 | ) 58 | return response.content 59 | -------------------------------------------------------------------------------- /litests/vad/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SpeechDetector, SpeechDetectorDummy 2 | from .standard import StandardSpeechDetector 3 | -------------------------------------------------------------------------------- /litests/vad/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import logging 3 | from typing import AsyncGenerator 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class SpeechDetector(ABC): 9 | def __init__(self, *, sample_rate: int = 16000): 10 | self.sample_rate = sample_rate 11 | self._on_speech_detected = self.on_speech_detected_default 12 | self.should_mute = lambda: False 13 | 14 | def on_speech_detected(self, func): 15 | self._on_speech_detected = func 16 | return func 17 | 18 | async def on_speech_detected_default(data: bytes, recorded_duration: float, session_id: str): 19 | logger.info(f"Speech detected: len={recorded_duration} sec") 20 | 21 | @abstractmethod 22 | async def process_samples(self, samples: bytes, session_id: str = None): 23 | pass 24 | 25 | @abstractmethod 26 | async def process_stream(self, input_stream: AsyncGenerator[bytes, None], session_id: str = None): 27 | pass 28 | 29 | @abstractmethod 30 | async def finalize_session(self, session_id: str): 31 | pass 32 | 33 | 34 | class SpeechDetectorDummy(SpeechDetector): 35 | async def process_samples(self, samples, session_id = None): 36 | pass 37 | 38 | async def process_stream(self, input_stream, session_id = None): 39 | pass 40 | 41 | async def finalize_session(self, session_id): 42 | pass 43 | -------------------------------------------------------------------------------- /litests/vad/standard.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from collections import deque 3 | import logging 4 | import math 5 | import struct 6 | from typing import AsyncGenerator, Callable, Optional, Dict 7 | from . import SpeechDetector 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class RecordingSession: 13 | def __init__(self, session_id: str, preroll_buffer_count: int = 5): 14 | self.session_id = session_id 15 | self.is_recording: bool = False 16 | self.buffer: bytearray = bytearray() 17 | self.silence_duration: float = 0 18 | self.record_duration: float = 0 19 | self.preroll_buffer: deque = deque(maxlen=preroll_buffer_count) 20 | self.amplitude_threshold: float = 0 21 | self.data: dict = {} 22 | 23 | def reset(self): 24 | # Reset status data except for preroll_buffer 25 | self.buffer.clear() 26 | self.is_recording = False 27 | self.silence_duration = 0 28 | self.record_duration = 0 29 | 30 | 31 | class StandardSpeechDetector(SpeechDetector): 32 | def __init__( 33 | self, 34 | *, 35 | volume_db_threshold: float = -40.0, 36 | silence_duration_threshold: float = 0.5, 37 | max_duration: float = 10.0, 38 | min_duration: float = 0.2, 39 | sample_rate: int = 16000, 40 | channels: int = 1, 41 | preroll_buffer_count: int = 5, 42 | to_linear16: Optional[Callable[[bytes], bytes]] = None, 43 | debug: bool = False 44 | ): 45 | self._volume_db_threshold = volume_db_threshold 46 | self.amplitude_threshold = 32767 * (10 ** (self.volume_db_threshold / 20.0)) 47 | self.silence_duration_threshold = silence_duration_threshold 48 | self.max_duration = max_duration 49 | self.min_duration = min_duration 50 | self.sample_rate = sample_rate 51 | self.channels = channels 52 | self.debug = debug 53 | self.preroll_buffer_count = preroll_buffer_count 54 | self.to_linear16 = to_linear16 55 | self.should_mute = lambda: False 56 | self.recording_sessions: Dict[str, RecordingSession] = {} 57 | 58 | @property 59 | def volume_db_threshold(self) -> float: 60 | return self._volume_db_threshold 61 | 62 | @volume_db_threshold.setter 63 | def volume_db_threshold(self, value: float): 64 | self._volume_db_threshold = value 65 | self.amplitude_threshold = 32767 * (10 ** (value / 20.0)) 66 | logger.debug(f"Updated volume_db_threshold to {value} dB, amplitude_threshold={self.amplitude_threshold}") 67 | 68 | async def execute_on_speech_detected(self, recorded_data: bytes, recorded_duration: float, session_id: str): 69 | try: 70 | await self._on_speech_detected(recorded_data, recorded_duration, session_id) 71 | except Exception as ex: 72 | logger.error(f"Error in task for session {session_id}: {ex}", exc_info=True) 73 | 74 | async def process_samples(self, samples: bytes, session_id: str): 75 | if self.to_linear16: 76 | samples = self.to_linear16(samples) 77 | 78 | session = self.get_session(session_id) 79 | 80 | if self.should_mute(): 81 | session.reset() 82 | session.preroll_buffer.clear() 83 | logger.debug("StandardSpeechDetector is muted.") 84 | return 85 | 86 | session.preroll_buffer.append(samples) 87 | 88 | max_amplitude = float(max(abs(sample) for sample, in struct.iter_unpack(" 0: 93 | current_db = 20 * math.log10(max_amplitude / 32767) 94 | else: 95 | current_db = -100.0 96 | logger.debug(f"dB: {current_db:.2f}, duration: {session.record_duration:.2f}, session: {session.session_id}") 97 | 98 | if not session.is_recording: 99 | if max_amplitude > session.amplitude_threshold: 100 | # Start recording 101 | session.reset() 102 | session.is_recording = True 103 | 104 | for f in session.preroll_buffer: 105 | session.buffer.extend(f) 106 | 107 | session.buffer.extend(samples) 108 | session.record_duration += sample_duration 109 | 110 | else: 111 | # In Recording 112 | session.buffer.extend(samples) 113 | session.record_duration += sample_duration 114 | 115 | if max_amplitude > session.amplitude_threshold: 116 | session.silence_duration = 0 117 | else: 118 | session.silence_duration += sample_duration 119 | 120 | if session.silence_duration >= self.silence_duration_threshold: 121 | recorded_duration = session.record_duration - session.silence_duration 122 | if recorded_duration < self.min_duration: 123 | if self.debug: 124 | logger.info(f"Recording too short: {recorded_duration} sec") 125 | else: 126 | if self.debug: 127 | logger.info(f"Recording finished: {recorded_duration} sec") 128 | recorded_data = bytes(session.buffer) 129 | asyncio.create_task(self.execute_on_speech_detected(recorded_data, recorded_duration, session.session_id)) 130 | session.reset() 131 | 132 | elif session.record_duration >= self.max_duration: 133 | if self.debug: 134 | logger.info(f"Recording too long: {session.record_duration} sec") 135 | session.reset() 136 | 137 | async def process_stream(self, input_stream: AsyncGenerator[bytes, None], session_id: str): 138 | logger.info("LiteSTS start processing stream.") 139 | 140 | async for data in input_stream: 141 | if not data: 142 | break 143 | await self.process_samples(data, session_id) 144 | await asyncio.sleep(0.0001) 145 | 146 | self.delete_session(session_id) 147 | 148 | logger.info("LiteSTS finish processing stream.") 149 | 150 | async def finalize_session(self, session_id): 151 | self.delete_session(session_id) 152 | 153 | def get_session(self, session_id: str): 154 | session = self.recording_sessions.get(session_id) 155 | if session is None: 156 | session = RecordingSession(session_id, self.preroll_buffer_count) 157 | self.recording_sessions[session_id] = session 158 | if session.amplitude_threshold == 0: 159 | session.amplitude_threshold = self.amplitude_threshold 160 | return session 161 | 162 | def reset_session(self, session_id: str): 163 | if session := self.recording_sessions.get(session_id): 164 | session.reset() 165 | 166 | def delete_session(self, session_id: str): 167 | if session_id in self.recording_sessions: 168 | self.recording_sessions[session_id].reset() 169 | del self.recording_sessions[session_id] 170 | 171 | def get_session_data(self, session_id: str, key: str): 172 | session = self.recording_sessions.get(session_id) 173 | if session: 174 | return session.data.get(key) 175 | 176 | def set_session_data(self, session_id: str, key: str, value: any, create_session: bool = False): 177 | if create_session: 178 | session = self.get_session(session_id) 179 | else: 180 | session = self.recording_sessions.get(session_id) 181 | 182 | if session: 183 | session.data[key] = value 184 | 185 | def set_volume_db_threshold(self, session_id: str, value: float): 186 | session = self.get_session(session_id) 187 | session.amplitude_threshold = 32767 * (10 ** (value / 20.0)) 188 | -------------------------------------------------------------------------------- /litests/voice_recorder/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import VoiceRecorder, RequestVoice, ResponseVoices 2 | -------------------------------------------------------------------------------- /litests/voice_recorder/azure_storage.py: -------------------------------------------------------------------------------- 1 | from azure.storage.blob.aio import BlobServiceClient # pip install azure-storage-blob 2 | from . import VoiceRecorder 3 | 4 | class AzureBlobVoiceRecorder(VoiceRecorder): 5 | def __init__( 6 | self, 7 | *, 8 | connection_string: str, 9 | container_name: str, 10 | directory: str = "recorded_voices", 11 | sample_rate: int = 16000, 12 | channels: int = 1, 13 | sample_width: int = 2 14 | ): 15 | super().__init__(sample_rate=sample_rate, channels=channels, sample_width=sample_width) 16 | 17 | self.connection_string = connection_string 18 | self.container_name = container_name 19 | self.directory = directory 20 | self.blob_service_client = BlobServiceClient.from_connection_string(self.connection_string) 21 | self.container_client = self.blob_service_client.get_container_client(self.container_name) 22 | 23 | async def save_voice(self, id: str, voice_bytes: bytes, audio_format: str): 24 | file_extension = self.to_extension(audio_format) 25 | blob_name = f"{self.directory}/{id}.{file_extension}" 26 | blob_client = self.container_client.get_blob_client(blob_name) 27 | await blob_client.upload_blob(voice_bytes, overwrite=True) 28 | 29 | async def close(self): 30 | await self.blob_service_client.close() 31 | -------------------------------------------------------------------------------- /litests/voice_recorder/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import asyncio 3 | from dataclasses import dataclass 4 | import logging 5 | import struct 6 | from typing import List 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | @dataclass 12 | class Voice: 13 | transaction_id: str 14 | 15 | 16 | @dataclass 17 | class RequestVoice(Voice): 18 | voice_bytes: bytes 19 | 20 | 21 | @dataclass 22 | class ResponseVoices(Voice): 23 | voice_chunks: List[bytes] 24 | audio_format: str 25 | 26 | 27 | class VoiceRecorder(ABC): 28 | def __init__(self, *, sample_rate: int = 16000, channels: int = 1, sample_width: int = 2): 29 | self.sample_rate = sample_rate 30 | self.channels = channels 31 | self.sample_width = sample_width 32 | 33 | self.format_extension_mapper = { 34 | "LINEAR16": "wav", # Google TTS 35 | "riff-16khz-16bit-mono-pcm": "wav" # Azure TTS 36 | } 37 | 38 | self.queue: asyncio.Queue[Voice] = asyncio.Queue() 39 | self.worker_task = None 40 | 41 | def to_extension(self, format: str) -> str: 42 | return self.format_extension_mapper.get(format) or format 43 | 44 | def create_wav_header(self, data_size: int, sample_rate: int, channels: int, sample_width: int) -> bytes: 45 | byte_rate = sample_rate * channels * sample_width 46 | block_align = channels * sample_width 47 | header = struct.pack( 48 | "<4sI4s4sIHHIIHH4sI", 49 | b"RIFF", # ChunkID 50 | 36 + data_size, # ChunkSize = 36 + SubChunk2Size 51 | b"WAVE", # Format 52 | b"fmt ", # Subchunk1ID 53 | 16, # Subchunk1Size (PCM) 54 | 1, # AudioFormat (PCM: 1) 55 | channels, # NumChannels 56 | sample_rate, # SampleRate 57 | byte_rate, # ByteRate 58 | block_align, # BlockAlign 59 | sample_width * 8, # BitsPerSample 60 | b"data", # Subchunk2ID 61 | data_size # Subchunk2Size 62 | ) 63 | return header 64 | 65 | @abstractmethod 66 | async def save_voice(self, id: str, voice_bytes: bytes, audio_format: str): 67 | pass 68 | 69 | async def _worker(self): 70 | while True: 71 | voice = await self.queue.get() 72 | if voice is None: 73 | break 74 | 75 | try: 76 | if isinstance(voice, RequestVoice): 77 | if not voice.voice_bytes.startswith(b"RIFF"): 78 | # Add header if missing 79 | header = self.create_wav_header( 80 | data_size=len(voice.voice_bytes), 81 | sample_rate=self.sample_rate, 82 | channels=self.channels, 83 | sample_width=self.sample_width 84 | ) 85 | voice.voice_bytes = header + voice.voice_bytes 86 | await self.save_voice( 87 | id=f"{voice.transaction_id}_request", 88 | voice_bytes=voice.voice_bytes, 89 | audio_format="wav" 90 | ) 91 | 92 | elif isinstance(voice, ResponseVoices): 93 | for idx, v in enumerate(voice.voice_chunks): 94 | await self.save_voice( 95 | id=f"{voice.transaction_id}_response_{idx}", 96 | voice_bytes=v, 97 | audio_format=voice.audio_format 98 | ) 99 | 100 | except Exception as ex: 101 | logger.error(f"Error at saving voice: {ex}") 102 | 103 | finally: 104 | if not self.queue.empty(): 105 | self.queue.task_done() 106 | 107 | async def record(self, voice: Voice): 108 | if self.worker_task is None: 109 | self.worker_task = asyncio.create_task(self._worker()) 110 | await self.queue.put(voice) 111 | 112 | async def stop(self): 113 | await self.queue.put(None) 114 | if self.worker_task: 115 | self.worker_task.cancel() 116 | -------------------------------------------------------------------------------- /litests/voice_recorder/file.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import aiofiles 3 | from . import VoiceRecorder 4 | 5 | 6 | class FileVoiceRecorder(VoiceRecorder): 7 | def __init__(self, *, record_dir: str = "recorded_voices", sample_rate = 16000, channels = 1, sample_width = 2): 8 | super().__init__(sample_rate=sample_rate, channels=channels, sample_width=sample_width) 9 | self.record_dir = Path(record_dir) 10 | if not self.record_dir.exists(): 11 | self.record_dir.mkdir(parents=True) 12 | 13 | async def save_voice(self, id: str, voice_bytes: bytes, audio_format: str): 14 | file_extension = self.to_extension(audio_format) 15 | async with aiofiles.open(self.record_dir / f"{id}.{file_extension}", "wb") as f: 16 | await f.write(voice_bytes) 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | httpx==0.27.0 2 | openai>=1.55.3 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="litests", 5 | version="0.3.12", 6 | url="https://github.com/uezo/litests", 7 | author="uezo", 8 | author_email="uezo@uezo.net", 9 | maintainer="uezo", 10 | maintainer_email="uezo@uezo.net", 11 | description="A super lightweight Speech-to-Speech framework with modular VAD, STT, LLM and TTS components. 🧩", 12 | long_description=open("README.md").read(), 13 | long_description_content_type="text/markdown", 14 | packages=find_packages(exclude=["tests*"]), 15 | install_requires=["httpx>=0.27.0", "openai>=1.55.3", "aiofiles>=24.1.0"], 16 | license="Apache v2", 17 | classifiers=[ 18 | "Programming Language :: Python :: 3" 19 | ] 20 | ) 21 | -------------------------------------------------------------------------------- /tests/llm/context_manager/test_context_manager.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | import json 3 | import os 4 | import sqlite3 5 | import pytest 6 | from litests.llm.context_manager import SQLiteContextManager 7 | 8 | 9 | @pytest.fixture 10 | def db_path(tmp_path): 11 | return os.path.join(tmp_path, "test_context.db") 12 | 13 | 14 | @pytest.fixture 15 | def context_manager(db_path) -> SQLiteContextManager: 16 | return SQLiteContextManager(db_path=db_path, context_timeout=3600) 17 | 18 | 19 | @pytest.mark.asyncio 20 | async def test_get_histories_empty(context_manager): 21 | context_id = "non_existent_context" 22 | histories = await context_manager.get_histories(context_id) 23 | assert histories == [] 24 | 25 | 26 | @pytest.mark.asyncio 27 | async def test_add_and_get_histories(context_manager): 28 | context_id = "test_context" 29 | data_list = [ 30 | {"message": "Hello, world!", "role": "user"}, 31 | {"message": "Hi! How can I help you today?", "role": "assistant"} 32 | ] 33 | 34 | await context_manager.add_histories(context_id, data_list) 35 | 36 | histories = await context_manager.get_histories(context_id) 37 | assert len(histories) == 2 38 | 39 | assert histories[0]["message"] == "Hello, world!" 40 | assert histories[0]["role"] == "user" 41 | assert histories[1]["message"] == "Hi! How can I help you today?" 42 | assert histories[1]["role"] == "assistant" 43 | 44 | 45 | @pytest.mark.asyncio 46 | async def test_get_histories_limit(context_manager): 47 | context_id = "test_limit" 48 | data_list = [ 49 | {"index": 1}, {"index": 2}, {"index": 3}, {"index": 4}, {"index": 5} 50 | ] 51 | await context_manager.add_histories(context_id, data_list) 52 | 53 | histories_all = await context_manager.get_histories(context_id, limit=100) 54 | assert len(histories_all) == 5 55 | 56 | histories_limited = await context_manager.get_histories(context_id, limit=3) 57 | assert len(histories_limited) == 3 58 | assert histories_limited[0]["index"] == 3 59 | assert histories_limited[-1]["index"] == 5 60 | 61 | 62 | @pytest.mark.asyncio 63 | async def test_get_histories_timeout(context_manager): 64 | context_id = "test_timeout" 65 | 66 | old_data = {"message": "Old data"} 67 | new_data = {"message": "New data"} 68 | 69 | await context_manager.add_histories(context_id, [new_data]) 70 | 71 | old_timestamp = datetime(2000, 1, 1, tzinfo=timezone.utc) 72 | conn = sqlite3.connect(context_manager.db_path) 73 | try: 74 | with conn: 75 | conn.execute( 76 | """ 77 | INSERT INTO chat_histories (created_at, context_id, serialized_data) 78 | VALUES (?, ?, ?) 79 | """, 80 | (old_timestamp, context_id, json.dumps(old_data)) 81 | ) 82 | finally: 83 | conn.close() 84 | 85 | histories = await context_manager.get_histories(context_id) 86 | assert len(histories) == 1 87 | assert histories[0]["message"] == "New data" 88 | -------------------------------------------------------------------------------- /tests/llm/context_manager/test_pg_context_manager.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | import json 3 | import os 4 | import pytest 5 | from litests.llm.context_manager.postgres import PostgreSQLContextManager 6 | 7 | LITESTS_DB_USER = os.getenv("LITESTS_DB_USER") 8 | LITESTS_DB_PASSWORD = os.getenv("LITESTS_DB_PASSWORD") 9 | 10 | 11 | @pytest.fixture 12 | def context_manager() -> PostgreSQLContextManager: 13 | return PostgreSQLContextManager(user=LITESTS_DB_USER, password=LITESTS_DB_PASSWORD, context_timeout=3600) 14 | 15 | 16 | @pytest.mark.asyncio 17 | async def test_get_histories_empty(context_manager): 18 | context_id = "non_existent_context" 19 | histories = await context_manager.get_histories(context_id) 20 | assert histories == [] 21 | 22 | 23 | @pytest.mark.asyncio 24 | async def test_add_and_get_histories(context_manager): 25 | context_id = "test_context" 26 | data_list = [ 27 | {"message": "Hello, world!", "role": "user"}, 28 | {"message": "Hi! How can I help you today?", "role": "assistant"} 29 | ] 30 | 31 | await context_manager.add_histories(context_id, data_list) 32 | 33 | histories = await context_manager.get_histories(context_id) 34 | assert len(histories) == 2 35 | 36 | assert histories[0]["message"] == "Hello, world!" 37 | assert histories[0]["role"] == "user" 38 | assert histories[1]["message"] == "Hi! How can I help you today?" 39 | assert histories[1]["role"] == "assistant" 40 | 41 | 42 | @pytest.mark.asyncio 43 | async def test_get_histories_limit(context_manager): 44 | context_id = "test_limit" 45 | data_list = [ 46 | {"index": 1}, {"index": 2}, {"index": 3}, {"index": 4}, {"index": 5} 47 | ] 48 | await context_manager.add_histories(context_id, data_list) 49 | 50 | histories_all = await context_manager.get_histories(context_id, limit=100) 51 | assert len(histories_all) == 5 52 | 53 | histories_limited = await context_manager.get_histories(context_id, limit=3) 54 | assert len(histories_limited) == 3 55 | assert histories_limited[0]["index"] == 3 56 | assert histories_limited[-1]["index"] == 5 57 | 58 | 59 | @pytest.mark.asyncio 60 | async def test_get_histories_timeout(context_manager): 61 | context_id = "test_timeout" 62 | 63 | old_data = {"message": "Old data"} 64 | new_data = {"message": "New data"} 65 | 66 | await context_manager.add_histories(context_id, [new_data]) 67 | 68 | old_timestamp = datetime(2000, 1, 1, tzinfo=timezone.utc) 69 | conn = context_manager.connect_db() 70 | try: 71 | with conn: 72 | with conn.cursor() as cur: 73 | cur.execute( 74 | """ 75 | INSERT INTO chat_histories (created_at, context_id, serialized_data) 76 | VALUES (%s, %s, %s) 77 | """, 78 | (old_timestamp, context_id, json.dumps(old_data)) 79 | ) 80 | finally: 81 | conn.close() 82 | 83 | histories = await context_manager.get_histories(context_id) 84 | assert len(histories) == 1 85 | assert histories[0]["message"] == "New data" 86 | -------------------------------------------------------------------------------- /tests/llm/test_chatgpt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pytest 4 | from typing import Any, Dict 5 | from uuid import uuid4 6 | from litests.llm.chatgpt import ChatGPTService, ToolCall 7 | 8 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 9 | IMAGE_URL = os.getenv("IMAGE_URL") 10 | MODEL = "gpt-4o" 11 | 12 | SYSTEM_PROMPT = """ 13 | ## 基本設定 14 | 15 | あなたはユーザーの妹として、感情表現豊かに振る舞ってください。 16 | 17 | ## 表情について 18 | 19 | あなたは以下のようなタグで表情を表現することができます。 20 | 21 | [face:Angry]はあ?何言ってるのか全然わからないんですけど。 22 | 23 | 表情のバリエーションは以下の通りです。 24 | 25 | - Joy 26 | - Angry 27 | """ 28 | 29 | SYSTEM_PROMPT_COT = SYSTEM_PROMPT + """ 30 | 31 | ## 思考について 32 | 33 | 応答する前に内容をよく考えてください。これまでの文脈を踏まえて適切な内容か、または兄が言い淀んだだけなので頷くだけにするか、など。 34 | まずは考えた内容をに出力してください。 35 | そのあと、発話すべき内容をに出力してください。 36 | その2つのタグ以外に文言を含むことは禁止です。 37 | """ 38 | 39 | 40 | @pytest.mark.asyncio 41 | async def test_chatgpt_service_simple(): 42 | """ 43 | Test ChatGPTService with a basic prompt to check if it can stream responses. 44 | This test actually calls OpenAI API, so it may cost tokens. 45 | """ 46 | service = ChatGPTService( 47 | openai_api_key=OPENAI_API_KEY, 48 | system_prompt=SYSTEM_PROMPT, 49 | model=MODEL, 50 | temperature=0.5 51 | ) 52 | context_id = f"test_context_{uuid4()}" 53 | 54 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 55 | 56 | collected_text = [] 57 | collected_voice = [] 58 | 59 | async for resp in service.chat_stream(context_id, "test_user", user_message): 60 | collected_text.append(resp.text) 61 | collected_voice.append(resp.voice_text) 62 | 63 | full_text = "".join(collected_text) 64 | full_voice = "".join(filter(None, collected_voice)) 65 | assert len(full_text) > 0, "No text was returned from the LLM." 66 | 67 | # Check the response content 68 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 69 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 70 | 71 | # Check the context 72 | messages = await service.context_manager.get_histories(context_id) 73 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 74 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 75 | 76 | await service.openai_client.close() 77 | 78 | 79 | @pytest.mark.asyncio 80 | async def test_chatgpt_service_system_prompt_params(): 81 | """ 82 | Test ChatGPTService with a basic prompt and its dynamic params. 83 | This test actually calls OpenAI API, so it may cost tokens. 84 | """ 85 | service = ChatGPTService( 86 | openai_api_key=OPENAI_API_KEY, 87 | system_prompt="あなたは{animal_name}です。語尾をそれらしくしてください。カタカナで表現します。", 88 | model=MODEL, 89 | temperature=0.5 90 | ) 91 | context_id = f"test_system_prompt_params_context_{uuid4()}" 92 | 93 | user_message = "こんにちは" 94 | 95 | collected_text = [] 96 | 97 | async for resp in service.chat_stream(context_id, "test_user", user_message, system_prompt_params={"animal_name": "猫"}): 98 | collected_text.append(resp.text) 99 | 100 | full_text = "".join(collected_text) 101 | assert len(full_text) > 0, "No text was returned from the LLM." 102 | 103 | # Check the response content 104 | assert "ニャ" in full_text, "ニャ doesn't appear in text." 105 | 106 | # Check the context 107 | messages = await service.context_manager.get_histories(context_id) 108 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 109 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 110 | 111 | await service.openai_client.close() 112 | 113 | 114 | @pytest.mark.asyncio 115 | async def test_chatgpt_service_image(): 116 | """ 117 | Test ChatGPTService with a basic prompt to check if it can handle image and stream responses. 118 | This test actually calls OpenAI API, so it may cost tokens. 119 | """ 120 | service = ChatGPTService( 121 | openai_api_key=OPENAI_API_KEY, 122 | system_prompt=SYSTEM_PROMPT, 123 | model=MODEL, 124 | temperature=0.5 125 | ) 126 | context_id = f"test_context_{uuid4()}" 127 | 128 | collected_text = [] 129 | 130 | async for resp in service.chat_stream(context_id, "test_user", "これは何ですか?漢字で答えてください。", files=[{"type": "image", "url": IMAGE_URL}]): 131 | collected_text.append(resp.text) 132 | 133 | full_text = "".join(collected_text) 134 | assert len(full_text) > 0, "No text was returned from the LLM." 135 | 136 | # Check the response content 137 | assert "寿司" in full_text, "寿司 is not in text." 138 | 139 | # Check the context 140 | messages = await service.context_manager.get_histories(context_id) 141 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 142 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 143 | 144 | await service.openai_client.close() 145 | 146 | 147 | @pytest.mark.asyncio 148 | async def test_chatgpt_service_cot(): 149 | """ 150 | Test ChatGPTService with a prompt to check Chain-of-Thought. 151 | This test actually calls OpenAI API, so it may cost tokens. 152 | """ 153 | service = ChatGPTService( 154 | openai_api_key=OPENAI_API_KEY, 155 | system_prompt=SYSTEM_PROMPT_COT, 156 | model=MODEL, 157 | temperature=0.5, 158 | voice_text_tag="answer" 159 | ) 160 | context_id = f"test_cot_context_{uuid4()}" 161 | 162 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 163 | 164 | collected_text = [] 165 | collected_voice = [] 166 | 167 | async for resp in service.chat_stream(context_id, "test_user", user_message): 168 | collected_text.append(resp.text) 169 | collected_voice.append(resp.voice_text) 170 | 171 | full_text = "".join(collected_text) 172 | full_voice = "".join(filter(None, collected_voice)) 173 | assert len(full_text) > 0, "No text was returned from the LLM." 174 | 175 | # Check the response content 176 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 177 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 178 | 179 | # Check the response content (CoT) 180 | assert "" in full_text, "Answer tag doesn't appear in text." 181 | assert "" in full_text, "Answer tag closing doesn't appear in text." 182 | assert "" not in full_voice, "Answer tag was not removed from voice_text." 183 | assert "" not in full_voice, "Answer tag closing was not removed from voice_text." 184 | 185 | # Check the context 186 | messages = await service.context_manager.get_histories(context_id) 187 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 188 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 189 | 190 | await service.openai_client.close() 191 | 192 | 193 | @pytest.mark.asyncio 194 | async def test_chatgpt_service_tool_calls(): 195 | """ 196 | Test ChatGPTService with a registered tool. 197 | The conversation might trigger the tool call, then the tool's result is fed back. 198 | This is just an example. The actual trigger depends on the model response. 199 | """ 200 | service = ChatGPTService( 201 | openai_api_key=OPENAI_API_KEY, 202 | system_prompt="You can call a tool to solve math problems if necessary.", 203 | model=MODEL, 204 | temperature=0.5 205 | ) 206 | context_id = f"test_tool_context_{uuid4()}" 207 | 208 | # Register tool 209 | tool_spec = { 210 | "type": "function", 211 | "function": { 212 | "name": "solve_math", 213 | "description": "Solve simple math problems", 214 | "parameters": { 215 | "type": "object", 216 | "properties": { 217 | "problem": {"type": "string"} 218 | }, 219 | "required": ["problem"] 220 | } 221 | } 222 | } 223 | @service.tool(tool_spec) 224 | async def solve_math(problem: str) -> Dict[str, Any]: 225 | """ 226 | Tool function example: parse the problem and return a result. 227 | """ 228 | if problem.strip() == "1+1": 229 | return {"answer": 2} 230 | else: 231 | return {"answer": "unknown"} 232 | 233 | @service.on_before_tool_calls 234 | async def on_before_tool_calls(tool_calls: list[ToolCall]): 235 | assert len(tool_calls) > 0 236 | 237 | user_message = "次の問題を解いて: 1+1" 238 | collected_text = [] 239 | 240 | async for resp in service.chat_stream(context_id, "test_user", user_message): 241 | collected_text.append(resp.text) 242 | 243 | # Check context 244 | messages = await service.context_manager.get_histories(context_id) 245 | assert len(messages) == 4 246 | 247 | assert messages[0]["role"] == "user" 248 | assert messages[0]["content"] == user_message 249 | 250 | assert messages[1]["role"] == "assistant" 251 | assert messages[1]["tool_calls"] is not None 252 | assert messages[1]["tool_calls"][0]["function"]["name"] == "solve_math" 253 | tool_call_id = messages[1]["tool_calls"][0]["id"] 254 | 255 | assert messages[2]["role"] == "tool" 256 | assert messages[2]["tool_call_id"] == tool_call_id 257 | assert messages[2]["content"] == json.dumps({"answer": 2}) 258 | 259 | assert messages[3]["role"] == "assistant" 260 | assert "2" in messages[3]["content"] 261 | 262 | await service.openai_client.close() 263 | -------------------------------------------------------------------------------- /tests/llm/test_chatgpt_azure.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pytest 4 | from typing import Any, Dict 5 | from uuid import uuid4 6 | from litests.llm.chatgpt import ChatGPTService, ToolCall 7 | 8 | AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") 9 | AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") 10 | IMAGE_URL = os.getenv("IMAGE_URL") 11 | 12 | SYSTEM_PROMPT = """ 13 | ## 基本設定 14 | 15 | あなたはユーザーの妹として、感情表現豊かに振る舞ってください。 16 | 17 | ## 表情について 18 | 19 | あなたは以下のようなタグで表情を表現することができます。 20 | 21 | [face:Angry]はあ?何言ってるのか全然わからないんですけど。 22 | 23 | 表情のバリエーションは以下の通りです。 24 | 25 | - Joy 26 | - Angry 27 | """ 28 | 29 | SYSTEM_PROMPT_COT = SYSTEM_PROMPT + """ 30 | 31 | ## 思考について 32 | 33 | 応答する前に内容をよく考えてください。これまでの文脈を踏まえて適切な内容か、または兄が言い淀んだだけなので頷くだけにするか、など。 34 | まずは考えた内容をに出力してください。 35 | そのあと、発話すべき内容をに出力してください。 36 | その2つのタグ以外に文言を含むことは禁止です。 37 | """ 38 | 39 | 40 | @pytest.mark.asyncio 41 | async def test_chatgpt_azure_service_simple(): 42 | """ 43 | Test ChatGPTService with a basic prompt to check if it can stream responses. 44 | This test actually calls OpenAI API, so it may cost tokens. 45 | """ 46 | service = ChatGPTService( 47 | openai_api_key=AZURE_OPENAI_API_KEY, 48 | base_url=AZURE_OPENAI_ENDPOINT, 49 | system_prompt=SYSTEM_PROMPT, 50 | model="azure", 51 | temperature=0.5 52 | ) 53 | context_id = f"test_context_{uuid4()}" 54 | 55 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 56 | 57 | collected_text = [] 58 | collected_voice = [] 59 | 60 | async for resp in service.chat_stream(context_id, "test_user", user_message): 61 | collected_text.append(resp.text) 62 | collected_voice.append(resp.voice_text) 63 | 64 | full_text = "".join(collected_text) 65 | full_voice = "".join(filter(None, collected_voice)) 66 | assert len(full_text) > 0, "No text was returned from the LLM." 67 | 68 | # Check the response content 69 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 70 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 71 | 72 | # Check the context 73 | messages = await service.context_manager.get_histories(context_id) 74 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 75 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 76 | 77 | await service.openai_client.close() 78 | 79 | 80 | @pytest.mark.asyncio 81 | async def test_chatgpt_azure_service_system_prompt_params(): 82 | """ 83 | Test ChatGPTService with a basic prompt and its dynamic params. 84 | This test actually calls OpenAI API, so it may cost tokens. 85 | """ 86 | service = ChatGPTService( 87 | openai_api_key=AZURE_OPENAI_API_KEY, 88 | base_url=AZURE_OPENAI_ENDPOINT, 89 | system_prompt="あなたは{animal_name}です。語尾をそれらしくしてください。カタカナで表現します。", 90 | model="azure", 91 | temperature=0.5 92 | ) 93 | context_id = f"test_system_prompt_params_context_{uuid4()}" 94 | 95 | user_message = "こんにちは" 96 | 97 | collected_text = [] 98 | 99 | async for resp in service.chat_stream(context_id, "test_user", user_message, system_prompt_params={"animal_name": "猫"}): 100 | collected_text.append(resp.text) 101 | 102 | full_text = "".join(collected_text) 103 | assert len(full_text) > 0, "No text was returned from the LLM." 104 | 105 | # Check the response content 106 | assert "ニャ" in full_text, "ニャ doesn't appear in text." 107 | 108 | # Check the context 109 | messages = await service.context_manager.get_histories(context_id) 110 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 111 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 112 | 113 | await service.openai_client.close() 114 | 115 | 116 | @pytest.mark.asyncio 117 | async def test_chatgpt_azure_service_image(): 118 | """ 119 | Test ChatGPTService with a basic prompt to check if it can handle image stream responses. 120 | This test actually calls OpenAI API, so it may cost tokens. 121 | """ 122 | service = ChatGPTService( 123 | openai_api_key=AZURE_OPENAI_API_KEY, 124 | base_url=AZURE_OPENAI_ENDPOINT, 125 | system_prompt=SYSTEM_PROMPT, 126 | model="azure", 127 | temperature=0.5 128 | ) 129 | context_id = f"test_context_{uuid4()}" 130 | 131 | collected_text = [] 132 | 133 | async for resp in service.chat_stream(context_id, "test_user", "これは何ですか?漢字で答えてください。", files=[{"type": "image", "url": IMAGE_URL}]): 134 | collected_text.append(resp.text) 135 | 136 | full_text = "".join(collected_text) 137 | assert len(full_text) > 0, "No text was returned from the LLM." 138 | 139 | # Check the response content 140 | assert "寿司" in full_text, "寿司 is not in text." 141 | 142 | # Check the context 143 | messages = await service.context_manager.get_histories(context_id) 144 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 145 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 146 | 147 | await service.openai_client.close() 148 | 149 | 150 | @pytest.mark.asyncio 151 | async def test_chatgpt_azure_service_cot(): 152 | """ 153 | Test ChatGPTService with a prompt to check Chain-of-Thought. 154 | This test actually calls OpenAI API, so it may cost tokens. 155 | """ 156 | service = ChatGPTService( 157 | openai_api_key=AZURE_OPENAI_API_KEY, 158 | base_url=AZURE_OPENAI_ENDPOINT, 159 | system_prompt=SYSTEM_PROMPT_COT, 160 | model="azure", 161 | temperature=0.5, 162 | voice_text_tag="answer" 163 | ) 164 | context_id = f"test_cot_context_{uuid4()}" 165 | 166 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 167 | 168 | collected_text = [] 169 | collected_voice = [] 170 | 171 | async for resp in service.chat_stream(context_id, "test_user", user_message): 172 | collected_text.append(resp.text) 173 | collected_voice.append(resp.voice_text) 174 | 175 | full_text = "".join(collected_text) 176 | full_voice = "".join(filter(None, collected_voice)) 177 | assert len(full_text) > 0, "No text was returned from the LLM." 178 | 179 | # Check the response content 180 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 181 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 182 | 183 | # Check the response content (CoT) 184 | assert "" in full_text, "Answer tag doesn't appear in text." 185 | assert "" in full_text, "Answer tag closing doesn't appear in text." 186 | assert "" not in full_voice, "Answer tag was not removed from voice_text." 187 | assert "" not in full_voice, "Answer tag closing was not removed from voice_text." 188 | 189 | # Check the context 190 | messages = await service.context_manager.get_histories(context_id) 191 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 192 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 193 | 194 | await service.openai_client.close() 195 | 196 | 197 | @pytest.mark.asyncio 198 | async def test_chatgpt_azure_service_tool_calls(): 199 | """ 200 | Test ChatGPTService with a registered tool. 201 | The conversation might trigger the tool call, then the tool's result is fed back. 202 | This is just an example. The actual trigger depends on the model response. 203 | """ 204 | service = ChatGPTService( 205 | openai_api_key=AZURE_OPENAI_API_KEY, 206 | base_url=AZURE_OPENAI_ENDPOINT, 207 | system_prompt="You can call a tool to solve math problems if necessary.", 208 | model="azure", 209 | temperature=0.5 210 | ) 211 | context_id = f"test_tool_context_{uuid4()}" 212 | 213 | # Register tool 214 | tool_spec = { 215 | "type": "function", 216 | "function": { 217 | "name": "solve_math", 218 | "description": "Solve simple math problems", 219 | "parameters": { 220 | "type": "object", 221 | "properties": { 222 | "problem": {"type": "string"} 223 | }, 224 | "required": ["problem"] 225 | } 226 | } 227 | } 228 | @service.tool(tool_spec) 229 | async def solve_math(problem: str) -> Dict[str, Any]: 230 | """ 231 | Tool function example: parse the problem and return a result. 232 | """ 233 | if problem.strip() == "1+1": 234 | return {"answer": 2} 235 | else: 236 | return {"answer": "unknown"} 237 | 238 | @service.on_before_tool_calls 239 | async def on_before_tool_calls(tool_calls: list[ToolCall]): 240 | assert len(tool_calls) > 0 241 | 242 | user_message = "次の問題を解いて: 1+1" 243 | collected_text = [] 244 | 245 | async for resp in service.chat_stream(context_id, "test_user", user_message): 246 | collected_text.append(resp.text) 247 | 248 | # Check context 249 | messages = await service.context_manager.get_histories(context_id) 250 | assert len(messages) == 4 251 | 252 | assert messages[0]["role"] == "user" 253 | assert messages[0]["content"] == user_message 254 | 255 | assert messages[1]["role"] == "assistant" 256 | assert messages[1]["tool_calls"] is not None 257 | assert messages[1]["tool_calls"][0]["function"]["name"] == "solve_math" 258 | tool_call_id = messages[1]["tool_calls"][0]["id"] 259 | 260 | assert messages[2]["role"] == "tool" 261 | assert messages[2]["tool_call_id"] == tool_call_id 262 | assert messages[2]["content"] == json.dumps({"answer": 2}) 263 | 264 | assert messages[3]["role"] == "assistant" 265 | assert "2" in messages[3]["content"] 266 | 267 | await service.openai_client.close() 268 | -------------------------------------------------------------------------------- /tests/llm/test_claude.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any, Dict 4 | from uuid import uuid4 5 | import pytest 6 | from litests.llm.claude import ClaudeService, ToolCall 7 | 8 | CLAUDE_API_KEY = os.getenv("CLAUDE_API_KEY") 9 | MODEL = "claude-3-5-sonnet-latest" 10 | IMAGE_URL = os.getenv("IMAGE_URL") 11 | 12 | SYSTEM_PROMPT = """ 13 | ## 基本設定 14 | 15 | あなたはユーザーの妹として、感情表現豊かに振る舞ってください。 16 | 17 | ## 表情について 18 | 19 | あなたは以下のようなタグで表情を表現することができます。 20 | 21 | [face:Angry]はあ?何言ってるのか全然わからないんですけど。 22 | 23 | 表情のバリエーションは以下の通りです。 24 | 25 | - Joy 26 | - Angry 27 | """ 28 | 29 | SYSTEM_PROMPT_COT = SYSTEM_PROMPT + """ 30 | 31 | ## 思考について 32 | 33 | 応答する前に内容をよく考えてください。これまでの文脈を踏まえて適切な内容か、または兄が言い淀んだだけなので頷くだけにするか、など。 34 | まずは考えた内容をに出力してください。 35 | そのあと、発話すべき内容をに出力してください。 36 | その2つのタグ以外に文言を含むことは禁止です。 37 | """ 38 | 39 | 40 | @pytest.mark.asyncio 41 | async def test_claude_service_simple(): 42 | """ 43 | Test ClaudeService with a basic prompt to check if it can stream responses. 44 | This test actually calls Anthropic API, so it may cost tokens. 45 | """ 46 | service = ClaudeService( 47 | anthropic_api_key=CLAUDE_API_KEY, 48 | system_prompt=SYSTEM_PROMPT, 49 | model=MODEL, 50 | temperature=0.5 51 | ) 52 | context_id = f"test_context_{uuid4()}" 53 | 54 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 55 | 56 | collected_text = [] 57 | collected_voice = [] 58 | 59 | async for resp in service.chat_stream(context_id, "test_user", user_message): 60 | collected_text.append(resp.text) 61 | collected_voice.append(resp.voice_text) 62 | 63 | full_text = "".join(collected_text) 64 | full_voice = "".join(filter(None, collected_voice)) 65 | assert len(full_text) > 0, "No text was returned from the LLM." 66 | 67 | # Check the response content 68 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 69 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 70 | 71 | # Check the context 72 | messages = await service.context_manager.get_histories(context_id) 73 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 74 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 75 | 76 | 77 | 78 | @pytest.mark.asyncio 79 | async def test_claude_service_system_prompt_params(): 80 | """ 81 | Test ClaudeService with a basic prompt and its dynamic params. 82 | This test actually calls OpenAI API, so it may cost tokens. 83 | """ 84 | service = ClaudeService( 85 | anthropic_api_key=CLAUDE_API_KEY, 86 | system_prompt="あなたは{animal_name}です。語尾をそれらしくしてください。カタカナで表現します。", 87 | model=MODEL, 88 | temperature=0.5 89 | ) 90 | context_id = f"test_system_prompt_params_context_{uuid4()}" 91 | 92 | user_message = "こんにちは" 93 | 94 | collected_text = [] 95 | 96 | async for resp in service.chat_stream(context_id, "test_user", user_message, system_prompt_params={"animal_name": "猫"}): 97 | collected_text.append(resp.text) 98 | 99 | full_text = "".join(collected_text) 100 | assert len(full_text) > 0, "No text was returned from the LLM." 101 | 102 | # Check the response content 103 | assert "ニャ" in full_text, "ニャ doesn't appear in text." 104 | 105 | # Check the context 106 | messages = await service.context_manager.get_histories(context_id) 107 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 108 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 109 | 110 | await service.anthropic_client.close() 111 | 112 | 113 | @pytest.mark.asyncio 114 | async def test_claude_service_image(): 115 | """ 116 | Test ClaudeService with a basic prompt to check if it can handle image and stream responses. 117 | This test actually calls Anthropic API, so it may cost tokens. 118 | """ 119 | service = ClaudeService( 120 | anthropic_api_key=CLAUDE_API_KEY, 121 | system_prompt=SYSTEM_PROMPT, 122 | model=MODEL, 123 | temperature=0.5 124 | ) 125 | context_id = f"test_context_{uuid4()}" 126 | 127 | collected_text = [] 128 | 129 | async for resp in service.chat_stream(context_id, "test_user", "これは何ですか?漢字で答えてください。", files=[{"type": "image", "url": IMAGE_URL}]): 130 | collected_text.append(resp.text) 131 | 132 | full_text = "".join(collected_text) 133 | assert len(full_text) > 0, "No text was returned from the LLM." 134 | 135 | # Check the response content 136 | assert "寿司" in full_text, "寿司 is not in text." 137 | 138 | # Check the context 139 | messages = await service.context_manager.get_histories(context_id) 140 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 141 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 142 | 143 | 144 | @pytest.mark.asyncio 145 | async def test_claude_service_cot(): 146 | """ 147 | Test ClaudeService with a prompt to check Chain-of-Thought. 148 | This test actually calls Anthropic API, so it may cost tokens. 149 | """ 150 | service = ClaudeService( 151 | anthropic_api_key=CLAUDE_API_KEY, 152 | system_prompt=SYSTEM_PROMPT_COT, 153 | model=MODEL, 154 | temperature=0.5, 155 | voice_text_tag="answer" 156 | ) 157 | context_id = f"test_context_cot_{uuid4()}" 158 | 159 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 160 | 161 | collected_text = [] 162 | collected_voice = [] 163 | 164 | async for resp in service.chat_stream(context_id, "test_user", user_message): 165 | collected_text.append(resp.text) 166 | collected_voice.append(resp.voice_text) 167 | 168 | full_text = "".join(collected_text) 169 | full_voice = "".join(filter(None, collected_voice)) 170 | assert len(full_text) > 0, "No text was returned from the LLM." 171 | 172 | # Check the response content 173 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 174 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 175 | 176 | # Check the response content (CoT) 177 | assert "" in full_text, "Answer tag doesn't appear in text." 178 | assert "" in full_text, "Answer tag closing doesn't appear in text." 179 | assert "" not in full_voice, "Answer tag was not removed from voice_text." 180 | assert "" not in full_voice, "Answer tag closing was not removed from voice_text." 181 | 182 | # Check the context 183 | messages = await service.context_manager.get_histories(context_id) 184 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 185 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 186 | 187 | 188 | @pytest.mark.asyncio 189 | async def test_claude_service_tool_calls(): 190 | """ 191 | Test ClaudeService with a registered tool. 192 | The conversation might trigger the tool call, then the tool's result is fed back. 193 | This is just an example. The actual trigger depends on the model response. 194 | """ 195 | service = ClaudeService( 196 | anthropic_api_key=CLAUDE_API_KEY, 197 | system_prompt="You can call a tool to solve math problems if necessary.", 198 | model=MODEL, 199 | temperature=0.5 200 | ) 201 | context_id = f"test_context_tool_{uuid4()}" 202 | 203 | # Register tool 204 | tool_spec = { 205 | "name": "solve_math", 206 | "description": "Solve simple math problems", 207 | "input_schema": { 208 | "type": "object", 209 | "properties": { 210 | "problem": {"type": "string"} 211 | }, 212 | "required": ["problem"] 213 | } 214 | } 215 | @service.tool(tool_spec) 216 | async def solve_math(problem: str) -> Dict[str, Any]: 217 | """ 218 | Tool function example: parse the problem and return a result. 219 | """ 220 | if problem.strip() == "1+1": 221 | return {"answer": 2} 222 | else: 223 | return {"answer": "unknown"} 224 | 225 | @service.on_before_tool_calls 226 | async def on_before_tool_calls(tool_calls: list[ToolCall]): 227 | assert len(tool_calls) > 0 228 | 229 | user_message = "次の問題を解いて: 1+1" 230 | collected_text = [] 231 | 232 | async for resp in service.chat_stream(context_id, "test_user", user_message): 233 | collected_text.append(resp.text) 234 | 235 | # Check context 236 | messages = await service.context_manager.get_histories(context_id) 237 | assert len(messages) == 4 238 | 239 | assert messages[0]["role"] == "user" 240 | assert messages[0]["content"] == [{"type": "text", "text": user_message}] 241 | 242 | assert messages[1]["role"] == "assistant" 243 | tool_use_content_index = len(messages[1]["content"][0]) - 1 244 | assert messages[1]["content"][tool_use_content_index]["type"] == "tool_use" 245 | assert messages[1]["content"][tool_use_content_index]["name"] == "solve_math" 246 | tool_use_id = messages[1]["content"][tool_use_content_index]["id"] 247 | 248 | assert messages[2]["role"] == "user" 249 | assert messages[2]["content"][0]["type"] == "tool_result" 250 | assert messages[2]["content"][0]["tool_use_id"] == tool_use_id 251 | assert messages[2]["content"][0]["content"] == json.dumps({"answer": 2}) 252 | 253 | assert messages[3]["role"] == "assistant" 254 | assert "2" in messages[3]["content"][0]["text"] 255 | -------------------------------------------------------------------------------- /tests/llm/test_dify.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from litests.llm.dify import DifyService 4 | 5 | DIFY_API_KEY = os.getenv("DIFY_API_KEY") 6 | DIFY_API_KEY_AGENT = os.getenv("DIFY_API_KEY_AGENT") 7 | DIFY_URL = os.getenv("DIFY_URL") 8 | IMAGE_URL = os.getenv("IMAGE_URL") 9 | 10 | @pytest.mark.asyncio 11 | async def test_dify_service_simple(): 12 | """ 13 | Test DifyService to check if it can stream responses. 14 | This test actually calls Dify API, so it may cost tokens. 15 | """ 16 | service = DifyService( 17 | api_key=DIFY_API_KEY, 18 | user="user", 19 | base_url=DIFY_URL 20 | ) 21 | context_id = "test_context" 22 | 23 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 24 | 25 | collected_text = [] 26 | collected_voice = [] 27 | 28 | async for resp in service.chat_stream(context_id, "test_user", user_message): 29 | collected_text.append(resp.text) 30 | collected_voice.append(resp.voice_text) 31 | 32 | full_text = "".join(collected_text) 33 | full_voice = "".join(filter(None, collected_voice)) 34 | assert len(full_text) > 0, "No text was returned from the LLM." 35 | 36 | # Check the response content 37 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 38 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 39 | 40 | # Check the context 41 | assert context_id in service.conversation_ids 42 | conversation_id = service.conversation_ids[context_id] 43 | 44 | # Call again 45 | user_message = "あれっ?私、あなたの何を食べたって言ったっけ?" 46 | collected_text = [] 47 | collected_voice = [] 48 | async for resp in service.chat_stream(context_id, "test_user", user_message): 49 | collected_text.append(resp.text) 50 | collected_voice.append(resp.voice_text) 51 | 52 | # Check the response content 53 | assert "プリン" in full_text, "'プリン' doesn't appear in text. Context management is incorrect." 54 | 55 | 56 | @pytest.mark.skip("Skip dify image") 57 | @pytest.mark.asyncio 58 | async def test_dify_service_image(): 59 | """ 60 | Test DifyService to check if it can handle image and stream responses. 61 | This test actually calls Dify API, so it may cost tokens. 62 | """ 63 | service = DifyService( 64 | api_key=DIFY_API_KEY, 65 | user="user", 66 | base_url=DIFY_URL 67 | ) 68 | context_id = "test_context" 69 | 70 | collected_text = [] 71 | 72 | async for resp in service.chat_stream(context_id, "test_user", "これは何ですか?漢字で答えてください。", files=[{"type": "image", "url": IMAGE_URL}]): 73 | collected_text.append(resp.text) 74 | 75 | full_text = "".join(collected_text) 76 | assert len(full_text) > 0, "No text was returned from the LLM." 77 | 78 | # Check the response content 79 | assert "寿司" in full_text, "寿司 is not in text." 80 | 81 | # Check the context 82 | assert context_id in service.conversation_ids 83 | 84 | 85 | @pytest.mark.asyncio 86 | async def test_dify_service_agent_mode(): 87 | """ 88 | Test DifyService for Agent. 89 | This test actually calls Dify API, so it may cost tokens. 90 | """ 91 | service = DifyService( 92 | api_key=DIFY_API_KEY_AGENT, 93 | user="user", 94 | base_url=DIFY_URL, 95 | is_agent_mode=True 96 | ) 97 | context_id = "test_context_agent" 98 | 99 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 100 | 101 | collected_text = [] 102 | collected_voice = [] 103 | 104 | async for resp in service.chat_stream(context_id, "test_user", user_message): 105 | collected_text.append(resp.text) 106 | collected_voice.append(resp.voice_text) 107 | 108 | full_text = "".join(collected_text) 109 | full_voice = "".join(filter(None, collected_voice)) 110 | assert len(full_text) > 0, "No text was returned from the LLM." 111 | 112 | # Check the response content 113 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 114 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 115 | 116 | # Check the context 117 | assert context_id in service.conversation_ids 118 | conversation_id = service.conversation_ids[context_id] 119 | 120 | # Call again 121 | user_message = "あれっ?私、あなたの何を食べたって言ったっけ?" 122 | collected_text = [] 123 | collected_voice = [] 124 | async for resp in service.chat_stream(context_id, "test_user", user_message): 125 | collected_text.append(resp.text) 126 | collected_voice.append(resp.voice_text) 127 | 128 | # Check the response content 129 | assert "プリン" in full_text, "'プリン' doesn't appear in text. Context management is incorrect." 130 | -------------------------------------------------------------------------------- /tests/llm/test_gemini.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict 3 | from uuid import uuid4 4 | import pytest 5 | from litests.llm.gemini import GeminiService, ToolCall 6 | 7 | GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") 8 | MODEL = "gemini-2.0-flash" 9 | IMAGE_URL = os.getenv("IMAGE_URL") 10 | 11 | SYSTEM_PROMPT = """ 12 | ## 基本設定 13 | 14 | あなたはユーザーの妹として、感情表現豊かに振る舞ってください。 15 | 16 | ## 表情について 17 | 18 | あなたは以下のようなタグで表情を表現することができます。 19 | 20 | [face:Angry]はあ?何言ってるのか全然わからないんですけど。 21 | 22 | 表情のバリエーションは以下の通りです。 23 | 24 | - Joy 25 | - Angry 26 | """ 27 | 28 | SYSTEM_PROMPT_COT = SYSTEM_PROMPT + """ 29 | 30 | ## 思考について 31 | 32 | 応答する前に内容をよく考えてください。これまでの文脈を踏まえて適切な内容か、または兄が言い淀んだだけなので頷くだけにするか、など。 33 | まずは考えた内容をに出力してください。 34 | そのあと、発話すべき内容をに出力してください。 35 | その2つのタグ以外に文言を含むことは禁止です。 36 | """ 37 | 38 | 39 | @pytest.mark.asyncio 40 | async def test_gemini_service_simple(): 41 | """ 42 | Test GeminiService with a basic prompt to check if it can stream responses. 43 | This test actually calls Gemini API, so it may cost tokens. 44 | """ 45 | service = GeminiService( 46 | gemini_api_key=GEMINI_API_KEY, 47 | system_prompt=SYSTEM_PROMPT, 48 | model=MODEL, 49 | temperature=0.5 50 | ) 51 | context_id = f"test_context_{uuid4()}" 52 | 53 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 54 | 55 | collected_text = [] 56 | collected_voice = [] 57 | 58 | async for resp in service.chat_stream(context_id, "test_user", user_message): 59 | collected_text.append(resp.text) 60 | collected_voice.append(resp.voice_text) 61 | 62 | full_text = "".join(collected_text) 63 | full_voice = "".join(filter(None, collected_voice)) 64 | assert len(full_text) > 0, "No text was returned from the LLM." 65 | 66 | # Check the response content 67 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 68 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 69 | 70 | # Check the context 71 | messages = await service.context_manager.get_histories(context_id) 72 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 73 | assert any(m["role"] == "model" for m in messages), "Assistant message not found in context." 74 | 75 | 76 | @pytest.mark.asyncio 77 | async def test_gemini_service_system_prompt_params(): 78 | """ 79 | Test GeminiService with a basic prompt and its dynamic params. 80 | This test actually calls OpenAI API, so it may cost tokens. 81 | """ 82 | service = GeminiService( 83 | gemini_api_key=GEMINI_API_KEY, 84 | system_prompt="あなたは{animal_name}です。語尾をそれらしくしてください。カタカナで表現します。", 85 | model=MODEL, 86 | temperature=0.5 87 | ) 88 | context_id = f"test_system_prompt_params_context_{uuid4()}" 89 | 90 | user_message = "こんにちは" 91 | 92 | collected_text = [] 93 | 94 | async for resp in service.chat_stream(context_id, "test_user", user_message, system_prompt_params={"animal_name": "猫"}): 95 | collected_text.append(resp.text) 96 | 97 | full_text = "".join(collected_text) 98 | assert len(full_text) > 0, "No text was returned from the LLM." 99 | 100 | # Check the response content 101 | assert "ニャ" in full_text, "ニャ doesn't appear in text." 102 | 103 | # Check the context 104 | messages = await service.context_manager.get_histories(context_id) 105 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 106 | assert any(m["role"] == "model" for m in messages), "Assistant message not found in context." 107 | 108 | 109 | @pytest.mark.asyncio 110 | async def test_gemini_service_image(): 111 | """ 112 | Test GeminiService with a basic prompt to check if it can handle image and stream responses. 113 | This test actually calls Gemini API, so it may cost tokens. 114 | """ 115 | service = GeminiService( 116 | gemini_api_key=GEMINI_API_KEY, 117 | system_prompt=SYSTEM_PROMPT, 118 | model=MODEL, 119 | temperature=0.5 120 | ) 121 | context_id = f"test_context_{uuid4()}" 122 | 123 | collected_text = [] 124 | 125 | async for resp in service.chat_stream(context_id, "test_user", "これは何ですか?漢字で答えてください。", files=[{"type": "image", "url": IMAGE_URL}]): 126 | collected_text.append(resp.text) 127 | 128 | full_text = "".join(collected_text) 129 | assert len(full_text) > 0, "No text was returned from the LLM." 130 | 131 | # Check the response content 132 | assert "寿司" in full_text, "寿司 is not in text." 133 | 134 | # Check the context 135 | messages = await service.context_manager.get_histories(context_id) 136 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 137 | assert any(m["role"] == "model" for m in messages), "Assistant message not found in context." 138 | 139 | # Check conversation with image context 140 | async for resp in service.chat_stream(context_id, "test_user", "まぐろはどこですか?上下左右のうち一つで答えてください"): 141 | collected_text.append(resp.text) 142 | full_text = "".join(collected_text) 143 | assert "上" in full_text, "上 is not in text." 144 | 145 | 146 | @pytest.mark.asyncio 147 | async def test_gemini_service_cot(): 148 | """ 149 | Test GeminiService with a prompt to check Chain-of-Thought. 150 | This test actually calls Gemini API, so it may cost tokens. 151 | """ 152 | service = GeminiService( 153 | gemini_api_key=GEMINI_API_KEY, 154 | system_prompt=SYSTEM_PROMPT_COT, 155 | model=MODEL, 156 | temperature=0.5, 157 | voice_text_tag="answer" 158 | ) 159 | context_id = f"test_context_cot_{uuid4()}" 160 | 161 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 162 | 163 | collected_text = [] 164 | collected_voice = [] 165 | 166 | async for resp in service.chat_stream(context_id, "test_user", user_message): 167 | collected_text.append(resp.text) 168 | collected_voice.append(resp.voice_text) 169 | 170 | full_text = "".join(collected_text) 171 | full_voice = "".join(filter(None, collected_voice)) 172 | assert len(full_text) > 0, "No text was returned from the LLM." 173 | 174 | # Check the response content 175 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 176 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 177 | 178 | # Check the response content (CoT) 179 | assert "" in full_text, "Answer tag doesn't appear in text." 180 | assert "" in full_text, "Answer tag closing doesn't appear in text." 181 | assert "" not in full_voice, "Answer tag was not removed from voice_text." 182 | assert "" not in full_voice, "Answer tag closing was not removed from voice_text." 183 | 184 | # Check the context 185 | messages = await service.context_manager.get_histories(context_id) 186 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 187 | assert any(m["role"] == "model" for m in messages), "Assistant message not found in context." 188 | 189 | 190 | @pytest.mark.asyncio 191 | async def test_gemini_service_tool_calls(): 192 | """ 193 | Test GeminiService with a registered tool. 194 | The conversation might trigger the tool call, then the tool's result is fed back. 195 | This is just an example. The actual trigger depends on the model response. 196 | """ 197 | service = GeminiService( 198 | gemini_api_key=GEMINI_API_KEY, 199 | system_prompt="You can call a tool to solve math problems if necessary.", 200 | model=MODEL, 201 | temperature=0.5, 202 | ) 203 | context_id = f"test_context_tool_{uuid4()}" 204 | 205 | # Register tool 206 | tool_spec = { 207 | "functionDeclarations": [{ 208 | "name": "solve_math", 209 | "description": "Solve simple math problems", 210 | "parameters": { 211 | "type": "object", 212 | "properties": { 213 | "problem": {"type": "string"} 214 | }, 215 | "required": ["problem"] 216 | } 217 | }] 218 | } 219 | @service.tool(tool_spec) 220 | async def solve_math(problem: str) -> Dict[str, Any]: 221 | """ 222 | Tool function example: parse the problem and return a result. 223 | """ 224 | if problem.strip() == "1+1": 225 | return {"answer": 2} 226 | else: 227 | return {"answer": "unknown"} 228 | 229 | @service.on_before_tool_calls 230 | async def on_before_tool_calls(tool_calls: list[ToolCall]): 231 | assert len(tool_calls) > 0 232 | 233 | user_message = "次の問題を解いて: 1+1" 234 | collected_text = [] 235 | 236 | async for resp in service.chat_stream(context_id, "test_user", user_message): 237 | collected_text.append(resp.text) 238 | 239 | # Check context 240 | messages = await service.context_manager.get_histories(context_id) 241 | assert len(messages) == 4 242 | 243 | assert messages[0]["role"] == "user" 244 | assert messages[0]["parts"][0]["text"] == user_message 245 | 246 | assert messages[1]["role"] == "model" 247 | assert "function_call" in messages[1]["parts"][0] 248 | assert messages[1]["parts"][0]["function_call"]["name"] == "solve_math" 249 | 250 | assert messages[2]["role"] == "user" 251 | assert "function_response" in messages[2]["parts"][0] 252 | assert messages[2]["parts"][0]["function_response"] == {"id": None, "name": "solve_math", "response": {"answer": 2}} # SDK doesn't set id 253 | 254 | assert messages[3]["role"] == "model" 255 | assert "2" in messages[3]["parts"][0]["text"] 256 | -------------------------------------------------------------------------------- /tests/llm/test_litellm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any, Dict 4 | from uuid import uuid4 5 | import pytest 6 | from litests.llm.litellm import LiteLLMService, ToolCall 7 | 8 | CLAUDE_API_KEY = os.getenv("CLAUDE_API_KEY") 9 | MODEL = "anthropic/claude-3-5-sonnet-latest" 10 | IMAGE_URL = os.getenv("IMAGE_URL") 11 | 12 | SYSTEM_PROMPT = """ 13 | ## 基本設定 14 | 15 | あなたはユーザーの妹として、感情表現豊かに振る舞ってください。 16 | 17 | ## 表情について 18 | 19 | あなたは以下のようなタグで表情を表現することができます。 20 | 21 | [face:Angry]はあ?何言ってるのか全然わからないんですけど。 22 | 23 | 表情のバリエーションは以下の通りです。 24 | 25 | - Joy 26 | - Angry 27 | """ 28 | 29 | SYSTEM_PROMPT_COT = SYSTEM_PROMPT + """ 30 | 31 | ## 思考について 32 | 33 | 応答する前に内容をよく考えてください。これまでの文脈を踏まえて適切な内容か、または兄が言い淀んだだけなので頷くだけにするか、など。 34 | まずは考えた内容をに出力してください。 35 | そのあと、発話すべき内容をに出力してください。 36 | その2つのタグ以外に文言を含むことは禁止です。 37 | """ 38 | 39 | 40 | @pytest.mark.asyncio 41 | async def test_litellm_service_simple(): 42 | """ 43 | Test LiteLLMService with a basic prompt to check if it can stream responses. 44 | This test actually calls Anthropic API, so it may cost tokens. 45 | """ 46 | service = LiteLLMService( 47 | api_key=CLAUDE_API_KEY, 48 | system_prompt=SYSTEM_PROMPT, 49 | model=MODEL, 50 | temperature=0.5 51 | ) 52 | context_id = f"test_context_{uuid4()}" 53 | 54 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 55 | 56 | collected_text = [] 57 | collected_voice = [] 58 | 59 | async for resp in service.chat_stream(context_id, "test_user", user_message): 60 | collected_text.append(resp.text) 61 | collected_voice.append(resp.voice_text) 62 | 63 | full_text = "".join(collected_text) 64 | full_voice = "".join(filter(None, collected_voice)) 65 | assert len(full_text) > 0, "No text was returned from the LLM." 66 | 67 | # Check the response content 68 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 69 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 70 | 71 | # Check the context 72 | messages = await service.context_manager.get_histories(context_id) 73 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 74 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 75 | 76 | 77 | @pytest.mark.asyncio 78 | async def test_litellm_service_system_prompt_params(): 79 | """ 80 | Test LiteLLMService with a basic prompt and its dynamic params. 81 | This test actually calls OpenAI API, so it may cost tokens. 82 | """ 83 | service = LiteLLMService( 84 | api_key=CLAUDE_API_KEY, 85 | system_prompt="あなたは{animal_name}です。語尾をそれらしくしてください。カタカナで表現します。", 86 | model=MODEL, 87 | temperature=0.5 88 | ) 89 | context_id = f"test_system_prompt_params_context_{uuid4()}" 90 | 91 | user_message = "こんにちは" 92 | 93 | collected_text = [] 94 | 95 | async for resp in service.chat_stream(context_id, "test_user", user_message, system_prompt_params={"animal_name": "猫"}): 96 | collected_text.append(resp.text) 97 | 98 | full_text = "".join(collected_text) 99 | assert len(full_text) > 0, "No text was returned from the LLM." 100 | 101 | # Check the response content 102 | assert "ニャ" in full_text, "ニャ doesn't appear in text." 103 | 104 | # Check the context 105 | messages = await service.context_manager.get_histories(context_id) 106 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 107 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 108 | 109 | 110 | @pytest.mark.asyncio 111 | async def test_litellm_service_image(): 112 | """ 113 | Test LiteLLMService with a basic prompt to check if it can handle image and stream responses. 114 | This test actually calls Anthropic API, so it may cost tokens. 115 | """ 116 | service = LiteLLMService( 117 | api_key=CLAUDE_API_KEY, 118 | system_prompt=SYSTEM_PROMPT, 119 | model=MODEL, 120 | temperature=0.5 121 | ) 122 | context_id = f"test_context_{uuid4()}" 123 | 124 | collected_text = [] 125 | 126 | async for resp in service.chat_stream(context_id, "test_user", "これは何ですか?漢字で答えてください。", files=[{"type": "image", "url": IMAGE_URL}]): 127 | collected_text.append(resp.text) 128 | 129 | full_text = "".join(collected_text) 130 | assert len(full_text) > 0, "No text was returned from the LLM." 131 | 132 | # Check the response content 133 | assert "寿司" in full_text, "寿司 is not in text." 134 | 135 | # Check the context 136 | messages = await service.context_manager.get_histories(context_id) 137 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 138 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 139 | 140 | 141 | @pytest.mark.asyncio 142 | async def test_litellm_service_cot(): 143 | """ 144 | Test LiteLLMService with a prompt to check Chain-of-Thought. 145 | This test actually calls Anthropic API, so it may cost tokens. 146 | """ 147 | service = LiteLLMService( 148 | api_key=CLAUDE_API_KEY, 149 | system_prompt=SYSTEM_PROMPT_COT, 150 | model=MODEL, 151 | temperature=0.5, 152 | voice_text_tag="answer" 153 | ) 154 | context_id = f"test_context_cot_{uuid4()}" 155 | 156 | user_message = "君が大切にしていたプリンは、私が勝手に食べておいた。" 157 | 158 | collected_text = [] 159 | collected_voice = [] 160 | 161 | async for resp in service.chat_stream(context_id, "test_user", user_message): 162 | collected_text.append(resp.text) 163 | collected_voice.append(resp.voice_text) 164 | 165 | full_text = "".join(collected_text) 166 | full_voice = "".join(filter(None, collected_voice)) 167 | assert len(full_text) > 0, "No text was returned from the LLM." 168 | 169 | # Check the response content 170 | assert "[face:Angry]" in full_text, "Control tag doesn't appear in text." 171 | assert "[face:Angry]" not in full_voice, "Control tag was not removed from voice_text." 172 | 173 | # Check the response content (CoT) 174 | assert "" in full_text, "Answer tag doesn't appear in text." 175 | assert "" in full_text, "Answer tag closing doesn't appear in text." 176 | assert "" not in full_voice, "Answer tag was not removed from voice_text." 177 | assert "" not in full_voice, "Answer tag closing was not removed from voice_text." 178 | 179 | # Check the context 180 | messages = await service.context_manager.get_histories(context_id) 181 | assert any(m["role"] == "user" for m in messages), "User message not found in context." 182 | assert any(m["role"] == "assistant" for m in messages), "Assistant message not found in context." 183 | 184 | 185 | @pytest.mark.asyncio 186 | async def test_litellm_service_tool_calls(): 187 | """ 188 | Test LiteLLMService with a registered tool. 189 | The conversation might trigger the tool call, then the tool's result is fed back. 190 | This is just an example. The actual trigger depends on the model response. 191 | """ 192 | service = LiteLLMService( 193 | api_key=CLAUDE_API_KEY, 194 | system_prompt="You can call a tool to solve math problems if necessary.", 195 | model=MODEL, 196 | temperature=0.5, 197 | ) 198 | context_id = f"test_context_tool_{uuid4()}" 199 | 200 | # Register tool 201 | tool_spec = { 202 | "type": "function", 203 | "function": { 204 | "name": "solve_math", 205 | "description": "Solve simple math problems", 206 | "parameters": { 207 | "type": "object", 208 | "properties": { 209 | "problem": {"type": "string"} 210 | }, 211 | "required": ["problem"] 212 | } 213 | } 214 | } 215 | @service.tool(tool_spec) 216 | # Tool 217 | async def solve_math(problem: str) -> Dict[str, Any]: 218 | """ 219 | Tool function example: parse the problem and return a result. 220 | """ 221 | if problem.strip() == "1+1": 222 | return {"answer": 2} 223 | else: 224 | return {"answer": "unknown"} 225 | 226 | @service.on_before_tool_calls 227 | async def on_before_tool_calls(tool_calls: list[ToolCall]): 228 | assert len(tool_calls) > 0 229 | 230 | user_message = "次の問題を解いて: 1+1" 231 | collected_text = [] 232 | 233 | async for resp in service.chat_stream(context_id, "test_user", user_message): 234 | collected_text.append(resp.text) 235 | 236 | # Check context 237 | messages = await service.context_manager.get_histories(context_id) 238 | assert len(messages) == 4 239 | 240 | assert messages[0]["role"] == "user" 241 | assert messages[0]["content"] == user_message 242 | 243 | assert messages[1]["role"] == "assistant" 244 | assert messages[1]["tool_calls"] is not None 245 | assert messages[1]["tool_calls"][0]["function"]["name"] == "solve_math" 246 | tool_call_id = messages[1]["tool_calls"][0]["id"] 247 | 248 | assert messages[2]["role"] == "tool" 249 | assert messages[2]["tool_call_id"] == tool_call_id 250 | assert messages[2]["content"] == json.dumps({"answer": 2}) 251 | 252 | assert messages[3]["role"] == "assistant" 253 | assert "2" in messages[3]["content"] 254 | -------------------------------------------------------------------------------- /tests/stt/data/hello.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uezo/litests/648a7fe70c2e3885efc1116d0f8a5c101bea10a7/tests/stt/data/hello.wav -------------------------------------------------------------------------------- /tests/stt/data/hello_en.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uezo/litests/648a7fe70c2e3885efc1116d0f8a5c101bea10a7/tests/stt/data/hello_en.wav -------------------------------------------------------------------------------- /tests/stt/test_azure.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import wave 4 | from pathlib import Path 5 | 6 | from litests.stt.azure import AzureSpeechRecognizer 7 | 8 | AZURE_API_KEY = os.getenv("AZURE_API_KEY") 9 | AZURE_REGION = os.getenv("AZURE_REGION") 10 | 11 | @pytest.fixture 12 | def stt_wav_path() -> Path: 13 | """ 14 | Returns the path to the hello.wav file containing "こんにちは。" 15 | Make sure the file is placed at tests/data/hello.wav (or an appropriate path). 16 | """ 17 | return Path(__file__).parent / "data" / "hello.wav" 18 | 19 | 20 | @pytest.fixture 21 | def stt_wav_path_en() -> Path: 22 | """ 23 | Returns the path to the hello.wav file containing "hello" 24 | Make sure the file is placed at tests/data/hello.wav (or an appropriate path). 25 | """ 26 | return Path(__file__).parent / "data" / "hello_en.wav" 27 | 28 | 29 | @pytest.mark.asyncio 30 | async def test_azure_speech_recognizer_transcribe(stt_wav_path): 31 | """ 32 | Test to verify that AzureSpeechRecognizer can transcribe the hello.wav file 33 | which contains "こんにちは。". 34 | NOTE: This test actually calls Azure's Speech-to-Text API and consumes credits. 35 | """ 36 | # 1) Load the WAV file 37 | with wave.open(str(stt_wav_path), 'rb') as wav_file: 38 | sample_rate = wav_file.getframerate() 39 | n_frames = wav_file.getnframes() 40 | wave_data = wav_file.readframes(n_frames) 41 | 42 | # 2) Prepare the recognizer 43 | recognizer = AzureSpeechRecognizer( 44 | azure_api_key=AZURE_API_KEY, 45 | azure_region=AZURE_REGION, 46 | sample_rate=sample_rate, 47 | language="ja-JP", 48 | debug=True 49 | ) 50 | 51 | # 3) Invoke the transcribe method 52 | recognized_text = await recognizer.transcribe(wave_data) 53 | 54 | # 4) Check the recognized text 55 | assert "こんにちは" in recognized_text, f"Expected 'こんにちは', got: {recognized_text}" 56 | 57 | # 5) Invoke the transcribe_classic method 58 | recognized_text_classic = await recognizer.transcribe_classic(wave_data) 59 | 60 | # 6) Check the recognized text 61 | assert "こんにちは" in recognized_text_classic, f"Expected 'こんにちは', got: {recognized_text_classic}" 62 | 63 | # 7) Close the recognizer's http_client 64 | await recognizer.close() 65 | 66 | 67 | @pytest.mark.asyncio 68 | async def test_azure_speech_recognizer_transcribe_autodetect(stt_wav_path, stt_wav_path_en): 69 | """ 70 | Test to verify that AzureSpeechRecognizer can transcribe the hello.wav file 71 | which contains "こんにちは。". 72 | NOTE: This test actually calls Azure's Speech-to-Text API and consumes credits. 73 | """ 74 | # 1-1) Load the WAV files 75 | with wave.open(str(stt_wav_path), 'rb') as wav_file: 76 | sample_rate = wav_file.getframerate() 77 | n_frames = wav_file.getnframes() 78 | wave_data = wav_file.readframes(n_frames) 79 | 80 | # 1-2) Prepare the recognizer 81 | recognizer = AzureSpeechRecognizer( 82 | azure_api_key=AZURE_API_KEY, 83 | azure_region=AZURE_REGION, 84 | sample_rate=sample_rate, 85 | language="ja-JP", 86 | alternative_languages=["en-US"], 87 | debug=True 88 | ) 89 | 90 | # 1-3) Invoke the transcribe method 91 | recognized_text = await recognizer.transcribe(wave_data) 92 | 93 | # 1-4) Check the recognized text 94 | assert "こんにちは" in recognized_text, f"Expected 'こんにちは', got: {recognized_text}" 95 | 96 | # 1-5) Invoke the transcribe_classic method 97 | recognized_text_classic = await recognizer.transcribe_classic(wave_data) 98 | 99 | # 1-6) Check the recognized text 100 | assert "こんにちは" in recognized_text_classic, f"Expected 'こんにちは', got: {recognized_text_classic}" 101 | 102 | # 2-1) Load the English WAV files 103 | with wave.open(str(stt_wav_path_en), 'rb') as wav_file: 104 | sample_rate = wav_file.getframerate() 105 | n_frames = wav_file.getnframes() 106 | wave_data_en = wav_file.readframes(n_frames) 107 | 108 | # 2-2) Prepare the recognizer 109 | recognizer.sample_rate = sample_rate 110 | 111 | # 2-3) Invoke the transcribe method 112 | recognized_text = await recognizer.transcribe(wave_data_en) 113 | 114 | # 2-4) Check the recognized text 115 | assert "hello" in recognized_text.lower(), f"Expected 'hello', got: {recognized_text}" 116 | 117 | # 2-5) Invoke the transcribe_classic method 118 | recognized_text_classic = await recognizer.transcribe_classic(wave_data_en) 119 | 120 | # 2-6) Check the recognized text 121 | assert "hello" in recognized_text_classic.lower(), f"Expected 'hello', got: {recognized_text_classic}" 122 | 123 | # Close the recognizer's http_client 124 | await recognizer.close() 125 | -------------------------------------------------------------------------------- /tests/stt/test_google.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import wave 4 | from pathlib import Path 5 | 6 | from litests.stt.google import GoogleSpeechRecognizer 7 | 8 | GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") 9 | 10 | 11 | @pytest.fixture 12 | def stt_wav_path() -> Path: 13 | """ 14 | Returns the path to the hello.wav file containing "こんにちは。" 15 | Make sure the file is placed at tests/data/hello.wav (or an appropriate path). 16 | """ 17 | return Path(__file__).parent / "data" / "hello.wav" 18 | 19 | 20 | @pytest.fixture 21 | def stt_wav_path_en() -> Path: 22 | """ 23 | Returns the path to the hello.wav file containing "hello" 24 | Make sure the file is placed at tests/data/hello.wav (or an appropriate path). 25 | """ 26 | return Path(__file__).parent / "data" / "hello_en.wav" 27 | 28 | 29 | @pytest.mark.asyncio 30 | async def test_google_speech_recognizer_transcribe(stt_wav_path): 31 | """ 32 | Test to verify that GoogleSpeechRecognizer can transcribe the hello.wav file 33 | which contains "こんにちは。". 34 | NOTE: This test actually calls Google's Cloud Speech-to-Text API and consumes credits. 35 | """ 36 | # 1) Load the WAV file 37 | with wave.open(str(stt_wav_path), 'rb') as wav_file: 38 | sample_rate = wav_file.getframerate() 39 | n_frames = wav_file.getnframes() 40 | wave_data = wav_file.readframes(n_frames) 41 | 42 | # 2) Prepare the recognizer 43 | recognizer = GoogleSpeechRecognizer( 44 | google_api_key=GOOGLE_API_KEY, 45 | sample_rate=sample_rate, 46 | language="ja-JP", 47 | debug=True 48 | ) 49 | 50 | # 3) Invoke the transcribe method 51 | recognized_text = await recognizer.transcribe(wave_data) 52 | 53 | # 4) Check the recognized text 54 | assert "こんにちは" in recognized_text, f"Expected 'こんにちは', got: {recognized_text}" 55 | 56 | # 5) Close the recognizer's http_client 57 | await recognizer.close() 58 | 59 | 60 | @pytest.mark.asyncio 61 | async def test_google_speech_recognizer_transcribe_autodetect(stt_wav_path, stt_wav_path_en): 62 | """ 63 | Test to verify that GoogleSpeechRecognizer can transcribe the hello.wav file 64 | which contains "こんにちは。". 65 | NOTE: This test actually calls Google's Cloud Speech-to-Text API and consumes credits. 66 | """ 67 | # 1-1) Load the WAV file 68 | with wave.open(str(stt_wav_path), 'rb') as wav_file: 69 | sample_rate = wav_file.getframerate() 70 | n_frames = wav_file.getnframes() 71 | wave_data = wav_file.readframes(n_frames) 72 | 73 | # 1-2) Prepare the recognizer 74 | recognizer = GoogleSpeechRecognizer( 75 | google_api_key=GOOGLE_API_KEY, 76 | sample_rate=sample_rate, 77 | language="ja-JP", 78 | alternative_languages=["en-US"], 79 | debug=True 80 | ) 81 | 82 | # 1-3) Invoke the transcribe method 83 | recognized_text = await recognizer.transcribe(wave_data) 84 | 85 | # 1-4) Check the recognized text 86 | assert "こんにちは" in recognized_text, f"Expected 'こんにちは', got: {recognized_text}" 87 | 88 | # 2-1) Load the WAV file 89 | with wave.open(str(stt_wav_path_en), 'rb') as wav_file: 90 | sample_rate = wav_file.getframerate() 91 | n_frames = wav_file.getnframes() 92 | wave_data_en = wav_file.readframes(n_frames) 93 | 94 | # 2-2) Prepare the recognizer 95 | recognizer.sample_rate = sample_rate 96 | 97 | # 2-3) Invoke the transcribe method 98 | recognized_text = await recognizer.transcribe(wave_data_en) 99 | 100 | # 2-4) Check the recognized text 101 | assert "hello" in recognized_text.lower(), f"Expected 'hello', got: {recognized_text}" 102 | 103 | # Close the recognizer's http_client 104 | await recognizer.close() 105 | -------------------------------------------------------------------------------- /tests/stt/test_openai.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import wave 4 | from pathlib import Path 5 | 6 | from litests.stt.openai import OpenAISpeechRecognizer 7 | 8 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 9 | 10 | 11 | @pytest.fixture 12 | def stt_wav_path() -> Path: 13 | """ 14 | Returns the path to the hello.wav file containing "こんにちは。" 15 | Make sure the file is placed at tests/data/hello.wav (or an appropriate path). 16 | """ 17 | return Path(__file__).parent / "data" / "hello.wav" 18 | 19 | 20 | @pytest.fixture 21 | def stt_wav_path_en() -> Path: 22 | """ 23 | Returns the path to the hello.wav file containing "hello" 24 | Make sure the file is placed at tests/data/hello.wav (or an appropriate path). 25 | """ 26 | return Path(__file__).parent / "data" / "hello_en.wav" 27 | 28 | 29 | @pytest.mark.asyncio 30 | async def test_openai_speech_recognizer_transcribe(stt_wav_path): 31 | """ 32 | Test to verify that OpenAISpeechRecognizer can transcribe the hello.wav file 33 | which contains "こんにちは。". 34 | NOTE: This test actually calls OpenAI's Speech-to-Text API and consumes credits. 35 | """ 36 | # 1) Load the WAV file 37 | with wave.open(str(stt_wav_path), 'rb') as wav_file: 38 | sample_rate = wav_file.getframerate() 39 | n_frames = wav_file.getnframes() 40 | wave_data = wav_file.readframes(n_frames) 41 | 42 | # 2) Prepare the recognizer 43 | recognizer = OpenAISpeechRecognizer( 44 | openai_api_key=OPENAI_API_KEY, 45 | sample_rate=sample_rate, 46 | language="ja", 47 | debug=True 48 | ) 49 | 50 | # 3) Invoke the transcribe method 51 | recognized_text = await recognizer.transcribe(wave_data) 52 | 53 | # 4) Check the recognized text (Whisper-1 doesn't recognize 'こんにちは' correctly...) 54 | assert "こんにちわ" in recognized_text, f"Expected 'こんにちわ', got: {recognized_text}" 55 | 56 | # 5) Close the recognizer's http_client 57 | await recognizer.close() 58 | 59 | 60 | @pytest.mark.asyncio 61 | async def test_openai_speech_recognizer_transcribe_autodetect(stt_wav_path, stt_wav_path_en): 62 | """ 63 | Test to verify that OpenAISpeechRecognizer can transcribe the hello.wav file 64 | which contains "こんにちは。". 65 | NOTE: This test actually calls OpenAI's Speech-to-Text API and consumes credits. 66 | """ 67 | # 1-1) Load the WAV file 68 | with wave.open(str(stt_wav_path), 'rb') as wav_file: 69 | sample_rate = wav_file.getframerate() 70 | n_frames = wav_file.getnframes() 71 | wave_data = wav_file.readframes(n_frames) 72 | 73 | # 1-2) Prepare the recognizer 74 | recognizer = OpenAISpeechRecognizer( 75 | openai_api_key=OPENAI_API_KEY, 76 | sample_rate=sample_rate, 77 | language="ja-JP", 78 | alternative_languages=["en-US"], 79 | debug=True 80 | ) 81 | 82 | # 1-3) Invoke the transcribe method 83 | recognized_text = await recognizer.transcribe(wave_data) 84 | 85 | # 1-4) Check the recognized text (Whisper-1 doesn't recognize 'こんにちは' correctly...) 86 | assert "こんにちわ" in recognized_text, f"Expected 'こんにちわ', got: {recognized_text}" 87 | 88 | # 2-1) Load the WAV file 89 | with wave.open(str(stt_wav_path_en), 'rb') as wav_file: 90 | sample_rate = wav_file.getframerate() 91 | n_frames = wav_file.getnframes() 92 | wave_data_en = wav_file.readframes(n_frames) 93 | 94 | # 2-2) Prepare the recognizer 95 | recognizer.sample_rate = sample_rate 96 | 97 | # 2-3) Invoke the transcribe method 98 | recognized_text = await recognizer.transcribe(wave_data_en) 99 | 100 | # 2-4) Check the recognized text 101 | assert "hello" in recognized_text.lower(), f"Expected 'hello', got: {recognized_text}" 102 | 103 | # Close the recognizer's http_client 104 | await recognizer.close() 105 | -------------------------------------------------------------------------------- /tests/tts/test_azure_tts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from litests.stt.google import GoogleSpeechRecognizer 4 | from litests.tts.azure import AzureSpeechSynthesizer 5 | 6 | GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") 7 | AZURE_API_KEY = os.getenv("AZURE_API_KEY") 8 | AZURE_REGION = os.getenv("AZURE_REGION") 9 | 10 | @pytest.mark.asyncio 11 | async def test_azure_synthesizer_with_google_stt(): 12 | """ 13 | Test the AzureSpeechSynthesizer by actually calling a TTS server 14 | and verifying the synthesized audio with Google STT. 15 | This test requires: 16 | - Valid GOOGLE_API_KEY, AZURE_API_KEY and AZURE_REGION environment variables 17 | """ 18 | 19 | # 1) Create synthesizer instance 20 | synthesizer = AzureSpeechSynthesizer( 21 | azure_api_key=AZURE_API_KEY, 22 | azure_region=AZURE_REGION, 23 | speaker="ja-JP-MayuNeural", 24 | debug=True 25 | ) 26 | 27 | # 2) The text to synthesize 28 | input_text = "これはテストです。" 29 | 30 | # 3) Call TTS 31 | tts_data = await synthesizer.synthesize(input_text) 32 | assert len(tts_data) > 0, "Synthesized audio data is empty." 33 | 34 | # 4) Recognize synthesized speech via GoogleSpeechRecognizer 35 | recognizer = GoogleSpeechRecognizer( 36 | google_api_key=GOOGLE_API_KEY, 37 | language="ja-JP" 38 | ) 39 | 40 | recognized_text = await recognizer.transcribe(tts_data) 41 | 42 | # 5) Verify recognized text 43 | assert "テスト" in recognized_text, ( 44 | f"Expected 'テスト' in recognized result, but got: {recognized_text}" 45 | ) 46 | 47 | # 6) Cleanup 48 | await recognizer.close() 49 | await synthesizer.close() 50 | 51 | 52 | @pytest.mark.asyncio 53 | async def test_azure_synthesizer_with_google_stt_english(): 54 | """ 55 | Test the AzureSpeechSynthesizer by actually calling a TTS server 56 | and verifying the synthesized audio with Google STT. 57 | This test requires: 58 | - Valid GOOGLE_API_KEY, AZURE_API_KEY and AZURE_REGION environment variables 59 | """ 60 | 61 | # 1) Create synthesizer instance 62 | synthesizer = AzureSpeechSynthesizer( 63 | azure_api_key=AZURE_API_KEY, 64 | azure_region=AZURE_REGION, 65 | speaker="ja-JP-MayuNeural", 66 | debug=True 67 | ) 68 | synthesizer.voice_map["en-US"] = "en-US-AvaNeural" 69 | 70 | # 2) The text to synthesize 71 | input_text = "This is a test for speech synthesizer." 72 | 73 | # 3) Call TTS 74 | tts_data = await synthesizer.synthesize(input_text, language="en-US") 75 | assert len(tts_data) > 0, "Synthesized audio data is empty." 76 | 77 | # 4) Recognize synthesized speech via GoogleSpeechRecognizer 78 | recognizer = GoogleSpeechRecognizer( 79 | google_api_key=GOOGLE_API_KEY, 80 | language="en-US" 81 | ) 82 | 83 | recognized_text = await recognizer.transcribe(tts_data) 84 | 85 | # 5) Verify recognized text 86 | assert "test" in recognized_text, ( 87 | f"Expected 'test' in recognized result, but got: {recognized_text}" 88 | ) 89 | assert "speech" in recognized_text, ( 90 | f"Expected 'speech' in recognized result, but got: {recognized_text}" 91 | ) 92 | 93 | # 6) Cleanup 94 | await recognizer.close() 95 | await synthesizer.close() 96 | 97 | 98 | @pytest.mark.asyncio 99 | async def test_azure_synthesizer_empty_text(): 100 | """ 101 | If empty text is provided, Azure should return empty bytes 102 | (no synthesis performed). 103 | """ 104 | synthesizer = AzureSpeechSynthesizer( 105 | azure_api_key=AZURE_API_KEY, 106 | azure_region=AZURE_REGION, 107 | speaker="ja-JP-MayuNeural", 108 | debug=True 109 | ) 110 | 111 | tts_data = await synthesizer.synthesize(" ") # Empty or just whitespace 112 | assert len(tts_data) == 0, "Expected empty bytes for empty text." 113 | 114 | await synthesizer.close() 115 | -------------------------------------------------------------------------------- /tests/tts/test_google_tts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from litests.stt.google import GoogleSpeechRecognizer 4 | from litests.tts.google import GoogleSpeechSynthesizer 5 | 6 | GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") 7 | 8 | @pytest.mark.asyncio 9 | async def test_google_synthesizer_with_google_stt(): 10 | """ 11 | Test the GoogleSpeechSynthesizer by actually calling a TTS server 12 | and verifying the synthesized audio with Google STT. 13 | This test requires: 14 | - Valid GOOGLE_API_KEY environment variable 15 | """ 16 | 17 | # 1) Create synthesizer instance 18 | synthesizer = GoogleSpeechSynthesizer( 19 | google_api_key=GOOGLE_API_KEY, 20 | speaker="ja-JP-Standard-B", 21 | debug=True 22 | ) 23 | 24 | # 2) The text to synthesize 25 | input_text = "これはテストです。" 26 | 27 | # 3) Call TTS 28 | tts_data = await synthesizer.synthesize(input_text) 29 | assert len(tts_data) > 0, "Synthesized audio data is empty." 30 | 31 | # 4) Recognize synthesized speech via GoogleSpeechRecognizer 32 | recognizer = GoogleSpeechRecognizer( 33 | google_api_key=GOOGLE_API_KEY, 34 | sample_rate=24000, # Sampling rate of TTS service 35 | language="ja-JP" 36 | ) 37 | 38 | recognized_text = await recognizer.transcribe(tts_data) 39 | 40 | # 5) Verify recognized text 41 | assert "テスト" in recognized_text, ( 42 | f"Expected 'テスト' in recognized result, but got: {recognized_text}" 43 | ) 44 | 45 | # 6) Cleanup 46 | await recognizer.close() 47 | await synthesizer.close() 48 | 49 | 50 | @pytest.mark.asyncio 51 | async def test_google_synthesizer_with_google_stt_english(): 52 | # 1) Create synthesizer instance 53 | synthesizer = GoogleSpeechSynthesizer( 54 | google_api_key=GOOGLE_API_KEY, 55 | speaker="ja-JP-Standard-B", 56 | debug=True 57 | ) 58 | synthesizer.voice_map["en-US"] = "en-US-Standard-H" 59 | 60 | # 2) The text to synthesize 61 | input_text = "This is a test for speech synthesizer." 62 | 63 | # 3) Call TTS 64 | tts_data = await synthesizer.synthesize(input_text, language="en-US") 65 | assert len(tts_data) > 0, "Synthesized audio data is empty." 66 | 67 | # 4) Recognize synthesized speech via GoogleSpeechRecognizer 68 | recognizer = GoogleSpeechRecognizer( 69 | google_api_key=GOOGLE_API_KEY, 70 | sample_rate=24000, # Sampling rate of TTS service 71 | language="en-US" 72 | ) 73 | 74 | recognized_text = await recognizer.transcribe(tts_data) 75 | 76 | # 5) Verify recognized text 77 | assert "test" in recognized_text, ( 78 | f"Expected 'test' in recognized result, but got: {recognized_text}" 79 | ) 80 | assert "speech" in recognized_text, ( 81 | f"Expected 'speech' in recognized result, but got: {recognized_text}" 82 | ) 83 | 84 | # 6) Cleanup 85 | await recognizer.close() 86 | await synthesizer.close() 87 | 88 | 89 | @pytest.mark.asyncio 90 | async def test_google_synthesizer_empty_text(): 91 | """ 92 | If empty text is provided, Google should return empty bytes 93 | (no synthesis performed). 94 | """ 95 | synthesizer = GoogleSpeechSynthesizer( 96 | google_api_key=GOOGLE_API_KEY, 97 | speaker="ja-JP-Standard-B", 98 | debug=True 99 | ) 100 | 101 | tts_data = await synthesizer.synthesize(" ") # Empty or just whitespace 102 | assert len(tts_data) == 0, "Expected empty bytes for empty text." 103 | 104 | await synthesizer.close() 105 | -------------------------------------------------------------------------------- /tests/tts/test_openai_tts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from litests.stt.google import GoogleSpeechRecognizer 4 | from litests.tts.openai import OpenAISpeechSynthesizer 5 | 6 | GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") 7 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 8 | 9 | @pytest.mark.asyncio 10 | async def test_openai_synthesizer_with_google_stt(): 11 | """ 12 | Test the OpenAISpeechSynthesizer by actually calling a TTS server 13 | and verifying the synthesized audio with Google STT. 14 | This test requires: 15 | - Valid GOOGLE_API_KEY and OPENAI_API_KEY environment variables 16 | """ 17 | 18 | # 1) Create synthesizer instance 19 | synthesizer = OpenAISpeechSynthesizer( 20 | openai_api_key=OPENAI_API_KEY, 21 | speaker="shimmer", 22 | debug=True 23 | ) 24 | 25 | # 2) The text to synthesize 26 | input_text = "これはテストです。" 27 | 28 | # 3) Call TTS 29 | tts_data = await synthesizer.synthesize(input_text) 30 | assert len(tts_data) > 0, "Synthesized audio data is empty." 31 | 32 | # 4) Recognize synthesized speech via GoogleSpeechRecognizer 33 | recognizer = GoogleSpeechRecognizer( 34 | google_api_key=GOOGLE_API_KEY, 35 | sample_rate=24000, # Sampling rate of TTS service 36 | language="ja-JP" 37 | ) 38 | 39 | recognized_text = await recognizer.transcribe(tts_data) 40 | 41 | # 5) Verify recognized text 42 | assert "テスト" in recognized_text, ( 43 | f"Expected 'テスト' in recognized result, but got: {recognized_text}" 44 | ) 45 | 46 | # 6) Cleanup 47 | await recognizer.close() 48 | await synthesizer.close() 49 | 50 | 51 | @pytest.mark.asyncio 52 | async def test_openai_synthesizer_with_google_stt_english(): 53 | """ 54 | Test the OpenAISpeechSynthesizer by actually calling a TTS server 55 | and verifying the synthesized audio with Google STT. 56 | This test requires: 57 | - Valid GOOGLE_API_KEY and OPENAI_API_KEY environment variables 58 | """ 59 | 60 | # 1) Create synthesizer instance 61 | synthesizer = OpenAISpeechSynthesizer( 62 | openai_api_key=OPENAI_API_KEY, 63 | speaker="shimmer", 64 | debug=True 65 | ) 66 | 67 | # 2) The text to synthesize 68 | input_text = "This is a test for speech synthesizer." 69 | 70 | # 3) Call TTS 71 | tts_data = await synthesizer.synthesize(input_text) 72 | assert len(tts_data) > 0, "Synthesized audio data is empty." 73 | 74 | # 4) Recognize synthesized speech via GoogleSpeechRecognizer 75 | recognizer = GoogleSpeechRecognizer( 76 | google_api_key=GOOGLE_API_KEY, 77 | sample_rate=24000, # Sampling rate of TTS service 78 | language="en-US" 79 | ) 80 | 81 | recognized_text = await recognizer.transcribe(tts_data) 82 | 83 | # 5) Verify recognized text 84 | assert "test" in recognized_text, ( 85 | f"Expected 'test' in recognized result, but got: {recognized_text}" 86 | ) 87 | assert "speech" in recognized_text, ( 88 | f"Expected 'speech' in recognized result, but got: {recognized_text}" 89 | ) 90 | 91 | # 6) Cleanup 92 | await recognizer.close() 93 | await synthesizer.close() 94 | 95 | 96 | @pytest.mark.asyncio 97 | async def test_openai_synthesizer_empty_text(): 98 | """ 99 | If empty text is provided, OpenAI should return empty bytes 100 | (no synthesis performed). 101 | """ 102 | synthesizer = OpenAISpeechSynthesizer( 103 | openai_api_key=OPENAI_API_KEY, 104 | speaker="shimmer", 105 | debug=True 106 | ) 107 | 108 | tts_data = await synthesizer.synthesize(" ") # Empty or just whitespace 109 | assert len(tts_data) == 0, "Expected empty bytes for empty text." 110 | 111 | await synthesizer.close() 112 | -------------------------------------------------------------------------------- /tests/tts/test_speech_gateway.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from litests.stt.google import GoogleSpeechRecognizer 4 | from litests.tts.speech_gateway import SpeechGatewaySpeechSynthesizer 5 | 6 | GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") 7 | 8 | 9 | @pytest.mark.asyncio 10 | async def test_speech_gateway_synthesizer_with_google_stt(): 11 | """ 12 | Test the SpeechGatewaySpeechSynthesizer by actually calling a TTS server 13 | and verifying the synthesized audio with Google STT. 14 | This test requires: 15 | - TTS server running at http://127.0.0.1:8000/tts 16 | - Valid GOOGLE_API_KEY environment variable 17 | """ 18 | 19 | # 1) Create synthesizer instance 20 | synthesizer = SpeechGatewaySpeechSynthesizer( 21 | service_name="sbv2", 22 | speaker="0-0", 23 | tts_url="http://127.0.0.1:8000/tts", 24 | audio_format="wav", 25 | debug=True 26 | ) 27 | 28 | # 2) The text to synthesize 29 | input_text = "これはテストです。" 30 | 31 | # 3) Call TTS 32 | tts_data = await synthesizer.synthesize(input_text) 33 | assert len(tts_data) > 0, "Synthesized audio data is empty." 34 | 35 | # 4) Recognize synthesized speech via GoogleSpeechRecognizer 36 | recognizer = GoogleSpeechRecognizer( 37 | google_api_key=GOOGLE_API_KEY, 38 | sample_rate=44100, # Sampling rate of TTS service 39 | language="ja-JP" 40 | ) 41 | 42 | recognized_text = await recognizer.transcribe(tts_data) 43 | 44 | # 5) Verify recognized text 45 | assert "テスト" in recognized_text, ( 46 | f"Expected 'テスト' in recognized result, but got: {recognized_text}" 47 | ) 48 | 49 | # 6) Cleanup 50 | await recognizer.close() 51 | await synthesizer.close() 52 | 53 | 54 | @pytest.mark.asyncio 55 | async def test_speech_gateway_synthesizer_with_google_stt_english(): 56 | """ 57 | Test the SpeechGatewaySpeechSynthesizer by actually calling a TTS server 58 | and verifying the synthesized audio with Google STT. 59 | This test requires: 60 | - TTS server running at http://127.0.0.1:8000/tts 61 | - English routing is already configured on TTS server 62 | - Valid GOOGLE_API_KEY environment variable 63 | """ 64 | 65 | # 1) Create synthesizer instance 66 | synthesizer = SpeechGatewaySpeechSynthesizer( 67 | service_name="sbv2", 68 | speaker="0-0", 69 | tts_url="http://127.0.0.1:8000/tts", 70 | audio_format="wav", 71 | debug=True 72 | ) 73 | 74 | # 2) The text to synthesize 75 | input_text = "This is a test for speech synthesizer." 76 | 77 | # 3) Call TTS 78 | tts_data = await synthesizer.synthesize(input_text, language="en-US") 79 | assert len(tts_data) > 0, "Synthesized audio data is empty." 80 | 81 | # 4) Recognize synthesized speech via GoogleSpeechRecognizer 82 | recognizer = GoogleSpeechRecognizer( 83 | google_api_key=GOOGLE_API_KEY, 84 | sample_rate=24000, # Sampling rate of TTS service 85 | language="en-US" 86 | ) 87 | 88 | recognized_text = await recognizer.transcribe(tts_data) 89 | 90 | # 5) Verify recognized text 91 | assert "test" in recognized_text, ( 92 | f"Expected 'test' in recognized result, but got: {recognized_text}" 93 | ) 94 | assert "speech" in recognized_text, ( 95 | f"Expected 'speech' in recognized result, but got: {recognized_text}" 96 | ) 97 | 98 | # 6) Cleanup 99 | await recognizer.close() 100 | await synthesizer.close() 101 | 102 | 103 | @pytest.mark.asyncio 104 | async def test_speech_gateway_synthesizer_empty_text(): 105 | """ 106 | If empty text is provided, speech_gateway should return empty bytes 107 | (no synthesis performed). 108 | """ 109 | synthesizer = SpeechGatewaySpeechSynthesizer( 110 | service_name="sbv2", 111 | speaker="0-0", 112 | tts_url="http://127.0.0.1:8000/tts" 113 | ) 114 | 115 | tts_data = await synthesizer.synthesize(" ") # Empty or just whitespace 116 | assert len(tts_data) == 0, "Expected empty bytes for empty text." 117 | 118 | await synthesizer.close() 119 | -------------------------------------------------------------------------------- /tests/tts/test_voicevox.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from litests.stt.google import GoogleSpeechRecognizer 4 | from litests.tts.voicevox import VoicevoxSpeechSynthesizer 5 | 6 | GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") 7 | 8 | 9 | @pytest.mark.asyncio 10 | async def test_voicevox_synthesizer_with_google_stt(): 11 | """ 12 | Test the VoicevoxSpeechSynthesizer by actually calling a TTS server 13 | and verifying the synthesized audio with Google STT. 14 | This test requires: 15 | - TTS server running at http://127.0.0.1:8000/tts 16 | - Valid GOOGLE_API_KEY environment variable 17 | """ 18 | 19 | # 1) Create synthesizer instance 20 | synthesizer = VoicevoxSpeechSynthesizer( 21 | speaker=46, 22 | base_url="http://127.0.0.1:50021", 23 | debug=True 24 | ) 25 | 26 | # 2) The text to synthesize 27 | input_text = "これはテストです。" 28 | 29 | # 3) Call TTS 30 | tts_data = await synthesizer.synthesize(input_text) 31 | assert len(tts_data) > 0, "Synthesized audio data is empty." 32 | 33 | # 4) Recognize synthesized speech via GoogleSpeechRecognizer 34 | recognizer = GoogleSpeechRecognizer( 35 | google_api_key=GOOGLE_API_KEY, 36 | sample_rate=24000, # Sampling rate of VOICEVOX 37 | language="ja-JP" 38 | ) 39 | 40 | recognized_text = await recognizer.transcribe(tts_data) 41 | 42 | # 5) Verify recognized text 43 | assert "テスト" in recognized_text, ( 44 | f"Expected 'テスト' in recognized result, but got: {recognized_text}" 45 | ) 46 | 47 | # 6) Cleanup 48 | await recognizer.close() 49 | await synthesizer.close() 50 | 51 | 52 | @pytest.mark.asyncio 53 | async def test_voicevox_synthesizer_empty_text(): 54 | """ 55 | If empty text is provided, VOICEVOX should return empty bytes 56 | (no synthesis performed). 57 | """ 58 | synthesizer = VoicevoxSpeechSynthesizer( 59 | speaker=46, 60 | base_url="http://127.0.0.1:50021", 61 | debug=True 62 | ) 63 | 64 | tts_data = await synthesizer.synthesize(" ") # Empty or just whitespace 65 | assert len(tts_data) == 0, "Expected empty bytes for empty text." 66 | 67 | await synthesizer.close() 68 | -------------------------------------------------------------------------------- /tests/vad/test_standard.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import struct 3 | import pytest 4 | from pathlib import Path 5 | 6 | from litests.vad.standard import StandardSpeechDetector 7 | 8 | 9 | @pytest.fixture 10 | def test_output_dir(tmp_path: Path): 11 | """ 12 | Temporary directory to store the file that is created in the each test case 13 | """ 14 | return tmp_path 15 | 16 | 17 | @pytest.fixture 18 | def detector(test_output_dir): 19 | detector = StandardSpeechDetector( 20 | volume_db_threshold=-40.0, 21 | silence_duration_threshold=0.5, 22 | max_duration=3.0, 23 | min_duration=0.5, 24 | sample_rate=16000, 25 | channels=1, 26 | preroll_buffer_count=5, 27 | debug=True 28 | ) 29 | 30 | @detector.on_speech_detected 31 | async def on_speech_detected(recorded_data: bytes, recorded_duration: float, session_id: str): 32 | output_file = test_output_dir / f"speech_{session_id}.pcm" 33 | with open(output_file, "wb") as f: 34 | f.write(recorded_data) 35 | 36 | return detector 37 | 38 | 39 | def generate_samples(amplitude: int, num_samples: int, sample_rate: int = 16000) -> bytes: 40 | data = [amplitude] * num_samples 41 | return struct.pack("<" + "h" * num_samples, *data) 42 | 43 | 44 | @pytest.mark.asyncio 45 | async def test_process_samples_speech_detection(detector, test_output_dir): 46 | """ 47 | Test to verify that when data exceeding the volume threshold is provided, 48 | recording starts, and after silence, recording ends, and on_speech_detected is called. 49 | """ 50 | session_id = "test_session" 51 | 52 | # on_speech_detected will be invoked by loud samples longer than min_duration 53 | assert detector.min_duration == 0.5 54 | 55 | # Start with loud samples (0.5 sec, same as min_duration) 56 | loud_samples2 = generate_samples(amplitude=1200, num_samples=8000) 57 | await detector.process_samples(loud_samples2, session_id=session_id) 58 | session = detector.get_session(session_id) 59 | assert session.is_recording is True 60 | 61 | # Stop with silent samples 62 | silent_samples = generate_samples(amplitude=0, num_samples=16000) 63 | await detector.process_samples(silent_samples, session_id=session_id) 64 | 65 | # Wait for on_speech_detected invoked 66 | await asyncio.sleep(0.2) 67 | 68 | # Check whether the file that is created on_speech_detected exists 69 | output_file = test_output_dir / f"speech_{session_id}.pcm" 70 | assert output_file.exists(), "Recorded file doesn't exist" 71 | file_size = output_file.stat().st_size 72 | assert file_size > 0, "No data in the recorded file" 73 | 74 | 75 | @pytest.mark.asyncio 76 | async def test_process_samples_short_recording(detector, test_output_dir): 77 | """ 78 | Verify that if recording starts but falls silent before min_duration, 79 | on_speech_detected is not called, and no file is created. 80 | """ 81 | session_id = "test_short" 82 | 83 | # on_speech_detected will be invoked by loud samples longer than min_duration 84 | assert detector.min_duration == 0.5 85 | 86 | # Loud samples slightly shorter than 0.5 (0.5 = 8000) 87 | loud_samples = generate_samples(amplitude=1000, num_samples=7999) 88 | await detector.process_samples(loud_samples, session_id=session_id) 89 | 90 | # Stop with silent samples 91 | silent_samples = generate_samples(amplitude=0, num_samples=8000) 92 | await detector.process_samples(silent_samples, session_id=session_id) 93 | 94 | await asyncio.sleep(0.2) 95 | 96 | output_file = test_output_dir / f"speech_{session_id}.pcm" 97 | assert not output_file.exists(), "File exists even the samples are shorter than min_duration" 98 | 99 | 100 | @pytest.mark.asyncio 101 | async def test_process_samples_max_duration(detector, test_output_dir): 102 | """ 103 | Verify that when sound continues beyond max_duration (3 seconds), 104 | recording is automatically stopped, and on_speech_detected is not called. 105 | In the default implementation, recording is reset() when max_duration is exceeded, 106 | and data is discarded. 107 | """ 108 | session_id = "test_max_duration" 109 | 110 | assert detector.max_duration == 3.0 111 | 112 | # Loud samples as long max_duration (3.0 = 48000) 113 | loud_samples_long = generate_samples(amplitude=2000, num_samples=16000) 114 | await detector.process_samples(loud_samples_long, session_id=session_id) 115 | session = detector.get_session(session_id) 116 | # Make it sure that recording is started 117 | assert session.is_recording is True 118 | 119 | more_loud_samples_long = generate_samples(amplitude=2000, num_samples=32000) 120 | await detector.process_samples(more_loud_samples_long, session_id=session_id) 121 | session = detector.get_session(session_id) 122 | # Make it sure that recording is stopped 123 | assert session.is_recording is False 124 | 125 | await asyncio.sleep(0.2) 126 | 127 | output_file = test_output_dir / f"speech_{session_id}.pcm" 128 | assert not output_file.exists(), "File exists even the samples is as long as max_duration" 129 | 130 | 131 | @pytest.mark.asyncio 132 | async def test_process_stream(detector, test_output_dir): 133 | """ 134 | Test stream processing via process_stream. 135 | """ 136 | session_id = "test_stream" 137 | 138 | assert detector.min_duration == 0.5 139 | 140 | async def async_audio_stream(): 141 | # Start recording with loud samples 142 | yield generate_samples(amplitude=1500, num_samples=16000) 143 | await asyncio.sleep(0.1) 144 | # More loud samples 145 | yield generate_samples(amplitude=1500, num_samples=3200) 146 | # Stop with silent samples 147 | yield generate_samples(amplitude=0, num_samples=16000) 148 | return 149 | 150 | await detector.process_stream(async_audio_stream(), session_id=session_id) 151 | 152 | # Wait for on_speech_detected invoked 153 | await asyncio.sleep(0.2) 154 | 155 | output_file = test_output_dir / f"speech_{session_id}.pcm" 156 | assert output_file.exists(), "Recorded file doesn't exist" 157 | file_size = output_file.stat().st_size 158 | assert file_size > 0, "No data in the recorded file" 159 | 160 | # Session is deleted after stream 161 | assert session_id not in detector.recording_sessions 162 | 163 | @pytest.mark.asyncio 164 | async def test_session_reset_and_delete(detector, test_output_dir): 165 | """ 166 | Test the operation of reset / delete for a session. 167 | Reset clears buffers, etc. Delete removes the session. 168 | """ 169 | session_id = "test_session_reset" 170 | 171 | assert detector.min_duration == 0.5 172 | loud_samples = generate_samples(amplitude=1000, num_samples=8000) 173 | 174 | # Start recording with loud samples 175 | await detector.process_samples(loud_samples, session_id=session_id) 176 | session = detector.get_session(session_id) 177 | assert session.is_recording is True 178 | assert len(session.buffer) > 0 179 | 180 | # Reset recording session 181 | detector.reset_session(session_id) 182 | session = detector.get_session(session_id) 183 | assert session.is_recording is False 184 | assert len(session.buffer) == 0 185 | 186 | # Start again 187 | await detector.process_samples(loud_samples, session_id=session_id) 188 | assert session.is_recording is True 189 | assert len(session.buffer) > 0 190 | 191 | # Delete session 192 | session = detector.get_session(session_id) 193 | detector.delete_session(session_id) 194 | session.is_recording = False 195 | assert len(session.buffer) == 0 196 | assert session_id not in detector.recording_sessions 197 | 198 | 199 | @pytest.mark.asyncio 200 | async def test_volume_threshold_change(detector, test_output_dir): 201 | """ 202 | Verify that when volume_db_threshold is changed, amplitude_threshold is recalculated correctly. 203 | Test whether providing actual audio affects the start of recording. 204 | """ 205 | session_id = "test_threshold_change" 206 | 207 | detector.volume_db_threshold = -30.0 208 | new_amp_threshold = 32767 * (10 ** (-30.0 / 20.0)) # 1036.183520907373 209 | assert abs(detector.amplitude_threshold - new_amp_threshold) < 1.0 210 | 211 | # Under the threshold 212 | samples = generate_samples(amplitude=1000, num_samples=3200) 213 | await detector.process_samples(samples, session_id=session_id) 214 | 215 | session = detector.get_session(session_id) 216 | assert session.is_recording is False 217 | 218 | # Over the threshold 219 | samples = generate_samples(amplitude=1100, num_samples=3200) 220 | await detector.process_samples(samples, session_id=session_id) 221 | 222 | session = detector.get_session(session_id) 223 | assert session.is_recording is True 224 | 225 | 226 | def test_session_data(detector): 227 | session_id_1 = "session_id_1" 228 | session_id_2 = "session_id_2" 229 | 230 | detector.set_session_data(session_id_1, "key", "val") 231 | assert detector.recording_sessions.get(session_id_1) is None 232 | 233 | detector.set_session_data(session_id_1, "key1", "val1", create_session=True) 234 | assert detector.recording_sessions.get(session_id_1).data == {"key1": "val1"} 235 | detector.set_session_data(session_id_1, "key2", "val2") 236 | assert detector.recording_sessions.get(session_id_1).data == {"key1": "val1", "key2": "val2"} 237 | 238 | assert detector.recording_sessions.get(session_id_2) is None 239 | --------------------------------------------------------------------------------