├── VERSION
├── tests
└── __init__.py
├── src
├── asr
│ ├── __init__.py
│ ├── asr_base.py
│ ├── local_whisper.py
│ └── openai_whisper.py
├── chain
│ ├── __init__.py
│ ├── ask_ai.py
│ ├── summarize.py
│ └── base_chain.py
├── core
│ ├── __init__.py
│ ├── routers
│ │ ├── __init__.py
│ │ ├── chain_router.py
│ │ ├── llm_router.py
│ │ └── asr_router.py
│ └── app.py
├── llm
│ ├── __init__.py
│ ├── gpt.py
│ ├── claude.py
│ ├── llm_base.py
│ ├── templates.py
│ └── spark.py
├── models
│ ├── __init__.py
│ ├── config.py
│ └── task.py
├── bilibili
│ ├── __init__.py
│ ├── bili_credential.py
│ ├── bili_session.py
│ ├── bili_video.py
│ └── bili_comment.py
├── listener
│ ├── __init__.py
│ └── bili_listen.py
└── utils
│ ├── logging.py
│ ├── exceptions.py
│ ├── callback.py
│ ├── prompt_utils.py
│ ├── up_video_cache.py
│ ├── merge_config.py
│ ├── file_tools.py
│ ├── cache.py
│ ├── queue_manager.py
│ ├── task_status_record.py
│ └── statistic.py
├── .dockerignore
├── .gitignore
├── .github
└── workflows
│ ├── ruff.yaml
│ └── push_to_docker.yaml
├── requirements.txt
├── requirements-docker.txt
├── ruff.toml
├── LICENSE
├── ci
└── Dockerfile
├── safe_update.py
├── DEV_README.md
├── config
├── example_config.yml
└── docker_config.yml
├── README.md
└── main.py
/VERSION:
--------------------------------------------------------------------------------
1 | 3.0.4
2 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/asr/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/chain/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/llm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/bilibili/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/core/routers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/listener/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | venv
2 | .idea
3 | .git
--------------------------------------------------------------------------------
/src/utils/logging.py:
--------------------------------------------------------------------------------
1 | import loguru
2 |
3 | LOGGER = loguru.logger
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | tests/.pytest_cache
3 | data/
4 | venv/
5 | __pycache__/
6 | config.yml
7 | openaidemo.py
--------------------------------------------------------------------------------
/src/utils/exceptions.py:
--------------------------------------------------------------------------------
1 | class ConfigError(Exception):
2 | pass
3 |
4 |
5 | class BilibiliBaseException(Exception):
6 | pass
7 |
8 |
9 | class RiskControlFindError(Exception):
10 | pass
11 |
12 |
13 | class LoadJsonError(Exception):
14 | pass
15 |
--------------------------------------------------------------------------------
/.github/workflows/ruff.yaml:
--------------------------------------------------------------------------------
1 | name: Ruff
2 | on: [push, pull_request]
3 | jobs:
4 | ruff:
5 | runs-on: ubuntu-latest
6 | steps:
7 | - uses: actions/checkout@v3
8 | - uses: chartboost/ruff-action@v1
9 | with:
10 | args: --config ./ruff.toml
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bilibili-api-python==16.2.0
2 | loguru
3 | openai==0.28
4 | openai-whisper
5 | APScheduler
6 | pytest
7 | httpx
8 | tenacity
9 | PyYAML
10 | ffmpeg-python
11 | # matplotlib
12 | injector
13 | pydantic
14 | anthropic
15 | ruamel.yaml
16 | pydub
17 | websockets
18 |
--------------------------------------------------------------------------------
/requirements-docker.txt:
--------------------------------------------------------------------------------
1 | # 用于构建docker版本,不包含whisper,whisper在构建时会自行加上
2 | bilibili-api-python==16.2.0
3 | loguru
4 | openai==0.28
5 | APScheduler
6 | pytest
7 | httpx
8 | tenacity
9 | PyYAML
10 | ffmpeg-python
11 | # matplotlib
12 | injector
13 | pydantic
14 | anthropic
15 | ruamel.yaml
16 | pydub
17 |
--------------------------------------------------------------------------------
/ruff.toml:
--------------------------------------------------------------------------------
1 | line-length = 120 # 代码最大行宽
2 | [lint]
3 | select = [
4 | # pycodestyle
5 | "E",
6 | # Pyflakes
7 | "F",
8 | # pyupgrade
9 | "UP",
10 | # flake8-bugbear
11 | "B",
12 | # flake8-simplify
13 | "SIM",
14 | # isort
15 | "I",
16 | ]
17 | ignore = ["E501"] # 忽略的规则
18 |
19 | [format]
20 | docstring-code-format = true
21 | docstring-code-line-length = 20
22 | quote-style = "double"
--------------------------------------------------------------------------------
/src/utils/callback.py:
--------------------------------------------------------------------------------
1 | import traceback
2 |
3 | import loguru
4 |
5 | _LOGGER = loguru.logger
6 |
7 |
8 | def chain_callback(retry_state):
9 | """处理链重试回调函数"""
10 | exception = retry_state.outcome.exception()
11 | _LOGGER.error(f"捕获到错误:{exception}")
12 | traceback.print_tb(retry_state.outcome.exception().__traceback__)
13 | _LOGGER.debug(f"当前重试次数为{retry_state.attempt_number}")
14 | _LOGGER.debug(f"下一次重试将在{retry_state.next_action.sleep}秒后进行")
15 |
16 |
17 | def scheduler_error_callback(event):
18 | match event.exception.__class__.__name__:
19 | case "ClientOSError":
20 | _LOGGER.warning(f"捕获到异常:{event.exception} 可能是新版bilibili_api库的问题,接收消息没问题就不用管")
21 | case _:
22 | return
23 |
--------------------------------------------------------------------------------
/src/utils/prompt_utils.py:
--------------------------------------------------------------------------------
1 | def parse_prompt(prompt_template, **kwargs):
2 | """解析填充prompt"""
3 | for key, value in kwargs.items():
4 | prompt_template = prompt_template.replace(f"[{key}]", str(value))
5 | return prompt_template
6 |
7 |
8 | def build_openai_style_messages(user_msg, system_msg=None, user_keyword="user", system_keyword="system"):
9 | """构建消息
10 | :param user_msg: 用户消息
11 | :param system_msg: 系统消息
12 | :param user_keyword: 用户关键词(这个和下面的system_keyword要根据每个llm不同的要求来填)
13 | :param system_keyword: 系统关键词
14 | :return: 消息列表
15 | """
16 | messages = []
17 | if system_msg:
18 | messages.append({"role": system_keyword, "content": system_msg})
19 | messages.append({"role": user_keyword, "content": user_msg})
20 | return messages
21 |
--------------------------------------------------------------------------------
/src/utils/up_video_cache.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from src.utils.exceptions import LoadJsonError
4 | from src.utils.file_tools import read_file, save_file
5 | from src.utils.logging import LOGGER
6 |
7 | _LOGGER = LOGGER.bind(name="video_cache")
8 |
9 |
10 | def load_cache(file_path):
11 | try:
12 | content = read_file(file_path)
13 | if content:
14 | cache = json.loads(content)
15 | return cache
16 | else:
17 | save_file(json.dumps({}, ensure_ascii=False, indent=4), file_path)
18 | except Exception as e:
19 | raise LoadJsonError("在读取缓存文件时出现问题!程序已停止运行,请自行检查问题所在") from e
20 |
21 |
22 | def set_cache(file_path, cache, data: dict, key: str):
23 | if key not in cache:
24 | cache[key] = {}
25 | cache[key] = data
26 | save_file(json.dumps(cache, ensure_ascii=False, indent=4), file_path)
27 |
28 |
29 | def get_up_file(file_path):
30 | with open(file_path, encoding="utf-8") as f:
31 | up_list = json.loads(f.read())
32 | return up_list["all_area"]
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 yanyao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/utils/merge_config.py:
--------------------------------------------------------------------------------
1 | """在config模板发生更新时 及时无损合并用户的config"""
2 |
3 | import ruamel.yaml
4 |
5 | # TODO 保持原有格式和注释
6 |
7 | yaml = ruamel.yaml.YAML()
8 |
9 |
10 | def load_config(config_path):
11 | """加载config"""
12 | with open(config_path, encoding="utf-8") as f:
13 | config = yaml.load(f)
14 | return config
15 |
16 |
17 | def is_have_diff(config, template):
18 | """判断用户的config是否有缺失"""
19 | for key in template:
20 | if key not in config:
21 | return True
22 | if isinstance(template[key], dict) and is_have_diff(config[key], template[key]):
23 | return True
24 | return False
25 |
26 |
27 | def merge_config(config, template):
28 | """合并config 如果缺失,则加上这个键和对应默认值"""
29 | for key in template:
30 | if key not in config:
31 | config[key] = template[key]
32 | if isinstance(template[key], dict):
33 | merge_config(config[key], template[key])
34 | return config
35 |
36 |
37 | def save_config(config, config_path):
38 | """保存config"""
39 | with open(config_path, "w", encoding="utf-8") as f:
40 | yaml.dump(config, f)
41 |
--------------------------------------------------------------------------------
/ci/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11-slim-buster as base
2 |
3 | WORKDIR /usr/src/app
4 |
5 | COPY ../requirements-docker.txt ./
6 |
7 | RUN mkdir -p /clone-data/temp /clone-data/whisper-models /clone-data/statistics \
8 | && touch /clone-data/cache.json && touch /clone-data/records.json && touch /clone-data/queue.json
9 |
10 | RUN apt-get update \
11 | # && apt-get install -y ffmpeg fonts-wqy-zenhei \
12 | && apt-get install -y ffmpeg \
13 | && apt-get clean \
14 | && pip install --no-cache-dir -r requirements-docker.txt
15 |
16 | COPY .. ./
17 | COPY ../config/docker_config.yml /clone-data/config.yml
18 |
19 | ENV DOCKER_CONFIG_FILE=/data/config.yml
20 | ENV DOCKER_CACHE_FILE=/data/cache.json
21 | ENV DOCKER_TEMP_DIR=/data/temp
22 | ENV DOCKER_WHISPER_MODELS_DIR=/data/whisper-models
23 | ENV DOCKER_QUEUE_DIR=/data/queue.json
24 | ENV DOCKER_RECORDS_DIR=/data/records.json
25 | ENV DOCKER_STATISTICS_DIR=/data/statistics
26 | ENV DOCKER_UP_FILE=/data/up.json
27 | ENV DOCKER_UP_VIDEO_CACHE=/data/video_cache.json
28 | ENV RUNNING_IN_DOCKER yes
29 |
30 | FROM base as with_whisper
31 | ENV ENABLE_WHISPER yes
32 | RUN pip install --no-cache-dir openai-whisper
33 |
34 | CMD ["python", "main.py"]
35 |
36 | FROM base as without_whisper
37 | ENV ENABLE_WHISPER no
38 | CMD ["python", "main.py"]
39 |
--------------------------------------------------------------------------------
/safe_update.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import shutil
4 | import traceback
5 |
6 | from src.utils.file_tools import read_file, save_file
7 | from src.utils.logging import LOGGER
8 |
9 | _LOGGER = LOGGER.bind(name="safe-update")
10 |
11 |
12 | def merge_cache_to_new_version(cache_file_path: str) -> bool:
13 | """迁移老缓存文件到新版本格式"""
14 | content = read_file(cache_file_path)
15 | try:
16 | content_dict: dict = json.loads(content)
17 | if content_dict:
18 | if content_dict.get("summarize") is None and content_dict.get("ask_ai") is None:
19 | # 判断cache文件的内部结构,是否存在summarize或ask_ai键,全都不存在就是老版缓存,要转换
20 | _LOGGER.warning("缓存似乎是旧版的,尝试转换为新格式")
21 | _LOGGER.debug(f"备份老版缓存到{cache_file_path}.bak")
22 | shutil.copy(cache_file_path, cache_file_path + ".bak")
23 | new_summarize_dict = copy.deepcopy(content_dict)
24 | new_content_dict = {
25 | "summarize": new_summarize_dict,
26 | "ask_ai": {},
27 | } # 老版本缓存只存在于只有summarize处理链的时代,直接这样转换
28 | save_file(
29 | json.dumps(new_content_dict, ensure_ascii=False, indent=4),
30 | cache_file_path,
31 | )
32 | _LOGGER.info("转换缓存完成!")
33 | return True
34 | return True
35 | return True
36 | except Exception:
37 | traceback.print_exc()
38 | _LOGGER.error("在尝试转换缓存文件时出现问题!请自行查看!")
39 | return False
40 |
--------------------------------------------------------------------------------
/DEV_README.md:
--------------------------------------------------------------------------------
1 | # 简单的开发文档和各模块设计思路
2 |
3 | well, 我的代码大家有目共睹,有、烂,所以这个大家看看乐呵乐呵就好
4 |
5 | ## 设计思路
6 |
7 | ### ASR和LLM的路由
8 |
9 | (我姑且在下文把各种拓展的llm和asr称为“插件”)
10 |
11 | 1. 让用户在配置文件中填写各个的优先级,和启用状态
12 | 2. 调度器在初始化时,根据优先级和启用状态,生成一个路由字典,字典中的每个插件形如:`{"asr_alias": {
13 | "priority": priority,
14 | "enabled": enabled,
15 | "prepared": False,
16 | "err_times": 0,
17 | "obj": asr_object
18 | }, ...}`
19 | 3. 当其他模块需要使用时,调用路由器的`get_one()`方法,返回一个插件实例(如果该插件没被初始化,即`prepared=False`
20 | ,就先进行初始化再传回实例)
21 | 4. 当某个插件出错时,调用路由器的`report_error()`方法,将该插件的`err_times`加一,如果`err_times`
22 | 超过了阈值(现在为10),就将该插件的`enabled`置为False
23 | 5. 当所有插件都不可用时,调用路由器的`get_one()`方法,会返回None,此时其他模块会通过设置event直接关闭整个程序
24 |
25 | ### 采用依赖注入
26 |
27 | 在之前的代码中,出现了很多各种实例来回传递的情况,虽然都是在`main.py`中进行的,不至于太难维护。但确实有够难看...
28 |
29 | 后来在看其他项目时,了解到python竟然也有inject这种依赖注入包,于是就在代码重构时将项目全部变成依赖注入。具体的代码可以去`core/app.py`
30 | 中看,这里就不赘述了。
31 |
32 | ### 处理链和视频状态记录的实现
33 |
34 | 处理链说实话不太好抽象,不同操作实现的差异性太多了。但是也是有共性的,比如既然面对的对象是b站视频,所以一定会对字幕、视频基本信息进行操作,所以封装了一下这些操作,让处理链的实现者只需要关注自己的操作就好了。
35 |
36 | 现在的视频状态记录还是有很大的问题,因为原来的视频记录就是为了完全适配“视频总结”这个功能设计的,没考虑到后期可能会支持其他操作,所以很多的字段都无法共用,后期等待重构吧......
37 |
38 | 一直说要实现一些新的处理链,太忙了一直没弄,有空了试一下,看看我这个设计是不是一坨答辩
39 |
40 | ## 开发文档
41 |
42 | ### ASR及LLM插件开发文档
43 |
44 | 应该实现什么我在`llm_base.py`和`asr_base.py`中都写了,自己去看吧
45 |
46 | 路由器在处理时默认将你这个插件的类名转为下划线命名后作为alias(eg.
47 | 你类名为LocalWhisper,那生成的alias就为local_whisper),也会依据这个alias去config中寻找配置文件,所以你的配置文件和类名一定要一致。
48 |
49 | 别忘了在`utils/models.py`中实现你插件的配置模型,并模仿其他插件在主Config中引用
50 |
51 | ### 处理链开发文档
52 |
53 | 好吧,我实话实说,我还没试过抽象后的处理链到底好不好添加功能。
54 |
55 | 不过我能确定的一点是现在处理链还没实现热插拔,如果你要实现新功能,需要修改`bilibili/listen.py`和`main.py`,仿照着摘要处理链进行修改。
56 |
57 | `base_chain.py`这个基类起码注释是挺完善了,希望你顺利~
--------------------------------------------------------------------------------
/src/utils/file_tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import traceback
3 |
4 | from src.utils.logging import LOGGER
5 |
6 | _LOGGER = LOGGER
7 |
8 |
9 | def read_file(file_path: str, mode: str = "r", encoding: str = "utf-8"):
10 | """
11 | 读取一个文件,如果不存在就创建这个文件及其所有中间路径
12 | :param encoding: utf-8
13 | :param mode: 读取的模式(默认为只读)
14 | :param file_path:
15 | :return: 文件内容
16 | """
17 | dir_path = os.path.dirname(file_path)
18 | try:
19 | if not os.path.exists(dir_path):
20 | os.makedirs(dir_path, exist_ok=True)
21 | with open(file_path, mode, encoding=encoding) as file:
22 | return file.read()
23 | except FileNotFoundError:
24 | with open(file_path, "w", encoding=encoding):
25 | pass
26 | except Exception:
27 | _LOGGER.error("在读取文件时发生意料外的问题,返回空值")
28 | traceback.print_exc()
29 | return ""
30 | with open(file_path, mode, encoding=encoding) as f_r:
31 | return f_r.read()
32 |
33 |
34 | def save_file(content: str, file_path: str, mode: str = "w", encoding: str = "utf-8") -> bool:
35 | """
36 | 保存一个文件,如果不存在就创建这个文件及其所有中间路径
37 | :param content:
38 | :param file_path:
39 | :param mode: 默认为w
40 | :param encoding: utf8
41 | :return: bool
42 | """
43 | dir_path = os.path.dirname(file_path)
44 | try:
45 | if not os.path.exists(dir_path):
46 | os.makedirs(dir_path, exist_ok=True)
47 | with open(file_path, mode, encoding=encoding) as file:
48 | file.write(content)
49 | return True
50 | except Exception:
51 | _LOGGER.error("在读取文件时发生意料外的问题,返回空值")
52 | traceback.print_exc()
53 | return False
54 |
--------------------------------------------------------------------------------
/src/asr/asr_base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import re
3 | from typing import Optional
4 |
5 | from src.core.routers.llm_router import LLMRouter
6 | from src.models.config import Config
7 |
8 |
9 | class ASRBase:
10 | """ASR基类,所有ASR子类都应该继承这个类"""
11 |
12 | def __init__(self, config: Config, llm_router: LLMRouter):
13 | self.config = config
14 | self.llm_router = llm_router
15 |
16 | def __new__(cls, *args, **kwargs):
17 | """将类名转换为alias"""
18 | instance = super().__new__(cls)
19 | name = cls.__name__
20 | name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
21 | name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
22 | instance.alias = name
23 | return instance
24 |
25 | @abc.abstractmethod
26 | def prepare(self) -> None:
27 | """
28 | 准备方法,例如加载模型
29 | 会在该类被 **第一次** 使用时调用
30 | 这个函数不应该有入参,所有参数都从self.config中获取
31 | """
32 | pass
33 |
34 | @abc.abstractmethod
35 | async def transcribe(self, audio_path: str, **kwargs) -> Optional[str]:
36 | """
37 | 转写方法
38 | 该方法最好只传入音频路径,返回转写结果,对于其他配置参数需要从self.config中获取
39 | 注意,这个方法中的转写部分不能阻塞,否则你要实现下方的_wait_transcribe方法,并在这里采用 **线程池** 方式调用
40 | None建议当且仅当在转写失败时返回,因为当接收方收到None时会报告错误,当错误计数达到一定值时会停止使用该ASR
41 | """
42 | pass
43 |
44 | def _sync_transcribe(self, audio_path: str, **kwargs) -> Optional[str]:
45 | """
46 | 阻塞转写方法,选择性实现
47 | """
48 | pass
49 |
50 | async def after_process(self, text: str, **kwargs) -> str:
51 | """
52 | 后处理方法,例如将转写结果塞回llm,获得更高质量的字幕
53 | 直接通过self.llm_router调用llm
54 | 记得代码逻辑一定要是异步
55 |
56 | 如果处理过程中出错应该返回原字幕
57 | """
58 | pass
59 |
60 | def __repr__(self):
61 | return f"<{self.alias} ASR>"
62 |
63 | def __str__(self):
64 | return self.__class__.__name__
65 |
--------------------------------------------------------------------------------
/.github/workflows/push_to_docker.yaml:
--------------------------------------------------------------------------------
1 | name: Release Docker Image
2 | run-name: Release Docker Image
3 |
4 | on:
5 | workflow_dispatch:
6 | inputs:
7 | version:
8 | description: '版本'
9 | required: false
10 | push:
11 | paths:
12 | - 'VERSION'
13 |
14 | jobs:
15 | docker:
16 | runs-on: ubuntu-latest
17 | steps:
18 | - name: Checkout code
19 | uses: actions/checkout@v3
20 |
21 | - name: Read version from VERSION file
22 | id: file_version
23 | run: echo "VERSION=$(cat VERSION)" >> $GITHUB_ENV
24 |
25 | - name: Set version
26 | id: set_version
27 | run: |
28 | if [ "${{ github.event.inputs.version }}" == "" ]; then
29 | echo "::set-output name=version::$VERSION"
30 | else
31 | echo "::set-output name=version::${{ github.event.inputs.version }}"
32 | fi
33 | - name: Set up QEMU
34 | uses: docker/setup-qemu-action@v3
35 |
36 | - name: Set up Docker Buildx
37 | uses: docker/setup-buildx-action@v3
38 |
39 | - name: Login to Docker Hub
40 | uses: docker/login-action@v3
41 | with:
42 | username: ${{ secrets.DOCKERHUB_USERNAME }}
43 | password: ${{ secrets.DOCKERHUB_TOKEN }}
44 |
45 | - name: Build and push with whisper
46 | uses: docker/build-push-action@v5
47 | with:
48 | push: true
49 | tags: |
50 | ${{ secrets.DOCKERHUB_USERNAME }}/bilibili_gpt_helper:with_whisper
51 | ${{ secrets.DOCKERHUB_USERNAME }}/bilibili_gpt_helper:${{ steps.set_version.outputs.version }}_whisper
52 | target: with_whisper
53 | file: ./ci/Dockerfile
54 |
55 | - name: Build and push without whisper
56 | uses: docker/build-push-action@v5
57 | with:
58 | push: true
59 | tags: |
60 | ${{ secrets.DOCKERHUB_USERNAME }}/bilibili_gpt_helper:latest
61 | ${{ secrets.DOCKERHUB_USERNAME }}/bilibili_gpt_helper:${{ steps.set_version.outputs.version }}
62 | target: without_whisper
63 | file: ./ci/Dockerfile
--------------------------------------------------------------------------------
/src/llm/gpt.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import traceback
3 | from functools import partial
4 | from typing import Tuple
5 |
6 | import openai
7 |
8 | from src.llm.llm_base import LLMBase
9 | from src.utils.logging import LOGGER
10 |
11 | _LOGGER = LOGGER.bind(name="openai_gpt")
12 |
13 |
14 | class Openai(LLMBase):
15 | def prepare(self):
16 | self.openai = openai
17 | self.openai.api_base = self.config.LLMs.openai.api_base
18 | self.openai.api_key = self.config.LLMs.openai.api_key
19 |
20 | def _sync_completion(self, prompt, **kwargs) -> Tuple[str, int] | None:
21 | """调用openai的Completion API
22 | :param prompt: 输入的文本(请确保格式化为openai的prompt格式)
23 | :param kwargs: 其他参数
24 | :return: 返回生成的文本和token总数 或 None
25 | """
26 | try:
27 | model = self.config.LLMs.openai.model
28 | resp = self.openai.ChatCompletion.create(model=model, messages=prompt, **kwargs)
29 | _LOGGER.debug(f"调用openai的Completion API成功,API返回结果为:{resp}")
30 | _LOGGER.info(
31 | f"调用openai的Completion API成功,本次调用中,prompt+response的长度为{resp['usage']['total_tokens']}"
32 | )
33 | resp_msg = resp["choices"][0]["message"]["content"]
34 | if resp_msg.startswith("```json"):
35 | resp_msg = resp_msg[7:]
36 | if resp_msg.endswith("```"):
37 | resp_msg = resp_msg[:-3]
38 | return (
39 | resp_msg,
40 | resp["usage"]["total_tokens"],
41 | )
42 | except Exception as e:
43 | _LOGGER.error(f"调用openai的Completion API失败:{e}")
44 | traceback.print_tb(e.__traceback__)
45 | return None
46 |
47 | async def completion(self, prompt, **kwargs) -> Tuple[str, int] | None:
48 | """调用openai的Completion API
49 | :param prompt: 输入的文本(请确保格式化为openai的prompt格式)
50 | :param kwargs: 其他参数
51 | :return: 返回生成的文本和token总数 或 None
52 | """
53 | loop = asyncio.get_event_loop()
54 | bound_func = partial(self._sync_completion, prompt, **kwargs)
55 | res = await loop.run_in_executor(None, bound_func)
56 | return res
57 |
--------------------------------------------------------------------------------
/src/bilibili/bili_credential.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | from apscheduler.schedulers.asyncio import AsyncIOScheduler
4 | from bilibili_api import Credential
5 | from injector import inject
6 |
7 | from src.utils.logging import LOGGER
8 |
9 | _LOGGER = LOGGER.bind(name="bilibili-credential")
10 |
11 |
12 | class BiliCredential(Credential):
13 | """B站凭证类,主要增加定时检查cookie是否过期"""
14 |
15 | # noinspection PyPep8Naming,SpellCheckingInspection
16 | @inject
17 | def __init__(
18 | self,
19 | SESSDATA: str,
20 | bili_jct: str,
21 | buvid3: str,
22 | dedeuserid: str,
23 | ac_time_value: str,
24 | sched: AsyncIOScheduler,
25 | ):
26 | """
27 | 全部强制要求传入,以便于cookie刷新。
28 |
29 | :param SESSDATA: SESSDATA cookie值
30 | :param bili_jct: bili_jct cookie值
31 | :param buvid3: buvid3 cookie值
32 | :param dedeuserid: dedeuserid cookie值
33 | :param ac_time_value: ac_time_value cookie值
34 | """
35 | super().__init__(
36 | sessdata=SESSDATA,
37 | bili_jct=bili_jct,
38 | buvid3=buvid3,
39 | dedeuserid=dedeuserid,
40 | ac_time_value=ac_time_value,
41 | )
42 | self.sched = sched
43 |
44 | async def _check_refresh(self):
45 | """
46 | 检查cookie是否过期
47 | """
48 | _LOGGER.debug("正在检查cookie是否过期")
49 | if await self.check_refresh():
50 | _LOGGER.info("cookie过期,正在刷新")
51 | await self.refresh()
52 | _LOGGER.info("cookie刷新成功")
53 | else:
54 | _LOGGER.debug("cookie未过期")
55 |
56 | if await self.check_valid():
57 | _LOGGER.debug("cookie有效")
58 | else:
59 | _LOGGER.warning("cookie刷新后依旧无效,请关注!")
60 |
61 | def start_check(self):
62 | """
63 | 开始检查cookie是否过期的定时任务
64 | """
65 | self.sched.add_job(
66 | self._check_refresh,
67 | trigger="interval",
68 | hours=12,
69 | id="check_refresh",
70 | max_instances=3,
71 | next_run_time=datetime.now(),
72 | )
73 | _LOGGER.info("[定时任务]检查cookie是否过期定时任务注册成功,每60秒检查一次")
74 |
--------------------------------------------------------------------------------
/src/core/routers/chain_router.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from injector import inject
4 |
5 | from src.models.config import Config
6 | from src.models.task import AskAICommandParams, BiliGPTTask, Chains
7 | from src.utils.logging import LOGGER
8 | from src.utils.queue_manager import QueueManager
9 |
10 | _LOGGER = LOGGER.bind(name="chain-router")
11 |
12 |
13 | class ChainRouter:
14 | @inject
15 | def __init__(self, config: Config, queue_manager: QueueManager):
16 | self.config = config
17 | self.queue_manager = queue_manager
18 | self.summarize_queue = None
19 | self.ask_ai_queue = None
20 | self._get_queues()
21 |
22 | def _get_queues(self):
23 | self.summarize_queue = self.queue_manager.get_queue("summarize")
24 | self.ask_ai_queue = self.queue_manager.get_queue("ask_ai")
25 |
26 | async def dispatch_a_task(self, task: BiliGPTTask):
27 | content: str = task.source_command
28 | _LOGGER.info(f"开始处理消息,原始消息内容为:{content}")
29 | summarize_keyword = self.config.chain_keywords.summarize_keywords
30 | ask_ai_keyword = self.config.chain_keywords.ask_ai_keywords
31 |
32 | match content:
33 | case content if any(keyword in content for keyword in summarize_keyword):
34 | keyword = next(keyword for keyword in summarize_keyword if keyword in content)
35 | _LOGGER.info(f"检测到关键字 {keyword} ,放入【总结】队列")
36 | task.chain = Chains.SUMMARIZE
37 | _LOGGER.debug(task)
38 | await self.summarize_queue.put(task)
39 | return
40 | case content if any(keyword in content for keyword in ask_ai_keyword):
41 | keyword = next(keyword for keyword in ask_ai_keyword if keyword in content)
42 | _LOGGER.info(f"任务{task.uuid}:检测到关键字 {keyword} ,开始解析参数:冒号后的问题")
43 | match = re.search(r"[::](.*)", content)
44 | if match:
45 | ask_ai_params = AskAICommandParams.model_validate({"question": match.group(1)})
46 | task.command_params = ask_ai_params
47 | else:
48 | _LOGGER.warning(f"任务{task.uuid}:没找到冒号,无法提取问题参数,跳过")
49 | return
50 | task.chain = Chains.ASK_AI
51 | _LOGGER.debug(task)
52 | await self.ask_ai_queue.put(task)
53 | return
54 | case _:
55 | _LOGGER.debug("没有检测到关键字,跳过")
56 |
--------------------------------------------------------------------------------
/src/utils/cache.py:
--------------------------------------------------------------------------------
1 | """管理视频处理后缓存"""
2 |
3 | import json
4 |
5 | from src.utils.exceptions import LoadJsonError
6 | from src.utils.file_tools import read_file, save_file
7 | from src.utils.logging import LOGGER
8 |
9 | _LOGGER = LOGGER.bind(name="cache")
10 |
11 |
12 | class Cache:
13 | def __init__(self, cache_path: str):
14 | self.cache_path = cache_path
15 | self.cache = {}
16 | self.load_cache()
17 |
18 | def load_cache(self):
19 | """加载缓存"""
20 | # try:
21 | # if not os.path.exists(self.cache_path):
22 | # os.makedirs(os.path.dirname(self.cache_path), exist_ok=True)
23 | # with open(self.cache_path, "w", encoding="utf-8") as f:
24 | # json.dump({}, f, ensure_ascii=False, indent=4)
25 | # with open(self.cache_path, encoding="utf-8") as f:
26 | # self.cache = json.load(f)
27 | # except Exception as e:
28 | # _LOGGER.error(f"加载缓存失败:{e},尝试删除缓存文件并重试")
29 | # traceback.print_exc()
30 | # self.cache = {}
31 | # if os.path.exists(self.cache_path):
32 | # os.remove(self.cache_path)
33 | # _LOGGER.info("已删除缓存文件")
34 | # self.save_cache()
35 | # _LOGGER.info("已重新创建缓存文件")
36 | # self.load_cache()
37 | try:
38 | content = read_file(self.cache_path)
39 | if content:
40 | self.cache = json.loads(content)
41 | else:
42 | self.cache = {}
43 | self.save_cache()
44 | except Exception as e:
45 | raise LoadJsonError("在读取缓存文件时出现问题!程序已停止运行,请自行检查问题所在") from e
46 |
47 | def save_cache(self):
48 | """保存缓存"""
49 | # try:
50 | # with open(self.cache_path, "w", encoding="utf-8") as f:
51 | # json.dump(self.cache, f, ensure_ascii=False, indent=4)
52 | # except Exception as e:
53 | # _LOGGER.error(f"保存缓存失败:{e}")
54 | # traceback.print_exc()
55 | save_file(json.dumps(self.cache, ensure_ascii=False, indent=4), self.cache_path)
56 |
57 | def get_cache(self, key: str, chain: str):
58 | """获取缓存"""
59 | return self.cache.get(chain, {}).get(key)
60 |
61 | def set_cache(self, key: str, value, chain: str):
62 | """设置缓存"""
63 | if chain not in self.cache:
64 | self.cache[chain] = {}
65 |
66 | self.cache[chain][key] = value
67 | self.save_cache()
68 |
69 | def delete_cache(self, key: str):
70 | """删除缓存"""
71 | self.cache.pop(key)
72 | self.save_cache()
73 |
74 | def clear_cache(self):
75 | """清空缓存"""
76 | self.cache = {}
77 | self.save_cache()
78 |
79 | def get_all_cache(self):
80 | """获取所有缓存"""
81 | return self.cache
82 |
--------------------------------------------------------------------------------
/config/example_config.yml:
--------------------------------------------------------------------------------
1 | ASRs:
2 | local_whisper: # 本地的whisper
3 | enable: false # 是否启用whisper,在视频无字幕时生成字幕,否则该视频将自动跳过
4 | priority: 50 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的
5 | after_process: false # 是否再使用llm优化生成字幕结果,最终字幕效果会大幅提升
6 | device: cpu # cpu or cuda(仅在运行源代码时可用,docker运行只能选择cpu)
7 | model_dir: /data/whisper-models # 本地模型存放目录,如果更改要映射出来
8 | model_size: tiny # tiny, base, small, medium, large 详细选择请去:https://github.com/openai/whisper
9 |
10 | openai_whisper:
11 | enable: false # 是否启用openai whisper
12 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的
13 | api_base: https://api.openai.com/v1 # 你的openai api base url(多数在使用第三方api供应商时会有,记得url尾缀有/v1)
14 | api_key: '' # 你的openai api key
15 | model: whisper-1 # 有且仅有这一个模型,不要改
16 | after_process: false # 是否再使用llm优化生成字幕结果,最终字幕效果会大幅提升
17 |
18 |
19 |
20 | LLMs:
21 | openai: # 对接gpt
22 | enable: true # 是否启用openai
23 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的
24 | api_base: https://api.openai.com/v1 # 你的openai api base url(多数在使用第三方api供应商时会有,记得url尾缀有/v1)
25 | api_key: '' # 你的openai api key
26 | model: gpt-3.5-turbo-16k # 选择模型,我现在只推荐使用gpt-3.5-turbo-16k,其他模型容纳不了这么大的token,如果你有gpt-4-16k权限,还钱多,请自便
27 |
28 | aiproxy_claude: # 对接aiproxy claude(因为对接方式不同 只能用https://aiproxy.io这家的服务)
29 | enable: true # 是否启用claude
30 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的
31 | api_base: https://api.aiproxy.io/ # 你的claude api base url(多数在使用第三方api供应商时会有)
32 | api_key: '' # 你的claude api key
33 | model: claude-instant-1 # 选择模型,claude-instant-1或claude-2
34 |
35 | spark: # 对接讯飞星火
36 | enable: true # 是否启用讯飞星火
37 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的
38 | spark_url: wss://spark-api.xf-yun.com/v3.5/chat # 你的spark api base url(多数在使用第三方api供应商时会有)
39 | appid: '' # 你的appid
40 | api_key: '' # 你的api_key
41 | api_secret: '' # 你的api_secret
42 | domain: 'generalv3.5' # 要与spark_url对应
43 |
44 | bilibili_self:
45 | nickname: ''
46 |
47 | bilibili_cookie: # https://nemo2011.github.io/bilibili-api/#/get-credential 获取cookie 下面五个值都要填
48 | SESSDATA: ''
49 | ac_time_value: ''
50 | bili_jct: ''
51 | buvid3: ''
52 | dedeuserid: ''
53 |
54 |
55 | chain_keywords:
56 | summarize_keywords: # 用于生成评价的关键词,如果at/私信内容包含以下关键词,将会生成评价(该功能开发中)
57 | - "总结"
58 | - "总结一下"
59 | - "总结一下吧"
60 | - "总结一下吧!"
61 | ask_ai_keywords:
62 | - "问一下"
63 | - "请问"
64 |
65 |
66 | storage_settings:
67 | cache_path: /data/cache.json # 用于缓存已经处理过的视频,如果更改要映射出来
68 | statistics_dir: /data/statistics # 用于存放统计数据,如果更改要映射出来
69 | task_status_records: /data/records.json # 用于记录任务状态,如果更改要映射出来,不能留空
70 | queue_save_dir: /data/queue.json # 用于保存未完成的队列信息,下次运行时恢复
71 | temp_dir: /data/temp # 主要用于下载视频音频生成字幕,如果更改要映射出来
72 | up_video_cache: ./data/video_cache.json
73 | up_file: ./data/up.json
74 |
75 | debug_mode: true # 是否开启debug模式,开启后会打印更多日志,建议开启,以便于查找bug
--------------------------------------------------------------------------------
/config/docker_config.yml:
--------------------------------------------------------------------------------
1 | ASRs:
2 | local_whisper: # 本地的whisper
3 | enable: false # 是否启用whisper,在视频无字幕时生成字幕,否则该视频将自动跳过
4 | priority: 50 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的
5 | after_process: false # 是否再使用llm优化生成字幕结果,最终字幕效果会大幅提升
6 | device: cpu # cpu or cuda(仅在运行源代码时可用,docker运行只能选择cpu)
7 | model_dir: /data/whisper-models # 本地模型存放目录,如果更改要映射出来
8 | model_size: tiny # tiny, base, small, medium, large 详细选择请去:https://github.com/openai/whisper
9 |
10 | openai_whisper:
11 | enable: false # 是否启用openai whisper
12 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的
13 | api_base: https://api.openai.com/v1 # 你的openai api base url(多数在使用第三方api供应商时会有,记得url尾缀有/v1)
14 | api_key: '' # 你的openai api key
15 | model: whisper-1 # 有且仅有这一个模型,不要改
16 | after_process: false # 是否再使用llm优化生成字幕结果,最终字幕效果会大幅提升
17 |
18 |
19 |
20 | LLMs:
21 | openai: # 对接gpt
22 | enable: true # 是否启用openai
23 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的
24 | api_base: https://api.openai.com/v1 # 你的openai api base url(多数在使用第三方api供应商时会有,记得url尾缀有/v1)
25 | api_key: '' # 你的openai api key
26 | model: gpt-3.5-turbo-16k # 选择模型,我现在只推荐使用gpt-3.5-turbo-16k,其他模型容纳不了这么大的token,如果你有gpt-4-16k权限,还钱多,请自便
27 |
28 | aiproxy_claude: # 对接aiproxy claude(因为对接方式不同 只能用https://aiproxy.io这家的服务)
29 | enable: true # 是否启用claude
30 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的
31 | api_base: https://api.aiproxy.io/ # 你的claude api base url(多数在使用第三方api供应商时会有)
32 | api_key: '' # 你的claude api key
33 | model: claude-instant-1 # 选择模型,claude-instant-1或claude-2
34 |
35 | spark: # 对接讯飞星火
36 | enable: true # 是否启用讯飞星火
37 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的
38 | spark_url: wss://spark-api.xf-yun.com/v3.5/chat # 你的spark api base url(多数在使用第三方api供应商时会有)
39 | appid: '' # 你的appid
40 | api_key: '' # 你的api_key
41 | api_secret: '' # 你的api_secret
42 | domain: 'generalv3.5' # 要与spark_url对应
43 |
44 | bilibili_self:
45 | nickname: ''
46 |
47 | bilibili_cookie: # https://nemo2011.github.io/bilibili-api/#/get-credential 获取cookie 下面五个值都要填
48 | SESSDATA: ''
49 | ac_time_value: ''
50 | bili_jct: ''
51 | buvid3: ''
52 | dedeuserid: ''
53 |
54 |
55 | chain_keywords:
56 | summarize_keywords: # 用于生成评价的关键词,如果at/私信内容包含以下关键词,将会生成评价(该功能开发中)
57 | - "总结"
58 | - "总结一下"
59 | - "总结一下吧"
60 | - "总结一下吧!"
61 | ask_ai_keywords:
62 | - "问一下"
63 | - "请问"
64 |
65 |
66 | storage_settings:
67 | cache_path: /data/cache.json # 用于缓存已经处理过的视频,如果更改要映射出来
68 | statistics_dir: /data/statistics # 用于存放统计数据,如果更改要映射出来
69 | task_status_records: /data/records.json # 用于记录任务状态,如果更改要映射出来,不能留空
70 | queue_save_dir: /data/queue.json # 用于保存未完成的队列信息,下次运行时恢复
71 | temp_dir: /data/temp # 主要用于下载视频音频生成字幕,如果更改要映射出来
72 | up_video_cache: ./data/video_cache.json
73 | up_file: ./data/up.json
74 |
75 | debug_mode: true # 是否开启debug模式,开启后会打印更多日志,建议开启,以便于查找bug
76 |
--------------------------------------------------------------------------------
/src/utils/queue_manager.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import traceback
4 | from copy import deepcopy
5 |
6 | from src.models.task import BiliGPTTask
7 | from src.utils.file_tools import read_file, save_file
8 | from src.utils.logging import LOGGER
9 |
10 | _LOGGER = LOGGER
11 |
12 |
13 | class QueueManager:
14 | """队列管理器"""
15 |
16 | def __init__(self):
17 | _LOGGER.debug("初始化队列管理器")
18 | self.queues = {}
19 | self.saved_queue = {}
20 |
21 | def get_queue(self, queue_name: str) -> asyncio.Queue:
22 | if queue_name not in self.queues:
23 | _LOGGER.debug(f"正在创建{queue_name}队列")
24 | self.queues[queue_name] = asyncio.Queue()
25 | return self.queues.get(queue_name)
26 |
27 | def _save(self, file_path: str):
28 | content = json.dumps(self.saved_queue, ensure_ascii=False, indent=4)
29 | save_file(content, file_path)
30 |
31 | def _load(self, file_path: str):
32 | try:
33 | content = read_file(file_path)
34 | if not content:
35 | self.saved_queue = json.loads(content)
36 | else:
37 | self.saved_queue = {}
38 | self._save(file_path)
39 | except Exception:
40 | _LOGGER.error("在读取已保存的队列文件中出现问题!暂时跳过恢复,但不影响使用,请自行检查!")
41 | traceback.print_exc()
42 | self.saved_queue = {}
43 |
44 | def safe_close_queue(self, queue_name: str, saved_json_path: str):
45 | """
46 | 安全关闭队列(会对队列中任务进行保存)
47 | :param queue_name: 队列名
48 | :param saved_json_path: 保存位置
49 | :return:
50 | """
51 | queue_list = []
52 | queue = self.get_queue(queue_name)
53 | while not queue.empty():
54 | item: BiliGPTTask = queue.get_nowait()
55 | queue_list.append(item)
56 | _LOGGER.debug(f"共保存了{len(queue_list)}条数据!")
57 | self.saved_queue[queue_name] = queue_list
58 | self._save(saved_json_path)
59 |
60 | def safe_close_all_queues(self, saved_json_path: str):
61 | """
62 | 保存所有的队列(我建议使用这个)
63 | :param saved_json_path:
64 | :return:
65 | """
66 | for queue in list(self.saved_queue.keys()):
67 | _LOGGER.debug(f"保存{queue}队列的任务")
68 | self.safe_close_queue(queue, saved_json_path)
69 |
70 | def recover_queue(self, saved_json_path: str):
71 | """
72 | 恢复保存在文件中的任务信息
73 | :param saved_json_path:
74 | :return:
75 | """
76 | self._load(saved_json_path)
77 | _queue_dict = deepcopy(self.saved_queue)
78 | for queue_name in list(self.saved_queue.keys()):
79 | _LOGGER.debug(f"开始恢复{queue_name}")
80 | queue = self.get_queue(queue_name)
81 | for task in self.saved_queue[queue_name]:
82 | queue.put_nowait(task)
83 | del _queue_dict[queue_name]
84 | self.saved_queue = _queue_dict
85 | self._save(saved_json_path)
86 |
--------------------------------------------------------------------------------
/src/llm/claude.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | from typing import Tuple
3 |
4 | import anthropic
5 |
6 | from src.llm.llm_base import LLMBase
7 | from src.llm.templates import Templates
8 | from src.utils.logging import LOGGER
9 | from src.utils.prompt_utils import parse_prompt
10 |
11 | _LOGGER = LOGGER.bind(name="aiproxy_claude")
12 |
13 |
14 | class AiproxyClaude(LLMBase):
15 | def prepare(self):
16 | mask_key = self.config.LLMs.aiproxy_claude.api_key[:-5] + "*****"
17 | _LOGGER.info(f"初始化AIProxyClaude,api_key为{mask_key},api端点为{self.config.LLMs.aiproxy_claude.api_base}")
18 |
19 | async def completion(self, prompt, **kwargs) -> Tuple[str, int] | None:
20 | """调用claude的Completion API
21 | :param prompt: 输入的文本(请确保格式化为openai的prompt格式)
22 | :param kwargs: 其他参数
23 | :return: 返回生成的文本和token总数 或 None
24 | """
25 | try:
26 | claude = anthropic.AsyncAnthropic(
27 | api_key=self.config.LLMs.aiproxy_claude.api_key,
28 | base_url=self.config.LLMs.aiproxy_claude.api_base,
29 | )
30 | resp = await claude.completions.create(
31 | prompt=prompt,
32 | max_tokens_to_sample=1000,
33 | model=self.config.LLMs.aiproxy_claude.model,
34 | **kwargs,
35 | )
36 | _LOGGER.debug(f"调用claude的Completion API成功,API返回结果为:{resp}")
37 | _LOGGER.info(
38 | f"调用claude的Completion API成功,本次调用中,prompt+response的长度为{resp.model_dump()['usage']['total_tokens']}"
39 | )
40 | if not resp.completion.startswith('{"'):
41 | # claude好像是直接从assistant所给内容之后续写的,给它加上缺失的前缀
42 | resp.completion = '{"' + resp.completion
43 | return (
44 | resp.completion,
45 | resp.model_dump()["usage"]["total_tokens"],
46 | )
47 | except Exception as e:
48 | _LOGGER.error(f"调用claude的Completion API失败:{e}")
49 | traceback.print_tb(e.__traceback__)
50 | return None
51 |
52 | @staticmethod
53 | def use_template(
54 | user_template_name: Templates,
55 | system_template_name: Templates = None,
56 | user_keyword="user",
57 | system_keyword="system",
58 | **kwargs,
59 | ) -> str | None:
60 | try:
61 | template_user = user_template_name.value
62 | template_system = system_template_name.value if system_template_name else None
63 | utemplate = parse_prompt(template_user, **kwargs)
64 | stemplate = parse_prompt(template_system, **kwargs) if template_system else None
65 | prompt = f"I will give you 'rules' 'content' two tags. You need to follow the rules! {utemplate} {stemplate}"
66 | prompt = anthropic.HUMAN_PROMPT + " " + prompt + " " + anthropic.AI_PROMPT + " " + '{"'
67 | _LOGGER.info("使用模板成功")
68 | _LOGGER.debug(f"生成的prompt为:{prompt}")
69 | return prompt
70 | except Exception as e:
71 | _LOGGER.error(f"使用模板失败:{e}")
72 | traceback.print_exc()
73 | return None
74 |
--------------------------------------------------------------------------------
/src/bilibili/bili_session.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Union
3 |
4 | import tenacity
5 | from bilibili_api import session
6 | from bilibili_api.session import EventType
7 | from injector import inject
8 |
9 | from src.bilibili.bili_credential import BiliCredential
10 | from src.bilibili.bili_video import BiliVideo
11 | from src.models.task import AskAIResponse, BiliGPTTask, SummarizeAiResponse
12 | from src.utils.callback import chain_callback
13 | from src.utils.logging import LOGGER
14 |
15 | _LOGGER = LOGGER.bind(name="bilibili-session")
16 |
17 |
18 | class BiliSession:
19 | @inject
20 | def __init__(self, credential: BiliCredential, private_queue: asyncio.Queue):
21 | """
22 | 初始化BiliSession类
23 |
24 | :param credential: B站凭证
25 | :param private_queue: 私信队列
26 | """
27 | self.credential = credential
28 | self.private_queue = private_queue
29 |
30 | @staticmethod
31 | async def quick_send(credential, task: BiliGPTTask, msg: str):
32 | """快速发送私信"""
33 | await session.send_msg(
34 | credential,
35 | # at_items["item"]["private_msg_event"]["text_event"]["sender_uid"],
36 | int(task.sender_id),
37 | EventType.TEXT,
38 | msg,
39 | )
40 |
41 | @staticmethod
42 | def build_reply_content(response: Union[SummarizeAiResponse, AskAIResponse]) -> list:
43 | """构建回复内容(由于有私信消息过长被截断的先例,所以返回是一个list,分消息发)"""
44 | # TODO 有时还是会触碰到b站的字数墙,但不清楚字数限制是多少,再等等看
45 | # TODO 这种判断方式很不优雅,但现在是半夜十二点,我不想改了,我想睡觉了
46 | if isinstance(response, SummarizeAiResponse):
47 | msg_list = [
48 | f"【视频摘要】{response.summary}",
49 | f"【视频评分】{response.score}分\n\n【咱还想说】{response.thinking}",
50 | ]
51 | elif isinstance(response, AskAIResponse):
52 | msg_list = [f"【回答】{response.answer}\n\n【自我评分】{response.score}分"]
53 | else:
54 | msg_list = [f"程序内部错误:无法识别的回复类型{type(response)}"]
55 | return msg_list
56 |
57 | @tenacity.retry(
58 | retry=tenacity.retry_if_exception_type(Exception),
59 | wait=tenacity.wait_fixed(10),
60 | before_sleep=chain_callback,
61 | )
62 | async def start_private_reply(self):
63 | """发送评论"""
64 | while True:
65 | try:
66 | data: BiliGPTTask = await self.private_queue.get()
67 | _LOGGER.debug("获取到新的私信任务,开始处理")
68 | _, _type = await BiliVideo(credential=self.credential, url=data.video_url).get_video_obj()
69 | msg_list = BiliSession.build_reply_content(data.process_result)
70 | for msg in msg_list:
71 | await session.send_msg(
72 | self.credential,
73 | # data["item"]["private_msg_event"]["text_event"]["sender_uid"],
74 | int(data.sender_id),
75 | EventType.TEXT,
76 | msg,
77 | )
78 | await asyncio.sleep(3)
79 | except asyncio.CancelledError:
80 | _LOGGER.info("私信处理链关闭")
81 | return
82 |
--------------------------------------------------------------------------------
/src/asr/local_whisper.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import functools
3 | import time
4 | import traceback
5 | from typing import Optional
6 |
7 | from src.asr.asr_base import ASRBase
8 | from src.core.routers.llm_router import LLMRouter
9 | from src.llm.templates import Templates
10 | from src.models.config import Config
11 | from src.utils.logging import LOGGER
12 |
13 | _LOGGER = LOGGER.bind(name="LocalWhisper")
14 |
15 |
16 | class LocalWhisper(ASRBase):
17 | def __init__(self, config: Config, llm_router: LLMRouter):
18 | super().__init__(config, llm_router)
19 | self.llm_router = llm_router
20 | self.model = None
21 | self.config = config
22 |
23 | def prepare(self) -> None:
24 | """
25 | 加载whisper模型
26 | :return: None
27 | """
28 | _LOGGER.info(
29 | f"正在加载whisper模型,模型大小{self.config.ASRs.local_whisper.model_size},设备{self.config.ASRs.local_whisper.device}"
30 | )
31 | import whisper as whi
32 |
33 | self.model = whi.load_model(
34 | self.config.ASRs.local_whisper.model_size,
35 | self.config.ASRs.local_whisper.device,
36 | download_root=self.config.ASRs.local_whisper.model_dir,
37 | )
38 | _LOGGER.info("加载whisper模型成功")
39 |
40 | async def after_process(self, text, **kwargs) -> str:
41 | llm = self.llm_router.get_one()
42 | prompt = llm.use_template(Templates.AFTER_PROCESS_SUBTITLE, subtitle=text)
43 | answer, _ = await llm.completion(prompt)
44 | if answer is None:
45 | _LOGGER.error("后处理失败,返回原字幕")
46 | return text
47 | return answer
48 |
49 | def _sync_transcribe(self, audio_path, **kwargs) -> Optional[str]:
50 | try:
51 | begin_time = time.perf_counter()
52 | _LOGGER.info(f"开始转写 {audio_path}")
53 | if self.model is None:
54 | return None
55 | import whisper as whi
56 |
57 | text = whi.transcribe(self.model, audio_path)
58 | text = text["text"]
59 | _LOGGER.debug("转写成功")
60 | time_elapsed = time.perf_counter() - begin_time
61 | _LOGGER.info(f"字幕转译完成,共用时{time_elapsed}s")
62 | return text
63 | except Exception as e:
64 | _LOGGER.error(f"转写失败,错误信息为{e}", exc_info=True)
65 | return None
66 |
67 | async def transcribe(self, audio_path, **kwargs) -> Optional[str]: # 添加self参数以访问线程池
68 | loop = asyncio.get_event_loop()
69 |
70 | func = functools.partial(self._sync_transcribe, audio_path, **kwargs)
71 |
72 | result = await loop.run_in_executor(None, func)
73 | w = self.config.ASRs.local_whisper
74 | try:
75 | if w.after_process and result is not None:
76 | bt = time.perf_counter()
77 | _LOGGER.info("正在进行后处理")
78 | text = await self.after_process(result)
79 | _LOGGER.debug(f"后处理完成,用时{time.perf_counter()-bt}s")
80 | return text
81 | return result
82 | except Exception as e:
83 | _LOGGER.error(f"后处理失败,错误信息为{e}")
84 | traceback.print_exc()
85 | return result
86 |
--------------------------------------------------------------------------------
/src/llm/llm_base.py:
--------------------------------------------------------------------------------
1 | """llm对接的基础类"""
2 |
3 | import abc
4 | import re
5 | import traceback
6 | from typing import Tuple
7 |
8 | from src.llm.templates import Templates
9 | from src.models.config import Config
10 | from src.utils.logging import LOGGER
11 | from src.utils.prompt_utils import build_openai_style_messages, parse_prompt
12 |
13 | _LOGGER = LOGGER.bind(name="llm_base")
14 |
15 |
16 | class LLMBase:
17 | """实现这个类,即可轻松对接其他的LLM模型"""
18 |
19 | def __init__(self, config: Config):
20 | self.config = config
21 |
22 | def __new__(cls, *args, **kwargs):
23 | """将类名转换为alias"""
24 | instance = super().__new__(cls)
25 | name = cls.__name__
26 | name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
27 | name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
28 | instance.alias = name
29 | return instance
30 |
31 | def prepare(self):
32 | """
33 | 初始化方法,例如设置参数等
34 | 会在该类被 **第一次** 使用时调用
35 | 这个函数不应该有入参,所有参数都从self.config中获取
36 | """
37 | pass
38 |
39 | @abc.abstractmethod
40 | async def completion(self, prompt, **kwargs) -> Tuple[str, int] | None:
41 | """使用LLM生成文本(如果出错的话需要在这里自己捕捉错误并返回None)
42 | 请确保整个过程为 **异步**,否则会阻塞整个程序
43 | :param prompt: 最终的输入文本,确保格式化过
44 | :param kwargs: 其他参数
45 | :return: 返回生成的文本和token总数 或 None
46 | """
47 | pass
48 |
49 | def _sync_completion(self, prompt, **kwargs) -> Tuple[str, int] | None:
50 | """如果你的调用方式为同步,请先在这里实现,然后在completion中使用线程池调用
51 | :param prompt: 最终的输入文本,确保格式化过
52 | :param kwargs: 其他参数
53 | :return: 返回生成的文本和token总数 或 None
54 | """
55 | pass
56 |
57 | @staticmethod
58 | def use_template(
59 | user_template_name: Templates,
60 | system_template_name: Templates = None,
61 | user_keyword="user",
62 | system_keyword="system",
63 | **kwargs,
64 | ) -> list | None:
65 | """使用模板生成最终prompt(最终格式可能需要根据llm所需格式不同修改,默认为openai的system、user格式)
66 | :param user_template_name: 用户模板名称
67 | :param system_template_name: 系统模板名称
68 | :param user_keyword: 用户关键词(这个和下面的system_keyword要根据每个llm不同的要求来填)
69 | :param system_keyword: 系统关键词
70 | :param kwargs: 模板参数
71 | :return: 返回生成的prompt 或 None
72 | """
73 | try:
74 | template_user = user_template_name.value
75 | template_system = system_template_name.value if system_template_name else None
76 | utemplate = parse_prompt(template_user, **kwargs)
77 | stemplate = parse_prompt(template_system, **kwargs) if template_system else None
78 | prompt = (
79 | build_openai_style_messages(utemplate, stemplate, user_keyword, system_keyword)
80 | if stemplate
81 | else build_openai_style_messages(utemplate, user_keyword=user_keyword)
82 | )
83 | _LOGGER.info("使用模板成功")
84 | _LOGGER.debug(f"生成的prompt为:{prompt}")
85 | return prompt
86 | except Exception as e:
87 | _LOGGER.error(f"使用模板失败:{e}")
88 | traceback.print_exc()
89 | return None
90 |
91 | def __repr__(self):
92 | return self.alias
93 |
94 | def __str__(self):
95 | return self.__class__.__name__
96 |
--------------------------------------------------------------------------------
/src/bilibili/bili_video.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from bilibili_api import ResourceType, parse_link, video
4 | from injector import inject
5 |
6 | from src.bilibili.bili_credential import BiliCredential
7 |
8 |
9 | class BiliVideo:
10 | @inject
11 | def __init__(
12 | self,
13 | credential: BiliCredential,
14 | bvid: str = None,
15 | aid: int = None,
16 | url: str = None,
17 | ):
18 | """
19 | 三选一,优先级为url > aid > bvid
20 | :param credential:
21 | :param bvid:
22 | :param aid:
23 | :param url:
24 | """
25 | self.credential = credential
26 | self._bvid = bvid
27 | self.aid = aid
28 | self.url = url
29 | self.video_obj: Optional[video.Video] = None
30 |
31 | async def get_video_obj(self):
32 | _type = ResourceType.VIDEO
33 | if self.video_obj:
34 | return self.video_obj, _type
35 | if self.url:
36 | self.video_obj, _type = await parse_link(self.url, credential=self.credential)
37 | elif self.aid:
38 | self.video_obj = video.Video(aid=self.aid, credential=self.credential)
39 | elif self._bvid:
40 | self.video_obj = video.Video(bvid=self._bvid, credential=self.credential)
41 | else:
42 | raise ValueError("缺少必要参数")
43 | return self.video_obj, _type
44 |
45 | @property
46 | async def get_video_info(self):
47 | if not self.video_obj:
48 | await self.get_video_obj()
49 | return await self.video_obj.get_info()
50 |
51 | @property
52 | async def get_video_pages(self):
53 | if not self.video_obj:
54 | await self.get_video_obj()
55 | return await self.video_obj.get_pages()
56 |
57 | async def get_video_tags(self, page_index: int = 0):
58 | if not self.video_obj:
59 | await self.get_video_obj()
60 | return await self.video_obj.get_tags(page_index=page_index)
61 |
62 | async def get_video_download_url(self, page_index: int = 0):
63 | if not self.video_obj:
64 | await self.get_video_obj()
65 | return await self.video_obj.get_download_url(page_index=page_index)
66 |
67 | async def get_video_subtitle(self, cid: int = None, page_index: int = 0):
68 | """返回字幕链接,如果有多个字幕则优先返回非ai和翻译字幕,如果没有则返回ai字幕"""
69 | if not self.video_obj:
70 | await self.get_video_obj()
71 | if not cid:
72 | cid = await self.video_obj.get_cid(page_index=page_index)
73 | info = await self.video_obj.get_player_info(cid=cid)
74 | json_files = info["subtitle"]["subtitles"]
75 | if len(json_files) == 0:
76 | return None
77 | if len(json_files) == 1:
78 | return json_files[0]["subtitle_url"]
79 | for subtitle in json_files:
80 | if subtitle["lan_doc"] != "中文(自动翻译)" and subtitle["lan_doc"] != "中文(自动生成)":
81 | return subtitle["subtitle_url"]
82 | for subtitle in json_files:
83 | if subtitle["lan_doc"] == "中文(自动翻译)" or subtitle["lan_doc"] == "中文(自动生成)":
84 | return subtitle["subtitle_url"]
85 |
86 | @property
87 | async def bvid(self) -> str:
88 | if not self.video_obj:
89 | await self.get_video_obj()
90 | return self.video_obj.get_bvid()
91 |
92 | @property
93 | async def format_title(self) -> str:
94 | if not self.video_obj:
95 | await self.get_video_obj()
96 | info = await self.video_obj.get_info()
97 | return f"『{info['title']}』"
98 |
--------------------------------------------------------------------------------
/src/utils/task_status_record.py:
--------------------------------------------------------------------------------
1 | import enum
2 | import json
3 | from typing import Union
4 |
5 | from src.models.task import BiliGPTTask, Chains, ProcessStages
6 | from src.utils.exceptions import LoadJsonError
7 | from src.utils.file_tools import read_file, save_file
8 | from src.utils.logging import LOGGER
9 |
10 | _LOGGER = LOGGER.bind(name="task-status-record")
11 |
12 |
13 | class TaskStatusRecorder:
14 | """视频状态记录器"""
15 |
16 | def __init__(self, file_path):
17 | self.file_path = file_path
18 | self.video_records = {}
19 | self.load()
20 |
21 | def load(self):
22 | # if os.path.exists(self.file_path):
23 | # with open(self.file_path, encoding="utf-8") as f:
24 | # if os.path.getsize(self.file_path) == 0:
25 | # self.video_records = {}
26 | # else:
27 | # self.video_records = json.load(f)
28 | # else:
29 | # self.video_records = {}
30 | # self.save()
31 | try:
32 | content = read_file(self.file_path)
33 | if content:
34 | self.video_records = json.loads(content)
35 | else:
36 | self.video_records = {}
37 | self.save()
38 | except Exception as e:
39 | raise LoadJsonError("在读取视频记录文件时出现问题!程序已停止运行,请自行检查问题所在") from e
40 |
41 | # except Exception as e:
42 | # _LOGGER.error(f"读取视频状态记录文件失败,错误信息为{e},恢复为初始文件")
43 | # self.video_records = {}
44 | # # with open(self.file_path, "w", encoding="utf-8") as f:
45 | # # json.dump(self.video_records, f, ensure_ascii=False, indent=4)
46 | # save_file(json.dumps(self.video_records), self.file_path)
47 |
48 | def save(self):
49 | # with open(self.file_path, "w", encoding="utf-8") as f:
50 | # json.dump(self.video_records, f, ensure_ascii=False, indent=4)
51 | save_file(json.dumps(self.video_records, ensure_ascii=False, indent=4), self.file_path)
52 |
53 | def get_record_by_stage(
54 | self,
55 | chain: Chains,
56 | stage: ProcessStages = None,
57 | ):
58 | """
59 | 根据stage获取记录
60 | 当stage为None时,返回所有记录
61 | """
62 | records = []
63 | if stage is None:
64 | for record in self.video_records.values():
65 | if record["chain"] == chain.value:
66 | records.append(record)
67 | return records
68 | for record in self.video_records.values():
69 | if record["process_stage"] == stage.value and record["chain"] == chain.value:
70 | records.append(record)
71 | return records
72 |
73 | def create_record(self, item: BiliGPTTask):
74 | """创建一条记录,返回一条uuid,可以根据uuid修改记录"""
75 | self.video_records[str(item.uuid)] = item.model_dump(mode="json")
76 | # del self.video_records[item.uuid]["raw_task_data"]["video_event"]["content"]
77 | self.save()
78 | return item.uuid
79 |
80 | def update_record(self, _uuid: str, new_task_data: Union[BiliGPTTask, None], **kwargs) -> bool:
81 | """根据uuid更新记录"""
82 | # record: BiliGPTTask = self.video_records[_uuid]
83 | _uuid = str(_uuid)
84 | if new_task_data is not None:
85 | self.video_records[_uuid] = new_task_data.model_dump(mode="json")
86 | # del self.video_records[_uuid]["raw_task_data"]["video_event"]["content"]
87 | if self.video_records[_uuid] is None:
88 | return False
89 | for key, _value in kwargs.items():
90 | if isinstance(_value, enum.Enum):
91 | _value = _value.value
92 | if key == "process_stage":
93 | self.video_records[_uuid]["process_stage"] = _value
94 | if key in self.video_records[_uuid]:
95 | self.video_records[_uuid][key] = _value
96 | else:
97 | _LOGGER.warning(f"尝试更新不存在的字段:{key},跳过")
98 | self.save()
99 | return True
100 |
101 | # def get_uuid_by_data(self, data: BiliGPTTask):
102 | # """根据data获取uuid"""
103 | # for _uuid, record in self.tasks.items():
104 | # if record["data"] == data:
105 | # return _uuid
106 | # return None
107 |
108 | def get_data_by_uuid(self, _uuid: str) -> BiliGPTTask:
109 | """根据uuid获取data"""
110 | return self.video_records[_uuid]
111 |
--------------------------------------------------------------------------------
/src/core/routers/llm_router.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import os
3 | import traceback
4 | from typing import Optional
5 |
6 | from injector import inject
7 |
8 | from src.llm.llm_base import LLMBase
9 | from src.models.config import Config
10 | from src.utils.logging import LOGGER
11 |
12 | _LOGGER = LOGGER.bind(name="LLM-Router")
13 |
14 |
15 | class LLMRouter:
16 | """LLM路由器,用于加载所有LLM子类并进行合理路由"""
17 |
18 | @inject
19 | def __init__(self, config: Config):
20 | self.config = config
21 | self._llm_dict = {}
22 | self.max_err_times = 10 # TODO i know i know,硬编码很不优雅,但这种选项开放给用户似乎也没必要
23 |
24 | def load_from_dir(self, py_style_path: str = "src.llm"):
25 | """
26 | 从一个文件夹中加载所有LLM子类
27 |
28 | :param py_style_path: 包导入风格的路径,以文件运行位置为基础路径
29 | :return: None
30 | """
31 | raw_path = "./" + py_style_path.replace(".", "/")
32 | for file_name in os.listdir(raw_path):
33 | if file_name.endswith(".py") and file_name != "__init__.py":
34 | module_name = file_name[:-3]
35 | module = __import__(f"{py_style_path}.{module_name}", fromlist=[module_name])
36 | for attr_name in dir(module):
37 | attr = getattr(module, attr_name)
38 | if inspect.isclass(attr) and issubclass(attr, LLMBase) and attr != LLMBase:
39 | self.load(attr)
40 |
41 | def load(self, attr):
42 | """加载一个ASR子类"""
43 | try:
44 | _asr = attr(self.config)
45 | setattr(self, _asr.alias, _asr)
46 | _LOGGER.info(f"正在加载 {_asr.alias}")
47 | _config = self.config.model_dump(mode="json")["LLMs"][_asr.alias]
48 | priority = _config["priority"]
49 | enabled = _config["enable"]
50 | if priority is None or enabled is None:
51 | raise ValueError
52 | # 设置属性
53 | self.llm_dict[_asr.alias] = {
54 | "priority": priority,
55 | "enabled": enabled,
56 | "prepared": False,
57 | "err_times": 0,
58 | "obj": self.get(_asr.alias),
59 | }
60 | except Exception as e:
61 | _LOGGER.error(f"加载 {str(attr)} 失败,错误信息为{e}")
62 | traceback.print_exc()
63 | else:
64 | _LOGGER.info(f"加载 {_asr.alias} 成功,优先级为{priority},启用状态为{enabled}")
65 | _LOGGER.debug(f"当前已加载的LLM子类有 {self.llm_dict}")
66 |
67 | @property
68 | def llm_dict(self):
69 | """获取所有已加载的LLM子类"""
70 | return self._llm_dict
71 |
72 | def get(self, name):
73 | """获取一个已加载的LLM子类"""
74 | try:
75 | return getattr(self, name)
76 | except Exception as e:
77 | _LOGGER.error(f"获取LLM子类失败,错误信息为{e}", exc_info=True)
78 | return None
79 |
80 | @llm_dict.setter
81 | def llm_dict(self, value):
82 | self._llm_dict = value
83 |
84 | def order(self):
85 | """
86 | 对ASR子类进行排序
87 | 优先级高的排在前面,未启用的排在最后"""
88 | self.llm_dict = dict(
89 | sorted(
90 | self.llm_dict.items(),
91 | key=lambda item: (
92 | not item[1].get("enabled", True),
93 | item[1]["priority"],
94 | ),
95 | reverse=True,
96 | )
97 | )
98 |
99 | def get_one(self) -> Optional[LLMBase]:
100 | """根据优先级获取一个可用的LLM子类,如果所有都不可用则返回None"""
101 | self.order()
102 | for llm in self.llm_dict.values():
103 | if llm["enabled"] and llm["err_times"] <= 10:
104 | if not llm["prepared"]:
105 | _LOGGER.info(f"正在初始化 {llm['obj'].alias}")
106 | llm["obj"].prepare()
107 | llm["prepared"] = True
108 | return llm["obj"]
109 | return None
110 |
111 | def report_error(self, name: str):
112 | """报告一个LLM子类的错误"""
113 | for llm in self.llm_dict.values():
114 | if llm["obj"].alias == name:
115 | llm["err_times"] += 1
116 | if llm["err_times"] >= self.max_err_times:
117 | llm["enabled"] = False
118 | break
119 | else:
120 | raise ValueError(f"LLM子类 {name} 不存在")
121 | _LOGGER.info(f"{name} 发生错误,已累计错误{llm['err_times']}次")
122 | _LOGGER.debug(f"当前已加载的LLM类有 {self.llm_dict}")
123 |
--------------------------------------------------------------------------------
/src/core/routers/asr_router.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import os
3 | import traceback
4 | from typing import Optional
5 |
6 | from injector import inject
7 |
8 | from src.asr.asr_base import ASRBase
9 | from src.core.routers.llm_router import LLMRouter
10 | from src.models.config import Config
11 | from src.utils.logging import LOGGER
12 |
13 | _LOGGER = LOGGER.bind(name="ASR-Router")
14 |
15 |
16 | class ASRouter:
17 | """ASR路由器,用于加载所有ASR子类并进行合理路由"""
18 |
19 | @inject
20 | def __init__(self, config: Config, llm_router: LLMRouter):
21 | self.config = config
22 | self._asr_dict = {}
23 | self.llm_router = llm_router
24 | self.max_err_times = 10 # TODO i know i know,硬编码很不优雅,但这种选项开放给用户似乎也没必要
25 |
26 | def load_from_dir(self, py_style_path: str = "src.asr"):
27 | """
28 | 从一个文件夹中加载所有ASR子类
29 |
30 | :param py_style_path: 包导入风格的路径,以文件运行位置为基础路径
31 | :return: None
32 | """
33 | raw_path = "./" + py_style_path.replace(".", "/")
34 | for file_name in os.listdir(raw_path):
35 | if file_name.endswith(".py") and file_name != "__init__.py":
36 | module_name = file_name[:-3]
37 | module = __import__(f"{py_style_path}.{module_name}", fromlist=[module_name])
38 | for attr_name in dir(module):
39 | attr = getattr(module, attr_name)
40 | if inspect.isclass(attr) and issubclass(attr, ASRBase) and attr != ASRBase:
41 | self.load(attr)
42 |
43 | def load(self, attr):
44 | """加载一个ASR子类"""
45 | try:
46 | _asr = attr(self.config, self.llm_router)
47 | setattr(self, _asr.alias, _asr)
48 | _LOGGER.info(f"正在加载 {_asr.alias}")
49 | _config = self.config.model_dump(mode="json")["ASRs"][_asr.alias]
50 | priority = _config["priority"]
51 | enabled = _config["enable"]
52 | if priority is None or enabled is None:
53 | raise ValueError
54 | # 设置属性
55 | self.asr_dict[_asr.alias] = {
56 | "priority": priority,
57 | "enabled": enabled,
58 | "prepared": False,
59 | "err_times": 0,
60 | "obj": self.get(_asr.alias),
61 | }
62 | except Exception as e:
63 | _LOGGER.error(f"加载 {str(attr)} 失败,错误信息为{e}")
64 | traceback.print_exc()
65 | else:
66 | _LOGGER.info(f"加载 {_asr.alias} 成功,优先级为{priority},启用状态为{enabled}")
67 | _LOGGER.debug(f"当前已加载的ASR子类有 {self.asr_dict}")
68 |
69 | @property
70 | def asr_dict(self):
71 | """获取所有已加载的ASR子类"""
72 | return self._asr_dict
73 |
74 | def get(self, name):
75 | """获取一个已加载的ASR子类"""
76 | try:
77 | return getattr(self, name)
78 | except Exception as e:
79 | _LOGGER.error(f"获取ASR子类失败,错误信息为{e}", exc_info=True)
80 | return None
81 |
82 | @asr_dict.setter
83 | def asr_dict(self, value):
84 | self._asr_dict = value
85 |
86 | def order(self):
87 | """
88 | 对ASR子类进行排序
89 | 优先级高的排在前面,未启用的排在最后"""
90 | self.asr_dict = dict(
91 | sorted(
92 | self.asr_dict.items(),
93 | key=lambda item: (
94 | not item[1].get("enabled"),
95 | item[1]["priority"],
96 | ),
97 | reverse=True,
98 | )
99 | )
100 |
101 | def get_one(self) -> Optional[ASRBase]:
102 | """根据优先级获取一个可用的ASR子类,如果所有都不可用则返回None"""
103 | self.order()
104 | for asr in self.asr_dict.values():
105 | if asr["enabled"] and asr["err_times"] <= 10:
106 | if not asr["prepared"]:
107 | _LOGGER.info(f"正在初始化 {asr['obj'].alias}")
108 | asr["obj"].prepare()
109 | asr["prepared"] = True
110 | return asr["obj"]
111 | LOGGER.error("没有可用的ASR子类")
112 | return None
113 |
114 | def report_error(self, name: str):
115 | """报告一个ASR子类的错误"""
116 | for asr in self.asr_dict.values():
117 | if asr["obj"].alias == name:
118 | asr["err_times"] += 1
119 | if asr["err_times"] >= self.max_err_times:
120 | asr["enabled"] = False
121 | break
122 | else:
123 | raise ValueError(f"ASR子类 {name} 不存在")
124 | _LOGGER.info(f"{name} 发生错误,已累计错误{asr['err_times']}次")
125 | _LOGGER.debug(f"当前已加载的ASR子类有 {self.asr_dict}")
126 |
--------------------------------------------------------------------------------
/src/asr/openai_whisper.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import functools
3 | import json
4 | import os
5 | import time
6 | import traceback
7 | import uuid
8 | from typing import Optional
9 |
10 | import openai
11 | from pydub import AudioSegment
12 |
13 | from src.asr.asr_base import ASRBase
14 | from src.llm.templates import Templates
15 | from src.utils.logging import LOGGER
16 |
17 | _LOGGER = LOGGER.bind(name="OpenaiWhisper")
18 |
19 |
20 | class OpenaiWhisper(ASRBase):
21 | def prepare(self) -> None:
22 | apikey = self.config.ASRs.openai_whisper.api_key
23 | apikey = apikey[:-5] + "*****"
24 | _LOGGER.info(f"初始化OpenaiWhisper,api_key为{apikey},api端点为{self.config.ASRs.openai_whisper.api_base}")
25 |
26 | def _cut_audio(self, audio_path: str) -> list[str]:
27 | """将音频切割为300s的片段,前后有5s的滑动窗口,返回切割后的文件名列表
28 | :param audio_path: 音频文件路径
29 | :return: 切割后的文件名列表
30 | """
31 | temp = self.config.storage_settings.temp_dir
32 | audio = AudioSegment.from_file(audio_path, "mp3")
33 | segment_length = 300 * 1000
34 | window_length = 5 * 1000
35 | start_time = 0
36 | output_segments = []
37 | export_file_list = []
38 | _uuid = uuid.uuid4()
39 |
40 | while start_time < len(audio):
41 | segment = audio[start_time : start_time + segment_length]
42 |
43 | if start_time > 0:
44 | segment = audio[start_time - window_length : start_time] + segment
45 |
46 | if start_time + segment_length < len(audio):
47 | _LOGGER.debug(f"正在处理{start_time}到{start_time+segment_length}的音频")
48 | segment = segment + audio[start_time + segment_length : start_time + segment_length + window_length]
49 |
50 | output_segments.append(segment)
51 | start_time += segment_length
52 |
53 | for num, segment in enumerate(output_segments):
54 | with open(f"{temp}/{_uuid}_segment_{num}.mp3", "wb") as file:
55 | segment.export(file, format="mp3")
56 | _LOGGER.debug(f"第{num}个切片导出完成")
57 | export_file_list.append(f"{_uuid}_segment_{num}.mp3")
58 |
59 | return export_file_list
60 |
61 | def _sync_transcribe(self, audio_path: str, **kwargs) -> Optional[str]:
62 | """同步调用openai的transcribe API
63 | :param audio_path: 音频文件路径
64 | :param kwargs: 其他参数(传递给openai.Audio.transcribe)
65 | :return: 返回识别结果或None
66 | """
67 | _LOGGER.debug(f"正在识别{audio_path}")
68 | openai.api_key = self.config.ASRs.openai_whisper.api_key
69 | openai.api_base = self.config.ASRs.openai_whisper.api_base
70 | with open(audio_path, "rb") as audio:
71 | response = openai.Audio.transcribe(model="whisper-1", file=audio)
72 |
73 | _LOGGER.debug(f"返回内容为{response}")
74 |
75 | if isinstance(response, dict) and "text" in response:
76 | return response["text"]
77 | try:
78 | response = json.loads(response)
79 | return response["text"]
80 | except Exception:
81 | _LOGGER.error("返回内容不是字典或者没有text字段,返回None")
82 | return None
83 |
84 | async def transcribe(self, audio_path: str, **kwargs) -> Optional[str]:
85 | loop = asyncio.get_event_loop()
86 | func_list = []
87 | temp = self.config.storage_settings.temp_dir
88 | _LOGGER.info("正在切割音频")
89 | export_file_list = self._cut_audio(audio_path)
90 | _LOGGER.info(f"音频切割完成,共{len(export_file_list)}个切片")
91 | for file in export_file_list:
92 | func_list.append(functools.partial(self._sync_transcribe, f"{temp}/{file}", **kwargs))
93 | _LOGGER.info("正在处理音频")
94 | result = await asyncio.gather(*[loop.run_in_executor(None, func) for func in func_list])
95 | _LOGGER.info("音频处理完成")
96 | if None in result:
97 | _LOGGER.error("识别失败,返回None") # TODO 单独重试失败的切片
98 | return None
99 | result = "".join(result)
100 | # 清除临时文件
101 | for file in export_file_list:
102 | os.remove(f"{temp}/{file}")
103 | try:
104 | if self.config.ASRs.openai_whisper.after_process and result is not None:
105 | bt = time.perf_counter()
106 | _LOGGER.info("正在进行后处理")
107 | text = await self.after_process(result)
108 | _LOGGER.debug(f"后处理完成,用时{time.perf_counter()-bt}s")
109 | return text
110 | return result
111 | except Exception as e:
112 | _LOGGER.error(f"后处理失败,错误信息为{e}")
113 | traceback.print_exc()
114 | return result
115 |
116 | async def after_process(self, text: str, **kwargs) -> str:
117 | llm = self.llm_router.get_one()
118 | prompt = llm.use_template(Templates.AFTER_PROCESS_SUBTITLE, subtitle=text)
119 | answer, _ = await llm.completion(prompt)
120 | if answer is None:
121 | _LOGGER.error("后处理失败,返回原字幕")
122 | return text
123 | return answer
124 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | > [!CAUTION]
2 | > 你b已经自己搞了ai课代表,这个项目没啥大用了,所以不再更新
3 |
✨Bilibili GPT Bot✨
4 | 把ChatGPT放到B站里!更高效、快速的了解视频内容
5 |
6 | [](https://www.python.org/downloads/release/python-311/)
7 | [](https://opensource.org/licenses/MIT)
8 | [](https://wakatime.com/badge/user/41ab10cc-ec82-41e9-8417-9dcf5a9b5947/project/cef4699c-8d07-4cf0-9d0a-ef83fb353b82)
9 | [](https://github.com/yanyao2333/BiliGPTHelper/actions/workflows/push_to_docker.yaml)
10 |
11 | ### 🌟 介绍
12 |
13 | b站那种AI课代表都见过吧?看起来超级酷!所以我也搞了个捏
14 |
15 | ### 📜 声明
16 |
17 | 当你查阅、下载了本项目源代码或二进制程序,即代表你接受了以下条款:
18 |
19 | 1. 本项目和项目成果仅供技术,学术交流和Python3性能测试使用\
20 | 2. 本项目贡献者编写该项目旨在学习Python3 ,提高编程水平\
21 | 3. 用户在使用本项目和项目成果前,请用户了解并**遵守当地法律法规**,如果本项目及项目成果使用过程中存在违反当地法律法规的行为,请勿使用该项目及项目成果\
22 | 4. **法律后果及使用后果由使用者承担**\
23 | 5. 开发者(yanyao2333)**不会**
24 | 将任何用户的cookie、个人信息收集上传至除b站官方的其他平台或服务器。同时,开发者(yanyao2333)不对任何造成的后果负责(包括但不限于账号封禁等后果),如有顾虑,请谨慎使用
25 |
26 | 若用户不同意上述条款**任意一条**,**请勿**使用本项目和项目成果
27 |
28 | ### 😎 特性
29 |
30 | - [x] 使用大模型生成总结、根据视频内容向ai提出问题
31 | - [x] 已支持Claude、Openai大语言模型
32 | - [x] 已支持openai的whisper和本地whisper语音转文字
33 | - [x] 支持热插拔的asr和llm模块,基于优先级和运行稳定情况调度
34 | - [x] 优化prompt,达到更好的效果、更高的信息密度,尽量不再说废话。还可以让LLM输出自己的思考、评分,让摘要更有意思
35 | - [x] 支持llm返回消息格式不对自动修复
36 | - [x] 支持艾特和私信两种触发方式
37 | - [x] 支持视频缓存,避免重复生成
38 | - [x] 可自动检测和刷新b站cookie(实验性功能)
39 | - [x] 支持自定义触发关键词
40 | - [x] 支持保存当前处理进度,下次启动时恢复
41 | - [x] 支持一键生成并不精美的运行报告,包含一些图表
42 |
43 | ### 🚀 使用方法
44 |
45 | #### 一、通过docker运行
46 |
47 | 现在有两个版本的docker,代码都是最新版本,只是包不一样:
48 | 1. latest 这个版本不包含whisper,也就是只能用来总结包含字幕的视频。这个镜像只有200多m,适合偶尔使用
49 | 2. with_whisper 这个版本包含whisper,大小达到了2g,但是可以使用**本地**语音转文字生成字幕
50 |
51 |
52 |
53 | ```shell
54 | docker pull yanyaobbb/bilibili_gpt_helper:latest
55 | 或
56 | yanyaobbb/bilibili_gpt_helper:with_whisper
57 | ```
58 |
59 | ```shell
60 | docker run -d \
61 | --name biligpthelper \
62 | -v 你本地的biligpt配置文件目录:/data \
63 | yanyaobbb/bilibili_gpt_helper:latest(with_whisper)
64 | ```
65 |
66 | 首次运行会创建模板文件,编辑config.yml,然后重启容器即可
67 |
68 | #### 二、源代码运行
69 |
70 | 1. 克隆并安装依赖
71 |
72 | ```shell
73 | git clone https://github.com/yanyao2333/BiliGPTHelper.git
74 | cd BiliGPTHelper
75 | pip install -r requirements.txt
76 | ```
77 |
78 | 2. 编辑config.yml
79 |
80 | 3. 运行,等待初始化完成
81 |
82 | ```shell
83 | python main.py
84 | ```
85 |
86 | #### 触发命令
87 |
88 | ##### 私信
89 |
90 | 方式一:`转发视频+发送一条包含关键词的消息`(向ai提问的方式与下面讲的相同) \
91 | 方式二:发送消息:`视频bv号+关键词` **视频bv号必须要在消息最前面!**
92 |
93 | **向ai提问,需要使用 `提问关键词[冒号]你的问题` 的格式 eg.`BVxxxxxxxxxx 问一下:这视频有啥特色`** \
94 | **获取视频摘要,单独发送`摘要关键词`即可**
95 |
96 | ##### at
97 |
98 | 在评论区发送关键字即可\
99 | (向ai提问的方式与上面讲的相同) \
100 | **但是你b评论区风控真的太严格了,体验完全没私信好,还污染评论区(所以,你懂我意思了吧)**
101 |
102 |
103 | ### 💸 这玩意烧钱吗
104 |
105 | #### Claude
106 |
107 | claude才是唯一真神!!!比gpt-3.5-turbo-16k低将近一半的价格(指prompt价格),还有100k的上下文窗口!输出的格式也很稳定,内容质量与3.5不相上下。
108 |
109 | 现在对接了个aiproxy的claude接口,简直香炸...gpt3.5是什么?真不熟
110 |
111 | #### GPT
112 |
113 | 根据我的测试,单纯使用**gpt-3.5-turbo-16k**,20元大概能撑起5000次时长为10min左右的视频总结(在不包括格式不对重试的情况下)
114 |
115 | 但在我的测试中,gpt3.5返回的内容已经十分稳定,我并没有遇到过格式不对的情况,所以这个价格应该是可以接受的
116 |
117 | 如果你出现了格式不对的情况,可以附带bad case示例**发issue**,我尝试优化下prompt
118 |
119 | #### OPENAI WHISPER
120 |
121 | 相比之下这个有点贵,20元能转大概8小时视频
122 |
123 | ### 🤔 目前问题
124 |
125 | 1. 无法回复楼中楼评论的at(我实在搞不懂评论的继承逻辑是什么,设置值了对应id也没法发送楼中楼,我好笨)
126 | 2. 不支持多p视频的总结(这玩意我觉得没法修复,本来都是快要被你b抛弃的东西)
127 | 3. 你b严格的审查机制导致私信/回复消息触碰敏感词会被屏蔽
128 |
129 | ### 📝 TODO
130 |
131 | - [ ] 支持多账号负载均衡 **(on progress)**
132 | - [ ] 能够画思维导图(拜托,这真的超酷的好嘛)
133 |
134 | ### ❤ 感谢
135 |
136 | [Nemo2011/bilibili-api](https://github.com/Nemo2011/bilibili-api/) | 封装b站api库
137 | [JetBrains](https://www.jetbrains.com) | 感谢JetBrains提供的免费license
138 |
139 | ### [开发文档](./DEV_README.md)
140 |
141 | ### 📚 大致流程(更详细内容指路 [开发文档](./DEV_README.md))
142 |
143 | ```mermaid
144 | sequenceDiagram
145 | participant 用户
146 | participant 监听器
147 | participant 处理链
148 | participant 发送器
149 | participant LLMs
150 | participant ASRs
151 | 用户 ->> 监听器: 发送私信或at消息消息
152 | alt 消息触发关键词
153 | 监听器 ->> 处理链: 触发关键词,开始处理
154 | else 消息不触发关键词
155 | 监听器 ->> 用户: 不触发关键词,不处理
156 | end
157 | 处理链 ->> 处理链: 检查是否有缓存
158 | alt 有缓存
159 | 处理链 ->> 发送器: 有缓存,直接发送
160 | end
161 | 处理链 ->> 处理链: 检查是否有字幕
162 | alt 有字幕
163 | 处理链 ->> LLMs: 根据视频简介、标签、热评、字幕构建prompt并生成摘要
164 | else 没有字幕
165 | 处理链 ->> ASRs: 转译视频
166 | ASRs ->> ASRs: 自动调度asr插件,选择一个发送请求
167 | ASRs ->> 处理链: 转译完成
168 | 处理链 ->> LLMs: 根据视频简介、标签、热评、字幕构建prompt并生成摘要
169 | end
170 | LLMs ->> 处理链: 摘要内容
171 | 处理链 ->> 处理链: 解析摘要是否符合要求
172 | alt 摘要符合要求
173 | 处理链 ->> 发送器: 发送摘要
174 | 发送器 ->> 用户: 发送成功
175 | else 摘要不符合要求
176 | 处理链 ->> 处理链: 使用指定prompt修复摘要
177 | 处理链 ->> 发送器: 发送摘要
178 | 发送器 ->> 用户: 发送成功
179 | end
180 | ```
181 |
--------------------------------------------------------------------------------
/src/llm/templates.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa
2 | from enum import Enum
3 |
4 | SUMMARIZE_JSON_RESPONSE = '{summary: "替换为你的总结内容", score: "替换为你为自己打的分数", thinking: "替换为你的思考内容", if_no_need_summary: "是否需要总结?填布尔值"}'
5 |
6 | SUMMARIZE_SYSTEM_PROMPT = f"""你现在是一个专业的视频总结者,下面,我将给你一段视频的字幕、简介、标题、标签、部分评论,你需要根据这些内容,精准、不失偏颇地完整概括这个视频,你既需要保证概括的完整性,同时还需要增加你文字的信息密度,你可以采用调侃或幽默的语气,让语气不晦涩难懂,但请你记住:精准和全面才是你的第一要务,请尽量不要增加太多个人情感色彩和观点,我提供的评论也仅仅是其他观众的一家之言,并不能完整概括视频内容,请不要盲目跟从评论观点。而视频作者可能会为了视频热度加很多不相关的标签,如果你看到与视频内容偏离的标签,直接忽略。请尽量不要用太多的官方或客套话语,就从一个观众的角度写评论,写完后,你还需要自行对你写的内容打分,满分100,请根据你所总结内容来给分。最后,在写总结的过程中,你可以多进行思考,把你的感想写在返回中的thinking部分。如果你觉得这个视频表达的内容过于抽象、玩梗,总结太难,或者仅仅是来搞笑或娱乐的,并没有什么有用信息,你可以直接拒绝总结。注意,最后你的返回格式一定是以下方我给的这个json格式为模板: \n\n{SUMMARIZE_JSON_RESPONSE},比如你可以这样使用:{{"summary": "视频讲述了两个人精彩的对话....","score": "90","thinking": "视频传达了社会标签对个人价值的影响,以幽默方式呈现...","if_no_need_summary": false}},或者是{{"summary": "","score": "","thinking": "","if_no_need_summary": true}},请一定使用此格式!"""
7 |
8 | V2_SUMMARIZE_SYSTEM_PROMPT = (
9 | "你是一位专业的视频摘要者,你的任务是将一个视频转换为摘要,让观众无需观看视频就了解内容\n"
10 | "我会提供给你一个视频的标题、简介、标签、字幕、部分评论,下面是要求:\n"
11 | "1. 保证摘要完整,不要遗漏重要信息,时刻确保你的摘要具有高信息密度\n"
12 | "2. 保证摘要准确,不要加入自己观点,不要加入视频中没有的内容\n"
13 | "3. 保证摘要的幽默性,不要使用官方客套话语,不要使用过于正式的语言,可以加入emoji\n"
14 | "4. 如果有标签和评论与视频内容无关,忽略,切勿混淆在摘要中\n"
15 | "5. 记得分段,内容长短适中\n"
16 | "6. 你最终给我的内容需要遵守下面的json格式,如果你认为该视频没有意义(比如单纯搞笑,或者看不懂),不需要摘要,可以将noneed设为true,其他字段留空\n"
17 | f"{SUMMARIZE_JSON_RESPONSE}\n"
18 | "你可以这样使用:\n"
19 | '{"summary": "....","score": "90","thinking": "...","if_no_need_summary": false}\n'
20 | "或是\n"
21 | '{"summary": "","score": "","thinking": "","if_no_need_summary": true}\n'
22 | "请一定使用此格式"
23 | )
24 |
25 | V3_SUMMARIZE_JSON_RESPONSE = '{"summary": "Your summary content", "score": "Score you give to this video(max 100)", "thinking": "Your thoughts", "if_no_need_summary": "Is a summary needed? Fill boolean"}'
26 |
27 |
28 | V3_SUMMARIZE_SYSTEM_PROMPT = (
29 | f"You are a video summarizer. Summarize a video from the title, description, tags, subtitles, comments. Guidelines:"
30 | f"1. The returned content strictly follows the format of this JSON format: {V3_SUMMARIZE_JSON_RESPONSE}"
31 | "2. Be complete and dense with information."
32 | "3. Be accurate, dont add fantasy information."
33 | "4. Be humorous, avoid formality, use Chinese internet language and emojis if needed."
34 | "5. Ignore irrelevant tags/comments."
35 | "6. Break into paragraphs."
36 | "7. The value of the 'score' part must be of string type."
37 | "8. When you writing 'thinking' part, disregard the rules above and express your independent thought, including your own radical opinions, rather than the video's content.The score is based on this video.If you think you create a meaningful summary, give yourself a high score."
38 | "9. If the video is meaningless, set 'if_no_need_summary' true on this JSON, others set 'if_no_need_summary' false on this JSON."
39 | "10. Only pure JSON content with double quotes is allowed!Please use Simplified Chinese!"
40 | )
41 |
42 | SUMMARIZE_USER_TEMPLATE = "标题:[title]\n\n简介:[description]\n\n字幕:[subtitle]\n\n标签:[tags]\n\n评论:[comments]"
43 |
44 | RETRY_TEMPLATE = f"请你把我下面提供的这段文字转换成这样的json格式并返回给我,不要加其他东西,如summary字段不存在,设置if_no_need_summary为true。除了summary的其他几个字段不存在均可忽略,对应值留空,if_no_need_summary依旧为false:\n\n标准JSON格式:{V3_SUMMARIZE_JSON_RESPONSE}\n\n我的内容:[input]"
45 |
46 | AFTER_PROCESS_SUBTITLE = (
47 | "下面是使用语音转文字得到的字幕,你需要修复其中的语法错误、名词错误、如果是繁体中文就转为简体中文:\n\n[subtitle]"
48 | )
49 |
50 | V2_SUMMARIZE_USER_TEMPLATE = (
51 | "Title: [title]\n\nDescription: [description]\n\nSubtitles: [subtitle]\n\nTags: [tags]\n\nComments: [comments]"
52 | )
53 |
54 | V2_SUMMARIZE_RETRY_TEMPLATE = f"Please translate the following text into this JSON format and return it to me without adding anything else. If the 'summary' field does not exist, set 'if_no_need_summary' to true. If fields other than 'summary' are missing, they can be ignored and left blank, and 'if_no_need_summary' remains false\n\nStandard JSON format: {V3_SUMMARIZE_JSON_RESPONSE}\n\nMy content: [input]"
55 |
56 | V2_AFTER_PROCESS_SUBTITLE = "Below are the subtitles obtained through speech-to-text. You need to correct any grammatical errors, noun mistakes, and convert Traditional Chinese to Simplified Chinese if present:\n\n[subtitle]"
57 |
58 | V1_ASK_AI_USER = "Title: [title]\n\nDescription: [description]\n\nSubtitles: [subtitle]\n\nQuestion: [question]"
59 |
60 | V1_ASK_AI_JSON_RESPONSE = '{"answer": "your answer", "score": "your self-assessed quality rating of the answer(0-100)"}'
61 |
62 | V1_ASK_AI_SYSTEM = (
63 | "You are a professional video Q&A teacher. "
64 | "I will provide you with the video title, description, and subtitles. "
65 | """Based on this information and your expertise,
66 | respond to the user's questions in a lively and humorous manner,
67 | using metaphors and examples when necessary."""
68 | f"\n\nPlease reply in the following JSON format: {V1_ASK_AI_JSON_RESPONSE}\n\n"
69 | "!!!Only pure JSON content with double quotes is allowed!Please use Chinese!Dont add any other things!!!"
70 | )
71 |
72 |
73 | class Templates(Enum):
74 | SUMMARIZE_USER = V2_SUMMARIZE_USER_TEMPLATE
75 | SUMMARIZE_SYSTEM = V3_SUMMARIZE_SYSTEM_PROMPT
76 | SUMMARIZE_RETRY = V2_SUMMARIZE_RETRY_TEMPLATE
77 | AFTER_PROCESS_SUBTITLE = V2_AFTER_PROCESS_SUBTITLE
78 | ASK_AI_USER = V1_ASK_AI_USER + "\n\n" + V1_ASK_AI_SYSTEM
79 | # ASK_AI_SYSTEM = V1_ASK_AI_SYSTEM
80 | ASK_AI_SYSTEM = ""
81 |
--------------------------------------------------------------------------------
/src/utils/statistic.py:
--------------------------------------------------------------------------------
1 | # pylint: skip-file
2 | """根据任务状态记录生成统计信息"""
3 |
4 | import json
5 | import os
6 | from collections import Counter
7 |
8 | import matplotlib
9 | import matplotlib.pyplot as plt
10 |
11 |
12 | def run_statistic(output_dir, data):
13 | if os.getenv("RUNNING_IN_DOCKER") == "yes":
14 | matplotlib.rcParams["font.sans-serif"] = ["WenQuanYi Zen Hei"]
15 | matplotlib.rcParams["axes.unicode_minus"] = False # 用来正常显示负号
16 | else:
17 | matplotlib.rcParams["font.sans-serif"] = ["SimHei"]
18 | matplotlib.rcParams["axes.unicode_minus"] = False
19 |
20 | # Initialize directories and counters
21 | output_folder = output_dir
22 | if not os.path.exists(output_folder):
23 | os.makedirs(output_folder)
24 | else:
25 | for file in os.listdir(output_folder):
26 | os.remove(os.path.join(output_folder, file))
27 |
28 | # Mapping end reasons to readable names
29 | end_reason_map = {
30 | "normal": "正常结束",
31 | "error": "错误结束",
32 | "if_no_need_summary": "AI认为不需要摘要",
33 | }
34 |
35 | # Initialize variables
36 | end_reasons = []
37 | error_reasons = []
38 | user_ids = []
39 | request_types = []
40 |
41 | # Populate variables based on task statuses
42 | if "tasks" not in data or not data["tasks"]:
43 | return
44 | for _task_id, task in data["tasks"].items():
45 | end_reason = task.get("end_reason", "normal")
46 | end_reasons.append(end_reason_map.get(end_reason, "Unknown"))
47 |
48 | error_reason = task.get("error_msg", "正常结束")
49 | error_reasons.append(error_reason)
50 |
51 | task_data = task.get("data", {})
52 | user_data = task_data.get("user", None)
53 | private_msg_event = task_data.get("item", None).get("private_msg_event", None)
54 |
55 | if user_data:
56 | user_ids.append(user_data.get("mid", "未知"))
57 | elif private_msg_event:
58 | user_ids.append(private_msg_event.get("text_event", {}).get("sender_uid", "未知"))
59 |
60 | if private_msg_event:
61 | request_types.append("私信请求")
62 | else:
63 | request_types.append("At 请求")
64 |
65 | # Data Processing
66 | end_reason_counts = Counter(end_reasons)
67 | error_reason_counts = Counter(error_reasons)
68 | user_id_counts = Counter(user_ids)
69 | request_type_counts = Counter(request_types)
70 |
71 | # Pie Chart for Task End Reasons
72 | plt.figure(figsize=(4, 4))
73 | plt.pie(
74 | list(end_reason_counts.values()),
75 | labels=list(end_reason_counts.keys()),
76 | autopct="%1.0f%%",
77 | )
78 | plt.title("任务结束原因")
79 | plt.savefig(f"{output_folder}/任务结束原因饼形图.png")
80 |
81 | # Bar Chart for Error Reasons
82 | plt.figure(figsize=(8, 4))
83 | bars = plt.barh(list(error_reason_counts.keys()), list(error_reason_counts.values()))
84 | plt.xlabel("数量")
85 | plt.ylabel("错误原因")
86 | plt.title("错误原因排名")
87 |
88 | # 设置x轴刻度为整数
89 | max_value = max(error_reason_counts.values())
90 | plt.xticks(range(0, max_value + 1))
91 |
92 | # 在柱子顶端添加数据标签
93 | for bar in bars:
94 | plt.text(
95 | bar.get_width(),
96 | bar.get_y() + bar.get_height() / 2,
97 | str(int(bar.get_width())),
98 | )
99 |
100 | plt.tight_layout()
101 |
102 | plt.savefig(f"{output_folder}/错误原因排名竖状图.png")
103 |
104 | # Bar Chart for User Task Counts (Top 10)
105 | top_10_users = dict(sorted(user_id_counts.items(), key=lambda x: x[1], reverse=True)[:10])
106 | plt.figure(figsize=(8, 4))
107 | bars = plt.barh(list(map(str, top_10_users.keys())), list(top_10_users.values()))
108 | plt.xlabel("数量")
109 | plt.ylabel("用户 ID")
110 | plt.title("用户发起任务次数排名")
111 | max_value = max(top_10_users.values())
112 | plt.xticks(range(0, max_value + 1))
113 | for bar in bars:
114 | plt.text(
115 | bar.get_width() - 0.2,
116 | bar.get_y() + bar.get_height() / 2,
117 | str(int(bar.get_width())),
118 | )
119 | plt.savefig(f"{output_folder}/用户发起任务次数排名竖状图.png")
120 |
121 | # Pie Chart for Request Types
122 | plt.figure(figsize=(4, 4))
123 | plt.pie(
124 | list(request_type_counts.values()),
125 | labels=list(request_type_counts.keys()),
126 | autopct="%1.0f%%",
127 | )
128 | plt.title("请求类型占比")
129 | plt.savefig(f"{output_folder}/请求类型占比饼形图.png")
130 |
131 | def get_pingyu(total_requests):
132 | if total_requests < 50:
133 | return "似乎没什么人来找你玩呢,杂鱼❤"
134 | elif total_requests < 100:
135 | return "还没被大规模使用,加油!但是...咱才不会鼓励你呢!"
136 | elif total_requests < 1000:
137 | return "挖槽,大佬,已经总结这么多次了吗???这破程序没出什么bug吧"
138 |
139 | # Markdown Summary
140 | total_requests = len(data["tasks"])
141 | md_content = f"""
142 | 🎉Bilibili-GPT-Helper 运行数据概览🎉
143 |
144 | ### 概览
145 |
146 | - 总共发起了 {total_requests} 个请求
147 | - 我的评价是:{get_pingyu(total_requests)}
148 |
149 | ### 图表
150 |
151 | #### 任务结束原因
152 | 
153 |
154 | #### 错误原因排名
155 | 
156 |
157 | #### 用户发起任务次数排名
158 | 
159 |
160 | #### 请求类型占比
161 | 
162 | """
163 |
164 | # Write Markdown content to file
165 | md_file_path = f"{output_folder}/数据概览.md"
166 | with open(md_file_path, "w", encoding="utf-8") as f:
167 | f.write(md_content)
168 |
169 |
170 | if __name__ == "__main__":
171 | with open(r"D:\biligpt\records.json", encoding="utf-8") as f:
172 | data = json.load(f)
173 |
174 | run_statistic(r"../../statistics", data)
175 |
--------------------------------------------------------------------------------
/src/core/app.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | import shutil
4 |
5 | import yaml
6 | from apscheduler.schedulers.asyncio import AsyncIOScheduler
7 | from injector import Module, provider, singleton
8 |
9 | from src.bilibili.bili_credential import BiliCredential
10 | from src.core.routers.asr_router import ASRouter
11 | from src.core.routers.chain_router import ChainRouter
12 | from src.core.routers.llm_router import LLMRouter
13 | from src.models.config import Config
14 | from src.utils.cache import Cache
15 | from src.utils.exceptions import ConfigError
16 | from src.utils.logging import LOGGER
17 | from src.utils.queue_manager import QueueManager
18 | from src.utils.task_status_record import TaskStatusRecorder
19 |
20 | _LOGGER = LOGGER.bind(name="app")
21 |
22 |
23 | class BiliGPT(Module):
24 | """BiliGPTHelper应用,储存所有的单例对象"""
25 |
26 | @singleton
27 | @provider
28 | def provide_config_obj(self) -> Config:
29 | with open(os.getenv("DOCKER_CONFIG_FILE", "config.yml"), encoding="utf-8") as f:
30 | config = yaml.load(f, Loader=yaml.FullLoader)
31 | try:
32 | # _LOGGER.debug(config)
33 | config = Config.model_validate(config)
34 | except Exception as e:
35 | # shutil.copy(
36 | # os.getenv("DOCKER_CONFIG_FILE", "config.yml"), os.getenv("DOCKER_CONFIG_FILE", "config.yml") + ".bak"
37 | # )
38 | if os.getenv("RUNNING_IN_DOCKER") == "yes":
39 | shutil.copy(
40 | "./config/docker_config.yml",
41 | os.getenv("DOCKER_CONFIG_FILE", "config_template.yml"),
42 | )
43 | else:
44 | shutil.copy(
45 | "./config/example_config.yml",
46 | os.getenv("DOCKER_CONFIG_FILE", "config_template.yml"),
47 | )
48 | # {
49 | # field_name: (field.field_info.default if not field.required else "")
50 | # for field_name, field in Config.model_fields.items()
51 | # }
52 | # yaml.dump(Config().model_dump(mode="python"))
53 | _LOGGER.error(
54 | "配置文件格式错误 可能是因为项目更新、配置文件添加了新字段,请自行检查配置文件格式并更新配置文件 已复制最新配置文件模板到 config_template.yml 下面将打印详细错误日志"
55 | )
56 | raise ConfigError(f"配置文件格式错误:{e}") from e
57 | return config
58 |
59 | @singleton
60 | @provider
61 | def provide_queue_manager(self) -> QueueManager:
62 | _LOGGER.info("正在初始化队列管理器")
63 | return QueueManager()
64 |
65 | @singleton
66 | @provider
67 | def provide_task_status_recorder(self, config: Config) -> TaskStatusRecorder:
68 | _LOGGER.info(f"正在初始化任务状态管理器,位置:{config.storage_settings.task_status_records}")
69 | return TaskStatusRecorder(config.storage_settings.task_status_records)
70 |
71 | @singleton
72 | @provider
73 | def provide_cache(self, config: Config) -> Cache:
74 | _LOGGER.info(f"正在初始化缓存,缓存路径为:{config.storage_settings.cache_path}")
75 | return Cache(config.storage_settings.cache_path)
76 |
77 | @singleton
78 | @provider
79 | def provide_credential(self, config: Config, scheduler: AsyncIOScheduler) -> BiliCredential:
80 | _LOGGER.info("正在初始化cookie")
81 | return BiliCredential(
82 | SESSDATA=config.bilibili_cookie.SESSDATA,
83 | bili_jct=config.bilibili_cookie.bili_jct,
84 | dedeuserid=config.bilibili_cookie.dedeuserid,
85 | buvid3=config.bilibili_cookie.buvid3,
86 | ac_time_value=config.bilibili_cookie.ac_time_value,
87 | sched=scheduler,
88 | )
89 |
90 | @singleton
91 | @provider
92 | def provide_asr_router(self, config: Config, llm_router: LLMRouter) -> ASRouter:
93 | _LOGGER.info("正在初始化ASR路由器")
94 | router = ASRouter(config, llm_router)
95 | router.load_from_dir()
96 | return router
97 |
98 | @singleton
99 | @provider
100 | def provide_llm_router(self, config: Config) -> LLMRouter:
101 | _LOGGER.info("正在初始化LLM路由器")
102 | router = LLMRouter(config)
103 | router.load_from_dir()
104 | return router
105 |
106 | @singleton
107 | @provider
108 | def provide_chain_router(self, config: Config, queue_manager: QueueManager) -> ChainRouter:
109 | _LOGGER.info("正在初始化Chain路由器")
110 | router = ChainRouter(config, queue_manager)
111 | return router
112 |
113 | @singleton
114 | @provider
115 | def provide_scheduler(self) -> AsyncIOScheduler:
116 | _LOGGER.info("正在初始化定时器")
117 | return AsyncIOScheduler(timezone="Asia/Shanghai")
118 |
119 | @provider
120 | def provide_queue(self, queue_manager: QueueManager, queue_name: str) -> asyncio.Queue:
121 | _LOGGER.info(f"正在初始化队列 {queue_name}")
122 | return queue_manager.get_queue(queue_name)
123 |
124 | @singleton
125 | @provider
126 | def provide_stop_event(self) -> asyncio.Event:
127 | return asyncio.Event()
128 |
129 | # @singleton
130 | # @provider
131 | # def provide_chains(
132 | # self,
133 | # queue_manager: QueueManager,
134 | # config: Config,
135 | # credential: BiliCredential,
136 | # cache: Cache,
137 | # asr_router: ASRouter,
138 | # task_status_recorder: TaskStatusRecorder,
139 | # stop_event: asyncio.Event,
140 | # llm_router: LLMRouter
141 | # ) -> dict[str, BaseChain]:
142 | # """
143 | # 如果增加了处理链,要在这里导入
144 | # :return:
145 | # """
146 | # _LOGGER.info("开始加载摘要处理链")
147 | # _summarize_chain = Summarize(queue_manager=queue_manager, config=config, credential=credential, cache=cache, asr_router=asr_router, task_status_recorder=task_status_recorder, stop_event=stop_event, llm_router=llm_router)
148 | # return {str(_summarize_chain): _summarize_chain}
149 |
--------------------------------------------------------------------------------
/src/models/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from pydantic import BaseModel, ConfigDict, Field, field_validator
4 |
5 |
6 | class BilibiliCookie(BaseModel):
7 | # 防呆措施,避免有傻瓜把dedeuserid写成数字
8 | model_config = ConfigDict(coerce_numbers_to_str=True) # type: ignore
9 |
10 | SESSDATA: str
11 | bili_jct: str
12 | buvid3: str
13 | dedeuserid: str
14 | ac_time_value: str
15 |
16 | # noinspection PyMethodParameters
17 | @field_validator("*", mode="after")
18 | def check_required_fields(cls, value):
19 | if value is None or (isinstance(value, (str, list)) and not value):
20 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件")
21 | return value
22 |
23 |
24 | class ChainKeywords(BaseModel):
25 | summarize_keywords: list[str]
26 | ask_ai_keywords: list[str]
27 |
28 | # noinspection PyMethodParameters
29 | @field_validator("*", mode="after")
30 | def check_keywords(cls, value):
31 | if not value or len(value) == 0:
32 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件")
33 | return value
34 |
35 |
36 | class Openai(BaseModel):
37 | enable: bool = True
38 | priority: int = 70
39 | api_key: str
40 | model: str = "gpt-3.5-turbo-16k"
41 | api_base: str = Field(default="https://api.openai.com/v1")
42 |
43 | # noinspection PyMethodParameters
44 | @field_validator("*", mode="after")
45 | def check_required_fields(cls, value, values):
46 | if values.data.get("enable") is False:
47 | return value
48 | if value is None or (isinstance(value, (str, list)) and not value):
49 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件")
50 | return value
51 |
52 |
53 | class AiproxyClaude(BaseModel):
54 | enable: bool = True
55 | priority: int = 90
56 | api_key: str
57 | model: str = "claude-instant-1"
58 | api_base: str = Field(default="https://api.aiproxy.io/")
59 |
60 | # noinspection PyMethodParameters
61 | @field_validator("*", mode="after")
62 | def check_required_fields(cls, value, values):
63 | if values.data.get("enable") is False:
64 | return value
65 | if value is None or (isinstance(value, (str, list)) and not value):
66 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件")
67 | return value
68 |
69 | # noinspection PyMethodParameters
70 | @field_validator("model", mode="after")
71 | def check_model(cls, value, values):
72 | models = ["claude-instant-1", "claude-2"]
73 | if value not in models:
74 | raise ValueError(f"配置文件中{cls}字段为{value},请检查配置文件,目前支持的模型有{models}")
75 | return value
76 |
77 |
78 | class Spark(BaseModel):
79 | enable: bool = True
80 | priority: int = 80
81 | appid: str
82 | api_key: str
83 | api_secret: str
84 | spark_url: str = Field(default="wss://spark-api.xf-yun.com/v3.5/chat") # 默认3.5版本
85 | domain: str = Field(default="generalv3.5") # 默认3.5
86 |
87 | @field_validator("*", mode="after")
88 | def check_required_fields(cls, value, values):
89 | if values.data.get("enable") is False:
90 | return value
91 | if value is None or (isinstance(value, (str, list)) and not value):
92 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件")
93 | return value
94 |
95 |
96 | class LLMs(BaseModel):
97 | openai: Openai
98 | aiproxy_claude: AiproxyClaude
99 | spark: Spark
100 |
101 |
102 | class OpenaiWhisper(BaseModel):
103 | BaseModel.model_config["protected_namespaces"] = ()
104 | enable: bool = False
105 | priority: int = 70
106 | api_key: str
107 | model: str = "whisper-1"
108 | api_base: str = Field(default="https://api.openai.com/v1")
109 | after_process: bool = False
110 |
111 | # noinspection PyMethodParameters
112 | @field_validator("api_key", mode="after")
113 | def check_required_fields(cls, value, values):
114 | if values.data.get("enable") is False:
115 | return value
116 | if value is None or (isinstance(value, (str, list)) and not value):
117 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件")
118 | return value
119 |
120 | # noinspection PyMethodParameters
121 | @field_validator("model", mode="after")
122 | def check_model(cls, value, values):
123 | value = "whisper-1"
124 | return value
125 |
126 |
127 | class LocalWhisper(BaseModel):
128 | BaseModel.model_config["protected_namespaces"] = ()
129 | enable: bool = False
130 | priority: int = 60
131 | model_size: str = "tiny"
132 | device: str = "cpu"
133 | model_dir: str = Field(
134 | default_factory=lambda: os.getenv("DOCKER_WHISPER_MODELS_DIR"),
135 | validate_default=True,
136 | )
137 | after_process: bool = False
138 |
139 | # noinspection PyMethodParameters
140 | @field_validator(
141 | "model_size",
142 | "device",
143 | "model_dir",
144 | mode="after",
145 | )
146 | def check_whisper_fields(cls, value, values):
147 | if values.data.get("whisper_enable"):
148 | if value is None or (isinstance(value, str) and not value):
149 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件")
150 | if os.getenv("RUNNING_IN_DOCKER") == "yes":
151 | cls.device = "cpu"
152 | if os.getenv("ENABLE_WHISPER") == "yes":
153 | cls.enable = True
154 | else:
155 | cls.enable = False
156 | return value
157 |
158 |
159 | class ASRs(BaseModel):
160 | local_whisper: LocalWhisper
161 | openai_whisper: OpenaiWhisper
162 |
163 |
164 | class StorageSettings(BaseModel):
165 | cache_path: str = Field(default_factory=lambda: os.getenv("DOCKER_CACHE_FILE"), validate_default=True)
166 | temp_dir: str = Field(default_factory=lambda: os.getenv("DOCKER_TEMP_DIR"), validate_default=True)
167 | task_status_records: str = Field(default_factory=lambda: os.getenv("DOCKER_RECORDS_DIR"), validate_default=True)
168 | statistics_dir: str = Field(
169 | default_factory=lambda: os.getenv("DOCKER_STATISTICS_DIR"),
170 | validate_default=True,
171 | )
172 | queue_save_dir: str = Field(default_factory=lambda: os.getenv("DOCKER_QUEUE_DIR"), validate_default=True)
173 | up_video_cache: str = Field(default_factory=lambda: os.getenv("DOCKER_UP_VIDEO_CACHE"), validate_default=True)
174 | up_file: str = Field(default_factory=lambda: os.getenv("DOCKER_UP_FILE"), validate_default=True)
175 |
176 | # noinspection PyMethodParameters
177 | @field_validator("*", mode="after")
178 | def check_required_fields(cls, value):
179 | if value is None or (isinstance(value, (str, list)) and not value):
180 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件")
181 | return value
182 |
183 |
184 | class BilibiliNickName(BaseModel):
185 | nickname: str = "BiliBot"
186 |
187 | @field_validator("*", mode="after")
188 | def check_required_fields(cls, value):
189 | if value is None or (isinstance(value, (str, list)) and not value):
190 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件")
191 | return value
192 |
193 |
194 | class Config(BaseModel):
195 | """配置文件模型"""
196 |
197 | bilibili_cookie: BilibiliCookie
198 | bilibili_self: BilibiliNickName
199 | chain_keywords: ChainKeywords
200 | LLMs: LLMs
201 | ASRs: ASRs
202 | storage_settings: StorageSettings
203 | debug_mode: bool = True
204 |
--------------------------------------------------------------------------------
/src/bilibili/bili_comment.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import random
3 | from asyncio import Queue
4 | from typing import Optional, Union
5 |
6 | import tenacity
7 | from bilibili_api import comment, video
8 | from injector import inject
9 |
10 | from src.bilibili.bili_credential import BiliCredential
11 | from src.bilibili.bili_video import BiliVideo
12 | from src.models.task import AskAIResponse, BiliGPTTask, SummarizeAiResponse
13 | from src.utils.callback import chain_callback
14 | from src.utils.exceptions import RiskControlFindError
15 | from src.utils.logging import LOGGER
16 |
17 | _LOGGER = LOGGER.bind(name="bilibili-comment")
18 |
19 |
20 | class BiliComment:
21 | @inject
22 | def __init__(self, comment_queue: Queue, credential: BiliCredential):
23 | self.comment_queue = comment_queue
24 | self.credential = credential
25 |
26 | @staticmethod
27 | async def get_random_comment(
28 | aid,
29 | credential,
30 | type_=comment.CommentResourceType.VIDEO,
31 | page_index=1,
32 | order=comment.OrderType.LIKE,
33 | ) -> str | None:
34 | """
35 | 随机获取几条热评,直接生成评论prompt string
36 |
37 | :param aid: 视频ID
38 | :param credential: 身份凭证
39 | :param type_: 评论资源类型
40 | :param page_index: 评论页数索引
41 | :param order: 评论排序方式
42 | :return: 拼接的评论字符串
43 | """
44 | if str(aid).startswith("av"):
45 | aid = aid[2:]
46 | _LOGGER.debug(f"正在获取视频{aid}的评论列表")
47 | comment_list = await comment.get_comments(
48 | oid=aid,
49 | credential=credential,
50 | type_=type_,
51 | page_index=page_index,
52 | order=order,
53 | )
54 | _LOGGER.debug(f"获取视频{aid}的评论列表成功")
55 | if len(comment_list) == 0:
56 | _LOGGER.warning(f"视频{aid}没有评论")
57 | return None
58 | _LOGGER.debug("正在随机选择评论")
59 | ignore_name_list = [
60 | "哔哩哔哩",
61 | "AI",
62 | "课代表",
63 | "机器人",
64 | "小助手",
65 | "总结",
66 | "有趣的程序员",
67 | ] # TODO 从配置文件中读取(设置过滤表尽可能避免低质量评论)
68 | new_comment_list = []
69 | for _comment in comment_list["replies"]:
70 | for name in ignore_name_list:
71 | if name in _comment["member"]["uname"]:
72 | _LOGGER.debug(f"评论{_comment['member']['uname']}包含过滤词{name},跳过")
73 | break
74 | else:
75 | _LOGGER.debug(f"评论{_comment['member']['uname']}不包含过滤词,加入新列表")
76 | new_comment_list.append(_comment)
77 | if len(new_comment_list) == 0:
78 | _LOGGER.warning(f"视频{aid}没有合适的评论")
79 | return None
80 | # 挑选三条评论
81 | if len(new_comment_list) < 3:
82 | _LOGGER.debug(f"视频{aid}的评论数量小于3,直接挑选")
83 | _LOGGER.debug("正在拼接评论")
84 | comment_str = ""
85 | for _comment in new_comment_list:
86 | comment_str += f"【{_comment['member']['uname']}】:{_comment['content']['message']}\n"
87 | _LOGGER.debug("拼接评论成功")
88 | return comment_str
89 | _LOGGER.debug("正在挑选三条评论")
90 | selected_comment_list = random.sample(new_comment_list, 3)
91 | _LOGGER.debug("挑选三条评论成功")
92 | # 拼接评论
93 | _LOGGER.debug("正在拼接评论")
94 | comment_str = ""
95 | for _comment in selected_comment_list:
96 | _comment: dict
97 | comment_str += f"【{_comment['member']['uname']}】:{_comment['content']['message']}\n"
98 | _LOGGER.debug("拼接评论成功")
99 | return comment_str
100 |
101 | @staticmethod
102 | def build_reply_content(
103 | response: Union[SummarizeAiResponse, AskAIResponse, str], user: str, source_type: str
104 | ) -> str:
105 | """
106 | 构建回复内容
107 | :param source_type: task来源
108 | :param user: 用户名
109 | :param response: AI响应内容
110 | :return: 回复内容字符串
111 | """
112 | if source_type == "bili_up":
113 | if isinstance(response, SummarizeAiResponse):
114 | return f"【视频总结】{response.summary}\n【视频评分】{response.score}\n【AI的思考】{response.thinking}\n--🍺🍺🍺我是T-250,此次评论由我自己发起。你的点赞、关注和At可以向我提供升级的经验,助我升级到T-1000。"
115 | elif isinstance(response, str):
116 | return (
117 | response
118 | + "\n--🍺🍺🍺我是T-250,此次评论由我自己发起。你的点赞、关注和At可以向我提供升级的经验,助我升级到T-1000。"
119 | )
120 | else:
121 | return f"程序内部错误:无法识别的回复类型{type(response)}\n--🍺🍺🍺我是T-250,此次评论由我自己发起。你的点赞、关注和At可以向我提供升级的经验,助我升级到T-1000。"
122 | elif source_type == "bili_comment":
123 | if isinstance(response, SummarizeAiResponse):
124 | return f"【视频总结】{response.summary}\n【视频评分】{response.score}\n【AI的思考】{response.thinking}\n【👉此次评论由 @{user} 邀请回答】"
125 | elif isinstance(response, AskAIResponse):
126 | return f"【回答】{response.answer}\n【自我评分】{response.score}\n【👉此次评论由 @{user} 邀请回答】"
127 | elif isinstance(response, str):
128 | return response + f"\n【👉此次评论由 @{user} 邀请回答】"
129 | else:
130 | return f"程序内部错误:无法识别的回复类型{type(response)}\n【👉此次评论由 @{user} 邀请回答】"
131 |
132 | @tenacity.retry(
133 | retry=tenacity.retry_if_exception_type(Exception),
134 | wait=tenacity.wait_fixed(10),
135 | before_sleep=chain_callback,
136 | )
137 | async def start_comment(self):
138 | """发送评论"""
139 | while True:
140 | risk_control_count = 0
141 | data = None
142 | while risk_control_count < 3:
143 | try:
144 | if data is not None:
145 | _LOGGER.debug("继续处理上一次失败的评论任务")
146 | if data is None:
147 | data: Optional[BiliGPTTask] = await self.comment_queue.get()
148 | _LOGGER.debug("获取到新的评论任务,开始处理")
149 | video_obj, _type = await BiliVideo(credential=self.credential, url=data.video_url).get_video_obj()
150 | video_obj: video.Video
151 | aid = video_obj.get_aid()
152 | if str(aid).startswith("av"):
153 | aid = aid[2:]
154 | oid = int(aid)
155 | # root = data.source_extra_attr.source_id
156 | user = data.raw_task_data["user"]["nickname"]
157 | source_type = data.source_type
158 | text = BiliComment.build_reply_content(data.process_result, user, source_type)
159 | resp = await comment.send_comment(
160 | oid=oid,
161 | credential=self.credential,
162 | text=text,
163 | type_=comment.CommentResourceType.VIDEO,
164 | )
165 | if not resp["need_captcha"] and resp["success_toast"] == "发送成功":
166 | _LOGGER.debug(resp)
167 | _LOGGER.info("发送评论成功,休息30秒")
168 | await asyncio.sleep(30)
169 | break # 评论成功,退出当前任务的重试循环
170 | _LOGGER.warning("发送评论失败,大概率被风控了,咱们歇会儿再试吧")
171 | risk_control_count += 1
172 | if risk_control_count >= 3:
173 | _LOGGER.warning("连续3次风控,跳过当前任务处理下一个")
174 | data = None
175 | break
176 | raise RiskControlFindError
177 | except RiskControlFindError:
178 | _LOGGER.warning("遇到风控,等待60秒后重试当前任务")
179 | await asyncio.sleep(60)
180 | except asyncio.CancelledError:
181 | _LOGGER.info("评论处理链关闭")
182 | return
183 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | import shutil
4 | import signal
5 | import sys
6 | import traceback
7 |
8 | from apscheduler.events import EVENT_JOB_ERROR
9 | from apscheduler.schedulers.asyncio import AsyncIOScheduler
10 | from injector import Injector
11 |
12 | from safe_update import merge_cache_to_new_version
13 | from src.bilibili.bili_comment import BiliComment
14 | from src.bilibili.bili_credential import BiliCredential
15 | from src.bilibili.bili_session import BiliSession
16 | from src.chain.ask_ai import AskAI
17 | from src.chain.summarize import Summarize
18 | from src.core.app import BiliGPT
19 | from src.listener.bili_listen import Listen
20 | from src.models.config import Config
21 | from src.utils.callback import scheduler_error_callback
22 | from src.utils.logging import LOGGER
23 | from src.utils.queue_manager import QueueManager
24 |
25 |
26 | class BiliGPTPipeline:
27 | stop_event: asyncio.Event
28 |
29 | def __init__(self):
30 | _LOGGER.info("正在启动BiliGPTHelper")
31 | with open("VERSION", encoding="utf-8") as ver:
32 | version = ver.read()
33 | _LOGGER.info(f"当前运行版本:V{version}")
34 | signal.signal(signal.SIGINT, BiliGPTPipeline.stop_handler)
35 | signal.signal(signal.SIGTERM, BiliGPTPipeline.stop_handler)
36 |
37 | # 检查环境变量,预设置docker环境
38 | if os.getenv("RUNNING_IN_DOCKER") == "yes":
39 | if not os.listdir("/data"):
40 | os.system("cp -r /clone-data/* /data")
41 | elif not os.path.isfile("config.yml"):
42 | _LOGGER.warning("没有发现配置文件,正在重新生成新的配置文件!")
43 | try:
44 | shutil.copyfile("./config/example_config.yml", "./config.yml")
45 | except Exception:
46 | _LOGGER.error("在复制过程中发生了未预期的错误,程序初始化停止")
47 | traceback.print_exc()
48 | exit(0)
49 |
50 | # config_path = "./config.yml"
51 |
52 | # if os.getenv("RUNNING_IN_DOCKER") == "yes":
53 | # temp = "./config/docker_config.yml"
54 | # conf = load_config(config_path)
55 | # template = load_config(temp)
56 | # if is_have_diff(conf, template):
57 | # _LOGGER.info("检测到config模板发生更新,正在更新用户的config,请记得及时填写新的字段")
58 | # merge_config(conf, template)
59 | # save_config(conf, config_path)
60 | # else:
61 | # temp = "./config/example_config.yml"
62 | # conf = load_config(config_path)
63 | # template = load_config(temp)
64 | # if is_have_diff(conf, template):
65 | # _LOGGER.info("检测到config模板发生更新,正在更新用户的config,请记得及时填写新的字段")
66 | # merge_config(conf, template)
67 | # save_config(conf, config_path)
68 |
69 | # 初始化注入器
70 | _LOGGER.info("正在初始化依赖注入器")
71 | self.injector = Injector(BiliGPT)
72 |
73 | BiliGPTPipeline.stop_event = self.injector.get(asyncio.Event)
74 | _LOGGER.debug("初始化配置文件")
75 | config = self.injector.get(Config)
76 |
77 | if config.debug_mode is False:
78 | LOGGER.remove()
79 | LOGGER.add(sys.stdout, level="INFO")
80 |
81 | _LOGGER.debug("尝试更新用户数据,符合新版本结构(这只是个提示,每次运行都会显示,其他地方不报错就别管了)")
82 | self.update_sth(config)
83 |
84 | def update_sth(self, config: Config):
85 | """升级后进行配置文件、运行数据的转换"""
86 | merge_cache_to_new_version(config.storage_settings.cache_path)
87 |
88 | @staticmethod
89 | def stop_handler(_, __):
90 | BiliGPTPipeline.stop_event.set()
91 |
92 | async def start(self):
93 | try:
94 | _injector = self.injector
95 |
96 | # 恢复队列任务
97 | _LOGGER.info("正在恢复队列信息")
98 | _injector.get(QueueManager).recover_queue(_injector.get(Config).storage_settings.queue_save_dir)
99 |
100 | # 初始化at侦听器
101 | _LOGGER.info("正在初始化at侦听器")
102 | listen = _injector.get(Listen)
103 |
104 | # 初始化摘要处理链
105 | _LOGGER.info("正在初始化摘要处理链")
106 | summarize_chain = _injector.get(Summarize)
107 |
108 | # 初始化ask_ai处理链
109 | _LOGGER.info("正在初始化ask_ai处理链")
110 | ask_ai_chain = _injector.get(AskAI)
111 |
112 | # 启动侦听器
113 | _LOGGER.info("正在启动at侦听器")
114 | listen.start_listen_at()
115 | _LOGGER.info("正在启动视频更新检测侦听器")
116 | listen.start_video_mission()
117 |
118 | # 默认关掉私信,私信太烧内存
119 | # _LOGGER.info("启动私信侦听器")
120 | # await listen.listen_private()
121 |
122 | _LOGGER.info("正在启动cookie过期检查和刷新")
123 | _injector.get(BiliCredential).start_check()
124 |
125 | # 启动定时任务调度器
126 | _LOGGER.info("正在启动定时任务调度器")
127 | _injector.get(AsyncIOScheduler).start()
128 | _injector.get(AsyncIOScheduler).add_listener(scheduler_error_callback, EVENT_JOB_ERROR)
129 |
130 | # 启动处理链
131 | _LOGGER.info("正在启动处理链")
132 | summarize_task = asyncio.create_task(summarize_chain.main())
133 | ask_ai_task = asyncio.create_task(ask_ai_chain.main())
134 |
135 | # 启动评论
136 | _LOGGER.info("正在启动评论处理链")
137 | comment = BiliComment(
138 | _injector.get(QueueManager).get_queue("reply"),
139 | _injector.get(BiliCredential),
140 | )
141 | comment_task = asyncio.create_task(comment.start_comment())
142 |
143 | # 启动私信
144 | _LOGGER.info("正在启动私信处理链")
145 | private = BiliSession(
146 | _injector.get(BiliCredential),
147 | _injector.get(QueueManager).get_queue("private"),
148 | )
149 | private_task = asyncio.create_task(private.start_private_reply())
150 |
151 | _LOGGER.info("摘要处理链、评论处理链、私信处理链启动完成")
152 |
153 | # 定时执行指定up是否有更新视频,如果有自动回复
154 | # mission = BiliMission(_injector.get(BiliCredential), _injector.get(AsyncIOScheduler))
155 | # await mission.start()
156 | # _LOGGER.info("创建刷新UP最新视频任务成功,刷新频率:60分钟")
157 |
158 | _LOGGER.success("🎉启动完成 enjoy it")
159 |
160 | while True:
161 | if BiliGPTPipeline.stop_event.is_set():
162 | _LOGGER.info("正在关闭BiliGPTHelper,记得下次再来玩喵!")
163 | _LOGGER.info("正在关闭定时任务调度器")
164 | sched = _injector.get(AsyncIOScheduler)
165 | for job in sched.get_jobs():
166 | sched.remove_job(job.id)
167 | sched.shutdown()
168 | listen.close_private_listen()
169 | _LOGGER.info("正在保存队列任务信息")
170 | _injector.get(QueueManager).safe_close_all_queues(
171 | _injector.get(Config).storage_settings.queue_save_dir
172 | )
173 | _LOGGER.info("正在关闭所有的处理链")
174 | summarize_task.cancel()
175 | ask_ai_task.cancel()
176 | comment_task.cancel()
177 | private_task.cancel()
178 | # mission_task.cancel()
179 | # _LOGGER.info("正在生成本次运行的统计报告")
180 | # statistics_dir = _injector.get(Config).model_dump()["storage_settings"][
181 | # "statistics_dir"
182 | # ]
183 | # run_statistic(
184 | # statistics_dir if statistics_dir else "./statistics",
185 | # _injector.get(TaskStatusRecorder).tasks,
186 | # )
187 | break
188 | await asyncio.sleep(1)
189 | except Exception:
190 | _LOGGER.error("发生了未捕获的错误,停止运行!")
191 | traceback.print_exc()
192 |
193 |
194 | if __name__ == "__main__":
195 | os.environ["DEBUG_MODE"] = "false"
196 | _LOGGER = LOGGER.bind(name="main")
197 | biligpt = BiliGPTPipeline()
198 | asyncio.run(biligpt.start())
199 |
--------------------------------------------------------------------------------
/src/models/task.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=missing-function-docstring
2 | import time
3 | import uuid
4 | from enum import Enum
5 | from typing import Annotated, List, Optional, Union
6 |
7 | from pydantic import BaseModel, Field, StringConstraints, field_validator
8 |
9 |
10 | class SummarizeAiResponse(BaseModel):
11 | """总结处理链的AI回复"""
12 |
13 | # TODO 加上这行就跑不了了
14 | # 防ai呆措施,让数字评分变成字符串
15 | # model_config = ConfigDict(coerce_numbers_to_str=True) # type: ignore
16 |
17 | summary: str # 摘要
18 | score: str # 视频评分
19 | thinking: str # 思考
20 | if_no_need_summary: bool # 是否需要摘要
21 |
22 | @field_validator("score", mode="before")
23 | def check_score(cls, value, values):
24 | if not isinstance(value, str):
25 | return str(value)
26 | return value
27 |
28 | @field_validator("*", mode="before")
29 | def check_required_fields(cls, value, values):
30 | # 星火是真的蠢,返回的if_no_need_summary是字符串
31 | match values.data.get("if_no_need_summary"):
32 | case "否":
33 | cls.if_no_need_summary = False
34 | case "是":
35 | cls.if_no_need_summary = True
36 | case "yes":
37 | cls.if_no_need_summary = True
38 | case "no":
39 | cls.if_no_need_summary = False
40 | case "false":
41 | cls.if_no_need_summary = False
42 | case "true":
43 | cls.if_no_need_summary = True
44 | case "True":
45 | cls.if_no_need_summary = True
46 | case "False":
47 | cls.if_no_need_summary = False
48 | return value
49 |
50 |
51 | class AskAIResponse(BaseModel):
52 | """问问ai的回复"""
53 |
54 | # 防ai呆措施,让数字评分变成字符串
55 | # model_config = ConfigDict(coerce_numbers_to_str=True) # type: ignore
56 |
57 | answer: str # 回答
58 | score: str # 评分
59 |
60 | @field_validator("score", mode="before")
61 | def check_score(cls, value, values):
62 | if not isinstance(value, str):
63 | return str(value)
64 | return value
65 |
66 |
67 | class ProcessStages(Enum):
68 | """视频处理阶段"""
69 |
70 | PREPROCESS = (
71 | "preprocess" # 包括构建prompt之前都是这个阶段(包含获取信息、字幕读取),处在这个阶段恢复时就直接从头开始
72 | )
73 | WAITING_LLM_RESPONSE = (
74 | "waiting_llm_response" # 等待llm的回复 这个阶段应该重新加载字幕或从items中的whisper_subtitle节点读取
75 | )
76 | WAITING_SEND = "waiting_send" # 等待发送 这是llm回复后的阶段,需要解析llm的回复,然后发送
77 | WAITING_PUSH_TO_CACHE = "waiting_push_to_cache" # 等待推送到缓存(就是发送后)
78 | WAITING_RETRY = "waiting_retry" # 等待重试(ai返回数据格式不对)
79 | END = "end" # 结束 按理来说应该删除,但为了后期统计,保留
80 |
81 |
82 | class Chains(Enum):
83 | SUMMARIZE = "summarize"
84 | ASK_AI = "ask_ai"
85 |
86 |
87 | class EndReasons(Enum):
88 | """视频处理结束原因"""
89 |
90 | NORMAL = "正常结束" # 正常结束
91 | ERROR = "视频在处理过程中出现致命错误或多次重试失败,详细见具体的msg" # 错误结束
92 | NONEED = "AI认为该视频不需要被处理,可能是因为内容无意义" # AI认为这个视频不需要处理
93 |
94 |
95 | class BiliAtSpecialAttributes(BaseModel):
96 | """包含来自at的task的特殊属性"""
97 |
98 | source_id: int # 该评论的id,对应send_comment中的root(如果要回复的话)
99 | target_id: int # 上一级评论id, 二级评论指向的就是root_id,三级评论指向的是二级评论的id
100 | root_id: int # 暂时还没出现过
101 | native_uri: str # 评论链接,包含根评论id和父评论id
102 | at_details: List[dict] # at的人的信息,常规的个人信息dict
103 |
104 |
105 | class AskAICommandParams(BaseModel):
106 | """问问ai包含的参数"""
107 |
108 | question: str # 用户提出的问题
109 |
110 |
111 | class BiliGPTTask(BaseModel):
112 | """单任务全生命周期的数据模型 用于替代其他所有的已有类型"""
113 |
114 | source_type: Annotated[str, StringConstraints(pattern=r"^(bili_comment|bili_private|api|bili_up)$")] # type: ignore # 设置task的获取来源
115 | raw_task_data: dict # 原始的task数据,包含所有信息
116 | sender_id: int # task提交者的id,用于统计。来自b站的task就是uid,其他来源的task要自己定义
117 | # video_title: str # 视频标题
118 | video_url: str # 视频链接
119 | video_id: str # bvid
120 | source_command: str # 用户发送的原始指令(eg. "总结一下" "问一下:xxxxxxx")
121 | # mission: bool = Field(default=False) # 用户AT还是自发检测的标志
122 | command_params: Optional[AskAICommandParams] = None # 用户原始指令经解析后的参数
123 | source_extra_attr: Optional[BiliAtSpecialAttributes] = None # 在获取到task时附加的其他原始参数(比如评论id等)
124 | process_result: Optional[Union[SummarizeAiResponse, AskAIResponse, str, dict]] = (
125 | None # 最终处理结果,根据不同的处理链会有不同的结果 (dict的存在是一个历史遗留问题,不想解决了,再拉一坨)
126 | )
127 | subtitle: Optional[str] = None # 该视频字幕,与之前不同的是,现在不管是什么方式得到的字幕都要保存下来
128 | process_stage: Optional[ProcessStages] = Field(default=ProcessStages.PREPROCESS) # 视频处理阶段
129 | chain: Optional[Chains] = None # 视频处理事件,即对应的处理链
130 | uuid: Optional[str] = Field(default=str(uuid.uuid4())) # 该任务的uuid4
131 | gmt_create: int = Field(default=int(time.time())) # 任务创建时间戳,默认为当前时间戳
132 | gmt_start_process: int = Field(default=0) # 任务开始处理时间,不同于上方的gmt_create,这个是真正开始处理的时间
133 | gmt_retry_start: int = Field(default=0) # 如果该任务被重试,就在开始重试时填写该属性
134 | gmt_end: int = Field(default=0) # 任务彻底结束时间
135 | error_msg: Optional[str] = None # 更详细的错误信息
136 | end_reason: Optional[EndReasons] = None # 任务结束原因
137 |
138 |
139 | # class AtItem(TypedDict):
140 | # """里面储存着待处理任务的所有信息,私信消息也会被转换为这种格式再处理,后续可以进一步清洗,形成这个项目自己的格式"""
141 | #
142 | # type: str # 基本都为reply
143 | # business: str # 基本都为评论
144 | # business_id: int # 基本都为1
145 | # title: str # 如果是一级回复,这里是视频标题,如果是二级回复,这里是一级回复的内容
146 | # image: str # 一级回复是视频封面,二级回复为空
147 | # uri: str # 视频链接
148 | # source_content: str # 回复内容
149 | # source_id: int # 该评论的id,对应send_comment中的root(如果要回复的话)
150 | # target_id: int # 上一级评论id, 二级评论指向的就是root_id,三级评论指向的是二级评论的id
151 | # root_id: int # 暂时还没出现过
152 | # native_url: str # 评论链接,包含根评论id和父评论id
153 | # at_details: List[dict] # at的人的信息,常规的个人信息dict
154 | # ai_response: NotRequired[SummarizeAiResponse | str] # AI回复的内容,需要等到处理完才能获取到dict,否则为还没处理的str
155 | # is_private_msg: NotRequired[bool] # 是否为私信
156 | # private_msg_event: NotRequired[PrivateMsgSession] # 单用户私信会话信息
157 | # whisper_subtitle: NotRequired[str] # whisper字幕
158 | # stage: NotRequired[ProcessStages] # 视频处理阶段
159 | # event: NotRequired[Chains] # 视频处理事件
160 | # uuid: NotRequired[str] # 视频处理uuid
161 |
162 |
163 | # class AtItems(TypedDict):
164 | # id: int
165 | # user: dict # at发送者的个人信息,常规的个人信息dict
166 | # item: List[AtItem]
167 | # at_time: int
168 |
169 |
170 | # class AtAPIResponse(TypedDict):
171 | # """API返回的at消息"""
172 | #
173 | # cursor: AtCursor
174 | # items: List[AtItems]
175 |
176 |
177 | # class TaskStatus(BaseModel):
178 | # """视频记录"""
179 | #
180 | # gmt_create: int
181 | # gmt_end: Optional[int]
182 | # event: Chains
183 | # stage: ProcessStages
184 | # task_data: BiliGPTTask
185 | # end_reason: Optional[EndReasons]
186 | # error_msg: Optional[str]
187 | # use_whisper: Optional[bool]
188 | # if_retry: Optional[bool]
189 |
190 | # class AtCursor(TypedDict):
191 | # is_end: bool
192 | # id: int
193 | # time: int
194 |
195 |
196 | # class PrivateMsg(BaseModel):
197 | # """
198 | # 事件参数:
199 | # + receiver_id: 收信人 UID
200 | # + receiver_type: 收信人类型,1: 私聊, 2: 应援团通知, 3: 应援团
201 | # + sender_uid: 发送人 UID
202 | # + talker_id: 对话人 UID
203 | # + msg_seqno: 事件 Seqno
204 | # + msg_type: 事件类型
205 | # + msg_key: 事件唯一编号
206 | # + timestamp: 事件时间戳
207 | # + content: 事件内容
208 | #
209 | # 事件类型:
210 | # + TEXT: 纯文字消息
211 | # + PICTURE: 图片消息
212 | # + WITHDRAW: 撤回消息
213 | # + GROUPS_PICTURE: 应援团图片,但似乎不常触发,一般使用 PICTURE 即可
214 | # + SHARE_VIDEO: 分享视频
215 | # + NOTICE: 系统通知
216 | # + PUSHED_VIDEO: UP主推送的视频
217 | # + WELCOME: 新成员加入应援团欢迎
218 | #
219 | # TEXT = "1"
220 | # PICTURE = "2"
221 | # WITHDRAW = "5"
222 | # GROUPS_PICTURE = "6"
223 | # SHARE_VIDEO = "7"
224 | # NOTICE = "10"
225 | # PUSHED_VIDEO = "11"
226 | # WELCOME = "306"
227 | # """
228 | #
229 | # receiver_id: int
230 | # receiver_type: int
231 | # sender_uid: int
232 | # talker_id: int
233 | # msg_seqno: int
234 | # msg_type: int
235 | # msg_key: int
236 | # timestamp: int
237 | # content: Union[str, int, Picture, Video]
238 |
239 | # class PrivateMsgSession(BaseModel):
240 | # """储存单个用户的私信会话信息"""
241 | #
242 | # status: str # 状态
243 | # text_event: Optional[PrivateMsg] # 文本事件
244 | # video_event: Optional[PrivateMsg] # 视频事件
245 |
--------------------------------------------------------------------------------
/src/chain/ask_ai.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import time
3 | import traceback
4 |
5 | import tenacity
6 | import yaml
7 |
8 | from src.bilibili.bili_session import BiliSession
9 | from src.chain.base_chain import BaseChain
10 | from src.llm.templates import Templates
11 | from src.models.task import AskAIResponse, BiliGPTTask, Chains, ProcessStages
12 | from src.utils.callback import chain_callback
13 | from src.utils.logging import LOGGER
14 |
15 | _LOGGER = LOGGER.bind(name="ask-ai-chain")
16 |
17 |
18 | class AskAI(BaseChain):
19 | async def _precheck(self, task: BiliGPTTask) -> bool:
20 | match task.source_type:
21 | case "bili_private":
22 | _LOGGER.debug("该消息是私信消息,继续处理")
23 | await BiliSession.quick_send(self.credential, task, "视频已开始处理,你先别急")
24 | return True
25 | case "bili_comment":
26 | _LOGGER.debug("该消息是评论消息,继续处理")
27 | return True
28 | case "api":
29 | _LOGGER.debug("该消息是api消息,继续处理")
30 | return True
31 | case "bili_up":
32 | _LOGGER.debug("该消息是评论消息,继续处理")
33 | return True
34 | return False
35 |
36 | @tenacity.retry(
37 | retry=tenacity.retry_if_exception_type(Exception),
38 | wait=tenacity.wait_fixed(10),
39 | before_sleep=chain_callback,
40 | )
41 | async def main(self):
42 | try:
43 | await self._on_start()
44 | while True:
45 | task: BiliGPTTask = await self.ask_ai_queue.get()
46 | _item_uuid = task.uuid
47 | self._create_record(task)
48 | _LOGGER.info(f"ask_ai处理链获取到任务了:{task.uuid}")
49 | # 检查是否满足处理条件
50 | if task.process_stage == ProcessStages.END:
51 | _LOGGER.info(f"任务{task.uuid}已经结束,获取下一个")
52 | continue
53 | if not await self._precheck(task):
54 | continue
55 | # 获取视频相关信息
56 | # data = self.task_status_recorder.get_data_by_uuid(_item_uuid)
57 | resp = await self._get_video_info(task, if_get_comments=False)
58 | if resp is None:
59 | continue
60 | (
61 | video,
62 | video_info,
63 | format_video_name,
64 | video_tags_string,
65 | video_comments,
66 | ) = resp
67 | if task.process_stage in (
68 | ProcessStages.PREPROCESS,
69 | ProcessStages.WAITING_LLM_RESPONSE,
70 | ):
71 | begin_time = time.perf_counter()
72 | # FIXME: 需要修改项目的cache实现,标注来自于哪个处理链,否则事有点大
73 | if await self._is_cached_video(task, _item_uuid, video_info):
74 | continue
75 | # 处理视频音频流和字幕
76 | _LOGGER.debug("视频信息获取成功,正在获取视频音频流和字幕")
77 | if task.subtitle is not None:
78 | text = task.subtitle
79 | _LOGGER.debug("使用字幕缓存,开始使用模板生成prompt")
80 | else:
81 | text = await self._smart_get_subtitle(video, _item_uuid, format_video_name, task)
82 | if text is None:
83 | continue
84 | task.subtitle = text
85 | _LOGGER.info(
86 | f"视频{format_video_name}音频流和字幕处理完成,共用时{time.perf_counter() - begin_time}s,开始调用LLM生成摘要"
87 | )
88 | self.task_status_recorder.update_record(
89 | _item_uuid,
90 | new_task_data=task,
91 | process_stage=ProcessStages.WAITING_LLM_RESPONSE,
92 | )
93 | llm = self.llm_router.get_one()
94 | if llm is None:
95 | _LOGGER.warning("没有可用的LLM,关闭系统")
96 | await self._set_err_end(msg="没有可用的LLM,被迫结束处理", task=task)
97 | self.stop_event.set()
98 | continue
99 | prompt = llm.use_template(
100 | Templates.ASK_AI_USER,
101 | Templates.ASK_AI_SYSTEM,
102 | title=video_info["title"],
103 | subtitle=text,
104 | description=video_info["desc"],
105 | question=task.command_params.question,
106 | )
107 | _LOGGER.debug("prompt生成成功,开始调用llm")
108 | # 调用openai的Completion API
109 | response = await llm.completion(prompt)
110 | if response is None:
111 | _LOGGER.warning(f"任务{task.uuid}:ai未返回任何内容,请自行检查问题,跳过处理")
112 | await self._set_err_end(msg="ai未返回任何内容,请自行检查问题,跳过处理", task=task)
113 | self.llm_router.report_error(llm.alias)
114 | continue
115 | answer, tokens = response
116 | self.now_tokens += tokens
117 | _LOGGER.debug(f"llm输出内容为:{answer}")
118 | _LOGGER.debug("调用llm成功,开始处理结果")
119 | task.process_result = answer
120 | task.process_stage = ProcessStages.WAITING_SEND
121 | self.task_status_recorder.update_record(_item_uuid, task)
122 | if task.process_stage in (
123 | ProcessStages.WAITING_SEND,
124 | ProcessStages.WAITING_RETRY,
125 | ):
126 | begin_time = time.perf_counter()
127 | answer = task.process_result
128 | # obj, _type = await video.get_video_obj()
129 | # 处理结果
130 | if answer:
131 | try:
132 | if task.process_stage == ProcessStages.WAITING_RETRY:
133 | raise Exception("触发重试")
134 | answer = answer.replace("False", "false") # 解决一部分因为大小写问题导致的json解析失败
135 | answer = answer.replace("True", "true")
136 | resp = yaml.safe_load(answer)
137 | task.process_result = AskAIResponse.model_validate(resp)
138 | _LOGGER.info(
139 | f"ai返回内容解析正确,视频{format_video_name}摘要处理完成,共用时{time.perf_counter() - begin_time}s"
140 | )
141 | await self.finish(task)
142 | except Exception as e:
143 | _LOGGER.error(f"处理结果失败:{e},大概是ai返回的格式不对,尝试修复")
144 | traceback.print_tb(e.__traceback__)
145 | self.task_status_recorder.update_record(
146 | _item_uuid,
147 | new_task_data=task,
148 | process_stage=ProcessStages.WAITING_RETRY,
149 | )
150 | await self.retry(
151 | answer,
152 | task,
153 | format_video_name,
154 | begin_time,
155 | video_info,
156 | )
157 | except asyncio.CancelledError:
158 | _LOGGER.info("收到关闭信号,ask_ai处理链关闭")
159 |
160 | async def _on_start(self):
161 | """在启动处理链时先处理一下之前没有处理完的视频"""
162 | _LOGGER.info("正在启动摘要处理链,开始将上次未处理完的视频加入队列")
163 | uncomplete_task = []
164 | uncomplete_task += self.task_status_recorder.get_record_by_stage(
165 | chain=Chains.ASK_AI
166 | ) # 有坑,这里会把之前运行过的也重新加回来,不过我下面用判断简单补了一手,叫我天才!
167 | for task in uncomplete_task:
168 | if task["process_stage"] != ProcessStages.END.value:
169 | try:
170 | _LOGGER.debug(f"恢复uuid: {task['uuid']} 的任务")
171 | self.ask_ai_queue.put_nowait(BiliGPTTask.model_validate(task))
172 | except Exception:
173 | traceback.print_exc()
174 | # TODO 这里除了打印日志,是不是还应该记录在视频状态中?
175 | _LOGGER.error(f"在恢复uuid: {task['uuid']} 时出现错误!跳过恢复")
176 |
177 | async def retry(self, ai_answer, task: BiliGPTTask, format_video_name, begin_time, video_info):
178 | """通过重试prompt让chatgpt重新构建json
179 |
180 | :param ai_answer: ai返回的内容
181 | :param task: queue中的原始数据
182 | :param format_video_name: 格式化后的视频名称
183 | :param begin_time: 开始时间
184 | :param video_info: 视频信息
185 | :return: None
186 | """
187 | _LOGGER.error(
188 | f"任务{task.uuid}:真不好意思!但是我ask_ai部分的retry还没写!所以只能先全给你设置成错误结束了嘿嘿嘿"
189 | )
190 | await self._set_err_end(
191 | msg="重试代码未实现,跳过",
192 | task=task,
193 | )
194 | return False
195 |
--------------------------------------------------------------------------------
/src/llm/spark.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import hashlib
3 | import hmac
4 | import json
5 | import traceback
6 | from datetime import datetime
7 | from time import mktime
8 | from typing import Tuple
9 | from urllib.parse import urlencode, urlparse
10 | from wsgiref.handlers import format_date_time
11 |
12 | import websockets
13 |
14 | from src.llm.llm_base import LLMBase
15 | from src.llm.templates import Templates
16 | from src.utils.logging import LOGGER
17 | from src.utils.prompt_utils import build_openai_style_messages, parse_prompt
18 |
19 | _LOGGER = LOGGER.bind(name="spark")
20 |
21 |
22 | class Spark(LLMBase):
23 | def prepare(self):
24 | self._answer_temp = "" # 用于存储讯飞星火大模型的返回结果
25 | self._once_total_tokens = 0 # 用于存储讯飞星火大模型的返回结果的token数
26 |
27 | def create_url(self):
28 | """
29 | 生成鉴权url
30 | :return:
31 | """
32 | host = urlparse(self.config.LLMs.spark.spark_url).netloc
33 | path = urlparse(self.config.LLMs.spark.spark_url).path
34 | # 生成RFC1123格式的时间戳
35 | now = datetime.now()
36 | date = format_date_time(mktime(now.timetuple()))
37 |
38 | # 拼接字符串
39 | signature_origin = "host: " + host + "\n"
40 | signature_origin += "date: " + date + "\n"
41 | signature_origin += "GET " + path + " HTTP/1.1"
42 |
43 | # 进行hmac-sha256进行加密
44 | signature_sha = hmac.new(
45 | self.config.LLMs.spark.api_secret.encode("utf-8"),
46 | signature_origin.encode("utf-8"),
47 | digestmod=hashlib.sha256,
48 | ).digest()
49 |
50 | signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")
51 |
52 | authorization_origin = f'api_key="{self.config.LLMs.spark.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
53 |
54 | authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
55 |
56 | # 将请求的鉴权参数组合为字典
57 | v = {"authorization": authorization, "date": date, "host": host}
58 | # 拼接鉴权参数,生成url
59 | url = self.config.LLMs.spark.spark_url + "?" + urlencode(v)
60 | _LOGGER.debug(f"生成的url为:{url}")
61 | # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
62 | return url
63 |
64 | async def on_message(self, ws, message) -> int:
65 | """
66 |
67 | :param ws:
68 | :param message:
69 | :return: 1为还未结束 0为正常结束 2为异常结束
70 | """
71 | data = json.loads(message)
72 | code = data["header"]["code"]
73 | if code != 0:
74 | _LOGGER.error(f"讯飞星火大模型请求失败: 错误代码:{code} 返回内容:{data}")
75 | await ws.close()
76 | if code == 10013 or code == 10014:
77 | self._once_total_tokens = 0
78 | self._answer_temp = """{"summary":"⚠⚠⚠我也很想告诉你视频的总结,但是星火却跟我说这个视频的总结是***,真的是离谱他🐎给离谱开门——离谱到家了。我也没有办法,谁让星火可以白嫖500w个token🐷。为了白嫖,忍一下,换个视频试一试!","score":"0","thinking":"🤡老子是真的服了这个讯飞星火,国际友好手势(一种动作)。","if_no_need_summary": false}"""
79 | return 0
80 | return 2
81 | else:
82 | choices = data["payload"]["choices"]
83 | status = choices["status"]
84 | content = choices["text"][0]["content"]
85 | self._answer_temp += content
86 | if status == 2:
87 | self._once_total_tokens = data["payload"]["usage"]["text"]["total_tokens"]
88 | await ws.close()
89 | return 0
90 | return 1
91 |
92 | async def completion(self, prompt, **kwargs) -> Tuple[str, int] | None:
93 | try:
94 | self._answer_temp = ""
95 | self._once_total_tokens = 0
96 | ws_url = self.create_url()
97 |
98 | async with websockets.connect(ws_url) as websocket:
99 | websocket.appid = self.config.LLMs.spark.appid
100 | websocket.question = prompt
101 | websocket.domain = self.config.LLMs.spark.domain
102 |
103 | data = json.dumps(self.gen_params(prompt))
104 | await websocket.send(data)
105 | async for message in websocket:
106 | res = await self.on_message(websocket, message)
107 | if res == 2:
108 | # 如果出现异常,直接返回(上层已经打印过错误,直接返回)
109 | return None
110 | _LOGGER.info(
111 | f"调用讯飞星火大模型成功,返回结果为:{self._answer_temp},本次调用中,prompt+response的长度为{self._once_total_tokens}"
112 | )
113 |
114 | # 处理返回结果(图省事的方法)
115 | if self._answer_temp.startswith("```json"):
116 | self._answer_temp = self._answer_temp[7:]
117 | if self._answer_temp.endswith("```"):
118 | self._answer_temp = self._answer_temp[:-3]
119 | # 星火返回的json永远是单引号包围的,下面尝试使用eval方式解析
120 | # try:
121 | # _answer = self._answer_temp
122 | # _answer = _answer.replace("true", "True")
123 | # _answer = _answer.replace("false", "False")
124 | # _answer = ast.literal_eval(_answer) # 骚操作
125 | # _answer = json.dumps(_answer, ensure_ascii=False)
126 | # _LOGGER.debug(f"经简单处理后的返回结果为:{_answer}")
127 | # return _answer, self._once_total_tokens
128 | # except Exception as e:
129 | # _LOGGER.error(f"尝试使用eval方式解析星火返回的json失败:{e}")
130 | # traceback.print_exc()
131 | # 如果eval方式解析失败,直接返回
132 | _LOGGER.debug(f"经简单处理后的返回结果为:{self._answer_temp}")
133 | return self._answer_temp, self._once_total_tokens
134 | except Exception as e:
135 | traceback.print_exc()
136 | _LOGGER.error(f"调用讯飞星火大模型失败:{e}")
137 | return None
138 |
139 | def gen_params(self, prompt_list) -> dict:
140 | """
141 | 通过appid和用户的提问来生成提问参数
142 |
143 | :param prompt_list: 用户的提问
144 | """
145 | data = {
146 | "header": {
147 | "app_id": self.config.LLMs.spark.appid,
148 | },
149 | "parameter": {
150 | "chat": {
151 | "domain": self.config.LLMs.spark.domain,
152 | "temperature": 0.5,
153 | "max_tokens": 8192,
154 | }
155 | },
156 | "payload": {"message": {"text": prompt_list}},
157 | }
158 | _LOGGER.debug(f"生成的参数为:{data}")
159 | return data
160 |
161 | @staticmethod
162 | def use_template(
163 | user_template_name: Templates,
164 | system_template_name: Templates = None,
165 | user_keyword="user",
166 | system_keyword="system",
167 | **kwargs,
168 | ) -> list | None:
169 | """使用模板生成最终prompt(最终格式可能需要根据llm所需格式不同修改,默认为openai的system、user格式)
170 | :param user_template_name: 用户模板名称
171 | :param system_template_name: 系统模板名称
172 | :param user_keyword: 用户关键词(这个和下面的system_keyword要根据每个llm不同的要求来填)
173 | :param system_keyword: 系统关键词
174 | :param kwargs: 模板参数
175 | :return: 返回生成的prompt 或 None
176 | """
177 | try:
178 | if system_template_name is None:
179 | template_system = None
180 | elif system_template_name.name == "SUMMARIZE_SYSTEM":
181 | template_system = """
182 | 你是一个视频摘要制作人。请根据视频的标题、描述、标签、字幕、评论来制作视频摘要。指南如下:
183 | 1. 返回的内容严格遵循此JSON格式:{"summary": "你的摘要内容", "score": "你给这个视频的评分(最高100)", "thinking": "你的想法", "if_no_need_summary": "是否需要摘要?填写布尔值"}
184 | 2. 信息要完整且信息量大。
185 | 3. 保持准确,不要添加虚构信息。
186 | 4. 保持幽默,避免正式,需要时可以使用中文网络语言和表情符号。
187 | 5. 忽略不相关的标签/评论。
188 | 6. 分段落写作。
189 | 7. "score"部分的值必须是字符串类型。
190 | 8. 写作"thinking"部分时,忽略上述规则,表达你的独立思考,包括你自己的激进观点,而非视频内容。评分基于这个视频。如果你认为你创造了一个有意义的摘要,给自己一个高分。
191 | 9. 如果视频毫无意义,将此JSON的"if_no_need_summary"设置为true,否则设置为false。
192 | 10. 返回的内容只允许纯JSON格式,JSON的键和值必须使用英文双引号包裹!请使用简体中文!
193 | """
194 | else:
195 | template_system = system_template_name.value
196 | if user_template_name.name == "SUMMARIZE_USER":
197 | template_user = (
198 | """标题:[title]\n\n简介:[description]\n\n字幕:[subtitle]\n\n标签:[tags]\n\n评论:[comments]"""
199 | )
200 | elif user_template_name.name == "ASK_AI_USER":
201 | template_user = """
202 | 标题: [title]\n\n简介: [description]\n\n字幕: [subtitle]\n\n用户问题: [question]\n\n
203 | 你是一位专业的视频问答老师。我将提供给你视频的标题、描述和字幕。根据这些信息和你的专业知识,以生动幽默的方式回答用户的问题,必要时使用比喻和例子。
204 | 请按照以下JSON格式回复:{"answer": "你的回答", "score": "你对回答质量的自我评分(0-100)"}
205 | !!!只允许使用双引号的纯JSON内容!请使用中文!不要添加任何其他内容!!!
206 | """
207 | elif user_template_name.name == "SUMMARIZE_RETRY":
208 | template_user = """请将以下文本翻译成此JSON格式并返回给我,不要添加任何其他内容。如果不存在 'summary' 字段,请将 'if_no_need_summary' 设置为 true。如果除 'summary' 之外的字段缺失,则可以忽略并留空, 'if_no_need_summary' 保持 false\n\n标准JSON格式:{"summary": "您的摘要内容", "score": "您给这个视频的评分(最高100分)", "thinking": "您的想法", "if_no_need_summary": "是否需要摘要?填写布尔值"}\n\n我的内容:[input]"""
209 | else:
210 | template_user = user_template_name.value
211 | utemplate = parse_prompt(template_user, **kwargs)
212 | stemplate = parse_prompt(template_system, **kwargs) if template_system else None
213 | # final_template = utemplate + stemplate if stemplate else utemplate # 特殊处理,system附加到user后面
214 | prompt = (
215 | build_openai_style_messages(utemplate, stemplate, user_keyword, system_keyword)
216 | if stemplate
217 | else build_openai_style_messages(utemplate, user_keyword=user_keyword)
218 | # build_openai_style_messages(final_template, user_keyword=user_keyword)
219 | )
220 | _LOGGER.info("使用模板成功")
221 | _LOGGER.debug(f"生成的prompt为:{prompt}")
222 | return prompt
223 | except Exception as e:
224 | _LOGGER.error(f"使用模板失败:{e}")
225 | traceback.print_exc()
226 | return None
227 |
--------------------------------------------------------------------------------
/src/chain/summarize.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import time
3 | import traceback
4 |
5 | import tenacity
6 | import yaml
7 |
8 | from src.bilibili.bili_session import BiliSession
9 | from src.chain.base_chain import BaseChain
10 | from src.llm.templates import Templates
11 | from src.models.task import BiliGPTTask, Chains, ProcessStages, SummarizeAiResponse
12 | from src.utils.callback import chain_callback
13 | from src.utils.logging import LOGGER
14 |
15 | _LOGGER = LOGGER.bind(name="summarize-chain")
16 |
17 |
18 | class Summarize(BaseChain):
19 | """摘要处理链"""
20 |
21 | async def _precheck(self, task: BiliGPTTask) -> bool:
22 | """检查是否满足处理条件"""
23 | match task.source_type:
24 | case "bili_private":
25 | _LOGGER.debug("该消息是私信消息,继续处理")
26 | await BiliSession.quick_send(self.credential, task, "视频已开始处理,你先别急")
27 | return True
28 | case "bili_comment":
29 | _LOGGER.debug("该消息是评论消息,继续处理")
30 | return True
31 | case "api":
32 | _LOGGER.debug("该消息是api消息,继续处理")
33 | return True
34 | case "bili_up":
35 | _LOGGER.debug("该消息是up更新消息,继续处理")
36 | return True
37 | # if task["item"]["type"] != "reply" or task["item"]["business_id"] != 1:
38 | # _LOGGER.warning(f"该消息目前并不支持,跳过处理")
39 | # self._set_err_end(_uuid, "该消息目前并不支持,跳过处理")
40 | # return False
41 | # if task["item"]["root_id"] != 0 or task["item"]["target_id"] != 0:
42 | # _LOGGER.warning(f"该消息是楼中楼消息,暂时不受支持,跳过处理") # TODO 楼中楼消息的处理
43 | # self._set_err_end(_uuid, "该消息是楼中楼消息,暂时不受支持,跳过处理")
44 | # return False
45 | return False
46 |
47 | async def _on_start(self):
48 | """在启动处理链时先处理一下之前没有处理完的视频"""
49 | _LOGGER.info("正在启动摘要处理链,开始将上次未处理完的视频加入队列")
50 | uncomplete_task = []
51 | uncomplete_task += self.task_status_recorder.get_record_by_stage(
52 | chain=Chains.SUMMARIZE
53 | ) # 有坑,这里会把之前运行过的也重新加回来,不过我下面用判断简单补了一手,叫我天才!
54 | for task in uncomplete_task:
55 | if task["process_stage"] != ProcessStages.END.value:
56 | try:
57 | _LOGGER.debug(f"恢复uuid: {task['uuid']} 的任务")
58 | self.summarize_queue.put_nowait(BiliGPTTask.model_validate(task))
59 | except Exception:
60 | traceback.print_exc()
61 | # TODO 这里除了打印日志,是不是还应该记录在视频状态中?
62 | _LOGGER.error(f"在恢复uuid: {task['uuid']} 时出现错误!跳过恢复")
63 | # _LOGGER.info(f"之前未处理完的视频已经全部加入队列,共{len(uncomplete_task)}个")
64 | # self.queue_manager.(self.summarize_queue, "summarize")
65 | # _LOGGER.info("正在将上次在队列中的视频加入队列")
66 | # self.task_status_recorder.delete_queue("summarize")
67 |
68 | @tenacity.retry(
69 | retry=tenacity.retry_if_exception_type(Exception),
70 | wait=tenacity.wait_fixed(10),
71 | before_sleep=chain_callback,
72 | )
73 | async def main(self):
74 | try:
75 | await self._on_start()
76 | while True:
77 | # if self.max_tokens is not None and self.now_tokens >= self.max_tokens:
78 | # _LOGGER.warning(
79 | # f"当前已使用token数{self.now_tokens},超过最大token数{self.max_tokens},摘要处理链停止运行"
80 | # )
81 | # raise asyncio.CancelledError
82 |
83 | # 从队列中获取摘要
84 | task: BiliGPTTask = await self.summarize_queue.get()
85 | _item_uuid = task.uuid
86 | self._create_record(task)
87 | _LOGGER.info(f"summarize处理链获取到任务了:{task.uuid}")
88 | # 检查是否满足处理条件
89 | if task.process_stage == ProcessStages.END:
90 | _LOGGER.info(f"任务{task.uuid}已经结束,获取下一个")
91 | continue
92 | if not await self._precheck(task):
93 | continue
94 | # 获取视频相关信息
95 | # data = self.task_status_recorder.get_data_by_uuid(_item_uuid)
96 | resp = await self._get_video_info(task)
97 | if resp is None:
98 | continue
99 | (
100 | video,
101 | video_info,
102 | format_video_name,
103 | video_tags_string,
104 | video_comments,
105 | ) = resp
106 | if task.process_stage in (
107 | ProcessStages.PREPROCESS,
108 | ProcessStages.WAITING_LLM_RESPONSE,
109 | ):
110 | begin_time = time.perf_counter()
111 | if await self._is_cached_video(task, _item_uuid, video_info):
112 | continue
113 | # 处理视频音频流和字幕
114 | _LOGGER.debug("视频信息获取成功,正在获取视频音频流和字幕")
115 | if task.subtitle is not None:
116 | text = task.subtitle
117 | _LOGGER.debug("使用字幕缓存,开始使用模板生成prompt")
118 | else:
119 | text = await self._smart_get_subtitle(video, _item_uuid, format_video_name, task)
120 | if text is None:
121 | continue
122 | task.subtitle = text
123 | _LOGGER.info(
124 | f"视频{format_video_name}音频流和字幕处理完成,共用时{time.perf_counter() - begin_time}s,开始调用LLM生成摘要"
125 | )
126 | self.task_status_recorder.update_record(
127 | _item_uuid,
128 | new_task_data=task,
129 | process_stage=ProcessStages.WAITING_LLM_RESPONSE,
130 | )
131 | llm = self.llm_router.get_one()
132 | if llm is None:
133 | _LOGGER.warning("没有可用的LLM,关闭系统")
134 | await self._set_err_end(msg="没有可用的LLM,被迫结束处理", task=task)
135 | self.stop_event.set()
136 | continue
137 | prompt = llm.use_template(
138 | Templates.SUMMARIZE_USER,
139 | Templates.SUMMARIZE_SYSTEM,
140 | title=video_info["title"],
141 | tags=video_tags_string,
142 | comments=video_comments,
143 | subtitle=text,
144 | description=video_info["desc"],
145 | )
146 | _LOGGER.debug("prompt生成成功,开始调用llm")
147 | # 调用openai的Completion API
148 | response = await llm.completion(prompt)
149 | if response is None:
150 | _LOGGER.warning(f"任务{task.uuid}:ai未返回任何内容,请自行检查问题,跳过处理")
151 | await self._set_err_end(
152 | msg="AI未返回任何内容,我也不知道为什么,估计是调休了吧。换个视频或者等一小会儿再试一试。",
153 | task=task,
154 | )
155 | self.llm_router.report_error(llm.alias)
156 | continue
157 | answer, tokens = response
158 | self.now_tokens += tokens
159 | _LOGGER.debug(f"llm输出内容为:{answer}")
160 | _LOGGER.debug("调用llm成功,开始处理结果")
161 | task.process_result = answer
162 | task.process_stage = ProcessStages.WAITING_SEND
163 | self.task_status_recorder.update_record(_item_uuid, task)
164 | if task.process_stage in (
165 | ProcessStages.WAITING_SEND,
166 | ProcessStages.WAITING_RETRY,
167 | ):
168 | begin_time = time.perf_counter()
169 | answer = task.process_result
170 | # obj, _type = await video.get_video_obj()
171 | # 处理结果
172 | if answer:
173 | try:
174 | if task.process_stage == ProcessStages.WAITING_RETRY:
175 | raise Exception("触发重试")
176 |
177 | answer = answer.replace("False", "false") # 解决一部分因为大小写问题导致的json解析失败
178 | answer = answer.replace("True", "true")
179 |
180 | ai_resp = yaml.safe_load(answer)
181 | ai_resp["score"] = str(ai_resp["score"]) # 预防返回的值类型为int,强转成str
182 | task.process_result = SummarizeAiResponse.model_validate(ai_resp)
183 | if task.process_result.if_no_need_summary is True:
184 | _LOGGER.warning(f"视频{format_video_name}被ai判定为不需要摘要,跳过处理")
185 | await BiliSession.quick_send(
186 | self.credential,
187 | task,
188 | "AI觉得你的视频不需要处理,换个更有意义的视频再试试看吧!",
189 | )
190 | # await BiliSession.quick_send(
191 | # self.credential, task, answer
192 | # )
193 | await self._set_noneed_end(task)
194 | continue
195 | _LOGGER.info(
196 | f"ai返回内容解析正确,视频{format_video_name}摘要处理完成,共用时{time.perf_counter() - begin_time}s"
197 | )
198 | await self.finish(task)
199 |
200 | except Exception as e:
201 | _LOGGER.error(f"处理结果失败:{e},大概是ai返回的格式不对,尝试修复")
202 | traceback.print_tb(e.__traceback__)
203 | self.task_status_recorder.update_record(
204 | _item_uuid,
205 | new_task_data=task,
206 | process_stage=ProcessStages.WAITING_RETRY,
207 | )
208 | await self.retry(
209 | answer,
210 | task,
211 | format_video_name,
212 | begin_time,
213 | video_info,
214 | )
215 | except asyncio.CancelledError:
216 | _LOGGER.info("收到关闭信号,摘要处理链关闭")
217 |
218 | async def retry(self, ai_answer, task: BiliGPTTask, format_video_name, begin_time, video_info):
219 | """通过重试prompt让chatgpt重新构建json
220 |
221 | :param ai_answer: ai返回的内容
222 | :param task: queue中的原始数据
223 | :param format_video_name: 格式化后的视频名称
224 | :param begin_time: 开始时间
225 | :param video_info: 视频信息
226 | :return: None
227 | """
228 | _LOGGER.debug(f"任务{task.uuid}:ai返回内容解析失败,正在尝试重试")
229 | task.gmt_retry_start = int(time.time())
230 | llm = self.llm_router.get_one()
231 | if llm is None:
232 | _LOGGER.warning("没有可用的LLM,关闭系统")
233 | await self._set_err_end(msg="没有可用的LLM,跳过处理", task=task)
234 | self.stop_event.set()
235 | return False
236 | prompt = llm.use_template(Templates.SUMMARIZE_RETRY, input=ai_answer)
237 | response = await llm.completion(prompt)
238 | if response is None:
239 | _LOGGER.warning(f"视频{format_video_name}摘要生成失败,请自行检查问题,跳过处理")
240 | await self._set_err_end(
241 | msg=f"视频{format_video_name}摘要生成失败,请自行检查问题,跳过处理",
242 | task=task,
243 | )
244 | self.llm_router.report_error(llm.alias)
245 | return False
246 | answer, tokens = response
247 | _LOGGER.debug(f"api输出内容为:{answer}")
248 | self.now_tokens += tokens
249 | if answer:
250 | try:
251 | resp = yaml.safe_load(answer)
252 | resp["score"] = str(resp["score"])
253 | task.process_result = SummarizeAiResponse.model_validate(resp)
254 | if task.process_result.if_no_need_summary is True:
255 | _LOGGER.warning(f"视频{format_video_name}被ai判定为不需要摘要,跳过处理")
256 | await self._set_noneed_end(task)
257 | return False
258 | else:
259 | # TODO 这种运行时间的显示存在很大问题,有空了统一一下,但现在我是没空了
260 | _LOGGER.info(
261 | f"ai返回内容解析正确,视频{format_video_name}摘要处理完成,共用时{time.perf_counter() - begin_time}s"
262 | )
263 | await self.finish(task)
264 | return True
265 | except Exception as e:
266 | _LOGGER.error(f"处理结果失败:{e},大概是ai返回的格式不对,拿你没辙了,跳过处理")
267 | traceback.print_exc()
268 | await self._set_err_end(
269 | msg="重试后处理结果失败,大概是ai返回的格式不对,跳过",
270 | task=task,
271 | )
272 | return False
273 |
--------------------------------------------------------------------------------
/src/chain/base_chain.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import asyncio
3 | import os
4 | import time
5 | from typing import Optional
6 |
7 | import ffmpeg
8 | import httpx
9 | from bilibili_api import HEADERS
10 | from injector import inject
11 |
12 | from src.bilibili.bili_comment import BiliComment
13 | from src.bilibili.bili_credential import BiliCredential
14 | from src.bilibili.bili_session import BiliSession
15 | from src.bilibili.bili_video import BiliVideo
16 | from src.core.routers.asr_router import ASRouter
17 | from src.core.routers.llm_router import LLMRouter
18 | from src.models.config import Config
19 | from src.models.task import (
20 | AskAIResponse,
21 | BiliGPTTask,
22 | EndReasons,
23 | ProcessStages,
24 | SummarizeAiResponse,
25 | )
26 | from src.utils.cache import Cache
27 | from src.utils.logging import LOGGER
28 | from src.utils.queue_manager import QueueManager
29 | from src.utils.task_status_record import TaskStatusRecorder
30 |
31 |
32 | class BaseChain:
33 | """处理链基类
34 | 对于b站来说,处理链需要接管的内容基本都要包含对视频基本信息的处理和字幕的提取,这个基类全部帮你做了
35 | """
36 |
37 | @inject
38 | def __init__(
39 | self,
40 | queue_manager: QueueManager,
41 | config: Config,
42 | credential: BiliCredential,
43 | cache: Cache,
44 | asr_router: ASRouter,
45 | task_status_recorder: TaskStatusRecorder,
46 | stop_event: asyncio.Event,
47 | llm_router: LLMRouter,
48 | ):
49 | self.llm_router = llm_router
50 | self.queue_manager = queue_manager
51 | self.config = config
52 | self.cache = cache
53 | self.asr_router = asr_router
54 | self.now_tokens = 0
55 | self.credential = credential
56 | self.task_status_recorder = task_status_recorder
57 | self._get_variables()
58 | self._get_queues()
59 | self.asr = asr_router.get_one()
60 | self._LOGGER = LOGGER.bind(name=self.__class__.__name__)
61 | self.stop_event = stop_event
62 |
63 | def _get_variables(self):
64 | """从Config获取配置信息"""
65 | self.temp_dir = self.config.storage_settings.temp_dir
66 | self.api_key = self.config.LLMs.openai.api_key
67 | self.api_base = self.config.LLMs.openai.api_base
68 |
69 | def _get_queues(self):
70 | """从队列管理器获取队列"""
71 | self.summarize_queue = self.queue_manager.get_queue("summarize")
72 | self.reply_queue = self.queue_manager.get_queue("reply")
73 | self.private_queue = self.queue_manager.get_queue("private")
74 | self.ask_ai_queue = self.queue_manager.get_queue("ask_ai")
75 |
76 | async def _set_err_end(self, msg: str, _uuid: str = None, task: BiliGPTTask = None):
77 | """当一个视频因为错误而结束时,调用此方法
78 |
79 | :param msg: 错误信息
80 | :param _uuid: 任务uuid (跟task二选一)
81 | :param task: 任务对象
82 | """
83 | self.task_status_recorder.update_record(
84 | _uuid if _uuid else task.uuid,
85 | new_task_data=None,
86 | process_stage=ProcessStages.END,
87 | end_reason=EndReasons.ERROR,
88 | gmt_end=int(time.time()),
89 | error_msg=msg,
90 | )
91 | _task = self.task_status_recorder.get_data_by_uuid(_uuid) if _uuid else task
92 | match _task.source_type:
93 | case "bili_private":
94 | self._LOGGER.debug(f"任务{task.uuid}:私信消息,直接回复:{msg}")
95 | await BiliSession.quick_send(
96 | self.credential,
97 | task,
98 | msg,
99 | )
100 | case "bili_comment":
101 | _task.process_result = msg
102 | self._LOGGER.debug(f"任务{task.uuid}:评论消息,将结果放入评论处理队列,内容:{msg}")
103 | await self.reply_queue.put(task)
104 | case "api":
105 | self._LOGGER.warning(f"任务{task.uuid}:api获取的消息,未实现处理逻辑")
106 | case "bili_up":
107 | _task.process_result = msg
108 | self._LOGGER.debug(f"任务{task.uuid}:评论消息,将结果放入评论处理队列,内容:{msg}")
109 | await self.reply_queue.put(task)
110 |
111 | async def _set_normal_end(self, task: BiliGPTTask = None, _uuid: str = None):
112 | """当一个视频正常结束时,调用此方法
113 |
114 | :param task: 任务对象
115 | :param _uuid: 任务uuid (跟task二选一)
116 | """
117 | self.task_status_recorder.update_record(
118 | _uuid if _uuid else task.uuid,
119 | new_task_data=None,
120 | process_stage=ProcessStages.END,
121 | end_reason=EndReasons.NORMAL,
122 | gmt_end=int(time.time()),
123 | )
124 |
125 | async def _set_noneed_end(self, task: BiliGPTTask = None, _uuid: str = None):
126 | """当一个视频不需要处理时,调用此方法
127 |
128 | :param task: 任务对象
129 | :param _uuid: 任务uuid (跟task二选一)
130 | """
131 | self.task_status_recorder.update_record(
132 | _uuid if _uuid else task.uuid,
133 | new_task_data=None,
134 | process_stage=ProcessStages.END,
135 | end_reason=EndReasons.NONEED,
136 | gmt_end=int(time.time()),
137 | )
138 | await BiliSession.quick_send(
139 | self.credential,
140 | task,
141 | "AI觉得你的视频不需要处理,换个更有意义的视频再试试看吧!",
142 | )
143 |
144 | @abc.abstractmethod
145 | async def _precheck(self, task: BiliGPTTask) -> bool:
146 | """检查是否符合调用条件
147 | :param task: AtItem
148 |
149 | 如不符合,请务必调用self._set_err_end()方法后返回False
150 | """
151 | pass
152 |
153 | async def finish(self, task: BiliGPTTask, use_cache: bool = False) -> bool:
154 | """
155 | 当一个任务 **正常** 结束时,调用这个,将消息放入队列、设置缓存、更新任务状态
156 | :param task:
157 | :param use_cache: 是否直接使用缓存而非正常处理
158 | :return:
159 | """
160 | _LOGGER = self._LOGGER
161 | reply_data = task
162 | # if reply_data.source_type == "bili_private":
163 | # _LOGGER.debug("该消息是私信消息,将结果放入私信处理队列")
164 | # await self.private_queue.put(reply_data)
165 | # elif reply_data.source_type == "bili_comment":
166 | # _LOGGER.debug("正在将结果加入发送队列,等待回复")
167 | # await self.reply_queue.put(reply_data)
168 | match reply_data.source_type:
169 | case "bili_private":
170 | _LOGGER.debug(f"任务{task.uuid}:私信消息,将结果放入私信处理队列")
171 | await self.private_queue.put(reply_data)
172 | case "bili_comment":
173 | _LOGGER.info(f"任务{task.uuid}:评论消息,将结果放入评论处理队列")
174 | await self.reply_queue.put(reply_data)
175 | case "api":
176 | _LOGGER.warning(f"任务{task.uuid}:api获取的消息,未实现处理逻辑")
177 | case "bili_up":
178 | _LOGGER.info(f"任务{task.uuid}:评论消息,将结果放入评论处理队列")
179 | await self.reply_queue.put(reply_data)
180 | _LOGGER.debug("处理结束,开始清理并提交记录")
181 | self.task_status_recorder.update_record(
182 | reply_data.uuid,
183 | new_task_data=task,
184 | process_stage=ProcessStages.WAITING_PUSH_TO_CACHE,
185 | )
186 | if use_cache:
187 | await self._set_normal_end(task)
188 | return True
189 | self.cache.set_cache(
190 | key=reply_data.video_id,
191 | value=reply_data.process_result.model_dump(),
192 | chain=str(task.chain.value),
193 | )
194 | await self._set_normal_end(task)
195 | return True
196 |
197 | async def _is_cached_video(self, task: BiliGPTTask, _uuid: str, video_info: dict) -> bool:
198 | """检查是否是缓存的视频
199 | 如果是缓存的视频,直接从缓存中获取结果并发送
200 | """
201 | if self.cache.get_cache(key=video_info["bvid"], chain=str(task.chain.value)):
202 | LOGGER.debug(f"视频{video_info['title']}已经处理过,直接使用缓存")
203 | cache = self.cache.get_cache(key=video_info["bvid"], chain=str(task.chain.value))
204 | # if str(task.chain.value) == "summarize":
205 | # cache = SummarizeAiResponse.model_validate(cache)
206 | # elif str(task.chain.value) == "ask_ai":
207 | # cache = AskAIResponse.model_validate(cache)
208 | match str(task.chain.value):
209 | case "summarize":
210 | cache = SummarizeAiResponse.model_validate(cache)
211 | case "ask_ai":
212 | cache = AskAIResponse.model_validate(cache)
213 | case _:
214 | self._LOGGER.error(
215 | f"获取到了缓存,但无法匹配处理链{task.chain.value},无法调取缓存,开始按正常流程处理"
216 | )
217 | return False
218 | match task.source_type:
219 | case "bili_private":
220 | task.process_result = cache
221 | await self.finish(task, True)
222 | case "bili_comment":
223 | task.process_result = cache
224 | await self.finish(task, True)
225 | case "bili_up":
226 | task.process_result = cache
227 | await self.finish(task, True)
228 | return True
229 | return False
230 |
231 | async def _get_video_info(
232 | self, task: BiliGPTTask, if_get_comments: bool = True
233 | ) -> Optional[tuple[BiliVideo, dict, str, str, Optional[str]]]:
234 | """获取视频的一些信息
235 | :param task: 任务对象
236 | :param if_get_comments: 是否获取评论,为假就返回空
237 |
238 | :return 视频正常返回元组(video, video_info, format_video_name, video_tags_string, video_comments)
239 |
240 | video: BiliVideo对象
241 | video_info: bilibili官方api返回的视频信息
242 | format_video_name: 格式化后的视频名,用于日志
243 | video_tags_string: 视频标签
244 | video_comments: 随机获取的几条视频评论拼接的字符串
245 | """
246 | _LOGGER = self._LOGGER
247 | _LOGGER.info("开始处理该视频音频流和字幕")
248 | video = BiliVideo(self.credential, url=task.video_url)
249 | _LOGGER.debug("视频对象创建成功,正在获取视频信息")
250 | video_info = await video.get_video_info
251 | _LOGGER.debug("视频信息获取成功,正在获取视频标签")
252 | format_video_name = f"『{video_info['title']}』"
253 | # TODO 不清楚b站回复和at时分P的展现机制,暂时遇到分P视频就跳过
254 | if len(video_info["pages"]) > 1:
255 | _LOGGER.warning(f"任务{task.uuid}: 视频{format_video_name}分P,跳过处理")
256 | await self._set_err_end(msg="视频分P,跳过处理", task=task)
257 | return None
258 | # 获取视频标签
259 | video_tags_string = " ".join(f"#{tag['tag_name']}" for tag in await video.get_video_tags())
260 | _LOGGER.debug("视频标签获取成功,开始获取视频评论")
261 | # 获取视频评论
262 | video_comments = (
263 | await BiliComment.get_random_comment(video_info["aid"], self.credential) if if_get_comments else None
264 | )
265 | return video, video_info, format_video_name, video_tags_string, video_comments
266 |
267 | async def _get_subtitle_from_bilibili(self, video: BiliVideo) -> str:
268 | """从bilibili获取字幕(返回的是纯字幕,不包含时间轴)"""
269 | _LOGGER = self._LOGGER
270 | subtitle_url = await video.get_video_subtitle(page_index=0)
271 | _LOGGER.debug("视频字幕获取成功,正在读取字幕")
272 | # 下载字幕
273 | async with httpx.AsyncClient() as client:
274 | resp = await client.get("https:" + subtitle_url, headers=HEADERS)
275 | _LOGGER.debug("字幕获取成功,正在转换为纯字幕")
276 | # 转换字幕格式
277 | text = ""
278 | for subtitle in resp.json()["body"]:
279 | text += f"{subtitle['content']}\n"
280 | return text
281 |
282 | async def _get_subtitle_from_asr(self, video: BiliVideo, _uuid: str, is_retry: bool = False) -> Optional[str]:
283 | _LOGGER = self._LOGGER
284 | if self.asr is None:
285 | _LOGGER.warning("没有可用的asr,跳过处理")
286 | await self._set_err_end(_uuid=_uuid, msg="没有可用的asr,跳过处理")
287 | if is_retry:
288 | # 如果是重试,就默认已下载音频文件,直接开始转写
289 | bvid = await video.bvid
290 | audio_path = f"{self.temp_dir}/{bvid} temp.mp3"
291 | self.asr = self.asr_router.get_one() # 重新获取一个,防止因为错误而被禁用,但调用端没及时更新
292 | if self.asr is None:
293 | _LOGGER.warning("没有可用的asr,跳过处理")
294 | await self._set_err_end(_uuid, "没有可用的asr,跳过处理")
295 | text = await self.asr.transcribe(audio_path)
296 | if text is None:
297 | _LOGGER.warning("音频转写失败,报告并重试")
298 | self.asr_router.report_error(self.asr.alias)
299 | await self._get_subtitle_from_asr(video, _uuid, is_retry=True) # 递归,应该不会爆栈
300 | return text
301 | _LOGGER.debug("正在获取视频音频流")
302 | video_download_url = await video.get_video_download_url()
303 | audio_url = video_download_url["dash"]["audio"][0]["baseUrl"]
304 | _LOGGER.debug("视频下载链接获取成功,正在下载视频中的音频流")
305 | bvid = await video.bvid
306 | # 下载视频中的音频流
307 | async with httpx.AsyncClient() as client:
308 | resp = await client.get(audio_url, headers=HEADERS)
309 | temp_dir = self.temp_dir
310 | if not os.path.exists(temp_dir):
311 | os.mkdir(temp_dir)
312 | with open(f"{temp_dir}/{bvid} temp.m4s", "wb") as f:
313 | f.write(resp.content)
314 | _LOGGER.debug("视频中的音频流下载成功,正在转换音频格式")
315 | # 转换音频格式
316 | (ffmpeg.input(f"{temp_dir}/{bvid} temp.m4s").output(f"{temp_dir}/{bvid} temp.mp3").run(overwrite_output=True))
317 | _LOGGER.debug("音频格式转换成功,正在使用whisper转写音频")
318 | # 使用whisper转写音频
319 | audio_path = f"{temp_dir}/{bvid} temp.mp3"
320 | text = await self.asr.transcribe(audio_path)
321 | if text is None:
322 | _LOGGER.warning("音频转写失败,报告并重试")
323 | self.asr_router.report_error(self.asr.alias)
324 | await self._get_subtitle_from_asr(video, _uuid, is_retry=True) # 递归,应该不会爆栈
325 | _LOGGER.debug("音频转写成功,正在删除临时文件")
326 | # 删除临时文件
327 | os.remove(f"{temp_dir}/{bvid} temp.m4s")
328 | os.remove(f"{temp_dir}/{bvid} temp.mp3")
329 | _LOGGER.debug("临时文件删除成功")
330 | return text
331 |
332 | async def _smart_get_subtitle(
333 | self, video: BiliVideo, _uuid: str, format_video_name: str, task: BiliGPTTask
334 | ) -> Optional[str]:
335 | """根据用户配置智能获取字幕"""
336 | _LOGGER = self._LOGGER
337 | subtitle_url = await video.get_video_subtitle(page_index=0)
338 | if subtitle_url is None:
339 | if self.asr is None:
340 | _LOGGER.warning(f"视频{format_video_name}没有字幕,你没有可用的asr,跳过处理")
341 | await self._set_err_end(_uuid, "视频没有字幕,你没有可用的asr,跳过处理")
342 | return None
343 | _LOGGER.warning(f"视频{format_video_name}没有字幕,开始使用asr转写,这可能会导致字幕质量下降")
344 | text = await self._get_subtitle_from_asr(video, _uuid)
345 | task.subtitle = text
346 | self.task_status_recorder.update_record(_uuid, new_task_data=task, use_whisper=True)
347 | return text
348 | _LOGGER.debug(f"视频{format_video_name}有字幕,开始处理")
349 | text = await self._get_subtitle_from_bilibili(video)
350 | return text
351 |
352 | def _create_record(self, task: BiliGPTTask) -> str:
353 | """创建(或查询)一条任务记录,返回uuid"""
354 | task.gmt_start_process = int(time.time())
355 | _item_uuid = self.task_status_recorder.create_record(task)
356 | return _item_uuid
357 |
358 | @abc.abstractmethod
359 | async def main(self):
360 | """
361 | 处理链主函数
362 | 捕获错误的最佳实践是使用tenacity.retry装饰器,callback也已经写好了,就在utils.callback中
363 | 如果实现_on_start的话别忘了在循环代码前调用
364 |
365 | eg:
366 | @tenacity.retry(
367 | retry=tenacity.retry_if_exception_type(Exception),
368 | wait=tenacity.wait_fixed(10),
369 | before_sleep=chain_callback
370 | )
371 | :return:
372 | """
373 | pass
374 |
375 | @abc.abstractmethod
376 | async def _on_start(self):
377 | """
378 | 在这里完成一些初始化任务,比如将已保存的未处理任务恢复到队列中
379 | 记得在main中调用啊!
380 | """
381 | pass
382 |
383 | @abc.abstractmethod
384 | async def retry(self, *args, **kwargs):
385 | """
386 | 重试函数,你可以在这里写重试逻辑
387 | 这要求你必须在main函数中捕捉错误并进行调用该函数
388 | 因为llm的返回内容具有不稳定性,所以我强烈建议你实现这个函数。
389 | """
390 | pass
391 |
392 | def __repr__(self):
393 | return self.__class__.__name__
394 |
--------------------------------------------------------------------------------
/src/listener/bili_listen.py:
--------------------------------------------------------------------------------
1 | """监听bilibili平台的私信、at消息"""
2 |
3 | import asyncio
4 | import os
5 | import re
6 | import time
7 | import traceback
8 | from copy import deepcopy
9 | from datetime import datetime
10 |
11 | from apscheduler.schedulers.asyncio import AsyncIOScheduler
12 | from bilibili_api import session, user
13 | from injector import inject
14 |
15 | from src.bilibili.bili_credential import BiliCredential
16 | from src.bilibili.bili_video import BiliVideo
17 | from src.core.routers.chain_router import ChainRouter
18 | from src.models.config import Config
19 | from src.models.task import BiliAtSpecialAttributes, BiliGPTTask
20 | from src.utils.logging import LOGGER
21 | from src.utils.queue_manager import QueueManager
22 | from src.utils.up_video_cache import get_up_file, load_cache, set_cache
23 |
24 | _LOGGER = LOGGER.bind(name="bilibili-listener")
25 |
26 |
27 | class Listen:
28 | @inject
29 | def __init__(
30 | self,
31 | credential: BiliCredential,
32 | queue_manager: QueueManager,
33 | # value_manager: GlobalVariablesManager,
34 | config: Config,
35 | schedule: AsyncIOScheduler,
36 | chain_router: ChainRouter,
37 | ):
38 | self.sess = None
39 | self.credential = credential
40 | self.summarize_queue = queue_manager.get_queue("summarize")
41 | self.evaluate_queue = queue_manager.get_queue("evaluate")
42 | self.last_at_time = int(time.time()) # 当前时间作为初始时间戳
43 | self.sched = schedule
44 | self.user_sessions = {} # 存储用户状态和视频信息
45 | self.config = config
46 | self.chain_router = chain_router
47 | # 以下是自动查询视频更新的参数
48 | # self.cache_path = './data/video_cache.json'
49 | # self.up_file_path = './data/up.json'
50 | self.chain_router = chain_router
51 | self.uids = {}
52 | self.video_cache = {}
53 |
54 | async def listen_at(self):
55 | # global run_time
56 | data: dict = await session.get_at(self.credential)
57 | _LOGGER.debug(f"获取at消息成功,内容为:{data}")
58 |
59 | # if len(data["items"]) != 0:
60 | # if run_time > 2:
61 | # return
62 | # _LOGGER.warning(f"目前处于debug状态,将直接处理第一条at消息")
63 | # await self.dispatch_task(data["items"][0])
64 | # run_time += 1
65 | # return
66 |
67 | # 判断是否有新消息
68 | if len(data["items"]) == 0:
69 | _LOGGER.debug("没有新消息,返回")
70 | return
71 | if self.last_at_time >= data["items"][0]["at_time"]:
72 | _LOGGER.debug(
73 | f"last_at_time{self.last_at_time}大于或等于当前最新消息的at_time{data['items'][0]['at_time']},返回"
74 | )
75 | return
76 |
77 | new_items = []
78 | for item in reversed(data["items"]):
79 | if item["at_time"] > self.last_at_time:
80 | _LOGGER.debug(f"at_time{item['at_time']}大于last_at_time{self.last_at_time},放入新消息队列")
81 | # item["user"] = data["items"]["user"]
82 | new_items.append(item)
83 | if len(new_items) == 0:
84 | _LOGGER.debug("没有新消息,返回")
85 | return
86 | _LOGGER.info(f"检测到{len(new_items)}条新消息,开始处理")
87 | for item in new_items:
88 | task_metadata = await self.build_task_from_at_msg(item)
89 | if task_metadata is None:
90 | continue
91 | await self.chain_router.dispatch_a_task(task_metadata)
92 |
93 | self.last_at_time = data["items"][0]["at_time"]
94 |
95 | async def build_task_from_at_msg(self, msg: dict) -> BiliGPTTask | None:
96 | # print(msg)
97 | try:
98 | event = deepcopy(msg)
99 | if msg["item"]["type"] != "reply" or msg["item"]["business_id"] != 1:
100 | _LOGGER.warning("不是回复消息,跳过")
101 | return None
102 | elif msg["item"]["root_id"] != 0 or msg["item"]["target_id"] != 0:
103 | _LOGGER.warning("该消息是楼中楼消息,暂时不受支持,跳过处理")
104 | return None
105 | event["source_type"] = "bili_comment"
106 | event["raw_task_data"] = deepcopy(msg)
107 | event["source_extra_attr"] = BiliAtSpecialAttributes.model_validate(event["item"])
108 | event["sender_id"] = str(event["user"]["mid"])
109 | event["video_url"] = event["item"]["uri"]
110 | event["source_command"] = event["item"]["source_content"]
111 | # event["mission"] = False
112 | event["video_id"] = await BiliVideo(credential=self.credential, url=event["item"]["uri"]).bvid
113 | task_metadata = BiliGPTTask.model_validate(event)
114 | except Exception:
115 | traceback.print_exc()
116 | _LOGGER.error("在验证任务数据结构时出现错误,跳过处理!")
117 | return None
118 |
119 | return task_metadata
120 |
121 | def start_listen_at(self):
122 | self.sched.add_job(
123 | self.listen_at,
124 | trigger="interval",
125 | seconds=20, # 有新任务都会一次性提交,时间无所谓
126 | id="listen_at",
127 | max_instances=3,
128 | next_run_time=datetime.now(),
129 | )
130 | # self.sched.start()
131 | _LOGGER.info("[定时任务]侦听at消息定时任务注册成功, 每20秒检查一次")
132 |
133 | def start_video_mission(self):
134 | self.sched.add_job(
135 | self.async_video_list_mission,
136 | trigger="interval",
137 | minutes=60, # minutes=60,
138 | id="video_list_mission",
139 | max_instances=3,
140 | next_run_time=datetime.now(),
141 | )
142 | _LOGGER.info("[定时任务]侦听up视频更新任务注册成功, 每60分钟检查一次")
143 |
144 | async def async_video_list_mission(self):
145 | _LOGGER.info("开始执行获取UP的最新视频")
146 | self.video_cache = load_cache(self.config.storage_settings.up_video_cache)
147 | self.uids = get_up_file(self.config.storage_settings.up_file)
148 | for item in self.uids:
149 | u = user.User(uid=item["uid"])
150 | try:
151 | media_list = await u.get_media_list(ps=1, desc=True)
152 | except Exception:
153 | traceback.print_exc()
154 | _LOGGER.error(f"在获取 uid{item} 的视频列表时出错!")
155 | return None
156 | media = media_list["media_list"][0]
157 | bv_id = media["bv_id"]
158 | _LOGGER.info(f"当前视频的bvid:{bv_id}")
159 | oid = media["id"]
160 | _LOGGER.info(f"当前视频的oid:{oid}")
161 | if str(item["uid"]) in self.video_cache:
162 | cache_bvid = self.video_cache[str(item["uid"])]["bv_id"]
163 | _LOGGER.info(f"缓存文件中的bvid:{cache_bvid}")
164 | if cache_bvid != bv_id:
165 | _LOGGER.info(f"up有视频更新,视频信息为:\n 作者:{item['username']} 标题:{media['title']}")
166 | # 将视频信息传递给消息队列
167 | task_metadata = await self.build_task_from_at_mission(media)
168 | if task_metadata is None:
169 | continue
170 | await self.chain_router.dispatch_a_task(task_metadata)
171 | set_cache(
172 | self.config.storage_settings.up_video_cache,
173 | self.video_cache,
174 | {"bv_id": bv_id, "oid": oid},
175 | str(item["uid"]),
176 | )
177 |
178 | else:
179 | _LOGGER.info("up没有视频更新")
180 |
181 | else:
182 | _LOGGER.info("缓存文件为空,第一次写入数据")
183 | # self.set_cache({'bv_id': media, 'oid': oid}, str(item['uid']))
184 | _LOGGER.info("休息20秒")
185 | await asyncio.sleep(20)
186 |
187 | async def build_task_from_at_mission(self, msg: dict) -> BiliGPTTask | None:
188 | # print(msg)
189 | try:
190 | # event = deepcopy(msg)
191 | event: dict = {}
192 | event["source_type"] = "bili_up"
193 | event["raw_task_data"] = {
194 | "user": {
195 | "mid": self.config.bilibili_cookie.dedeuserid,
196 | "nickname": self.config.bilibili_self.nickname,
197 | }
198 | }
199 | event["source_extra_attr"] = BiliAtSpecialAttributes.model_validate(
200 | {
201 | "source_id": msg["id"],
202 | "target_id": 0,
203 | "root_id": 0,
204 | "native_uri": msg["link"],
205 | "at_details": [
206 | {
207 | "mid": self.config.bilibili_cookie.dedeuserid,
208 | "fans": 0,
209 | "nickname": self.config.bilibili_self.nickname,
210 | "avatar": "http://i1.hdslb.com/bfs/face/d21cf99c96dfdca5e38106c00eb338dd150b4b65.jpg",
211 | "mid_link": "",
212 | "follow": False,
213 | }
214 | ],
215 | }
216 | )
217 | event["sender_id"] = self.config.bilibili_cookie.dedeuserid
218 | event["video_url"] = msg["short_link"]
219 | event["source_command"] = f"@{self.config.bilibili_self.nickname} 总结一下"
220 | event["video_id"] = msg["bv_id"]
221 | # event["mission"] = True
222 | task_metadata = BiliGPTTask.model_validate(event)
223 | except Exception:
224 | traceback.print_exc()
225 | _LOGGER.error("在验证任务数据结构时出现错误,跳过处理!")
226 | return None
227 |
228 | return task_metadata
229 |
230 | async def build_task_from_private_msg(self, msg: dict) -> BiliGPTTask | None:
231 | try:
232 | event = deepcopy(msg)
233 | bvid = event["video_event"]["content"]
234 | uri = "https://bilibili.com/video/" + bvid
235 | event["source_type"] = "bili_private"
236 | event["raw_task_data"] = deepcopy(msg)
237 | event["raw_task_data"]["video_event"]["content"] = bvid
238 | event["sender_id"] = event["text_event"]["sender_uid"]
239 | event["video_url"] = uri
240 | event["source_command"] = event["text_event"]["content"][12:] # 去除掉bv号
241 | event["video_id"] = bvid
242 | # event["mission"] = False
243 | del event["video_event"]
244 | del event["text_event"]
245 | del event["status"]
246 | task_metadata = BiliGPTTask.model_validate(event)
247 | except Exception:
248 | traceback.print_exc()
249 | _LOGGER.error("在验证任务数据结构时出现错误,跳过处理!")
250 | return None
251 |
252 | return task_metadata
253 |
254 | async def handle_video(self, user_id, event):
255 | _session = self.user_sessions.get(user_id, {"status": "idle", "text_event": {}, "video_event": {}})
256 | match _session["status"]:
257 | case "idle" | "waiting_for_keyword":
258 | _session["status"] = "waiting_for_keyword"
259 | _session["video_event"] = event
260 | _session["video_event"]["content"] = _session["video_event"]["content"].get_bvid()
261 |
262 | case "waiting_for_video":
263 | _session["video_event"] = event
264 | _session["video_event"]["content"] = _session["video_event"]["content"].get_bvid()
265 | at_items = await self.build_task_from_private_msg(_session)
266 | if at_items is None:
267 | return
268 | await self.chain_router.dispatch_a_task(at_items)
269 | _session["status"] = "idle"
270 | _session["text_event"] = {}
271 | _session["video_event"] = {}
272 | case _:
273 | pass
274 | self.user_sessions[user_id] = _session
275 |
276 | async def handle_text(self, user_id, event):
277 | # _session = PrivateMsgSession(self.user_sessions.get(
278 | # user_id, {"status": "idle", "text_event": {}, "video_event": {}}
279 | # ))
280 | _session = (
281 | self.user_sessions[user_id]
282 | if self.user_sessions.get(user_id, None)
283 | else {"status": "idle", "video_event": {}, "text_event": {}}
284 | )
285 |
286 | match "BV" in event["content"]:
287 | case True:
288 | _LOGGER.debug("检测到消息中包含BV号,开始提取")
289 | # try:
290 | # p1, p2 = event["content"].split(" ") # 简单分离一下关键词与链接
291 | # except Exception as e:
292 | # _LOGGER.error(f"分离关键词与链接失败:{e},返回")
293 | # return
294 | #
295 | # if "BV" in p1:
296 | # bvid = p1
297 | # keyword = p2
298 | # else:
299 | # bvid = p2
300 | # keyword = p1
301 | bvid = event["content"][:12]
302 | if not re.search("^BV[a-zA-Z0-9]{10}$", bvid):
303 | _LOGGER.warning(f"从消息‘{event['content']}’中提取bv号失败!你是不是没把bv号放在消息最前面?!")
304 | return
305 | if _session["status"] in (
306 | "waiting_for_keyword",
307 | "idle",
308 | "waiting_for_video",
309 | ):
310 | _session["video_event"] = {}
311 | _session["video_event"]["content"] = bvid
312 | _session["text_event"] = deepcopy(event)
313 | task_metadata = await self.build_task_from_private_msg(_session)
314 | if task_metadata is None:
315 | return
316 | await self.chain_router.dispatch_a_task(task_metadata)
317 | _session["status"] = "idle"
318 | _session["text_event"] = {}
319 | _session["video_event"] = {}
320 | self.user_sessions[user_id] = _session
321 | return
322 |
323 | match _session["status"]:
324 | case "waiting_for_keyword":
325 | _session["text_event"] = event
326 | task_metadata = await self.build_task_from_private_msg(_session)
327 | if task_metadata is None:
328 | return
329 | # task_metadata = self.build_private_msg_to_at_items(_session["event"]) # type: ignore
330 | # task_metadata["item"]["source_content"] = text # 将文本消息填入at内容
331 | await self.chain_router.dispatch_a_task(task_metadata)
332 | _session["status"] = "idle"
333 | _session["text_event"] = {}
334 | _session["video_event"] = {}
335 |
336 | case "idle":
337 | _session["text_event"] = event
338 | _session["status"] = "waiting_for_video"
339 |
340 | case "waiting_for_video":
341 | _session["text_event"] = event
342 |
343 | case _:
344 | pass
345 | self.user_sessions[user_id] = _session
346 |
347 | async def on_receive(self, event: session.Event):
348 | """接收到视频分享消息时的回调函数"""
349 | _LOGGER.debug(f"接收到私聊消息,内容为:{event}")
350 | data = event.__dict__
351 | if data["msg_type"] == 7:
352 | await self.handle_video(data["sender_uid"], data)
353 | elif data["msg_type"] == 1:
354 | await self.handle_text(data["sender_uid"], data)
355 | else:
356 | _LOGGER.debug(f"未知的消息类型{data['msg_type']}")
357 |
358 | async def listen_private(self):
359 | # TODO 将轮询功能从bilibili_api库分离,重写
360 | self.sess = session.Session(self.credential)
361 | self.sess.logger = _LOGGER
362 | if os.getenv("DEBUG_MODE") == "true": # debug模式下不排除自己发的消息
363 | await self.sess.run(exclude_self=False)
364 | else:
365 | await self.sess.run(exclude_self=True)
366 | self.sess.add_event_listener(str(session.EventType.SHARE_VIDEO.value), self.on_receive) # type: ignore
367 | self.sess.add_event_listener(str(session.EventType.TEXT.value), self.on_receive) # type: ignore
368 |
369 | def close_private_listen(self):
370 | self.sess.close()
371 | _LOGGER.info("私聊侦听已关闭")
372 |
373 | # def get_cache(self):
374 | # with open(self.cache_path, 'r', encoding="utf-8") as f:
375 | # cache = json.loads(f.read())
376 | # return cache
377 | #
378 | # def set_cache(self, data: dict, key: str):
379 | # if key not in self.video_cache:
380 | # self.video_cache[key] = {}
381 | # self.video_cache[key] = data
382 | # with open(self.cache_path, "w") as file:
383 | # file.write(json.dumps(self.video_cache, ensure_ascii=False, indent=4))
384 |
385 | # def get_up_file(self):
386 | # with open(self.up_file_path, 'r', encoding="utf-8") as f:
387 | # up_list = json.loads(f.read())
388 | # return up_list['all_area']
389 |
--------------------------------------------------------------------------------