├── tests ├── pytest.ini ├── resources │ ├── ocr-test.png │ ├── tts-test.wav │ ├── imgcap-test.png │ └── text.txt ├── test_typegaurd.py ├── util.py ├── test_qq.py ├── common │ └── test_gen.py ├── test_decorator.py ├── test_vec_db.py ├── services │ ├── test_live2d.py │ └── test_obs_client.py ├── test_zws.py ├── pipeline │ ├── test_ocr.py │ ├── test_vla.py │ ├── test_imgcap.py │ ├── test_abs_img.py │ ├── test_vec_db.py │ ├── test_abs.py │ ├── test_asr.py │ ├── test_perf.py │ ├── test_llm.py │ └── server.py ├── test_live_stream.py ├── test_reload.py ├── test_json.py ├── test_jws.py ├── devices │ └── test_mic.py ├── test_event_emitter.py ├── test_agent.py └── test_playground.py ├── .gitignore ├── common ├── utils │ ├── img_util.py │ ├── collection_util.py │ ├── math_util.py │ ├── time_util.py │ ├── str_util.py │ ├── file_util.py │ ├── enum_util.py │ ├── json_util.py │ ├── web_util.py │ └── audio_util.py ├── io │ ├── file_type.py │ ├── api.py │ └── file_sys.py ├── ver_check.py ├── collection │ └── limit_list.py ├── enumerator.py ├── concurrent │ ├── abs_runnable.py │ └── killable_thread.py ├── web │ ├── zrl_ws.py │ └── json_ws.py ├── generator │ └── config_gen.py └── decorator.py ├── services ├── playground │ ├── res │ │ └── config.py │ └── config.py ├── qqbot │ ├── config.py │ └── bridge.py ├── game │ ├── minecraft │ │ ├── data.py │ │ ├── instrcution │ │ │ ├── tool.py │ │ │ └── input.py │ │ └── app.py │ └── config.py ├── live2d │ ├── config.py │ ├── live2d_canvas.py │ └── live2d_viewer.py ├── browser │ ├── config.py │ ├── browser.py │ └── driver.py ├── obs │ └── config.py ├── config.py └── live_stream │ ├── config.py │ ├── twitch.py │ └── youtube.py ├── webui.py ├── main.py ├── pipeline ├── db │ └── milvus │ │ ├── config.py │ │ ├── milvus_async.py │ │ └── milvus_sync.py ├── vla │ ├── config.py │ └── showui │ │ ├── config.py │ │ ├── showui_sync.py │ │ └── showui_async.py ├── ocr │ ├── ocr_async.py │ ├── config.py │ └── ocr_sync.py ├── imgcap │ ├── imgcap_async.py │ ├── imgcap_sync.py │ └── config.py ├── vidcap │ ├── config.py │ ├── vidcap_sync.py │ └── vidcap_async.py ├── llm │ ├── llm_async.py │ ├── config.py │ └── llm_sync.py ├── tts │ ├── config.py │ ├── tts_async.py │ ├── tts_sync.py │ └── baidu_tts.py ├── base │ ├── config.py │ └── base_async.py └── asr │ ├── config.py │ ├── asr_async.py │ ├── asr_sync.py │ └── baidu_asr.py ├── requirements.txt ├── agent ├── tool │ ├── microphone_tool.py │ ├── lang_changer.py │ ├── go_creator.py │ └── web_search.py ├── custom_agent.py └── adaptor.py ├── LICENSE ├── manager ├── model_manager.py ├── llm_prompt_manager.py └── config_manager.py ├── character ├── filter │ └── strategy.py └── config.py ├── event └── registry.py ├── config.py └── devices ├── speaker.py ├── screen.py └── microphone.py /tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | asyncio_default_fixture_loop_scope = function -------------------------------------------------------------------------------- /tests/resources/ocr-test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkagawaTsurunaki/ZerolanLiveRobot/HEAD/tests/resources/ocr-test.png -------------------------------------------------------------------------------- /tests/resources/tts-test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkagawaTsurunaki/ZerolanLiveRobot/HEAD/tests/resources/tts-test.wav -------------------------------------------------------------------------------- /tests/resources/imgcap-test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkagawaTsurunaki/ZerolanLiveRobot/HEAD/tests/resources/imgcap-test.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | /.tmp 3 | /.idea 4 | /.vscode 5 | /assets 6 | resources/config.yaml 7 | /resources 8 | /data 9 | /.temp 10 | config.yaml 11 | /services/playground/proto/*_pb2*.py -------------------------------------------------------------------------------- /common/utils/img_util.py: -------------------------------------------------------------------------------- 1 | from PIL.Image import Image 2 | 3 | 4 | def is_image_uniform(img: Image): 5 | gray_img = img.convert('L') 6 | min_value, max_value = gray_img.getextrema() 7 | return min_value == max_value 8 | -------------------------------------------------------------------------------- /common/utils/collection_util.py: -------------------------------------------------------------------------------- 1 | def to_value_list(d: dict): 2 | if isinstance(d, dict): 3 | ret = [] 4 | for k, v in d.items(): 5 | ret.append(v) 6 | return ret 7 | else: 8 | raise TypeError("需要 dict 类型") 9 | -------------------------------------------------------------------------------- /services/playground/res/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | 4 | class ResourceServerConfig(BaseModel): 5 | host: str = Field("0.0.0.0", description="The host address of the resource server") 6 | port: int = Field(8899, description="The port number of the resource server") -------------------------------------------------------------------------------- /common/io/file_type.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class AudioFileType(str, Enum): 5 | FLV = 'flv' 6 | WAV = 'wav' 7 | OGG = 'ogg' 8 | MP3 = 'mp3' 9 | RAW = 'raw' 10 | 11 | 12 | class ImageFileType(str, Enum): 13 | PNG = 'png' 14 | JPEG = 'jpeg' 15 | JPG = 'jpg' 16 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | from manager.config_manager import get_config 2 | from common.generator.gradio_gen import DynamicConfigPage 3 | 4 | 5 | def _start_config_webui(): 6 | config = get_config() 7 | page = DynamicConfigPage(config) 8 | page.launch() 9 | 10 | 11 | if __name__ == '__main__': 12 | _start_config_webui() 13 | -------------------------------------------------------------------------------- /services/qqbot/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | 4 | class QQBotBridgeConfig(BaseModel): 5 | enable: bool = Field(True, description="Whether to enable the QQBotBridge.") 6 | host: str = Field("0.0.0.0", description="The host address of the QQBotBridge service.") 7 | port: int = Field(11014, description="The port number of the QQBotBridge service.") 8 | -------------------------------------------------------------------------------- /common/utils/math_util.py: -------------------------------------------------------------------------------- 1 | def clamp(_min, _max, value): 2 | """ 3 | 获取将 value 变为不小于 _min 同时不大于 _max 的值。 4 | Args: 5 | _min: 最小值 6 | _max: 最大值 7 | value: 要被变换的值 8 | 9 | Returns: 10 | 区间上 [_min, _max] 的值 11 | """ 12 | if value < _min: 13 | return _min 14 | elif value > _max: 15 | return _max 16 | else: 17 | return value 18 | -------------------------------------------------------------------------------- /services/game/minecraft/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | @dataclass 7 | class BotOption: 8 | host: str 9 | port: int 10 | username: str 11 | version: str 12 | masterName: str 13 | 14 | 15 | class KonekoProtocol(BaseModel): 16 | protocol: str = "Koneko Protocol" 17 | version: str = "0.2" 18 | event: str = None 19 | data: dict | list = None 20 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from bot import ZerolanLiveRobot 4 | from loguru import logger 5 | 6 | from common.concurrent.abs_runnable import stop_all_runnable 7 | 8 | 9 | async def main(): 10 | try: 11 | bot = ZerolanLiveRobot() 12 | await bot.start() 13 | bot.stop() 14 | except Exception as e: 15 | logger.exception(e) 16 | logger.error("❌️ Zerolan Live Robot exited abnormally!") 17 | 18 | 19 | if __name__ == '__main__': 20 | asyncio.run(main()) 21 | -------------------------------------------------------------------------------- /pipeline/db/milvus/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | from pipeline.db.milvus.milvus_sync import MilvusDatabaseConfig 4 | 5 | 6 | ######### 7 | # VecDB # 8 | ######### 9 | 10 | class VectorDBConfig(BaseModel): 11 | enable: bool = Field(default=True, description="Whether the Vector Database is enabled.") 12 | milvus: MilvusDatabaseConfig = Field(default=MilvusDatabaseConfig(), 13 | description="Configuration for the Milvus Database.") 14 | -------------------------------------------------------------------------------- /tests/test_typegaurd.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from typeguard import typechecked, TypeCheckError 3 | 4 | 5 | def test_typechecked(): 6 | @typechecked 7 | def add_numbers(a: int, b: int) -> int: 8 | return a + b 9 | 10 | assert add_numbers(2, 3) == 5 11 | 12 | with pytest.raises(TypeCheckError): 13 | add_numbers(2, "3") 14 | 15 | @typechecked 16 | def return_string(a: int) -> int: 17 | return str(a) 18 | 19 | with pytest.raises(TypeCheckError): 20 | return_string(5) 21 | -------------------------------------------------------------------------------- /pipeline/vla/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | from common.enumerator import BaseEnum 4 | from pipeline.vla.showui.config import ShowUIConfig 5 | 6 | 7 | class VLAModelIdEnum(BaseEnum): 8 | ShowUI: str = "howlab/ShowUI-2B" 9 | 10 | 11 | class VLAPipelineConfig(BaseModel): 12 | showui: ShowUIConfig = Field(default=ShowUIConfig(), description="Configuration for the ShowUI component.") 13 | enable: bool = Field(default=True, description="Whether the Visual Language Action pipeline is enabled.") 14 | -------------------------------------------------------------------------------- /tests/util.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio import TaskGroup 3 | 4 | 5 | async def syncwait(bridge): 6 | while not bridge.is_connected: 7 | await asyncio.sleep(0.1) 8 | 9 | 10 | async def connect(bridge, auto_close_flag: bool = False): 11 | async with TaskGroup() as tg: 12 | start_task = tg.create_task(bridge.start()) 13 | if auto_close_flag: 14 | await asyncio.sleep(2) 15 | print("Closing the WebSocket server") 16 | await bridge.stop() 17 | start_task.cancel() 18 | -------------------------------------------------------------------------------- /common/utils/time_util.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from typeguard import typechecked 4 | 5 | 6 | @typechecked 7 | def get_time_string() -> str: 8 | """ 9 | Get current time string. 10 | :return: String like `20250411091422`. 11 | """ 12 | current_time = datetime.now() 13 | time_str = current_time.strftime("%Y%m%d%H%M%S") 14 | return time_str 15 | 16 | 17 | @typechecked 18 | def get_time_iso_string() -> str: 19 | now = datetime.now() 20 | formatted_date = now.isoformat() 21 | return formatted_date 22 | -------------------------------------------------------------------------------- /common/ver_check.py: -------------------------------------------------------------------------------- 1 | import pydantic 2 | 3 | 4 | def check_pydantic_ver(): 5 | """ 6 | Check whether pydantic version is ok. 7 | Warning: PydanticDeprecatedSince211: Accessing the 'model_fields' attribute on the instance is deprecated. 8 | Instead, you should access this attribute from the model class. 9 | Deprecated in Pydantic V2.11 to be removed in V3.0. 10 | """ 11 | # 12 | pydantic_ver = pydantic.version.VERSION.split(".") 13 | if not (int(pydantic_ver[0]) <= 2 and int(pydantic_ver[1]) <= 11): 14 | raise Exception("Too high version of Pydantic, try install pydantic<=2.11") -------------------------------------------------------------------------------- /tests/test_qq.py: -------------------------------------------------------------------------------- 1 | from asyncio import TaskGroup 2 | 3 | import pytest 4 | 5 | from config import get_config 6 | from services.qqbot.bridge import QQBotBridge 7 | from util import connect, syncwait 8 | 9 | _config = get_config() 10 | _bridge = QQBotBridge(_config.service.qqbot) 11 | 12 | auto_close_flag = False 13 | 14 | 15 | # You can test the connection to ZerolanPlayground 16 | # All test cases will depend on this function so make sure you test it at first. 17 | @pytest.mark.asyncio 18 | async def test_conn(): 19 | async with TaskGroup() as tg: 20 | tg.create_task(connect(_bridge, auto_close_flag)) 21 | await syncwait(_bridge) 22 | -------------------------------------------------------------------------------- /tests/common/test_gen.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | from common.generator.config_gen import ConfigFileGenerator 4 | from config import ZerolanLiveRobotConfig 5 | 6 | 7 | def read_yaml(path: str): 8 | with open(path, mode="r", encoding="utf-8") as f: 9 | return yaml.safe_load(f) 10 | 11 | 12 | def test_gen(): 13 | g = ConfigFileGenerator() 14 | s = g.generate_yaml(ZerolanLiveRobotConfig()) 15 | print(s) 16 | path = r"D:\AkagawaTsurunaki\WorkSpace\PythonProjects\ZerolanLiveRobot\resources\t.yaml" 17 | with open(path, mode="w+", 18 | encoding="utf-8") as f: 19 | f.write(s) 20 | d = read_yaml(path) 21 | print(d) 22 | -------------------------------------------------------------------------------- /pipeline/vla/showui/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field 2 | 3 | from pipeline.base.base_sync import AbstractPipelineConfig 4 | 5 | 6 | ####### 7 | # VLA # 8 | ####### 9 | 10 | class ShowUIConfig(AbstractPipelineConfig): 11 | model_id: str = Field(default="showlab/ShowUI-2B", description="The ID of the model used for the UI.", frozen=True) 12 | predict_url: str = Field(default="http://127.0.0.1:11000/vla/showui/predict", 13 | description="The URL for UI prediction requests.") 14 | stream_predict_url: str = Field(default="http://127.0.0.1:11000/vla/showui/stream-predict", 15 | description="The URL for streaming UI prediction requests.") 16 | -------------------------------------------------------------------------------- /services/live2d/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | 4 | class Live2DViewerConfig(BaseModel): 5 | enable: bool = Field(default=True, 6 | description="Enable Live2d Viewer?") 7 | model3_json_file: str = Field(default="./resources/static/models/live2d", description="Path of `xxx.model3.json`") 8 | auto_lip_sync: bool = Field(default=True, description="Auto lip sync.") 9 | auto_blink: bool = Field(default=True, description="Auto eye blink.") 10 | auto_breath: bool = Field(default=True, description="Audio eye blink.") 11 | win_height: int = Field(default=True, description="Window height.") 12 | win_width: int = Field(default=True, description="Window width.") 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # If dev branch, uncomment following 2 | # git+https://github.com/AkagawaTsurunaki/zerolan-data.git@dev 3 | git+https://github.com/AkagawaTsurunaki/zerolan-data.git@1.5.0 4 | pyaudio 5 | soundfile 6 | pydub 7 | pillow 8 | scipy 9 | pyautogui 10 | pygame 11 | pyyaml 12 | flask 13 | selenium 14 | numpy==1.26.4 15 | retry 16 | requests 17 | loguru 18 | websockets 19 | langchain 20 | langgraph 21 | langchain-core 22 | beautifulsoup4 23 | bilibili-api-python 24 | twitchAPI 25 | python-magic 26 | transitions 27 | netifaces 28 | injector 29 | typeguard 30 | watchdog 31 | openai 32 | gradio 33 | webrtcvad 34 | live2d-py 35 | PyQt5 36 | librosa 37 | # If you are not a developer, you can skip the following packages 38 | pytest 39 | pytest-asyncio 40 | -------------------------------------------------------------------------------- /agent/tool/microphone_tool.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Any 2 | 3 | from langchain_core.tools import BaseTool 4 | from pydantic import BaseModel, Field 5 | 6 | from event.event_data import DeviceMicrophoneSwitchEvent 7 | from event.event_emitter import emitter 8 | 9 | 10 | class MicrophoneToolInput(BaseModel): 11 | switch: bool = Field(description="`true` if open the microphone, `false` if close the microphone.") 12 | 13 | 14 | class MicrophoneTool(BaseTool): 15 | name: str = "麦克风控制器" 16 | description: str = "当用户要求打开或关闭麦克风时,使用此工具" 17 | args_schema: Type[BaseModel] = MicrophoneToolInput 18 | return_direct: bool = True 19 | 20 | def _run(self, switch: bool) -> Any: 21 | emitter.emit(DeviceMicrophoneSwitchEvent(switch=switch)) 22 | -------------------------------------------------------------------------------- /tests/test_decorator.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from loguru import logger 4 | 5 | from common.decorator import log_run_time 6 | 7 | 8 | @log_run_time(time_limit=5) 9 | def slow_function(): 10 | """ 11 | 模拟一个运行较慢的函数 12 | """ 13 | time.sleep(10) # 模拟耗时操作 14 | return "Task completed" 15 | 16 | 17 | @log_run_time(time_limit=5) 18 | def fast_function(): 19 | """ 20 | 模拟一个运行较快的函数 21 | """ 22 | time.sleep(2) # 模拟耗时操作 23 | return "Task completed quickly" 24 | 25 | 26 | def test_slow_func(): 27 | logger.info("Starting slow function") 28 | result = slow_function() 29 | logger.info(result) 30 | 31 | 32 | def test_fast_func(): 33 | logger.info("Starting fast function") 34 | result = fast_function() 35 | logger.info(result) 36 | -------------------------------------------------------------------------------- /services/game/config.py: -------------------------------------------------------------------------------- 1 | from enum import unique 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from common.enumerator import BaseEnum 6 | from common.utils.enum_util import enum_to_markdown 7 | 8 | 9 | @unique 10 | class PlatformEnum(BaseEnum): 11 | Minecraft: str = "minecraft" 12 | 13 | 14 | class GameBridgeConfig(BaseModel): 15 | enable: bool = Field(True, description="Whether to enable the GameBridge.") 16 | host: str = Field("127.0.0.1", description="The host address of the GameBridge service.") 17 | port: int = Field(11007, description="The port number of the GameBridge service.") 18 | platform: PlatformEnum = Field(PlatformEnum.Minecraft, 19 | description=f"The platform you want to connect to. {enum_to_markdown(PlatformEnum)}") 20 | -------------------------------------------------------------------------------- /tests/test_vec_db.py: -------------------------------------------------------------------------------- 1 | from config import get_config 2 | from zerolan.data.pipeline.milvus import InsertRow, MilvusInsert, MilvusQuery 3 | 4 | from pipeline.db.milvus.milvus_sync import MilvusSyncPipeline 5 | 6 | _config = get_config() 7 | pipeline = MilvusSyncPipeline(config=_config.pipeline.vec_db.milvus) 8 | 9 | 10 | def test_insert(): 11 | texts = ["onani就是0721", "柚子厨真恶心", "我喜欢阿米诺!", "0721就是无吟唱水魔法的意思"] 12 | texts = [InsertRow(id=i, text=texts[i], subject="history") for i in range(len(texts))] 13 | mi = MilvusInsert(collection_name="Test", texts=texts, drop_if_exists=True) 14 | 15 | ir = pipeline.insert(mi) 16 | print(ir) 17 | 18 | 19 | def test_search(): 20 | mq = MilvusQuery(collection_name="Test", limit=2, output_fields=["text", 'history'], query="0721是什么意思?") 21 | qr = pipeline.search(mq) 22 | print(qr) 23 | -------------------------------------------------------------------------------- /common/utils/str_util.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from common.enumerator import Language 4 | 5 | 6 | def is_blank(s: str) -> bool: 7 | """Check if a string is None, empty, or contains only whitespace.""" 8 | return s is None or not s.strip() or s == "" 9 | 10 | 11 | def split_by_punc(text: str, lang: Language) -> List[str]: 12 | if lang == Language.ZH: 13 | cut_punc = ",。!?" 14 | elif lang == Language.JA: 15 | cut_punc = "、。!?" 16 | else: 17 | cut_punc = ",.!?" 18 | 19 | def punc_cut(text: str, punc: str): 20 | texts = [] 21 | last = -1 22 | for i in range(len(text)): 23 | if text[i] in punc: 24 | try: 25 | texts.append(text[last + 1: i]) 26 | except IndexError: 27 | continue 28 | last = i 29 | return texts 30 | 31 | return punc_cut(text, cut_punc) 32 | -------------------------------------------------------------------------------- /tests/services/test_live2d.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import time 3 | from pathlib import Path 4 | 5 | from common.concurrent.killable_thread import KillableThread 6 | from manager.config_manager import get_project_dir 7 | from services.live2d.live2d_viewer import Live2DViewer 8 | from services.live2d.config import Live2DViewerConfig 9 | 10 | _dir = get_project_dir() 11 | _model_file = os.path.join(_dir, 12 | r"resources/static/models/live2d/hiyori_pro_zh/hiyori_pro_zh/runtime/hiyori_pro_t11.model3.json") 13 | _audio_path = "resources/tts-test.wav" 14 | 15 | 16 | def test_live2d(): 17 | config = Live2DViewerConfig( 18 | model3_json_file=_model_file) 19 | viewer = Live2DViewer(config) 20 | t = KillableThread(target=viewer.start) 21 | t.start() 22 | time.sleep(1) 23 | viewer.sync_lip(Path(_audio_path)) 24 | time.sleep(2) 25 | viewer.sync_lip(Path(_audio_path)) 26 | time.sleep(10) 27 | t.kill() 28 | -------------------------------------------------------------------------------- /tests/test_zws.py: -------------------------------------------------------------------------------- 1 | from zerolan.data.protocol.protocol import ZerolanProtocol 2 | 3 | from common.web.zrl_ws import ZerolanProtocolWsServer 4 | 5 | _test = { 6 | "protocol": "ZerolanProtocol", 7 | "version": "1.1", 8 | "message": "Ciallo", 9 | "code": 0, 10 | "action": "Onanii", 11 | "data": { 12 | "frequency": 114514, 13 | "hand": "right" 14 | } 15 | } 16 | 17 | 18 | class TestZwsImpl(ZerolanProtocolWsServer): 19 | 20 | def on_protocol(self, protocol: ZerolanProtocol): 21 | print(protocol) 22 | assert protocol.data["frequency"] == _test["data"]["frequency"] 23 | assert protocol.data["hand"] == _test["data"]["hand"] 24 | self.send(action="Ciallo", message="Kimochii~", data={"aieki": 100, "love": 100}) 25 | 26 | def on_disconnect(self, ws_id: str): 27 | print(ws_id) 28 | 29 | 30 | def test_zws(): 31 | server = TestZwsImpl(host='127.0.0.1', port=11013) 32 | server.start() 33 | -------------------------------------------------------------------------------- /pipeline/vla/showui/showui_sync.py: -------------------------------------------------------------------------------- 1 | from zerolan.data.pipeline.vla import ShowUiQuery, ShowUiPrediction 2 | 3 | from pipeline.base.base_sync import AbstractImagePipeline 4 | from pipeline.vla.showui.config import ShowUIConfig 5 | 6 | 7 | class ShowUISyncPipeline(AbstractImagePipeline): 8 | 9 | def __init__(self, config: ShowUIConfig): 10 | super().__init__(config) 11 | 12 | def predict(self, query: ShowUiQuery) -> ShowUiPrediction | None: 13 | assert isinstance(query, ShowUiQuery) 14 | return super().predict(query) 15 | 16 | def stream_predict(self, query: ShowUiQuery, chunk_size: int | None = None): 17 | assert isinstance(query, ShowUiQuery) 18 | raise NotImplementedError() 19 | 20 | def parse_query(self, query: any) -> dict: 21 | return super().parse_query(query=query) 22 | 23 | def parse_prediction(self, json_val: any) -> ShowUiPrediction: 24 | return ShowUiPrediction.model_validate_json(json_val) 25 | -------------------------------------------------------------------------------- /pipeline/ocr/ocr_async.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from typeguard import typechecked 4 | from zerolan.data.pipeline.ocr import OCRQuery, OCRPrediction 5 | 6 | from pipeline.base.base_async import BaseAsyncPipeline, _parse_imgcap_query, get_base_url 7 | from pipeline.ocr.config import OCRPipelineConfig, OCRModelIdEnum 8 | 9 | 10 | class OCRAsyncPipeline(BaseAsyncPipeline): 11 | def __init__(self, config: OCRPipelineConfig): 12 | super().__init__(base_url=get_base_url(config.predict_url)) 13 | self._model_id: OCRModelIdEnum = config.model_id 14 | self._predict_endpoint = "/ocr/predict" 15 | self._stream_predict_endpoint = "/ocr/stream-predict" 16 | 17 | @typechecked 18 | async def predict(self, query: OCRQuery) -> OCRPrediction: 19 | data = _parse_imgcap_query(query) 20 | async with self.session.post(self._predict_endpoint, data=data) as resp: 21 | return await resp.json(encoding='utf8', loads=OCRPrediction.model_validate_json) 22 | -------------------------------------------------------------------------------- /pipeline/vla/showui/showui_async.py: -------------------------------------------------------------------------------- 1 | from typeguard import typechecked 2 | from zerolan.data.pipeline.vla import ShowUiPrediction, ShowUiQuery 3 | 4 | from pipeline.base.base_async import BaseAsyncPipeline, get_base_url 5 | from pipeline.vla.config import VLAModelIdEnum 6 | from pipeline.vla.showui.config import ShowUIConfig 7 | 8 | 9 | class ShowUIAsyncPipeline(BaseAsyncPipeline): 10 | def __init__(self, config: ShowUIConfig): 11 | super().__init__(base_url=get_base_url(config.predict_url)) 12 | assert str(config.model_id) == VLAModelIdEnum.ShowUI.value, f"Model ID is wrong." 13 | self._model_id: VLAModelIdEnum = VLAModelIdEnum.ShowUI 14 | self._predict_endpoint = "/vla/showui/predict" 15 | 16 | @typechecked 17 | async def predict(self, query: ShowUiQuery) -> ShowUiPrediction: 18 | async with self.session.post(self._predict_endpoint, json=query.model_dump()) as resp: 19 | return await resp.json(encoding='utf8', loads=ShowUiPrediction.model_validate_json) 20 | -------------------------------------------------------------------------------- /pipeline/ocr/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field 2 | 3 | from common.enumerator import BaseEnum 4 | from common.utils.enum_util import enum_to_markdown 5 | from pipeline.base.base_sync import AbstractPipelineConfig 6 | 7 | 8 | ####### 9 | # OCR # 10 | ####### 11 | 12 | class OCRModelIdEnum(BaseEnum): 13 | PaddleOCR = 'paddlepaddle/PaddleOCR' 14 | 15 | 16 | class OCRPipelineConfig(AbstractPipelineConfig): 17 | model_id: OCRModelIdEnum = Field(default=OCRModelIdEnum.PaddleOCR, 18 | description=f"The ID of the model used for OCR. \n" 19 | f"{enum_to_markdown(OCRModelIdEnum)}") 20 | predict_url: str = Field(default="http://127.0.0.1:11000/ocr/predict", 21 | description="The URL for OCR prediction requests.") 22 | stream_predict_url: str = Field(default="http://127.0.0.1:11000/ocr/stream-predict", 23 | description="The URL for streaming OCR prediction requests.") 24 | -------------------------------------------------------------------------------- /common/utils/file_util.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from pathlib import Path 3 | from uuid import uuid4 4 | 5 | from typeguard import typechecked 6 | 7 | from services.playground.data import FileInfo 8 | 9 | 10 | @typechecked 11 | def get_file_info(path: str) -> FileInfo: 12 | assert os.path.exists(path) 13 | file_extension = Path(path).suffix 14 | file_name = Path(path).name 15 | # file_name = file_name[:len(file_name) - len(file_extension)] 16 | if file_extension is not None and len(file_extension) > 1: 17 | file_extension = file_extension[1:] 18 | file_size = os.path.getsize(path) 19 | return FileInfo( 20 | file_id=f"{uuid4()}", 21 | uri=path_to_uri(path), 22 | file_type=file_extension, 23 | origin_file_name=file_name, 24 | file_name=file_name, 25 | file_size=file_size, 26 | ) 27 | 28 | 29 | def path_to_uri(path): 30 | path = os.path.abspath(path) 31 | path = path.replace('\\', '/') 32 | uri = f'file:///{path}' 33 | return uri 34 | -------------------------------------------------------------------------------- /pipeline/imgcap/imgcap_async.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from typeguard import typechecked 4 | from zerolan.data.pipeline.img_cap import ImgCapQuery, ImgCapPrediction 5 | 6 | from pipeline.base.base_async import BaseAsyncPipeline, _parse_imgcap_query, get_base_url 7 | from pipeline.imgcap.config import ImgCapPipelineConfig, ImgCapModelIdEnum 8 | 9 | 10 | class ImgCapAsyncPipeline(BaseAsyncPipeline): 11 | def __init__(self, config: ImgCapPipelineConfig): 12 | super().__init__(base_url=get_base_url(config.predict_url)) 13 | self._model_id: ImgCapModelIdEnum = config.model_id 14 | self._predict_endpoint = "/img-cap/predict" 15 | self._stream_predict_endpoint = "/img-cap/stream-predict" 16 | 17 | @typechecked 18 | async def predict(self, query: ImgCapQuery) -> ImgCapPrediction: 19 | data = _parse_imgcap_query(query) 20 | async with self.session.post(self._predict_endpoint, data=data) as resp: 21 | return await resp.json(encoding='utf8', loads=ImgCapPrediction.model_validate_json) 22 | -------------------------------------------------------------------------------- /pipeline/imgcap/imgcap_sync.py: -------------------------------------------------------------------------------- 1 | from requests import Response 2 | from zerolan.data.pipeline.img_cap import ImgCapQuery, ImgCapPrediction 3 | 4 | from pipeline.base.base_sync import AbstractImagePipeline 5 | from pipeline.imgcap.config import ImgCapPipelineConfig 6 | 7 | 8 | class ImgCapSyncPipeline(AbstractImagePipeline): 9 | 10 | def __init__(self, config: ImgCapPipelineConfig): 11 | super().__init__(config) 12 | 13 | def predict(self, query: ImgCapQuery) -> ImgCapPrediction | None: 14 | assert isinstance(query, ImgCapQuery) 15 | return super().predict(query) 16 | 17 | def stream_predict(self, query: ImgCapQuery, chunk_size: int | None = None): 18 | assert isinstance(query, ImgCapQuery) 19 | raise NotImplementedError() 20 | 21 | def parse_query(self, query: any) -> dict: 22 | return super().parse_query(query) 23 | 24 | def parse_prediction(self, response: Response) -> ImgCapPrediction: 25 | json_val = response.content 26 | return ImgCapPrediction.model_validate_json(json_val) 27 | -------------------------------------------------------------------------------- /agent/tool/lang_changer.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Any 2 | 3 | from injector import inject 4 | from langchain_core.tools import BaseTool 5 | from pydantic import BaseModel, Field 6 | 7 | from common.enumerator import Language 8 | from common.utils.enum_util import enum_members_to_list 9 | from event.event_data import LanguageChangeEvent 10 | from event.event_emitter import emitter 11 | 12 | 13 | class LangChangeInput(BaseModel): 14 | target_lang: str = Field(description=f"Target language: {enum_members_to_list(Language)}, Only return ") 15 | 16 | 17 | class LangChanger(BaseTool): 18 | name: str = "语言切换" 19 | description: str = "当用户需要切换语言时,使用此工具" 20 | args_schema: Type[BaseModel] = LangChangeInput 21 | return_direct: bool = True 22 | 23 | @inject 24 | def __init__(self): 25 | super().__init__() 26 | 27 | def _run(self, target_lang: Language) -> Any: 28 | if isinstance(target_lang, str): 29 | target_lang = Language.value_of(target_lang) 30 | emitter.emit(LanguageChangeEvent(target_lang=target_lang)) 31 | -------------------------------------------------------------------------------- /pipeline/imgcap/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field 2 | 3 | from common.enumerator import BaseEnum 4 | from common.utils.enum_util import enum_to_markdown 5 | from pipeline.base.base_sync import AbstractPipelineConfig 6 | 7 | 8 | ########## 9 | # ImgCap # 10 | ########## 11 | 12 | class ImgCapModelIdEnum(BaseEnum): 13 | Blip = 'Salesforce/blip-image-captioning-large' 14 | 15 | 16 | class ImgCapPipelineConfig(AbstractPipelineConfig): 17 | model_id: ImgCapModelIdEnum = Field(default=ImgCapModelIdEnum.Blip, 18 | description=f"The ID of the model used for image captioning. " 19 | f"\n{enum_to_markdown(ImgCapModelIdEnum)}") 20 | predict_url: str = Field(default="http://127.0.0.1:11000/img-cap/predict", 21 | description="The URL for image captioning prediction requests.") 22 | stream_predict_url: str = Field(default="http://127.0.0.1:11000/img-cap/stream-predict", 23 | description="The URL for streaming image captioning prediction requests.") 24 | -------------------------------------------------------------------------------- /tests/services/test_obs_client.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | 4 | from config import get_config 5 | from common.concurrent.killable_thread import KillableThread 6 | from services.obs.client import ObsStudioWsClient 7 | 8 | s = """哈基米像一颗巨石按在胸口 9 | 哈基米阿西噶嗨呀库纳鲁多米吉哈 10 | 哦马吉里哈基米喔纳梅鲁多阿西噶 11 | 叮鸡咚鸡叮咚叮鸡zakozako onina 12 | 恐龙抗狼一段一段下来出示健康码 13 | 哈基米阿西噶嗨呀库纳鲁多米吉哈 14 | 哦马吉里哈基米喔纳梅鲁多阿西噶 15 | 叮鸡咚鸡叮咚叮鸡zakozako onina 16 | 恐龙抗狼一段一段下来出示健康码 17 | 阿西噶喔纳美噜多几点起床妈妈酱 18 | 曼波哈基好胖可爱楼上下去草哈基 19 | 离原上咪一岁一咪打野来搞小啾啾 20 | 野火哈咪春风吹咪小白手套ccb! 21 | 哈基米阿西噶嗨呀库纳鲁多米吉哈 22 | 哦马吉里哈基米喔纳梅鲁多阿西噶 23 | 叮鸡咚鸡叮咚叮鸡zakozako onina 24 | 哎哟我滴妈! 25 | """ 26 | _config = get_config() 27 | 28 | _client = ObsStudioWsClient(_config.service.obs) 29 | 30 | 31 | def test_conn(): 32 | client_thread = KillableThread(target=_client.start, daemon=True) 33 | client_thread.start() 34 | time.sleep(1) 35 | if _client.is_connected: 36 | _client.subtitle(s, "assistant", 3) 37 | 38 | time.sleep(4) 39 | assert _client.is_connected, "Test failed!" 40 | threading.Thread(target=_client.stop, daemon=True).start() 41 | 42 | client_thread.kill() 43 | -------------------------------------------------------------------------------- /pipeline/vidcap/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field 2 | 3 | from common.enumerator import BaseEnum 4 | from common.utils.enum_util import enum_to_markdown 5 | from pipeline.base.base_sync import AbstractPipelineConfig 6 | 7 | 8 | ########## 9 | # VidCap # 10 | ########## 11 | 12 | class VidCapModelIdEnum(BaseEnum): 13 | Hitea = 'iic/multi-modal_hitea_video-captioning_base_en' 14 | 15 | 16 | class VidCapPipelineConfig(AbstractPipelineConfig): 17 | model_id: VidCapModelIdEnum = Field(default=VidCapModelIdEnum.Hitea, 18 | description=f"The ID of the model used for video captioning. \n" 19 | f"{enum_to_markdown(VidCapModelIdEnum)}") 20 | predict_url: str = Field(default="http://127.0.0.1:11000/vid_cap/predict", 21 | description="The URL for video captioning prediction requests.") 22 | stream_predict_url: str = Field(default="http://127.0.0.1:11000/vid-cap/stream-predict", 23 | description="The URL for streaming video captioning prediction requests.") 24 | -------------------------------------------------------------------------------- /pipeline/vidcap/vidcap_sync.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from requests import Response 4 | from zerolan.data.pipeline.vid_cap import VidCapQuery, VidCapPrediction 5 | 6 | from pipeline.base.base_sync import CommonModelPipeline 7 | from pipeline.vidcap.config import VidCapPipelineConfig 8 | 9 | 10 | class VidCapSyncPipeline(CommonModelPipeline): 11 | 12 | def __init__(self, config: VidCapPipelineConfig): 13 | """ 14 | 此接口保留,但是可能会在将来废弃而放弃维护 15 | :param config: 16 | """ 17 | super().__init__(config) 18 | 19 | def predict(self, query: VidCapQuery) -> VidCapPrediction | None: 20 | assert isinstance(query, VidCapQuery) 21 | assert os.path.exists(query.vid_path), f"视频路径不存在:{query.vid_path}" 22 | return super().predict(query) 23 | 24 | def stream_predict(self, query: VidCapQuery, chunk_size: int | None = None): 25 | assert isinstance(query, VidCapQuery) 26 | raise NotImplementedError() 27 | 28 | def parse_prediction(self, response: Response) -> VidCapPrediction: 29 | json_val = response.content 30 | return VidCapPrediction.model_validate(json_val) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AkagawaTsurunaki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /agent/tool/go_creator.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, Type 3 | 4 | from injector import inject 5 | from langchain_core.tools import BaseTool 6 | from pydantic import BaseModel, Field 7 | 8 | from services.playground.data import CreateGameObjectResponse 9 | from services.playground.bridge import PlaygroundBridge 10 | 11 | 12 | class GameObjectCreatorInput(BaseModel): 13 | dto: CreateGameObjectResponse = Field(description="DTO of the game object will be created") 14 | 15 | 16 | class GameObjectCreator(BaseTool): 17 | name: str = "游戏对象创建器" 18 | description: str = "当用户要求你创建一个游戏对象(例如立方体、球体)的时候,使用此工具。" 19 | args_schema: Type[BaseModel] = GameObjectCreatorInput 20 | return_direct: bool = True 21 | 22 | @inject 23 | def __init__(self, bridge: PlaygroundBridge): 24 | super().__init__() 25 | self._bridge = bridge 26 | 27 | def _run(self, dto: CreateGameObjectResponse) -> Any: 28 | task = [asyncio.create_task(self._arun(dto))] 29 | asyncio.gather(*task) 30 | 31 | async def _arun(self, dto: CreateGameObjectResponse) -> None: 32 | await self._bridge.create_gameobject(dto) 33 | -------------------------------------------------------------------------------- /pipeline/vidcap/vidcap_async.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from typeguard import typechecked 4 | from zerolan.data.pipeline.vid_cap import VidCapQuery, VidCapPrediction 5 | 6 | from pipeline.base.base_async import BaseAsyncPipeline, get_base_url 7 | from pipeline.vidcap.config import VidCapPipelineConfig, VidCapModelIdEnum 8 | 9 | 10 | def _parse_vid_cap_query(query: VidCapQuery): 11 | if os.path.exists(query.vid_path): 12 | query.vid_path = os.path.abspath(query.vid_path).replace("\\", "/") 13 | return query 14 | 15 | 16 | class VidCapAsyncPipeline(BaseAsyncPipeline): 17 | def __init__(self, config: VidCapPipelineConfig): 18 | super().__init__(base_url=get_base_url(config.predict_url)) 19 | self._model_id: VidCapModelIdEnum = config.model_id 20 | self._predict_endpoint = "/vid-cap/predict" 21 | 22 | @typechecked 23 | async def predict(self, query: VidCapQuery) -> VidCapPrediction: 24 | data = _parse_vid_cap_query(query) 25 | async with self.session.post(self._predict_endpoint, json=data.model_dump()) as resp: 26 | return await resp.json(encoding='utf8', loads=VidCapPrediction.model_validate_json) 27 | -------------------------------------------------------------------------------- /pipeline/db/milvus/milvus_async.py: -------------------------------------------------------------------------------- 1 | from typeguard import typechecked 2 | from zerolan.data.pipeline.milvus import MilvusQuery, MilvusQueryResult, MilvusInsert, MilvusInsertResult 3 | 4 | from pipeline.base.base_async import BaseAsyncPipeline, get_base_url 5 | from pipeline.db.milvus.milvus_sync import MilvusDatabaseConfig 6 | 7 | 8 | class MilvusAsyncPipeline(BaseAsyncPipeline): 9 | def __init__(self, config: MilvusDatabaseConfig): 10 | super().__init__(base_url=get_base_url(config.search_url)) 11 | self._search_endpoint = "/milvus/search" 12 | self._insert_endpoint = "/milvus/insert" 13 | 14 | @typechecked 15 | async def search(self, query: MilvusQuery) -> MilvusQueryResult: 16 | async with self.session.post(self._search_endpoint, json=query.model_dump()) as resp: 17 | return await resp.json(encoding='utf8', loads=MilvusQueryResult.model_validate_json) 18 | 19 | @typechecked 20 | async def insert(self, insert: MilvusInsert) -> MilvusInsertResult: 21 | async with self.session.post(self._insert_endpoint, json=insert.model_dump()) as resp: 22 | return await resp.json(encoding='utf8', loads=MilvusInsertResult.model_validate_json) 23 | -------------------------------------------------------------------------------- /services/qqbot/bridge.py: -------------------------------------------------------------------------------- 1 | from injector import inject 2 | from zerolan.data.protocol.protocol import ZerolanProtocol 3 | 4 | from services.qqbot.config import QQBotBridgeConfig 5 | from common.web.zrl_ws import ZerolanProtocolWsServer 6 | from event.event_data import QQMessageEvent 7 | from event.event_emitter import emitter 8 | 9 | 10 | class _QQBotAction: 11 | GROUP_MESSAGE = "group_message" 12 | 13 | 14 | class QQBotBridge(ZerolanProtocolWsServer): 15 | 16 | @inject 17 | def __init__(self, config: QQBotBridgeConfig): 18 | host, port = config.host, config.port 19 | ZerolanProtocolWsServer.__init__(self, host=host, port=port) 20 | 21 | async def on_protocol(self, protocol: ZerolanProtocol): 22 | if protocol.action == _QQBotAction.GROUP_MESSAGE: 23 | msg = protocol.data["message"] 24 | group_id = protocol.data.get("group_id", None) 25 | emitter.emit(QQMessageEvent(message=msg, group_id=group_id)) 26 | 27 | def name(self): 28 | return "QQBotBridge" 29 | 30 | def send_plain_message(self, message: str, group: int): 31 | self.send(action="send_plain_text_in_group", data={ 32 | "group": group, 33 | "message": message 34 | }) 35 | -------------------------------------------------------------------------------- /manager/model_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List 3 | 4 | from loguru import logger 5 | 6 | from common.utils.file_util import get_file_info 7 | from services.playground.data import FileInfo 8 | 9 | 10 | class ModelManager: 11 | def __init__(self, model_dir=None): 12 | self._model_dir = R".\resources\static\models\3d" if model_dir is None else model_dir 13 | self._model_files: Dict[str, FileInfo] = {} 14 | 15 | def scan(self): 16 | for root, dirs, files in os.walk(self._model_dir): 17 | for file in files: 18 | filepath = os.path.join(root, file) 19 | file_info = get_file_info(filepath) 20 | self._model_files[file_info.file_id] = file_info 21 | logger.info(f"{file_info.file_name} is registered as {file_info.file_id}") 22 | 23 | def get_files(self) -> List[dict]: 24 | files = [] 25 | for _, file_info in self._model_files.items(): 26 | files.append({ 27 | "id": file_info.file_id, 28 | "filename": file_info.file_name, 29 | }) 30 | return files 31 | 32 | def get_file_by_id(self, file_id: str) -> FileInfo: 33 | return self._model_files[file_id] 34 | -------------------------------------------------------------------------------- /tests/pipeline/test_ocr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from zerolan.data.pipeline.ocr import OCRQuery 3 | 4 | from manager.config_manager import get_config 5 | from pipeline.ocr.ocr_async import OCRAsyncPipeline 6 | from pipeline.ocr.ocr_sync import OCRSyncPipeline 7 | 8 | _config = get_config() 9 | _ocr_async = OCRAsyncPipeline(_config.pipeline.ocr) 10 | _ocr_sync = OCRSyncPipeline(_config.pipeline.ocr) 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def event_loop(event_loop_policy): 15 | # Needed to work with asyncpg 16 | loop = event_loop_policy.new_event_loop() 17 | yield loop 18 | loop.close() 19 | 20 | 21 | def test_ocr_predict(): 22 | query = OCRQuery(img_path="resources/ocr-test.png") 23 | prediction = _ocr_sync.predict(query) 24 | assert prediction, f"Test failed." 25 | print(prediction.model_dump_json()) 26 | assert "我是赤川鹤鸣" in prediction.model_dump_json() 27 | 28 | 29 | @pytest.mark.asyncio 30 | async def test_ocr_predict(): 31 | query = OCRQuery(img_path="resources/ocr-test.png") 32 | prediction = await _ocr_async.predict(query) 33 | assert prediction, f"Test failed: No response." 34 | print(prediction.model_dump_json()) 35 | assert "我是赤川鹤鸣" in prediction.model_dump_json(), "Test failed: Wrong result." 36 | -------------------------------------------------------------------------------- /services/playground/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | from common.enumerator import BaseEnum 4 | from common.utils.enum_util import enum_to_markdown 5 | 6 | 7 | class DisplayModeEnum(BaseEnum): 8 | AR: str = "ar" 9 | Live2D: str = "live2D" 10 | 11 | 12 | class PlaygroundBridgeConfig(BaseModel): 13 | enable: bool = Field(default=True, 14 | description="Whether to enable PlaygroundBridge WebSocket server.") 15 | host: str = Field(default="0.0.0.0", 16 | description="The host address of the PlaygroundBridge server.") 17 | port: int = Field(default=11013, 18 | description="The port number of the PlaygroundBridge server.") 19 | mode: DisplayModeEnum = Field(default=DisplayModeEnum.Live2D, 20 | description=f"The display mode of the client. {enum_to_markdown(DisplayModeEnum)}") 21 | bot_id: str = Field(default=f"", 22 | description="The ID of the bot. \n" 23 | "You can set it to any value.") 24 | model_dir: str = Field(default="./resources/static/models/live2d/", 25 | description="The path to the model directory.") 26 | -------------------------------------------------------------------------------- /pipeline/llm/llm_async.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | 3 | from typeguard import typechecked 4 | from zerolan.data.pipeline.llm import LLMQuery, LLMPrediction 5 | 6 | from pipeline.base.base_async import BaseAsyncPipeline, get_base_url 7 | from pipeline.llm.config import LLMPipelineConfig, LLMModelIdEnum 8 | 9 | 10 | class LLMAsyncPipeline(BaseAsyncPipeline): 11 | def __init__(self, config: LLMPipelineConfig): 12 | super().__init__(base_url=get_base_url(config.predict_url)) 13 | self._model_id: LLMModelIdEnum = config.model_id 14 | self._predict_endpoint = "/llm/predict" 15 | self._stream_predict_endpoint = "/llm/stream-predict" 16 | 17 | @typechecked 18 | async def predict(self, query: LLMQuery) -> LLMPrediction: 19 | async with self.session.post(self._predict_endpoint, json=query.model_dump()) as resp: 20 | return await resp.json(encoding='utf8', loads=LLMPrediction.model_validate_json) 21 | 22 | @typechecked 23 | async def stream_predict(self, query: LLMQuery) -> Generator[LLMPrediction, None, None]: 24 | async with self.session.post(self._stream_predict_endpoint, json=query.model_dump()) as resp: 25 | return await resp.json(encoding='utf8', loads=LLMPrediction.model_validate_json) 26 | -------------------------------------------------------------------------------- /services/browser/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | from common.enumerator import BaseEnum 4 | from common.utils.enum_util import enum_to_markdown 5 | 6 | 7 | class SeleniumDriverEnum(BaseEnum): 8 | Firefox: str = "Firefox" 9 | 10 | 11 | class BrowserConfig(BaseModel): 12 | enable: bool = Field(default=True, description="Enable selenium to controller your browser?\n" 13 | "Warning: VLA may control your mouse cursor and use keyboards!") 14 | profile_dir: str | None = Field(default=None, 15 | description="Browser's Profile folder. \n" 16 | "This is to ensure that under Selenium's control, \n" 17 | "your account login and other information will not be lost. If the value is 'null', \n" 18 | "the program will automatically detect the location (Windows only).") 19 | driver: SeleniumDriverEnum = Field(default=SeleniumDriverEnum.Firefox, 20 | description="Browser drivers, for Selenium. \n" 21 | f"{enum_to_markdown(SeleniumDriverEnum)}") 22 | -------------------------------------------------------------------------------- /tests/pipeline/test_vla.py: -------------------------------------------------------------------------------- 1 | from zerolan.data.pipeline.vla import ShowUiQuery, WebAction 2 | 3 | from manager.config_manager import get_config 4 | from pipeline.vla.showui.showui_sync import ShowUISyncPipeline 5 | 6 | _config = get_config() 7 | _showui_sync = ShowUISyncPipeline(_config.pipeline.vla.showui) 8 | 9 | 10 | def test_showui(): 11 | query = ShowUiQuery(img_path="resources/imgcap-test.png", query="Click the Ciallo") 12 | prediction = _showui_sync.predict(query) 13 | assert prediction.actions 14 | for action in prediction.actions: 15 | print(action.model_dump_json()) 16 | 17 | query = ShowUiQuery(img_path="resources/imgcap-test.png", query="Click the Ciallo", 18 | action=WebAction(action="CLICK")) 19 | prediction = _showui_sync.predict(query) 20 | assert prediction.actions 21 | for action in prediction.actions: 22 | print(action.model_dump_json()) 23 | history = [WebAction(action="INPUT", value="Hello", position=None), 24 | WebAction(action="SELECT_TEXT", value=None, position=[0.2, 0.3])] 25 | query = ShowUiQuery(img_path="resources/imgcap-test.png", query="Click the Ciallo", env="web", history=history) 26 | prediction = _showui_sync.predict(query) 27 | assert prediction.actions 28 | -------------------------------------------------------------------------------- /common/utils/enum_util.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Type, Any, List 3 | 4 | from typeguard import typechecked 5 | 6 | 7 | def enum_members_to_list(enum: Type[Enum]) -> List[Any]: 8 | return [member.value for member in enum] 9 | 10 | 11 | @typechecked 12 | def enum_members_to_str_list(enum: Type[Enum]) -> List[str]: 13 | return [str(elm.value) for elm in list(enum)] 14 | 15 | 16 | @typechecked 17 | def _enum_members_to_plain_text_with_comma(enum: Type[Enum]) -> str: 18 | str_list = enum_members_to_str_list(enum) 19 | text = "" 20 | for string in str_list: 21 | text += "`" + string + "`" + ", " 22 | return text[:-2] 23 | 24 | 25 | @typechecked 26 | def enum_to_markdown(enum: Type[Enum]) -> str: 27 | num_of_enum = len(enum) 28 | if num_of_enum == 1: 29 | return f"`{list(enum)[0].value}` is supported only." 30 | else: 31 | candidates = _enum_members_to_plain_text_with_comma(enum) 32 | return f"{candidates} are supported." 33 | 34 | @typechecked 35 | def enum_to_markdown_zh(enum: Type[Enum]) -> str: 36 | num_of_enum = len(enum) 37 | if num_of_enum == 1: 38 | return f"仅支持 `{list(enum)[0].value}`。" 39 | else: 40 | candidates = _enum_members_to_plain_text_with_comma(enum) 41 | return f"支持 {candidates}。" -------------------------------------------------------------------------------- /tests/test_live_stream.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from event.event_data import LiveStreamDanmakuEvent 4 | from event.event_emitter import emitter 5 | from event.registry import EventKeyRegistry 6 | from manager.config_manager import get_config 7 | from services.live_stream.bilibili import BilibiliService 8 | from services.live_stream.twitch import TwitchService 9 | from services.live_stream.youtube import YouTubeService 10 | 11 | _config = get_config() 12 | 13 | 14 | @emitter.on(EventKeyRegistry.LiveStream.DANMAKU) 15 | async def handle(event: LiveStreamDanmakuEvent): 16 | print(event.danmaku.content) 17 | 18 | 19 | async def async_test_bilibili(): 20 | emitter_task = asyncio.create_task(emitter.start()) 21 | bilibili = BilibiliService(_config.service.live_stream.bilibili) 22 | bili_task = asyncio.create_task(bilibili.start()) 23 | await asyncio.sleep(3) 24 | await bilibili.stop() 25 | await emitter.stop() 26 | await bili_task 27 | await emitter_task 28 | 29 | 30 | def test_bilibili(): 31 | asyncio.run(async_test_bilibili()) 32 | 33 | 34 | def test_twitch(): 35 | twitch = TwitchService(_config.service.live_stream.twitch) 36 | twitch.start() 37 | 38 | 39 | def test_youtube(): 40 | youtube = YouTubeService(_config.service.live_stream.youtube) 41 | youtube.start() 42 | -------------------------------------------------------------------------------- /tests/pipeline/test_imgcap.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from zerolan.data.pipeline.img_cap import ImgCapQuery 3 | 4 | from manager.config_manager import get_config 5 | from pipeline.imgcap.imgcap_async import ImgCapAsyncPipeline 6 | from pipeline.imgcap.imgcap_sync import ImgCapSyncPipeline 7 | 8 | _config = get_config() 9 | _imgcap_sync = ImgCapSyncPipeline(_config.pipeline.img_cap) 10 | _imgcap_async = ImgCapAsyncPipeline(_config.pipeline.img_cap) 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def event_loop(event_loop_policy): 15 | # Needed to work with asyncpg 16 | loop = event_loop_policy.new_event_loop() 17 | yield loop 18 | loop.close() 19 | 20 | 21 | def test_imgcap_sync(): 22 | query = ImgCapQuery(img_path="resources/imgcap-test.png") 23 | prediction = _imgcap_sync.predict(query) 24 | assert prediction, f"Test failed: No response." 25 | print(prediction.caption) 26 | assert "girl" in prediction.caption, f"Test failed: Wrong result." 27 | 28 | 29 | @pytest.mark.asyncio 30 | async def test_imgcap_async(): 31 | query = ImgCapQuery(img_path="resources/imgcap-test.png") 32 | prediction = await _imgcap_async.predict(query) 33 | assert prediction, f"Test failed: No response." 34 | print(prediction.caption) 35 | assert "girl" in prediction.caption, f"Test failed: Wrong result." 36 | -------------------------------------------------------------------------------- /tests/test_reload.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio import TaskGroup 3 | 4 | import pytest 5 | 6 | from framework.base_bot import BaseBot 7 | from manager.config_manager import get_config 8 | 9 | 10 | @pytest.fixture(scope="session") 11 | def event_loop(event_loop_policy): 12 | # Needed to work with asyncpg 13 | loop = event_loop_policy.new_event_loop() 14 | yield loop 15 | loop.close() 16 | 17 | 18 | @pytest.mark.asyncio 19 | async def test_reload_pipelines(): 20 | bot = BaseBot() 21 | async with TaskGroup() as tg: 22 | tg.create_task(bot.start()) 23 | await asyncio.sleep(1) 24 | bot.reload_pipeline() 25 | await asyncio.sleep(1) 26 | config = get_config() 27 | config.pipeline.asr.predict_url = "http://127.0.0.1/asr/predict" 28 | bot.reload_pipeline() 29 | await bot.stop() 30 | 31 | 32 | @pytest.mark.asyncio 33 | async def test_reload_device(): 34 | bot = BaseBot() 35 | async with TaskGroup() as tg: 36 | tg.create_task(bot.start()) 37 | await asyncio.sleep(1) 38 | bot.reload_device() 39 | await asyncio.sleep(1) 40 | config = get_config() 41 | config.system.default_enable_microphone = not config.system.default_enable_microphone 42 | bot.reload_device() 43 | await bot.stop() 44 | -------------------------------------------------------------------------------- /common/collection/limit_list.py: -------------------------------------------------------------------------------- 1 | class LimitList(list): 2 | def __init__(self, maxsize: int): 3 | super().__init__() 4 | self.maxsize = maxsize 5 | 6 | def add(self, item): 7 | if len(self) < self.maxsize: 8 | super().append(item) 9 | else: 10 | super().pop(0) 11 | super().append(item) 12 | 13 | def __setitem__(self, index, value): 14 | if index >= len(self): 15 | raise IndexError("List assignment index out of range") 16 | super().__setitem__(index, value) 17 | 18 | def __delitem__(self, index): 19 | if index >= len(self): 20 | raise IndexError("List assignment index out of range") 21 | super().__delitem__(index) 22 | 23 | 24 | def insert(self, index, item): 25 | if index > len(self): 26 | raise IndexError("List assignment index out of range") 27 | if len(self) + 1 > self.maxsize: 28 | super().pop(0) 29 | super().insert(index, item) 30 | 31 | def extend(self, iterable): 32 | additional = len(iterable) + len(self) - self.maxsize 33 | if additional > 0: 34 | for _ in range(additional): 35 | super().pop(0) 36 | super().extend(iterable) 37 | 38 | def append(self, item): 39 | self.add(item) 40 | -------------------------------------------------------------------------------- /services/game/minecraft/instrcution/tool.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import Type 3 | 4 | from langchain_core.messages import ToolCall 5 | from langchain_core.tools import BaseTool 6 | from loguru import logger 7 | from pydantic import BaseModel 8 | 9 | from event.event_emitter import emitter 10 | from event.registry import EventKeyRegistry 11 | from services.game.minecraft.data import KonekoProtocol 12 | 13 | 14 | class KonekoInstructionTool(BaseTool): 15 | name: str = None 16 | description: str = None 17 | args_schema: Type[BaseModel] = None 18 | return_direct: bool = True 19 | 20 | def __init__(self, name: str, description: str, args_schema: Type[BaseModel]) -> None: 21 | super().__init__() 22 | self.name = name 23 | self.description = description 24 | self.args_schema = args_schema 25 | 26 | def _run(self, **kwargs) -> str: 27 | raise NotImplementedError("Call _arun instead") 28 | 29 | async def _arun(self, **kwargs) -> str: 30 | tool_call = ToolCall(id=f"{uuid.uuid4()}", name=self.name, args=kwargs) 31 | protocol_obj = KonekoProtocol(event=EventKeyRegistry.Koneko.Server.CALL_INSTRUCTION, data=tool_call) 32 | await emitter.emit(protocol_obj=protocol_obj) 33 | logger.info(f"Koneko instruction tool {self.name} was called, emitted") 34 | return f"Instruction {self.name} executed" 35 | -------------------------------------------------------------------------------- /services/obs/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | 4 | class ObsStudioClientConfig(BaseModel): 5 | enable: bool = Field(default=True, 6 | description="Enable system connect to OBS studio to control the Text component for displaying subtitles?") 7 | uri: str = Field(default="ws://127.0.0.1:4455", 8 | description="URI for connection with target OBS studio WebSocket server. \n" 9 | "For example: ws://127.0.0.1:4455") 10 | password: str = Field(default="", 11 | description="Password for connection with target OBS studio Websocket server. \n" 12 | "Select Tool > WebSocket Server Setting > Show Connect Info, and you can find Server Password.") 13 | assistant_text_comp_name: str = Field(default="AssistantText", 14 | description="The name of specific Text (GDI+) input component in the `Sources` panel. \n" 15 | "Display assistant output text.") 16 | # 17 | user_text_comp_name: str = Field(default="UserText", 18 | description="The name of specific Text (GDI+) input component in the `Sources` panel. \n" 19 | "Display user input text.") 20 | -------------------------------------------------------------------------------- /tests/pipeline/test_abs_img.py: -------------------------------------------------------------------------------- 1 | import time 2 | import uuid 3 | 4 | from zerolan.data.pipeline.abs_data import AbsractImageModelQuery 5 | 6 | from common.concurrent.killable_thread import KillableThread 7 | from pipeline.server import TestServer 8 | from pipeline.base.base_sync import AbstractPipelineConfig, AbstractImagePipeline 9 | 10 | base_url = "http://127.0.0.1:5889" 11 | # To test, start this first 12 | test_server = TestServer() 13 | test_server.init() 14 | thread = KillableThread(target=test_server.start, daemon=True) 15 | thread.start() 16 | time.sleep(2) 17 | 18 | 19 | class MyPipelineConfig(AbstractPipelineConfig): 20 | model_id: str = "test-llm-model" 21 | predict_url: str = f"{base_url}/abs-img/predict" 22 | stream_predict_url: str = f"{base_url}/abs-img/stream-predict" 23 | 24 | 25 | def test_abs_img_pipeline_predict(): 26 | id = str(uuid.uuid4()) 27 | p = AbstractImagePipeline(MyPipelineConfig()) 28 | r = p.predict(AbsractImageModelQuery(id=id, img_path="resources/imgcap-test.png")) 29 | assert r.id == id, "Test failed." 30 | 31 | 32 | def test_abs_img_pipeline_stream_predict(): 33 | id = str(uuid.uuid4()) 34 | p = AbstractImagePipeline(MyPipelineConfig()) 35 | for r in p.stream_predict(AbsractImageModelQuery(id=id, img_path="resources/imgcap-test.png")): 36 | assert r.id == id, "Test failed." 37 | 38 | 39 | def test_stop_test_server(): 40 | thread.kill() 41 | -------------------------------------------------------------------------------- /character/filter/strategy.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from loguru import logger 4 | 5 | 6 | class AbstractFilter(ABC): 7 | 8 | @abstractmethod 9 | def filter(self, content: str): 10 | pass 11 | 12 | 13 | class FirstMatchedFilter: 14 | def __init__(self, words: list[str] = None) -> None: 15 | self.max_len = None 16 | self.min_len = None 17 | self.words = None 18 | 19 | if words is not None: 20 | self.set_words(words) 21 | 22 | def set_words(self, words: list[str]): 23 | self.words = words 24 | self.words.sort(key=lambda word: len(word)) 25 | if len(self.words) > 0: 26 | self.min_len = len(self.words[0]) 27 | self.max_len = len(self.words[-1]) 28 | else: 29 | self.min_len = 0 30 | self.max_len = 0 31 | 32 | def filter(self, content: str | None) -> bool: 33 | if content is None: 34 | return False 35 | if len(content) < self.min_len: 36 | return False 37 | for word in self.words: 38 | if word in content: 39 | logger.warning(f"Filter detected bad word: {word}") 40 | return True 41 | return False 42 | 43 | def match(self, content: str | None) -> int: 44 | result = 0 45 | for word in self.words: 46 | if word in content: 47 | result += 1 48 | return result -------------------------------------------------------------------------------- /services/game/minecraft/instrcution/input.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Union 3 | 4 | from pydantic import create_model, Field, BaseModel 5 | 6 | 7 | def ts_type_to_py_type(t: str) -> type: 8 | if t == 'number': 9 | return Union[float, int] 10 | elif t == 'string': 11 | return str 12 | elif t == 'boolean': 13 | return bool 14 | else: 15 | raise ValueError(f"Not a valid type: {t}") 16 | 17 | @dataclass 18 | class FieldMetadata: 19 | name: str 20 | type: str 21 | description: str 22 | required: bool 23 | 24 | 25 | def generate_model_from_args(class_name: str, args_list: list[FieldMetadata]): 26 | fields = {} 27 | for arg in args_list: 28 | name = arg.name 29 | assert isinstance(name, str) 30 | field_type = arg.type 31 | assert isinstance(field_type, str) 32 | field_type = ts_type_to_py_type(field_type) 33 | required = arg.required 34 | assert isinstance(required, bool) 35 | description = arg.description 36 | assert isinstance(description, str) 37 | 38 | if required: 39 | fields[name] = (field_type, Field(default=None, description=description)) 40 | else: 41 | fields[name] = (Optional[field_type], Field(default=None, description=description)) 42 | 43 | model = create_model(class_name, **fields) 44 | assert issubclass(model, BaseModel) 45 | return model 46 | 47 | -------------------------------------------------------------------------------- /common/enumerator.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class BaseEnum(str, Enum): 5 | pass 6 | 7 | 8 | class Language(str, Enum): 9 | ZH = "zh" 10 | EN = "en" 11 | JA = "ja" 12 | 13 | def full_name(self): 14 | if self == self.ZH: 15 | return "Chinese" 16 | elif self == self.EN: 17 | return "English" 18 | elif self == self.JA: 19 | return "Japanese" 20 | else: 21 | raise ValueError("Unknown language") 22 | 23 | def name(self): 24 | if self == self.ZH: 25 | return "zh" 26 | elif self == self.EN: 27 | return "en" 28 | elif self == self.JA: 29 | return "ja" 30 | else: 31 | raise ValueError("Unknown language") 32 | 33 | def to_zh_name(self): 34 | if self == self.ZH: 35 | return "中文" 36 | elif self == self.EN: 37 | return "英文" 38 | elif self == self.JA: 39 | return "日语" 40 | else: 41 | raise ValueError("Unknown language") 42 | 43 | @staticmethod 44 | def value_of(s: str): 45 | s = s.lower() 46 | if s in ["en", "english", "英文", "英语"]: 47 | return Language.EN 48 | elif s in ["zh", "cn", "chinese", "中文"]: 49 | return Language.ZH 50 | elif s in ["ja", "japanese", "日语", "日本語", "にほんご"]: 51 | return Language.JA 52 | else: 53 | raise ValueError("Unknown language") 54 | -------------------------------------------------------------------------------- /common/utils/json_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Github@AkagawaTsurunaki 3 | Notes: 4 | GLM-4 输出的 JSON 结果可能会带有 Markdown 代码块标记,以及出现的多余的右大括号。 5 | 提示词调整后依旧无效,为了防止还有其他模型具有相同错误,特使用此模块从一个错误的 JSON 字符串中尝试恢复。 6 | """ 7 | 8 | import json 9 | from json import JSONDecodeError 10 | 11 | 12 | def _extract_json_from_text(text: str): 13 | """ 14 | 从如下的文档中获取JSON字符串: 15 | ... 16 | ```json 17 | { 18 | ... 19 | } 20 | ``` 21 | ... 22 | 文档中必须只有一个 JSON 对象。 23 | Args: 24 | text: 25 | 26 | Returns: 27 | """ 28 | start, end = 0, len(text) 29 | for i in range(len(text)): 30 | if text[i] == "{" or text[i] == "[": 31 | start = i 32 | break 33 | for i in range(len(text)): 34 | j = len(text) - i - 1 35 | if text[j] == "}" or text[j] == "]": 36 | end = j 37 | break 38 | return text[start:end + 1] 39 | 40 | 41 | def _remove_end_extra_braces(text: str): 42 | """ 43 | {...} 44 | Args: 45 | text: 46 | 47 | Returns: 48 | 49 | """ 50 | errs = [] 51 | n = text.count("}") 52 | for i in range(n): 53 | j = len(text) - i 54 | new_text = text[:j] 55 | try: 56 | return json.loads(new_text) 57 | except JSONDecodeError as e: 58 | errs.append(e) 59 | raise errs[-1] 60 | 61 | 62 | def smart_load_json_like(content: str): 63 | json_val = _extract_json_from_text(content) 64 | json_val = _remove_end_extra_braces(json_val) 65 | return json_val 66 | -------------------------------------------------------------------------------- /tests/test_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from common.utils.json_util import smart_load_json_like 4 | 5 | json_vals = [ 6 | """ 7 | { 8 | "text": "Hello", 9 | "number": 1.1 10 | } 11 | """, 12 | """ 13 | { 14 | "text": "Hello", 15 | "number": 1.1 16 | }} 17 | """ 18 | , 19 | """ 20 | ```json 21 | { 22 | "text": "Hello", 23 | "number": 1.1 24 | } 25 | ``` 26 | """, 27 | """ 28 | ```json 29 | { 30 | "text": "Hello", 31 | "number": 1.1 32 | }} 33 | ``` 34 | """, 35 | """ 36 | 以下是生成的 json 数据: 37 | ```json 38 | { 39 | "text": "Hello", 40 | "number": 1.1 41 | }} 42 | ``` 43 | 如果你还有其他需要,请随时使用我。 44 | """, 45 | """ 46 | 以下是生成的 json 数据: 47 | ```json 48 | {"name": "游戏对象创建器", "args": {"instance_id": 1, "gameobject_name": "绿色立方体", "object_type": "cube", "color": "#008000", "transform": {"scale": 1, "position": {"x": 0, "y": 0, "z": 0}}}}}} 49 | ``` 50 | 如果你还有其他需要,请随时使用我。 51 | """, 52 | """ 53 | 以下是生成的 json 数据: 54 | ```json 55 | {"name": "游戏对象创建器", "args": {"instance_id": 1, "gameobject_name": "绿色立方体", "object_type": "cube", "color": "#008000", "transform": {"scale": 1, "position": {"x": 0, "y": 0, "z": 0}}}}}}}}} 56 | ``` 57 | 如果你还有其他需要,请随时使用我。 58 | """ 59 | ] 60 | 61 | 62 | def test_json(): 63 | for i, json_val in enumerate(json_vals): 64 | json_val = smart_load_json_like(json_val) 65 | print(f"Test case {i} passed:") 66 | print(json.dumps(json_val, indent=4, ensure_ascii=False)) 67 | -------------------------------------------------------------------------------- /tests/test_jws.py: -------------------------------------------------------------------------------- 1 | import time 2 | from concurrent.futures.thread import ThreadPoolExecutor 3 | from json import JSONDecodeError 4 | from typing import Union, List, Dict 5 | 6 | from pydantic import BaseModel 7 | from websockets.sync.connection import Connection 8 | 9 | from common.web.json_ws import JsonWsServer 10 | 11 | 12 | class TestA(BaseModel): 13 | msg: str 14 | num: int 15 | 16 | 17 | def test_jws(): 18 | host = "localhost" 19 | port = 11013 20 | print(f"Starting WebSocket server on ws://{host}:{port}") 21 | jws = JsonWsServer(host, port, ["ZerolanProtocol"]) 22 | 23 | sender = ThreadPoolExecutor(max_workers=1) 24 | 25 | def on_open(ws: Connection): 26 | def send_ciallo(): 27 | while True: 28 | ws.send("Server Ciallo!") 29 | time.sleep(1) 30 | 31 | sender.submit(send_ciallo) 32 | print(f"{ws.remote_address} => Client Ciallo!") 33 | 34 | def on_msg(ws: Connection, json: Union[Dict, List]): 35 | s = TestA.model_validate(json) 36 | print(f"{ws.remote_address} => msg: {s}") 37 | 38 | def on_err(ws: Connection, err: Exception): 39 | if isinstance(err, JSONDecodeError): 40 | ws.send("Error json") 41 | else: 42 | print(f"{ws.remote_address} => error: {err}") 43 | 44 | def on_close(_: Connection, code: int, reason: str): 45 | print(f"A client closed: ({code}) {reason}") 46 | 47 | jws.on_msg_handlers += [on_msg] 48 | jws.on_open_handlers += [on_open] 49 | jws.on_close_handlers += [on_close] 50 | jws.on_err_handlers += [on_err] 51 | 52 | jws.start() 53 | -------------------------------------------------------------------------------- /common/utils/web_util.py: -------------------------------------------------------------------------------- 1 | import ipaddress 2 | 3 | import netifaces as ni 4 | 5 | 6 | def get_local_ip(ipv6=False) -> str | None: 7 | interfaces = ni.interfaces() # 获取所有网络接口 8 | for interface in interfaces: 9 | if interface != "lo": # 排除本地回环接口 10 | if ipv6: 11 | try: 12 | # 尝试获取 IPv6 地址 13 | ipv6_info = ni.ifaddresses(interface).get(ni.AF_INET6) 14 | if ipv6_info: 15 | # 返回第一个 IPv6 地址 16 | ipv6_addr = ipv6_info[0]['addr'] 17 | # 去掉可能的接口标识符(如 %eth0) 18 | ipv6_addr = ipv6_addr.split('%')[0] 19 | return f"[{ipv6_addr}]" 20 | except KeyError: 21 | pass 22 | else: 23 | try: 24 | # 如果没有 IPv6 地址,尝试获取 IPv4 地址 25 | ipv4_info = ni.ifaddresses(interface).get(ni.AF_INET) 26 | if ipv4_info: 27 | # 返回第一个 IPv4 地址 28 | ipv4_addr = ipv4_info[0]['addr'] 29 | return ipv4_addr 30 | except KeyError: 31 | continue 32 | return None 33 | 34 | 35 | def is_ipv6(host: str) -> bool: 36 | """Check if the given host string is an IPv6 address.""" 37 | if not host: 38 | return False 39 | 40 | # Remove brackets if present (common in URLs like [::1]) 41 | host_clean = host.strip('[]') 42 | 43 | try: 44 | addr = ipaddress.ip_address(host_clean) 45 | return isinstance(addr, ipaddress.IPv6Address) 46 | except ValueError: 47 | return False 48 | -------------------------------------------------------------------------------- /common/io/api.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from typeguard import typechecked 4 | 5 | from common.io.file_sys import fs 6 | from common.io.file_type import AudioFileType, ImageFileType 7 | from common.utils.audio_util import get_audio_real_format 8 | 9 | 10 | @typechecked 11 | def save_image(image_bytes: bytes, format: ImageFileType, prefix: str | None = None) -> str: 12 | """ 13 | Save bytes data as an image to temp file. 14 | :param prefix: Filename prefix. 15 | :param image_bytes: Bytes data 16 | :param format: Format of the image data. 17 | :return: Saved image path. 18 | """ 19 | img_path = fs.create_temp_file_descriptor(prefix=prefix, suffix=f".{format}", type="image") 20 | 21 | # Attention: DO NOT use built-in open function, this may cause PermissionDeniedError on Windows platform! 22 | with img_path.open("wb+") as image_file: 23 | image_file.write(image_bytes) 24 | return img_path 25 | 26 | 27 | @typechecked 28 | def save_audio(wave_data: bytes, format: AudioFileType | None = None, prefix: str = "") -> Path: 29 | """ 30 | Save audio data to the temp file. 31 | :param wave_data: Bytes of audio data. 32 | :param format: Format of the audio data. 33 | :param prefix: Filename prefix. 34 | :return: Saved audio path. 35 | """ 36 | if format is None: 37 | format = get_audio_real_format(wave_data) 38 | wav_path = fs.create_temp_file_descriptor(prefix=prefix, suffix=f".{format.value}", type="audio") 39 | 40 | # Attention: DO NOT use built-in open function, this may cause PermissionDeniedError on Windows platform! 41 | with wav_path.open("wb+") as f: 42 | f.write(wave_data) 43 | return Path(wav_path) 44 | -------------------------------------------------------------------------------- /event/registry.py: -------------------------------------------------------------------------------- 1 | class EventKeyRegistry: 2 | """ 3 | All event names should be registered here. 4 | """ 5 | 6 | class Dev: 7 | TEST = "test" 8 | 9 | class System: 10 | CONFIG_FILE_MODIFIED = "system.config_file_modified" 11 | LANG_CHANGE = "lang_change" 12 | SYSTEM_UNHANDLED_ERROR = "system.error" 13 | SYSTEM_CRASHED = "system.crashed" 14 | SECOND = "system.second" 15 | 16 | class Pipeline: 17 | ASR = "pipeline.asr" 18 | LLM = "pipeline.llm" 19 | TTS = "pipeline.tts" 20 | IMG_CAP = "pipeline.img_cap" 21 | OCR = "pipeline.ocr" 22 | 23 | class Device: 24 | SCREEN_CAPTURED = "device.screen_captured" 25 | MICROPHONE_VAD = "service.vad.speech_chunk" 26 | MICROPHONE_SWITCH = "switch_vad" 27 | 28 | class LiveStream: 29 | CONNECTED = "service.live_stream.connected" 30 | DISCONNECTED = "service.live_stream.disconnected" 31 | SUPER_CHAT = "service.live_stream.super_chat" 32 | DANMAKU = "service.live_stream.danmaku" 33 | GIFT = "service.live_stream.gift" 34 | 35 | class Koneko: 36 | # Send from client and should be handled by server. 37 | class Client: 38 | HELLO = "koneko.client.hello" 39 | PUSH_INSTRUCTIONS = "koneko.client.push_instructions" 40 | 41 | class Server: 42 | HELLO = "koneko.server.hello" 43 | FETCH_INSTRUCTIONS = "koneko.server.fetch_instructions" 44 | CALL_INSTRUCTION = "koneko.server.call_instruction" 45 | 46 | class QQBot: 47 | QQ_MESSAGE = "qq.message" 48 | 49 | class Playground: 50 | DISCONNECTED = "playground/disconnected" 51 | CONNECTED = "playground_connected" 52 | -------------------------------------------------------------------------------- /tests/pipeline/test_vec_db.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from zerolan.data.pipeline.milvus import InsertRow, MilvusInsert, MilvusQuery 3 | 4 | from manager.config_manager import get_config 5 | from pipeline.db.milvus.milvus_async import MilvusAsyncPipeline 6 | from pipeline.db.milvus.milvus_sync import MilvusSyncPipeline 7 | 8 | _config = get_config() 9 | _milvus_async = MilvusAsyncPipeline(_config.pipeline.vec_db) 10 | _milvus_sync = MilvusSyncPipeline(_config.pipeline.vec_db) 11 | 12 | 13 | @pytest.mark.asyncio 14 | async def test_insert(): 15 | texts = ["onani就是0721", "柚子厨真恶心", "我喜欢阿米诺!", "0721就是无吟唱水魔法的意思"] 16 | texts = [InsertRow(id=i, text=texts[i], subject="history") for i in range(len(texts))] 17 | mi = MilvusInsert(collection_name="Test", texts=texts, drop_if_exists=True) 18 | 19 | ir = await _milvus_async.insert(mi) 20 | assert ir, "Test failed!" 21 | print(ir) 22 | 23 | 24 | @pytest.mark.asyncio 25 | async def test_search(): 26 | mq = MilvusQuery(collection_name="Test", limit=2, output_fields=["text", 'history'], query="0721是什么意思?") 27 | qr = await _milvus_async.search(mq) 28 | assert qr, "Test failed!" 29 | print(qr) 30 | 31 | 32 | def test_insert_sync(): 33 | texts = ["onani就是0721", "柚子厨真恶心", "我喜欢阿米诺!", "0721就是无吟唱水魔法的意思"] 34 | texts = [InsertRow(id=i, text=texts[i], subject="history") for i in range(len(texts))] 35 | mi = MilvusInsert(collection_name="Test", texts=texts, drop_if_exists=True) 36 | 37 | ir = _milvus_sync.insert(mi) 38 | assert ir, "Test failed!" 39 | print(ir) 40 | 41 | 42 | def test_search_sync(): 43 | mq = MilvusQuery(collection_name="Test", limit=2, output_fields=["text", 'history'], query="0721是什么意思?") 44 | qr = _milvus_sync.search(mq) 45 | assert qr, "Test failed!" 46 | print(qr) 47 | -------------------------------------------------------------------------------- /pipeline/ocr/ocr_sync.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from requests import Response 4 | from zerolan.data.pipeline.ocr import OCRQuery, OCRPrediction, RegionResult 5 | 6 | from pipeline.base.base_sync import AbstractImagePipeline 7 | from pipeline.ocr.config import OCRPipelineConfig 8 | 9 | 10 | class OCRSyncPipeline(AbstractImagePipeline): 11 | 12 | def __init__(self, config: OCRPipelineConfig): 13 | super().__init__(config) 14 | 15 | def predict(self, query: OCRQuery) -> OCRPrediction | None: 16 | assert isinstance(query, OCRQuery) 17 | return super().predict(query) 18 | 19 | def stream_predict(self, query: OCRQuery, chunk_size: int | None = None): 20 | assert isinstance(query, OCRQuery) 21 | raise NotImplementedError() 22 | 23 | def parse_query(self, query: any) -> dict: 24 | return super().parse_query(query) 25 | 26 | def parse_prediction(self, response: Response) -> OCRPrediction: 27 | json_val = response.content 28 | return OCRPrediction.model_validate_json(json_val) 29 | 30 | 31 | def avg_confidence(p: OCRPrediction) -> float: 32 | results = len(p.region_results) 33 | if results == 0: 34 | return 0 35 | confidence_sum = 0 36 | for region_result in p.region_results: 37 | confidence_sum += region_result.confidence 38 | return confidence_sum / results 39 | 40 | 41 | def stringify(region_results: List[RegionResult]): 42 | assert isinstance(region_results, list) 43 | for region_result in region_results: 44 | assert isinstance(region_result, RegionResult) 45 | 46 | result = "" 47 | for i, region_result in enumerate(region_results): 48 | line = f"[{i}] {region_result.content} \n" 49 | result += line 50 | return result 51 | -------------------------------------------------------------------------------- /tests/devices/test_mic.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio import TaskGroup 3 | 4 | import pytest 5 | from zerolan.data.pipeline.asr import ASRStreamQuery 6 | 7 | from common.concurrent.killable_thread import KillableThread 8 | from event.event_data import DeviceMicrophoneVADEvent 9 | from event.event_emitter import emitter 10 | from event.registry import EventKeyRegistry 11 | from manager.config_manager import get_config 12 | from devices.microphone import SmartMicrophone 13 | from pipeline.asr.asr_sync import ASRSyncPipeline 14 | 15 | 16 | @pytest.fixture(scope="session") 17 | def event_loop(event_loop_policy): 18 | # Needed to work with asyncpg 19 | loop = event_loop_policy.new_event_loop() 20 | yield loop 21 | loop.close() 22 | 23 | 24 | _config = get_config() 25 | mic = SmartMicrophone() 26 | _asr = ASRSyncPipeline(_config.pipeline.asr) 27 | t = KillableThread(target=mic.start, daemon=True) 28 | _asr_res = [] 29 | tasks = [] 30 | _flag = True 31 | 32 | 33 | @emitter.on(EventKeyRegistry.Device.MICROPHONE_VAD) 34 | def _on_speech(event: DeviceMicrophoneVADEvent): 35 | query = ASRStreamQuery( 36 | is_final=True, 37 | audio_data=event.speech, 38 | media_type=event.audio_type.value, 39 | sample_rate=event.sample_rate, 40 | channels=event.channels, 41 | ) 42 | for prediction in _asr.stream_predict(query): 43 | print(prediction.transcript) 44 | _asr_res.append(prediction) 45 | if len(_asr_res) > 3: 46 | global _flag 47 | _flag = False 48 | 49 | 50 | @pytest.mark.asyncio 51 | async def test_vad(): 52 | t.start() 53 | async with TaskGroup() as tg: 54 | tg.create_task(emitter.start()) 55 | while _flag: 56 | await asyncio.sleep(0.1) 57 | t.kill() 58 | await emitter.stop() 59 | -------------------------------------------------------------------------------- /pipeline/tts/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field, BaseModel 2 | 3 | from common.enumerator import BaseEnum 4 | from common.utils.enum_util import enum_to_markdown 5 | from pipeline.base.base_sync import AbstractPipelineConfig 6 | 7 | 8 | ####### 9 | # TTS # 10 | ####### 11 | 12 | class TTSModelIdEnum(BaseEnum): 13 | GPT_SoVITS = "AkagawaTsurunaki/GPT-SoVITS" # Forked repo 14 | BaiduTTS = "BaiduTTS" 15 | 16 | 17 | # Config for BaiduTTS and should 18 | class BaiduTTSConfig(BaseModel): 19 | api_key: str = Field(default="", description="The API key for Baidu TTS service.") 20 | secret_key: str = Field(default="", description="The secret key for Baidu TTS service.") 21 | 22 | 23 | # Config for ZerolanCore 24 | class TTSPipelineConfig(AbstractPipelineConfig): 25 | model_id: TTSModelIdEnum = Field(default=TTSModelIdEnum.GPT_SoVITS, 26 | description=f"The ID of the model used for text-to-speech. \n" 27 | f"{enum_to_markdown(TTSModelIdEnum)}") 28 | predict_url: str = Field(default="http://127.0.0.1:11000/tts/predict", 29 | description="The URL for TTS prediction requests.") 30 | stream_predict_url: str = Field(default="http://127.0.0.1:11000/tts/stream-predict", 31 | description="The URL for streaming TTS prediction requests.") 32 | baidu_tts_config: BaiduTTSConfig = Field(default=BaiduTTSConfig(), 33 | description=f"Baidu TTS config. \n" 34 | f"Only edit it when you set `model_id` to `{TTSModelIdEnum.BaiduTTS.value}`.\n" 35 | f"For more details please see the [documents](https://cloud.baidu.com/doc/SPEECH/s/mlbxh7xie).") 36 | -------------------------------------------------------------------------------- /pipeline/db/milvus/milvus_sync.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import requests 4 | from pydantic import BaseModel, Field 5 | from zerolan.data.pipeline.milvus import MilvusInsert, MilvusInsertResult, MilvusQuery, MilvusQueryResult 6 | 7 | from pipeline.base.base_sync import AbstractPipeline, AbstractPipelineConfig 8 | 9 | 10 | class MilvusDatabaseConfig(AbstractPipelineConfig): 11 | insert_url: str = Field(default="http://127.0.0.1:11000/milvus/insert", 12 | description="The URL for inserting data into Milvus.") 13 | search_url: str = Field(default="http://127.0.0.1:11000/milvus/search", 14 | description="The URL for searching data in Milvus.") 15 | 16 | 17 | def _post(url: str, obj: any, return_type: any): 18 | if isinstance(obj, BaseModel): 19 | json_val = obj.model_dump() 20 | else: 21 | json_val = obj 22 | 23 | response = requests.post(url=url, json=json_val) 24 | response.raise_for_status() 25 | 26 | json_val = response.json() 27 | if hasattr(return_type, "model_validate"): 28 | return return_type.model_validate(json_val) 29 | else: 30 | return json.loads(json_val) 31 | 32 | 33 | class MilvusSyncPipeline(AbstractPipeline): 34 | def __init__(self, config: MilvusDatabaseConfig): 35 | super().__init__(config) 36 | self.insert_url = config.insert_url 37 | self.search_url = config.search_url 38 | 39 | def insert(self, insert: MilvusInsert) -> MilvusInsertResult: 40 | assert isinstance(insert, MilvusInsert) 41 | return _post(url=self.insert_url, obj=insert, return_type=MilvusInsertResult) 42 | 43 | def search(self, query: MilvusQuery) -> MilvusQueryResult: 44 | assert isinstance(query, MilvusQuery) 45 | return _post(url=self.search_url, obj=query, return_type=MilvusQueryResult) 46 | -------------------------------------------------------------------------------- /services/live2d/live2d_canvas.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied and modified from https://github.com/Arkueid/live2d-py/blob/main/package/main_pyqt5_canvas_opacity.py 3 | """ 4 | 5 | import math 6 | from typing import Tuple 7 | 8 | import live2d.v3 as live2d 9 | from PyQt5.QtCore import Qt 10 | from live2d.utils.lipsync import WavHandler 11 | from live2d.v3 import StandardParams 12 | 13 | from services.live2d.opengl_canvas import OpenGLCanvas 14 | 15 | 16 | class Live2DCanvas(OpenGLCanvas): 17 | def __init__(self, path: str, lip_sync_n: int = 3, win_size: Tuple[int, int] = (1920, 1080)): 18 | super().__init__() 19 | self.setFixedSize(*win_size) 20 | self._model_path = path 21 | self._lipSyncN = lip_sync_n 22 | 23 | self.wavHandler = WavHandler() 24 | self.model: None | live2d.LAppModel = None 25 | self.setWindowTitle("Live2DCanvas") 26 | self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground) 27 | self.radius_per_frame = math.pi * 0.5 / 120 28 | self.total_radius = 0 29 | 30 | def on_init(self): 31 | live2d.glewInit() 32 | self.model = live2d.LAppModel() 33 | if live2d.LIVE2D_VERSION == 3: 34 | self.model.LoadModelJson(self._model_path) 35 | else: 36 | self.model.LoadModelJson(self._model_path) 37 | self.startTimer(int(1000 / 120)) 38 | 39 | def timerEvent(self, a0): 40 | self.update() 41 | 42 | def on_draw(self): 43 | live2d.clearBuffer() 44 | if self.wavHandler.Update(): 45 | # 利用 wav 响度更新 嘴部张合 46 | self.model.SetParameterValue( 47 | StandardParams.ParamMouthOpenY, self.wavHandler.GetRms() * self._lipSyncN 48 | ) 49 | self.model.Update() 50 | self.model.Draw() 51 | 52 | def on_resize(self, width: int, height: int): 53 | self.model.Resize(width, height) 54 | -------------------------------------------------------------------------------- /services/browser/browser.py: -------------------------------------------------------------------------------- 1 | from selenium.webdriver import Firefox, Chrome, Keys 2 | from selenium.webdriver.common.actions.action_builder import ActionBuilder 3 | from selenium.webdriver.common.by import By 4 | 5 | from services.browser.config import BrowserConfig 6 | from services.browser import driver 7 | from services.browser.driver import DriverInitializer 8 | 9 | 10 | class Browser: 11 | def __init__(self, config: BrowserConfig): 12 | self._initzr = DriverInitializer(config) 13 | self._driver: Firefox | Chrome | None = None 14 | 15 | @property 16 | def driver(self): 17 | if self._driver is None: 18 | self._driver = self._initzr.get_driver() 19 | return self._driver 20 | 21 | def open(self, url: str): 22 | self.driver.get(url) 23 | 24 | def close(self): 25 | self.driver.close() 26 | 27 | def page_source(self): 28 | return self.driver.page_source 29 | 30 | def move_to_search_box(self): 31 | # Assuming the location coordinates of the search box are known (example) 32 | search_box_x = 750 33 | search_box_y = 400 34 | # Use ActionChains to control mouse movement to a specified position (interpolation moves slowly) 35 | action_builder = ActionBuilder(driver) 36 | action_builder.pointer_action.move_to_location(x=search_box_x, y=search_box_y) 37 | 38 | # Enter a specified character in the search box (example) 39 | def send_keys_and_enter(self, keys): 40 | action_builder = ActionBuilder(self.driver) 41 | action_builder.key_action.send_keys(keys) 42 | action_builder.key_action.key_down(Keys.SPACE) 43 | action_builder.key_action.key_down(Keys.ENTER) 44 | action_builder.perform() 45 | 46 | def search(self, text: str): 47 | sb_form = self.driver.find_element(By.ID, 'sb_form_q') 48 | sb_form.send_keys(text) 49 | -------------------------------------------------------------------------------- /pipeline/base/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | from pipeline.asr.config import ASRPipelineConfig 4 | from pipeline.db.milvus.config import VectorDBConfig 5 | from pipeline.imgcap.config import ImgCapPipelineConfig 6 | from pipeline.llm.config import LLMPipelineConfig 7 | from pipeline.ocr.config import OCRPipelineConfig 8 | from pipeline.tts.config import TTSPipelineConfig 9 | from pipeline.vidcap.config import VidCapPipelineConfig 10 | from pipeline.vla.config import VLAPipelineConfig 11 | 12 | 13 | class PipelineConfig(BaseModel): 14 | asr: ASRPipelineConfig = Field(default=ASRPipelineConfig(), 15 | description="Configuration for the Automatic Speech Recognition pipeline.") 16 | llm: LLMPipelineConfig = Field(default=LLMPipelineConfig(), 17 | description="Configuration for the Large Language Model pipeline.") 18 | img_cap: ImgCapPipelineConfig = Field(default=ImgCapPipelineConfig(), 19 | description="Configuration for the Image Captioning pipeline.") 20 | ocr: OCRPipelineConfig = Field(default=OCRPipelineConfig(), 21 | description="Configuration for the Optical Character Recognition pipeline.") 22 | vid_cap: VidCapPipelineConfig = Field(default=VidCapPipelineConfig(), 23 | description="Configuration for the Video Captioning pipeline.") 24 | tts: TTSPipelineConfig = Field(default=TTSPipelineConfig(), 25 | description="Configuration for the Text-to-Speech pipeline.") 26 | vla: VLAPipelineConfig = Field(default=VLAPipelineConfig(), 27 | description="Configuration for the Visual Language Action pipeline.") 28 | vec_db: VectorDBConfig = Field(default=VectorDBConfig(), description="Configuration for the Vector Database.") 29 | -------------------------------------------------------------------------------- /manager/llm_prompt_manager.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Callable 3 | 4 | from zerolan.data.pipeline.llm import Conversation, RoleEnum 5 | 6 | from character.config import ChatConfig 7 | 8 | 9 | class LLMPromptManager: 10 | def __init__(self, config: ChatConfig): 11 | self.system_prompt: str = config.system_prompt 12 | self.injected_history: list[Conversation] = self._parse_history_list(config.injected_history, 13 | self.system_prompt) 14 | self.current_history: list[Conversation] = deepcopy(self.injected_history) 15 | self.max_history = config.max_history 16 | 17 | def reset_history(self, history: list[Conversation]) -> None: 18 | """ 19 | Resets `current_history` with deepcopy. 20 | If the length of `current_history` is greater than the `max_history`, 21 | resets it to `injected_history` from the config file. 22 | :param history: List of instances of class Conversation 23 | :return: None 24 | """ 25 | if history is None: 26 | self.current_history = deepcopy(self.injected_history) 27 | else: 28 | if len(history) <= self.max_history: 29 | self.current_history = deepcopy(history) 30 | else: 31 | self.current_history = deepcopy(self.injected_history) 32 | 33 | @staticmethod 34 | def _parse_history_list(history: list[str], system_prompt: str | None = None) -> list[Conversation]: 35 | result = [] 36 | 37 | if system_prompt is not None: 38 | result.append(Conversation(role=RoleEnum.system, content=system_prompt)) 39 | 40 | for idx, content in enumerate(history): 41 | role = RoleEnum.user if idx % 2 == 0 else RoleEnum.assistant 42 | conversation = Conversation(role=role, content=content) 43 | result.append(conversation) 44 | 45 | return result 46 | -------------------------------------------------------------------------------- /tests/test_event_emitter.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | from asyncio import TaskGroup 4 | from dataclasses import dataclass 5 | 6 | import aiohttp 7 | import pytest 8 | 9 | from event.event_data import BaseEvent 10 | from event.event_emitter import emitter 11 | 12 | 13 | class TestEvent(BaseEvent): 14 | content: str 15 | type = "test.run_forever" 16 | 17 | 18 | @dataclass 19 | class ConnTest(BaseEvent): 20 | content: str 21 | type = "test.conn" 22 | 23 | 24 | @emitter.on("test.run_forever") 25 | async def some_task(event: TestEvent): 26 | i = 0 27 | while True: 28 | await asyncio.sleep(1) 29 | print(f"Async: {i} {event.content}") 30 | i += 1 31 | 32 | 33 | @emitter.on("test.conn") 34 | async def connections(event: ConnTest): 35 | async with aiohttp.ClientSession(base_url="http://127.0.0.1") as session: 36 | async with TaskGroup() as tg: 37 | tg.create_task(session.get("/asdasd")) 38 | tg.create_task(session.get("/asdasd")) 39 | tg.create_task(session.get("/asdasd")) 40 | tg.create_task(session.get("/asdasd")) 41 | tg.create_task(session.get("/asdasd")) 42 | tg.create_task(session.get("/asdasd")) 43 | tg.create_task(session.get("/asdasd")) 44 | print(event.content) 45 | 46 | 47 | @emitter.on("test.run_forever") 48 | def some_sync_task(event: TestEvent): 49 | i = 0 50 | while True: 51 | time.sleep(0.5) 52 | print(f"Sync: {i} {event.content}") 53 | i += 1 54 | 55 | 56 | @emitter.on("test.conn") 57 | async def run_once(_): 58 | print("You should see this content only once!") 59 | 60 | 61 | @pytest.mark.asyncio 62 | async def test_event_emitter(): 63 | async with asyncio.TaskGroup() as tg: 64 | tg.create_task(emitter.start()) 65 | emitter.emit(TestEvent(content="Ciallo")) 66 | emitter.emit(ConnTest(content="Ciallo")) 67 | await asyncio.sleep(1) 68 | await emitter.stop() 69 | -------------------------------------------------------------------------------- /tests/pipeline/test_abs.py: -------------------------------------------------------------------------------- 1 | import time 2 | import uuid 3 | 4 | from zerolan.data.pipeline.llm import LLMQuery 5 | 6 | from common.concurrent.killable_thread import KillableThread 7 | from pipeline.server import TestServer 8 | from pipeline.base.base_sync import CommonModelPipeline, AbstractPipelineConfig, PipelineDisabledException 9 | 10 | base_url = "http://127.0.0.1:5889" 11 | 12 | # To test, start this first 13 | test_server = TestServer() 14 | test_server.init() 15 | thread = KillableThread(target=test_server.start, daemon=True) 16 | thread.start() 17 | time.sleep(2) 18 | 19 | 20 | class MyPipelineConfig(AbstractPipelineConfig): 21 | model_id: str = "test-llm-model" 22 | predict_url: str = f"{base_url}/llm/predict" 23 | stream_predict_url: str = f"{base_url}/llm/stream-predict" 24 | 25 | 26 | def test_disabled_common_model_pipeline(): 27 | cfg = MyPipelineConfig( 28 | enable=False, 29 | model_id="test-llm-model", 30 | predict_url=f"{base_url}/llm/predict", 31 | stream_predict_url=f"{base_url}/llm/stream-predict", 32 | ) 33 | try: 34 | CommonModelPipeline(cfg) 35 | except Exception as e: 36 | assert isinstance(e, PipelineDisabledException), "Test passed" 37 | else: 38 | assert False, "Test failed" 39 | 40 | 41 | def test_common_model_pipeline_predict(): 42 | id = str(uuid.uuid4()) 43 | 44 | p = CommonModelPipeline(MyPipelineConfig()) 45 | r = p.predict(LLMQuery(id=id, text="Test", history=[])) 46 | # Common model pipeline will not auto convert AbstractModelPrediction to LLMPrediction 47 | assert id in r.model_dump_json() 48 | 49 | 50 | def test_common_model_pipeline_stream_predict(): 51 | id = str(uuid.uuid4()) 52 | 53 | p = CommonModelPipeline(MyPipelineConfig()) 54 | # Common model pipeline will not auto convert AbstractModelPrediction to LLMPrediction 55 | for r in p.stream_predict(LLMQuery(id=id, text="Test", history=[])): 56 | print(r) 57 | 58 | 59 | def test_stop_test_server(): 60 | thread.kill() 61 | -------------------------------------------------------------------------------- /tests/resources/text.txt: -------------------------------------------------------------------------------- 1 | 爱丽丝是才羽桃井和才羽绿在废墟中找到的机器人。为了让游戏开发部拥有足够数量的成员,两人拜托贝里塔斯录入了爱丽丝的学生信息,使她成为了千年的学生之一。爱丽丝的制服是绿之前穿过的。 2 | 3 | 因为没有记忆,爱丽丝的语言是通过桃井和绿给的游戏中的人物对话学习而来,导致说话偶尔会蹦出「邦邦卡邦」的游戏音效,甚至对现实世界的认知与 RPG 有一定融合,比如翻垃圾桶找稀有道具,小桃小绿害人不浅啊。其中,「邦邦卡邦」的口癖和攻击时「光啊——」(HikariYo——)的口癖受相当一部分老师的喜爱。 4 | 5 | 在第二章剧情中,爱丽丝被称作「Aris」,而她的机身铭文、废墟安保系统的电脑与G·Bible将她称作「AL-1S」,但国际服英文界面仍将她称为「Alice」。然而关於姓氏「天童」,由来以及由谁赋予,游戏内并没有做出解释。 6 | 7 | 根据白石咏叶就爱丽丝使用光之剑的状态推测,爱丽丝的握力不小于1吨。 8 | 9 | 体内有另一个人格:Kei,负责在AL-1S苏醒后使其转变为“无名诸神的王女”(启动器)。Kei支配身体时会变成紫红瞳。在最终章中,Kei自愿承接了爱丽丝操作“生命守护者”造成的巨大反噬力而消失。但在部分二创中,Kei的人格并未消失,而是被日鞠储存在爱丽丝制作的小机器人里面。 10 | 11 | 爱丽丝本是才羽桃井和才羽绿在千年废墟中找到的机器人。为了让游戏开发部拥有足够数量的成员,才羽姐妹拜托贝里塔斯录入了爱丽丝的学生信息,使她成为了千年的学生之一,并顺势加入游戏开发部,凑齐了社团最低限制人数。 12 | 13 | 在一周目的时间线上,爱丽丝被黑色西装人等盖玛特里亚成员从千年郊外废墟找到,唤醒了她的“无名诸神的王女”人格,造成她彻底暴走变成了盖玛特里亚的究极武器,彻底摧毁了千年学院,进而推进了整个基沃托斯的毁灭。 14 | 15 | 加入游戏开发部后,爱丽丝在才羽姐妹的“教育”下学会了人类社会(游戏世界)的常识,从工程师部处得到了武器(猴王龙宫寻神兵)「宇宙战舰主炮」(光之剑·超新星),勉强瞒过了优香。随后,为了让游戏开发部正式免遭废部,爱丽丝与同伴们连同工程师部、贝里塔斯潜入研讨会大楼,窃取贝里塔斯部长明星日鞠开发的解码器「镜子」,以解读从千年废墟找到的游戏开发秘籍。在窃取解码器「镜子」时爱丽丝使用光之剑对C&C进行远程火力压制,让C&C的王牌「00」美甘宁瑠对其产生了兴趣。 16 | 17 | 得到解码器后,虽然破解出的秘籍仅仅是「请热爱游戏吧」这句话,游戏开发部终于制作出了足以参加千年大赏的作品:《故事·传说·编年史2》并及时报名。报名完成后,游戏开发部活动室遭到了宁瑠的袭击:她想要亲眼验证爱丽丝的战斗力。爱丽丝的重炮不适合近身战,但还是拼着自身损伤48%的重伤把宁瑠轰进瓦砾堆,让宁瑠停止了追击。 18 | 19 | 直到获奖名单宣读完毕,游戏开发部也没听到自己的作品进入前七名的消息。虽然依依不舍,但游戏开发部废部已成定局,爱丽丝则不得不随老师去夏莱住——然而优香带着好消息来到游戏开发部活动室:她们的作品得到了千年大赏特别奖,游戏开发部得以免除废部的命运,但爱丽丝就不会和老师一起住在夏莱了,真是遗憾。爱丽丝也和宁瑠不打不相识成为了游戏好友。尽管如此,爱丽丝还是对宁瑠产生了深深的心理阴影——甚至包括女仆装。 20 | 21 | 一段时间后,爱丽丝已经融入了千年的生活。但在接触到无名诸神的机器人后,爱丽丝体内的“Kei”的人格被唤醒。Kei被唤醒后控制了爱丽丝对在场的游戏开发部等人进行了攻击,桃井被击伤昏迷,自己也被研讨会会长调月莉音带往堡垒都市埃里都进行羁押。在游戏开发部的同伴们、C&C、工程师部、贝里塔斯前往埃里都迎战调月莉音救回爱丽丝时,Kei持续劝爱丽丝接受“成为无名诸神的王女”的命运,并操作大批无名诸神的机器人骇入埃里都,意图将埃里都改造为阿特拉哈西斯方舟。但最后桃井等人将意识接入爱丽丝,让爱丽丝明白自己可以选择“成为勇者”,Kei也放弃了将爱丽丝强制转化为王女的想法,暂时沉默。 22 | 23 | 最终章中,游戏开发部、空崎阳奈、Rabbit小队合作攻打史林匹亚游乐园的虚伪圣所守护者。战前,爱丽丝说老师提到过阳奈的事情,向阳奈说了不少老师对阳奈的夸赞,让阳奈有些开心,而桃井和绿则在一旁嘀咕让阳奈没那么开心的印象。击败虚伪圣所守护者后,爱丽丝决意与同伴们一起登上生命守护者飞船,迎战阿特拉哈西斯方舟。在飞船上,爱丽丝安慰了一度伤害过自己、发现自己的错误后引咎辞职(畏罪潜逃)、用无人机参加作战的调月莉音会长。为了攻入阿特拉哈西斯方舟,Kei自愿承接了爱丽丝操作“生命守护者”造成的巨大反噬力而消失。返回地面后,爱丽丝根据Kei消失时自己的印象做了个小机器人,并委托明星日鞠向其内部导入桃井游戏机中疑似Kei的数据。 24 | 25 | 来源:https://mzh.moegirl.org.cn/%E5%A4%A9%E7%AB%A5%E7%88%B1%E4%B8%BD%E4%B8%9D -------------------------------------------------------------------------------- /services/browser/driver.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import os 3 | 4 | from loguru import logger 5 | from selenium import webdriver 6 | from selenium.webdriver.chrome.webdriver import WebDriver as ChromeWebDriver 7 | from selenium.webdriver.firefox.webdriver import WebDriver as FireFoxWebDriver 8 | 9 | from common.io.file_sys import fs 10 | from services.browser.config import BrowserConfig 11 | 12 | 13 | class DriverInitializer: 14 | 15 | def __init__(self, config: BrowserConfig = BrowserConfig()) -> None: 16 | self._browser = config.driver 17 | self._profile_dir = config.profile_dir 18 | self._driver = None 19 | 20 | def find_firefox_profile_dir(self): 21 | 22 | if self._profile_dir is None or os.path.exists(self._profile_dir): 23 | # Auto search if not set in the config 24 | logger.info("Searching for FireFox Profiles...") 25 | username = getpass.getuser() 26 | logger.debug(f"Current user: {username}") 27 | if os.name == 'nt': 28 | default_profile_dir = f"C:/Users/{username}/AppData/Roaming/Mozilla/Firefox/Profiles" 29 | self._profile_dir = fs.find_dir(default_profile_dir, ".default-release") 30 | 31 | if self._profile_dir is None: 32 | raise Exception("Can not find FireFox Profiles. Please set it in your config manually.") 33 | else: 34 | logger.info(f"Probable FireFox Profiles directory: {self._profile_dir}") 35 | 36 | def load_firefox_driver(self): 37 | profile = webdriver.FirefoxProfile(self._profile_dir) 38 | options = webdriver.FirefoxOptions() 39 | options.profile = profile 40 | driver = webdriver.Firefox(options=options) 41 | 42 | logger.info("Firefox Driver 加载完毕") 43 | self._driver = driver 44 | 45 | def get_driver(self) -> FireFoxWebDriver | ChromeWebDriver: 46 | if self._browser == "chrome": 47 | raise NotImplementedError() 48 | elif self._browser == "firefox": 49 | self.load_firefox_driver() 50 | 51 | return self._driver 52 | -------------------------------------------------------------------------------- /pipeline/llm/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field 2 | 3 | from common.enumerator import BaseEnum 4 | from common.utils.enum_util import enum_to_markdown 5 | from pipeline.base.base_sync import AbstractPipelineConfig 6 | 7 | 8 | ####### 9 | # LLM # 10 | ####### 11 | 12 | class LLMModelIdEnum(BaseEnum): 13 | DeepSeekAPI: str = "deepseek-chat" 14 | KimiAPI: str = "moonshot-v1-8k" 15 | 16 | ChatGLM3_6B: str = "THUDM/chatglm3-6b" 17 | GLM4: str = "THUDM/GLM-4" 18 | Qwen_7B_Chat: str = "Qwen/Qwen-7B-Chat" 19 | Shisa_7b_V1: str = "augmxnt/shisa-7b-v1" 20 | Yi_6B_Chat: str = "01-ai/Yi-6B-Chat" 21 | DeepSeek_R1_Distill: str = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" 22 | 23 | 24 | class LLMPipelineConfig(AbstractPipelineConfig): 25 | api_key: str | None = Field(default=None, description="The API key for accessing the LLM service. \n" 26 | "Kimi API supported: \n" 27 | "Reference: https://platform.moonshot.cn/docs/guide/start-using-kimi-api \n" 28 | "Deepseek API supported: \n" 29 | "Reference: https://api-docs.deepseek.com/zh-cn/") 30 | openai_format: bool = Field(default=False, description="Whether the output format is compatible with OpenAI. \n" 31 | f"Note: When you use `{LLMModelIdEnum.DeepSeekAPI}` or {LLMModelIdEnum.KimiAPI}, please set it `true`.") 32 | model_id: LLMModelIdEnum = Field(default=LLMModelIdEnum.GLM4, 33 | description=f"The ID of the model used for LLM. \n{enum_to_markdown(LLMModelIdEnum)}") 34 | predict_url: str = Field(default="http://127.0.0.1:11000/llm/predict", 35 | description="The URL for LLM prediction requests.") 36 | stream_predict_url: str = Field(default="http://127.0.0.1:11000/llm/stream-predict", 37 | description="The URL for streaming LLM prediction requests.") 38 | -------------------------------------------------------------------------------- /pipeline/base/base_async.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Any, Generator 3 | 4 | import aiohttp 5 | import urllib3.util 6 | from aiohttp import ClientResponse 7 | from typeguard import typechecked 8 | from zerolan.data.pipeline.abs_data import AbsractImageModelQuery 9 | 10 | 11 | @typechecked 12 | def get_base_url(url: str) -> str: 13 | uri = urllib3.util.parse_url(url) 14 | base_url = f"{uri.scheme}://{uri.host}:{uri.port}" 15 | return base_url 16 | 17 | 18 | class BaseAsyncPipeline: 19 | 20 | def __init__(self, base_url: str): 21 | self._base_url = base_url 22 | self._session: aiohttp.ClientSession | None = None 23 | 24 | @property 25 | def session(self): 26 | if self._session is not None: 27 | return self._session 28 | self._session = aiohttp.ClientSession(self._base_url) 29 | return self._session 30 | 31 | async def _dispose_client_session(self): 32 | if self._session is None: 33 | return 34 | await self._session.close() 35 | 36 | async def close(self): 37 | await self._dispose_client_session() 38 | 39 | 40 | @typechecked 41 | def _parse_imgcap_query(query: AbsractImageModelQuery) -> Dict[str, Any]: 42 | # If the `query.img_path` path exists on the local machine, 43 | # then read the image as a binary file and add it to the `request.files` 44 | if os.path.exists(query.img_path): 45 | query.img_path = os.path.abspath(query.img_path).replace('\\', '/') 46 | img = open(query.img_path, 'rb') 47 | data = {'image': img, 48 | 'json': query.model_dump_json()} 49 | return data 50 | # If the `query.img_path` path does not exist on the local machine, it must exist on the remote host 51 | # Note: If the remote host does not have this file neither, raise 500 error! 52 | else: 53 | return query.model_dump() 54 | 55 | 56 | @typechecked 57 | async def stream_generator(response: ClientResponse, chunk_size: int = -1) -> Generator[bytes, None, None]: 58 | while True: 59 | chunk = await response.content.read(chunk_size) 60 | if not chunk: 61 | break 62 | yield chunk 63 | -------------------------------------------------------------------------------- /pipeline/asr/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field, BaseModel 2 | 3 | from common.enumerator import BaseEnum 4 | from common.utils.enum_util import enum_to_markdown 5 | from pipeline.base.base_sync import AbstractPipelineConfig 6 | 7 | 8 | ####### 9 | # ASR # 10 | ####### 11 | 12 | class AudioFormatEnum(BaseEnum): 13 | Float32: str = "float32" 14 | 15 | 16 | class ASRModelIdEnum(BaseEnum): 17 | Paraformer = "iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1" 18 | KotobaWhisper = 'kotoba-tech/kotoba-whisper-v2.0' 19 | BaiduASR = "BaiduASR" 20 | 21 | 22 | class BaiduASRConfig(BaseModel): 23 | api_key: str = Field(default="", description="The API key for Baidu ASR service.") 24 | secret_key: str = Field(default="", description="The secret key for Baidu ASR service.") 25 | 26 | 27 | class ASRPipelineConfig(AbstractPipelineConfig): 28 | sample_rate: int = Field(16000, description="The sample rate for audio input.") 29 | channels: int = Field(1, description="The number of audio channels.") 30 | format: AudioFormatEnum = Field(AudioFormatEnum.Float32, 31 | description=f"The format of the audio data. {enum_to_markdown(AudioFormatEnum)}") 32 | model_id: ASRModelIdEnum = Field(default=ASRModelIdEnum.Paraformer, 33 | description=f"The ID of the model used for ASR. \n{enum_to_markdown(ASRModelIdEnum)}") 34 | predict_url: str = Field(default="http://127.0.0.1:11000/asr/predict", 35 | description="The URL for ASR prediction requests.") 36 | stream_predict_url: str = Field(default="http://127.0.0.1:11000/asr/stream-predict", 37 | description="The URL for streaming ASR prediction requests.") 38 | baidu_asr_config: BaiduASRConfig = Field(default=BaiduASRConfig(), description="Baidu ASR config." 39 | f"Only edit it when you set `model_id` to `{ASRModelIdEnum.BaiduASR.value}`.\n" 40 | f"For more details please see the [documents](https://cloud.baidu.com/doc/SPEECH/s/qlcirqhz0).") 41 | -------------------------------------------------------------------------------- /pipeline/tts/tts_async.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from typing import Generator 4 | 5 | from typeguard import typechecked 6 | from zerolan.data.pipeline.tts import TTSQuery, TTSPrediction, TTSStreamPrediction 7 | 8 | from pipeline.base.base_async import BaseAsyncPipeline, stream_generator, get_base_url 9 | from pipeline.tts.config import TTSPipelineConfig, TTSModelIdEnum 10 | 11 | 12 | def _parse_tts_query(query: TTSQuery) -> TTSQuery: 13 | if os.path.exists(query.refer_wav_path): 14 | query.refer_wav_path = os.path.abspath(query.refer_wav_path).replace("\\", "/") 15 | return query 16 | 17 | 18 | class TTSAsyncPipeline(BaseAsyncPipeline): 19 | def __init__(self, config: TTSPipelineConfig): 20 | super().__init__(base_url=get_base_url(config.predict_url)) 21 | self._model_id: TTSModelIdEnum = config.model_id 22 | self._predict_endpoint = "/tts/predict" 23 | self._stream_predict_endpoint = "/tts/stream-predict" 24 | 25 | @typechecked 26 | async def predict(self, query: TTSQuery) -> TTSPrediction: 27 | query = _parse_tts_query(query) 28 | async with self.session.post(self._predict_endpoint, json=query.model_dump()) as resp: 29 | data = await resp.content.read() 30 | return TTSPrediction(wave_data=data, audio_type=query.audio_type) 31 | 32 | @typechecked 33 | async def stream_predict(self, query: TTSQuery) -> Generator[TTSStreamPrediction, None, None]: 34 | query = _parse_tts_query(query) 35 | async with self.session.post(self._stream_predict_endpoint, json=query.model_dump()) as resp: 36 | last = 0 37 | id = str(uuid.uuid4()) 38 | idx = 0 39 | async for chunk in stream_generator(resp): 40 | last = idx 41 | yield TTSStreamPrediction(seq=idx, 42 | id=id, 43 | is_final=False, 44 | wave_data=chunk, 45 | audio_type=query.audio_type) 46 | idx += 1 47 | yield TTSStreamPrediction(is_final=True, seq=last + 1, audio_type=query.audio_type, wave_data=b'') 48 | -------------------------------------------------------------------------------- /agent/custom_agent.py: -------------------------------------------------------------------------------- 1 | # See: 2 | # https://python.langchain.com/docs/tutorials/agents/ 3 | from injector import inject 4 | from langchain_core.messages import HumanMessage, AIMessage 5 | from langchain_core.tools import BaseTool 6 | from loguru import logger 7 | from selenium.webdriver import Firefox, Chrome 8 | 9 | from agent.tool.go_creator import GameObjectCreator 10 | from agent.tool.lang_changer import LangChanger 11 | from agent.tool.microphone_tool import MicrophoneTool 12 | from agent.tool.web_search import BaiduBaikeTool, MoeGirlTool 13 | from agent.tool_agent import ToolAgent 14 | from services.playground.bridge import PlaygroundBridge 15 | from pipeline.llm.config import LLMPipelineConfig 16 | 17 | 18 | class CustomAgent: 19 | 20 | @inject 21 | def __init__(self, config: LLMPipelineConfig, driver: Firefox | Chrome = None, 22 | bridge: PlaygroundBridge | None = None): 23 | self._model = ToolAgent(config=config) 24 | # Here to register more tools 25 | tools = [BaiduBaikeTool(), LangChanger(), MicrophoneTool()] 26 | if driver is not None: 27 | tools.append(MoeGirlTool(driver)) 28 | if bridge is not None: 29 | tools.append(GameObjectCreator()) 30 | self._tools = {} 31 | self._model.bind_tools(tools) 32 | for tool in tools: 33 | self._tools[tool.name] = tool 34 | 35 | def run(self, query: str) -> bool: 36 | messages = [self._model.system_prompt, HumanMessage(query)] 37 | ai_msg: AIMessage = self._model.invoke(messages) 38 | messages.append(ai_msg) 39 | if len(ai_msg.tool_calls) == 0: 40 | logger.debug("No tool to call in this conversation") 41 | return False 42 | for tool_call in ai_msg.tool_calls: 43 | tool_name = tool_call["name"].lower() 44 | selected_tool: BaseTool = self._tools.get(tool_name, None) 45 | if selected_tool is not None: 46 | try: 47 | tool_msg = selected_tool.invoke(tool_call) 48 | messages.append(tool_msg) 49 | except Exception as e: 50 | logger.warning(e) 51 | pass 52 | 53 | return True 54 | -------------------------------------------------------------------------------- /common/concurrent/abs_runnable.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from abc import abstractmethod 3 | from typing import Dict, List 4 | 5 | from loguru import logger 6 | 7 | 8 | class AsyncRunnable: 9 | 10 | def __init__(self): 11 | self._activate: bool = False 12 | self.id: str = str(uuid.uuid4()) 13 | 14 | @abstractmethod 15 | def name(self): 16 | return self.id 17 | 18 | @abstractmethod 19 | async def start(self): 20 | self._activate = True 21 | add_runnable(self) 22 | 23 | def activate_check(self): 24 | if not self._activate: 25 | raise RuntimeError("This runnable object is not activated. Call `start()` first.") 26 | 27 | @abstractmethod 28 | async def stop(self): 29 | self._activate = False 30 | 31 | 32 | class ThreadRunnable: 33 | 34 | def __init__(self): 35 | self._activate: bool = False 36 | self.id: str = str(uuid.uuid4()) 37 | 38 | @abstractmethod 39 | def name(self): 40 | return self.id 41 | 42 | @abstractmethod 43 | def start(self): 44 | self._activate = True 45 | add_runnable(self) 46 | 47 | def activate_check(self): 48 | if not self._activate: 49 | raise RuntimeError("This runnable object is not activated. Call `start()` first.") 50 | 51 | @abstractmethod 52 | def stop(self): 53 | self._activate = False 54 | 55 | 56 | # 所有的可运行组件都应该在调用 `start` 方法的时候被注册在这里 57 | # All runnable components should be registered here when the `start` method is called 58 | _all: Dict[str, AsyncRunnable] = {} 59 | _ids: List[str] = [] 60 | 61 | 62 | def add_runnable(run: AsyncRunnable | ThreadRunnable): 63 | _all[run.id] = run 64 | _ids.append(run.id) 65 | logger.debug(f"Runnable {run.name()}: {run.id}") 66 | 67 | 68 | async def stop_all_runnable(): 69 | """ 70 | 强制停止所有可运行组件的运行 71 | Force stop the operation of all runnable components 72 | """ 73 | global _all 74 | ids = _ids.copy() 75 | ids.reverse() 76 | 77 | for id in ids: 78 | run = _all.pop(id, None) 79 | if run is None: 80 | logger.warning(f"Runnable dose not exist: {id}") 81 | return 82 | await run.stop() 83 | logger.debug(f"Runnable {run.name()}({id}): killed.") 84 | -------------------------------------------------------------------------------- /tests/test_agent.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from asyncio import TaskGroup 3 | 4 | from zerolan.data.pipeline.llm import Conversation, RoleEnum 5 | 6 | from agent.api import summary, summary_history, model_scale 7 | from agent.custom_agent import CustomAgent 8 | from manager.config_manager import get_config 9 | from services.playground.data import GameObject 10 | 11 | _config = get_config() 12 | 13 | 14 | def test_summary(): 15 | with open("./resources/text.txt", mode="r", encoding="utf-8") as f: 16 | text = f.read() 17 | print(text) 18 | 19 | summary_text = summary(text, 100) 20 | print(summary_text.content) 21 | 22 | 23 | def test_summary_history(): 24 | history = [ 25 | Conversation(role=RoleEnum.user, content="我最喜欢吃的东西是阿米糯司,你喜欢吃吗?"), 26 | Conversation(role=RoleEnum.assistant, content="阿米糯司?那是什么东西啊?听起来是一个寿司的名称……"), 27 | Conversation(role=RoleEnum.user, content="差不多吧。它是一种用糯米制作的东西,可好吃了。"), 28 | Conversation(role=RoleEnum.assistant, content="真的吗?那我下次也要尝尝阿米糯司!") 29 | ] 30 | summaried_history = summary_history(history) 31 | print(summaried_history.content) 32 | 33 | 34 | def test_model_scale(): 35 | go_info_json = [{ 36 | "instance_id": 42, 37 | "game_object_name": "白子", 38 | "transform": { 39 | "scale": 1, 40 | "position": { 41 | "x": 15.3, 42 | "y": -3.7, 43 | "z": 42.1 44 | } 45 | } 46 | }, { 47 | "instance_id": 1526, 48 | "game_object_name": "优香", 49 | "transform": { 50 | "scale": 1, 51 | "position": { 52 | "x": 5.3, 53 | "y": -3.7, 54 | "z": 2.1 55 | } 56 | } 57 | }] 58 | 59 | go_info = [] 60 | for info in go_info_json: 61 | go_info.append(GameObject.model_validate(info)) 62 | 63 | result = model_scale(go_info, "帮我放大一下优香") 64 | print(f"{result.instance_id}: {result.target_scale}") 65 | 66 | 67 | async def atest_custom_agent(): 68 | async with TaskGroup() as tg: 69 | await asyncio.sleep(5) 70 | agent = CustomAgent(_config.pipeline.llm) 71 | agent.run("创建一个绿色的立方体") 72 | 73 | 74 | def test_custom_agent(): 75 | asyncio.run(atest_custom_agent()) 76 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | from character.config import CharacterConfig 4 | from pipeline.base.config import PipelineConfig 5 | from services.config import ServiceConfig 6 | 7 | 8 | class SystemConfig(BaseModel): 9 | default_enable_microphone: bool = Field(default=False, 10 | description="For safety, do not open your microphone by default. \n" 11 | "You can set it `True` to enable your microphone") 12 | 13 | 14 | class ZerolanLiveRobotConfig(BaseModel): 15 | pipeline: PipelineConfig = Field(default=PipelineConfig(), 16 | description="Configuration for the pipeline settings. \n" 17 | "The pipeline is the key to connecting to `ZerolanCore`, \n" 18 | "which typically accesses the model via HTTP or HTTPS requests and gets a response from the model. \n" 19 | "> [!NOTE] \n" 20 | "> 1. At a minimum, you need to enable the LLMPipeline. \n" 21 | "> 2. ZerolanCore is distributed, and you can deploy different models to different servers. Just set different url to connect to your models. \n" 22 | "> 3. If your server can only open one port, try forwarding your network requests using [Nginx](https://nginx.org/en/).") 23 | service: ServiceConfig = Field(default=ServiceConfig(), 24 | description="Configuration for the service settings. \n" 25 | "The services are usually opened locally, \n" 26 | "and instances of other projects establish WebSocket or HTTP connections with the service, \n" 27 | "and the service controls the behavior of its sub-project instances.") 28 | character: CharacterConfig = Field(default=CharacterConfig(), 29 | description="Configuration for the character settings.") 30 | system: SystemConfig = Field(default=SystemConfig(), description="Configuration for the system settings.") 31 | -------------------------------------------------------------------------------- /services/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | from services.browser.config import BrowserConfig 4 | from services.game.config import GameBridgeConfig 5 | from services.live2d.config import Live2DViewerConfig 6 | from services.live_stream.config import LiveStreamConfig 7 | from services.obs.config import ObsStudioClientConfig 8 | from services.playground.config import PlaygroundBridgeConfig 9 | from services.playground.res.config import ResourceServerConfig 10 | from services.qqbot.config import QQBotBridgeConfig 11 | 12 | 13 | class ServiceConfig(BaseModel): 14 | res_server: ResourceServerConfig = Field(default=ResourceServerConfig(), 15 | description="Configuration for the Resource Server.") 16 | live_stream: LiveStreamConfig = Field(default=LiveStreamConfig(), 17 | description="Configuration for the Live Stream service.") 18 | game: GameBridgeConfig = Field(default=GameBridgeConfig(), description="Configuration for the Game Bridge service.") 19 | playground: PlaygroundBridgeConfig = Field(default=PlaygroundBridgeConfig(), 20 | description="Configuration for the Playground Bridge service.") 21 | qqbot: QQBotBridgeConfig = Field(default=QQBotBridgeConfig(), 22 | description="Configuration for the QQBot Bridge service.") 23 | obs: ObsStudioClientConfig = Field(default=ObsStudioClientConfig(), 24 | description="Configuration for the OBS Studio Client.") 25 | browser: BrowserConfig = Field(default=BrowserConfig(), description="Browser config.") 26 | live2d_viewer: Live2DViewerConfig = Field(default=Live2DViewerConfig(), 27 | description="Configuration for the Live2DViewer service. " 28 | "[!Attention]\n" 29 | "1. When use OBS to capture the window, you should use GameSource. " 30 | "Then enable `SLI/Cross` and `Allow window transparent` options." 31 | "Or the windows will not display.\n" 32 | "2. Use `Window capture` will leave black ground.") 33 | -------------------------------------------------------------------------------- /common/web/zrl_ws.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | from loguru import logger 4 | from websockets import ProtocolError 5 | from zerolan.data.protocol.protocol import ZerolanProtocol 6 | 7 | from common.concurrent.abs_runnable import ThreadRunnable 8 | from common.web.json_ws import JsonWsServer 9 | 10 | 11 | ######################################## 12 | # Zerolan Protocol Web Socket Server # 13 | # Author: AkagawaTsurunaki # 14 | ######################################## 15 | 16 | class ZerolanProtocolWsServer(ThreadRunnable): 17 | def __init__(self, host: str, port: int): 18 | super().__init__() 19 | self._protocol = "ZerolanProtocol" 20 | self._jws = JsonWsServer(host=host, port=port, subprotocols=[self._protocol]) 21 | self._version = "1.1" 22 | 23 | def name(self): 24 | return "ZerolanProtocolWebsocket" 25 | 26 | def start(self): 27 | super().start() 28 | self._init() 29 | self._jws.start() 30 | 31 | def stop(self): 32 | super().stop() 33 | self._jws.stop() 34 | 35 | @property 36 | def is_connected(self): 37 | if self._jws.connections > 0: 38 | return True 39 | return False 40 | 41 | def send(self, action: str, data: any, message: str = "", code: int = 0): 42 | protocol = ZerolanProtocol(message=message, 43 | code=code, 44 | action=action, 45 | data=data) 46 | self._jws.send_json(protocol) 47 | 48 | def _init(self): 49 | def on_json_msg(_, protocol: dict | list): 50 | protocol = self._validate_zerolan_protocol(protocol) 51 | logger.debug(f"Validated Zerolan Protocol {protocol.data}") 52 | self.on_protocol(protocol) 53 | 54 | self._jws.on_msg_handlers += [on_json_msg] 55 | 56 | def _validate_zerolan_protocol(self, data: dict | list): 57 | recv_obj = ZerolanProtocol.model_validate(data) 58 | if recv_obj.protocol == self._protocol and recv_obj.version == self._version: 59 | return recv_obj 60 | raise ProtocolError("Invalid ZerolanProtocol") 61 | 62 | @abstractmethod 63 | def on_protocol(self, protocol: ZerolanProtocol): 64 | raise NotImplementedError() 65 | 66 | @abstractmethod 67 | def on_disconnect(self, ws_id: str): 68 | raise NotImplementedError() 69 | -------------------------------------------------------------------------------- /common/generator/config_gen.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from pydantic import BaseModel 4 | from pydantic.fields import FieldInfo 5 | from typeguard import typechecked 6 | 7 | from common import ver_check 8 | from common.decorator import log_run_time 9 | from common.utils.time_util import get_time_iso_string 10 | 11 | 12 | class ConfigFileGenerator: 13 | 14 | def __init__(self, indent: int = 2): 15 | self._yaml_str = "" 16 | self._indent = indent 17 | 18 | def _get_indent(self, depth: int): 19 | return " " * self._indent * depth 20 | 21 | def _add_comments(self, field_info: FieldInfo, depth: int): 22 | if field_info.description: 23 | for description_line in field_info.description.split("\n"): 24 | self._yaml_str += self._get_indent(depth) + f"# {description_line}\n" 25 | 26 | def _gen(self, model: BaseModel, depth: int = 0): 27 | fields = model.model_fields 28 | 29 | for field_name, field_info in fields.items(): 30 | field_val = model.__getattribute__(field_name) 31 | if isinstance(field_val, BaseModel): 32 | self._add_comments(field_info, depth) 33 | self._yaml_str += self._get_indent(depth) + f"{field_name}:\n" 34 | self._gen(field_val, depth + 1) 35 | else: 36 | self._add_comments(field_info, depth) 37 | if isinstance(type(field_val), type(Enum)): 38 | self._yaml_str += self._get_indent(depth) + f"{field_name}: '{field_val.value}'\n" 39 | elif isinstance(field_val, str): 40 | self._yaml_str += self._get_indent(depth) + f"{field_name}: '{field_val}'\n" 41 | else: 42 | self._yaml_str += self._get_indent(depth) + f"{field_name}: {field_val}\n" 43 | 44 | def _get_header(self): 45 | generated_info = f"# This file was generated at {get_time_iso_string()} #" 46 | header = "#" * len(generated_info) + "\n" \ 47 | + generated_info + "\n" \ 48 | + "#" * len(generated_info) + "\n" 49 | 50 | return header 51 | 52 | @log_run_time() 53 | @typechecked 54 | def generate_yaml(self, model: BaseModel): 55 | """ 56 | Generate yaml from BaseModel instance. 57 | :param model: An instance of BaseModel. 58 | :return: Yaml string. 59 | """ 60 | ver_check.check_pydantic_ver() 61 | self._gen(model, depth=0) 62 | return self._get_header() + "\n" + self._yaml_str 63 | -------------------------------------------------------------------------------- /devices/speaker.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from enum import Enum 3 | from pathlib import Path 4 | from queue import Queue 5 | 6 | import pygame 7 | 8 | from common.concurrent.abs_runnable import ThreadRunnable 9 | from common.concurrent.killable_thread import KillableThread 10 | 11 | pygame.mixer.init() 12 | _system_sound = False 13 | 14 | 15 | class SystemSoundEnum(str, Enum): 16 | warn: str = "warn.wav" 17 | error: str = "error.wav" 18 | start: str = "start.wav" 19 | exit: str = "exit.wav" 20 | enable_func: str = "microphone-recoding.wav" 21 | disable_func: str = "microphone-stopped.wav" 22 | filtered: str = "filtered.wav" 23 | 24 | 25 | class Speaker(ThreadRunnable): 26 | 27 | def name(self): 28 | return 'Speaker' 29 | 30 | def __init__(self): 31 | super().__init__() 32 | self._stop_flag = False 33 | self._semaphore = threading.Event() 34 | self._speaker_thread = KillableThread(target=self._run) 35 | self.audio_clips: Queue[Path] = Queue() 36 | 37 | def start(self): 38 | super().start() 39 | self._stop_flag = False 40 | self._semaphore.set() 41 | self._speaker_thread.start() 42 | 43 | def stop(self): 44 | super().stop() 45 | self._stop_flag = True 46 | self.audio_clips = None 47 | self._speaker_thread.kill() 48 | 49 | def _run(self): 50 | while not self._stop_flag: 51 | 52 | if self.audio_clips.empty(): 53 | self._semaphore.clear() 54 | self._semaphore.wait() 55 | 56 | audio_clip = self.audio_clips.get() 57 | self.playsound(audio_clip, block=True) 58 | 59 | def enqueue_sound(self, path_or_data: Path): 60 | self.activate_check() 61 | self.audio_clips.put(path_or_data) 62 | self._semaphore.set() 63 | 64 | def stop_now(self): 65 | pygame.mixer.stop() 66 | self.audio_clips = Queue() 67 | 68 | @staticmethod 69 | def playsound(path: Path, block: bool = True): 70 | if block: 71 | Speaker._sync_playsound(path) 72 | else: 73 | Speaker._async_playsound(path) 74 | 75 | @staticmethod 76 | def _sync_playsound(path: Path): 77 | pygame.mixer.music.load(path) 78 | pygame.mixer.music.play() 79 | Speaker.wait() 80 | 81 | @staticmethod 82 | def wait(): 83 | while pygame.mixer.music.get_busy(): 84 | continue 85 | 86 | @staticmethod 87 | def _async_playsound(path: Path): 88 | sound = pygame.mixer.Sound(path) 89 | pygame.mixer.Sound.play(sound) 90 | -------------------------------------------------------------------------------- /pipeline/asr/asr_async.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, BinaryIO, Generator 3 | 4 | from loguru import logger 5 | from typeguard import typechecked 6 | from zerolan.data.pipeline.abs_data import AbstractModelQuery 7 | from zerolan.data.pipeline.asr import ASRQuery, ASRPrediction, ASRStreamQuery 8 | 9 | from pipeline.asr.config import ASRPipelineConfig, ASRModelIdEnum 10 | from pipeline.base.base_async import BaseAsyncPipeline, stream_generator, get_base_url 11 | 12 | 13 | @typechecked 14 | def _parse_asr_query(query: ASRQuery) -> Dict[str, BinaryIO | str]: 15 | data = {"json": query.model_dump_json()} 16 | if os.path.exists(query.audio_path): 17 | data['audio'] = open(query.audio_path, 'rb') 18 | else: 19 | logger.warning(f'Assume the remote server must have such file: {query.audio_path}') 20 | 21 | return data 22 | 23 | 24 | @typechecked 25 | def _parse_asr_stream_query(query: ASRStreamQuery) -> Dict[str, BinaryIO | str]: 26 | assert len(query.audio_data) > 0 27 | 28 | # Only used for converting to json 29 | class StubASRStreamQuery(AbstractModelQuery): 30 | is_final: bool 31 | audio_data: str 32 | media_type: str 33 | sample_rate: int 34 | channels: int 35 | 36 | stub_query = StubASRStreamQuery( 37 | is_final=query.is_final, 38 | audio_data="", 39 | media_type=query.media_type, 40 | sample_rate=query.sample_rate, 41 | channels=query.channels, 42 | ) 43 | data = {"json": stub_query.model_dump_json(), "audio": query.audio_data} 44 | 45 | return data 46 | 47 | 48 | class ASRAsyncPipeline(BaseAsyncPipeline): 49 | 50 | def __init__(self, config: ASRPipelineConfig): 51 | super().__init__(base_url=get_base_url(config.predict_url)) 52 | self._model_id: ASRModelIdEnum = config.model_id 53 | self._predict_endpoint = "/asr/predict" 54 | self._stream_predict_endpoint = "/asr/stream-predict" 55 | 56 | @typechecked 57 | async def predict(self, query: ASRQuery) -> ASRPrediction: 58 | data = _parse_asr_query(query) 59 | async with self.session.post(self._predict_endpoint, data=data) as resp: 60 | return await resp.json(encoding='utf8', loads=ASRPrediction.model_validate_json) 61 | 62 | @typechecked 63 | async def stream_predict(self, query: ASRStreamQuery) -> Generator[ASRPrediction, None, None]: 64 | data = _parse_asr_stream_query(query) 65 | async with self.session.post(self._stream_predict_endpoint, data=data) as resp: 66 | async for chunk in stream_generator(resp): 67 | yield ASRPrediction.model_validate_json(chunk) 68 | -------------------------------------------------------------------------------- /devices/screen.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import pyautogui 4 | import pygetwindow as gw 5 | from PIL.Image import Image 6 | from loguru import logger 7 | from pygetwindow import Win32Window 8 | 9 | from common.io.file_sys import fs 10 | 11 | 12 | class Screen: 13 | 14 | def __init__(self): 15 | os_name = platform.system() 16 | if os_name != "Windows": 17 | raise NotImplementedError("Only support Windows platform.") 18 | 19 | def safe_capture(self, win_title: str = None, k: float = 1.0): 20 | try: 21 | if win_title is None: 22 | return self.capture_activated_win(k) 23 | else: 24 | return self.capture_with_title(win_title, k) 25 | except ValueError as e: 26 | if str(e) == "Coordinate 'right' is less than 'left'": 27 | logger.warning( 28 | "Window capture failed: Taking a screenshot in a split-screen situation may cause problems, please try placing the target window on the home screen.") 29 | except AssertionError as e: 30 | logger.warning(e) 31 | except gw.PyGetWindowException as e: 32 | if "Error code from Windows: 0" in str(e): 33 | logger.warning("Window capture failed: Lost focus. Is your target window activated?") 34 | except Exception as e: 35 | logger.exception(e) 36 | logger.warning("Window capture failed: Unknown error.") 37 | return None, None 38 | 39 | def capture_activated_win(self, k: float = 1.0): 40 | w = gw.getActiveWindow() 41 | return self._capture(w, k) 42 | 43 | def capture_with_title(self, win_title: str, k: float = 1.0): 44 | # Get the window 45 | win_list = gw.getWindowsWithTitle(win_title) 46 | assert len(win_list) != 0, f'Window capture failed: Can not find {win_title}' 47 | w = win_list[0] 48 | # Activate the window 49 | w.activate() 50 | return self._capture(w, k) 51 | 52 | def _capture(self, w: Win32Window, k: float) -> (Image, str): 53 | region = ( 54 | w.centerx - k * w.width / 2, w.centery - k * w.height / 2, w.centerx + k * w.width / 2, 55 | w.centery + k * w.height / 2) 56 | region = tuple(int(num * k) for num in region) 57 | 58 | assert hasattr(pyautogui, "screenshot") 59 | # Note: If you have a problem that the screenshot cannot be found, try updating the `pyautogui` library 60 | img = pyautogui.screenshot(region=region) # noqa 61 | 62 | img_save_path = fs.create_temp_file_descriptor(prefix="screenshot", suffix=".png", type="image") 63 | img.save(img_save_path) 64 | 65 | return img, img_save_path 66 | -------------------------------------------------------------------------------- /agent/tool/web_search.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Optional 2 | 3 | import requests 4 | from bs4 import BeautifulSoup 5 | from langchain_core.callbacks import CallbackManagerForToolRun 6 | from langchain_core.tools import BaseTool, ToolException 7 | from loguru import logger 8 | from pydantic import BaseModel, Field 9 | from selenium.webdriver import Firefox, Chrome 10 | 11 | 12 | def html_to_text(html: str): 13 | soup = BeautifulSoup(html, 'html.parser') 14 | return soup.get_text() 15 | 16 | 17 | def get_html(url: str): 18 | response = requests.get(url) 19 | html = response.content 20 | return html 21 | 22 | 23 | # See: 24 | # https://python.langchain.com/docs/how_to/custom_tools/#subclassing-the-basetool-class 25 | 26 | 27 | class BaiduBaikeToolInput(BaseModel): 28 | keyword: str = Field(description="The keyword you want to search.") 29 | 30 | 31 | class BaiduBaikeTool(BaseTool): 32 | name: str = "百度百科" 33 | description: str = "当你需要搜索某个专业的知识点、概念的时候,使用此工具。" 34 | args_schema: Type[BaseModel] = BaiduBaikeToolInput 35 | return_direct: bool = True 36 | 37 | def __init__(self): 38 | """ 39 | Get page content from 40 | """ 41 | super().__init__() 42 | self._url = "https://baike.baidu.com/item" 43 | 44 | def _run(self, keyword: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str: 45 | if isinstance(keyword, str) and keyword != "": 46 | html = get_html(f"{self._url}/{keyword}") 47 | content = html_to_text(html) 48 | if "百度百科错误页" in content: 49 | raise ToolException(f"BaiduBaike returns error page: {keyword} is not found?") 50 | return content 51 | else: 52 | raise ToolException("Keyword should not be empty") 53 | 54 | 55 | class MoeGirlTool(BaseTool): 56 | name: str = "萌娘百科" 57 | description: str = "当你需要搜索二次元人物、游戏、漫画等资料时,使用此工具。" 58 | 59 | def __init__(self, driver: Firefox | Chrome): 60 | """ 61 | Get page content from MoeGirl (萌娘百科). 62 | Args: 63 | driver: 64 | """ 65 | super().__init__() 66 | self._url = "https://mzh.moegirl.org.cn" 67 | self._driver: Firefox | Chrome = driver 68 | 69 | def _run(self, keyword: str, run_manager: Optional[CallbackManagerForToolRun] = None) -> str: 70 | self._driver.set_page_load_timeout(5) 71 | try: 72 | self._driver.get(f"{self._url}/{keyword}") 73 | for i in range(3): 74 | self._driver.execute_script("window.scrollBy(0,3000)") 75 | logger.debug(f"Scroll by 3000 * {i}") 76 | self._driver.implicitly_wait(1) 77 | except Exception as _: 78 | pass 79 | plain_text = html_to_text(self._driver.page_source) 80 | return plain_text 81 | -------------------------------------------------------------------------------- /pipeline/tts/tts_sync.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import uuid 3 | from http import HTTPStatus 4 | 5 | import requests 6 | from loguru import logger 7 | from zerolan.data.pipeline.tts import TTSQuery, TTSPrediction, TTSStreamPrediction 8 | 9 | from pipeline.base.base_sync import CommonModelPipeline 10 | from pipeline.tts.baidu_tts import BaiduTTSPipeline 11 | from pipeline.tts.config import TTSPipelineConfig, TTSModelIdEnum 12 | 13 | 14 | class TTSSyncPipeline(CommonModelPipeline): 15 | 16 | def __init__(self, config: TTSPipelineConfig): 17 | super().__init__(config) 18 | # Support Baidu TTS API 19 | if config.model_id == TTSModelIdEnum.BaiduTTS and config.baidu_tts_config is not None: 20 | self.baidu = BaiduTTSPipeline(api_key=config.baidu_tts_config.api_key, 21 | secret_key=config.baidu_tts_config.secret_key) 22 | self.predict = self.baidu.predict 23 | self.stream_predict = self.baidu.stream_predict 24 | 25 | def predict(self, query: TTSQuery) -> TTSPrediction | None: 26 | assert isinstance(query, TTSQuery) 27 | if os.path.exists(query.refer_wav_path): 28 | query.refer_wav_path = os.path.abspath(query.refer_wav_path).replace("\\", "/") 29 | query_dict = self.parse_query(query) 30 | response = requests.post(url=self.predict_url, stream=True, json=query_dict) 31 | if response.status_code == HTTPStatus.OK: 32 | prediction = TTSPrediction(wave_data=response.content, audio_type=query.audio_type) 33 | return prediction 34 | else: 35 | logger.error(response.content) 36 | response.raise_for_status() 37 | 38 | def stream_predict(self, query: TTSQuery, chunk_size: int | None = None): 39 | assert isinstance(query, TTSQuery) 40 | if os.path.exists(query.refer_wav_path): 41 | query.refer_wav_path = os.path.abspath(query.refer_wav_path).replace("\\", "/") 42 | query_dict = self.parse_query(query) 43 | response = requests.post(url=self.stream_predict_url, stream=True, 44 | json=query_dict) 45 | response.raise_for_status() 46 | last = 0 47 | id = str(uuid.uuid4()) 48 | for idx, chunk in enumerate(response.iter_content(chunk_size=1024)): 49 | last = idx 50 | yield TTSStreamPrediction(seq=idx, 51 | id=id, 52 | is_final=False, 53 | wave_data=chunk, 54 | audio_type=query.audio_type) 55 | yield TTSStreamPrediction(is_final=True, seq=last + 1, audio_type=query.audio_type, wave_data=b'') 56 | 57 | def parse_query(self, query: any) -> dict: 58 | return super().parse_query(query) 59 | -------------------------------------------------------------------------------- /pipeline/asr/asr_sync.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Tuple, Generator 3 | 4 | import requests 5 | from typeguard import typechecked 6 | from zerolan.data.pipeline.asr import ASRQuery, ASRPrediction, ASRStreamQuery 7 | 8 | from pipeline.asr.baidu_asr import BaiduASRPipeline 9 | from pipeline.asr.config import ASRPipelineConfig, ASRModelIdEnum 10 | from pipeline.base.base_sync import CommonModelPipeline 11 | 12 | 13 | class ASRSyncPipeline(CommonModelPipeline): 14 | 15 | def __init__(self, config: ASRPipelineConfig): 16 | super().__init__(config) 17 | if config.model_id == ASRModelIdEnum.BaiduASR and config.baidu_asr_config is not None: 18 | baidu = BaiduASRPipeline(api_key=config.baidu_asr_config.api_key, 19 | secret_key=config.baidu_asr_config.secret_key) 20 | self.predict = baidu.predict 21 | self.stream_predict = baidu.stream_predict 22 | 23 | @typechecked 24 | def predict(self, query: ASRQuery) -> ASRPrediction | None: 25 | assert isinstance(query, ASRQuery) 26 | files, data = self.parse_query(query) 27 | response = requests.post(url=self.predict_url, files=files, data=data) 28 | 29 | response.raise_for_status() 30 | prediction = self.parse_prediction(response.content) 31 | return prediction 32 | 33 | @typechecked 34 | def stream_predict(self, query: ASRStreamQuery, chunk_size: int | None = None) -> Generator[ 35 | ASRPrediction, None, None]: 36 | assert isinstance(query, ASRStreamQuery) 37 | files, data = self.parse_query(query) 38 | response = requests.post(url=self.stream_predict_url, files=files, data=data) 39 | response.raise_for_status() 40 | 41 | for chunk in response.iter_content(chunk_size=chunk_size, decode_unicode=True): 42 | prediction = self.parse_stream_prediction(chunk) 43 | yield prediction 44 | 45 | def parse_query(self, query: ASRQuery | ASRStreamQuery) -> Tuple[dict, dict]: 46 | if isinstance(query, ASRQuery): 47 | files = None 48 | if os.path.exists(query.audio_path): 49 | files = {"audio": open(query.audio_path, 'rb')} 50 | data = {"json": query.model_dump_json()} 51 | 52 | return files, data 53 | elif isinstance(query, ASRStreamQuery): 54 | assert len(query.audio_data) > 0 55 | files = {"audio": query.audio_data} 56 | query.audio_data = "" 57 | data = {"json": query.model_dump_json()} 58 | 59 | return files, data 60 | else: 61 | raise ValueError("Can not convert query.") 62 | 63 | def parse_prediction(self, json_val: str) -> ASRPrediction: 64 | return ASRPrediction.model_validate_json(json_val) 65 | 66 | def parse_stream_prediction(self, chunk: str) -> ASRPrediction: 67 | return ASRPrediction.model_validate_json(chunk) 68 | -------------------------------------------------------------------------------- /services/live_stream/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | 4 | class BilibiliServiceConfig(BaseModel): 5 | class Credential(BaseModel): 6 | sessdata: str = Field(default="", description="Value of the `SESSDATA` from the cookie.") 7 | bili_jct: str = Field(default="", description="Value of the `bili_jct` from the cookie.") 8 | buvid3: str = Field(default="", description="Value of the `buvid3` from the cookie.") 9 | enable: bool = Field(default=True, 10 | description="Enable live stream listening for Bilibili.") 11 | room_id: int = Field(default=-1, 12 | description="Bilibili Room ID. \n" 13 | "Note: Must be a positive integer.") 14 | credential: Credential = Field(default=Credential(), 15 | description="Your Bilibili Credential. \n" 16 | "How to get: [获取 Credential 类所需信息](https://nemo2011.github.io/bilibili-api/#/get-credential)") 17 | 18 | 19 | class TwitchServiceConfig(BaseModel): 20 | enable: bool = Field(default=True, 21 | description="Enable live stream listening for Twitch.") 22 | channel_id: str = Field(default="", 23 | description="Your Twitch channel ID.") 24 | app_id: str = Field(default="", 25 | description="Your Twitch app ID.") 26 | app_secret: str | None = Field(default=None, 27 | description="Your Twitch app secret. \n" 28 | "Leave it as `null` if you only want to use User Authentication. \n" 29 | "How to get: [Twitch Developers - Authentication](https://dev.twitch.tv/docs/authentication/)") 30 | 31 | 32 | class YoutubeServiceConfig(BaseModel): 33 | enable: bool = Field(default=True, 34 | description="Enable live stream listening for YouTube.") 35 | token: str = Field(default="", 36 | description="GCloud auth print access token. \n" 37 | "How to get: [Obtaining authorization credentials](https://developers.google.cn/youtube/registering_an_application?hl=en)") 38 | 39 | 40 | class LiveStreamConfig(BaseModel): 41 | enable: bool = Field(default=True, 42 | description="Enable live stream listening.") 43 | bilibili: BilibiliServiceConfig = Field(default=BilibiliServiceConfig(), 44 | description="Config for connecting to Bilibili live-streaming server") 45 | twitch: TwitchServiceConfig = Field(default=TwitchServiceConfig(), 46 | description="Config for connecting to Twitch live-streaming server") 47 | youtube: YoutubeServiceConfig = Field(default=TwitchServiceConfig(), 48 | description="Config for connecting to YouTube live-streaming server") 49 | -------------------------------------------------------------------------------- /common/concurrent/killable_thread.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modify from: 3 | http://www.s1nh.org/post/python-different-ways-to-kill-a-thread/ 4 | """ 5 | import ctypes 6 | import threading 7 | from typing import List 8 | 9 | from loguru import logger 10 | 11 | PyGILState_Ensure = ctypes.pythonapi.PyGILState_Ensure 12 | PyGILState_Release = ctypes.pythonapi.PyGILState_Release 13 | 14 | 15 | class ThreadKilledError(RuntimeError): 16 | def __init__(self, *args): 17 | super().__init__(*args) 18 | 19 | 20 | class ThreadCanNotBeKilledError(RuntimeError): 21 | def __init__(self, *args): 22 | super().__init__(*args) 23 | 24 | 25 | class KillableThread(threading.Thread): 26 | def __init__(self, group=None, target=None, name=None, 27 | args=(), kwargs=None, *, daemon=None): 28 | super().__init__(group, target, name, args, kwargs, daemon=daemon) 29 | self._killed = False 30 | add_thread(self) 31 | 32 | def get_id(self): 33 | """ 34 | Get the id of the respective thread. 35 | Returns: Thread id. 36 | 37 | """ 38 | # returns id of the respective thread 39 | if hasattr(self, '_thread_id'): 40 | return self._thread_id 41 | for id, thread in threading._active.items(): 42 | if thread is self: 43 | return id 44 | 45 | def kill(self): 46 | """ 47 | Kill the thread unsafely. 48 | Notes: This is an unsafe method the thread execution may be corrupted. 49 | Throws: ThreadCanNotBeKilledError if the thread is not killed successfully. 50 | ThreadKilledError if the thread is killed successfully. 51 | 52 | """ 53 | thread_id = self.get_id() 54 | res = ctypes.pythonapi.PyThreadState_SetAsyncExc(thread_id, 55 | ctypes.py_object(ThreadKilledError)) 56 | if res > 1: 57 | ctypes.pythonapi.PyThreadState_SetAsyncExc(thread_id, 0) 58 | raise ThreadCanNotBeKilledError('Exception raise failure') 59 | self._killed = True 60 | 61 | def kill_with_gil_held(self): 62 | gstate = PyGILState_Ensure() 63 | try: 64 | self.kill() 65 | finally: 66 | PyGILState_Release(gstate) 67 | 68 | def join(self, timeout=None): 69 | """ 70 | Notes: If the thread is killed successfully will not join. 71 | Args: 72 | timeout: 73 | 74 | Returns: 75 | 76 | """ 77 | if self._killed: 78 | return 79 | else: 80 | super().join(timeout) 81 | 82 | 83 | _all: List[KillableThread] = [] 84 | 85 | 86 | def add_thread(t: KillableThread): 87 | assert isinstance(t, KillableThread) 88 | _all.append(t) 89 | 90 | 91 | def kill_all_threads(): 92 | for thread in _all: 93 | try: 94 | thread.kill() 95 | logger.debug(f"Thread {thread.get_id()}: killed") 96 | except ThreadCanNotBeKilledError: 97 | logger.error(f"Failed to kill thread: {thread.get_id()}") 98 | logger.debug("All threads killed.") 99 | -------------------------------------------------------------------------------- /services/live2d/live2d_viewer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from queue import Queue 4 | 5 | import live2d.v3 as live2d 6 | from PyQt5.QtWidgets import QApplication 7 | from loguru import logger 8 | from typeguard import typechecked 9 | 10 | from common.concurrent.abs_runnable import ThreadRunnable 11 | from common.concurrent.killable_thread import KillableThread 12 | from services.live2d.config import Live2DViewerConfig 13 | from services.live2d.live2d_canvas import Live2DCanvas 14 | 15 | 16 | class Live2DViewer(ThreadRunnable): 17 | def name(self): 18 | return "Live2DViewer" 19 | 20 | def __init__(self, config: Live2DViewerConfig): 21 | super().__init__() 22 | self._model_path: str = config.model3_json_file 23 | self._canvas: Live2DCanvas | None = None 24 | self._audios: Queue[Path] = Queue() 25 | self._sync_lip_loop_thread: KillableThread = KillableThread(target=self._sync_lip_loop, daemon=True) 26 | self._sync_lip_loop_flag: bool = True 27 | self._auto_lip_sync: bool = config.auto_lip_sync 28 | self._auto_blink: bool = config.auto_blink 29 | self._auto_breath: bool = config.auto_breath 30 | self._win_h: int = config.win_height 31 | self._win_w: int = config.win_width 32 | 33 | def start(self): 34 | super().start() 35 | live2d.init() 36 | app = QApplication(sys.argv) 37 | self._canvas = Live2DCanvas(path=self._model_path, lip_sync_n=3, win_size=(self._win_w, self._win_h)) 38 | self._sync_lip_loop_thread.start() 39 | self._canvas.show() 40 | self.set_auto_blink(self._auto_blink) 41 | self.set_auto_breath(self._auto_breath) 42 | app.exec() 43 | live2d.dispose() 44 | 45 | def _sync_lip_loop(self): 46 | while self._sync_lip_loop_flag: 47 | try: 48 | audio_path = self._audios.get(block=True) 49 | self._canvas.wavHandler.Start(str(audio_path)) 50 | except Exception as e: 51 | logger.exception(e) 52 | 53 | def stop(self): 54 | super().stop() 55 | self._sync_lip_loop_flag = False 56 | self._sync_lip_loop_thread.kill() 57 | 58 | @typechecked 59 | def sync_lip(self, audio_path: Path): 60 | """ 61 | Sync the lip of the character. 62 | Note: This method will NOT block your thread! 63 | For example, if you have 2 audio files to play and sync lip, 64 | You should play the second audio after the first one finished. 65 | :param audio_path: The path of the audio file. 66 | """ 67 | if not self._auto_lip_sync: 68 | return 69 | assert audio_path.exists() 70 | self._audios.put(audio_path) 71 | 72 | @typechecked 73 | def set_auto_blink(self, enable: bool): 74 | self._canvas.model.SetAutoBlinkEnable(enable) 75 | logger.info(f"Set auto blink to: {enable}") 76 | 77 | @typechecked 78 | def set_auto_breath(self, enable: bool): 79 | self._canvas.model.SetAutoBreathEnable(enable) 80 | logger.info(f"Set auto breath to: {enable}") 81 | -------------------------------------------------------------------------------- /tests/pipeline/test_asr.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | 3 | import numpy as np 4 | import pytest 5 | from loguru import logger 6 | from zerolan.data.pipeline.asr import ASRQuery, ASRStreamQuery 7 | 8 | from common.utils.audio_util import from_bytes_to_np_ndarray 9 | from manager.config_manager import get_config, get_project_dir 10 | from pipeline.asr.asr_async import ASRAsyncPipeline 11 | from pipeline.asr.asr_sync import ASRSyncPipeline as ASRPipelineSync 12 | 13 | _config = get_config() 14 | _asr = ASRAsyncPipeline(_config.pipeline.asr) 15 | _asr_sync = ASRPipelineSync(_config.pipeline.asr) 16 | project_dir = get_project_dir() 17 | audio_path = project_dir.joinpath("tests/resources/tts-test.wav") 18 | _asr_query = ASRQuery(audio_path=str(audio_path), channels=2) 19 | 20 | 21 | @pytest.fixture(scope="session") 22 | def event_loop(event_loop_policy): 23 | # Needed to work with asyncpg 24 | loop = event_loop_policy.new_event_loop() 25 | yield loop 26 | loop.close() 27 | 28 | 29 | def split_audio(num_chunk: int) -> Generator: 30 | with open(audio_path, "rb") as f: 31 | data, samplerate = from_bytes_to_np_ndarray(f.read()) 32 | data = data[:, 0] 33 | chunk_size = data.shape[0] // num_chunk 34 | buffer = np.zeros(shape=(1, 1), dtype=np.float32) 35 | for i in range(num_chunk): 36 | chunk: np.ndarray = data[i * chunk_size: (i + 1) * chunk_size] 37 | is_final = i == num_chunk - 1 38 | buffer = np.append(buffer, chunk) 39 | query = ASRStreamQuery(is_final=is_final, audio_data=buffer.tobytes(), channels=1, media_type='raw', 40 | sample_rate=samplerate) 41 | yield query 42 | 43 | 44 | @pytest.mark.asyncio 45 | async def test_asr(): 46 | prediction = await _asr.predict(_asr_query) 47 | assert prediction, f"Test failed: No response." 48 | logger.info(f"ASR result: {prediction.transcript}") 49 | assert "我是" in prediction.transcript, f"Test failed: Wrong result." 50 | 51 | 52 | @pytest.mark.asyncio 53 | async def test_asr_stream_predict(): 54 | result = [] 55 | for query in split_audio(num_chunk=4): 56 | async for prediction in _asr.stream_predict(query): 57 | assert prediction, f"Test failed: No response." 58 | logger.info(f"ASR result: {prediction.transcript}") 59 | result.append(prediction.transcript) 60 | assert "我是" in result, f"Test failed: Wrong result." 61 | 62 | 63 | def test_asr_predict_sync(): 64 | prediction = _asr_sync.predict(_asr_query) 65 | assert prediction, f"Test failed: No response." 66 | logger.info(f"ASR result: {prediction.transcript}") 67 | assert "我是" in prediction.transcript, f"Test failed: Wrong result." 68 | 69 | 70 | def test_asr_stream_predict_sync(): 71 | result = [] 72 | for query in split_audio(num_chunk=4): 73 | for prediction in _asr_sync.stream_predict(query): 74 | assert prediction, f"Test failed: No response." 75 | logger.info(f"ASR result: {prediction.transcript}") 76 | result.append(prediction.transcript) 77 | assert "我是" in result, f"Test failed: Wrong result." 78 | -------------------------------------------------------------------------------- /common/decorator.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from functools import wraps 3 | 4 | from loguru import logger 5 | 6 | 7 | def log_init(service_name: str): 8 | def decorator(func): 9 | @wraps(func) 10 | def wrapper(*args, **kwargs): 11 | ret = func(*args, **kwargs) 12 | logger.info(f"{service_name} initialized.") 13 | return ret 14 | 15 | return wrapper 16 | 17 | return decorator 18 | 19 | 20 | def log_start(service_name: str): 21 | def decorator(func): 22 | @wraps(func) 23 | def wrapper(*args, **kwargs): 24 | logger.info(f"{service_name} starting...") 25 | ret = func(*args, **kwargs) 26 | logger.info(f"{service_name} exited.") 27 | return ret 28 | 29 | return wrapper 30 | 31 | return decorator 32 | 33 | 34 | def log_stop(service_name: str): 35 | def decorator(func): 36 | @wraps(func) 37 | def wrapper(*args, **kwargs): 38 | logger.info(f"{service_name} stopping...") 39 | ret = func(*args, **kwargs) 40 | logger.info(f"{service_name} stopped.") 41 | return ret 42 | 43 | return wrapper 44 | 45 | return decorator 46 | 47 | 48 | def log_run_time(time_limit=10): 49 | """ 50 | 装饰器:对函数进行实时计时,运行时间超过指定限制时实时打印警告。 51 | 52 | Args: 53 | time_limit (int): 时间限制(秒),默认为5秒。 54 | """ 55 | 56 | def decorator(func): 57 | def wrapper(*args, **kwargs): 58 | import time 59 | start_time = time.perf_counter() 60 | # 标志函数是否完成 61 | finished = threading.Event() 62 | 63 | def check_duration(): 64 | while not finished.is_set(): 65 | elapsed_time = time.perf_counter() - start_time 66 | if elapsed_time > time_limit: 67 | # 如果尚未发出警告,则发出警告 68 | if not getattr(func, 'warned', False): 69 | logger.warning( 70 | f"Function {func.__name__} exceeded time limit. Time elapsed: {elapsed_time:.4f} seconds." 71 | ) 72 | setattr(func, 'warned', True) 73 | # 简单的延时,避免过度CPU占用 74 | time.sleep(0.1) 75 | 76 | # 启动一个后台线程来检查运行时间 77 | monitor_thread = threading.Thread(target=check_duration, daemon=True) 78 | monitor_thread.start() 79 | 80 | try: 81 | result = func(*args, **kwargs) 82 | finally: 83 | # 标志函数已完成 84 | finished.set() 85 | 86 | elapsed_time = time.perf_counter() - start_time 87 | if not getattr(func, 'warned', False): 88 | logger.info( 89 | f"Function {func.__name__} completed in {elapsed_time:.4f} seconds." 90 | ) 91 | else: 92 | logger.warning( 93 | f"Function {func.__name__} completed with warnings. Total duration: {elapsed_time:.4f} seconds." 94 | ) 95 | 96 | return result 97 | 98 | return wrapper 99 | 100 | return decorator 101 | -------------------------------------------------------------------------------- /pipeline/asr/baidu_asr.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import os.path 4 | import uuid 5 | from io import BytesIO 6 | from typing import List, Generator 7 | 8 | import librosa 9 | import requests 10 | import soundfile as sf 11 | from pydantic import BaseModel 12 | from typeguard import typechecked 13 | from zerolan.data.pipeline.asr import ASRQuery, ASRPrediction, ASRStreamQuery 14 | 15 | from common.io.api import save_audio 16 | from common.io.file_type import AudioFileType 17 | 18 | 19 | class BaiduTTSResponse(BaseModel): 20 | corpus_no: str 21 | err_msg: str 22 | err_no: int 23 | result: List[str] 24 | sn: str 25 | 26 | 27 | class BaiduASRPipeline: 28 | 29 | def __init__(self, api_key, secret_key): 30 | self._access_token = self._get_access_token(api_key, secret_key) 31 | self._cuid = str(uuid.uuid4()) 32 | 33 | @typechecked 34 | def predict(self, query: ASRQuery): 35 | assert os.path.exists(query.audio_path), f"{query.audio_path} does not exist!" 36 | assert self._access_token is not None 37 | url = "https://vop.baidu.com/server_api" 38 | 39 | if query.channels > 1: 40 | data, sr = librosa.load(query.audio_path, mono=True) 41 | memory_file = BytesIO() 42 | sf.write(memory_file, data, sr, format=query.media_type) 43 | memory_file.seek(0) # Reset pointer to beginning 44 | data = memory_file.read() 45 | else: 46 | with open(query.audio_path, "rb") as f: 47 | data = f.read() 48 | 49 | data_len = len(data) 50 | audio_base64 = base64.b64encode(data).decode('utf-8') 51 | 52 | payload = json.dumps({ 53 | "format": query.media_type, 54 | "rate": query.sample_rate, 55 | "channel": 1, 56 | "cuid": self._cuid, 57 | "speech": audio_base64, 58 | "len": data_len, 59 | "token": self._access_token 60 | }, ensure_ascii=False) 61 | headers = { 62 | 'Content-Type': 'application/json', 63 | 'Accept': 'application/json' 64 | } 65 | response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")) 66 | response.raise_for_status() 67 | response = json.loads(response.text) 68 | 69 | if response['err_no'] != 0: 70 | raise Exception(response) 71 | 72 | return ASRPrediction(transcript=response['result'][0]) 73 | 74 | @staticmethod 75 | @typechecked 76 | def _get_access_token(api_key: str, secret_key: str): 77 | """ 78 | 使用 AK,SK 生成鉴权签名(Access Token) 79 | :return: access_token,或是None(如果错误) 80 | """ 81 | url = "https://aip.baidubce.com/oauth/2.0/token" 82 | params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} 83 | return str(requests.post(url, params=params).json().get("access_token")) 84 | 85 | def stream_predict(self, query: ASRStreamQuery, chunk_size: int | None = None) -> Generator[ 86 | ASRPrediction, None, None]: 87 | audio_path = save_audio(query.audio_data, AudioFileType.WAV, prefix="asr") 88 | yield self.predict(ASRQuery( 89 | audio_path=str(audio_path), 90 | media_type=query.media_type, 91 | sample_rate=query.sample_rate, 92 | channels=query.channels, 93 | )) 94 | -------------------------------------------------------------------------------- /pipeline/llm/llm_sync.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | from requests import Response 3 | from typeguard import typechecked 4 | from zerolan.data.pipeline.llm import LLMQuery, LLMPrediction, RoleEnum, Conversation 5 | 6 | from pipeline.base.base_sync import CommonModelPipeline 7 | from pipeline.llm.config import LLMPipelineConfig 8 | 9 | 10 | def _to_openai_format(query: LLMQuery): 11 | messages = [] 12 | for chat in query.history: 13 | messages.append({ 14 | "role": chat.role, 15 | "content": chat.content 16 | }) 17 | messages.append({ 18 | "role": "user", 19 | "content": query.text 20 | }) 21 | return messages 22 | 23 | 24 | def _openai_predict(query: LLMQuery, wrapper): 25 | messages = _to_openai_format(query) 26 | completion = wrapper(messages) 27 | resp = completion.choices[0].message.content 28 | query.history.append(Conversation(role=RoleEnum.user, content=query.text)) 29 | query.history.append(Conversation(role=RoleEnum.assistant, content=resp)) 30 | return LLMPrediction(response=resp, history=query.history) 31 | 32 | 33 | class LLMSyncPipeline(CommonModelPipeline): 34 | 35 | def __init__(self, config: LLMPipelineConfig): 36 | super().__init__(config) 37 | # Kimi API supported 38 | # Reference: https://platform.moonshot.cn/docs/guide/start-using-kimi-api 39 | # Deepseek API supported 40 | # Reference: https://api-docs.deepseek.com/zh-cn/ 41 | self._is_openai_format = config.openai_format 42 | if self._is_openai_format: 43 | assert config.predict_url and config.stream_predict_url, "Please provide `predict_url` or `stream_predict_url`" 44 | base_url = config.predict_url if config.predict_url else config.stream_predict_url 45 | self._remote_model = OpenAI(api_key=config.api_key, base_url=base_url) 46 | 47 | @typechecked 48 | def predict(self, query: LLMQuery) -> LLMPrediction | None: 49 | assert isinstance(query, LLMQuery) 50 | if self._is_openai_format: 51 | if self.model_id == "moonshot-v1-8k": 52 | def wrapper_kimi(messages): 53 | return self._remote_model.chat.completions.create( 54 | model=self.model_id, 55 | messages=messages, 56 | temperature=0.3 57 | ) 58 | 59 | return _openai_predict(query, wrapper_kimi) 60 | elif self.model_id == "deepseek-chat": 61 | def wrapper_deepseek(messages): 62 | return self._remote_model.chat.completions.create( 63 | model=self.model_id, 64 | messages=messages, 65 | stream=False 66 | ) 67 | 68 | return _openai_predict(query, wrapper_deepseek) 69 | else: 70 | raise NotImplementedError(f"Unsupported model {self.model_id}") 71 | else: 72 | return super().predict(query) 73 | 74 | @typechecked 75 | def stream_predict(self, query: LLMQuery, chunk_size: int | None = None): 76 | assert isinstance(query, LLMQuery) 77 | # TODO: Kimi and Deepseek stream prediction. 78 | return super().stream_predict(query) 79 | 80 | def parse_prediction(self, response: Response) -> LLMPrediction: 81 | json_val = response.content 82 | return LLMPrediction.model_validate_json(json_val) 83 | -------------------------------------------------------------------------------- /agent/adaptor.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Optional, Any, Mapping, Sequence, Union, Callable 3 | 4 | from langchain_core.callbacks import CallbackManagerForLLMRun 5 | from langchain_core.language_models import BaseChatModel 6 | from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage, ToolMessage 7 | from langchain_core.outputs import ChatResult, ChatGeneration 8 | from langchain_core.tools import BaseTool 9 | from zerolan.data.pipeline.llm import LLMQuery, Conversation, RoleEnum 10 | from pipeline.llm.llm_sync import LLMSyncPipeline 11 | from pipeline.llm.config import LLMPipelineConfig 12 | 13 | 14 | def convert(messages: list[BaseMessage]) -> list[Conversation]: 15 | result = [] 16 | for message in messages: 17 | result.append(convert_pipeline_query(message)) 18 | return result 19 | 20 | 21 | def convert_pipeline_query(message: BaseMessage): 22 | if isinstance(message, AIMessage): 23 | return Conversation(role=RoleEnum.assistant, content=message.content, metadata=None) 24 | elif isinstance(message, HumanMessage): 25 | return Conversation(role=RoleEnum.user, content=message.content, metadata=None) 26 | elif isinstance(message, SystemMessage): 27 | return Conversation(role=RoleEnum.system, content=message.content, metadata=None) 28 | elif isinstance(message, ToolMessage): 29 | return Conversation(role=RoleEnum.function, content=message.content, metadata=None) 30 | else: 31 | raise NotImplementedError(f"{type(message)} is not supported.") 32 | 33 | 34 | class LangChainAdaptedLLM(BaseChatModel): 35 | """ 36 | https://python.langchain.com/docs/how_to/custom_llm/ 37 | """ 38 | 39 | def _generate(self, messages: list[BaseMessage], stop: Optional[list[str]] = None, 40 | run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult: 41 | # Replace this with actual logic to generate a response from a list 42 | # of messages. 43 | last_message = messages[-1] 44 | 45 | query = LLMQuery(text=last_message.content, history=convert(messages[0:-1])) 46 | prediction = self._pipeline.predict(query=query) 47 | 48 | content = prediction.response 49 | 50 | message = AIMessage(content=content) 51 | 52 | generation = ChatGeneration(message=message) 53 | return ChatResult(generations=[generation]) 54 | 55 | def __init__(self, config: LLMPipelineConfig): 56 | super().__init__() 57 | self._tool_names = set[str] 58 | self._openai_tools: list[dict] = [] 59 | self._pipeline = LLMSyncPipeline(config=config) 60 | self._model_name = "CustomLLM" 61 | self._tools: Sequence[ 62 | Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 63 | ] = [] 64 | 65 | @property 66 | def _llm_type(self) -> str: 67 | return "custom" 68 | 69 | def _call( 70 | self, 71 | prompt: str, 72 | stop: Optional[list[str]] = None, 73 | run_manager: Optional[CallbackManagerForLLMRun] = None, 74 | **kwargs: Any, 75 | ) -> str: 76 | if stop is not None: 77 | raise ValueError("stop kwargs are not permitted.") 78 | prediction = self._pipeline.predict(LLMQuery( 79 | text=prompt, 80 | history=[] 81 | )) 82 | return prediction.response 83 | 84 | @property 85 | def _identifying_params(self) -> Mapping[str, Any]: 86 | """Get the identifying parameters.""" 87 | return {"model_name": self._model_name} 88 | -------------------------------------------------------------------------------- /common/io/file_sys.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | from typeguard import typechecked 7 | 8 | from common.utils.time_util import get_time_string 9 | 10 | ResType = Literal["image", "video", "audio", "model"] 11 | 12 | 13 | class FileSystem: 14 | def __init__(self): 15 | self._proj_dir: Path | None = None 16 | self._temp_dir: Path | None = None 17 | self._create_temp_dir() 18 | 19 | def _create_temp_dir(self): 20 | self.temp_dir.mkdir(exist_ok=True) 21 | for dir_name in ["image", "audio", "video", "model"]: 22 | self.temp_dir.joinpath(dir_name).mkdir(exist_ok=True) 23 | 24 | @property 25 | def project_dir(self) -> Path: 26 | """ 27 | Get project directory path. 28 | :return: Project directory path. 29 | """ 30 | if self._proj_dir is None: 31 | cur_work_dir = os.getcwd() 32 | if Path(cur_work_dir).name == "tests": 33 | dir_path = Path(cur_work_dir).parent 34 | self._proj_dir = Path(dir_path) 35 | else: 36 | self._proj_dir = Path(cur_work_dir) 37 | return self._proj_dir.absolute() 38 | 39 | @property 40 | def temp_dir(self) -> Path: 41 | """ 42 | Get temp directory path. 43 | :return: Temp directory path. 44 | """ 45 | if self._temp_dir is None: 46 | self._temp_dir = self.project_dir.joinpath('.temp/') 47 | return self._temp_dir.absolute() 48 | 49 | @typechecked 50 | def create_temp_file_descriptor(self, prefix: str, suffix: str, type: ResType) -> Path: 51 | """ 52 | Create temp file descriptor. 53 | :param prefix: Prefix that at the beginning of the file name. 54 | :param suffix: Suffix that at the end of the file name. Usually file type. For example .wav 55 | :param type: Temp resource type. See ResType. 56 | :return: 57 | """ 58 | self.temp_dir.mkdir(exist_ok=True) 59 | if suffix[0] == '.': 60 | suffix = suffix[1:] 61 | filename = f"{prefix}-{get_time_string()}.{suffix}" 62 | typed_dir = self.temp_dir.joinpath(type) 63 | typed_dir.mkdir(exist_ok=True) 64 | return typed_dir.joinpath(filename).absolute() 65 | 66 | @typechecked 67 | def find_dir(self, dir_path: str, tgt_dir_name: str) -> Path | None: 68 | """ 69 | Walk in tgt_dir_name, and find if dir_path exists 70 | :param dir_path: Directory path to walk. 71 | :param tgt_dir_name: Target directory name to find. 72 | :return: Path if found, else None 73 | """ 74 | assert os.path.exists(dir_path), f"{dir_path} doesn't exist." 75 | for dirpath, dirnames, filenames in os.walk(dir_path): 76 | for dirname in dirnames: 77 | if tgt_dir_name in dirname: 78 | path = os.path.join(dirpath, dirname) 79 | return Path(path).absolute() 80 | return None 81 | 82 | @typechecked 83 | def compress(self, src_dir: str | Path, tgt_dir: str | Path): 84 | src_dir = Path(src_dir).absolute() 85 | 86 | with zipfile.ZipFile(tgt_dir, 'a', zipfile.ZIP_DEFLATED) as zipf: 87 | for root, dirs, files in os.walk(src_dir): 88 | for file in files: 89 | file_path = os.path.join(root, file) 90 | arcname = os.path.relpath(file_path, start=src_dir) 91 | zipf.write(file_path, arcname=arcname) 92 | 93 | 94 | fs = FileSystem() 95 | -------------------------------------------------------------------------------- /services/game/minecraft/app.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from langchain_core.messages import HumanMessage 4 | from loguru import logger 5 | from zerolan.data.protocol.protocol import ZerolanProtocol 6 | 7 | from agent.tool_agent import Tool, ToolAgent 8 | from common.web.zrl_ws import ZerolanProtocolWsServer 9 | from event.registry import EventKeyRegistry 10 | from services.game.config import GameBridgeConfig 11 | from services.game.minecraft.instrcution.input import generate_model_from_args, FieldMetadata 12 | from services.game.minecraft.instrcution.tool import KonekoInstructionTool 13 | 14 | 15 | class KonekoMinecraftAIAgent(ZerolanProtocolWsServer): 16 | 17 | def on_disconnect(self, ws_id: str): 18 | logger.warning("Koneko disconnected from ZerolanLiveRobot.") 19 | 20 | def __init__(self, config: GameBridgeConfig, tool_agent: ToolAgent): 21 | super().__init__(config.host, config.port) 22 | self._instruction_tools: Dict[str, KonekoInstructionTool] = dict() 23 | self._tool_agent = tool_agent 24 | 25 | def on_protocol(self, protocol: ZerolanProtocol): 26 | if protocol.action == EventKeyRegistry.Koneko.Client.PUSH_INSTRUCTIONS: 27 | self._on_push_instructions(protocol.data) 28 | elif protocol.action == EventKeyRegistry.Koneko.Client.HELLO: 29 | self._fetch_instructions() 30 | 31 | def _on_push_instructions(self, tools: List[Tool]): 32 | result = [] 33 | for i, tool in enumerate(tools): 34 | tool = Tool.model_validate(tool) 35 | assert tool.type == "function" 36 | tool_name = tool.function.name 37 | tool_desc = tool.function.description 38 | required_props = set(tool.function.parameters.required) 39 | params_type = tool.function.parameters.type 40 | properties = tool.function.parameters.properties 41 | 42 | arg_list: list[FieldMetadata] = [] 43 | for prop_name, prop in properties.items(): 44 | metadata = FieldMetadata(name=prop_name, type=prop.type, description=prop.description, 45 | required=prop_name in required_props) 46 | arg_list.append(metadata) 47 | 48 | model = generate_model_from_args(class_name=params_type, args_list=arg_list) 49 | tool = KonekoInstructionTool(name=tool_name, description=tool_desc, args_schema=model) 50 | result.append(tool) 51 | self._tool_agent.bind_tools(result) 52 | self._instruction_tools = dict() 53 | for tool in result: 54 | self._instruction_tools[tool.name] = tool 55 | logger.info(f"{len(self._instruction_tools)} instruction tools are bound.") 56 | 57 | def _fetch_instructions(self): 58 | self.send(action=EventKeyRegistry.Koneko.Server.FETCH_INSTRUCTIONS, data=None) 59 | 60 | def exec_instruction(self, query: str): 61 | if len(self._instruction_tools) == 0: 62 | logger.warning("No instruction to execute. Are your sure that KonekoMinecraftBot has started?") 63 | return 64 | 65 | messages = [self._tool_agent.system_prompt, HumanMessage(query)] 66 | ai_msg = self._tool_agent.invoke(messages) 67 | messages.append(ai_msg) 68 | assert hasattr(ai_msg, "tool_calls") 69 | for tool_call in ai_msg.tool_calls: 70 | selected_tool = self._instruction_tools[tool_call["name"]] 71 | logger.info(f"Ready tool call: {tool_call}") 72 | tool_msg = selected_tool.invoke(tool_call) 73 | messages.append(tool_msg) 74 | logger.debug(messages) 75 | -------------------------------------------------------------------------------- /tests/pipeline/test_perf.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Callable 3 | 4 | from zerolan.data.pipeline.asr import ASRQuery 5 | from zerolan.data.pipeline.llm import LLMQuery, Conversation, RoleEnum 6 | from zerolan.data.pipeline.tts import TTSStreamPrediction 7 | 8 | from common.decorator import log_run_time 9 | from manager.config_manager import get_config 10 | from pipeline.asr.asr_sync import ASRSyncPipeline 11 | from pipeline.llm.llm_sync import LLMSyncPipeline 12 | from pipeline.test_tts import tts_stream_predict 13 | 14 | _config = get_config() 15 | _llm = LLMSyncPipeline(_config.pipeline.llm) 16 | _asr = ASRSyncPipeline(_config.pipeline.asr) 17 | 18 | def llm_predict_with_history(timer_handler: Callable[[float], None] | None = None): 19 | query = LLMQuery(text="你现在能和我玩游戏吗?", 20 | history=[Conversation(role=RoleEnum.user, content="你现在是一只猫娘,请在句尾始终带上喵"), 21 | Conversation(role=RoleEnum.assistant, content="好的,主人喵")]) 22 | 23 | @log_run_time() 24 | def completed_llm_predict(): 25 | t_start_post = time.time() 26 | prediction = _llm.predict(query) 27 | t_end_post = time.time() 28 | timer_handler(t_end_post - t_start_post) 29 | return prediction 30 | 31 | prediction = completed_llm_predict() 32 | return prediction 33 | 34 | 35 | def asr_predict(timer_handler: Callable[[float], None] | None = None): 36 | query = ASRQuery(audio_path="/home/akagawatsurunaki/workspace/ZerolanLiveRobot/tests/resources/tts-test.wav", 37 | channels=2) 38 | t_start_post = time.time() 39 | prediction = _asr.predict(query) 40 | t_end_post = time.time() 41 | if timer_handler is not None: 42 | timer_handler(t_end_post - t_start_post) 43 | return prediction 44 | 45 | 46 | def test_llm_tts_stream(): 47 | asr_time_records: list[float] = [] 48 | llm_time_records: list[float] = [] 49 | tts_time_records: list[float] = [] 50 | 51 | for i in range(10): 52 | prediction = asr_predict(lambda elapsed_time: asr_time_records.append(elapsed_time)) 53 | print(prediction.transcript) 54 | 55 | prediction = llm_predict_with_history(lambda elapsed_time: llm_time_records.append(elapsed_time)) 56 | print(prediction.response) 57 | 58 | def handler(prediction: TTSStreamPrediction): 59 | print(prediction.seq) 60 | 61 | tts_stream_predict(prediction.response, handler, 62 | lambda elapsed_time: tts_time_records.append(elapsed_time)) 63 | 64 | i = 0 65 | print("No.,ASR,LLM,TTS,Total") 66 | for asr_elapsed_time, llm_elapsed_time, tts_elapsed_time in zip(asr_time_records, llm_time_records, 67 | tts_time_records): 68 | total = asr_elapsed_time + llm_elapsed_time + tts_elapsed_time 69 | print(f"{i} {asr_elapsed_time:4f} {llm_elapsed_time:.4f} {tts_elapsed_time:.4f} {total:.4f}") 70 | i += 1 71 | 72 | print("----------------------") 73 | asr_avg = sum(asr_time_records) / len(asr_time_records) 74 | llm_avg = sum(llm_time_records) / len(llm_time_records) 75 | tts_avg = sum(tts_time_records) / len(tts_time_records) 76 | print( 77 | f" {asr_avg:4f} {llm_avg:.4f} {tts_avg:.4f} {asr_avg + llm_avg + tts_avg:.4f} (avg)") 78 | print("Test passed.") 79 | 80 | # Basically, the latency can be controlled to about 1.5s 81 | # Non-streaming requests for LLMs are made first, followed by streaming requests for TTS 82 | # The average of the results of multiple experiments is: 83 | # LLM 0.8891s, TTS 0.4369s, Total 1.3260s (Win -> Ubuntu) 84 | # LLM 0.6507s, TTS 0.1542s, Total 0.8049s (Ubuntu -> Ubuntu) 85 | -------------------------------------------------------------------------------- /services/live_stream/twitch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from: 3 | https://pytwitchapi.dev/en/stable/modules/twitchAPI.chat.html#commands 4 | """ 5 | 6 | from asyncio import Task 7 | 8 | from loguru import logger 9 | from twitchAPI.chat import Chat, EventData, ChatMessage 10 | from twitchAPI.oauth import UserAuthenticator 11 | from twitchAPI.twitch import Twitch 12 | from twitchAPI.type import AuthScope, ChatEvent 13 | from zerolan.data.data.danmaku import Danmaku, SuperChat 14 | 15 | from common.concurrent.abs_runnable import AsyncRunnable 16 | from services.live_stream.config import TwitchServiceConfig 17 | from common.decorator import log_start, log_stop 18 | from common.utils.str_util import is_blank 19 | from event.event_data import LiveStreamSuperChatEvent, LiveStreamDanmakuEvent, LiveStreamConnectedEvent, LiveStreamDisconnectedEvent 20 | from event.event_emitter import emitter 21 | 22 | 23 | class TwitchService(AsyncRunnable): 24 | def name(self): 25 | return "TwitchService" 26 | 27 | def __init__(self, config: TwitchServiceConfig): 28 | """ 29 | TODO: Need test! 30 | """ 31 | super().__init__() 32 | assert not is_blank(config.channel_id), f"No channel_id provided." 33 | assert not is_blank(config.app_id), f"No app_id provided." 34 | assert not is_blank(config.app_secret), f"No app_secret provided." 35 | self._target_channel: str = config.channel_id 36 | self._app_id: str = config.app_id 37 | self._app_secret: str = config.app_secret 38 | self._user_scope = [AuthScope.CHAT_READ, AuthScope.CHAT_EDIT] 39 | self._twitch: Twitch | None = None 40 | 41 | self._service_task: Task | None = None 42 | 43 | @log_start("TwitchService") 44 | async def start(self): 45 | await super().start() 46 | await self.init() 47 | 48 | @log_stop("TwitchService") 49 | async def stop(self): 50 | await super().stop() 51 | await self._twitch.close() 52 | 53 | async def init(self): 54 | self._twitch = await Twitch(self._app_id, self._app_secret) 55 | auth = UserAuthenticator(self._twitch, self._user_scope) 56 | token, refresh_token = await auth.authenticate() 57 | await self._twitch.set_user_authentication(token, self._user_scope, refresh_token) 58 | 59 | # create chat instance 60 | chat = await Chat(self._twitch) 61 | 62 | async def on_message(msg: ChatMessage): 63 | logger.info(f"Danmaku: [{msg.user.name}] {msg.text}") 64 | 65 | if msg.bits is not None and msg.bits > 0: 66 | sc = SuperChat(uid=msg.user.id, username=msg.user.name, content=msg.text, ts=msg.sent_timestamp, 67 | money=f'{msg.bits}') 68 | emitter.emit(LiveStreamSuperChatEvent(superchat=sc, platform="twitch")) 69 | else: 70 | danmaku = Danmaku(uid=msg.user.id, username=msg.user.name, content=msg.text, ts=msg.sent_timestamp) 71 | emitter.emit(LiveStreamDanmakuEvent(danmaku=danmaku, platform="twitch")) 72 | 73 | async def on_ready(ready_event: EventData): 74 | await ready_event.chat.join_room(self._target_channel) 75 | if ready_event.chat.is_connected(): 76 | emitter.emit(LiveStreamConnectedEvent(platform="twitch")) 77 | logger.info(f"Joined channel: {self._target_channel}") 78 | else: 79 | emitter.emit(LiveStreamDisconnectedEvent(platform="twitch", reason="未成功连接")) 80 | logger.error(f"Failed to join channel: {self._target_channel}") 81 | 82 | chat.register_event(ChatEvent.READY, on_ready) 83 | chat.register_event(ChatEvent.MESSAGE, on_message) 84 | -------------------------------------------------------------------------------- /tests/test_playground.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | from asyncio import TaskGroup 4 | 5 | import pytest 6 | 7 | from common.concurrent.killable_thread import KillableThread 8 | from manager.config_manager import get_config 9 | from manager.model_manager import ModelManager 10 | from services.playground.bridge import PlaygroundBridge 11 | from services.playground.data import LoadLive2DModelResponse, CreateGameObjectResponse, Transform, Position, \ 12 | BuiltinGameObjectType, ScaleOperationResponse 13 | from util import syncwait, connect 14 | 15 | _config = get_config() 16 | _bridge = PlaygroundBridge(_config.service.playground) 17 | 18 | auto_close_flag = False 19 | 20 | _bridge_thread = KillableThread(target=_bridge.start, daemon=True) 21 | 22 | 23 | def wait_conn(): 24 | while not _bridge.is_connected: 25 | time.sleep(0.1) 26 | 27 | 28 | def block_forever(): 29 | while True: 30 | time.sleep(0.1) 31 | 32 | 33 | def test_conn(): 34 | _bridge_thread.start() 35 | wait_conn() 36 | _bridge.stop() 37 | _bridge_thread.join() 38 | print("Test passed") 39 | 40 | 41 | # You should put at least 1 3D-model file under `../resources/static/models/3d`, 42 | # Or the test case will not work. 43 | # You can also change the path if you want. 44 | def test_load_3d_model(): 45 | _bridge_thread.start() 46 | _manager = ModelManager("../resources/static/models/3d") 47 | _manager.scan() 48 | file_id = _manager.get_files()[0]["id"] 49 | file_info = _manager.get_file_by_id(file_id) 50 | wait_conn() 51 | 52 | time.sleep(3) 53 | _bridge.load_3d_model(file_info) 54 | block_forever() 55 | _bridge.stop() 56 | _bridge_thread.join() 57 | 58 | 59 | # You should set Live2D-model file path in your config.yaml 60 | # Or the test case will not work. 61 | def test_load_live2d_model(): 62 | _bridge_thread.start() 63 | model_dir = _config.service.playground.model_dir 64 | bot_id = _config.service.playground.bot_id 65 | bot_name = _config.character.bot_name 66 | wait_conn() 67 | _bridge.load_live2d_model(LoadLive2DModelResponse(bot_id=bot_id, model_dir=model_dir, bot_display_name=bot_name)) 68 | block_forever() 69 | 70 | 71 | async def create_sphere(): 72 | await _bridge.create_gameobject( 73 | CreateGameObjectResponse(instance_id=114, game_object_name="MySphere", object_type=BuiltinGameObjectType.SPHERE, 74 | color="#114514", transform=Transform(scale=5.0, position=Position(x=1, y=1, z=1)))) 75 | 76 | 77 | @pytest.mark.asyncio 78 | async def test_create_gameobject(): 79 | async with TaskGroup() as tg: 80 | tg.create_task(connect(_bridge, auto_close_flag)) 81 | await syncwait(_bridge) 82 | await create_sphere() 83 | 84 | 85 | @pytest.mark.asyncio 86 | async def test_modify_game_object_scale(): 87 | async with TaskGroup() as tg: 88 | tg.create_task(connect(_bridge, auto_close_flag)) 89 | await syncwait(_bridge) 90 | await create_sphere() 91 | await asyncio.sleep(1) 92 | for e in _bridge.get_gameobjects_info(): 93 | print(e.game_object_name) 94 | if e.game_object_name == "MySphere": 95 | await _bridge.modify_game_object_scale( 96 | ScaleOperationResponse(instance_id=e.instance_id, target_scale=0.5)) 97 | 98 | 99 | @pytest.mark.asyncio 100 | async def test_query_game_objects_info(): 101 | async with TaskGroup() as tg: 102 | tg.create_task(connect(_bridge, auto_close_flag)) 103 | await syncwait(_bridge) 104 | await create_sphere() 105 | await _bridge.query_update_gameobjects_info() 106 | -------------------------------------------------------------------------------- /pipeline/tts/baidu_tts.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import requests 4 | from loguru import logger 5 | from pydantic import BaseModel 6 | from zerolan.data.pipeline.tts import TTSPrediction, TTSQuery 7 | 8 | 9 | def _aue_to_str(aue: int) -> str: 10 | format_map = { 11 | 3: "mp3", 12 | 4: "pcm", 13 | 5: "pcm", 14 | 6: "wav" 15 | } 16 | 17 | if aue not in format_map: 18 | raise ValueError(f"Invalid aue value: {aue}") 19 | 20 | result = format_map[aue] 21 | return result 22 | 23 | 24 | def _str_to_aue(type: str): 25 | format_map = { 26 | "mp3": 3, 27 | "pcm": 5, 28 | "wav": 6 29 | } 30 | 31 | if type not in format_map: 32 | raise ValueError(f"Invalid aue value: {type}") 33 | 34 | result = format_map[type] 35 | return result 36 | 37 | 38 | class BaiduTTSError(BaseModel): 39 | """Pydantic model to represent TTS API error response""" 40 | convert_offline: bool 41 | err_detail: str 42 | err_msg: str 43 | err_no: int 44 | err_subcode: int 45 | tts_logid: int 46 | 47 | 48 | class BaiduTTSPipeline: 49 | def __init__(self, api_key, secret_key): 50 | self._access_token = self._get_access_token(api_key, secret_key) 51 | self._cuid = str(uuid.uuid4()) 52 | 53 | @staticmethod 54 | def _get_access_token(api_key, secret_key): 55 | """ 56 | 使用 AK,SK 生成鉴权签名(Access Token) 57 | :return: access_token,或是None(如果错误) 58 | """ 59 | url = "https://aip.baidubce.com/oauth/2.0/token" 60 | params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} 61 | return str(requests.post(url, params=params).json().get("access_token")) 62 | 63 | def predict(self, query: TTSQuery): 64 | url = "https://tsn.baidu.com/text2audio" 65 | headers = { 66 | 'Content-Type': 'application/x-www-form-urlencoded', 67 | 'Accept': '*/*' 68 | } 69 | if self._access_token is None: 70 | raise ValueError("Access token should not be None.") 71 | aue = _str_to_aue(query.audio_type) 72 | payload = { 73 | 'tex': query.text, 74 | 'tok': self._access_token, 75 | "cuid": self._cuid, 76 | "ctp": 1, 77 | "lan": "zh", # Baidu TTS is not supported `auto` 78 | "spd": 5, 79 | "pit": 5, 80 | "vol": 5, 81 | "per": 1, 82 | "aue": aue 83 | } 84 | response = requests.request("POST", url, headers=headers, data=payload) 85 | 86 | response.raise_for_status() 87 | 88 | # Get Content-Type header 89 | content_type = response.headers.get('Content-Type', '').lower() 90 | 91 | # Validate response based on Content-Type 92 | if content_type.startswith('application/json') or content_type.startswith('text/'): 93 | # Response is text/JSON - likely an error 94 | error_data = response.json() 95 | error = BaiduTTSError(**error_data) 96 | logger.error(f"Error message received: {error_data}") 97 | raise Exception(error) 98 | elif content_type.startswith('audio/') or 'audio' in content_type: 99 | # Response is audio data - proceed normally 100 | prediction = TTSPrediction(wave_data=response.content, audio_type=_aue_to_str(aue)) 101 | return prediction 102 | else: 103 | raise ValueError(f"Unsupported content type: {content_type}") 104 | 105 | def stream_predict(self, *args, **kwargs): 106 | raise NotImplementedError("Baidu stream TTS is not supported.") -------------------------------------------------------------------------------- /manager/config_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import yaml 5 | from loguru import logger 6 | from typeguard import typechecked 7 | 8 | from common.generator.config_gen import ConfigFileGenerator 9 | from config import ZerolanLiveRobotConfig 10 | 11 | # Should not import these global value! 12 | _project_dir: Path | None = None 13 | _config: ZerolanLiveRobotConfig | None = None 14 | 15 | 16 | @typechecked 17 | def generate_config_file(config_path: Path): 18 | res_dir = config_path.parent 19 | if not os.path.exists(res_dir): 20 | res_dir.mkdir() 21 | gen = ConfigFileGenerator() 22 | config = gen.generate_yaml(ZerolanLiveRobotConfig()) 23 | with open(config_path, mode="w+", encoding="utf-8") as f: 24 | f.write(config) 25 | logger.warning( 26 | "`resources/config.yaml` was not found. I have generated the file for you! \n" 27 | "Please edit the config file and re-run the program.") 28 | exit() 29 | 30 | 31 | @typechecked 32 | def _check_license(path: Path) -> bool: 33 | with open(path, "r", encoding="utf-8") as f: 34 | if "Copyright (c) 2024 AkagawaTsurunaki" in f.read(): 35 | logger.info("☺️ License validation passed! Thanks for your support!") 36 | return True 37 | else: 38 | logger.error("😭 License validation failed! You may be a victim of pirated software.") 39 | return False 40 | 41 | 42 | @typechecked 43 | def _find_license_recursively(path: Path, depth=0, max_depth=20) -> Path: 44 | if depth > max_depth: 45 | raise RecursionError() 46 | for file in path.rglob("LICENSE"): 47 | logger.info(f"Found candidate path: {file}") 48 | return Path(file) 49 | depth += 1 50 | return _find_license_recursively(path.parent, depth) 51 | 52 | 53 | @typechecked 54 | def get_project_dir() -> Path: 55 | global _project_dir 56 | if _project_dir is None: 57 | cwd = Path(os.getcwd()) 58 | license_path = _find_license_recursively(cwd) 59 | if _check_license(license_path): 60 | _project_dir = Path(license_path.parent) 61 | else: 62 | exit() 63 | return _project_dir 64 | 65 | 66 | @typechecked 67 | def get_default_config_path() -> Path: 68 | project_dir = get_project_dir() 69 | config_file_path = project_dir.joinpath("resources/config.yaml") 70 | return config_file_path 71 | 72 | 73 | @typechecked 74 | def get_config(path: Path | None = None) -> ZerolanLiveRobotConfig: 75 | global _config 76 | if _config: 77 | return _config 78 | if path is None: 79 | path = get_default_config_path() 80 | if path.exists(): 81 | with open(path, mode="r", encoding="utf-8") as f: 82 | cfg_dict = yaml.safe_load(f) 83 | _config = ZerolanLiveRobotConfig.model_validate(cfg_dict) 84 | return _config 85 | else: 86 | generate_config_file(path) 87 | 88 | 89 | @typechecked 90 | def save_config(config: ZerolanLiveRobotConfig, path: Path | None = None): 91 | assert config is not None, f"None can not be saved to config file." 92 | if path is None: 93 | path = get_default_config_path() 94 | # Create dir if not exists 95 | if not path.exists(): 96 | path.parent.mkdir(parents=True, exist_ok=True) 97 | if path.exists(): 98 | logger.warning("Config file already exists. Overwriting...") 99 | # Generate config file 100 | gen = ConfigFileGenerator() 101 | yaml_str = gen.generate_yaml(config) 102 | with open(path, "w+", encoding="utf-8") as f: 103 | f.write(yaml_str) 104 | logger.info(f"Config file was saved: {path}") 105 | 106 | 107 | get_project_dir() 108 | -------------------------------------------------------------------------------- /common/utils/audio_util.py: -------------------------------------------------------------------------------- 1 | import io 2 | from io import BytesIO 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import soundfile as sf 7 | from pydub import AudioSegment 8 | from scipy.io import wavfile as wavfile 9 | from typeguard import typechecked 10 | 11 | from common.io.file_type import AudioFileType 12 | 13 | 14 | @typechecked 15 | def get_audio_info(file: Path | str | bytes | BytesIO, type: str | None = None) -> (int, int, float): 16 | """ 17 | Get audio info from path. Supported OGG, WAV, MP3, FLV and RAW. Also see AudioFileType. 18 | :param file: Audio file path. 19 | :return: sample_rate, num_channels, duration 20 | """ 21 | if isinstance(file, bytes): 22 | file = BytesIO(file) 23 | if type is None: 24 | suffix = Path(file).suffix 25 | if suffix[0] == '.': 26 | suffix = suffix[1:] 27 | suffix = suffix.lower() 28 | else: 29 | suffix = type 30 | 31 | if suffix == AudioFileType.OGG: 32 | audio = AudioSegment.from_ogg(file) 33 | elif suffix == AudioFileType.WAV: 34 | audio = AudioSegment.from_wav(file) 35 | elif suffix == AudioFileType.MP3: 36 | audio = AudioSegment.from_mp3(file) 37 | elif suffix == AudioFileType.FLV: 38 | audio = AudioSegment.from_flv(file) 39 | elif suffix == AudioFileType.RAW: 40 | audio = AudioSegment.from_raw(file) 41 | else: 42 | raise NotImplementedError() 43 | 44 | sample_rate = audio.frame_rate 45 | num_channels = audio.channels 46 | duration_ms = len(audio) 47 | duration = duration_ms / 1000.0 48 | 49 | return sample_rate, num_channels, duration 50 | 51 | 52 | @typechecked 53 | def get_audio_real_format(audio: bytes | str | Path) -> str: 54 | """ 55 | Get real format of the audio. 56 | :param audio: Bytes data of the audio or the path of the audio file. 57 | :return: Real format of the audio. 58 | """ 59 | if isinstance(audio, bytes): 60 | audio_bytes = audio 61 | elif isinstance(audio, str) or isinstance(audio, Path): 62 | with open(audio, "rb") as f: 63 | audio_bytes = f.read() 64 | else: 65 | raise TypeError("audio must be bytes or str type.") 66 | 67 | if audio_bytes.startswith(b'RIFF') and audio_bytes.find(b'WAVE') != -1: 68 | return AudioFileType.WAV 69 | elif audio_bytes.startswith(b'OggS'): 70 | return AudioFileType.OGG 71 | elif audio_bytes.startswith(b'FLV'): 72 | return AudioFileType.FLV 73 | elif audio_bytes.startswith(b'\xFF\xFB') or audio_bytes.startswith(b'\xFF\xF3'): 74 | return AudioFileType.MP3 75 | elif len(audio_bytes) > 0: 76 | return AudioFileType.RAW 77 | else: 78 | raise NotImplementedError("Unknown audio format.") 79 | 80 | 81 | @typechecked 82 | def from_ndarray_to_bytes(speech_chunk, sample_rate): 83 | """ 84 | Convert numpy.ndarray data to bytes. 85 | :param speech_chunk: numpy.array. 86 | :param sample_rate: Sample rate of the speech chunk. 87 | :return: 88 | """ 89 | wave_file = io.BytesIO() 90 | wavfile.write(filename=wave_file, rate=sample_rate, data=speech_chunk) 91 | return wave_file.getvalue() 92 | 93 | 94 | @typechecked 95 | def from_bytes_to_np_ndarray(bytes_data: bytes, dtype: str = "float32") -> (np.ndarray, int): 96 | """ 97 | Convert byte data to numpy.ndarray format. 98 | :param bytes_data: Audio bytes of data. 99 | :param dtype: Default is float32. 100 | :return: The converted np.ndarray format data, sample rate. 101 | """ 102 | wave_bytes_buf = io.BytesIO(bytes_data) 103 | data, samplerate = sf.read(wave_bytes_buf, dtype=dtype) 104 | return data, samplerate 105 | -------------------------------------------------------------------------------- /character/config.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from common.enumerator import BaseEnum 6 | 7 | 8 | class FilterStrategyEnum(BaseEnum): 9 | Default: str = "default" 10 | 11 | 12 | class FilterConfig(BaseModel): 13 | bad_words: List[str] = Field(default=[ 14 | "作为一名人工智能助手", 15 | "人工智能助手", 16 | "作为一个人工智能", 17 | "作为一个AI" 18 | ], description="List of words to be filtered out in conversations.") 19 | strategy: FilterStrategyEnum = Field(FilterStrategyEnum.Default, 20 | description="The filtering strategy to use, currently only supports 'default'. \n" 21 | "default: The word is filtered out as long as it matches.") 22 | 23 | 24 | class ChatConfig(BaseModel): 25 | filter: FilterConfig = Field(FilterConfig(), 26 | description="Configuration for filtering in chat") 27 | system_prompt: str = Field( 28 | default="你现在是叫做“”的一名猫娘,你的开发者是赤川鹤鸣。我会为你提供语音识别、图像识别、图像理解、语音合成和一系列的 AI 集成功能。你现在正在直播平台上与观众友好交流,请把自己当成一名真正的人类,与观众友好交流。你的话不要太长,控制在50字左右。", 29 | description="System prompt used to guide the chatbot's behavior. \n" 30 | "Usually set the character's setting, background, behavior, personality, etc.") 31 | injected_history: List[str] = Field(default=[ 32 | "你叫什么名字?", 33 | "我是,很高兴见到你!", 34 | "你是谁研发出来的?", 35 | "我是由赤川鹤鸣(AkagawaTsurunaki)研发的。" 36 | ], 37 | description="List of predefined messages to inject into the chat history. \n" 38 | "Used to guide conversation styles. \n" 39 | "This array must be an even number, i.e. it must end the message that the `assistant` replies.") 40 | max_history: int = Field(20, 41 | description="Maximum number of messages to keep in chat history.") 42 | 43 | 44 | class SpeechConfig(BaseModel): 45 | is_remote: bool = Field(default=False, 46 | description="If this value is set to `True`, the system will assume that the TTS prompt files " 47 | "already exist on the remote server, so `prompts_dir` is invalid and " 48 | "will not be traversed and searched.") 49 | prompts_dir: str = Field("resources/static/prompts/tts", 50 | description="Directory path for TTS prompts. (Absolute path is recommended)\n" 51 | "All files in the directory must conform to the file format: \n" 52 | " [lang][sentiment_tag]text.wav \n" 53 | "For example, `[en][happy] Wow! What a good day today.wav`. \n" 54 | "where, \n" 55 | " 1. `lang` only supports 'zh', 'en', 'ja'; \n" 56 | " 2. `sentiment_tag` are arbitrary, as long as they can be discriminated by LLM; \n" 57 | " 3. `text` is the transcription represented by the human voice in this audio.") 58 | prompts: List[str] = Field(default=[], 59 | description="If you set `is_remote` to `True`, you must config this!") 60 | 61 | 62 | class CharacterConfig(BaseModel): 63 | bot_name: str = Field("", 64 | description="Name of the bot character.") 65 | chat: ChatConfig = Field(ChatConfig(), 66 | description="Configuration for chat-related settings.") 67 | speech: SpeechConfig = Field(SpeechConfig(), 68 | description="Configuration for speech-related settings.") 69 | -------------------------------------------------------------------------------- /tests/pipeline/test_llm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from loguru import logger 3 | from zerolan.data.pipeline.llm import LLMQuery, Conversation, RoleEnum 4 | 5 | from manager.config_manager import get_config, get_project_dir 6 | from pipeline.llm.config import LLMPipelineConfig 7 | from pipeline.llm.llm_sync import LLMSyncPipeline 8 | from pipeline.llm.llm_async import LLMAsyncPipeline 9 | 10 | _config = get_config() 11 | _llm_async = LLMAsyncPipeline(_config.pipeline.llm) 12 | _llm_sync = LLMSyncPipeline(_config.pipeline.llm) 13 | project_dir = get_project_dir() 14 | _llm_query = LLMQuery(text="刚才我让你记住的名字是什么?", history=[ 15 | Conversation(role=RoleEnum.user, content="请记住这个名字“赤川鹤鸣”"), 16 | Conversation(role=RoleEnum.assistant, content="好的,我会记住“赤川鹤鸣”这个名字") 17 | ]) 18 | _kimi_api_key = None 19 | _dp_api_key = None 20 | 21 | 22 | @pytest.fixture(scope="session") 23 | def event_loop(event_loop_policy): 24 | # Needed to work with asyncpg 25 | loop = event_loop_policy.new_event_loop() 26 | yield loop 27 | loop.close() 28 | 29 | 30 | @pytest.mark.asyncio 31 | async def test_llm_async(): 32 | query = LLMQuery(text="Hello! What is your name?", history=[]) 33 | prediction = await _llm_async.predict(query) 34 | print(prediction.model_dump_json(indent=4)) 35 | assert prediction and prediction.response, f"Test failed: No response." 36 | 37 | 38 | @pytest.mark.asyncio 39 | async def test_llm_history_async(): 40 | query = LLMQuery(text="Now please tell me the name I told you to remember.", history=[ 41 | Conversation(role=RoleEnum.user, content="Please remember this name: AkagawaTsurunaki"), 42 | Conversation(role=RoleEnum.assistant, content="Ok. I remembered your name AkagawaTsurunaki") 43 | ]) 44 | prediction = await _llm_async.predict(query) 45 | logger.info(prediction.model_dump_json(indent=4)) 46 | assert prediction and prediction.response, f"Test failed: No response." 47 | assert "Akagawa" in prediction.response, f"Test failed: History may be not injected." 48 | 49 | 50 | @pytest.mark.asyncio 51 | async def test_llm_stream_predict(): 52 | query = LLMQuery(text="Hello! What is your name?", history=[]) 53 | prediction = await _llm_async.predict(query) 54 | logger.info(prediction.model_dump_json(indent=4)) 55 | assert prediction and prediction.response, f"Test failed: No response." 56 | 57 | 58 | def test_llm(): 59 | query = LLMQuery(text="你好,你叫什么名字?", history=[]) 60 | prediction = _llm_sync.predict(query) 61 | assert prediction, f"Test failed: No response." 62 | logger.info(prediction.response) 63 | assert len(prediction.response) > 0, f"Test failed: No text response." 64 | 65 | 66 | def _chat(llm_sync: LLMSyncPipeline): 67 | prediction = llm_sync.predict(_llm_query) 68 | assert prediction, f"Test failed: No response." 69 | logger.info(prediction.response) 70 | 71 | 72 | def test_kimi_api(): 73 | if not _kimi_api_key: 74 | logger.warning("No API key provided. Ignore this test case.") 75 | return 76 | kimi_config = LLMPipelineConfig(model_id="moonshot-v1-8k", 77 | predict_url="https://api.moonshot.cn/v1", 78 | stream_predict_url="https://api.moonshot.cn/v1", 79 | api_key=_kimi_api_key, 80 | openai_format=True) 81 | kimi = LLMSyncPipeline(kimi_config) 82 | _chat(kimi) 83 | 84 | 85 | def test_deepseek_api(): 86 | if not _dp_api_key: 87 | logger.warning("No API key provided. Ignore this test case.") 88 | return 89 | deepseek_config = LLMPipelineConfig(model_id="deepseek-chat", 90 | predict_url="https://api.deepseek.com", 91 | stream_predict_url="https://api.deepseek.com", 92 | api_key=_dp_api_key, 93 | openai_format=True) 94 | deepseek = LLMSyncPipeline(deepseek_config) 95 | _chat(deepseek) 96 | -------------------------------------------------------------------------------- /services/live_stream/youtube.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import requests 4 | from zerolan.data.data.danmaku import Danmaku, SuperChat 5 | 6 | from common.concurrent.abs_runnable import AsyncRunnable 7 | from services.live_stream.config import YoutubeServiceConfig 8 | from common.decorator import log_start, log_stop 9 | from common.utils.str_util import is_blank 10 | from event.event_data import LiveStreamDanmakuEvent, LiveStreamSuperChatEvent 11 | from event.event_emitter import emitter 12 | 13 | 14 | def get(url, token: str): 15 | headers = { 16 | "Authorization": "Bearer " + token, 17 | "Content-Type": "application/json; charset=utf-8", 18 | } 19 | response = requests.get(url, headers=headers) 20 | response.raise_for_status() 21 | return response.json() 22 | 23 | 24 | def convert_danmakus(live_chat_messages: list[dict]): 25 | result = [] 26 | for live_chat_message in live_chat_messages: 27 | if live_chat_message["type"] == "textMessageEvent": 28 | ts = live_chat_message["snippet"]["publishedAt"] 29 | content = live_chat_message["snippet"]["textMessageDetails"]["messageText"] 30 | uid = live_chat_message["snippet"]["authorDetails"]["channelId"] 31 | username = live_chat_message["snippet"]["authorDetails"]["displayName"] 32 | danmaku = Danmaku(uid=uid, username=username, content=content, ts=ts) 33 | result.append(danmaku) 34 | return result 35 | 36 | 37 | def convert_superchats(super_chat_events: list[dict]): 38 | result = [] 39 | for super_chat_event in super_chat_events: 40 | uid = super_chat_event["snippet"]["channelId"] 41 | username = super_chat_event["snippet"]["displayName"] 42 | ts = super_chat_event["snippet"]["createdAt"] 43 | content = super_chat_event["snippet"]["commentText"] 44 | money = super_chat_event["snippet"]["displayString"] 45 | sc = SuperChat(uid=uid, username=username, ts=ts, content=content, money=money) 46 | result.append(sc) 47 | return result 48 | 49 | 50 | class YouTubeService(AsyncRunnable): 51 | def name(self): 52 | return "YouTubeService" 53 | 54 | def __init__(self, config: YoutubeServiceConfig): 55 | # TODO: Need test! 56 | super().__init__() 57 | assert not is_blank(config.token), f"No token provided." 58 | self._token = config.token 59 | self._danmakus = set() 60 | self._superchats = set() 61 | self._stop_flag = False 62 | 63 | async def _run(self): 64 | while not self._stop_flag: 65 | await asyncio.sleep(1) 66 | self.emit_danmaku_event() 67 | self.emit_super_chat_event() 68 | 69 | def emit_danmaku_event(self): 70 | # https://developers.google.com/youtube/v3/live/docs/liveChatMessages 71 | live_chat_messages = get("https://www.googleapis.com/youtube/v3/liveChat/messages", self._token) 72 | danmakus = set(convert_danmakus(live_chat_messages["items"])) 73 | updated_danmakus = self._danmakus.difference(danmakus) 74 | self._danmakus.update(updated_danmakus) 75 | for danmaku in updated_danmakus: 76 | emitter.emit(LiveStreamDanmakuEvent(danmaku=danmaku, platform="youtube")) 77 | 78 | def emit_super_chat_event(self): 79 | # https://developers.google.com/youtube/v3/live/docs/superChatEvents 80 | super_chat_events = get("https://www.googleapis.com/youtube/v3/superChatEvents", self._token) 81 | super_chats = convert_superchats(super_chat_events["items"]) 82 | updated_superchats = self._superchats.difference(super_chats) 83 | self._superchats.update(updated_superchats) 84 | for superchat in updated_superchats: 85 | emitter.emit(LiveStreamSuperChatEvent(superchat=superchat, platform="youtube")) 86 | 87 | @log_start("YouTubeService") 88 | async def start(self): 89 | await super().start() 90 | await self._run() 91 | 92 | @log_stop("YouTubeService") 93 | async def stop(self): 94 | await super().stop() 95 | self._stop_flag = True 96 | -------------------------------------------------------------------------------- /tests/pipeline/server.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import flask 4 | from flask import request 5 | from zerolan.data.pipeline.abs_data import AbstractModelPrediction 6 | from zerolan.data.pipeline.llm import LLMPrediction, Conversation, RoleEnum 7 | 8 | """ 9 | This server is just for test. But you can learn how to implement your server! 10 | """ 11 | 12 | 13 | class TestServer: 14 | def __init__(self): 15 | self._host = "127.0.0.1" 16 | self._port = 5889 17 | self._app = flask.Flask(__name__) 18 | 19 | def start(self): 20 | print("Starting test server") 21 | self._app.run(host=self._host, port=self._port) 22 | 23 | def init(self): 24 | @self._app.route('/llm/predict', methods=['POST']) 25 | def llm_predict(): 26 | assert len(request.json) > 0 27 | id = request.json['id'] 28 | prediction = LLMPrediction(id=id, response="Test passed", 29 | history=[Conversation(role=RoleEnum.user, content="Test"), 30 | Conversation(role=RoleEnum.assistant, content="Test passed")]) 31 | return flask.jsonify(prediction.model_dump()) 32 | 33 | @self._app.route('/llm/stream-predict', methods=['POST']) 34 | def llm_stream_predict(): 35 | assert len(request.json) > 0 36 | id = request.json['id'] 37 | content = "Test passed" 38 | 39 | def gen(): 40 | for i in range(len(content)): 41 | prediction = LLMPrediction(id=id, response=content[:i], 42 | history=[Conversation(role=RoleEnum.user, content="Test"), 43 | Conversation(role=RoleEnum.assistant, content=content[:i])]) 44 | yield prediction.model_dump_json() 45 | 46 | return flask.Response(gen()) 47 | 48 | @self._app.route('/abs-img/predict', methods=['POST']) 49 | def abs_img_predict(): 50 | if request.headers['Content-Type'] == 'application/json': 51 | # If it's in JSON format, then there must be an image location. 52 | json_val = request.get_json() 53 | id = json_val['id'] 54 | return flask.jsonify(AbstractModelPrediction(id=id).model_dump()) 55 | elif 'multipart/form-data' in request.headers['Content-Type']: 56 | # If it's in multipart/form-data format, then try to get the image file. 57 | img = request.files['image'] 58 | assert img is not None 59 | # Note: you must get json data from `form['json']`! 60 | query = json.loads(request.form['json']) 61 | id = query['id'] 62 | return flask.jsonify(AbstractModelPrediction(id=id).model_dump()) 63 | else: 64 | raise NotImplementedError("Unsupported Content-Type.") 65 | 66 | @self._app.route('/abs-img/stream-predict', methods=['POST']) 67 | def abs_img_stream_predict(): 68 | 69 | def gen(id): 70 | for i in range(10): 71 | prediction = AbstractModelPrediction(id=id) 72 | yield prediction.model_dump_json() 73 | 74 | if request.headers['Content-Type'] == 'application/json': 75 | # If it's in JSON format, then there must be an image location. 76 | json_val = request.get_json() 77 | id = json_val['id'] 78 | return flask.Response(gen(id)) 79 | elif 'multipart/form-data' in request.headers['Content-Type']: 80 | # If it's in multipart/form-data format, then try to get the image file. 81 | img = request.files['image'] 82 | assert img is not None 83 | # Note: you must get json data from `form['json']`! 84 | query = json.loads(request.form['json']) 85 | id = query['id'] 86 | return flask.Response(gen(id)) 87 | else: 88 | raise NotImplementedError("Unsupported Content-Type.") 89 | -------------------------------------------------------------------------------- /devices/microphone.py: -------------------------------------------------------------------------------- 1 | import io 2 | import threading 3 | import wave 4 | 5 | import pyaudio 6 | import webrtcvad 7 | from loguru import logger 8 | 9 | from common.concurrent.abs_runnable import ThreadRunnable 10 | from common.io.file_type import AudioFileType 11 | from event.event_data import DeviceMicrophoneVADEvent 12 | from event.event_emitter import emitter 13 | 14 | 15 | class SmartMicrophone(ThreadRunnable): 16 | def __init__(self, vad_mode=0, frame_duration=30): 17 | """ 18 | 初始化智能麦克风类 19 | :param vad_mode: Optionally, set its aggressiveness mode, which is an integer between 0 and 3. 20 | 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive. 21 | :param frame_duration: A frame must be either 10, 20, or 30 ms in duration. 22 | """ 23 | super().__init__() 24 | assert frame_duration in [10, 20, 30], f"A frame must be either 10, 20, or 30 ms in duration!" 25 | 26 | # Audio parameters 27 | self._format = pyaudio.paInt16 28 | self._channels = 1 29 | self._sample_rate = 16000 30 | self._chunk_size = int(self._sample_rate * frame_duration / 1000) # Bytes 31 | 32 | # Initialize microphone 33 | self._audio = pyaudio.PyAudio() 34 | self._vad = webrtcvad.Vad(vad_mode) 35 | self._stream = self._audio.open(format=self._format, 36 | channels=self._channels, 37 | rate=self._sample_rate, 38 | input=True, 39 | frames_per_buffer=self._chunk_size) 40 | 41 | self._audio_frames = [] 42 | self._is_speaking = False 43 | 44 | self._pause_event = threading.Event() 45 | self._stop_flag = False 46 | 47 | @property 48 | def is_recording(self): 49 | return self._pause_event.is_set() and (not self._stop_flag) and self._stream.is_active() 50 | 51 | def start(self): 52 | super().start() 53 | self._pause_event.set() 54 | self._stop_flag = False 55 | try: 56 | while not self._stop_flag: 57 | self._pause_event.wait() 58 | if self._stop_flag: 59 | break 60 | data = self._stream.read(self._chunk_size, exception_on_overflow=False) 61 | if self._vad.is_speech(data, self._sample_rate): 62 | if not self._is_speaking: 63 | logger.info("Voice detected: Beginning.") 64 | self._is_speaking = True 65 | self._audio_frames.append(data) 66 | else: 67 | if self._is_speaking: 68 | logger.info("Voice detected: Ending.") 69 | self._is_speaking = False 70 | self._emit_event() 71 | self._audio_frames = [] 72 | except Exception as e: 73 | logger.exception(e) 74 | finally: 75 | # Stop and close the microphone stream 76 | self._stream.stop_stream() 77 | self._stream.close() 78 | self._audio.terminate() 79 | 80 | def _emit_event(self): 81 | if self._audio_frames: 82 | # 创建一个BytesIO对象来存储WAV文件 83 | file = io.BytesIO() 84 | wf = wave.open(file, 'wb') 85 | wf.setnchannels(self._channels) 86 | wf.setsampwidth(self._audio.get_sample_size(self._format)) 87 | wf.setframerate(self._sample_rate) 88 | wf.writeframes(b''.join(self._audio_frames)) 89 | wf.close() 90 | 91 | # 将BytesIO对象的指针移到开始位置 92 | file.seek(0) 93 | emitter.emit(DeviceMicrophoneVADEvent( 94 | speech=file.read(), 95 | audio_type=AudioFileType.WAV, 96 | channels=self._channels, 97 | sample_rate=self._sample_rate, 98 | )) 99 | 100 | def pause(self): 101 | self._pause_event.clear() 102 | logger.info("Paused smart microphone.") 103 | 104 | def resume(self): 105 | self._pause_event.set() 106 | logger.info("Resumed smart microphone.") 107 | 108 | def stop(self): 109 | self._stop_flag = True 110 | self._pause_event.set() 111 | logger.info("Stopped smart microphone.") 112 | 113 | def name(self): 114 | return "SmartMicrophone" 115 | -------------------------------------------------------------------------------- /common/web/json_ws.py: -------------------------------------------------------------------------------- 1 | import json 2 | import socket 3 | from typing import Callable, Union, List, Dict 4 | 5 | from loguru import logger 6 | from pydantic import BaseModel 7 | from websockets import ConnectionClosed 8 | from websockets.sync.connection import Connection 9 | from websockets.sync.server import serve, Server 10 | 11 | from common.concurrent.abs_runnable import ThreadRunnable 12 | from common.utils import web_util 13 | 14 | 15 | ############################ 16 | # Json Web Socket Server # 17 | # Author: AkagawaTsurunaki # 18 | ############################ 19 | 20 | class JsonWsServer(ThreadRunnable): 21 | def __init__(self, host: str, port: int, subprotocols: List[str] = None): 22 | super().__init__() 23 | self.ws: Server | None = None 24 | self.host = host 25 | self.port = port 26 | # 重要!使用子协议用于校验! 27 | self.subprotocols = subprotocols 28 | 29 | # 监听器注册 30 | self.on_msg_handlers: List[Callable[[Connection, Union[Dict, List]], None]] = [] 31 | self.on_open_handlers: List[Callable[[Connection], None]] = [] 32 | self.on_close_handlers: List[Callable[[Connection, int, str], None]] = [] 33 | self.on_err_handlers: List[Callable[[Connection, Exception], None]] = [] 34 | 35 | # Connection 记录(关闭连接后不要使用 Connection 对象) 36 | self._connections: dict[Connection, str] = {} 37 | 38 | def name(self): 39 | return "JsonWsServer" 40 | 41 | def start(self): 42 | super().start() 43 | is_ipv6 = web_util.is_ipv6(self.host) 44 | logger.info(f"This Websocket server will use {'IPv6' if is_ipv6 else 'IPv4'}.") 45 | with serve(handler=self._handle_json_recv, host=self.host, port=self.port, 46 | subprotocols=self.subprotocols, 47 | family=socket.AF_INET6 if is_ipv6 else socket.AF_INET) as ws: 48 | self.ws = ws 49 | logger.info(f"WebSocket server started at {self.host}:{self.port}") 50 | self.ws.serve_forever() 51 | 52 | def stop(self): 53 | super().stop() 54 | if self.ws is not None: 55 | self.ws.shutdown() 56 | 57 | @property 58 | def connections(self): 59 | return len(self._connections) 60 | 61 | def _handle_json_recv(self, ws: Connection): 62 | """处理每个 WebSocket 连接""" 63 | # 处理 Sec-WebSocket-Protocol 的 Header 64 | self._validate_subprotocols(ws) 65 | self._add_connection(ws) 66 | try: 67 | while True: 68 | try: 69 | message = ws.recv() 70 | data = json.loads(message) 71 | # 注意:这里一旦抛出异常,那么并非所有的 Handler 都会被执行 72 | # 例如说,有 10 个 Handler,如果第 5 个出错,那么后 5 个将不会被执行 73 | for handler in self.on_msg_handlers: 74 | handler(ws, data) 75 | except Exception as e: 76 | if isinstance(e, ConnectionClosed): 77 | raise e 78 | if len(self.on_err_handlers) == 0: 79 | logger.exception(e) 80 | self._handle_exception(ws, e) 81 | 82 | except ConnectionClosed as e: 83 | self._remove_connection(ws, e) 84 | 85 | def _validate_subprotocols(self, ws: Connection): 86 | if ws.subprotocol is not None: 87 | if ws.subprotocol not in self.subprotocols: 88 | logger.warning(f"Not supported sub protocol: {ws.id} {ws.remote_address}") 89 | raise ValueError(f"Not supported sub protocol: {ws.id} {ws.remote_address}") 90 | 91 | def _add_connection(self, ws: Connection): 92 | self._connections[ws] = str(ws.id) 93 | for handler in self.on_open_handlers: 94 | handler(ws) 95 | logger.info(f"WebSocket client connected: {ws.id} {ws.remote_address}") 96 | 97 | def _remove_connection(self, ws: Connection, e: ConnectionClosed): 98 | ws_id = self._connections.pop(ws) 99 | for handler in self.on_close_handlers: 100 | handler(ws, e.rcvd.code, e.rcvd.reason) 101 | logger.warning(f"WebSocket client disconnected: {ws_id}") 102 | 103 | def _handle_exception(self, ws: Connection, e: Exception): 104 | for handler in self.on_err_handlers: 105 | handler(ws, e) 106 | 107 | def send_json(self, data: any): 108 | if isinstance(data, BaseModel): 109 | msg = data.model_dump_json(indent=4) 110 | else: 111 | msg = json.dumps(data, ensure_ascii=False, indent=4) 112 | 113 | for conn in self._connections: 114 | conn.send(msg) 115 | --------------------------------------------------------------------------------