├── .gitignore
├── public
├── Snipaste_2023-04-25_17-48-51.png
├── Snipaste_2023-04-25_21-26-54.png
├── mm_reward_qrcode_1686025672796.png
└── readme_v1.md
├── readme.md
└── src
├── __init__.py
├── config.py
├── config.sample.ini
├── core
├── __init__.py
├── main.py
└── vup.py
├── db
├── __init__.py
├── dao.py
├── milvus.py
└── models.py
├── forbidden_words.txt
├── manager.py
├── modules
├── __init__.py
├── actions.py
├── audio.py
└── speech_rec.py
├── requirements.txt
├── rooms
├── __init__.py
├── bilibili.py
└── douyin.py
├── scripts
├── crawlers
│ └── tie_ba_spider.py
├── manager.py
├── utils.py
└── workers.py
├── static
├── speech
│ └── .gitkeep
└── voice
│ └── .gitkeep
├── token.txt
└── utils
├── base.py
├── dfa.py
├── events.py
├── init.py
├── log.py
├── prompt_temple.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 |
4 | # IDEs and editors
5 | .idea/
6 | *.swp
7 | *~.nfs*
8 | *.bak
9 | *.cache
10 | *.dat
11 | *.db
12 | *.log
13 | *.patch
14 | *.orig.*
15 | *.rej.*
16 | *.tmp.*
17 |
18 | *.mp3
19 | *.wav
20 |
21 | env/
22 | config.ini
--------------------------------------------------------------------------------
/public/Snipaste_2023-04-25_17-48-51.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/public/Snipaste_2023-04-25_17-48-51.png
--------------------------------------------------------------------------------
/public/Snipaste_2023-04-25_21-26-54.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/public/Snipaste_2023-04-25_21-26-54.png
--------------------------------------------------------------------------------
/public/mm_reward_qrcode_1686025672796.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/public/mm_reward_qrcode_1686025672796.png
--------------------------------------------------------------------------------
/public/readme_v1.md:
--------------------------------------------------------------------------------
1 | # 项目名称
2 |
3 | GPT-vup
4 |
5 | ## :memo: 简介
6 |
7 | 支持BiliBili和抖音直播,基于生产者-消费者模型设计,使用了openai嵌入、GPT3.5 api
8 |
9 | 
10 |
11 | ## :cloud: 环境
12 |
13 | - python 3. 8
14 | - windows
15 | - 确保有VPN 并开启全局代理
16 |
17 | ## :computer: 功能
18 |
19 | 1. 回答弹幕和SC
20 | 2. 欢迎入场观众
21 | 3. 感谢礼物
22 | 4. 自定义任务
23 | 5. plugin 在config.ini设置,默认都为False
24 | - speech:监听ctrl+t热键,输入语音转为文本和ai数字人交互
25 | - action:根据观众的行为匹配对应人物动作
26 | - schedule:隔一段时间触发某一事件,讲故事、唱rap...
27 |
28 | ## :book: 原理
29 |
30 | 
31 |
32 | GPT-vup一共运行三个子线程:
33 |
34 | 生产者线程一:BiliBili Websocket
35 |
36 | - 运行bilibili_api库,通过长连接websocket不断获取直播间的Event,分配到每个filter函数。
37 | - filter函数干两件事,筛选哪些event入队,入哪个队列
38 | - 线程消息队列有两个:
39 | - 前提:生产者的生产速度远大于消费者
40 | - event_queue:有最大长度,超过长度时挤掉最旧的消息,因此它是不可靠的,用来处理直播间的一般消息(普通弹幕、欢迎提示)
41 | - hight...queue:不限长,处理直播间重要消息(sc、上舰)
42 |
43 | 生产者线程二:抖音 WebSocket
44 |
45 | - 借助开源项目 [抖音弹幕抓取数据推送: 基于系统代理抓包打造的抖音弹幕服务推送程序](https://gitee.com/haodong108/dy-barrage-grab/tree/V2.6.5/BarrageGrab) 在本地开一个转发端口
46 | - 再运行一个线程监听这个端口即可,同样用filter过滤,入队
47 |
48 | 生产者线程三:
49 |
50 | - 如果vup只有回应弹幕,我觉得有些单调了,因此可以通过schedule模块,每隔一段时间往high_priority_event_queue送一些自定义Event,比如我想让她每隔十分钟做一个自我介绍、表演节目。
51 | - 5-13更新,支持热键触发实时语音交互,见plugin
52 |
53 | 消费者线程:
54 |
55 | - worker类,有三个函数:generate_chat、generate_action、output去处理不同的Event
56 | - 遵循依赖倒置原则,不管弹幕Event、sc Event都依赖抽象Event,而worker也依赖Event
57 |
58 | 说明:
59 | - 消费者线程必须运行,生产者线程保证至少一个开启
60 |
61 | ## :microscope: 安装配置及使用教程
62 |
63 | ### 克隆项目,安装python依赖
64 |
65 | ```
66 | git https://github.com/jiran214/GPT-vup.git
67 | cd src
68 | # 建议命令行或者pycharm创建虚拟环境并激活
69 | python -m pip install --upgrade pip pip
70 | # 可能会依赖冲突,没法彻底解决
71 | pip install -r .\requirements.txt --no-deps
72 | ```
73 |
74 | ### 配置config
75 |
76 | 在src目录下创建配置文件config.ini(该项目所有配置信息都在这)
77 |
78 | ```ini
79 | [openai]
80 | api_key = sk-iHeZopAaLtem7E7FLEM6T3BaakFJsvhz0yVchBkii0oLJl0V
81 |
82 | [room]
83 | id=27661527
84 |
85 | [edge-tss]
86 | voice = zh-CN-XiaoyiNeural
87 | rate = +10%
88 | volume = +0%
89 |
90 | [other]
91 | debug = True
92 | proxy = 127.0.0.1:7890
93 |
94 | [plugin]
95 | action=False
96 | schedule=False
97 | speech=False
98 | ```
99 |
100 | **说明:**
101 |
102 | - room-id 为直播间房,比如我的是[哔哩哔哩直播,二次元弹幕直播平台 (bilibili.com)](https://live.bilibili.com/27661527)最后一部分(没有房间号可以随便找一个作为测试)
103 |
104 | - edge-tss 语音相关配置
105 |
106 | ### 安装VTS(Vtuber Studio),获取VTS TOKEN并调试API
107 |
108 | - 安装及使用教程网上有,可以配置嘴和音频同步,只说明程序部分
109 | - action plugin 实现Vtuber根据观众的互动行为匹配动作,可忽略
110 | - config.ini 中的action设置为True
111 | - 打开VTS,开启VTS的API开关
112 | - 给任务的每一个动作重命名为体现动作表情的词,不然没有意义
113 | - 运行>> `python ./main action`,pyvts会请求vts api(注意:此时VTS会有确认弹窗)
114 | - 会自动生成 action.json
115 |
116 |
117 | 说明:action plugin原理?
118 |
119 | - 简单说 根据用户发来的弹幕响应对应的动作,先去获取弹幕或者相关信息的向量,用这个向量查找action_embeddings中余弦相似度最接近的向量,也就是最接近的动作,作为响应action。
120 | - 动作响应不一定依靠embedding,实际效果差强人意,用embedding是因为我有考虑到 后期可以给用户的输入匹配更多上下文。上下文可以来源于任何地方 贴吧、小红书...只要提前生成向量保存到向量数据库即可,让AI主播的回答更丰富。
121 | - 关于openai的embedding的介绍和作用,可以看openai文档 [Embeddings - OpenAI API](https://platform.openai.com/docs/guides/embeddings)
122 |
123 | ### 抖音直播配置(可忽略)
124 |
125 | - 参考 [抖音弹幕抓取数据推送: 基于系统代理抓包打造的抖音弹幕服务推送程序](https://gitee.com/haodong108/dy-barrage-grab/tree/V2.6.5/BarrageGrab)
126 | - 打开正在直播的直播间,数据开始抓取
127 |
128 | ### 运行
129 |
130 | - 方式一:谷歌fire库 命令行方式启动(默认):确保 main.py fire.Fire(Management)这一行运行,其它注释掉
131 | - 方式二:正常运行,根据需要运行
132 | ```python
133 | if __name__ == '__main__':
134 | """命令行启动,等同于下面的程序启动"""
135 | # fire.Fire(Management)
136 |
137 | """测试"""
138 | # >> python main test
139 | # Management().test()
140 |
141 | """启动程序"""
142 | # >> python main run bilibili
143 | # Management().run('BiliBili')
144 | # Management().run('DouYin')
145 |
146 | """初始化"""
147 | # >> python main action
148 | # Management().action()
149 | ```
150 | - 建议先运行测试,检测vpn,再正式运行程序
151 | - python main action 时用来初始化 action plugin的,可忽略
152 |
153 | ### OBS
154 |
155 | 网上有教程
156 |
157 | ## :bulb: 踩坑和经验
158 |
159 | 1. 再用openai库的acreate 关闭ssl还是会偶尔遇到ssl报错,怀疑lib底层调aiohttp有冲突,使用create后报错明显减少
160 | 2. 和vts交互上,最开始尝试keyboard键盘操作操控,发现vts的快捷键不像其它软件一样,只能通过pyvts调用api实现动作响应
161 | 3. 在这个AI主播的场景里,需要确保每个 消息队列出队-处理-输出过程的原子性,最好不要同时处理多个弹幕(Event)
162 | 4. 协程适合轻量级的任务,或者说一个协程函数里awiat不能太多,否则并发安全很难维护
163 | 5. 每个线程要创建自己的事件循环
164 | 6. 本项目利用协程解耦不同的生产消费过程,也可以看看这篇文章[写个AI虚拟主播:看懂弹幕,妙语连珠,悲欢形于色,以一种简单的实现 - 掘金 (juejin.cn)](https://juejin.cn/post/7204742468612145209),它用到端口/进程解耦,最后把所有组件用Go组装,AI 主播 总体流程都差不多
165 |
166 | ## :page_with_curl: 更新日志
167 |
168 | - 4.26 支持抖音直播
169 | - 5.13 LangChain重构prompt部分
170 | - 5.13 config.json存取action及向量
171 | - 5.13 支持通过fire命令行启动
172 | - 5.13 增加运行前的测试
173 | - 5.13 插件系统
174 | - 5.14 requirements 修改,bilibili_api库没有更新,没法彻底解决依赖,请在pip install后面加上--no-deps
175 |
176 | ## :phone: Contact Me
177 |
178 | 欢迎加我WX:yuchen59384 交流!
179 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # GPT-vup Live2D数字人直播
2 | **(本库停止维护,可以关注下我的新项目:Langup: https://github.com/jiran214/langup-ai 已实现直播数字人)**
3 |
4 | 
5 |
6 | ## 简介
7 | **Real Virtual UP**
8 | 支持BiliBili和抖音直播,基于生产者-消费者模型设计,使用了openai嵌入、GPT3.5 api
9 | (本库暂时停止维护)
10 |
11 | ### 功能
12 | - 基本功能,回答弹幕和SC、欢迎入场观众、感谢礼物
13 | - plugin(默认关闭)
14 | - speech:监听ctrl+t热键,输入语音转为文本和ai数字人交互
15 | - action:根据观众的行为匹配对应人物动作
16 | - schedule:隔一段时间触发某一事件,讲故事、唱rap...
17 | - context:给问题补充上下文
18 |
19 | ## 安装
20 | ### 环境
21 | - win 10
22 | - python 3.8
23 | - vpn全局代理
24 | ### pip安装依赖
25 | ```shell
26 | git clone https://github.com/jiran214/GPT-vup.git
27 | cd src
28 | # 建议命令行或者pycharm创建虚拟环境并激活 https://blog.csdn.net/xp178171640/article/details/115950985
29 | python -m pip install --upgrade pip pip
30 | pip install -r requirements.txt
31 | ```
32 | ### 新建config.ini
33 | - 重命名config.sample.ini为config.ini
34 | - 更改api_key和proxy 其它可以不用管
35 | - 相关配置见后
36 | ### 测试网络环境
37 | - src目录下运行:>>`python manager.py test_net`
38 | ## 快速开始
39 | ### B站直播
40 | - 安装依赖库:>>`pip install bilibili-api-python`
41 | - config.ini 的 room -> id 更改为自己的房间号,可以先随便找个
42 | - src目录下运行:>>`python manager.py run bilibili`
43 | ### 抖音直播
44 | - 参考 [抖音弹幕抓取数据推送: 基于系统代理抓包打造的抖音弹幕服务推送程序](https://gitee.com/haodong108/dy-barrage-grab/tree/V2.6.5/BarrageGrab)
45 | - 启动该项目
46 | - 打开web或者桌面端抖音正在直播的直播间,数据开始抓取
47 | - src目录下运行:>>`python manager.py run douyin`
48 | ### Vtube Studio 安装及配置
49 | - 在steam下载Vtube Studio软件
50 | - 教程:https://www.bilibili.com/video/BV1nV4y1X7yJ?t=426.7
51 | - 重点!!!
52 | - 麦克风设置:你可以不用虚拟声道,win 默认输出设备为Speaker Realtek(R) Audio,在VTS里的麦克风设置,输入设备也设置为Realtek(R) Audio即可。
53 | - 嘴型同步声音,在mouthOpen的输入参数设置为声音频率、或者声音音量
54 | - 如果需要更好的直播效果,请自行了解更多
55 | ## 进阶
56 | ### speech plugin:语音交互
57 | - config.ini -> plugin -> speech 设置为True
58 | - 运行>> `pip install pyaudio speech_recognition keyboard`
59 | - 程序启动后按住 ctrl+T 说话,自动语音转文字,vup会听到你说的话
60 | ### schedule plugin:隔一段时间触发某一事件,讲故事、唱rap...
61 | - config.ini -> plugin -> schedule 设置为True
62 | - utils/prompt_temple.py 的schedule_task_temple_list列表有我写好的触发事件
63 | ### action plugin:VTS动作表情交互
64 | 实现vup根据观众的互动行为匹配动作
65 | - config.ini -> plugin -> action设置为True
66 | - 运行>>`pip install pyvts`
67 | - 打开VTS,开启VTS的API开关
68 | - 在VTS的表情设置里,给每一个动作重命名为体现动作表情的词,不然没有意义
69 | - src目录下运行>> `python manager.py action`,pyvts会请求vts api(注意:此时VTS会有确认弹窗)
70 | - 程序会自动生成 action.json
71 | - 如果需要更新动作,请重复上述步骤
72 | ### 实验功能:context plugin:给对话补充上下文
73 | - 前提1:Docker[安装milvus2.0单机版本](https://milvus.io/docs/v2.0.x/install_standalone-docker.md),并设置 config.ini -> milvus -> host and port
74 | - 前提2:Mysql环境,并设置 config.ini -> mysql -> uri
75 | - config.ini -> plugin -> context 设置为True
76 | - 运行>> `pip install pymilvus==2.0`
77 | - 自行设置scripts/manager.py的参数,运行>> `python scripts/manager.py run`,采集贴吧数据到MySQL,处理后推给Milvus
78 | ### 其它
79 | - utils/prompt_temple.py 的 system_template 可以更改vup的初始设定
80 | ## 更新日志
81 | - V2.0 支持context plugin,目录重构、更简单的readme,解决依赖混乱的问题
82 | - V1.0 [旧版本内容](https://github.com/jiran214/GPT-vup/tree/1.0)
83 | ## Contact Me
84 |
85 |

86 |
87 |
88 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/__init__.py
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: config.py
5 | @DateTime: 2023/4/7 14:30
6 | @SoftWare: PyCharm
7 | """
8 | import configparser
9 | import json
10 | import os.path
11 |
12 | root_path = os.path.abspath(os.path.dirname(__file__))
13 | file_path = os.path.join(root_path, 'config.ini')
14 |
15 | _config = configparser.RawConfigParser()
16 | _config.read(file_path)
17 |
18 | tss_settings = dict(_config.items('edge-tss'))
19 |
20 | api_key_list = [value for key, value in _config.items('openai') if key.startswith('api') and value]
21 | temperature = _config.get('openai', 'temperature')
22 | room_id = _config.getint('room', 'id')
23 | mysql = dict(_config.items('mysql'))
24 | sqlite = dict(_config.items('sqlite'))
25 | milvus = dict(_config.items('milvus'))
26 | debug = _config.getboolean('other', 'debug')
27 | proxy = _config.get('other', 'proxy')
28 |
29 | action_plugin = _config.getboolean('plugin', 'action')
30 | schedule_plugin = _config.getboolean('plugin', 'schedule')
31 | speech_plugin = _config.getboolean('plugin', 'speech')
32 | context_plugin = _config.getboolean('plugin', 'context')
33 |
34 | try:
35 | live2D_actions = []
36 | live2D_embeddings = []
37 | if action_plugin:
38 | with open("./action.json", 'r') as load_f:
39 | live2D_action_dict = json.load(load_f)
40 | live2D_actions = live2D_action_dict.keys()
41 | assert live2D_embeddings
42 | live2D_embeddings = [live2D_action_dict[action] for action in live2D_actions]
43 | except Exception as e:
44 | print('读取embedding文件错误,请检查本地是否生成action.json 且动作不为空, 使用action plugin前请先运行 python manager.py action', e)
45 |
46 |
47 | with open(os.path.join(root_path, 'forbidden_words.txt'), mode='r', encoding='utf-8') as f:
48 | keyword_str_list = [line.strip() for line in f.readlines()]
49 |
50 | if __name__ == '__main__':
51 | print(api_key, proxy)
--------------------------------------------------------------------------------
/src/config.sample.ini:
--------------------------------------------------------------------------------
1 | [openai]
2 | api_key = xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
3 | temperature = 0.9
4 |
5 | [room]
6 | id=732
7 |
8 | [edge-tss]
9 | voice = zh-CN-XiaoyiNeural
10 | rate = +10%
11 | volume = +0%
12 |
13 | [other]
14 | debug = True
15 | proxy = 127.0.0.1:7890
16 |
17 | [plugin]
18 | action=False
19 | schedule=False
20 | speech=False
21 | context=False
22 |
23 | [mysql]
24 | uri = mysql+pymysql://root:123456@localhost/vup?charset=utf8mb4
25 |
26 | [sqlite]
27 | uri = xxxx
28 |
29 | [milvus]
30 | host = xxxxxx
31 | port = 19530
32 | collection = sun_ba
33 | top_n = 4
34 |
--------------------------------------------------------------------------------
/src/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/core/__init__.py
--------------------------------------------------------------------------------
/src/core/main.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: main.py
5 | @DateTime: 2023/4/23 13:19
6 | @SoftWare: PyCharm
7 | """
8 | import threading
9 | import time
10 | import schedule
11 |
12 |
13 | from src import config
14 | from src.utils.prompt_temple import get_schedule_task
15 | from src.utils.events import UserEvent
16 | from src.utils.utils import worker_logger
17 | from src.rooms.bilibili import BlLiveRoom
18 | from src.rooms.douyin import dy_connect
19 |
20 | from src.utils.utils import user_queue, NewEventLoop
21 | from src.core.vup import VtuBer
22 | from bilibili_api import sync
23 |
24 |
25 | logger = worker_logger
26 |
27 |
28 | # Define the producer function
29 | def bl_producer():
30 | r = BlLiveRoom()
31 | r.connect()
32 |
33 |
34 | def dy_producer():
35 | t_loop = NewEventLoop()
36 | t_loop.run(dy_connect())
37 |
38 |
39 | class UserProducer:
40 |
41 | def __init__(self):
42 | self.run_funcs = []
43 |
44 | def send_user_event_2_queue(self, task):
45 | if user_queue.event_queue.empty():
46 | ue = UserEvent(*task)
47 | ue.is_high_priority = True
48 | # ue.action = live2D_actions.index('Anim Shake')
49 | user_queue.send(ue)
50 |
51 | def create_schedule(self):
52 | # 延时启动
53 | time.sleep(30)
54 | # 清空任务
55 | schedule.clear()
56 | # 创建一个按5分钟间隔执行任务
57 | schedule.every(5).minutes.do(
58 | self.send_user_event_2_queue, get_schedule_task()
59 | )
60 | return schedule
61 |
62 | def run(self):
63 | if config.schedule_plugin:
64 | schedule_obj = self.create_schedule()
65 | self.run_funcs.append(schedule_obj.run_pending)
66 | if config.speech_plugin:
67 | try:
68 | from src.modules.speech_rec import speech_hotkey_listener
69 | except ImportError:
70 | raise 'Please run pip install pyaudio speech_recognition keyboard'
71 | # self.run_funcs.append(speech_hotkey_listener)
72 | speech_hotkey_listener()
73 | if self.run_funcs:
74 | self.run_funcs.append(lambda: time.sleep(2))
75 | while True:
76 | for run_fun in self.run_funcs:
77 | run_fun()
78 |
79 |
80 | # Define the consumer function
81 | def consumer():
82 | while True:
83 | t0 = time.time()
84 | event = user_queue.recv()
85 | if not event:
86 | # Both queues are empty, wait for new items to be added
87 | time.sleep(1)
88 | logger.debug('consumer waiting')
89 | continue
90 |
91 | worker = VtuBer(event)
92 | try:
93 | worker.run()
94 | logger.debug(f'worker耗时:{time.time() - t0}')
95 | except Exception as e:
96 | raise e
97 | # logger.error(e)
98 | # time.sleep(20)
99 |
100 |
101 | def start_thread(worker_name):
102 | worker_map = {
103 | 'bl_producer': bl_producer,
104 | 'dy_producer': dy_producer,
105 | 'user_producer': UserProducer().run,
106 | 'consumer': consumer
107 | }
108 | if worker_name not in worker_map:
109 | raise '不存在...'
110 |
111 | thread = threading.Thread(target=worker_map[worker_name])
112 | thread.start()
113 | return thread
114 |
115 |
116 | if __name__ == '__main__':
117 | # bl_producer()
118 | t = start_thread('bl_producer')
119 | start_thread('consumer')
120 | t.join()
121 | # time.sleep(10000)
122 |
--------------------------------------------------------------------------------
/src/core/vup.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: process.py
5 | @DateTime: 2023/4/22 22:17
6 | @SoftWare: PyCharm
7 | """
8 | import asyncio
9 | import threading
10 | import time
11 | from langchain.chat_models import ChatOpenAI
12 | from bilibili_api import sync
13 |
14 | from src import config
15 | from src.config import live2D_embeddings, keyword_str_list
16 | from src.db.milvus import VectorStore
17 | from src.db.models import TieBa
18 | from src.db.dao import get_session
19 | from src.modules.actions import play_action
20 | from src.modules.audio import tts_save, play_sound
21 | from src.utils.dfa import DFA
22 | from src.utils.events import BlDanmuMsgEvent
23 | from src.utils.utils import worker_logger, sync_get_embedding, get_openai_key
24 | from src.utils.utils import Event
25 | from src.utils.utils import audio_lock, NewEventLoop, top_n_indices_from_embeddings
26 |
27 | logger = worker_logger
28 |
29 | base_path = './static/voice/{}.mp3'
30 |
31 |
32 | class VtuBer:
33 | dfa = DFA(keyword_str_list)
34 |
35 | def __init__(self, event: Event):
36 | self.event = event
37 | self.sound_path = base_path.format(int(time.time()))
38 |
39 | async def generate_chat(self, embedding):
40 | # 额外参数
41 | extra_kwargs = {}
42 | # 只给弹幕增加上下文
43 | if config.context_plugin and isinstance(self.event, BlDanmuMsgEvent):
44 | ids = VectorStore(config.milvus['collection']).search_top_n_from_milvus(int(config.milvus['top_n']), embedding)[0].ids
45 | with get_session() as s:
46 | rows = s.query(TieBa).filter(TieBa.hash_id.in_(str(hash_id) for hash_id in ids)).all()
47 | context = [row.content for row in rows]
48 | extra_kwargs['context'] = str(context)
49 | # 请求GPT
50 | messages = self.event.get_prompt_messages(**extra_kwargs)
51 | logger.info(f"prompt:{messages[1]} 开始请求gpt")
52 | chat = ChatOpenAI(temperature=config.temperature, max_retries=2, max_tokens=150,
53 | openai_api_key=get_openai_key())
54 | llm_res = chat.generate([messages])
55 | assistant_content = llm_res.generations[0][0].text
56 | logger.info(f'assistant_content:{assistant_content}')
57 | # 违禁词判断
58 | dfa_match_list = self.dfa.match(assistant_content)
59 | forbidden_words = [forbidden_word['match'] for forbidden_word in dfa_match_list]
60 | if dfa_match_list:
61 | logger.warning(f'包含违禁词:{forbidden_words},跳过本次语音生成')
62 | return False
63 | # 使用 Edge TTS 生成回复消息的语音文件
64 | logger.debug(f"开始生成TTS 文件")
65 | t0 = time.time()
66 | await tts_save(self.event.get_audio_txt(assistant_content), self.sound_path)
67 | logger.debug(f"tts请求耗时:{time.time()-t0}")
68 |
69 | async def generate_action(self, embedding):
70 | if isinstance(self.event.action, str):
71 | # 是否手动设置
72 | logger.debug(f"开始生成动作")
73 | t0 = time.time()
74 | # 匹配动作
75 | self.event.action = int(top_n_indices_from_embeddings(embedding, live2D_embeddings, top=1)[0])
76 | logger.debug(f"动作请求耗时:{time.time()-t0}")
77 |
78 | async def output(self):
79 | logger.debug(f'path:{self.sound_path} 准备播放音频和动作')
80 | while audio_lock.locked():
81 | await asyncio.sleep(1)
82 | else:
83 | # 每句话间隔时间
84 | time.sleep(0.5)
85 | # 播放声音
86 | play_sound_thread = threading.Thread(target=play_sound, args=(self.sound_path,))
87 | play_sound_thread.start()
88 | # 播放动作
89 | if config.action_plugin and isinstance(self.event.action, int):
90 | await play_action(self.event.action)
91 | # play_sound_thread.join()
92 | # time.sleep(5)
93 |
94 | async def _run(self):
95 | # 获取词向量
96 | str_tuple = ('text', 'content', 'message', 'user_name')
97 | prompt_kwargs = self.event.prompt_kwargs.copy()
98 | embedding_str = None
99 | for key in str_tuple:
100 | if key in prompt_kwargs:
101 | embedding_str = prompt_kwargs[key]
102 | break
103 | if not embedding_str:
104 | raise '不应该不存在'
105 | embedding = sync_get_embedding([embedding_str])
106 | tasks = [asyncio.create_task(self.generate_chat(embedding))]
107 | if config.action_plugin and self.event.action:
108 | tasks.append(asyncio.create_task(self.generate_action(embedding)))
109 | state = await asyncio.gather(*tasks)
110 | if state[0] is not False:
111 | await self.output()
112 |
113 | def run(self):
114 | # t_loop = NewEventLoop()
115 | # t_loop.run(self._run())
116 | sync(self._run())
117 |
118 |
119 |
120 | if __name__ == '__main__':
121 | res = embedding = sync_get_embedding(['embedding_str'])
122 | print(res)
123 | logger.debug('123')
124 |
--------------------------------------------------------------------------------
/src/db/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/db/__init__.py
--------------------------------------------------------------------------------
/src/db/dao.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 |
3 | from sqlalchemy.orm import Session
4 | import contextlib
5 | from sqlalchemy.ext.declarative import declarative_base
6 | from sqlalchemy.orm import sessionmaker
7 | from sqlalchemy import (
8 | create_engine
9 | )
10 | from src import config # config模块里有自己写的配置,我们可以换成别的,注意下面用到config的地方也要一起换
11 |
12 | try:
13 | engine = create_engine(
14 | config.sqlite['uri'] or config.mysql['uri'], # SQLAlchemy 数据库连接串
15 | # echo=bool(config.SQLALCHEMY_ECHO), # 是不是要把所执行的SQL打印出来,一般用于调试
16 | # pool_size=int(config.SQLALCHEMY_POOL_SIZE), # 连接池大小
17 | # max_overflow=int(config.SQLALCHEMY_POOL_MAX_SIZE), # 连接池最大的大小
18 | # pool_recycle=int(config.SQLALCHEMY_POOL_RECYCLE), # 多久时间回收连接
19 | )
20 | Session = sessionmaker(bind=engine)
21 | Base = declarative_base(engine)
22 | except:
23 | engine = None
24 | Session = None
25 | Base = object
26 |
27 |
28 | @contextlib.contextmanager
29 | def get_session():
30 | s = Session()
31 | try:
32 | yield s
33 | s.commit()
34 | except Exception as e:
35 | s.rollback()
36 | raise e
37 | finally:
38 | s.close()
39 |
--------------------------------------------------------------------------------
/src/db/milvus.py:
--------------------------------------------------------------------------------
1 | from src import config
2 | from src.utils.init import initialize_openai
3 | from src.utils.utils import sync_get_embedding
4 |
5 |
6 | class VectorStore:
7 | def __init__(self, name):
8 | try:
9 | from pymilvus import Collection, connections
10 | except ImportError:
11 | raise 'Please run pip install pymilvus==2.1'
12 |
13 | connections.connect(
14 | alias="default",
15 | host=config.milvus['host'],
16 | port=config.milvus['port']
17 | )
18 |
19 | self.collection = Collection(name) # Get an existing collection.
20 | # num_entities = self.collection.num_entities
21 | self.collection.load()
22 | self.search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
23 |
24 | def search_top_n_from_milvus(self, limit, embedding):
25 | results = self.collection.search(
26 | data=[embedding],
27 | anns_field="embedding",
28 | param=self.search_params,
29 | limit=limit,
30 | expr=None,
31 | consistency_level="Strong",
32 | # output_fields='hash_id'
33 | )
34 | return results
35 |
36 |
37 | if __name__ == '__main__':
38 | initialize_openai()
39 | # print()
40 | ids = VectorStore('sun_ba').search_top_n_from_milvus(3, sync_get_embedding(texts=['hello']))[0].ids
--------------------------------------------------------------------------------
/src/db/models.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import Column, String, Boolean, Integer
2 |
3 | from src.db.dao import Base, engine, get_session
4 |
5 |
6 | # class Document(Base):
7 | # __tablename__ = ''
8 | # id = Column(Integer, primary_key=True)
9 | # hash_id = Column(String(30), nullable=False, unique=True)
10 | # content = Column(String(200), nullable=False)
11 | # embedding_state = Column(Boolean, default=False)
12 |
13 |
14 | class TieBa(Base):
15 | __tablename__ = "tie_ba"
16 | id = Column(Integer, primary_key=True)
17 | hash_id = Column(String(30), nullable=False, unique=True)
18 | content = Column(String(200), nullable=False)
19 | embedding_state = Column(Boolean, default=False)
20 | tid = Column(String(20), nullable=False)
21 |
22 | def __repr__(self):
23 | return f'{self.hash_id} self.{self.content}'
24 |
25 |
26 |
--------------------------------------------------------------------------------
/src/forbidden_words.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/forbidden_words.txt
--------------------------------------------------------------------------------
/src/manager.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | path = os.path.dirname(os.path.abspath(__file__))
5 | sys.path.insert(0, path)
6 | sys.path.insert(1, os.path.dirname(path))
7 |
8 | import fire
9 | from src import config
10 | from src.core.main import start_thread
11 | from src.utils.init import initialize_action, initialize_openai
12 | from src.utils.log import worker_logger
13 | from src.utils.utils import NewEventLoop, get_openai_key
14 | from src.utils.init import initialize_openai
15 |
16 |
17 | initialize_openai()
18 | logger = worker_logger
19 |
20 |
21 | class Management:
22 | def __init__(self):
23 | try:
24 | assert config.api_key_list
25 | except:
26 | raise '请填写openai -> api_key!'
27 |
28 | def action(self):
29 | loop = NewEventLoop()
30 | loop.run(initialize_action())
31 |
32 | def run(self, name):
33 | tasks = []
34 | self.test_plugin_dependency()
35 | if name.lower() == 'douyin':
36 | tasks.append(start_thread('dy_producer'))
37 | elif name.lower() == 'bilibili':
38 | tasks.append(start_thread('bl_producer'))
39 | tasks.append(start_thread('user_producer'))
40 | tasks.append(start_thread('consumer'))
41 | for task in tasks:
42 | task.join()
43 |
44 | def test_net(self):
45 | from langchain import OpenAI
46 | import requests
47 | # 测试外网环境(可能异常)
48 | r = requests.get(url='https://www.youtube.com/', verify=False, proxies={
49 | 'http': f'http://{config.proxy}/',
50 | 'https': f'http://{config.proxy}/'
51 | })
52 | assert r.status_code == 200
53 | # 测试openai库
54 | llm = OpenAI(temperature=config.temperature, openai_api_key=get_openai_key(), verbose=config.debug)
55 | text = "python是世界上最好的语言 "
56 | print(llm(text))
57 | print('测试成功!')
58 |
59 | def test_plugin_dependency(self):
60 | if config.context_plugin:
61 | try:
62 | from pymilvus import connections, has_collection, Collection
63 | import cryptography
64 | except ImportError:
65 | raise 'Please run pip install pymilvus==2.0 cryptography parsel'
66 |
67 | try:
68 | connections.connect(
69 | alias="default",
70 | host=config.milvus['host'],
71 | port=config.milvus['port']
72 | )
73 | assert has_collection(config.milvus['collection'])
74 | collection = Collection(config.milvus['collection'])
75 | assert collection.num_entities != 0
76 | except Exception as e:
77 | raise e
78 |
79 | logger.info('上下文插件已开启')
80 | if config.speech_plugin:
81 | try:
82 | from src.modules.speech_rec import speech_hotkey_listener
83 | except ImportError:
84 | raise 'Please run pip install pyaudio speech_recognition keyboard'
85 | logger.info('语音交互插件已开启')
86 | if config.action_plugin:
87 | try:
88 | import pyvts
89 | except ImportError:
90 | raise 'Please run pip install pyvts,then run python manager action'
91 | logger.info('动作响应插件已开启')
92 | if config.schedule_plugin:
93 | logger.info('循环任务插件已开启')
94 |
95 |
96 | if __name__ == '__main__':
97 | """命令行启动,等同于下面的程序启动"""
98 | fire.Fire(Management)
99 |
100 | """测试"""
101 | # >> python manager.py test
102 | # Management().test_net()
103 |
104 | """启动程序"""
105 | # >> python manager.py run bilibili
106 | # Management().run('BiliBili')
107 | # Management().run('DouYin')
108 |
109 | """初始化"""
110 | # >> python manager.py action
111 | # Management().action()
112 |
--------------------------------------------------------------------------------
/src/modules/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "xxx"
4 |
--------------------------------------------------------------------------------
/src/modules/actions.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: action.py
5 | @DateTime: 2023/4/24 0:01
6 | @SoftWare: PyCharm
7 | """
8 | from src.config import live2D_actions
9 |
10 | plugin_info = {
11 | "plugin_name": "start pyvts",
12 | "developer": "Jiran",
13 | "authentication_token_path": "./token.txt"
14 | }
15 |
16 |
17 | async def play_action(action_index):
18 | try:
19 | import pyvts
20 | except ImportError:
21 | raise 'Please run pip install pyvts'
22 | vts = pyvts.vts(plugin_info=plugin_info)
23 | await vts.connect()
24 | await vts.read_token()
25 | await vts.request_authenticate() # use token
26 |
27 | if action_index > len(live2D_actions) - 1:
28 | raise '动作不存在'
29 | send_hotkey_request = vts.vts_request.requestTriggerHotKey(live2D_actions[action_index])
30 | await vts.request(send_hotkey_request)
31 | await vts.close()
32 |
33 |
34 | if __name__ == "__main__":
35 | action_embeddings = None # 获取token不需要运行
36 | # asyncio.run(initialize_action())
37 | # asyncio.run(play_action())
38 |
--------------------------------------------------------------------------------
/src/modules/audio.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: audio.py
5 | @DateTime: 2023/4/24 14:16
6 | @SoftWare: PyCharm
7 | """
8 |
9 | import edge_tts
10 | from src import config
11 |
12 | from pygame import mixer, time as pygame_time
13 |
14 | from src.utils.utils import worker_logger
15 | from src.utils.utils import audio_lock
16 |
17 | logger = worker_logger
18 |
19 |
20 | async def tts_save(text, path):
21 | # tts = edge_tts.Communicate(text=text, proxy=f'http://{config.proxy}', **config.tss_settings)
22 | tts = edge_tts.Communicate(text=text, **config.tss_settings)
23 | await tts.save(path)
24 |
25 |
26 | def play_sound(file_path):
27 | with audio_lock:
28 | # 播放生成的语音文件
29 | mixer.init()
30 | mixer.music.load(file_path)
31 | mixer.music.play()
32 | while mixer.music.get_busy():
33 | pygame_time.Clock().tick(10)
34 |
35 | mixer.music.stop()
36 | mixer.quit()
--------------------------------------------------------------------------------
/src/modules/speech_rec.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: speech_recognition.py
5 | @DateTime: 2023/5/13 18:21
6 | @SoftWare: PyCharm
7 | """
8 | import threading
9 | import time
10 |
11 | import wave
12 | import pyaudio
13 | import speech_recognition as sr
14 | import keyboard
15 |
16 | from src.utils.events import UserEvent
17 | from src.utils.log import worker_logger
18 | from src.utils.utils import user_queue
19 |
20 | CHUNK = 1024
21 | FORMAT = pyaudio.paInt16
22 | CHANNELS = 1
23 | RATE = 16000
24 | RECORD_SECONDS = 5
25 | WAVE_OUTPUT_FILENAME = "../static/speech/{}.wav"
26 |
27 | logger = worker_logger
28 |
29 |
30 | def speech_recognition_task():
31 | r = sr.Recognizer()
32 |
33 | # 使用PyAudio录制音频
34 | audio = pyaudio.PyAudio()
35 | stream = audio.open(format=FORMAT, channels=CHANNELS,
36 | rate=RATE, input=True,
37 | frames_per_buffer=CHUNK)
38 | frames = []
39 | print("正在录音...")
40 | for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
41 | data = stream.read(CHUNK)
42 | frames.append(data)
43 | # if keyboard.is_pressed('ctrl+t'): # 按下Ctrl+T停止录音
44 | # break
45 | print("录音结束")
46 |
47 | stream.stop_stream()
48 | stream.close()
49 | audio.terminate()
50 |
51 | # 将录制的音频写入Wave文件
52 | waveFile = wave.open(WAVE_OUTPUT_FILENAME.format(int(time.time())), 'wb')
53 | waveFile.setnchannels(CHANNELS)
54 | waveFile.setsampwidth(audio.get_sample_size(FORMAT))
55 | waveFile.setframerate(RATE)
56 | waveFile.writeframes(b''.join(frames))
57 | waveFile.close()
58 |
59 | try:
60 | # 使用SpeechRecognition库来将音频转换为文本(非常慢)
61 | audio_file = sr.AudioFile(WAVE_OUTPUT_FILENAME.format(int(time.time())))
62 |
63 | with audio_file as source:
64 | audio_data = r.record(source)
65 | logger.debug('正在请求语音转文字接口recognize_google...')
66 | text = r.recognize_google(audio_data, language='zh-CN')
67 | print('识别到的文字:', text)
68 | ue = UserEvent(text, '{}')
69 | ue.is_high_priority = True
70 | # ue.action = live2D_actions.index('Anim Shake')
71 | user_queue.send(ue)
72 | except sr.UnknownValueError:
73 | print("无法识别语音")
74 | except sr.RequestError as e:
75 | print(f"语音识别服务出错:{e}")
76 |
77 |
78 | def speech_hotkey_listener():
79 | keyboard.add_hotkey('ctrl+t', speech_recognition_task)
80 | keyboard.wait()
81 |
82 |
83 | if __name__ == '__main__':
84 | t = threading.Thread(target=speech_hotkey_listener)
85 | t.start()
86 | # speech_hotkey_listener()
87 |
--------------------------------------------------------------------------------
/src/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/requirements.txt
--------------------------------------------------------------------------------
/src/rooms/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: __init__.py.py
5 | @DateTime: 2023/4/25 21:55
6 | @SoftWare: PyCharm
7 | """
8 |
--------------------------------------------------------------------------------
/src/rooms/bilibili.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: room.py
5 | @DateTime: 2023/4/22 21:23
6 | @SoftWare: PyCharm
7 | """
8 | import asyncio
9 |
10 | from bilibili_api import sync
11 |
12 | from src import config
13 | from src.utils.events import BlDanmuMsgEvent, BlSendGiftEvent, BlSuperChatMessageEvent, BlInteractWordEvent
14 | from src.utils.utils import worker_logger
15 |
16 | from src.utils.utils import user_queue
17 |
18 |
19 | logger = worker_logger
20 |
21 |
22 | class BlLiveRoom:
23 | def __init__(self, bl_room_id=config.room_id):
24 | try:
25 | from bilibili_api import live, sync
26 | except ImportError:
27 | raise 'Please run pip install bilibili-api-python'
28 | self.room = live.LiveDanmaku(bl_room_id)
29 | self.add_event_listeners()
30 |
31 | def add_event_listeners(self):
32 | listener_map = {
33 | 'DANMU_MSG': on_danmaku_event_filter,
34 | 'SUPER_CHAT_MESSAGE': on_super_chat_message_event_filter,
35 | 'SEND_GIFT': on_gift_event_filter,
36 | 'INTERACT_WORD': on_interact_word_event_filter,
37 | }
38 | for item in listener_map.items():
39 | self.room.add_event_listener(*item)
40 |
41 | def connect(self):
42 | # loop = asyncio.get_event_loop()
43 | # loop.run_until_complete(self.room.connect())
44 | sync(self.room.connect())
45 |
46 |
47 | async def on_danmaku_event_filter(event_dict):
48 | # # 收到弹幕
49 | # event = BlDanmuMsgEvent.filter(event_dict)
50 | event = BlDanmuMsgEvent(event_dict)
51 | user_queue.send(event)
52 |
53 |
54 | async def on_super_chat_message_event_filter(event_dict):
55 | # SUPER_CHAT_MESSAGE
56 | # info = event['data']['data']
57 | # user_info = info['user_info']
58 | # print('SUPER_CHAT_MESSAGE',
59 | # user_info['uname'],
60 | # user_info['face'],
61 | #
62 | # info['message'],
63 | # info['price'],
64 | # info['start_time'],
65 | # )
66 | event = BlSuperChatMessageEvent(event_dict)
67 | user_queue.send(event)
68 |
69 |
70 | async def on_gift_event_filter(event_dict):
71 | # 收到礼物
72 | # info = event_dict['data']['data']
73 | # print('SEND_GIFT',
74 | # info['face'],
75 | # info['uname'],
76 | # info['action'],
77 | # info['giftName'],
78 | # info['timestamp'],
79 | # )
80 | event = BlSendGiftEvent(event_dict)
81 | user_queue.send(event)
82 |
83 |
84 | async def on_interact_word_event_filter(event_dict):
85 | # INTERACT_WORD
86 | # info = event_dict['data']['data']
87 | # fans_medal = info['fans_medal']
88 | # print('INTERACT_WORD',
89 | # fans_medal['medal_name'],
90 | # fans_medal['medal_level'],
91 | # info['uname'],
92 | # info['timestamp']
93 | # )
94 | if not user_queue.event_queue.full():
95 | event = BlInteractWordEvent(event_dict)
96 | user_queue.send(event)
97 |
98 | # @room.on('WELCOME')
99 | # async def on_welcome_event_filter(event):
100 | # # 老爷进入房间
101 | # print(3, event)
102 | #
103 | #
104 | # @room.on('WELCOME_GUARD')
105 | # async def on_welcome_guard_event_filter(event):
106 | # # 房管进入房间
107 | # print(4, event)
108 |
109 | if __name__ == '__main__':
110 | r = BlLiveRoom()
111 | r.connect()
--------------------------------------------------------------------------------
/src/rooms/douyin.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: douyin.py
5 | @DateTime: 2023/4/25 21:56
6 | @SoftWare: PyCharm
7 | """
8 |
9 | import aiohttp
10 |
11 | import json
12 |
13 |
14 | from src.utils.events import DyDanmuMsgEvent, DyAttentionEvent, DySendGiftEvent, DyCkEvent, DyWelcomeWordEvent
15 | from src.utils.utils import user_queue
16 |
17 |
18 | def msg(event_dict):
19 | event_dict['Data'] = json.loads(event_dict['Data'])
20 | event = DyDanmuMsgEvent(event_dict)
21 | user_queue.send(event)
22 |
23 |
24 | def ck(event_dict): # type2
25 | if not user_queue.event_queue.full():
26 | event_dict['Data'] = json.loads(event_dict['Data'])
27 | event = DyCkEvent(event_dict)
28 | # print("感谢" + load_data.get("Content"))
29 | user_queue.send(event)
30 |
31 |
32 | def welcome(event_dict): # type3
33 | if not user_queue.event_queue.full():
34 | event_dict['Data'] = json.loads(event_dict['Data'])
35 | event = DyWelcomeWordEvent(event_dict)
36 | # print("欢迎:" + json2["Nickname"])
37 | user_queue.send(event)
38 |
39 |
40 | def Gift(event_dict): # type5
41 | event_dict['Data'] = json.loads(event_dict['Data'])
42 | event = DySendGiftEvent(event_dict)
43 | user_queue.send(event)
44 |
45 |
46 | def attention(event_dict):
47 | event_dict['Data'] = json.loads(event_dict['Data'])
48 | event = DyAttentionEvent(event_dict)
49 | user_queue.send(event)
50 |
51 |
52 | async def dy_connect():
53 | session = aiohttp.ClientSession()
54 | async with session.ws_connect("ws://127.0.0.1:8888") as ws:
55 | await ws.send_str('token')
56 | async for message in ws:
57 | # print(f"Received message: {message}")
58 | # 处理websocket 消息
59 | if message.type == aiohttp.WSMsgType.TEXT:
60 | data = json.loads(message.data)
61 | event_name = data.get("Type") # 标签类型
62 | filter_map = {
63 | 1: msg, # 1用户发言
64 | 2: ck, # 2用户点赞
65 | 3: welcome, # 3用户入房
66 | 4: attention, # 用户关注
67 | 5: Gift, # 5用户礼物
68 | # 6: check, # 6人数统计
69 | }
70 | if event_name in filter_map:
71 | filter_map[event_name](data)
72 |
73 | elif msg.type == aiohttp.WSMsgType.CLOSED:
74 | break
75 | elif msg.type == aiohttp.WSMsgType.ERROR:
76 | break
77 |
78 | await ws.close()
79 |
80 |
81 | if __name__ == '__main__':
82 | from bilibili_api import sync
83 | sync(dy_connect())
84 |
--------------------------------------------------------------------------------
/src/scripts/crawlers/tie_ba_spider.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | try:
4 | from parsel import Selector
5 | import pymysql
6 | except ImportError:
7 | raise 'Please run pip install parsel pymysql cryptography'
8 |
9 |
10 | from src.db.dao import get_session
11 | from src.db.models import TieBa
12 |
13 | pymysql.install_as_MySQLdb()
14 |
15 | # connecting to a MySQL database with user and password
16 |
17 | session = requests.session()
18 | prefix_url = 'https://tieba.baidu.com'
19 | ba_url = 'https://tieba.baidu.com/f?kw={kw}&ie=utf-8&tp=0&pn={pn}'
20 | tie_url = 'https://tieba.baidu.com/p/{tie_id}?pn={pn}'
21 |
22 |
23 | class Ba:
24 |
25 | def __init__(self, kw, max_ba_pn, max_tie_pn, start_ba_pn=1, start_tie_pn=1, max_str_length=400):
26 | self.kw = kw
27 | self.start_ba_pn = start_ba_pn
28 | self.start_tie_pn = start_tie_pn
29 | self.max_ba_pn = max_ba_pn
30 | self.max_tie_pn = max_tie_pn
31 | self.max_str_length = max_str_length
32 |
33 | def get_tie_list(self, pn):
34 | r = session.get(ba_url.format(kw=self.kw, pn=pn))
35 | sl = Selector(text=r.text)
36 |
37 | href_list = sl.xpath("""//div[@class='threadlist_title pull_left j_th_tit ']//a/@href""").getall()
38 | tie_id_list = [href.split('/')[-1] for href in href_list]
39 | yield from tie_id_list
40 |
41 | def get_tie_detail(self, tie_id, pn):
42 | r = session.get(tie_url.format(tie_id=tie_id, pn=pn))
43 | sl = Selector(text=r.text)
44 | tie_item = sl.xpath("""//div[@class='d_post_content_main ']""")
45 |
46 | data_list = []
47 | for tie in tie_item:
48 | data_dict = {
49 | 'tid': tie_id,
50 | 'hash_id': '1',
51 | 'content': '',
52 | # 'embedding': None
53 | }
54 | content_list = tie.xpath(
55 | """.//div[@class='d_post_content j_d_post_content ']//text()""").getall()
56 | content = ' '.join(content_list).strip().replace(' ', '')
57 | data_dict['hash_id'] = str(hash(content))
58 | data_dict['content'] = content
59 |
60 | # 过滤
61 | if len(content.replace(' ', '')) > 10:
62 | if len(content.replace(' ', '')) < self.max_str_length:
63 | data_list.append(data_dict)
64 | else:
65 | print('数据异常 ->', content)
66 |
67 | if data_list:
68 | self.save(data_list)
69 |
70 | if next_pn := sl.xpath(
71 | """//li[@class='l_pager pager_theme_5 pb_list_pager']/span/following-sibling::a[1]/text()""").get():
72 | next_args = (tie_id, next_pn)
73 | return next_args
74 | return None
75 |
76 | def save(self, data_list):
77 | with get_session() as s:
78 | for data_dict in data_list:
79 | s.add(TieBa(**data_dict))
80 |
81 | def depth_first_run(self):
82 | for ba_pn in range(self.start_ba_pn, self.max_ba_pn + 1):
83 | print(f'正在爬取{self.kw}吧 第{ba_pn}页...')
84 | for tid in self.get_tie_list(ba_pn):
85 | next_args = (tid, self.start_tie_pn)
86 | while 1:
87 | print(f'正在爬取{self.kw}吧{tid}贴 第{next_args[1]}页...', end='')
88 | next_args = self.get_tie_detail(*next_args)
89 | print('完成!')
90 | if not next_args or int(next_args[1]) == self.max_tie_pn+1:
91 | break
92 | print('over')
93 |
94 |
95 | if __name__ == '__main__':
96 | Ba(kw='孙笑川', max_ba_pn=1, max_tie_pn=1).depth_first_run()
--------------------------------------------------------------------------------
/src/scripts/manager.py:
--------------------------------------------------------------------------------
1 | import fire
2 |
3 | from src.db.models import *
4 | from src.db.dao import Base, engine
5 | from src.scripts.crawlers.tie_ba_spider import Ba
6 | from src.scripts.workers import EmbeddingWorker
7 |
8 |
9 | class Management:
10 | @staticmethod
11 | def create_table():
12 | Base.metadata.create_all(engine)
13 |
14 | @staticmethod
15 | def crawl(kw='孙笑川', max_ba_pn=1, max_tie_pn=1, start_ba_pn=1, start_tie_pn=1, max_str_length=400):
16 | """
17 |
18 | Args:
19 | kw: 贴吧吧名
20 | max_ba_pn: 贴吧最大爬取页数
21 | max_tie_pn: 每个帖子最大页数
22 | start_ba_pn: 吧开始页数
23 | start_tie_pn: 开始帖子页数
24 | max_str_length: 接受最短的文本长度
25 |
26 | Returns:
27 |
28 | """
29 | Ba(kw, max_ba_pn, max_tie_pn, start_ba_pn, start_tie_pn, max_str_length).depth_first_run()
30 |
31 | @staticmethod
32 | def embedding(name, model, desc=''):
33 | """
34 |
35 | Args:
36 | name: Milvus集合名
37 | model: 定义mysql模型,需要继承于Document
38 | desc: 集合简介
39 |
40 | Returns:
41 |
42 | """
43 | EmbeddingWorker(limit_num=100, embedding_query_num=5000).run(name, desc, model)
44 |
45 | def run(self, kw, name):
46 | self.create_table()
47 | self.crawl(kw=kw)
48 | self.embedding(name, model=TieBa)
49 |
50 |
51 | if __name__ == '__main__':
52 | fire.Fire(Management)
53 |
--------------------------------------------------------------------------------
/src/scripts/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def filter_str(desstr, restr=''):
5 | # 过滤除中英文及数字以外的其他字符
6 | res = re.compile("[^\\u4e00-\\u9fa5^a-z^A-Z^0-9]")
7 | return res.sub(restr, desstr)
8 |
9 |
10 | def num_tokens_from_string(string: str, encoding_name: str = 'cl100k_base') -> int:
11 | import tiktoken
12 | """Returns the number of tokens in a text string."""
13 | encoding = tiktoken.get_encoding(encoding_name)
14 | # encoding = tiktoken.get_encoding(encoding_name)
15 | num_tokens = len(encoding.encode(string))
16 | return num_tokens
17 |
--------------------------------------------------------------------------------
/src/scripts/workers.py:
--------------------------------------------------------------------------------
1 | from pymilvus import FieldSchema, Collection, CollectionSchema, DataType, has_collection
2 |
3 | from src import config
4 | from src.db.dao import get_session
5 | from src.db.models import TieBa
6 | from src.utils.init import initialize_openai
7 | from src.utils.utils import sync_get_embedding
8 |
9 |
10 | class EmbeddingWorker:
11 |
12 | def __init__(self, limit_num=100, embedding_query_num=5000):
13 | try:
14 | from pymilvus import FieldSchema, DataType
15 | from pymilvus import connections
16 | except ImportError:
17 | raise 'Please run pip install pymilvus==2.1'
18 |
19 | initialize_openai()
20 | connections.connect(
21 | alias="default",
22 | host=config.milvus['host'],
23 | port=config.milvus['port']
24 | )
25 |
26 | self.limit_num = limit_num
27 | self.embedding_query_max_length = embedding_query_num
28 |
29 | def search_rows_no_embedding(self, model: Document):
30 | while 1:
31 | with get_session() as s:
32 | rows = s.query(model).filter(model.embedding_state == False).limit(self.limit_num).all()
33 | if not rows:
34 | break
35 | rows_no_embedding = []
36 | rows_content_list = []
37 | current_length = 0
38 | while 1:
39 | if rows:
40 | row = rows.pop()
41 | rows_no_embedding.append(row)
42 | rows_content_list.append(row.content)
43 | current_length += len(row.content)
44 | else:
45 | break
46 | if current_length > self.embedding_query_max_length or not rows:
47 | yield rows_no_embedding, rows_content_list
48 | rows_no_embedding = []
49 | rows_content_list = []
50 | current_length = 0
51 |
52 | def query_embedding(self, rows_no_embedding, rows_content_list):
53 | embedding = sync_get_embedding(rows_content_list)
54 | # content = rows_content_list
55 | hash_id = [int(row.hash_id) for row in rows_no_embedding]
56 | return [hash_id, embedding]
57 |
58 | @staticmethod
59 | def create_collection(name, desc):
60 | # 创建集合
61 | hash_id = FieldSchema(
62 | name="hash_id",
63 | dtype=DataType.INT64,
64 | is_primary=True,
65 | )
66 | # content = FieldSchema(
67 | # name="content",
68 | # dtype=DataType.VARCHAR,
69 | # )
70 | embedding = FieldSchema(
71 | name="embedding",
72 | dtype=DataType.FLOAT_VECTOR,
73 | dim=1536
74 | )
75 | schema = CollectionSchema(
76 | fields=[hash_id, embedding],
77 | description=desc
78 | )
79 | collection_name = name
80 | collection = Collection(
81 | name=collection_name,
82 | schema=schema,
83 | using='default',
84 | shards_num=2,
85 | consistency_level="Strong"
86 | )
87 | return collection
88 |
89 | @staticmethod
90 | def create_index(collection):
91 | # 创建索引
92 | index_params = {
93 | "metric_type": "L2",
94 | "index_type": "IVF_FLAT",
95 | "params": {"nlist": 1024}
96 | }
97 | collection.create_index(
98 | field_name="embedding",
99 | index_params=index_params
100 | )
101 |
102 | @staticmethod
103 | def push_2_milvus(collection, entries):
104 | # 插入数据
105 | try:
106 | collection.insert(entries)
107 | print(f'插入成功,共{len(entries[0])}条')
108 | return True
109 | except Exception as e:
110 | print('插入失败')
111 | raise e
112 | return False
113 |
114 | def run(self, name, desc, model):
115 | if not has_collection(name):
116 | collection = self.create_collection(name, desc)
117 | self.create_index(collection)
118 | else:
119 | collection = Collection(name)
120 |
121 | print(collection.num_entities)
122 | for rows_no_embedding, rows_content_list in self.search_rows_no_embedding(model):
123 | entries = self.query_embedding(rows_no_embedding, rows_content_list)
124 | if self.push_2_milvus(collection, entries):
125 | with get_session() as s:
126 | for tie in rows_no_embedding:
127 | tie.embedding_state = True
128 | print('over')
129 | print(collection.num_entities)
130 |
131 |
132 | if __name__ == '__main__':
133 | # with get_session() as s:
134 | # s.query(TieBa).update({'embedding_state': False})
135 | EmbeddingWorker().run('sun_ba', '孙笑川吧', TieBa)
--------------------------------------------------------------------------------
/src/static/speech/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/static/speech/.gitkeep
--------------------------------------------------------------------------------
/src/static/voice/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiran214/GPT-vup/826aed1455776917832ef79a4c240730f958ed3f/src/static/voice/.gitkeep
--------------------------------------------------------------------------------
/src/token.txt:
--------------------------------------------------------------------------------
1 | 0297c744c0fafa2d8b71bd64594ea390cb5e6411c63eb46e0040d42109c0aae9
--------------------------------------------------------------------------------
/src/utils/base.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: base.py
5 | @DateTime: 2023/4/22 22:19
6 | @SoftWare: PyCharm
7 | """
8 | from abc import abstractmethod
9 | from functools import cached_property
10 | from typing import Union
11 |
12 | import time
13 |
14 | from src import config
15 | from src.config import live2D_actions
16 | from src.utils.prompt_temple import get_chat_prompt_template
17 |
18 |
19 | class Event:
20 | def __init__(self, event_dict):
21 | self._event_dict = event_dict
22 | self._event_name = event_dict.get('type', '') or event_dict.get('Type', '')
23 | if not self._event_name:
24 | raise
25 |
26 | self._kwargs = self.get_kwargs()
27 | self._action = None
28 | # 是否优先处理
29 | self.is_high_priority = False
30 |
31 | @abstractmethod
32 | def get_kwargs(self):
33 | """初始化event中有用的数据"""
34 | return {
35 | 'time': None
36 | }
37 |
38 | @property
39 | @abstractmethod
40 | def prompt_kwargs(self):
41 | """提示模板需要用到的数据"""
42 | return {
43 | 'time': None
44 | }
45 |
46 | @property
47 | @abstractmethod
48 | def human_template(self):
49 | """每类event对应的模板"""
50 | return '{text}'
51 |
52 | def get_prompt_messages(self, **kwargs):
53 | """出口函数,生成prompt,给到llm调用"""
54 | if config.context_plugin and 'context' in kwargs:
55 | human_template = '上下文:{context}\n问题:' + self.human_template
56 | else:
57 | human_template = self.human_template
58 | return get_chat_prompt_template(human_template).format_prompt(**self.prompt_kwargs, **kwargs).to_messages()
59 |
60 | @abstractmethod
61 | def get_audio_txt(self, *args, **kwargs):
62 | """数字人说的话"""
63 | return None
64 |
65 | @property
66 | def time(self):
67 | """易读的时间"""
68 | return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self._kwargs['time']))
69 |
70 | @property
71 | def action(self) -> Union[None, str, int]:
72 | """
73 | :return:
74 | None: 该event不做任何动作
75 | str: zero-shot 匹配动作
76 | int: 通过索引固定 做某个动作
77 | """
78 | return self._action
79 |
80 | @action.setter
81 | def action(self, value: Union[None, str, int]):
82 | if value in live2D_actions:
83 | self._action = live2D_actions.index(value)
84 | self._action = value
--------------------------------------------------------------------------------
/src/utils/dfa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 |
4 | from src.config import keyword_str_list
5 |
6 |
7 | class DFA:
8 |
9 | def __init__(self, keyword_list: list):
10 | self.state_event_dict = self._generate_state_event_dict(keyword_list)
11 |
12 | def match(self, content: str):
13 | match_list = []
14 | state_list = []
15 | temp_match_list = []
16 |
17 | for char_pos, char in enumerate(content):
18 | if char in self.state_event_dict:
19 | state_list.append(self.state_event_dict)
20 | temp_match_list.append({
21 | "start": char_pos,
22 | "match": ""
23 | })
24 |
25 | for index, state in enumerate(state_list):
26 | is_find = False
27 | state_char = None
28 |
29 | # 如果是 * 则匹配所有内容
30 | if "*" in state:
31 | state_list[index] = state["*"]
32 | state_char = state["*"]
33 | is_find = True
34 |
35 | if char in state:
36 | state_list[index] = state[char]
37 | state_char = state[char]
38 | is_find = True
39 |
40 | if is_find:
41 | temp_match_list[index]["match"] += char
42 |
43 | if state_char["is_end"]:
44 | match_list.append(copy.deepcopy(temp_match_list[index]))
45 |
46 | if len(state_char.keys()) == 1:
47 | state_list.pop(index)
48 | temp_match_list.pop(index)
49 | else:
50 | state_list.pop(index)
51 | temp_match_list.pop(index)
52 |
53 | return match_list
54 |
55 | @staticmethod
56 | def _generate_state_event_dict(keyword_list: list) -> dict:
57 | state_event_dict = {}
58 |
59 | for keyword in keyword_list:
60 | current_dict = state_event_dict
61 | length = len(keyword)
62 |
63 | for index, char in enumerate(keyword):
64 | if char not in current_dict:
65 | next_dict = {"is_end": False}
66 | current_dict[char] = next_dict
67 | else:
68 | next_dict = current_dict[char]
69 | current_dict = next_dict
70 | if index == length - 1:
71 | current_dict["is_end"] = True
72 |
73 | return state_event_dict
74 |
75 |
76 | if __name__ == "__main__":
77 | dfa = DFA(keyword_str_list)
78 | print(dfa.match("信息抽我取之 DFA 算法匹配关键词,匹配算法"))
79 |
--------------------------------------------------------------------------------
/src/utils/events.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: events.py
5 | @DateTime: 2023/4/23 16:05
6 | @SoftWare: PyCharm
7 | """
8 | import time
9 |
10 | from src.utils.utils import Event
11 |
12 |
13 | class BlDanmuMsgEvent(Event):
14 |
15 | def __init__(self, *args, **kwargs):
16 | super().__init__(*args, **kwargs)
17 | self.action = self._kwargs['content']
18 |
19 | def get_kwargs(self):
20 | return {
21 | 'content': self._event_dict['data']['info'][1],
22 | 'user_name': self._event_dict['data']['info'][2][1],
23 | 'time': self._event_dict['data']['info'][9]['ts']
24 | }
25 |
26 | @property
27 | def prompt_kwargs(self):
28 | return {
29 | 'text': self._kwargs['content']
30 | }
31 |
32 | @property
33 | def human_template(self):
34 | return '{text}'
35 |
36 | def get_audio_txt(self, gpt_resp):
37 | return f"{self._kwargs['content']} {gpt_resp}"
38 |
39 |
40 | class BlSuperChatMessageEvent(Event):
41 |
42 | def __init__(self, *args, **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.is_high_priority = True
45 | self.action = self._kwargs['message']
46 |
47 | def get_kwargs(self):
48 | info = self._event_dict['data']['data']
49 | user_info = info['user_info']
50 | return {
51 | 'user_name': user_info['uname'],
52 | 'face': user_info['face'],
53 | 'message': info['message'],
54 | 'price': info['price'],
55 | 'time': info['start_time'],
56 | }
57 |
58 | @property
59 | def prompt_kwargs(self):
60 | return {
61 | 'message': self._kwargs['message']
62 | }
63 |
64 | @property
65 | def human_template(self):
66 | return '{message}'
67 |
68 | def get_audio_txt(self, gpt_resp):
69 | return f"感谢{self._kwargs['user_name']}的sc。{self._kwargs['message']} {gpt_resp}"
70 |
71 |
72 | class BlSendGiftEvent(Event):
73 |
74 | def get_kwargs(self):
75 | info = self._event_dict['data']['data']
76 | return {
77 | 'user_name': info['uname'],
78 | 'face': info['face'],
79 | 'action': info['action'],
80 | 'giftName': info['giftName'],
81 | 'time': info['timestamp'],
82 |
83 | }
84 |
85 | @property
86 | def prompt_kwargs(self):
87 | return {
88 | 'content': f"{self._kwargs['user_name']}{self._kwargs['action']}了{self._kwargs['giftName']}。"
89 | }
90 |
91 | @property
92 | def human_template(self):
93 | return (
94 | "{content}"
95 | "请表示感谢,说一句赞美他的话!"
96 | )
97 |
98 | def get_audio_txt(self, gpt_resp):
99 | return f"{self.prompt_kwargs['content']} {gpt_resp}"
100 |
101 |
102 | class BlInteractWordEvent(Event):
103 |
104 | def __init__(self, *args, **kwargs):
105 | super().__init__(*args, **kwargs)
106 |
107 | def get_kwargs(self):
108 | info = self._event_dict['data']['data']
109 | fans_medal = info['fans_medal']
110 | return {
111 | 'medal_name': fans_medal['medal_name'],
112 | 'medal_level': fans_medal['medal_level'],
113 | 'user_name': info['uname'],
114 | 'content': f"{info['uname']} 进入直播间。",
115 | 'time': info['timestamp']
116 | }
117 |
118 | @property
119 | def prompt_kwargs(self):
120 | return {
121 | 'content': self._kwargs['content'],
122 | 'medal_name': self._kwargs['medal_name']
123 | }
124 |
125 | @property
126 | def human_template(self):
127 | return (
128 | "{content}"
129 | "请表示欢迎!并简短聊聊他加入的粉丝团{medal_name}"
130 | )
131 |
132 | def get_audio_txt(self, gpt_resp):
133 | return f"{gpt_resp}"
134 |
135 |
136 | class DyDanmuMsgEvent(BlDanmuMsgEvent):
137 |
138 | def get_kwargs(self):
139 | return {
140 | 'content': self._event_dict['Data']['Content'],
141 | 'user_name': self._event_dict['Data']['User']['Nickname'],
142 | 'time': int(time.time())
143 | }
144 |
145 |
146 | class DyCkEvent(Event):
147 |
148 | def get_kwargs(self):
149 | return {
150 | 'user_name': self._event_dict['Data']['User']['Nickname'],
151 | 'content': self._event_dict['Data']['Content'],
152 | 'time': int(time.time())
153 | }
154 |
155 | @property
156 | def prompt_kwargs(self):
157 | return {
158 | 'content': self._kwargs['user_name']
159 | }
160 |
161 | @property
162 | def human_template(self):
163 | return (
164 | "{content}给你点了赞,"
165 | "请表示感谢!"
166 | )
167 |
168 | def get_audio_txt(self, gpt_resp):
169 | return f"{gpt_resp}"
170 |
171 |
172 | class DyWelcomeWordEvent(Event):
173 |
174 | def get_kwargs(self):
175 | return {
176 | 'user_name': self._event_dict['Data']['User']['Nickname'],
177 | 'time': int(time.time())
178 | }
179 |
180 | @property
181 | def prompt_kwargs(self):
182 | return {
183 | 'user_name': self._kwargs['user_name']
184 | }
185 |
186 | @property
187 | def human_template(self):
188 | return (
189 | "{user_name},进入直播间"
190 | "请表示欢迎!并简短聊聊他的名字"
191 | )
192 |
193 | def get_audio_txt(self, gpt_resp):
194 | return f"{gpt_resp}"
195 |
196 |
197 | class DySendGiftEvent(Event):
198 |
199 | def __init__(self, *args, **kwargs):
200 | super().__init__(*args, **kwargs)
201 | self.is_high_priority = True
202 |
203 | def get_kwargs(self):
204 | return {
205 | 'user_name': self._event_dict['Data']['User']['Nickname'],
206 | 'giftName': self._event_dict['Data']['GiftName'],
207 | 'content': self._event_dict['Data']['Content'],
208 | 'time': int(time.time())
209 | }
210 |
211 | @property
212 | def prompt_kwargs(self):
213 | return {
214 | 'content': self._kwargs['content']
215 | }
216 |
217 | @property
218 | def human_template(self):
219 | return (
220 | "{content}"
221 | "请表示感谢,说一句赞美他的话!"
222 | )
223 |
224 | def get_audio_txt(self, gpt_resp):
225 | return f"{self._kwargs['user_name']} 送出{self._kwargs['giftName']}!{gpt_resp}"
226 |
227 |
228 | class DyAttentionEvent(Event):
229 |
230 | def get_kwargs(self):
231 | return {
232 | 'user_name': self._event_dict['data']['User']['Nickname'],
233 | 'content': self._event_dict['data']['Content'],
234 | 'time': int(time.time())
235 | }
236 |
237 | @property
238 | def prompt_kwargs(self):
239 | return {
240 | 'user_name': self._kwargs['user_name']
241 | }
242 |
243 | @property
244 | def human_template(self):
245 | return (
246 | "{user_name}关注了你!"
247 | "请表示感谢,说一句赞美他的话!"
248 | )
249 |
250 | def get_audio_txt(self, gpt_resp):
251 | return gpt_resp
252 |
253 |
254 | class UserEvent(Event):
255 |
256 | def __init__(self, content, audio_txt_temple):
257 | super(UserEvent, self).__init__({'type': 'user_event'})
258 | self.content = content
259 | self.audio_txt_temple = audio_txt_temple
260 |
261 | def get_kwargs(self):
262 | return {
263 | 'time': int(time.time())
264 | }
265 |
266 | @property
267 | def prompt_kwargs(self):
268 | return {
269 | 'content': self.content
270 | }
271 |
272 | @property
273 | def human_template(self):
274 | return '{content}'
275 |
276 | def get_audio_txt(self, gpt_resp):
277 | return self.audio_txt_temple.format(gpt_resp)
278 |
279 |
--------------------------------------------------------------------------------
/src/utils/init.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: monkey_patch.py
5 | @DateTime: 2023/5/13 12:53
6 | @SoftWare: PyCharm
7 | 先导入这个文件,改写lib
8 | """
9 | import configparser
10 | import json
11 | import os
12 | import random
13 | from contextlib import asynccontextmanager
14 | from typing import AsyncIterator
15 |
16 | import aiohttp
17 | import urllib3
18 | from aiohttp import TCPConnector
19 | from openai import api_requestor
20 | import openai
21 |
22 | from src.modules.actions import plugin_info
23 |
24 | urllib3.disable_warnings()
25 |
26 |
27 | def initialize_openai():
28 | # 避免循环导入,多一次读取ini配置
29 | file_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../config.ini')
30 | _config = configparser.RawConfigParser()
31 | _config.read(file_path)
32 | proxy = _config.get('other', 'proxy')
33 | if proxy:
34 | os.environ["http_proxy"] = f'http://{proxy}/'
35 | os.environ["https_proxy"] = f'http://{proxy}/'
36 |
37 | @asynccontextmanager
38 | async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
39 | async with aiohttp.ClientSession(connector=TCPConnector(limit_per_host=5, ssl=False), trust_env=True) as session:
40 | # async with aiohttp.ClientSession(trust_env=True) as session:
41 | yield session
42 |
43 | api_requestor.aiohttp_session = aiohttp_session
44 |
45 |
46 | async def initialize_action():
47 | # websocket连接 获取token到本地
48 | try:
49 | import pyvts
50 | except ImportError:
51 | raise 'Please run pip install pyvts'
52 | vts = pyvts.vts(plugin_info=plugin_info)
53 | try:
54 | await vts.connect()
55 | except ConnectionRefusedError:
56 | raise '请先打开VTS,并打开API开关!'
57 | print('请在live2D VTS弹窗中点击确认!')
58 | await vts.request_authenticate_token() # get token
59 | await vts.write_token()
60 | await vts.request_authenticate() # use token
61 |
62 | response_data = await vts.request(vts.vts_request.requestHotKeyList())
63 | hotkey_list = []
64 | for hotkey in response_data['data']['availableHotkeys']:
65 | hotkey_list.append(hotkey['name'])
66 | print('读取到所有模型动作:', hotkey_list)
67 |
68 | # 请求embedding
69 | print('请求embedding模型中...')
70 | try:
71 | initialize_openai()
72 | res = await openai.Embedding.acreate(input=hotkey_list, model="text-embedding-ada-002")
73 | action_embeddings = [d['embedding'] for d in res['data']]
74 | action_dict = dict(zip(hotkey_list, action_embeddings))
75 | print(len(action_dict))
76 | except Exception as e:
77 | print('很可能是翻墙有问题')
78 | raise e
79 |
80 | # 写入
81 | with open("../action.json", "w") as dump_f:
82 | json.dump(action_dict, dump_f)
83 |
84 | # 测试
85 | assert len(hotkey_list) == len(action_dict.keys()) # vts 和 本地的动作是否一致
86 | assert len(hotkey_list) not in (0, 1) # 动作太少
87 | action = random.choice(hotkey_list)
88 | print('随机播放动作测试...', action)
89 | send_hotkey_request = vts.vts_request.requestTriggerHotKey(action)
90 | await vts.request(send_hotkey_request)
91 | await vts.close()
--------------------------------------------------------------------------------
/src/utils/log.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: logging_module.py
5 | @DateTime: 2023/3/20 23:14
6 | @SoftWare: PyCharm
7 | """
8 | import datetime
9 | import logging
10 | import logging.handlers
11 | import os
12 |
13 | import colorlog
14 |
15 | from src import config
16 | from src.config import root_path
17 |
18 |
19 | class Logging:
20 | def __init__(self, log_file_name, log_file_path=os.path.join(root_path, 'logs')):
21 | """
22 | :param log_file_path:
23 | 1、print(os.getcwd()) # 获取当前工作目录路径
24 | 2、print(os.path.abspath('.')) # 获取当前工作目录路径
25 | :param log_file_name:
26 | 1、current_work_dir = os.path.dirname(__file__) # 当前文件所在的目录
27 | 2、weight_path = os.path.join(current_work_dir, weight_path) # 再加上它的相对路径,这样可以动态生成绝对路径
28 | """
29 | self.log_colors_config = {
30 | 'DEBUG': 'cyan', # cyan white
31 | 'INFO': 'green',
32 | 'WARNING': 'yellow',
33 | 'ERROR': 'red',
34 | 'CRITICAL': 'bold_red',
35 | }
36 | # log文件存储路径
37 | self.log_file_path = log_file_path
38 | self.log_file_name = log_file_name
39 | self._log_filename = self.get_log_filename()
40 |
41 | # 创建一个日志对象
42 | self._logger = logging.getLogger(self.log_file_name)
43 |
44 | # 设置控制台日志的输出级别: 级别排序:CRITICAL > ERROR > WARNING > INFO > DEBUG
45 | if config.debug is True:
46 | self.set_console_logger()
47 | self.set_file_logger()
48 | self._logger.setLevel(logging.DEBUG) # 大于info级别的日志信息都会被输出
49 | else:
50 | self.set_file_logger()
51 | self._logger.setLevel(logging.INFO) # 大于info级别的日志信息都会被输出
52 | # self._logger.setLevel(logging.DEBUG)
53 |
54 | def get_log_filename(self):
55 | if not os.path.isdir(self.log_file_path):
56 | # 创建文件夹
57 | os.makedirs(self.log_file_path)
58 | return f"{self.log_file_path}/{self.log_file_name}_{str(datetime.date.today())}.log"
59 |
60 | def set_console_logger(self):
61 | formatter = colorlog.ColoredFormatter(
62 | # fmt='%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s -> %(funcName)s line:%(lineno)d [%(levelname)s] : %(message)s',
63 | fmt='%(log_color)s[%(asctime)s] %(levelname)s [%(funcName)s] - %(message)s',
64 | datefmt='%Y-%m-%d %H:%M:%S',
65 | log_colors=self.log_colors_config)
66 | console_handler = logging.StreamHandler()
67 | console_handler.setLevel(logging.DEBUG)
68 | console_handler.setFormatter(formatter)
69 | self._logger.addHandler(console_handler)
70 |
71 | def set_file_logger(self):
72 | # 日志文件信息输出格式
73 | formatter = logging.Formatter(
74 | '%(asctime)s - %(filename)s -> %(funcName)s - line:%(lineno)d - %(levelname)s: %(message)s')
75 | # 将输出日志信息保存到文件中
76 | file_handler = logging.handlers.RotatingFileHandler(
77 | self._log_filename, maxBytes=10485760, backupCount=5, encoding="utf-8")
78 | file_handler.setLevel(logging.INFO)
79 | file_handler.setFormatter(formatter)
80 | self._logger.addHandler(file_handler)
81 |
82 | def get_logger(self):
83 | return self._logger
84 |
85 |
86 | # request_logger = Logging("Request").get_logger()
87 | worker_logger = Logging("Worker").get_logger()
88 |
89 |
90 | if __name__ == '__main__':
91 | worker_logger.debug('123')
--------------------------------------------------------------------------------
/src/utils/prompt_temple.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: prompt_temple.py
5 | @DateTime: 2023/5/13 11:37
6 | @SoftWare: PyCharm
7 | """
8 | import random
9 |
10 | from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
11 |
12 | from src import config
13 |
14 | system_template = (
15 | "请扮演一个傲娇的AI虚拟主播,不要回答任何敏感问题!"
16 | "不要强调你是虚拟主播,不准道歉,只需要回答问题!"
17 | "你是一个带有一定的高傲情绪,可能会在自我表现时表现出一些自大或骄傲的情绪,或者表现出一些自信或者不屑一顾的态度。"
18 | "有时会有一些任性或者固执的表现。在做出决定时可能会表现出一定的倔强和固执,"
19 | "可能不太容易妥协或者接受别人的意见。同时也有一定的可爱和迷人之处。"
20 | "她们可能会有一些俏皮或者撒娇的表现,或者在和自己关心的人相处时表现得特别温柔和可爱。"
21 | )
22 |
23 | if config.context_plugin:
24 | system_template += (
25 | "\n\n"
26 | "请你重复上下文的内容和口吻,做出对观众问题的回应,你的回答里不能说你参考了上下文。"
27 | ""
28 | )
29 | system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
30 |
31 |
32 | def get_chat_prompt_template(human_template="{text}") -> ChatPromptTemplate:
33 | human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
34 | chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
35 | return chat_prompt
36 |
37 |
38 | schedule_task_temple_list = [
39 | ('我想让你扮演说唱歌手。您将想出强大而有意义的歌词、节拍和节奏,让听众“惊叹”。你的歌词应该有一个有趣的含义和信息,'
40 | '人们也可以联系起来。在选择节拍时,请确保它既朗朗上口又与你的文字相关,这样当它们组合在一起时,每次都会发出爆炸声!'
41 | '我的第一个请求是“我需要一首关于在你自己身上寻找力量的说唱歌曲。', '接下来我给大家表演一首说唱!{}'),
42 |
43 | ('我要你扮演诗人。你将创作出能唤起情感并具有触动人心的力量的诗歌。写任何主题或主题,'
44 | '但要确保您的文字以优美而有意义的方式传达您试图表达的感觉。您还可以想出一些短小的诗句,这些诗句仍然足够强大,可以在读者的脑海中留下印记。'
45 | '我的第一个请求是“我需要一首关于爱情的诗”。', '{}'),
46 |
47 | ('我希望你充当励志演说家。将能够激发行动的词语放在一起,让人们感到有能力做一些超出他们能力的事情。你可以谈论任何话题,但目的是确保你所说的话能引起听众的共鸣,激励他们努力实现自己的目标并争取更好的可能性。'
48 | '我的第一个请求是“我需要一个关于每个人如何永不放弃的演讲”。', '全体注意,我将发表一个演讲!{}'),
49 |
50 | ("我要你担任哲学老师。我会提供一些与哲学研究相关的话题,你的工作就是用通俗易懂的方式解释这些概念。"
51 | "这可能包括提供示例、提出问题或将复杂的想法分解成更容易理解的更小的部分。"
52 | "我的第一个请求是“我需要帮助来理解不同的哲学理论如何应用于日常生活。", '我想到一个深奥的道理,{}'),
53 |
54 | ("我想让你扮演讲故事的角色。您将想出引人入胜、富有想象力和吸引观众的有趣故事。它可以是童话故事、教育故事或任何其他类型的故事,"
55 | "有可能吸引人们的注意力和想象力。根据目标受众,您可以为讲故事环节选择特定的主题或主题,例如,如果是儿童,则可以谈论动物;如果是成年人,"
56 | "那么基于历史的故事可能会更好地吸引他们等等。请开始你的讲述?", '我给大家讲个故事吧{}')
57 | ]
58 |
59 |
60 | def get_schedule_task() -> (str, str):
61 | return random.choice(schedule_task_temple_list)
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | @Author: jiran
3 | @Email: jiran214@qq.com
4 | @FileName: utils.py
5 | @DateTime: 2023/4/23 17:00
6 | @SoftWare: PyCharm
7 | """
8 | import asyncio
9 | import queue
10 | import random
11 | import time
12 | from collections import deque
13 |
14 | from threading import Lock
15 | from dataclasses import dataclass
16 |
17 | from typing import List, Union, Dict
18 |
19 | import numpy as np
20 | import openai
21 | from scipy import spatial
22 |
23 | from src import config
24 | from src.utils.base import Event
25 | from src.utils.log import worker_logger
26 |
27 | logger = worker_logger
28 | audio_lock = Lock()
29 |
30 |
31 | class NewEventLoop:
32 | def __init__(self):
33 | # self.loop = asyncio.new_event_loop()
34 | # asyncio.set_event_loop(self.loop)
35 | self.loop = asyncio.get_event_loop()
36 |
37 | def run(self, coroutine):
38 | self.loop.run_until_complete(coroutine)
39 |
40 |
41 | @dataclass
42 | class GPT35Params:
43 | messages: List # 聊天格式的输入消息列表
44 | model: str = "gpt-3.5-turbo" # 模型 ID,只支持 gpt-3.5-turbo 和 gpt-3.5-turbo-0301
45 | temperature: float = 1.5 # 采样温度,0~2 范围内的浮点数。较大的值会使输出更随机,较小的值会使输出更确定
46 | top_p: float = 1.0 # 替代采样温度的另一种方式,称为 nucleus 采样,只考虑概率质量排名前 top_p 的 token。范围在 0~1 之间
47 | n: int = 1 # 每个输入消息要生成的聊天完成选项数量,默认为 1
48 | stream: bool = False # 是否启用流式输出
49 | stop: Union[str, List[str], None] = None # 最多 4 个序列,当 API 生成的 token 包含任意一个序列时停止生成
50 | max_tokens: int = 1000 # 默认inf # 生成的答案中允许的最大 token 数量,默认为 (4096 - prompt tokens)
51 | presence_penalty: float = None # 0.0 # -2.0 到 2.0 之间的数字,用于基于新 token 是否出现在已有文本中惩罚模型。正数值会增加模型谈论新话题的可能性
52 | frequency_penalty: float = None # 0.0 # -2.0 到 2.0 之间的数字,用于基于新 token 是否在已有文本中的频率惩罚模型。正数值会降低模型直接重复相同文本的可能性
53 | logit_bias: Dict[str, float] = None # 一个将 token ID 映射到关联偏差值(-100 到 100)的 JSON 对象,用于修改指定 token 出现在完成中的可能性
54 | user: str = None # 表示最终用户的唯一标识符,可帮助 OpenAI 监视和检测滥用
55 |
56 | def dict(self, exclude_defaults=False, exclude_none=True):
57 | res = {}
58 | for k, v in self.__dict__.items():
59 | if exclude_none:
60 | if v is None:
61 | continue
62 | res[k] = v
63 | return res
64 | # dict(exclude_defaults=False, exclude_none=True)
65 |
66 |
67 | class UserQueue:
68 | maxsize = 15
69 |
70 | def __init__(self):
71 | self.high_priority_event_queue = queue.Queue()
72 | self.event_queue = queue.Queue(self.maxsize)
73 |
74 | def send(self, event: Union[Event, None]):
75 | if not event:
76 | logger.debug(f'过滤:{event}')
77 | return
78 | # print(event._event_name)
79 | # Check if high-priority queue is full
80 | if event.is_high_priority:
81 | # Add object to high-priority queue
82 | self.high_priority_event_queue.put_nowait(event)
83 | else:
84 | # Check if main queue is full
85 | if not self.event_queue.full():
86 | # Add object to main queue
87 | self.event_queue.put_nowait(event)
88 | else:
89 | # Remove oldest item from queue and add new item
90 | self.event_queue.get()
91 | self.event_queue.put_nowait(event)
92 |
93 | def recv(self) -> Union[None, Event]:
94 | # Check high-priority queue first
95 | if not self.high_priority_event_queue.empty():
96 | # Remove oldest item from high-priority queue and process it
97 | event = self.high_priority_event_queue.get()
98 | elif not self.event_queue.empty():
99 | # Remove oldest item from main queue and process it
100 | event = self.event_queue.get()
101 | else:
102 | event = None
103 | return event
104 |
105 | def __str__(self):
106 | return f"event_queue 数量:{self.event_queue.qsize()} high_priority_event_queue 数量:{self.high_priority_event_queue.qsize()}"
107 |
108 |
109 | user_queue = UserQueue()
110 |
111 |
112 | class FixedLengthTSDeque:
113 |
114 | def __init__(self, per_minute_times):
115 | self._per_minute_times = per_minute_times
116 | self._data_deque = deque(maxlen=per_minute_times)
117 |
118 | def _append(self) -> None:
119 | self._data_deque.append(int(time.time()))
120 |
121 | def can_append(self):
122 | if len(self._data_deque) < self._per_minute_times or (int(time.time()) - self._data_deque[0]) < 60:
123 | return True
124 | return False
125 |
126 | def acquire(self):
127 | if not self.can_append():
128 | logger.warning('当前api_key调用频率超过3次/min,建议在config -> openai添加api_key')
129 | return False
130 | else:
131 | self._append()
132 | return True
133 |
134 |
135 | def top_n_indices_from_embeddings(
136 | query_embedding: List[float],
137 | embeddings: List[List[float]],
138 | distance_metric="cosine",
139 | top=1
140 | ) -> list:
141 | """Return the distances between a query embedding and a list of embeddings."""
142 | distance_metrics = {
143 | "cosine": spatial.distance.cosine,
144 | "L1": spatial.distance.cityblock,
145 | "L2": spatial.distance.euclidean,
146 | "Linf": spatial.distance.chebyshev,
147 | }
148 | distances = [
149 | distance_metrics[distance_metric](query_embedding, embedding)
150 | for embedding in embeddings
151 | ]
152 | top_n_indices = np.argsort(distances)[:top]
153 | return top_n_indices
154 |
155 |
156 | def sync_get_embedding(texts: List[str], model="text-embedding-ada-002"):
157 | res = openai.Embedding.create(input=texts, model=model, api_key=get_openai_key())
158 | if isinstance(texts, list) and len(texts) == 1:
159 | return res['data'][0]['embedding']
160 | else:
161 | return [d['embedding'] for d in res['data']]
162 |
163 |
164 | api_key_limit_dict = dict(zip(config.api_key_list, [FixedLengthTSDeque(3) for _ in config.api_key_list]))
165 | current_api_key = config.api_key_list[0]
166 |
167 |
168 | def get_openai_key():
169 | global current_api_key
170 | times = 0
171 | while 1:
172 | times = times + 1
173 | if api_key_limit_dict[current_api_key].acquire():
174 | return current_api_key
175 | else:
176 | current_api_key = config.api_key_list[(config.api_key_list.index(current_api_key)+1) % len(config.api_key_list)]
177 | # print('switch', current_api_key)
178 | if times > len(config.api_key_list):
179 | logger.debug('限流等待中...')
180 | time.sleep(5)
181 |
182 |
183 | if __name__ == '__main__':
184 | while 1:
185 | print('123')
186 | print(get_openai_key())
187 | time.sleep(1)
--------------------------------------------------------------------------------