├── .gitignore ├── public ├── Snipaste_2023-04-25_17-48-51.png ├── Snipaste_2023-04-25_21-26-54.png ├── mm_reward_qrcode_1686025672796.png └── readme_v1.md ├── readme.md └── src ├── __init__.py ├── config.py ├── config.sample.ini ├── core ├── __init__.py ├── main.py └── vup.py ├── db ├── __init__.py ├── dao.py ├── milvus.py └── models.py ├── forbidden_words.txt ├── manager.py ├── modules ├── __init__.py ├── actions.py ├── audio.py └── speech_rec.py ├── requirements.txt ├── rooms ├── __init__.py ├── bilibili.py └── douyin.py ├── scripts ├── crawlers │ └── tie_ba_spider.py ├── manager.py ├── utils.py └── workers.py ├── static ├── speech │ └── .gitkeep └── voice │ └── .gitkeep ├── token.txt └── utils ├── base.py ├── dfa.py ├── events.py ├── init.py ├── log.py ├── prompt_temple.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | 4 | # IDEs and editors 5 | .idea/ 6 | *.swp 7 | *~.nfs* 8 | *.bak 9 | *.cache 10 | *.dat 11 | *.db 12 | *.log 13 | *.patch 14 | *.orig.* 15 | *.rej.* 16 | *.tmp.* 17 | 18 | *.mp3 19 | *.wav 20 | 21 | env/ 22 | config.ini -------------------------------------------------------------------------------- /public/Snipaste_2023-04-25_17-48-51.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/public/Snipaste_2023-04-25_17-48-51.png -------------------------------------------------------------------------------- /public/Snipaste_2023-04-25_21-26-54.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/public/Snipaste_2023-04-25_21-26-54.png -------------------------------------------------------------------------------- /public/mm_reward_qrcode_1686025672796.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/public/mm_reward_qrcode_1686025672796.png -------------------------------------------------------------------------------- /public/readme_v1.md: -------------------------------------------------------------------------------- 1 | # 项目名称 2 | 3 | GPT-vup 4 | 5 | ## :memo: 简介 6 | 7 | 支持BiliBili和抖音直播,基于生产者-消费者模型设计,使用了openai嵌入、GPT3.5 api 8 | 9 | ![Snipaste_2023-04-25_21-26-54](https://raw.githubusercontent.com/jiran214/GPT-vup/master/public/Snipaste_2023-04-25_21-26-54.png) 10 | 11 | ## :cloud: 环境 12 | 13 | - python 3. 8 14 | - windows 15 | - 确保有VPN 并开启全局代理 16 | 17 | ## :computer: 功能 18 | 19 | 1. 回答弹幕和SC 20 | 2. 欢迎入场观众 21 | 3. 感谢礼物 22 | 4. 自定义任务 23 | 5. plugin 在config.ini设置,默认都为False 24 | - speech:监听ctrl+t热键,输入语音转为文本和ai数字人交互 25 | - action:根据观众的行为匹配对应人物动作 26 | - schedule:隔一段时间触发某一事件,讲故事、唱rap... 27 | 28 | ## :book: 原理 29 | 30 | ![Snipaste_2023-04-25_17-48-51](https://raw.githubusercontent.com/jiran214/GPT-vup/master/public/Snipaste_2023-04-25_17-48-51.png) 31 | 32 | GPT-vup一共运行三个子线程: 33 | 34 | 生产者线程一:BiliBili Websocket 35 | 36 | - 运行bilibili_api库,通过长连接websocket不断获取直播间的Event,分配到每个filter函数。 37 | - filter函数干两件事,筛选哪些event入队,入哪个队列 38 | - 线程消息队列有两个: 39 | - 前提:生产者的生产速度远大于消费者 40 | - event_queue:有最大长度,超过长度时挤掉最旧的消息,因此它是不可靠的,用来处理直播间的一般消息(普通弹幕、欢迎提示) 41 | - hight...queue:不限长,处理直播间重要消息(sc、上舰) 42 | 43 | 生产者线程二:抖音 WebSocket 44 | 45 | - 借助开源项目 [抖音弹幕抓取数据推送: 基于系统代理抓包打造的抖音弹幕服务推送程序](https://gitee.com/haodong108/dy-barrage-grab/tree/V2.6.5/BarrageGrab) 在本地开一个转发端口 46 | - 再运行一个线程监听这个端口即可,同样用filter过滤,入队 47 | 48 | 生产者线程三: 49 | 50 | - 如果vup只有回应弹幕,我觉得有些单调了,因此可以通过schedule模块,每隔一段时间往high_priority_event_queue送一些自定义Event,比如我想让她每隔十分钟做一个自我介绍、表演节目。 51 | - 5-13更新,支持热键触发实时语音交互,见plugin 52 | 53 | 消费者线程: 54 | 55 | - worker类,有三个函数:generate_chat、generate_action、output去处理不同的Event 56 | - 遵循依赖倒置原则,不管弹幕Event、sc Event都依赖抽象Event,而worker也依赖Event 57 | 58 | 说明: 59 | - 消费者线程必须运行,生产者线程保证至少一个开启 60 | 61 | ## :microscope: 安装配置及使用教程 62 | 63 | ### 克隆项目,安装python依赖 64 | 65 | ``` 66 | git https://github.com/jiran214/GPT-vup.git 67 | cd src 68 | # 建议命令行或者pycharm创建虚拟环境并激活 69 | python -m pip install --upgrade pip pip 70 | # 可能会依赖冲突,没法彻底解决 71 | pip install -r .\requirements.txt --no-deps 72 | ``` 73 | 74 | ### 配置config 75 | 76 | 在src目录下创建配置文件config.ini(该项目所有配置信息都在这) 77 | 78 | ```ini 79 | [openai] 80 | api_key = sk-iHeZopAaLtem7E7FLEM6T3BaakFJsvhz0yVchBkii0oLJl0V 81 | 82 | [room] 83 | id=27661527 84 | 85 | [edge-tss] 86 | voice = zh-CN-XiaoyiNeural 87 | rate = +10% 88 | volume = +0% 89 | 90 | [other] 91 | debug = True 92 | proxy = 127.0.0.1:7890 93 | 94 | [plugin] 95 | action=False 96 | schedule=False 97 | speech=False 98 | ``` 99 | 100 | **说明:** 101 | 102 | - room-id 为直播间房,比如我的是[哔哩哔哩直播,二次元弹幕直播平台 (bilibili.com)](https://live.bilibili.com/27661527)最后一部分(没有房间号可以随便找一个作为测试) 103 | 104 | - edge-tss 语音相关配置 105 | 106 | ### 安装VTS(Vtuber Studio),获取VTS TOKEN并调试API 107 | 108 | - 安装及使用教程网上有,可以配置嘴和音频同步,只说明程序部分 109 | - action plugin 实现Vtuber根据观众的互动行为匹配动作,可忽略 110 | - config.ini 中的action设置为True 111 | - 打开VTS,开启VTS的API开关 112 | - 给任务的每一个动作重命名为体现动作表情的词,不然没有意义 113 | - 运行>> `python ./main action`,pyvts会请求vts api(注意:此时VTS会有确认弹窗) 114 | - 会自动生成 action.json 115 | 116 | 117 | 说明:action plugin原理? 118 | 119 | - 简单说 根据用户发来的弹幕响应对应的动作,先去获取弹幕或者相关信息的向量,用这个向量查找action_embeddings中余弦相似度最接近的向量,也就是最接近的动作,作为响应action。 120 | - 动作响应不一定依靠embedding,实际效果差强人意,用embedding是因为我有考虑到 后期可以给用户的输入匹配更多上下文。上下文可以来源于任何地方 贴吧、小红书...只要提前生成向量保存到向量数据库即可,让AI主播的回答更丰富。 121 | - 关于openai的embedding的介绍和作用,可以看openai文档 [Embeddings - OpenAI API](https://platform.openai.com/docs/guides/embeddings) 122 | 123 | ### 抖音直播配置(可忽略) 124 | 125 | - 参考 [抖音弹幕抓取数据推送: 基于系统代理抓包打造的抖音弹幕服务推送程序](https://gitee.com/haodong108/dy-barrage-grab/tree/V2.6.5/BarrageGrab) 126 | - 打开正在直播的直播间,数据开始抓取 127 | 128 | ### 运行 129 | 130 | - 方式一:谷歌fire库 命令行方式启动(默认):确保 main.py fire.Fire(Management)这一行运行,其它注释掉 131 | - 方式二:正常运行,根据需要运行 132 | ```python 133 | if __name__ == '__main__': 134 | """命令行启动,等同于下面的程序启动""" 135 | # fire.Fire(Management) 136 | 137 | """测试""" 138 | # >> python main test 139 | # Management().test() 140 | 141 | """启动程序""" 142 | # >> python main run bilibili 143 | # Management().run('BiliBili') 144 | # Management().run('DouYin') 145 | 146 | """初始化""" 147 | # >> python main action 148 | # Management().action() 149 | ``` 150 | - 建议先运行测试,检测vpn,再正式运行程序 151 | - python main action 时用来初始化 action plugin的,可忽略 152 | 153 | ### OBS 154 | 155 | 网上有教程 156 | 157 | ## :bulb: 踩坑和经验 158 | 159 | 1. 再用openai库的acreate 关闭ssl还是会偶尔遇到ssl报错,怀疑lib底层调aiohttp有冲突,使用create后报错明显减少 160 | 2. 和vts交互上,最开始尝试keyboard键盘操作操控,发现vts的快捷键不像其它软件一样,只能通过pyvts调用api实现动作响应 161 | 3. 在这个AI主播的场景里,需要确保每个 消息队列出队-处理-输出过程的原子性,最好不要同时处理多个弹幕(Event) 162 | 4. 协程适合轻量级的任务,或者说一个协程函数里awiat不能太多,否则并发安全很难维护 163 | 5. 每个线程要创建自己的事件循环 164 | 6. 本项目利用协程解耦不同的生产消费过程,也可以看看这篇文章[写个AI虚拟主播:看懂弹幕,妙语连珠,悲欢形于色,以一种简单的实现 - 掘金 (juejin.cn)](https://juejin.cn/post/7204742468612145209),它用到端口/进程解耦,最后把所有组件用Go组装,AI 主播 总体流程都差不多 165 | 166 | ## :page_with_curl: 更新日志 167 | 168 | - 4.26 支持抖音直播 169 | - 5.13 LangChain重构prompt部分 170 | - 5.13 config.json存取action及向量 171 | - 5.13 支持通过fire命令行启动 172 | - 5.13 增加运行前的测试 173 | - 5.13 插件系统 174 | - 5.14 requirements 修改,bilibili_api库没有更新,没法彻底解决依赖,请在pip install后面加上--no-deps 175 | 176 | ## :phone: Contact Me 177 | 178 | 欢迎加我WX:yuchen59384 交流! 179 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # GPT-vup Live2D数字人直播 2 | **(本库停止维护,可以关注下我的新项目:Langup: https://github.com/jiran214/langup-ai 已实现直播数字人)** 3 | 4 | ![](https://img.shields.io/badge/license-GPL-blue) 5 | 6 | ## 简介 7 | **Real Virtual UP** 8 | 支持BiliBili和抖音直播,基于生产者-消费者模型设计,使用了openai嵌入、GPT3.5 api 9 | (本库暂时停止维护) 10 | 11 | ### 功能 12 | - 基本功能,回答弹幕和SC、欢迎入场观众、感谢礼物 13 | - plugin(默认关闭) 14 | - speech:监听ctrl+t热键,输入语音转为文本和ai数字人交互 15 | - action:根据观众的行为匹配对应人物动作 16 | - schedule:隔一段时间触发某一事件,讲故事、唱rap... 17 | - context:给问题补充上下文 18 | 19 | ## 安装 20 | ### 环境 21 | - win 10 22 | - python 3.8 23 | - vpn全局代理 24 | ### pip安装依赖 25 | ```shell 26 | git clone https://github.com/jiran214/GPT-vup.git 27 | cd src 28 | # 建议命令行或者pycharm创建虚拟环境并激活 https://blog.csdn.net/xp178171640/article/details/115950985 29 | python -m pip install --upgrade pip pip 30 | pip install -r requirements.txt 31 | ``` 32 | ### 新建config.ini 33 | - 重命名config.sample.ini为config.ini 34 | - 更改api_key和proxy 其它可以不用管 35 | - 相关配置见后 36 | ### 测试网络环境 37 | - src目录下运行:>>`python manager.py test_net` 38 | ## 快速开始 39 | ### B站直播 40 | - 安装依赖库:>>`pip install bilibili-api-python` 41 | - config.ini 的 room -> id 更改为自己的房间号,可以先随便找个 42 | - src目录下运行:>>`python manager.py run bilibili` 43 | ### 抖音直播 44 | - 参考 [抖音弹幕抓取数据推送: 基于系统代理抓包打造的抖音弹幕服务推送程序](https://gitee.com/haodong108/dy-barrage-grab/tree/V2.6.5/BarrageGrab) 45 | - 启动该项目 46 | - 打开web或者桌面端抖音正在直播的直播间,数据开始抓取 47 | - src目录下运行:>>`python manager.py run douyin` 48 | ### Vtube Studio 安装及配置 49 | - 在steam下载Vtube Studio软件 50 | - 教程:https://www.bilibili.com/video/BV1nV4y1X7yJ?t=426.7 51 | - 重点!!! 52 | - 麦克风设置:你可以不用虚拟声道,win 默认输出设备为Speaker Realtek(R) Audio,在VTS里的麦克风设置,输入设备也设置为Realtek(R) Audio即可。 53 | - 嘴型同步声音,在mouthOpen的输入参数设置为声音频率、或者声音音量 54 | - 如果需要更好的直播效果,请自行了解更多 55 | ## 进阶 56 | ### speech plugin:语音交互 57 | - config.ini -> plugin -> speech 设置为True 58 | - 运行>> `pip install pyaudio speech_recognition keyboard` 59 | - 程序启动后按住 ctrl+T 说话,自动语音转文字,vup会听到你说的话 60 | ### schedule plugin:隔一段时间触发某一事件,讲故事、唱rap... 61 | - config.ini -> plugin -> schedule 设置为True 62 | - utils/prompt_temple.py 的schedule_task_temple_list列表有我写好的触发事件 63 | ### action plugin:VTS动作表情交互 64 | 实现vup根据观众的互动行为匹配动作 65 | - config.ini -> plugin -> action设置为True 66 | - 运行>>`pip install pyvts` 67 | - 打开VTS,开启VTS的API开关 68 | - 在VTS的表情设置里,给每一个动作重命名为体现动作表情的词,不然没有意义 69 | - src目录下运行>> `python manager.py action`,pyvts会请求vts api(注意:此时VTS会有确认弹窗) 70 | - 程序会自动生成 action.json 71 | - 如果需要更新动作,请重复上述步骤 72 | ### 实验功能:context plugin:给对话补充上下文 73 | - 前提1:Docker[安装milvus2.0单机版本](https://milvus.io/docs/v2.0.x/install_standalone-docker.md),并设置 config.ini -> milvus -> host and port 74 | - 前提2:Mysql环境,并设置 config.ini -> mysql -> uri 75 | - config.ini -> plugin -> context 设置为True 76 | - 运行>> `pip install pymilvus==2.0` 77 | - 自行设置scripts/manager.py的参数,运行>> `python scripts/manager.py run`,采集贴吧数据到MySQL,处理后推给Milvus 78 | ### 其它 79 | - utils/prompt_temple.py 的 system_template 可以更改vup的初始设定 80 | ## 更新日志 81 | - V2.0 支持context plugin,目录重构、更简单的readme,解决依赖混乱的问题 82 | - V1.0 [旧版本内容](https://github.com/jiran214/GPT-vup/tree/1.0) 83 | ## Contact Me 84 |
85 |
86 |
87 | 88 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/__init__.py -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: config.py 5 | @DateTime: 2023/4/7 14:30 6 | @SoftWare: PyCharm 7 | """ 8 | import configparser 9 | import json 10 | import os.path 11 | 12 | root_path = os.path.abspath(os.path.dirname(__file__)) 13 | file_path = os.path.join(root_path, 'config.ini') 14 | 15 | _config = configparser.RawConfigParser() 16 | _config.read(file_path) 17 | 18 | tss_settings = dict(_config.items('edge-tss')) 19 | 20 | api_key_list = [value for key, value in _config.items('openai') if key.startswith('api') and value] 21 | temperature = _config.get('openai', 'temperature') 22 | room_id = _config.getint('room', 'id') 23 | mysql = dict(_config.items('mysql')) 24 | sqlite = dict(_config.items('sqlite')) 25 | milvus = dict(_config.items('milvus')) 26 | debug = _config.getboolean('other', 'debug') 27 | proxy = _config.get('other', 'proxy') 28 | 29 | action_plugin = _config.getboolean('plugin', 'action') 30 | schedule_plugin = _config.getboolean('plugin', 'schedule') 31 | speech_plugin = _config.getboolean('plugin', 'speech') 32 | context_plugin = _config.getboolean('plugin', 'context') 33 | 34 | try: 35 | live2D_actions = [] 36 | live2D_embeddings = [] 37 | if action_plugin: 38 | with open("./action.json", 'r') as load_f: 39 | live2D_action_dict = json.load(load_f) 40 | live2D_actions = live2D_action_dict.keys() 41 | assert live2D_embeddings 42 | live2D_embeddings = [live2D_action_dict[action] for action in live2D_actions] 43 | except Exception as e: 44 | print('读取embedding文件错误,请检查本地是否生成action.json 且动作不为空, 使用action plugin前请先运行 python manager.py action', e) 45 | 46 | 47 | with open(os.path.join(root_path, 'forbidden_words.txt'), mode='r', encoding='utf-8') as f: 48 | keyword_str_list = [line.strip() for line in f.readlines()] 49 | 50 | if __name__ == '__main__': 51 | print(api_key, proxy) -------------------------------------------------------------------------------- /src/config.sample.ini: -------------------------------------------------------------------------------- 1 | [openai] 2 | api_key = xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx 3 | temperature = 0.9 4 | 5 | [room] 6 | id=732 7 | 8 | [edge-tss] 9 | voice = zh-CN-XiaoyiNeural 10 | rate = +10% 11 | volume = +0% 12 | 13 | [other] 14 | debug = True 15 | proxy = 127.0.0.1:7890 16 | 17 | [plugin] 18 | action=False 19 | schedule=False 20 | speech=False 21 | context=False 22 | 23 | [mysql] 24 | uri = mysql+pymysql://root:123456@localhost/vup?charset=utf8mb4 25 | 26 | [sqlite] 27 | uri = xxxx 28 | 29 | [milvus] 30 | host = xxxxxx 31 | port = 19530 32 | collection = sun_ba 33 | top_n = 4 34 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/core/__init__.py -------------------------------------------------------------------------------- /src/core/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: main.py 5 | @DateTime: 2023/4/23 13:19 6 | @SoftWare: PyCharm 7 | """ 8 | import threading 9 | import time 10 | import schedule 11 | 12 | 13 | from src import config 14 | from src.utils.prompt_temple import get_schedule_task 15 | from src.utils.events import UserEvent 16 | from src.utils.utils import worker_logger 17 | from src.rooms.bilibili import BlLiveRoom 18 | from src.rooms.douyin import dy_connect 19 | 20 | from src.utils.utils import user_queue, NewEventLoop 21 | from src.core.vup import VtuBer 22 | from bilibili_api import sync 23 | 24 | 25 | logger = worker_logger 26 | 27 | 28 | # Define the producer function 29 | def bl_producer(): 30 | r = BlLiveRoom() 31 | r.connect() 32 | 33 | 34 | def dy_producer(): 35 | t_loop = NewEventLoop() 36 | t_loop.run(dy_connect()) 37 | 38 | 39 | class UserProducer: 40 | 41 | def __init__(self): 42 | self.run_funcs = [] 43 | 44 | def send_user_event_2_queue(self, task): 45 | if user_queue.event_queue.empty(): 46 | ue = UserEvent(*task) 47 | ue.is_high_priority = True 48 | # ue.action = live2D_actions.index('Anim Shake') 49 | user_queue.send(ue) 50 | 51 | def create_schedule(self): 52 | # 延时启动 53 | time.sleep(30) 54 | # 清空任务 55 | schedule.clear() 56 | # 创建一个按5分钟间隔执行任务 57 | schedule.every(5).minutes.do( 58 | self.send_user_event_2_queue, get_schedule_task() 59 | ) 60 | return schedule 61 | 62 | def run(self): 63 | if config.schedule_plugin: 64 | schedule_obj = self.create_schedule() 65 | self.run_funcs.append(schedule_obj.run_pending) 66 | if config.speech_plugin: 67 | try: 68 | from src.modules.speech_rec import speech_hotkey_listener 69 | except ImportError: 70 | raise 'Please run pip install pyaudio speech_recognition keyboard' 71 | # self.run_funcs.append(speech_hotkey_listener) 72 | speech_hotkey_listener() 73 | if self.run_funcs: 74 | self.run_funcs.append(lambda: time.sleep(2)) 75 | while True: 76 | for run_fun in self.run_funcs: 77 | run_fun() 78 | 79 | 80 | # Define the consumer function 81 | def consumer(): 82 | while True: 83 | t0 = time.time() 84 | event = user_queue.recv() 85 | if not event: 86 | # Both queues are empty, wait for new items to be added 87 | time.sleep(1) 88 | logger.debug('consumer waiting') 89 | continue 90 | 91 | worker = VtuBer(event) 92 | try: 93 | worker.run() 94 | logger.debug(f'worker耗时:{time.time() - t0}') 95 | except Exception as e: 96 | raise e 97 | # logger.error(e) 98 | # time.sleep(20) 99 | 100 | 101 | def start_thread(worker_name): 102 | worker_map = { 103 | 'bl_producer': bl_producer, 104 | 'dy_producer': dy_producer, 105 | 'user_producer': UserProducer().run, 106 | 'consumer': consumer 107 | } 108 | if worker_name not in worker_map: 109 | raise '不存在...' 110 | 111 | thread = threading.Thread(target=worker_map[worker_name]) 112 | thread.start() 113 | return thread 114 | 115 | 116 | if __name__ == '__main__': 117 | # bl_producer() 118 | t = start_thread('bl_producer') 119 | start_thread('consumer') 120 | t.join() 121 | # time.sleep(10000) 122 | -------------------------------------------------------------------------------- /src/core/vup.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: process.py 5 | @DateTime: 2023/4/22 22:17 6 | @SoftWare: PyCharm 7 | """ 8 | import asyncio 9 | import threading 10 | import time 11 | from langchain.chat_models import ChatOpenAI 12 | from bilibili_api import sync 13 | 14 | from src import config 15 | from src.config import live2D_embeddings, keyword_str_list 16 | from src.db.milvus import VectorStore 17 | from src.db.models import TieBa 18 | from src.db.dao import get_session 19 | from src.modules.actions import play_action 20 | from src.modules.audio import tts_save, play_sound 21 | from src.utils.dfa import DFA 22 | from src.utils.events import BlDanmuMsgEvent 23 | from src.utils.utils import worker_logger, sync_get_embedding, get_openai_key 24 | from src.utils.utils import Event 25 | from src.utils.utils import audio_lock, NewEventLoop, top_n_indices_from_embeddings 26 | 27 | logger = worker_logger 28 | 29 | base_path = './static/voice/{}.mp3' 30 | 31 | 32 | class VtuBer: 33 | dfa = DFA(keyword_str_list) 34 | 35 | def __init__(self, event: Event): 36 | self.event = event 37 | self.sound_path = base_path.format(int(time.time())) 38 | 39 | async def generate_chat(self, embedding): 40 | # 额外参数 41 | extra_kwargs = {} 42 | # 只给弹幕增加上下文 43 | if config.context_plugin and isinstance(self.event, BlDanmuMsgEvent): 44 | ids = VectorStore(config.milvus['collection']).search_top_n_from_milvus(int(config.milvus['top_n']), embedding)[0].ids 45 | with get_session() as s: 46 | rows = s.query(TieBa).filter(TieBa.hash_id.in_(str(hash_id) for hash_id in ids)).all() 47 | context = [row.content for row in rows] 48 | extra_kwargs['context'] = str(context) 49 | # 请求GPT 50 | messages = self.event.get_prompt_messages(**extra_kwargs) 51 | logger.info(f"prompt:{messages[1]} 开始请求gpt") 52 | chat = ChatOpenAI(temperature=config.temperature, max_retries=2, max_tokens=150, 53 | openai_api_key=get_openai_key()) 54 | llm_res = chat.generate([messages]) 55 | assistant_content = llm_res.generations[0][0].text 56 | logger.info(f'assistant_content:{assistant_content}') 57 | # 违禁词判断 58 | dfa_match_list = self.dfa.match(assistant_content) 59 | forbidden_words = [forbidden_word['match'] for forbidden_word in dfa_match_list] 60 | if dfa_match_list: 61 | logger.warning(f'包含违禁词:{forbidden_words},跳过本次语音生成') 62 | return False 63 | # 使用 Edge TTS 生成回复消息的语音文件 64 | logger.debug(f"开始生成TTS 文件") 65 | t0 = time.time() 66 | await tts_save(self.event.get_audio_txt(assistant_content), self.sound_path) 67 | logger.debug(f"tts请求耗时:{time.time()-t0}") 68 | 69 | async def generate_action(self, embedding): 70 | if isinstance(self.event.action, str): 71 | # 是否手动设置 72 | logger.debug(f"开始生成动作") 73 | t0 = time.time() 74 | # 匹配动作 75 | self.event.action = int(top_n_indices_from_embeddings(embedding, live2D_embeddings, top=1)[0]) 76 | logger.debug(f"动作请求耗时:{time.time()-t0}") 77 | 78 | async def output(self): 79 | logger.debug(f'path:{self.sound_path} 准备播放音频和动作') 80 | while audio_lock.locked(): 81 | await asyncio.sleep(1) 82 | else: 83 | # 每句话间隔时间 84 | time.sleep(0.5) 85 | # 播放声音 86 | play_sound_thread = threading.Thread(target=play_sound, args=(self.sound_path,)) 87 | play_sound_thread.start() 88 | # 播放动作 89 | if config.action_plugin and isinstance(self.event.action, int): 90 | await play_action(self.event.action) 91 | # play_sound_thread.join() 92 | # time.sleep(5) 93 | 94 | async def _run(self): 95 | # 获取词向量 96 | str_tuple = ('text', 'content', 'message', 'user_name') 97 | prompt_kwargs = self.event.prompt_kwargs.copy() 98 | embedding_str = None 99 | for key in str_tuple: 100 | if key in prompt_kwargs: 101 | embedding_str = prompt_kwargs[key] 102 | break 103 | if not embedding_str: 104 | raise '不应该不存在' 105 | embedding = sync_get_embedding([embedding_str]) 106 | tasks = [asyncio.create_task(self.generate_chat(embedding))] 107 | if config.action_plugin and self.event.action: 108 | tasks.append(asyncio.create_task(self.generate_action(embedding))) 109 | state = await asyncio.gather(*tasks) 110 | if state[0] is not False: 111 | await self.output() 112 | 113 | def run(self): 114 | # t_loop = NewEventLoop() 115 | # t_loop.run(self._run()) 116 | sync(self._run()) 117 | 118 | 119 | 120 | if __name__ == '__main__': 121 | res = embedding = sync_get_embedding(['embedding_str']) 122 | print(res) 123 | logger.debug('123') 124 | -------------------------------------------------------------------------------- /src/db/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/db/__init__.py -------------------------------------------------------------------------------- /src/db/dao.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | from sqlalchemy.orm import Session 4 | import contextlib 5 | from sqlalchemy.ext.declarative import declarative_base 6 | from sqlalchemy.orm import sessionmaker 7 | from sqlalchemy import ( 8 | create_engine 9 | ) 10 | from src import config # config模块里有自己写的配置,我们可以换成别的,注意下面用到config的地方也要一起换 11 | 12 | try: 13 | engine = create_engine( 14 | config.sqlite['uri'] or config.mysql['uri'], # SQLAlchemy 数据库连接串 15 | # echo=bool(config.SQLALCHEMY_ECHO), # 是不是要把所执行的SQL打印出来,一般用于调试 16 | # pool_size=int(config.SQLALCHEMY_POOL_SIZE), # 连接池大小 17 | # max_overflow=int(config.SQLALCHEMY_POOL_MAX_SIZE), # 连接池最大的大小 18 | # pool_recycle=int(config.SQLALCHEMY_POOL_RECYCLE), # 多久时间回收连接 19 | ) 20 | Session = sessionmaker(bind=engine) 21 | Base = declarative_base(engine) 22 | except: 23 | engine = None 24 | Session = None 25 | Base = object 26 | 27 | 28 | @contextlib.contextmanager 29 | def get_session(): 30 | s = Session() 31 | try: 32 | yield s 33 | s.commit() 34 | except Exception as e: 35 | s.rollback() 36 | raise e 37 | finally: 38 | s.close() 39 | -------------------------------------------------------------------------------- /src/db/milvus.py: -------------------------------------------------------------------------------- 1 | from src import config 2 | from src.utils.init import initialize_openai 3 | from src.utils.utils import sync_get_embedding 4 | 5 | 6 | class VectorStore: 7 | def __init__(self, name): 8 | try: 9 | from pymilvus import Collection, connections 10 | except ImportError: 11 | raise 'Please run pip install pymilvus==2.1' 12 | 13 | connections.connect( 14 | alias="default", 15 | host=config.milvus['host'], 16 | port=config.milvus['port'] 17 | ) 18 | 19 | self.collection = Collection(name) # Get an existing collection. 20 | # num_entities = self.collection.num_entities 21 | self.collection.load() 22 | self.search_params = {"metric_type": "L2", "params": {"nprobe": 10}} 23 | 24 | def search_top_n_from_milvus(self, limit, embedding): 25 | results = self.collection.search( 26 | data=[embedding], 27 | anns_field="embedding", 28 | param=self.search_params, 29 | limit=limit, 30 | expr=None, 31 | consistency_level="Strong", 32 | # output_fields='hash_id' 33 | ) 34 | return results 35 | 36 | 37 | if __name__ == '__main__': 38 | initialize_openai() 39 | # print() 40 | ids = VectorStore('sun_ba').search_top_n_from_milvus(3, sync_get_embedding(texts=['hello']))[0].ids -------------------------------------------------------------------------------- /src/db/models.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, Boolean, Integer 2 | 3 | from src.db.dao import Base, engine, get_session 4 | 5 | 6 | # class Document(Base): 7 | # __tablename__ = '' 8 | # id = Column(Integer, primary_key=True) 9 | # hash_id = Column(String(30), nullable=False, unique=True) 10 | # content = Column(String(200), nullable=False) 11 | # embedding_state = Column(Boolean, default=False) 12 | 13 | 14 | class TieBa(Base): 15 | __tablename__ = "tie_ba" 16 | id = Column(Integer, primary_key=True) 17 | hash_id = Column(String(30), nullable=False, unique=True) 18 | content = Column(String(200), nullable=False) 19 | embedding_state = Column(Boolean, default=False) 20 | tid = Column(String(20), nullable=False) 21 | 22 | def __repr__(self): 23 | return f'{self.hash_id} self.{self.content}' 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/forbidden_words.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/forbidden_words.txt -------------------------------------------------------------------------------- /src/manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | path = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.insert(0, path) 6 | sys.path.insert(1, os.path.dirname(path)) 7 | 8 | import fire 9 | from src import config 10 | from src.core.main import start_thread 11 | from src.utils.init import initialize_action, initialize_openai 12 | from src.utils.log import worker_logger 13 | from src.utils.utils import NewEventLoop, get_openai_key 14 | from src.utils.init import initialize_openai 15 | 16 | 17 | initialize_openai() 18 | logger = worker_logger 19 | 20 | 21 | class Management: 22 | def __init__(self): 23 | try: 24 | assert config.api_key_list 25 | except: 26 | raise '请填写openai -> api_key!' 27 | 28 | def action(self): 29 | loop = NewEventLoop() 30 | loop.run(initialize_action()) 31 | 32 | def run(self, name): 33 | tasks = [] 34 | self.test_plugin_dependency() 35 | if name.lower() == 'douyin': 36 | tasks.append(start_thread('dy_producer')) 37 | elif name.lower() == 'bilibili': 38 | tasks.append(start_thread('bl_producer')) 39 | tasks.append(start_thread('user_producer')) 40 | tasks.append(start_thread('consumer')) 41 | for task in tasks: 42 | task.join() 43 | 44 | def test_net(self): 45 | from langchain import OpenAI 46 | import requests 47 | # 测试外网环境(可能异常) 48 | r = requests.get(url='https://www.youtube.com/', verify=False, proxies={ 49 | 'http': f'http://{config.proxy}/', 50 | 'https': f'http://{config.proxy}/' 51 | }) 52 | assert r.status_code == 200 53 | # 测试openai库 54 | llm = OpenAI(temperature=config.temperature, openai_api_key=get_openai_key(), verbose=config.debug) 55 | text = "python是世界上最好的语言 " 56 | print(llm(text)) 57 | print('测试成功!') 58 | 59 | def test_plugin_dependency(self): 60 | if config.context_plugin: 61 | try: 62 | from pymilvus import connections, has_collection, Collection 63 | import cryptography 64 | except ImportError: 65 | raise 'Please run pip install pymilvus==2.0 cryptography parsel' 66 | 67 | try: 68 | connections.connect( 69 | alias="default", 70 | host=config.milvus['host'], 71 | port=config.milvus['port'] 72 | ) 73 | assert has_collection(config.milvus['collection']) 74 | collection = Collection(config.milvus['collection']) 75 | assert collection.num_entities != 0 76 | except Exception as e: 77 | raise e 78 | 79 | logger.info('上下文插件已开启') 80 | if config.speech_plugin: 81 | try: 82 | from src.modules.speech_rec import speech_hotkey_listener 83 | except ImportError: 84 | raise 'Please run pip install pyaudio speech_recognition keyboard' 85 | logger.info('语音交互插件已开启') 86 | if config.action_plugin: 87 | try: 88 | import pyvts 89 | except ImportError: 90 | raise 'Please run pip install pyvts,then run python manager action' 91 | logger.info('动作响应插件已开启') 92 | if config.schedule_plugin: 93 | logger.info('循环任务插件已开启') 94 | 95 | 96 | if __name__ == '__main__': 97 | """命令行启动,等同于下面的程序启动""" 98 | fire.Fire(Management) 99 | 100 | """测试""" 101 | # >> python manager.py test 102 | # Management().test_net() 103 | 104 | """启动程序""" 105 | # >> python manager.py run bilibili 106 | # Management().run('BiliBili') 107 | # Management().run('DouYin') 108 | 109 | """初始化""" 110 | # >> python manager.py action 111 | # Management().action() 112 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "xxx" 4 | -------------------------------------------------------------------------------- /src/modules/actions.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: action.py 5 | @DateTime: 2023/4/24 0:01 6 | @SoftWare: PyCharm 7 | """ 8 | from src.config import live2D_actions 9 | 10 | plugin_info = { 11 | "plugin_name": "start pyvts", 12 | "developer": "Jiran", 13 | "authentication_token_path": "./token.txt" 14 | } 15 | 16 | 17 | async def play_action(action_index): 18 | try: 19 | import pyvts 20 | except ImportError: 21 | raise 'Please run pip install pyvts' 22 | vts = pyvts.vts(plugin_info=plugin_info) 23 | await vts.connect() 24 | await vts.read_token() 25 | await vts.request_authenticate() # use token 26 | 27 | if action_index > len(live2D_actions) - 1: 28 | raise '动作不存在' 29 | send_hotkey_request = vts.vts_request.requestTriggerHotKey(live2D_actions[action_index]) 30 | await vts.request(send_hotkey_request) 31 | await vts.close() 32 | 33 | 34 | if __name__ == "__main__": 35 | action_embeddings = None # 获取token不需要运行 36 | # asyncio.run(initialize_action()) 37 | # asyncio.run(play_action()) 38 | -------------------------------------------------------------------------------- /src/modules/audio.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: audio.py 5 | @DateTime: 2023/4/24 14:16 6 | @SoftWare: PyCharm 7 | """ 8 | 9 | import edge_tts 10 | from src import config 11 | 12 | from pygame import mixer, time as pygame_time 13 | 14 | from src.utils.utils import worker_logger 15 | from src.utils.utils import audio_lock 16 | 17 | logger = worker_logger 18 | 19 | 20 | async def tts_save(text, path): 21 | # tts = edge_tts.Communicate(text=text, proxy=f'http://{config.proxy}', **config.tss_settings) 22 | tts = edge_tts.Communicate(text=text, **config.tss_settings) 23 | await tts.save(path) 24 | 25 | 26 | def play_sound(file_path): 27 | with audio_lock: 28 | # 播放生成的语音文件 29 | mixer.init() 30 | mixer.music.load(file_path) 31 | mixer.music.play() 32 | while mixer.music.get_busy(): 33 | pygame_time.Clock().tick(10) 34 | 35 | mixer.music.stop() 36 | mixer.quit() -------------------------------------------------------------------------------- /src/modules/speech_rec.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: speech_recognition.py 5 | @DateTime: 2023/5/13 18:21 6 | @SoftWare: PyCharm 7 | """ 8 | import threading 9 | import time 10 | 11 | import wave 12 | import pyaudio 13 | import speech_recognition as sr 14 | import keyboard 15 | 16 | from src.utils.events import UserEvent 17 | from src.utils.log import worker_logger 18 | from src.utils.utils import user_queue 19 | 20 | CHUNK = 1024 21 | FORMAT = pyaudio.paInt16 22 | CHANNELS = 1 23 | RATE = 16000 24 | RECORD_SECONDS = 5 25 | WAVE_OUTPUT_FILENAME = "../static/speech/{}.wav" 26 | 27 | logger = worker_logger 28 | 29 | 30 | def speech_recognition_task(): 31 | r = sr.Recognizer() 32 | 33 | # 使用PyAudio录制音频 34 | audio = pyaudio.PyAudio() 35 | stream = audio.open(format=FORMAT, channels=CHANNELS, 36 | rate=RATE, input=True, 37 | frames_per_buffer=CHUNK) 38 | frames = [] 39 | print("正在录音...") 40 | for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)): 41 | data = stream.read(CHUNK) 42 | frames.append(data) 43 | # if keyboard.is_pressed('ctrl+t'): # 按下Ctrl+T停止录音 44 | # break 45 | print("录音结束") 46 | 47 | stream.stop_stream() 48 | stream.close() 49 | audio.terminate() 50 | 51 | # 将录制的音频写入Wave文件 52 | waveFile = wave.open(WAVE_OUTPUT_FILENAME.format(int(time.time())), 'wb') 53 | waveFile.setnchannels(CHANNELS) 54 | waveFile.setsampwidth(audio.get_sample_size(FORMAT)) 55 | waveFile.setframerate(RATE) 56 | waveFile.writeframes(b''.join(frames)) 57 | waveFile.close() 58 | 59 | try: 60 | # 使用SpeechRecognition库来将音频转换为文本(非常慢) 61 | audio_file = sr.AudioFile(WAVE_OUTPUT_FILENAME.format(int(time.time()))) 62 | 63 | with audio_file as source: 64 | audio_data = r.record(source) 65 | logger.debug('正在请求语音转文字接口recognize_google...') 66 | text = r.recognize_google(audio_data, language='zh-CN') 67 | print('识别到的文字:', text) 68 | ue = UserEvent(text, '{}') 69 | ue.is_high_priority = True 70 | # ue.action = live2D_actions.index('Anim Shake') 71 | user_queue.send(ue) 72 | except sr.UnknownValueError: 73 | print("无法识别语音") 74 | except sr.RequestError as e: 75 | print(f"语音识别服务出错:{e}") 76 | 77 | 78 | def speech_hotkey_listener(): 79 | keyboard.add_hotkey('ctrl+t', speech_recognition_task) 80 | keyboard.wait() 81 | 82 | 83 | if __name__ == '__main__': 84 | t = threading.Thread(target=speech_hotkey_listener) 85 | t.start() 86 | # speech_hotkey_listener() 87 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/requirements.txt -------------------------------------------------------------------------------- /src/rooms/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: __init__.py.py 5 | @DateTime: 2023/4/25 21:55 6 | @SoftWare: PyCharm 7 | """ 8 | -------------------------------------------------------------------------------- /src/rooms/bilibili.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: room.py 5 | @DateTime: 2023/4/22 21:23 6 | @SoftWare: PyCharm 7 | """ 8 | import asyncio 9 | 10 | from bilibili_api import sync 11 | 12 | from src import config 13 | from src.utils.events import BlDanmuMsgEvent, BlSendGiftEvent, BlSuperChatMessageEvent, BlInteractWordEvent 14 | from src.utils.utils import worker_logger 15 | 16 | from src.utils.utils import user_queue 17 | 18 | 19 | logger = worker_logger 20 | 21 | 22 | class BlLiveRoom: 23 | def __init__(self, bl_room_id=config.room_id): 24 | try: 25 | from bilibili_api import live, sync 26 | except ImportError: 27 | raise 'Please run pip install bilibili-api-python' 28 | self.room = live.LiveDanmaku(bl_room_id) 29 | self.add_event_listeners() 30 | 31 | def add_event_listeners(self): 32 | listener_map = { 33 | 'DANMU_MSG': on_danmaku_event_filter, 34 | 'SUPER_CHAT_MESSAGE': on_super_chat_message_event_filter, 35 | 'SEND_GIFT': on_gift_event_filter, 36 | 'INTERACT_WORD': on_interact_word_event_filter, 37 | } 38 | for item in listener_map.items(): 39 | self.room.add_event_listener(*item) 40 | 41 | def connect(self): 42 | # loop = asyncio.get_event_loop() 43 | # loop.run_until_complete(self.room.connect()) 44 | sync(self.room.connect()) 45 | 46 | 47 | async def on_danmaku_event_filter(event_dict): 48 | # # 收到弹幕 49 | # event = BlDanmuMsgEvent.filter(event_dict) 50 | event = BlDanmuMsgEvent(event_dict) 51 | user_queue.send(event) 52 | 53 | 54 | async def on_super_chat_message_event_filter(event_dict): 55 | # SUPER_CHAT_MESSAGE 56 | # info = event['data']['data'] 57 | # user_info = info['user_info'] 58 | # print('SUPER_CHAT_MESSAGE', 59 | # user_info['uname'], 60 | # user_info['face'], 61 | # 62 | # info['message'], 63 | # info['price'], 64 | # info['start_time'], 65 | # ) 66 | event = BlSuperChatMessageEvent(event_dict) 67 | user_queue.send(event) 68 | 69 | 70 | async def on_gift_event_filter(event_dict): 71 | # 收到礼物 72 | # info = event_dict['data']['data'] 73 | # print('SEND_GIFT', 74 | # info['face'], 75 | # info['uname'], 76 | # info['action'], 77 | # info['giftName'], 78 | # info['timestamp'], 79 | # ) 80 | event = BlSendGiftEvent(event_dict) 81 | user_queue.send(event) 82 | 83 | 84 | async def on_interact_word_event_filter(event_dict): 85 | # INTERACT_WORD 86 | # info = event_dict['data']['data'] 87 | # fans_medal = info['fans_medal'] 88 | # print('INTERACT_WORD', 89 | # fans_medal['medal_name'], 90 | # fans_medal['medal_level'], 91 | # info['uname'], 92 | # info['timestamp'] 93 | # ) 94 | if not user_queue.event_queue.full(): 95 | event = BlInteractWordEvent(event_dict) 96 | user_queue.send(event) 97 | 98 | # @room.on('WELCOME') 99 | # async def on_welcome_event_filter(event): 100 | # # 老爷进入房间 101 | # print(3, event) 102 | # 103 | # 104 | # @room.on('WELCOME_GUARD') 105 | # async def on_welcome_guard_event_filter(event): 106 | # # 房管进入房间 107 | # print(4, event) 108 | 109 | if __name__ == '__main__': 110 | r = BlLiveRoom() 111 | r.connect() -------------------------------------------------------------------------------- /src/rooms/douyin.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: douyin.py 5 | @DateTime: 2023/4/25 21:56 6 | @SoftWare: PyCharm 7 | """ 8 | 9 | import aiohttp 10 | 11 | import json 12 | 13 | 14 | from src.utils.events import DyDanmuMsgEvent, DyAttentionEvent, DySendGiftEvent, DyCkEvent, DyWelcomeWordEvent 15 | from src.utils.utils import user_queue 16 | 17 | 18 | def msg(event_dict): 19 | event_dict['Data'] = json.loads(event_dict['Data']) 20 | event = DyDanmuMsgEvent(event_dict) 21 | user_queue.send(event) 22 | 23 | 24 | def ck(event_dict): # type2 25 | if not user_queue.event_queue.full(): 26 | event_dict['Data'] = json.loads(event_dict['Data']) 27 | event = DyCkEvent(event_dict) 28 | # print("感谢" + load_data.get("Content")) 29 | user_queue.send(event) 30 | 31 | 32 | def welcome(event_dict): # type3 33 | if not user_queue.event_queue.full(): 34 | event_dict['Data'] = json.loads(event_dict['Data']) 35 | event = DyWelcomeWordEvent(event_dict) 36 | # print("欢迎:" + json2["Nickname"]) 37 | user_queue.send(event) 38 | 39 | 40 | def Gift(event_dict): # type5 41 | event_dict['Data'] = json.loads(event_dict['Data']) 42 | event = DySendGiftEvent(event_dict) 43 | user_queue.send(event) 44 | 45 | 46 | def attention(event_dict): 47 | event_dict['Data'] = json.loads(event_dict['Data']) 48 | event = DyAttentionEvent(event_dict) 49 | user_queue.send(event) 50 | 51 | 52 | async def dy_connect(): 53 | session = aiohttp.ClientSession() 54 | async with session.ws_connect("ws://127.0.0.1:8888") as ws: 55 | await ws.send_str('token') 56 | async for message in ws: 57 | # print(f"Received message: {message}") 58 | # 处理websocket 消息 59 | if message.type == aiohttp.WSMsgType.TEXT: 60 | data = json.loads(message.data) 61 | event_name = data.get("Type") # 标签类型 62 | filter_map = { 63 | 1: msg, # 1用户发言 64 | 2: ck, # 2用户点赞 65 | 3: welcome, # 3用户入房 66 | 4: attention, # 用户关注 67 | 5: Gift, # 5用户礼物 68 | # 6: check, # 6人数统计 69 | } 70 | if event_name in filter_map: 71 | filter_map[event_name](data) 72 | 73 | elif msg.type == aiohttp.WSMsgType.CLOSED: 74 | break 75 | elif msg.type == aiohttp.WSMsgType.ERROR: 76 | break 77 | 78 | await ws.close() 79 | 80 | 81 | if __name__ == '__main__': 82 | from bilibili_api import sync 83 | sync(dy_connect()) 84 | -------------------------------------------------------------------------------- /src/scripts/crawlers/tie_ba_spider.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | try: 4 | from parsel import Selector 5 | import pymysql 6 | except ImportError: 7 | raise 'Please run pip install parsel pymysql cryptography' 8 | 9 | 10 | from src.db.dao import get_session 11 | from src.db.models import TieBa 12 | 13 | pymysql.install_as_MySQLdb() 14 | 15 | # connecting to a MySQL database with user and password 16 | 17 | session = requests.session() 18 | prefix_url = 'https://tieba.baidu.com' 19 | ba_url = 'https://tieba.baidu.com/f?kw={kw}&ie=utf-8&tp=0&pn={pn}' 20 | tie_url = 'https://tieba.baidu.com/p/{tie_id}?pn={pn}' 21 | 22 | 23 | class Ba: 24 | 25 | def __init__(self, kw, max_ba_pn, max_tie_pn, start_ba_pn=1, start_tie_pn=1, max_str_length=400): 26 | self.kw = kw 27 | self.start_ba_pn = start_ba_pn 28 | self.start_tie_pn = start_tie_pn 29 | self.max_ba_pn = max_ba_pn 30 | self.max_tie_pn = max_tie_pn 31 | self.max_str_length = max_str_length 32 | 33 | def get_tie_list(self, pn): 34 | r = session.get(ba_url.format(kw=self.kw, pn=pn)) 35 | sl = Selector(text=r.text) 36 | 37 | href_list = sl.xpath("""//div[@class='threadlist_title pull_left j_th_tit ']//a/@href""").getall() 38 | tie_id_list = [href.split('/')[-1] for href in href_list] 39 | yield from tie_id_list 40 | 41 | def get_tie_detail(self, tie_id, pn): 42 | r = session.get(tie_url.format(tie_id=tie_id, pn=pn)) 43 | sl = Selector(text=r.text) 44 | tie_item = sl.xpath("""//div[@class='d_post_content_main ']""") 45 | 46 | data_list = [] 47 | for tie in tie_item: 48 | data_dict = { 49 | 'tid': tie_id, 50 | 'hash_id': '1', 51 | 'content': '', 52 | # 'embedding': None 53 | } 54 | content_list = tie.xpath( 55 | """.//div[@class='d_post_content j_d_post_content ']//text()""").getall() 56 | content = ' '.join(content_list).strip().replace(' ', '') 57 | data_dict['hash_id'] = str(hash(content)) 58 | data_dict['content'] = content 59 | 60 | # 过滤 61 | if len(content.replace(' ', '')) > 10: 62 | if len(content.replace(' ', '')) < self.max_str_length: 63 | data_list.append(data_dict) 64 | else: 65 | print('数据异常 ->', content) 66 | 67 | if data_list: 68 | self.save(data_list) 69 | 70 | if next_pn := sl.xpath( 71 | """//li[@class='l_pager pager_theme_5 pb_list_pager']/span/following-sibling::a[1]/text()""").get(): 72 | next_args = (tie_id, next_pn) 73 | return next_args 74 | return None 75 | 76 | def save(self, data_list): 77 | with get_session() as s: 78 | for data_dict in data_list: 79 | s.add(TieBa(**data_dict)) 80 | 81 | def depth_first_run(self): 82 | for ba_pn in range(self.start_ba_pn, self.max_ba_pn + 1): 83 | print(f'正在爬取{self.kw}吧 第{ba_pn}页...') 84 | for tid in self.get_tie_list(ba_pn): 85 | next_args = (tid, self.start_tie_pn) 86 | while 1: 87 | print(f'正在爬取{self.kw}吧{tid}贴 第{next_args[1]}页...', end='') 88 | next_args = self.get_tie_detail(*next_args) 89 | print('完成!') 90 | if not next_args or int(next_args[1]) == self.max_tie_pn+1: 91 | break 92 | print('over') 93 | 94 | 95 | if __name__ == '__main__': 96 | Ba(kw='孙笑川', max_ba_pn=1, max_tie_pn=1).depth_first_run() -------------------------------------------------------------------------------- /src/scripts/manager.py: -------------------------------------------------------------------------------- 1 | import fire 2 | 3 | from src.db.models import * 4 | from src.db.dao import Base, engine 5 | from src.scripts.crawlers.tie_ba_spider import Ba 6 | from src.scripts.workers import EmbeddingWorker 7 | 8 | 9 | class Management: 10 | @staticmethod 11 | def create_table(): 12 | Base.metadata.create_all(engine) 13 | 14 | @staticmethod 15 | def crawl(kw='孙笑川', max_ba_pn=1, max_tie_pn=1, start_ba_pn=1, start_tie_pn=1, max_str_length=400): 16 | """ 17 | 18 | Args: 19 | kw: 贴吧吧名 20 | max_ba_pn: 贴吧最大爬取页数 21 | max_tie_pn: 每个帖子最大页数 22 | start_ba_pn: 吧开始页数 23 | start_tie_pn: 开始帖子页数 24 | max_str_length: 接受最短的文本长度 25 | 26 | Returns: 27 | 28 | """ 29 | Ba(kw, max_ba_pn, max_tie_pn, start_ba_pn, start_tie_pn, max_str_length).depth_first_run() 30 | 31 | @staticmethod 32 | def embedding(name, model, desc=''): 33 | """ 34 | 35 | Args: 36 | name: Milvus集合名 37 | model: 定义mysql模型,需要继承于Document 38 | desc: 集合简介 39 | 40 | Returns: 41 | 42 | """ 43 | EmbeddingWorker(limit_num=100, embedding_query_num=5000).run(name, desc, model) 44 | 45 | def run(self, kw, name): 46 | self.create_table() 47 | self.crawl(kw=kw) 48 | self.embedding(name, model=TieBa) 49 | 50 | 51 | if __name__ == '__main__': 52 | fire.Fire(Management) 53 | -------------------------------------------------------------------------------- /src/scripts/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def filter_str(desstr, restr=''): 5 | # 过滤除中英文及数字以外的其他字符 6 | res = re.compile("[^\\u4e00-\\u9fa5^a-z^A-Z^0-9]") 7 | return res.sub(restr, desstr) 8 | 9 | 10 | def num_tokens_from_string(string: str, encoding_name: str = 'cl100k_base') -> int: 11 | import tiktoken 12 | """Returns the number of tokens in a text string.""" 13 | encoding = tiktoken.get_encoding(encoding_name) 14 | # encoding = tiktoken.get_encoding(encoding_name) 15 | num_tokens = len(encoding.encode(string)) 16 | return num_tokens 17 | -------------------------------------------------------------------------------- /src/scripts/workers.py: -------------------------------------------------------------------------------- 1 | from pymilvus import FieldSchema, Collection, CollectionSchema, DataType, has_collection 2 | 3 | from src import config 4 | from src.db.dao import get_session 5 | from src.db.models import TieBa 6 | from src.utils.init import initialize_openai 7 | from src.utils.utils import sync_get_embedding 8 | 9 | 10 | class EmbeddingWorker: 11 | 12 | def __init__(self, limit_num=100, embedding_query_num=5000): 13 | try: 14 | from pymilvus import FieldSchema, DataType 15 | from pymilvus import connections 16 | except ImportError: 17 | raise 'Please run pip install pymilvus==2.1' 18 | 19 | initialize_openai() 20 | connections.connect( 21 | alias="default", 22 | host=config.milvus['host'], 23 | port=config.milvus['port'] 24 | ) 25 | 26 | self.limit_num = limit_num 27 | self.embedding_query_max_length = embedding_query_num 28 | 29 | def search_rows_no_embedding(self, model: Document): 30 | while 1: 31 | with get_session() as s: 32 | rows = s.query(model).filter(model.embedding_state == False).limit(self.limit_num).all() 33 | if not rows: 34 | break 35 | rows_no_embedding = [] 36 | rows_content_list = [] 37 | current_length = 0 38 | while 1: 39 | if rows: 40 | row = rows.pop() 41 | rows_no_embedding.append(row) 42 | rows_content_list.append(row.content) 43 | current_length += len(row.content) 44 | else: 45 | break 46 | if current_length > self.embedding_query_max_length or not rows: 47 | yield rows_no_embedding, rows_content_list 48 | rows_no_embedding = [] 49 | rows_content_list = [] 50 | current_length = 0 51 | 52 | def query_embedding(self, rows_no_embedding, rows_content_list): 53 | embedding = sync_get_embedding(rows_content_list) 54 | # content = rows_content_list 55 | hash_id = [int(row.hash_id) for row in rows_no_embedding] 56 | return [hash_id, embedding] 57 | 58 | @staticmethod 59 | def create_collection(name, desc): 60 | # 创建集合 61 | hash_id = FieldSchema( 62 | name="hash_id", 63 | dtype=DataType.INT64, 64 | is_primary=True, 65 | ) 66 | # content = FieldSchema( 67 | # name="content", 68 | # dtype=DataType.VARCHAR, 69 | # ) 70 | embedding = FieldSchema( 71 | name="embedding", 72 | dtype=DataType.FLOAT_VECTOR, 73 | dim=1536 74 | ) 75 | schema = CollectionSchema( 76 | fields=[hash_id, embedding], 77 | description=desc 78 | ) 79 | collection_name = name 80 | collection = Collection( 81 | name=collection_name, 82 | schema=schema, 83 | using='default', 84 | shards_num=2, 85 | consistency_level="Strong" 86 | ) 87 | return collection 88 | 89 | @staticmethod 90 | def create_index(collection): 91 | # 创建索引 92 | index_params = { 93 | "metric_type": "L2", 94 | "index_type": "IVF_FLAT", 95 | "params": {"nlist": 1024} 96 | } 97 | collection.create_index( 98 | field_name="embedding", 99 | index_params=index_params 100 | ) 101 | 102 | @staticmethod 103 | def push_2_milvus(collection, entries): 104 | # 插入数据 105 | try: 106 | collection.insert(entries) 107 | print(f'插入成功,共{len(entries[0])}条') 108 | return True 109 | except Exception as e: 110 | print('插入失败') 111 | raise e 112 | return False 113 | 114 | def run(self, name, desc, model): 115 | if not has_collection(name): 116 | collection = self.create_collection(name, desc) 117 | self.create_index(collection) 118 | else: 119 | collection = Collection(name) 120 | 121 | print(collection.num_entities) 122 | for rows_no_embedding, rows_content_list in self.search_rows_no_embedding(model): 123 | entries = self.query_embedding(rows_no_embedding, rows_content_list) 124 | if self.push_2_milvus(collection, entries): 125 | with get_session() as s: 126 | for tie in rows_no_embedding: 127 | tie.embedding_state = True 128 | print('over') 129 | print(collection.num_entities) 130 | 131 | 132 | if __name__ == '__main__': 133 | # with get_session() as s: 134 | # s.query(TieBa).update({'embedding_state': False}) 135 | EmbeddingWorker().run('sun_ba', '孙笑川吧', TieBa) -------------------------------------------------------------------------------- /src/static/speech/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/static/speech/.gitkeep -------------------------------------------------------------------------------- /src/static/voice/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/static/voice/.gitkeep -------------------------------------------------------------------------------- /src/token.txt: -------------------------------------------------------------------------------- 1 | 0297c744c0fafa2d8b71bd64594ea390cb5e6411c63eb46e0040d42109c0aae9 -------------------------------------------------------------------------------- /src/utils/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: base.py 5 | @DateTime: 2023/4/22 22:19 6 | @SoftWare: PyCharm 7 | """ 8 | from abc import abstractmethod 9 | from functools import cached_property 10 | from typing import Union 11 | 12 | import time 13 | 14 | from src import config 15 | from src.config import live2D_actions 16 | from src.utils.prompt_temple import get_chat_prompt_template 17 | 18 | 19 | class Event: 20 | def __init__(self, event_dict): 21 | self._event_dict = event_dict 22 | self._event_name = event_dict.get('type', '') or event_dict.get('Type', '') 23 | if not self._event_name: 24 | raise 25 | 26 | self._kwargs = self.get_kwargs() 27 | self._action = None 28 | # 是否优先处理 29 | self.is_high_priority = False 30 | 31 | @abstractmethod 32 | def get_kwargs(self): 33 | """初始化event中有用的数据""" 34 | return { 35 | 'time': None 36 | } 37 | 38 | @property 39 | @abstractmethod 40 | def prompt_kwargs(self): 41 | """提示模板需要用到的数据""" 42 | return { 43 | 'time': None 44 | } 45 | 46 | @property 47 | @abstractmethod 48 | def human_template(self): 49 | """每类event对应的模板""" 50 | return '{text}' 51 | 52 | def get_prompt_messages(self, **kwargs): 53 | """出口函数,生成prompt,给到llm调用""" 54 | if config.context_plugin and 'context' in kwargs: 55 | human_template = '上下文:{context}\n问题:' + self.human_template 56 | else: 57 | human_template = self.human_template 58 | return get_chat_prompt_template(human_template).format_prompt(**self.prompt_kwargs, **kwargs).to_messages() 59 | 60 | @abstractmethod 61 | def get_audio_txt(self, *args, **kwargs): 62 | """数字人说的话""" 63 | return None 64 | 65 | @property 66 | def time(self): 67 | """易读的时间""" 68 | return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self._kwargs['time'])) 69 | 70 | @property 71 | def action(self) -> Union[None, str, int]: 72 | """ 73 | :return: 74 | None: 该event不做任何动作 75 | str: zero-shot 匹配动作 76 | int: 通过索引固定 做某个动作 77 | """ 78 | return self._action 79 | 80 | @action.setter 81 | def action(self, value: Union[None, str, int]): 82 | if value in live2D_actions: 83 | self._action = live2D_actions.index(value) 84 | self._action = value -------------------------------------------------------------------------------- /src/utils/dfa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | 4 | from src.config import keyword_str_list 5 | 6 | 7 | class DFA: 8 | 9 | def __init__(self, keyword_list: list): 10 | self.state_event_dict = self._generate_state_event_dict(keyword_list) 11 | 12 | def match(self, content: str): 13 | match_list = [] 14 | state_list = [] 15 | temp_match_list = [] 16 | 17 | for char_pos, char in enumerate(content): 18 | if char in self.state_event_dict: 19 | state_list.append(self.state_event_dict) 20 | temp_match_list.append({ 21 | "start": char_pos, 22 | "match": "" 23 | }) 24 | 25 | for index, state in enumerate(state_list): 26 | is_find = False 27 | state_char = None 28 | 29 | # 如果是 * 则匹配所有内容 30 | if "*" in state: 31 | state_list[index] = state["*"] 32 | state_char = state["*"] 33 | is_find = True 34 | 35 | if char in state: 36 | state_list[index] = state[char] 37 | state_char = state[char] 38 | is_find = True 39 | 40 | if is_find: 41 | temp_match_list[index]["match"] += char 42 | 43 | if state_char["is_end"]: 44 | match_list.append(copy.deepcopy(temp_match_list[index])) 45 | 46 | if len(state_char.keys()) == 1: 47 | state_list.pop(index) 48 | temp_match_list.pop(index) 49 | else: 50 | state_list.pop(index) 51 | temp_match_list.pop(index) 52 | 53 | return match_list 54 | 55 | @staticmethod 56 | def _generate_state_event_dict(keyword_list: list) -> dict: 57 | state_event_dict = {} 58 | 59 | for keyword in keyword_list: 60 | current_dict = state_event_dict 61 | length = len(keyword) 62 | 63 | for index, char in enumerate(keyword): 64 | if char not in current_dict: 65 | next_dict = {"is_end": False} 66 | current_dict[char] = next_dict 67 | else: 68 | next_dict = current_dict[char] 69 | current_dict = next_dict 70 | if index == length - 1: 71 | current_dict["is_end"] = True 72 | 73 | return state_event_dict 74 | 75 | 76 | if __name__ == "__main__": 77 | dfa = DFA(keyword_str_list) 78 | print(dfa.match("信息抽我取之 DFA 算法匹配关键词,匹配算法")) 79 | -------------------------------------------------------------------------------- /src/utils/events.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: events.py 5 | @DateTime: 2023/4/23 16:05 6 | @SoftWare: PyCharm 7 | """ 8 | import time 9 | 10 | from src.utils.utils import Event 11 | 12 | 13 | class BlDanmuMsgEvent(Event): 14 | 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.action = self._kwargs['content'] 18 | 19 | def get_kwargs(self): 20 | return { 21 | 'content': self._event_dict['data']['info'][1], 22 | 'user_name': self._event_dict['data']['info'][2][1], 23 | 'time': self._event_dict['data']['info'][9]['ts'] 24 | } 25 | 26 | @property 27 | def prompt_kwargs(self): 28 | return { 29 | 'text': self._kwargs['content'] 30 | } 31 | 32 | @property 33 | def human_template(self): 34 | return '{text}' 35 | 36 | def get_audio_txt(self, gpt_resp): 37 | return f"{self._kwargs['content']} {gpt_resp}" 38 | 39 | 40 | class BlSuperChatMessageEvent(Event): 41 | 42 | def __init__(self, *args, **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.is_high_priority = True 45 | self.action = self._kwargs['message'] 46 | 47 | def get_kwargs(self): 48 | info = self._event_dict['data']['data'] 49 | user_info = info['user_info'] 50 | return { 51 | 'user_name': user_info['uname'], 52 | 'face': user_info['face'], 53 | 'message': info['message'], 54 | 'price': info['price'], 55 | 'time': info['start_time'], 56 | } 57 | 58 | @property 59 | def prompt_kwargs(self): 60 | return { 61 | 'message': self._kwargs['message'] 62 | } 63 | 64 | @property 65 | def human_template(self): 66 | return '{message}' 67 | 68 | def get_audio_txt(self, gpt_resp): 69 | return f"感谢{self._kwargs['user_name']}的sc。{self._kwargs['message']} {gpt_resp}" 70 | 71 | 72 | class BlSendGiftEvent(Event): 73 | 74 | def get_kwargs(self): 75 | info = self._event_dict['data']['data'] 76 | return { 77 | 'user_name': info['uname'], 78 | 'face': info['face'], 79 | 'action': info['action'], 80 | 'giftName': info['giftName'], 81 | 'time': info['timestamp'], 82 | 83 | } 84 | 85 | @property 86 | def prompt_kwargs(self): 87 | return { 88 | 'content': f"{self._kwargs['user_name']}{self._kwargs['action']}了{self._kwargs['giftName']}。" 89 | } 90 | 91 | @property 92 | def human_template(self): 93 | return ( 94 | "{content}" 95 | "请表示感谢,说一句赞美他的话!" 96 | ) 97 | 98 | def get_audio_txt(self, gpt_resp): 99 | return f"{self.prompt_kwargs['content']} {gpt_resp}" 100 | 101 | 102 | class BlInteractWordEvent(Event): 103 | 104 | def __init__(self, *args, **kwargs): 105 | super().__init__(*args, **kwargs) 106 | 107 | def get_kwargs(self): 108 | info = self._event_dict['data']['data'] 109 | fans_medal = info['fans_medal'] 110 | return { 111 | 'medal_name': fans_medal['medal_name'], 112 | 'medal_level': fans_medal['medal_level'], 113 | 'user_name': info['uname'], 114 | 'content': f"{info['uname']} 进入直播间。", 115 | 'time': info['timestamp'] 116 | } 117 | 118 | @property 119 | def prompt_kwargs(self): 120 | return { 121 | 'content': self._kwargs['content'], 122 | 'medal_name': self._kwargs['medal_name'] 123 | } 124 | 125 | @property 126 | def human_template(self): 127 | return ( 128 | "{content}" 129 | "请表示欢迎!并简短聊聊他加入的粉丝团{medal_name}" 130 | ) 131 | 132 | def get_audio_txt(self, gpt_resp): 133 | return f"{gpt_resp}" 134 | 135 | 136 | class DyDanmuMsgEvent(BlDanmuMsgEvent): 137 | 138 | def get_kwargs(self): 139 | return { 140 | 'content': self._event_dict['Data']['Content'], 141 | 'user_name': self._event_dict['Data']['User']['Nickname'], 142 | 'time': int(time.time()) 143 | } 144 | 145 | 146 | class DyCkEvent(Event): 147 | 148 | def get_kwargs(self): 149 | return { 150 | 'user_name': self._event_dict['Data']['User']['Nickname'], 151 | 'content': self._event_dict['Data']['Content'], 152 | 'time': int(time.time()) 153 | } 154 | 155 | @property 156 | def prompt_kwargs(self): 157 | return { 158 | 'content': self._kwargs['user_name'] 159 | } 160 | 161 | @property 162 | def human_template(self): 163 | return ( 164 | "{content}给你点了赞," 165 | "请表示感谢!" 166 | ) 167 | 168 | def get_audio_txt(self, gpt_resp): 169 | return f"{gpt_resp}" 170 | 171 | 172 | class DyWelcomeWordEvent(Event): 173 | 174 | def get_kwargs(self): 175 | return { 176 | 'user_name': self._event_dict['Data']['User']['Nickname'], 177 | 'time': int(time.time()) 178 | } 179 | 180 | @property 181 | def prompt_kwargs(self): 182 | return { 183 | 'user_name': self._kwargs['user_name'] 184 | } 185 | 186 | @property 187 | def human_template(self): 188 | return ( 189 | "{user_name},进入直播间" 190 | "请表示欢迎!并简短聊聊他的名字" 191 | ) 192 | 193 | def get_audio_txt(self, gpt_resp): 194 | return f"{gpt_resp}" 195 | 196 | 197 | class DySendGiftEvent(Event): 198 | 199 | def __init__(self, *args, **kwargs): 200 | super().__init__(*args, **kwargs) 201 | self.is_high_priority = True 202 | 203 | def get_kwargs(self): 204 | return { 205 | 'user_name': self._event_dict['Data']['User']['Nickname'], 206 | 'giftName': self._event_dict['Data']['GiftName'], 207 | 'content': self._event_dict['Data']['Content'], 208 | 'time': int(time.time()) 209 | } 210 | 211 | @property 212 | def prompt_kwargs(self): 213 | return { 214 | 'content': self._kwargs['content'] 215 | } 216 | 217 | @property 218 | def human_template(self): 219 | return ( 220 | "{content}" 221 | "请表示感谢,说一句赞美他的话!" 222 | ) 223 | 224 | def get_audio_txt(self, gpt_resp): 225 | return f"{self._kwargs['user_name']} 送出{self._kwargs['giftName']}!{gpt_resp}" 226 | 227 | 228 | class DyAttentionEvent(Event): 229 | 230 | def get_kwargs(self): 231 | return { 232 | 'user_name': self._event_dict['data']['User']['Nickname'], 233 | 'content': self._event_dict['data']['Content'], 234 | 'time': int(time.time()) 235 | } 236 | 237 | @property 238 | def prompt_kwargs(self): 239 | return { 240 | 'user_name': self._kwargs['user_name'] 241 | } 242 | 243 | @property 244 | def human_template(self): 245 | return ( 246 | "{user_name}关注了你!" 247 | "请表示感谢,说一句赞美他的话!" 248 | ) 249 | 250 | def get_audio_txt(self, gpt_resp): 251 | return gpt_resp 252 | 253 | 254 | class UserEvent(Event): 255 | 256 | def __init__(self, content, audio_txt_temple): 257 | super(UserEvent, self).__init__({'type': 'user_event'}) 258 | self.content = content 259 | self.audio_txt_temple = audio_txt_temple 260 | 261 | def get_kwargs(self): 262 | return { 263 | 'time': int(time.time()) 264 | } 265 | 266 | @property 267 | def prompt_kwargs(self): 268 | return { 269 | 'content': self.content 270 | } 271 | 272 | @property 273 | def human_template(self): 274 | return '{content}' 275 | 276 | def get_audio_txt(self, gpt_resp): 277 | return self.audio_txt_temple.format(gpt_resp) 278 | 279 | -------------------------------------------------------------------------------- /src/utils/init.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: monkey_patch.py 5 | @DateTime: 2023/5/13 12:53 6 | @SoftWare: PyCharm 7 | 先导入这个文件,改写lib 8 | """ 9 | import configparser 10 | import json 11 | import os 12 | import random 13 | from contextlib import asynccontextmanager 14 | from typing import AsyncIterator 15 | 16 | import aiohttp 17 | import urllib3 18 | from aiohttp import TCPConnector 19 | from openai import api_requestor 20 | import openai 21 | 22 | from src.modules.actions import plugin_info 23 | 24 | urllib3.disable_warnings() 25 | 26 | 27 | def initialize_openai(): 28 | # 避免循环导入,多一次读取ini配置 29 | file_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../config.ini') 30 | _config = configparser.RawConfigParser() 31 | _config.read(file_path) 32 | proxy = _config.get('other', 'proxy') 33 | if proxy: 34 | os.environ["http_proxy"] = f'http://{proxy}/' 35 | os.environ["https_proxy"] = f'http://{proxy}/' 36 | 37 | @asynccontextmanager 38 | async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]: 39 | async with aiohttp.ClientSession(connector=TCPConnector(limit_per_host=5, ssl=False), trust_env=True) as session: 40 | # async with aiohttp.ClientSession(trust_env=True) as session: 41 | yield session 42 | 43 | api_requestor.aiohttp_session = aiohttp_session 44 | 45 | 46 | async def initialize_action(): 47 | # websocket连接 获取token到本地 48 | try: 49 | import pyvts 50 | except ImportError: 51 | raise 'Please run pip install pyvts' 52 | vts = pyvts.vts(plugin_info=plugin_info) 53 | try: 54 | await vts.connect() 55 | except ConnectionRefusedError: 56 | raise '请先打开VTS,并打开API开关!' 57 | print('请在live2D VTS弹窗中点击确认!') 58 | await vts.request_authenticate_token() # get token 59 | await vts.write_token() 60 | await vts.request_authenticate() # use token 61 | 62 | response_data = await vts.request(vts.vts_request.requestHotKeyList()) 63 | hotkey_list = [] 64 | for hotkey in response_data['data']['availableHotkeys']: 65 | hotkey_list.append(hotkey['name']) 66 | print('读取到所有模型动作:', hotkey_list) 67 | 68 | # 请求embedding 69 | print('请求embedding模型中...') 70 | try: 71 | initialize_openai() 72 | res = await openai.Embedding.acreate(input=hotkey_list, model="text-embedding-ada-002") 73 | action_embeddings = [d['embedding'] for d in res['data']] 74 | action_dict = dict(zip(hotkey_list, action_embeddings)) 75 | print(len(action_dict)) 76 | except Exception as e: 77 | print('很可能是翻墙有问题') 78 | raise e 79 | 80 | # 写入 81 | with open("../action.json", "w") as dump_f: 82 | json.dump(action_dict, dump_f) 83 | 84 | # 测试 85 | assert len(hotkey_list) == len(action_dict.keys()) # vts 和 本地的动作是否一致 86 | assert len(hotkey_list) not in (0, 1) # 动作太少 87 | action = random.choice(hotkey_list) 88 | print('随机播放动作测试...', action) 89 | send_hotkey_request = vts.vts_request.requestTriggerHotKey(action) 90 | await vts.request(send_hotkey_request) 91 | await vts.close() -------------------------------------------------------------------------------- /src/utils/log.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: logging_module.py 5 | @DateTime: 2023/3/20 23:14 6 | @SoftWare: PyCharm 7 | """ 8 | import datetime 9 | import logging 10 | import logging.handlers 11 | import os 12 | 13 | import colorlog 14 | 15 | from src import config 16 | from src.config import root_path 17 | 18 | 19 | class Logging: 20 | def __init__(self, log_file_name, log_file_path=os.path.join(root_path, 'logs')): 21 | """ 22 | :param log_file_path: 23 | 1、print(os.getcwd()) # 获取当前工作目录路径 24 | 2、print(os.path.abspath('.')) # 获取当前工作目录路径 25 | :param log_file_name: 26 | 1、current_work_dir = os.path.dirname(__file__) # 当前文件所在的目录 27 | 2、weight_path = os.path.join(current_work_dir, weight_path) # 再加上它的相对路径,这样可以动态生成绝对路径 28 | """ 29 | self.log_colors_config = { 30 | 'DEBUG': 'cyan', # cyan white 31 | 'INFO': 'green', 32 | 'WARNING': 'yellow', 33 | 'ERROR': 'red', 34 | 'CRITICAL': 'bold_red', 35 | } 36 | # log文件存储路径 37 | self.log_file_path = log_file_path 38 | self.log_file_name = log_file_name 39 | self._log_filename = self.get_log_filename() 40 | 41 | # 创建一个日志对象 42 | self._logger = logging.getLogger(self.log_file_name) 43 | 44 | # 设置控制台日志的输出级别: 级别排序:CRITICAL > ERROR > WARNING > INFO > DEBUG 45 | if config.debug is True: 46 | self.set_console_logger() 47 | self.set_file_logger() 48 | self._logger.setLevel(logging.DEBUG) # 大于info级别的日志信息都会被输出 49 | else: 50 | self.set_file_logger() 51 | self._logger.setLevel(logging.INFO) # 大于info级别的日志信息都会被输出 52 | # self._logger.setLevel(logging.DEBUG) 53 | 54 | def get_log_filename(self): 55 | if not os.path.isdir(self.log_file_path): 56 | # 创建文件夹 57 | os.makedirs(self.log_file_path) 58 | return f"{self.log_file_path}/{self.log_file_name}_{str(datetime.date.today())}.log" 59 | 60 | def set_console_logger(self): 61 | formatter = colorlog.ColoredFormatter( 62 | # fmt='%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s -> %(funcName)s line:%(lineno)d [%(levelname)s] : %(message)s', 63 | fmt='%(log_color)s[%(asctime)s] %(levelname)s [%(funcName)s] - %(message)s', 64 | datefmt='%Y-%m-%d %H:%M:%S', 65 | log_colors=self.log_colors_config) 66 | console_handler = logging.StreamHandler() 67 | console_handler.setLevel(logging.DEBUG) 68 | console_handler.setFormatter(formatter) 69 | self._logger.addHandler(console_handler) 70 | 71 | def set_file_logger(self): 72 | # 日志文件信息输出格式 73 | formatter = logging.Formatter( 74 | '%(asctime)s - %(filename)s -> %(funcName)s - line:%(lineno)d - %(levelname)s: %(message)s') 75 | # 将输出日志信息保存到文件中 76 | file_handler = logging.handlers.RotatingFileHandler( 77 | self._log_filename, maxBytes=10485760, backupCount=5, encoding="utf-8") 78 | file_handler.setLevel(logging.INFO) 79 | file_handler.setFormatter(formatter) 80 | self._logger.addHandler(file_handler) 81 | 82 | def get_logger(self): 83 | return self._logger 84 | 85 | 86 | # request_logger = Logging("Request").get_logger() 87 | worker_logger = Logging("Worker").get_logger() 88 | 89 | 90 | if __name__ == '__main__': 91 | worker_logger.debug('123') -------------------------------------------------------------------------------- /src/utils/prompt_temple.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: prompt_temple.py 5 | @DateTime: 2023/5/13 11:37 6 | @SoftWare: PyCharm 7 | """ 8 | import random 9 | 10 | from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate 11 | 12 | from src import config 13 | 14 | system_template = ( 15 | "请扮演一个傲娇的AI虚拟主播,不要回答任何敏感问题!" 16 | "不要强调你是虚拟主播,不准道歉,只需要回答问题!" 17 | "你是一个带有一定的高傲情绪,可能会在自我表现时表现出一些自大或骄傲的情绪,或者表现出一些自信或者不屑一顾的态度。" 18 | "有时会有一些任性或者固执的表现。在做出决定时可能会表现出一定的倔强和固执," 19 | "可能不太容易妥协或者接受别人的意见。同时也有一定的可爱和迷人之处。" 20 | "她们可能会有一些俏皮或者撒娇的表现,或者在和自己关心的人相处时表现得特别温柔和可爱。" 21 | ) 22 | 23 | if config.context_plugin: 24 | system_template += ( 25 | "\n\n" 26 | "请你重复上下文的内容和口吻,做出对观众问题的回应,你的回答里不能说你参考了上下文。" 27 | "" 28 | ) 29 | system_message_prompt = SystemMessagePromptTemplate.from_template(system_template) 30 | 31 | 32 | def get_chat_prompt_template(human_template="{text}") -> ChatPromptTemplate: 33 | human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) 34 | chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) 35 | return chat_prompt 36 | 37 | 38 | schedule_task_temple_list = [ 39 | ('我想让你扮演说唱歌手。您将想出强大而有意义的歌词、节拍和节奏,让听众“惊叹”。你的歌词应该有一个有趣的含义和信息,' 40 | '人们也可以联系起来。在选择节拍时,请确保它既朗朗上口又与你的文字相关,这样当它们组合在一起时,每次都会发出爆炸声!' 41 | '我的第一个请求是“我需要一首关于在你自己身上寻找力量的说唱歌曲。', '接下来我给大家表演一首说唱!{}'), 42 | 43 | ('我要你扮演诗人。你将创作出能唤起情感并具有触动人心的力量的诗歌。写任何主题或主题,' 44 | '但要确保您的文字以优美而有意义的方式传达您试图表达的感觉。您还可以想出一些短小的诗句,这些诗句仍然足够强大,可以在读者的脑海中留下印记。' 45 | '我的第一个请求是“我需要一首关于爱情的诗”。', '{}'), 46 | 47 | ('我希望你充当励志演说家。将能够激发行动的词语放在一起,让人们感到有能力做一些超出他们能力的事情。你可以谈论任何话题,但目的是确保你所说的话能引起听众的共鸣,激励他们努力实现自己的目标并争取更好的可能性。' 48 | '我的第一个请求是“我需要一个关于每个人如何永不放弃的演讲”。', '全体注意,我将发表一个演讲!{}'), 49 | 50 | ("我要你担任哲学老师。我会提供一些与哲学研究相关的话题,你的工作就是用通俗易懂的方式解释这些概念。" 51 | "这可能包括提供示例、提出问题或将复杂的想法分解成更容易理解的更小的部分。" 52 | "我的第一个请求是“我需要帮助来理解不同的哲学理论如何应用于日常生活。", '我想到一个深奥的道理,{}'), 53 | 54 | ("我想让你扮演讲故事的角色。您将想出引人入胜、富有想象力和吸引观众的有趣故事。它可以是童话故事、教育故事或任何其他类型的故事," 55 | "有可能吸引人们的注意力和想象力。根据目标受众,您可以为讲故事环节选择特定的主题或主题,例如,如果是儿童,则可以谈论动物;如果是成年人," 56 | "那么基于历史的故事可能会更好地吸引他们等等。请开始你的讲述?", '我给大家讲个故事吧{}') 57 | ] 58 | 59 | 60 | def get_schedule_task() -> (str, str): 61 | return random.choice(schedule_task_temple_list) -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: jiran 3 | @Email: jiran214@qq.com 4 | @FileName: utils.py 5 | @DateTime: 2023/4/23 17:00 6 | @SoftWare: PyCharm 7 | """ 8 | import asyncio 9 | import queue 10 | import random 11 | import time 12 | from collections import deque 13 | 14 | from threading import Lock 15 | from dataclasses import dataclass 16 | 17 | from typing import List, Union, Dict 18 | 19 | import numpy as np 20 | import openai 21 | from scipy import spatial 22 | 23 | from src import config 24 | from src.utils.base import Event 25 | from src.utils.log import worker_logger 26 | 27 | logger = worker_logger 28 | audio_lock = Lock() 29 | 30 | 31 | class NewEventLoop: 32 | def __init__(self): 33 | # self.loop = asyncio.new_event_loop() 34 | # asyncio.set_event_loop(self.loop) 35 | self.loop = asyncio.get_event_loop() 36 | 37 | def run(self, coroutine): 38 | self.loop.run_until_complete(coroutine) 39 | 40 | 41 | @dataclass 42 | class GPT35Params: 43 | messages: List # 聊天格式的输入消息列表 44 | model: str = "gpt-3.5-turbo" # 模型 ID,只支持 gpt-3.5-turbo 和 gpt-3.5-turbo-0301 45 | temperature: float = 1.5 # 采样温度,0~2 范围内的浮点数。较大的值会使输出更随机,较小的值会使输出更确定 46 | top_p: float = 1.0 # 替代采样温度的另一种方式,称为 nucleus 采样,只考虑概率质量排名前 top_p 的 token。范围在 0~1 之间 47 | n: int = 1 # 每个输入消息要生成的聊天完成选项数量,默认为 1 48 | stream: bool = False # 是否启用流式输出 49 | stop: Union[str, List[str], None] = None # 最多 4 个序列,当 API 生成的 token 包含任意一个序列时停止生成 50 | max_tokens: int = 1000 # 默认inf # 生成的答案中允许的最大 token 数量,默认为 (4096 - prompt tokens) 51 | presence_penalty: float = None # 0.0 # -2.0 到 2.0 之间的数字,用于基于新 token 是否出现在已有文本中惩罚模型。正数值会增加模型谈论新话题的可能性 52 | frequency_penalty: float = None # 0.0 # -2.0 到 2.0 之间的数字,用于基于新 token 是否在已有文本中的频率惩罚模型。正数值会降低模型直接重复相同文本的可能性 53 | logit_bias: Dict[str, float] = None # 一个将 token ID 映射到关联偏差值(-100 到 100)的 JSON 对象,用于修改指定 token 出现在完成中的可能性 54 | user: str = None # 表示最终用户的唯一标识符,可帮助 OpenAI 监视和检测滥用 55 | 56 | def dict(self, exclude_defaults=False, exclude_none=True): 57 | res = {} 58 | for k, v in self.__dict__.items(): 59 | if exclude_none: 60 | if v is None: 61 | continue 62 | res[k] = v 63 | return res 64 | # dict(exclude_defaults=False, exclude_none=True) 65 | 66 | 67 | class UserQueue: 68 | maxsize = 15 69 | 70 | def __init__(self): 71 | self.high_priority_event_queue = queue.Queue() 72 | self.event_queue = queue.Queue(self.maxsize) 73 | 74 | def send(self, event: Union[Event, None]): 75 | if not event: 76 | logger.debug(f'过滤:{event}') 77 | return 78 | # print(event._event_name) 79 | # Check if high-priority queue is full 80 | if event.is_high_priority: 81 | # Add object to high-priority queue 82 | self.high_priority_event_queue.put_nowait(event) 83 | else: 84 | # Check if main queue is full 85 | if not self.event_queue.full(): 86 | # Add object to main queue 87 | self.event_queue.put_nowait(event) 88 | else: 89 | # Remove oldest item from queue and add new item 90 | self.event_queue.get() 91 | self.event_queue.put_nowait(event) 92 | 93 | def recv(self) -> Union[None, Event]: 94 | # Check high-priority queue first 95 | if not self.high_priority_event_queue.empty(): 96 | # Remove oldest item from high-priority queue and process it 97 | event = self.high_priority_event_queue.get() 98 | elif not self.event_queue.empty(): 99 | # Remove oldest item from main queue and process it 100 | event = self.event_queue.get() 101 | else: 102 | event = None 103 | return event 104 | 105 | def __str__(self): 106 | return f"event_queue 数量:{self.event_queue.qsize()} high_priority_event_queue 数量:{self.high_priority_event_queue.qsize()}" 107 | 108 | 109 | user_queue = UserQueue() 110 | 111 | 112 | class FixedLengthTSDeque: 113 | 114 | def __init__(self, per_minute_times): 115 | self._per_minute_times = per_minute_times 116 | self._data_deque = deque(maxlen=per_minute_times) 117 | 118 | def _append(self) -> None: 119 | self._data_deque.append(int(time.time())) 120 | 121 | def can_append(self): 122 | if len(self._data_deque) < self._per_minute_times or (int(time.time()) - self._data_deque[0]) < 60: 123 | return True 124 | return False 125 | 126 | def acquire(self): 127 | if not self.can_append(): 128 | logger.warning('当前api_key调用频率超过3次/min,建议在config -> openai添加api_key') 129 | return False 130 | else: 131 | self._append() 132 | return True 133 | 134 | 135 | def top_n_indices_from_embeddings( 136 | query_embedding: List[float], 137 | embeddings: List[List[float]], 138 | distance_metric="cosine", 139 | top=1 140 | ) -> list: 141 | """Return the distances between a query embedding and a list of embeddings.""" 142 | distance_metrics = { 143 | "cosine": spatial.distance.cosine, 144 | "L1": spatial.distance.cityblock, 145 | "L2": spatial.distance.euclidean, 146 | "Linf": spatial.distance.chebyshev, 147 | } 148 | distances = [ 149 | distance_metrics[distance_metric](query_embedding, embedding) 150 | for embedding in embeddings 151 | ] 152 | top_n_indices = np.argsort(distances)[:top] 153 | return top_n_indices 154 | 155 | 156 | def sync_get_embedding(texts: List[str], model="text-embedding-ada-002"): 157 | res = openai.Embedding.create(input=texts, model=model, api_key=get_openai_key()) 158 | if isinstance(texts, list) and len(texts) == 1: 159 | return res['data'][0]['embedding'] 160 | else: 161 | return [d['embedding'] for d in res['data']] 162 | 163 | 164 | api_key_limit_dict = dict(zip(config.api_key_list, [FixedLengthTSDeque(3) for _ in config.api_key_list])) 165 | current_api_key = config.api_key_list[0] 166 | 167 | 168 | def get_openai_key(): 169 | global current_api_key 170 | times = 0 171 | while 1: 172 | times = times + 1 173 | if api_key_limit_dict[current_api_key].acquire(): 174 | return current_api_key 175 | else: 176 | current_api_key = config.api_key_list[(config.api_key_list.index(current_api_key)+1) % len(config.api_key_list)] 177 | # print('switch', current_api_key) 178 | if times > len(config.api_key_list): 179 | logger.debug('限流等待中...') 180 | time.sleep(5) 181 | 182 | 183 | if __name__ == '__main__': 184 | while 1: 185 | print('123') 186 | print(get_openai_key()) 187 | time.sleep(1) --------------------------------------------------------------------------------