├── VERSION ├── tests └── __init__.py ├── src ├── asr │ ├── __init__.py │ ├── asr_base.py │ ├── local_whisper.py │ └── openai_whisper.py ├── chain │ ├── __init__.py │ ├── ask_ai.py │ ├── summarize.py │ └── base_chain.py ├── core │ ├── __init__.py │ ├── routers │ │ ├── __init__.py │ │ ├── chain_router.py │ │ ├── llm_router.py │ │ └── asr_router.py │ └── app.py ├── llm │ ├── __init__.py │ ├── gpt.py │ ├── claude.py │ ├── llm_base.py │ ├── templates.py │ └── spark.py ├── models │ ├── __init__.py │ ├── config.py │ └── task.py ├── bilibili │ ├── __init__.py │ ├── bili_credential.py │ ├── bili_session.py │ ├── bili_video.py │ └── bili_comment.py ├── listener │ ├── __init__.py │ └── bili_listen.py └── utils │ ├── logging.py │ ├── exceptions.py │ ├── callback.py │ ├── prompt_utils.py │ ├── up_video_cache.py │ ├── merge_config.py │ ├── file_tools.py │ ├── cache.py │ ├── queue_manager.py │ ├── task_status_record.py │ └── statistic.py ├── .dockerignore ├── .gitignore ├── .github └── workflows │ ├── ruff.yaml │ └── push_to_docker.yaml ├── requirements.txt ├── requirements-docker.txt ├── ruff.toml ├── LICENSE ├── ci └── Dockerfile ├── safe_update.py ├── DEV_README.md ├── config ├── example_config.yml └── docker_config.yml ├── README.md └── main.py /VERSION: -------------------------------------------------------------------------------- 1 | 3.0.4 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/asr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/chain/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/llm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/bilibili/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/core/routers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/listener/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | venv 2 | .idea 3 | .git -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | import loguru 2 | 3 | LOGGER = loguru.logger 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | tests/.pytest_cache 3 | data/ 4 | venv/ 5 | __pycache__/ 6 | config.yml 7 | openaidemo.py -------------------------------------------------------------------------------- /src/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | class ConfigError(Exception): 2 | pass 3 | 4 | 5 | class BilibiliBaseException(Exception): 6 | pass 7 | 8 | 9 | class RiskControlFindError(Exception): 10 | pass 11 | 12 | 13 | class LoadJsonError(Exception): 14 | pass 15 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yaml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | on: [push, pull_request] 3 | jobs: 4 | ruff: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v3 8 | - uses: chartboost/ruff-action@v1 9 | with: 10 | args: --config ./ruff.toml -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bilibili-api-python==16.2.0 2 | loguru 3 | openai==0.28 4 | openai-whisper 5 | APScheduler 6 | pytest 7 | httpx 8 | tenacity 9 | PyYAML 10 | ffmpeg-python 11 | # matplotlib 12 | injector 13 | pydantic 14 | anthropic 15 | ruamel.yaml 16 | pydub 17 | websockets 18 | -------------------------------------------------------------------------------- /requirements-docker.txt: -------------------------------------------------------------------------------- 1 | # 用于构建docker版本,不包含whisper,whisper在构建时会自行加上 2 | bilibili-api-python==16.2.0 3 | loguru 4 | openai==0.28 5 | APScheduler 6 | pytest 7 | httpx 8 | tenacity 9 | PyYAML 10 | ffmpeg-python 11 | # matplotlib 12 | injector 13 | pydantic 14 | anthropic 15 | ruamel.yaml 16 | pydub 17 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 120 # 代码最大行宽 2 | [lint] 3 | select = [ 4 | # pycodestyle 5 | "E", 6 | # Pyflakes 7 | "F", 8 | # pyupgrade 9 | "UP", 10 | # flake8-bugbear 11 | "B", 12 | # flake8-simplify 13 | "SIM", 14 | # isort 15 | "I", 16 | ] 17 | ignore = ["E501"] # 忽略的规则 18 | 19 | [format] 20 | docstring-code-format = true 21 | docstring-code-line-length = 20 22 | quote-style = "double" -------------------------------------------------------------------------------- /src/utils/callback.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | import loguru 4 | 5 | _LOGGER = loguru.logger 6 | 7 | 8 | def chain_callback(retry_state): 9 | """处理链重试回调函数""" 10 | exception = retry_state.outcome.exception() 11 | _LOGGER.error(f"捕获到错误:{exception}") 12 | traceback.print_tb(retry_state.outcome.exception().__traceback__) 13 | _LOGGER.debug(f"当前重试次数为{retry_state.attempt_number}") 14 | _LOGGER.debug(f"下一次重试将在{retry_state.next_action.sleep}秒后进行") 15 | 16 | 17 | def scheduler_error_callback(event): 18 | match event.exception.__class__.__name__: 19 | case "ClientOSError": 20 | _LOGGER.warning(f"捕获到异常:{event.exception} 可能是新版bilibili_api库的问题,接收消息没问题就不用管") 21 | case _: 22 | return 23 | -------------------------------------------------------------------------------- /src/utils/prompt_utils.py: -------------------------------------------------------------------------------- 1 | def parse_prompt(prompt_template, **kwargs): 2 | """解析填充prompt""" 3 | for key, value in kwargs.items(): 4 | prompt_template = prompt_template.replace(f"[{key}]", str(value)) 5 | return prompt_template 6 | 7 | 8 | def build_openai_style_messages(user_msg, system_msg=None, user_keyword="user", system_keyword="system"): 9 | """构建消息 10 | :param user_msg: 用户消息 11 | :param system_msg: 系统消息 12 | :param user_keyword: 用户关键词(这个和下面的system_keyword要根据每个llm不同的要求来填) 13 | :param system_keyword: 系统关键词 14 | :return: 消息列表 15 | """ 16 | messages = [] 17 | if system_msg: 18 | messages.append({"role": system_keyword, "content": system_msg}) 19 | messages.append({"role": user_keyword, "content": user_msg}) 20 | return messages 21 | -------------------------------------------------------------------------------- /src/utils/up_video_cache.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from src.utils.exceptions import LoadJsonError 4 | from src.utils.file_tools import read_file, save_file 5 | from src.utils.logging import LOGGER 6 | 7 | _LOGGER = LOGGER.bind(name="video_cache") 8 | 9 | 10 | def load_cache(file_path): 11 | try: 12 | content = read_file(file_path) 13 | if content: 14 | cache = json.loads(content) 15 | return cache 16 | else: 17 | save_file(json.dumps({}, ensure_ascii=False, indent=4), file_path) 18 | except Exception as e: 19 | raise LoadJsonError("在读取缓存文件时出现问题!程序已停止运行,请自行检查问题所在") from e 20 | 21 | 22 | def set_cache(file_path, cache, data: dict, key: str): 23 | if key not in cache: 24 | cache[key] = {} 25 | cache[key] = data 26 | save_file(json.dumps(cache, ensure_ascii=False, indent=4), file_path) 27 | 28 | 29 | def get_up_file(file_path): 30 | with open(file_path, encoding="utf-8") as f: 31 | up_list = json.loads(f.read()) 32 | return up_list["all_area"] 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 yanyao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/utils/merge_config.py: -------------------------------------------------------------------------------- 1 | """在config模板发生更新时 及时无损合并用户的config""" 2 | 3 | import ruamel.yaml 4 | 5 | # TODO 保持原有格式和注释 6 | 7 | yaml = ruamel.yaml.YAML() 8 | 9 | 10 | def load_config(config_path): 11 | """加载config""" 12 | with open(config_path, encoding="utf-8") as f: 13 | config = yaml.load(f) 14 | return config 15 | 16 | 17 | def is_have_diff(config, template): 18 | """判断用户的config是否有缺失""" 19 | for key in template: 20 | if key not in config: 21 | return True 22 | if isinstance(template[key], dict) and is_have_diff(config[key], template[key]): 23 | return True 24 | return False 25 | 26 | 27 | def merge_config(config, template): 28 | """合并config 如果缺失,则加上这个键和对应默认值""" 29 | for key in template: 30 | if key not in config: 31 | config[key] = template[key] 32 | if isinstance(template[key], dict): 33 | merge_config(config[key], template[key]) 34 | return config 35 | 36 | 37 | def save_config(config, config_path): 38 | """保存config""" 39 | with open(config_path, "w", encoding="utf-8") as f: 40 | yaml.dump(config, f) 41 | -------------------------------------------------------------------------------- /ci/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-buster as base 2 | 3 | WORKDIR /usr/src/app 4 | 5 | COPY ../requirements-docker.txt ./ 6 | 7 | RUN mkdir -p /clone-data/temp /clone-data/whisper-models /clone-data/statistics \ 8 | && touch /clone-data/cache.json && touch /clone-data/records.json && touch /clone-data/queue.json 9 | 10 | RUN apt-get update \ 11 | # && apt-get install -y ffmpeg fonts-wqy-zenhei \ 12 | && apt-get install -y ffmpeg \ 13 | && apt-get clean \ 14 | && pip install --no-cache-dir -r requirements-docker.txt 15 | 16 | COPY .. ./ 17 | COPY ../config/docker_config.yml /clone-data/config.yml 18 | 19 | ENV DOCKER_CONFIG_FILE=/data/config.yml 20 | ENV DOCKER_CACHE_FILE=/data/cache.json 21 | ENV DOCKER_TEMP_DIR=/data/temp 22 | ENV DOCKER_WHISPER_MODELS_DIR=/data/whisper-models 23 | ENV DOCKER_QUEUE_DIR=/data/queue.json 24 | ENV DOCKER_RECORDS_DIR=/data/records.json 25 | ENV DOCKER_STATISTICS_DIR=/data/statistics 26 | ENV DOCKER_UP_FILE=/data/up.json 27 | ENV DOCKER_UP_VIDEO_CACHE=/data/video_cache.json 28 | ENV RUNNING_IN_DOCKER yes 29 | 30 | FROM base as with_whisper 31 | ENV ENABLE_WHISPER yes 32 | RUN pip install --no-cache-dir openai-whisper 33 | 34 | CMD ["python", "main.py"] 35 | 36 | FROM base as without_whisper 37 | ENV ENABLE_WHISPER no 38 | CMD ["python", "main.py"] 39 | -------------------------------------------------------------------------------- /safe_update.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import shutil 4 | import traceback 5 | 6 | from src.utils.file_tools import read_file, save_file 7 | from src.utils.logging import LOGGER 8 | 9 | _LOGGER = LOGGER.bind(name="safe-update") 10 | 11 | 12 | def merge_cache_to_new_version(cache_file_path: str) -> bool: 13 | """迁移老缓存文件到新版本格式""" 14 | content = read_file(cache_file_path) 15 | try: 16 | content_dict: dict = json.loads(content) 17 | if content_dict: 18 | if content_dict.get("summarize") is None and content_dict.get("ask_ai") is None: 19 | # 判断cache文件的内部结构,是否存在summarize或ask_ai键,全都不存在就是老版缓存,要转换 20 | _LOGGER.warning("缓存似乎是旧版的,尝试转换为新格式") 21 | _LOGGER.debug(f"备份老版缓存到{cache_file_path}.bak") 22 | shutil.copy(cache_file_path, cache_file_path + ".bak") 23 | new_summarize_dict = copy.deepcopy(content_dict) 24 | new_content_dict = { 25 | "summarize": new_summarize_dict, 26 | "ask_ai": {}, 27 | } # 老版本缓存只存在于只有summarize处理链的时代,直接这样转换 28 | save_file( 29 | json.dumps(new_content_dict, ensure_ascii=False, indent=4), 30 | cache_file_path, 31 | ) 32 | _LOGGER.info("转换缓存完成!") 33 | return True 34 | return True 35 | return True 36 | except Exception: 37 | traceback.print_exc() 38 | _LOGGER.error("在尝试转换缓存文件时出现问题!请自行查看!") 39 | return False 40 | -------------------------------------------------------------------------------- /DEV_README.md: -------------------------------------------------------------------------------- 1 | # 简单的开发文档和各模块设计思路 2 | 3 | well, 我的代码大家有目共睹,有、烂,所以这个大家看看乐呵乐呵就好 4 | 5 | ## 设计思路 6 | 7 | ### ASR和LLM的路由 8 | 9 | (我姑且在下文把各种拓展的llm和asr称为“插件”) 10 | 11 | 1. 让用户在配置文件中填写各个的优先级,和启用状态 12 | 2. 调度器在初始化时,根据优先级和启用状态,生成一个路由字典,字典中的每个插件形如:`{"asr_alias": { 13 | "priority": priority, 14 | "enabled": enabled, 15 | "prepared": False, 16 | "err_times": 0, 17 | "obj": asr_object 18 | }, ...}` 19 | 3. 当其他模块需要使用时,调用路由器的`get_one()`方法,返回一个插件实例(如果该插件没被初始化,即`prepared=False` 20 | ,就先进行初始化再传回实例) 21 | 4. 当某个插件出错时,调用路由器的`report_error()`方法,将该插件的`err_times`加一,如果`err_times` 22 | 超过了阈值(现在为10),就将该插件的`enabled`置为False 23 | 5. 当所有插件都不可用时,调用路由器的`get_one()`方法,会返回None,此时其他模块会通过设置event直接关闭整个程序 24 | 25 | ### 采用依赖注入 26 | 27 | 在之前的代码中,出现了很多各种实例来回传递的情况,虽然都是在`main.py`中进行的,不至于太难维护。但确实有够难看... 28 | 29 | 后来在看其他项目时,了解到python竟然也有inject这种依赖注入包,于是就在代码重构时将项目全部变成依赖注入。具体的代码可以去`core/app.py` 30 | 中看,这里就不赘述了。 31 | 32 | ### 处理链和视频状态记录的实现 33 | 34 | 处理链说实话不太好抽象,不同操作实现的差异性太多了。但是也是有共性的,比如既然面对的对象是b站视频,所以一定会对字幕、视频基本信息进行操作,所以封装了一下这些操作,让处理链的实现者只需要关注自己的操作就好了。 35 | 36 | 现在的视频状态记录还是有很大的问题,因为原来的视频记录就是为了完全适配“视频总结”这个功能设计的,没考虑到后期可能会支持其他操作,所以很多的字段都无法共用,后期等待重构吧...... 37 | 38 | 一直说要实现一些新的处理链,太忙了一直没弄,有空了试一下,看看我这个设计是不是一坨答辩 39 | 40 | ## 开发文档 41 | 42 | ### ASR及LLM插件开发文档 43 | 44 | 应该实现什么我在`llm_base.py`和`asr_base.py`中都写了,自己去看吧 45 | 46 | 路由器在处理时默认将你这个插件的类名转为下划线命名后作为alias(eg. 47 | 你类名为LocalWhisper,那生成的alias就为local_whisper),也会依据这个alias去config中寻找配置文件,所以你的配置文件和类名一定要一致。 48 | 49 | 别忘了在`utils/models.py`中实现你插件的配置模型,并模仿其他插件在主Config中引用 50 | 51 | ### 处理链开发文档 52 | 53 | 好吧,我实话实说,我还没试过抽象后的处理链到底好不好添加功能。 54 | 55 | 不过我能确定的一点是现在处理链还没实现热插拔,如果你要实现新功能,需要修改`bilibili/listen.py`和`main.py`,仿照着摘要处理链进行修改。 56 | 57 | `base_chain.py`这个基类起码注释是挺完善了,希望你顺利~ -------------------------------------------------------------------------------- /src/utils/file_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | 4 | from src.utils.logging import LOGGER 5 | 6 | _LOGGER = LOGGER 7 | 8 | 9 | def read_file(file_path: str, mode: str = "r", encoding: str = "utf-8"): 10 | """ 11 | 读取一个文件,如果不存在就创建这个文件及其所有中间路径 12 | :param encoding: utf-8 13 | :param mode: 读取的模式(默认为只读) 14 | :param file_path: 15 | :return: 文件内容 16 | """ 17 | dir_path = os.path.dirname(file_path) 18 | try: 19 | if not os.path.exists(dir_path): 20 | os.makedirs(dir_path, exist_ok=True) 21 | with open(file_path, mode, encoding=encoding) as file: 22 | return file.read() 23 | except FileNotFoundError: 24 | with open(file_path, "w", encoding=encoding): 25 | pass 26 | except Exception: 27 | _LOGGER.error("在读取文件时发生意料外的问题,返回空值") 28 | traceback.print_exc() 29 | return "" 30 | with open(file_path, mode, encoding=encoding) as f_r: 31 | return f_r.read() 32 | 33 | 34 | def save_file(content: str, file_path: str, mode: str = "w", encoding: str = "utf-8") -> bool: 35 | """ 36 | 保存一个文件,如果不存在就创建这个文件及其所有中间路径 37 | :param content: 38 | :param file_path: 39 | :param mode: 默认为w 40 | :param encoding: utf8 41 | :return: bool 42 | """ 43 | dir_path = os.path.dirname(file_path) 44 | try: 45 | if not os.path.exists(dir_path): 46 | os.makedirs(dir_path, exist_ok=True) 47 | with open(file_path, mode, encoding=encoding) as file: 48 | file.write(content) 49 | return True 50 | except Exception: 51 | _LOGGER.error("在读取文件时发生意料外的问题,返回空值") 52 | traceback.print_exc() 53 | return False 54 | -------------------------------------------------------------------------------- /src/asr/asr_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import re 3 | from typing import Optional 4 | 5 | from src.core.routers.llm_router import LLMRouter 6 | from src.models.config import Config 7 | 8 | 9 | class ASRBase: 10 | """ASR基类,所有ASR子类都应该继承这个类""" 11 | 12 | def __init__(self, config: Config, llm_router: LLMRouter): 13 | self.config = config 14 | self.llm_router = llm_router 15 | 16 | def __new__(cls, *args, **kwargs): 17 | """将类名转换为alias""" 18 | instance = super().__new__(cls) 19 | name = cls.__name__ 20 | name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) 21 | name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() 22 | instance.alias = name 23 | return instance 24 | 25 | @abc.abstractmethod 26 | def prepare(self) -> None: 27 | """ 28 | 准备方法,例如加载模型 29 | 会在该类被 **第一次** 使用时调用 30 | 这个函数不应该有入参,所有参数都从self.config中获取 31 | """ 32 | pass 33 | 34 | @abc.abstractmethod 35 | async def transcribe(self, audio_path: str, **kwargs) -> Optional[str]: 36 | """ 37 | 转写方法 38 | 该方法最好只传入音频路径,返回转写结果,对于其他配置参数需要从self.config中获取 39 | 注意,这个方法中的转写部分不能阻塞,否则你要实现下方的_wait_transcribe方法,并在这里采用 **线程池** 方式调用 40 | None建议当且仅当在转写失败时返回,因为当接收方收到None时会报告错误,当错误计数达到一定值时会停止使用该ASR 41 | """ 42 | pass 43 | 44 | def _sync_transcribe(self, audio_path: str, **kwargs) -> Optional[str]: 45 | """ 46 | 阻塞转写方法,选择性实现 47 | """ 48 | pass 49 | 50 | async def after_process(self, text: str, **kwargs) -> str: 51 | """ 52 | 后处理方法,例如将转写结果塞回llm,获得更高质量的字幕 53 | 直接通过self.llm_router调用llm 54 | 记得代码逻辑一定要是异步 55 | 56 | 如果处理过程中出错应该返回原字幕 57 | """ 58 | pass 59 | 60 | def __repr__(self): 61 | return f"<{self.alias} ASR>" 62 | 63 | def __str__(self): 64 | return self.__class__.__name__ 65 | -------------------------------------------------------------------------------- /.github/workflows/push_to_docker.yaml: -------------------------------------------------------------------------------- 1 | name: Release Docker Image 2 | run-name: Release Docker Image 3 | 4 | on: 5 | workflow_dispatch: 6 | inputs: 7 | version: 8 | description: '版本' 9 | required: false 10 | push: 11 | paths: 12 | - 'VERSION' 13 | 14 | jobs: 15 | docker: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v3 20 | 21 | - name: Read version from VERSION file 22 | id: file_version 23 | run: echo "VERSION=$(cat VERSION)" >> $GITHUB_ENV 24 | 25 | - name: Set version 26 | id: set_version 27 | run: | 28 | if [ "${{ github.event.inputs.version }}" == "" ]; then 29 | echo "::set-output name=version::$VERSION" 30 | else 31 | echo "::set-output name=version::${{ github.event.inputs.version }}" 32 | fi 33 | - name: Set up QEMU 34 | uses: docker/setup-qemu-action@v3 35 | 36 | - name: Set up Docker Buildx 37 | uses: docker/setup-buildx-action@v3 38 | 39 | - name: Login to Docker Hub 40 | uses: docker/login-action@v3 41 | with: 42 | username: ${{ secrets.DOCKERHUB_USERNAME }} 43 | password: ${{ secrets.DOCKERHUB_TOKEN }} 44 | 45 | - name: Build and push with whisper 46 | uses: docker/build-push-action@v5 47 | with: 48 | push: true 49 | tags: | 50 | ${{ secrets.DOCKERHUB_USERNAME }}/bilibili_gpt_helper:with_whisper 51 | ${{ secrets.DOCKERHUB_USERNAME }}/bilibili_gpt_helper:${{ steps.set_version.outputs.version }}_whisper 52 | target: with_whisper 53 | file: ./ci/Dockerfile 54 | 55 | - name: Build and push without whisper 56 | uses: docker/build-push-action@v5 57 | with: 58 | push: true 59 | tags: | 60 | ${{ secrets.DOCKERHUB_USERNAME }}/bilibili_gpt_helper:latest 61 | ${{ secrets.DOCKERHUB_USERNAME }}/bilibili_gpt_helper:${{ steps.set_version.outputs.version }} 62 | target: without_whisper 63 | file: ./ci/Dockerfile -------------------------------------------------------------------------------- /src/llm/gpt.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import traceback 3 | from functools import partial 4 | from typing import Tuple 5 | 6 | import openai 7 | 8 | from src.llm.llm_base import LLMBase 9 | from src.utils.logging import LOGGER 10 | 11 | _LOGGER = LOGGER.bind(name="openai_gpt") 12 | 13 | 14 | class Openai(LLMBase): 15 | def prepare(self): 16 | self.openai = openai 17 | self.openai.api_base = self.config.LLMs.openai.api_base 18 | self.openai.api_key = self.config.LLMs.openai.api_key 19 | 20 | def _sync_completion(self, prompt, **kwargs) -> Tuple[str, int] | None: 21 | """调用openai的Completion API 22 | :param prompt: 输入的文本(请确保格式化为openai的prompt格式) 23 | :param kwargs: 其他参数 24 | :return: 返回生成的文本和token总数 或 None 25 | """ 26 | try: 27 | model = self.config.LLMs.openai.model 28 | resp = self.openai.ChatCompletion.create(model=model, messages=prompt, **kwargs) 29 | _LOGGER.debug(f"调用openai的Completion API成功,API返回结果为:{resp}") 30 | _LOGGER.info( 31 | f"调用openai的Completion API成功,本次调用中,prompt+response的长度为{resp['usage']['total_tokens']}" 32 | ) 33 | resp_msg = resp["choices"][0]["message"]["content"] 34 | if resp_msg.startswith("```json"): 35 | resp_msg = resp_msg[7:] 36 | if resp_msg.endswith("```"): 37 | resp_msg = resp_msg[:-3] 38 | return ( 39 | resp_msg, 40 | resp["usage"]["total_tokens"], 41 | ) 42 | except Exception as e: 43 | _LOGGER.error(f"调用openai的Completion API失败:{e}") 44 | traceback.print_tb(e.__traceback__) 45 | return None 46 | 47 | async def completion(self, prompt, **kwargs) -> Tuple[str, int] | None: 48 | """调用openai的Completion API 49 | :param prompt: 输入的文本(请确保格式化为openai的prompt格式) 50 | :param kwargs: 其他参数 51 | :return: 返回生成的文本和token总数 或 None 52 | """ 53 | loop = asyncio.get_event_loop() 54 | bound_func = partial(self._sync_completion, prompt, **kwargs) 55 | res = await loop.run_in_executor(None, bound_func) 56 | return res 57 | -------------------------------------------------------------------------------- /src/bilibili/bili_credential.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from apscheduler.schedulers.asyncio import AsyncIOScheduler 4 | from bilibili_api import Credential 5 | from injector import inject 6 | 7 | from src.utils.logging import LOGGER 8 | 9 | _LOGGER = LOGGER.bind(name="bilibili-credential") 10 | 11 | 12 | class BiliCredential(Credential): 13 | """B站凭证类,主要增加定时检查cookie是否过期""" 14 | 15 | # noinspection PyPep8Naming,SpellCheckingInspection 16 | @inject 17 | def __init__( 18 | self, 19 | SESSDATA: str, 20 | bili_jct: str, 21 | buvid3: str, 22 | dedeuserid: str, 23 | ac_time_value: str, 24 | sched: AsyncIOScheduler, 25 | ): 26 | """ 27 | 全部强制要求传入,以便于cookie刷新。 28 | 29 | :param SESSDATA: SESSDATA cookie值 30 | :param bili_jct: bili_jct cookie值 31 | :param buvid3: buvid3 cookie值 32 | :param dedeuserid: dedeuserid cookie值 33 | :param ac_time_value: ac_time_value cookie值 34 | """ 35 | super().__init__( 36 | sessdata=SESSDATA, 37 | bili_jct=bili_jct, 38 | buvid3=buvid3, 39 | dedeuserid=dedeuserid, 40 | ac_time_value=ac_time_value, 41 | ) 42 | self.sched = sched 43 | 44 | async def _check_refresh(self): 45 | """ 46 | 检查cookie是否过期 47 | """ 48 | _LOGGER.debug("正在检查cookie是否过期") 49 | if await self.check_refresh(): 50 | _LOGGER.info("cookie过期,正在刷新") 51 | await self.refresh() 52 | _LOGGER.info("cookie刷新成功") 53 | else: 54 | _LOGGER.debug("cookie未过期") 55 | 56 | if await self.check_valid(): 57 | _LOGGER.debug("cookie有效") 58 | else: 59 | _LOGGER.warning("cookie刷新后依旧无效,请关注!") 60 | 61 | def start_check(self): 62 | """ 63 | 开始检查cookie是否过期的定时任务 64 | """ 65 | self.sched.add_job( 66 | self._check_refresh, 67 | trigger="interval", 68 | hours=12, 69 | id="check_refresh", 70 | max_instances=3, 71 | next_run_time=datetime.now(), 72 | ) 73 | _LOGGER.info("[定时任务]检查cookie是否过期定时任务注册成功,每60秒检查一次") 74 | -------------------------------------------------------------------------------- /src/core/routers/chain_router.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from injector import inject 4 | 5 | from src.models.config import Config 6 | from src.models.task import AskAICommandParams, BiliGPTTask, Chains 7 | from src.utils.logging import LOGGER 8 | from src.utils.queue_manager import QueueManager 9 | 10 | _LOGGER = LOGGER.bind(name="chain-router") 11 | 12 | 13 | class ChainRouter: 14 | @inject 15 | def __init__(self, config: Config, queue_manager: QueueManager): 16 | self.config = config 17 | self.queue_manager = queue_manager 18 | self.summarize_queue = None 19 | self.ask_ai_queue = None 20 | self._get_queues() 21 | 22 | def _get_queues(self): 23 | self.summarize_queue = self.queue_manager.get_queue("summarize") 24 | self.ask_ai_queue = self.queue_manager.get_queue("ask_ai") 25 | 26 | async def dispatch_a_task(self, task: BiliGPTTask): 27 | content: str = task.source_command 28 | _LOGGER.info(f"开始处理消息,原始消息内容为:{content}") 29 | summarize_keyword = self.config.chain_keywords.summarize_keywords 30 | ask_ai_keyword = self.config.chain_keywords.ask_ai_keywords 31 | 32 | match content: 33 | case content if any(keyword in content for keyword in summarize_keyword): 34 | keyword = next(keyword for keyword in summarize_keyword if keyword in content) 35 | _LOGGER.info(f"检测到关键字 {keyword} ,放入【总结】队列") 36 | task.chain = Chains.SUMMARIZE 37 | _LOGGER.debug(task) 38 | await self.summarize_queue.put(task) 39 | return 40 | case content if any(keyword in content for keyword in ask_ai_keyword): 41 | keyword = next(keyword for keyword in ask_ai_keyword if keyword in content) 42 | _LOGGER.info(f"任务{task.uuid}:检测到关键字 {keyword} ,开始解析参数:冒号后的问题") 43 | match = re.search(r"[::](.*)", content) 44 | if match: 45 | ask_ai_params = AskAICommandParams.model_validate({"question": match.group(1)}) 46 | task.command_params = ask_ai_params 47 | else: 48 | _LOGGER.warning(f"任务{task.uuid}:没找到冒号,无法提取问题参数,跳过") 49 | return 50 | task.chain = Chains.ASK_AI 51 | _LOGGER.debug(task) 52 | await self.ask_ai_queue.put(task) 53 | return 54 | case _: 55 | _LOGGER.debug("没有检测到关键字,跳过") 56 | -------------------------------------------------------------------------------- /src/utils/cache.py: -------------------------------------------------------------------------------- 1 | """管理视频处理后缓存""" 2 | 3 | import json 4 | 5 | from src.utils.exceptions import LoadJsonError 6 | from src.utils.file_tools import read_file, save_file 7 | from src.utils.logging import LOGGER 8 | 9 | _LOGGER = LOGGER.bind(name="cache") 10 | 11 | 12 | class Cache: 13 | def __init__(self, cache_path: str): 14 | self.cache_path = cache_path 15 | self.cache = {} 16 | self.load_cache() 17 | 18 | def load_cache(self): 19 | """加载缓存""" 20 | # try: 21 | # if not os.path.exists(self.cache_path): 22 | # os.makedirs(os.path.dirname(self.cache_path), exist_ok=True) 23 | # with open(self.cache_path, "w", encoding="utf-8") as f: 24 | # json.dump({}, f, ensure_ascii=False, indent=4) 25 | # with open(self.cache_path, encoding="utf-8") as f: 26 | # self.cache = json.load(f) 27 | # except Exception as e: 28 | # _LOGGER.error(f"加载缓存失败:{e},尝试删除缓存文件并重试") 29 | # traceback.print_exc() 30 | # self.cache = {} 31 | # if os.path.exists(self.cache_path): 32 | # os.remove(self.cache_path) 33 | # _LOGGER.info("已删除缓存文件") 34 | # self.save_cache() 35 | # _LOGGER.info("已重新创建缓存文件") 36 | # self.load_cache() 37 | try: 38 | content = read_file(self.cache_path) 39 | if content: 40 | self.cache = json.loads(content) 41 | else: 42 | self.cache = {} 43 | self.save_cache() 44 | except Exception as e: 45 | raise LoadJsonError("在读取缓存文件时出现问题!程序已停止运行,请自行检查问题所在") from e 46 | 47 | def save_cache(self): 48 | """保存缓存""" 49 | # try: 50 | # with open(self.cache_path, "w", encoding="utf-8") as f: 51 | # json.dump(self.cache, f, ensure_ascii=False, indent=4) 52 | # except Exception as e: 53 | # _LOGGER.error(f"保存缓存失败:{e}") 54 | # traceback.print_exc() 55 | save_file(json.dumps(self.cache, ensure_ascii=False, indent=4), self.cache_path) 56 | 57 | def get_cache(self, key: str, chain: str): 58 | """获取缓存""" 59 | return self.cache.get(chain, {}).get(key) 60 | 61 | def set_cache(self, key: str, value, chain: str): 62 | """设置缓存""" 63 | if chain not in self.cache: 64 | self.cache[chain] = {} 65 | 66 | self.cache[chain][key] = value 67 | self.save_cache() 68 | 69 | def delete_cache(self, key: str): 70 | """删除缓存""" 71 | self.cache.pop(key) 72 | self.save_cache() 73 | 74 | def clear_cache(self): 75 | """清空缓存""" 76 | self.cache = {} 77 | self.save_cache() 78 | 79 | def get_all_cache(self): 80 | """获取所有缓存""" 81 | return self.cache 82 | -------------------------------------------------------------------------------- /config/example_config.yml: -------------------------------------------------------------------------------- 1 | ASRs: 2 | local_whisper: # 本地的whisper 3 | enable: false # 是否启用whisper,在视频无字幕时生成字幕,否则该视频将自动跳过 4 | priority: 50 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的 5 | after_process: false # 是否再使用llm优化生成字幕结果,最终字幕效果会大幅提升 6 | device: cpu # cpu or cuda(仅在运行源代码时可用,docker运行只能选择cpu) 7 | model_dir: /data/whisper-models # 本地模型存放目录,如果更改要映射出来 8 | model_size: tiny # tiny, base, small, medium, large 详细选择请去:https://github.com/openai/whisper 9 | 10 | openai_whisper: 11 | enable: false # 是否启用openai whisper 12 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的 13 | api_base: https://api.openai.com/v1 # 你的openai api base url(多数在使用第三方api供应商时会有,记得url尾缀有/v1) 14 | api_key: '' # 你的openai api key 15 | model: whisper-1 # 有且仅有这一个模型,不要改 16 | after_process: false # 是否再使用llm优化生成字幕结果,最终字幕效果会大幅提升 17 | 18 | 19 | 20 | LLMs: 21 | openai: # 对接gpt 22 | enable: true # 是否启用openai 23 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的 24 | api_base: https://api.openai.com/v1 # 你的openai api base url(多数在使用第三方api供应商时会有,记得url尾缀有/v1) 25 | api_key: '' # 你的openai api key 26 | model: gpt-3.5-turbo-16k # 选择模型,我现在只推荐使用gpt-3.5-turbo-16k,其他模型容纳不了这么大的token,如果你有gpt-4-16k权限,还钱多,请自便 27 | 28 | aiproxy_claude: # 对接aiproxy claude(因为对接方式不同 只能用https://aiproxy.io这家的服务) 29 | enable: true # 是否启用claude 30 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的 31 | api_base: https://api.aiproxy.io/ # 你的claude api base url(多数在使用第三方api供应商时会有) 32 | api_key: '' # 你的claude api key 33 | model: claude-instant-1 # 选择模型,claude-instant-1或claude-2 34 | 35 | spark: # 对接讯飞星火 36 | enable: true # 是否启用讯飞星火 37 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的 38 | spark_url: wss://spark-api.xf-yun.com/v3.5/chat # 你的spark api base url(多数在使用第三方api供应商时会有) 39 | appid: '' # 你的appid 40 | api_key: '' # 你的api_key 41 | api_secret: '' # 你的api_secret 42 | domain: 'generalv3.5' # 要与spark_url对应 43 | 44 | bilibili_self: 45 | nickname: '' 46 | 47 | bilibili_cookie: # https://nemo2011.github.io/bilibili-api/#/get-credential 获取cookie 下面五个值都要填 48 | SESSDATA: '' 49 | ac_time_value: '' 50 | bili_jct: '' 51 | buvid3: '' 52 | dedeuserid: '' 53 | 54 | 55 | chain_keywords: 56 | summarize_keywords: # 用于生成评价的关键词,如果at/私信内容包含以下关键词,将会生成评价(该功能开发中) 57 | - "总结" 58 | - "总结一下" 59 | - "总结一下吧" 60 | - "总结一下吧!" 61 | ask_ai_keywords: 62 | - "问一下" 63 | - "请问" 64 | 65 | 66 | storage_settings: 67 | cache_path: /data/cache.json # 用于缓存已经处理过的视频,如果更改要映射出来 68 | statistics_dir: /data/statistics # 用于存放统计数据,如果更改要映射出来 69 | task_status_records: /data/records.json # 用于记录任务状态,如果更改要映射出来,不能留空 70 | queue_save_dir: /data/queue.json # 用于保存未完成的队列信息,下次运行时恢复 71 | temp_dir: /data/temp # 主要用于下载视频音频生成字幕,如果更改要映射出来 72 | up_video_cache: ./data/video_cache.json 73 | up_file: ./data/up.json 74 | 75 | debug_mode: true # 是否开启debug模式,开启后会打印更多日志,建议开启,以便于查找bug -------------------------------------------------------------------------------- /config/docker_config.yml: -------------------------------------------------------------------------------- 1 | ASRs: 2 | local_whisper: # 本地的whisper 3 | enable: false # 是否启用whisper,在视频无字幕时生成字幕,否则该视频将自动跳过 4 | priority: 50 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的 5 | after_process: false # 是否再使用llm优化生成字幕结果,最终字幕效果会大幅提升 6 | device: cpu # cpu or cuda(仅在运行源代码时可用,docker运行只能选择cpu) 7 | model_dir: /data/whisper-models # 本地模型存放目录,如果更改要映射出来 8 | model_size: tiny # tiny, base, small, medium, large 详细选择请去:https://github.com/openai/whisper 9 | 10 | openai_whisper: 11 | enable: false # 是否启用openai whisper 12 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的 13 | api_base: https://api.openai.com/v1 # 你的openai api base url(多数在使用第三方api供应商时会有,记得url尾缀有/v1) 14 | api_key: '' # 你的openai api key 15 | model: whisper-1 # 有且仅有这一个模型,不要改 16 | after_process: false # 是否再使用llm优化生成字幕结果,最终字幕效果会大幅提升 17 | 18 | 19 | 20 | LLMs: 21 | openai: # 对接gpt 22 | enable: true # 是否启用openai 23 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的 24 | api_base: https://api.openai.com/v1 # 你的openai api base url(多数在使用第三方api供应商时会有,记得url尾缀有/v1) 25 | api_key: '' # 你的openai api key 26 | model: gpt-3.5-turbo-16k # 选择模型,我现在只推荐使用gpt-3.5-turbo-16k,其他模型容纳不了这么大的token,如果你有gpt-4-16k权限,还钱多,请自便 27 | 28 | aiproxy_claude: # 对接aiproxy claude(因为对接方式不同 只能用https://aiproxy.io这家的服务) 29 | enable: true # 是否启用claude 30 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的 31 | api_base: https://api.aiproxy.io/ # 你的claude api base url(多数在使用第三方api供应商时会有) 32 | api_key: '' # 你的claude api key 33 | model: claude-instant-1 # 选择模型,claude-instant-1或claude-2 34 | 35 | spark: # 对接讯飞星火 36 | enable: true # 是否启用讯飞星火 37 | priority: 100 # 优先级,数字越大优先级越高,程序在选择时会更倾向于选择优先级高的 38 | spark_url: wss://spark-api.xf-yun.com/v3.5/chat # 你的spark api base url(多数在使用第三方api供应商时会有) 39 | appid: '' # 你的appid 40 | api_key: '' # 你的api_key 41 | api_secret: '' # 你的api_secret 42 | domain: 'generalv3.5' # 要与spark_url对应 43 | 44 | bilibili_self: 45 | nickname: '' 46 | 47 | bilibili_cookie: # https://nemo2011.github.io/bilibili-api/#/get-credential 获取cookie 下面五个值都要填 48 | SESSDATA: '' 49 | ac_time_value: '' 50 | bili_jct: '' 51 | buvid3: '' 52 | dedeuserid: '' 53 | 54 | 55 | chain_keywords: 56 | summarize_keywords: # 用于生成评价的关键词,如果at/私信内容包含以下关键词,将会生成评价(该功能开发中) 57 | - "总结" 58 | - "总结一下" 59 | - "总结一下吧" 60 | - "总结一下吧!" 61 | ask_ai_keywords: 62 | - "问一下" 63 | - "请问" 64 | 65 | 66 | storage_settings: 67 | cache_path: /data/cache.json # 用于缓存已经处理过的视频,如果更改要映射出来 68 | statistics_dir: /data/statistics # 用于存放统计数据,如果更改要映射出来 69 | task_status_records: /data/records.json # 用于记录任务状态,如果更改要映射出来,不能留空 70 | queue_save_dir: /data/queue.json # 用于保存未完成的队列信息,下次运行时恢复 71 | temp_dir: /data/temp # 主要用于下载视频音频生成字幕,如果更改要映射出来 72 | up_video_cache: ./data/video_cache.json 73 | up_file: ./data/up.json 74 | 75 | debug_mode: true # 是否开启debug模式,开启后会打印更多日志,建议开启,以便于查找bug 76 | -------------------------------------------------------------------------------- /src/utils/queue_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import traceback 4 | from copy import deepcopy 5 | 6 | from src.models.task import BiliGPTTask 7 | from src.utils.file_tools import read_file, save_file 8 | from src.utils.logging import LOGGER 9 | 10 | _LOGGER = LOGGER 11 | 12 | 13 | class QueueManager: 14 | """队列管理器""" 15 | 16 | def __init__(self): 17 | _LOGGER.debug("初始化队列管理器") 18 | self.queues = {} 19 | self.saved_queue = {} 20 | 21 | def get_queue(self, queue_name: str) -> asyncio.Queue: 22 | if queue_name not in self.queues: 23 | _LOGGER.debug(f"正在创建{queue_name}队列") 24 | self.queues[queue_name] = asyncio.Queue() 25 | return self.queues.get(queue_name) 26 | 27 | def _save(self, file_path: str): 28 | content = json.dumps(self.saved_queue, ensure_ascii=False, indent=4) 29 | save_file(content, file_path) 30 | 31 | def _load(self, file_path: str): 32 | try: 33 | content = read_file(file_path) 34 | if not content: 35 | self.saved_queue = json.loads(content) 36 | else: 37 | self.saved_queue = {} 38 | self._save(file_path) 39 | except Exception: 40 | _LOGGER.error("在读取已保存的队列文件中出现问题!暂时跳过恢复,但不影响使用,请自行检查!") 41 | traceback.print_exc() 42 | self.saved_queue = {} 43 | 44 | def safe_close_queue(self, queue_name: str, saved_json_path: str): 45 | """ 46 | 安全关闭队列(会对队列中任务进行保存) 47 | :param queue_name: 队列名 48 | :param saved_json_path: 保存位置 49 | :return: 50 | """ 51 | queue_list = [] 52 | queue = self.get_queue(queue_name) 53 | while not queue.empty(): 54 | item: BiliGPTTask = queue.get_nowait() 55 | queue_list.append(item) 56 | _LOGGER.debug(f"共保存了{len(queue_list)}条数据!") 57 | self.saved_queue[queue_name] = queue_list 58 | self._save(saved_json_path) 59 | 60 | def safe_close_all_queues(self, saved_json_path: str): 61 | """ 62 | 保存所有的队列(我建议使用这个) 63 | :param saved_json_path: 64 | :return: 65 | """ 66 | for queue in list(self.saved_queue.keys()): 67 | _LOGGER.debug(f"保存{queue}队列的任务") 68 | self.safe_close_queue(queue, saved_json_path) 69 | 70 | def recover_queue(self, saved_json_path: str): 71 | """ 72 | 恢复保存在文件中的任务信息 73 | :param saved_json_path: 74 | :return: 75 | """ 76 | self._load(saved_json_path) 77 | _queue_dict = deepcopy(self.saved_queue) 78 | for queue_name in list(self.saved_queue.keys()): 79 | _LOGGER.debug(f"开始恢复{queue_name}") 80 | queue = self.get_queue(queue_name) 81 | for task in self.saved_queue[queue_name]: 82 | queue.put_nowait(task) 83 | del _queue_dict[queue_name] 84 | self.saved_queue = _queue_dict 85 | self._save(saved_json_path) 86 | -------------------------------------------------------------------------------- /src/llm/claude.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from typing import Tuple 3 | 4 | import anthropic 5 | 6 | from src.llm.llm_base import LLMBase 7 | from src.llm.templates import Templates 8 | from src.utils.logging import LOGGER 9 | from src.utils.prompt_utils import parse_prompt 10 | 11 | _LOGGER = LOGGER.bind(name="aiproxy_claude") 12 | 13 | 14 | class AiproxyClaude(LLMBase): 15 | def prepare(self): 16 | mask_key = self.config.LLMs.aiproxy_claude.api_key[:-5] + "*****" 17 | _LOGGER.info(f"初始化AIProxyClaude,api_key为{mask_key},api端点为{self.config.LLMs.aiproxy_claude.api_base}") 18 | 19 | async def completion(self, prompt, **kwargs) -> Tuple[str, int] | None: 20 | """调用claude的Completion API 21 | :param prompt: 输入的文本(请确保格式化为openai的prompt格式) 22 | :param kwargs: 其他参数 23 | :return: 返回生成的文本和token总数 或 None 24 | """ 25 | try: 26 | claude = anthropic.AsyncAnthropic( 27 | api_key=self.config.LLMs.aiproxy_claude.api_key, 28 | base_url=self.config.LLMs.aiproxy_claude.api_base, 29 | ) 30 | resp = await claude.completions.create( 31 | prompt=prompt, 32 | max_tokens_to_sample=1000, 33 | model=self.config.LLMs.aiproxy_claude.model, 34 | **kwargs, 35 | ) 36 | _LOGGER.debug(f"调用claude的Completion API成功,API返回结果为:{resp}") 37 | _LOGGER.info( 38 | f"调用claude的Completion API成功,本次调用中,prompt+response的长度为{resp.model_dump()['usage']['total_tokens']}" 39 | ) 40 | if not resp.completion.startswith('{"'): 41 | # claude好像是直接从assistant所给内容之后续写的,给它加上缺失的前缀 42 | resp.completion = '{"' + resp.completion 43 | return ( 44 | resp.completion, 45 | resp.model_dump()["usage"]["total_tokens"], 46 | ) 47 | except Exception as e: 48 | _LOGGER.error(f"调用claude的Completion API失败:{e}") 49 | traceback.print_tb(e.__traceback__) 50 | return None 51 | 52 | @staticmethod 53 | def use_template( 54 | user_template_name: Templates, 55 | system_template_name: Templates = None, 56 | user_keyword="user", 57 | system_keyword="system", 58 | **kwargs, 59 | ) -> str | None: 60 | try: 61 | template_user = user_template_name.value 62 | template_system = system_template_name.value if system_template_name else None 63 | utemplate = parse_prompt(template_user, **kwargs) 64 | stemplate = parse_prompt(template_system, **kwargs) if template_system else None 65 | prompt = f"I will give you 'rules' 'content' two tags. You need to follow the rules! {utemplate} {stemplate}" 66 | prompt = anthropic.HUMAN_PROMPT + " " + prompt + " " + anthropic.AI_PROMPT + " " + '{"' 67 | _LOGGER.info("使用模板成功") 68 | _LOGGER.debug(f"生成的prompt为:{prompt}") 69 | return prompt 70 | except Exception as e: 71 | _LOGGER.error(f"使用模板失败:{e}") 72 | traceback.print_exc() 73 | return None 74 | -------------------------------------------------------------------------------- /src/bilibili/bili_session.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Union 3 | 4 | import tenacity 5 | from bilibili_api import session 6 | from bilibili_api.session import EventType 7 | from injector import inject 8 | 9 | from src.bilibili.bili_credential import BiliCredential 10 | from src.bilibili.bili_video import BiliVideo 11 | from src.models.task import AskAIResponse, BiliGPTTask, SummarizeAiResponse 12 | from src.utils.callback import chain_callback 13 | from src.utils.logging import LOGGER 14 | 15 | _LOGGER = LOGGER.bind(name="bilibili-session") 16 | 17 | 18 | class BiliSession: 19 | @inject 20 | def __init__(self, credential: BiliCredential, private_queue: asyncio.Queue): 21 | """ 22 | 初始化BiliSession类 23 | 24 | :param credential: B站凭证 25 | :param private_queue: 私信队列 26 | """ 27 | self.credential = credential 28 | self.private_queue = private_queue 29 | 30 | @staticmethod 31 | async def quick_send(credential, task: BiliGPTTask, msg: str): 32 | """快速发送私信""" 33 | await session.send_msg( 34 | credential, 35 | # at_items["item"]["private_msg_event"]["text_event"]["sender_uid"], 36 | int(task.sender_id), 37 | EventType.TEXT, 38 | msg, 39 | ) 40 | 41 | @staticmethod 42 | def build_reply_content(response: Union[SummarizeAiResponse, AskAIResponse]) -> list: 43 | """构建回复内容(由于有私信消息过长被截断的先例,所以返回是一个list,分消息发)""" 44 | # TODO 有时还是会触碰到b站的字数墙,但不清楚字数限制是多少,再等等看 45 | # TODO 这种判断方式很不优雅,但现在是半夜十二点,我不想改了,我想睡觉了 46 | if isinstance(response, SummarizeAiResponse): 47 | msg_list = [ 48 | f"【视频摘要】{response.summary}", 49 | f"【视频评分】{response.score}分\n\n【咱还想说】{response.thinking}", 50 | ] 51 | elif isinstance(response, AskAIResponse): 52 | msg_list = [f"【回答】{response.answer}\n\n【自我评分】{response.score}分"] 53 | else: 54 | msg_list = [f"程序内部错误:无法识别的回复类型{type(response)}"] 55 | return msg_list 56 | 57 | @tenacity.retry( 58 | retry=tenacity.retry_if_exception_type(Exception), 59 | wait=tenacity.wait_fixed(10), 60 | before_sleep=chain_callback, 61 | ) 62 | async def start_private_reply(self): 63 | """发送评论""" 64 | while True: 65 | try: 66 | data: BiliGPTTask = await self.private_queue.get() 67 | _LOGGER.debug("获取到新的私信任务,开始处理") 68 | _, _type = await BiliVideo(credential=self.credential, url=data.video_url).get_video_obj() 69 | msg_list = BiliSession.build_reply_content(data.process_result) 70 | for msg in msg_list: 71 | await session.send_msg( 72 | self.credential, 73 | # data["item"]["private_msg_event"]["text_event"]["sender_uid"], 74 | int(data.sender_id), 75 | EventType.TEXT, 76 | msg, 77 | ) 78 | await asyncio.sleep(3) 79 | except asyncio.CancelledError: 80 | _LOGGER.info("私信处理链关闭") 81 | return 82 | -------------------------------------------------------------------------------- /src/asr/local_whisper.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | import time 4 | import traceback 5 | from typing import Optional 6 | 7 | from src.asr.asr_base import ASRBase 8 | from src.core.routers.llm_router import LLMRouter 9 | from src.llm.templates import Templates 10 | from src.models.config import Config 11 | from src.utils.logging import LOGGER 12 | 13 | _LOGGER = LOGGER.bind(name="LocalWhisper") 14 | 15 | 16 | class LocalWhisper(ASRBase): 17 | def __init__(self, config: Config, llm_router: LLMRouter): 18 | super().__init__(config, llm_router) 19 | self.llm_router = llm_router 20 | self.model = None 21 | self.config = config 22 | 23 | def prepare(self) -> None: 24 | """ 25 | 加载whisper模型 26 | :return: None 27 | """ 28 | _LOGGER.info( 29 | f"正在加载whisper模型,模型大小{self.config.ASRs.local_whisper.model_size},设备{self.config.ASRs.local_whisper.device}" 30 | ) 31 | import whisper as whi 32 | 33 | self.model = whi.load_model( 34 | self.config.ASRs.local_whisper.model_size, 35 | self.config.ASRs.local_whisper.device, 36 | download_root=self.config.ASRs.local_whisper.model_dir, 37 | ) 38 | _LOGGER.info("加载whisper模型成功") 39 | 40 | async def after_process(self, text, **kwargs) -> str: 41 | llm = self.llm_router.get_one() 42 | prompt = llm.use_template(Templates.AFTER_PROCESS_SUBTITLE, subtitle=text) 43 | answer, _ = await llm.completion(prompt) 44 | if answer is None: 45 | _LOGGER.error("后处理失败,返回原字幕") 46 | return text 47 | return answer 48 | 49 | def _sync_transcribe(self, audio_path, **kwargs) -> Optional[str]: 50 | try: 51 | begin_time = time.perf_counter() 52 | _LOGGER.info(f"开始转写 {audio_path}") 53 | if self.model is None: 54 | return None 55 | import whisper as whi 56 | 57 | text = whi.transcribe(self.model, audio_path) 58 | text = text["text"] 59 | _LOGGER.debug("转写成功") 60 | time_elapsed = time.perf_counter() - begin_time 61 | _LOGGER.info(f"字幕转译完成,共用时{time_elapsed}s") 62 | return text 63 | except Exception as e: 64 | _LOGGER.error(f"转写失败,错误信息为{e}", exc_info=True) 65 | return None 66 | 67 | async def transcribe(self, audio_path, **kwargs) -> Optional[str]: # 添加self参数以访问线程池 68 | loop = asyncio.get_event_loop() 69 | 70 | func = functools.partial(self._sync_transcribe, audio_path, **kwargs) 71 | 72 | result = await loop.run_in_executor(None, func) 73 | w = self.config.ASRs.local_whisper 74 | try: 75 | if w.after_process and result is not None: 76 | bt = time.perf_counter() 77 | _LOGGER.info("正在进行后处理") 78 | text = await self.after_process(result) 79 | _LOGGER.debug(f"后处理完成,用时{time.perf_counter()-bt}s") 80 | return text 81 | return result 82 | except Exception as e: 83 | _LOGGER.error(f"后处理失败,错误信息为{e}") 84 | traceback.print_exc() 85 | return result 86 | -------------------------------------------------------------------------------- /src/llm/llm_base.py: -------------------------------------------------------------------------------- 1 | """llm对接的基础类""" 2 | 3 | import abc 4 | import re 5 | import traceback 6 | from typing import Tuple 7 | 8 | from src.llm.templates import Templates 9 | from src.models.config import Config 10 | from src.utils.logging import LOGGER 11 | from src.utils.prompt_utils import build_openai_style_messages, parse_prompt 12 | 13 | _LOGGER = LOGGER.bind(name="llm_base") 14 | 15 | 16 | class LLMBase: 17 | """实现这个类,即可轻松对接其他的LLM模型""" 18 | 19 | def __init__(self, config: Config): 20 | self.config = config 21 | 22 | def __new__(cls, *args, **kwargs): 23 | """将类名转换为alias""" 24 | instance = super().__new__(cls) 25 | name = cls.__name__ 26 | name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) 27 | name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() 28 | instance.alias = name 29 | return instance 30 | 31 | def prepare(self): 32 | """ 33 | 初始化方法,例如设置参数等 34 | 会在该类被 **第一次** 使用时调用 35 | 这个函数不应该有入参,所有参数都从self.config中获取 36 | """ 37 | pass 38 | 39 | @abc.abstractmethod 40 | async def completion(self, prompt, **kwargs) -> Tuple[str, int] | None: 41 | """使用LLM生成文本(如果出错的话需要在这里自己捕捉错误并返回None) 42 | 请确保整个过程为 **异步**,否则会阻塞整个程序 43 | :param prompt: 最终的输入文本,确保格式化过 44 | :param kwargs: 其他参数 45 | :return: 返回生成的文本和token总数 或 None 46 | """ 47 | pass 48 | 49 | def _sync_completion(self, prompt, **kwargs) -> Tuple[str, int] | None: 50 | """如果你的调用方式为同步,请先在这里实现,然后在completion中使用线程池调用 51 | :param prompt: 最终的输入文本,确保格式化过 52 | :param kwargs: 其他参数 53 | :return: 返回生成的文本和token总数 或 None 54 | """ 55 | pass 56 | 57 | @staticmethod 58 | def use_template( 59 | user_template_name: Templates, 60 | system_template_name: Templates = None, 61 | user_keyword="user", 62 | system_keyword="system", 63 | **kwargs, 64 | ) -> list | None: 65 | """使用模板生成最终prompt(最终格式可能需要根据llm所需格式不同修改,默认为openai的system、user格式) 66 | :param user_template_name: 用户模板名称 67 | :param system_template_name: 系统模板名称 68 | :param user_keyword: 用户关键词(这个和下面的system_keyword要根据每个llm不同的要求来填) 69 | :param system_keyword: 系统关键词 70 | :param kwargs: 模板参数 71 | :return: 返回生成的prompt 或 None 72 | """ 73 | try: 74 | template_user = user_template_name.value 75 | template_system = system_template_name.value if system_template_name else None 76 | utemplate = parse_prompt(template_user, **kwargs) 77 | stemplate = parse_prompt(template_system, **kwargs) if template_system else None 78 | prompt = ( 79 | build_openai_style_messages(utemplate, stemplate, user_keyword, system_keyword) 80 | if stemplate 81 | else build_openai_style_messages(utemplate, user_keyword=user_keyword) 82 | ) 83 | _LOGGER.info("使用模板成功") 84 | _LOGGER.debug(f"生成的prompt为:{prompt}") 85 | return prompt 86 | except Exception as e: 87 | _LOGGER.error(f"使用模板失败:{e}") 88 | traceback.print_exc() 89 | return None 90 | 91 | def __repr__(self): 92 | return self.alias 93 | 94 | def __str__(self): 95 | return self.__class__.__name__ 96 | -------------------------------------------------------------------------------- /src/bilibili/bili_video.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from bilibili_api import ResourceType, parse_link, video 4 | from injector import inject 5 | 6 | from src.bilibili.bili_credential import BiliCredential 7 | 8 | 9 | class BiliVideo: 10 | @inject 11 | def __init__( 12 | self, 13 | credential: BiliCredential, 14 | bvid: str = None, 15 | aid: int = None, 16 | url: str = None, 17 | ): 18 | """ 19 | 三选一,优先级为url > aid > bvid 20 | :param credential: 21 | :param bvid: 22 | :param aid: 23 | :param url: 24 | """ 25 | self.credential = credential 26 | self._bvid = bvid 27 | self.aid = aid 28 | self.url = url 29 | self.video_obj: Optional[video.Video] = None 30 | 31 | async def get_video_obj(self): 32 | _type = ResourceType.VIDEO 33 | if self.video_obj: 34 | return self.video_obj, _type 35 | if self.url: 36 | self.video_obj, _type = await parse_link(self.url, credential=self.credential) 37 | elif self.aid: 38 | self.video_obj = video.Video(aid=self.aid, credential=self.credential) 39 | elif self._bvid: 40 | self.video_obj = video.Video(bvid=self._bvid, credential=self.credential) 41 | else: 42 | raise ValueError("缺少必要参数") 43 | return self.video_obj, _type 44 | 45 | @property 46 | async def get_video_info(self): 47 | if not self.video_obj: 48 | await self.get_video_obj() 49 | return await self.video_obj.get_info() 50 | 51 | @property 52 | async def get_video_pages(self): 53 | if not self.video_obj: 54 | await self.get_video_obj() 55 | return await self.video_obj.get_pages() 56 | 57 | async def get_video_tags(self, page_index: int = 0): 58 | if not self.video_obj: 59 | await self.get_video_obj() 60 | return await self.video_obj.get_tags(page_index=page_index) 61 | 62 | async def get_video_download_url(self, page_index: int = 0): 63 | if not self.video_obj: 64 | await self.get_video_obj() 65 | return await self.video_obj.get_download_url(page_index=page_index) 66 | 67 | async def get_video_subtitle(self, cid: int = None, page_index: int = 0): 68 | """返回字幕链接,如果有多个字幕则优先返回非ai和翻译字幕,如果没有则返回ai字幕""" 69 | if not self.video_obj: 70 | await self.get_video_obj() 71 | if not cid: 72 | cid = await self.video_obj.get_cid(page_index=page_index) 73 | info = await self.video_obj.get_player_info(cid=cid) 74 | json_files = info["subtitle"]["subtitles"] 75 | if len(json_files) == 0: 76 | return None 77 | if len(json_files) == 1: 78 | return json_files[0]["subtitle_url"] 79 | for subtitle in json_files: 80 | if subtitle["lan_doc"] != "中文(自动翻译)" and subtitle["lan_doc"] != "中文(自动生成)": 81 | return subtitle["subtitle_url"] 82 | for subtitle in json_files: 83 | if subtitle["lan_doc"] == "中文(自动翻译)" or subtitle["lan_doc"] == "中文(自动生成)": 84 | return subtitle["subtitle_url"] 85 | 86 | @property 87 | async def bvid(self) -> str: 88 | if not self.video_obj: 89 | await self.get_video_obj() 90 | return self.video_obj.get_bvid() 91 | 92 | @property 93 | async def format_title(self) -> str: 94 | if not self.video_obj: 95 | await self.get_video_obj() 96 | info = await self.video_obj.get_info() 97 | return f"『{info['title']}』" 98 | -------------------------------------------------------------------------------- /src/utils/task_status_record.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import json 3 | from typing import Union 4 | 5 | from src.models.task import BiliGPTTask, Chains, ProcessStages 6 | from src.utils.exceptions import LoadJsonError 7 | from src.utils.file_tools import read_file, save_file 8 | from src.utils.logging import LOGGER 9 | 10 | _LOGGER = LOGGER.bind(name="task-status-record") 11 | 12 | 13 | class TaskStatusRecorder: 14 | """视频状态记录器""" 15 | 16 | def __init__(self, file_path): 17 | self.file_path = file_path 18 | self.video_records = {} 19 | self.load() 20 | 21 | def load(self): 22 | # if os.path.exists(self.file_path): 23 | # with open(self.file_path, encoding="utf-8") as f: 24 | # if os.path.getsize(self.file_path) == 0: 25 | # self.video_records = {} 26 | # else: 27 | # self.video_records = json.load(f) 28 | # else: 29 | # self.video_records = {} 30 | # self.save() 31 | try: 32 | content = read_file(self.file_path) 33 | if content: 34 | self.video_records = json.loads(content) 35 | else: 36 | self.video_records = {} 37 | self.save() 38 | except Exception as e: 39 | raise LoadJsonError("在读取视频记录文件时出现问题!程序已停止运行,请自行检查问题所在") from e 40 | 41 | # except Exception as e: 42 | # _LOGGER.error(f"读取视频状态记录文件失败,错误信息为{e},恢复为初始文件") 43 | # self.video_records = {} 44 | # # with open(self.file_path, "w", encoding="utf-8") as f: 45 | # # json.dump(self.video_records, f, ensure_ascii=False, indent=4) 46 | # save_file(json.dumps(self.video_records), self.file_path) 47 | 48 | def save(self): 49 | # with open(self.file_path, "w", encoding="utf-8") as f: 50 | # json.dump(self.video_records, f, ensure_ascii=False, indent=4) 51 | save_file(json.dumps(self.video_records, ensure_ascii=False, indent=4), self.file_path) 52 | 53 | def get_record_by_stage( 54 | self, 55 | chain: Chains, 56 | stage: ProcessStages = None, 57 | ): 58 | """ 59 | 根据stage获取记录 60 | 当stage为None时,返回所有记录 61 | """ 62 | records = [] 63 | if stage is None: 64 | for record in self.video_records.values(): 65 | if record["chain"] == chain.value: 66 | records.append(record) 67 | return records 68 | for record in self.video_records.values(): 69 | if record["process_stage"] == stage.value and record["chain"] == chain.value: 70 | records.append(record) 71 | return records 72 | 73 | def create_record(self, item: BiliGPTTask): 74 | """创建一条记录,返回一条uuid,可以根据uuid修改记录""" 75 | self.video_records[str(item.uuid)] = item.model_dump(mode="json") 76 | # del self.video_records[item.uuid]["raw_task_data"]["video_event"]["content"] 77 | self.save() 78 | return item.uuid 79 | 80 | def update_record(self, _uuid: str, new_task_data: Union[BiliGPTTask, None], **kwargs) -> bool: 81 | """根据uuid更新记录""" 82 | # record: BiliGPTTask = self.video_records[_uuid] 83 | _uuid = str(_uuid) 84 | if new_task_data is not None: 85 | self.video_records[_uuid] = new_task_data.model_dump(mode="json") 86 | # del self.video_records[_uuid]["raw_task_data"]["video_event"]["content"] 87 | if self.video_records[_uuid] is None: 88 | return False 89 | for key, _value in kwargs.items(): 90 | if isinstance(_value, enum.Enum): 91 | _value = _value.value 92 | if key == "process_stage": 93 | self.video_records[_uuid]["process_stage"] = _value 94 | if key in self.video_records[_uuid]: 95 | self.video_records[_uuid][key] = _value 96 | else: 97 | _LOGGER.warning(f"尝试更新不存在的字段:{key},跳过") 98 | self.save() 99 | return True 100 | 101 | # def get_uuid_by_data(self, data: BiliGPTTask): 102 | # """根据data获取uuid""" 103 | # for _uuid, record in self.tasks.items(): 104 | # if record["data"] == data: 105 | # return _uuid 106 | # return None 107 | 108 | def get_data_by_uuid(self, _uuid: str) -> BiliGPTTask: 109 | """根据uuid获取data""" 110 | return self.video_records[_uuid] 111 | -------------------------------------------------------------------------------- /src/core/routers/llm_router.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | import traceback 4 | from typing import Optional 5 | 6 | from injector import inject 7 | 8 | from src.llm.llm_base import LLMBase 9 | from src.models.config import Config 10 | from src.utils.logging import LOGGER 11 | 12 | _LOGGER = LOGGER.bind(name="LLM-Router") 13 | 14 | 15 | class LLMRouter: 16 | """LLM路由器,用于加载所有LLM子类并进行合理路由""" 17 | 18 | @inject 19 | def __init__(self, config: Config): 20 | self.config = config 21 | self._llm_dict = {} 22 | self.max_err_times = 10 # TODO i know i know,硬编码很不优雅,但这种选项开放给用户似乎也没必要 23 | 24 | def load_from_dir(self, py_style_path: str = "src.llm"): 25 | """ 26 | 从一个文件夹中加载所有LLM子类 27 | 28 | :param py_style_path: 包导入风格的路径,以文件运行位置为基础路径 29 | :return: None 30 | """ 31 | raw_path = "./" + py_style_path.replace(".", "/") 32 | for file_name in os.listdir(raw_path): 33 | if file_name.endswith(".py") and file_name != "__init__.py": 34 | module_name = file_name[:-3] 35 | module = __import__(f"{py_style_path}.{module_name}", fromlist=[module_name]) 36 | for attr_name in dir(module): 37 | attr = getattr(module, attr_name) 38 | if inspect.isclass(attr) and issubclass(attr, LLMBase) and attr != LLMBase: 39 | self.load(attr) 40 | 41 | def load(self, attr): 42 | """加载一个ASR子类""" 43 | try: 44 | _asr = attr(self.config) 45 | setattr(self, _asr.alias, _asr) 46 | _LOGGER.info(f"正在加载 {_asr.alias}") 47 | _config = self.config.model_dump(mode="json")["LLMs"][_asr.alias] 48 | priority = _config["priority"] 49 | enabled = _config["enable"] 50 | if priority is None or enabled is None: 51 | raise ValueError 52 | # 设置属性 53 | self.llm_dict[_asr.alias] = { 54 | "priority": priority, 55 | "enabled": enabled, 56 | "prepared": False, 57 | "err_times": 0, 58 | "obj": self.get(_asr.alias), 59 | } 60 | except Exception as e: 61 | _LOGGER.error(f"加载 {str(attr)} 失败,错误信息为{e}") 62 | traceback.print_exc() 63 | else: 64 | _LOGGER.info(f"加载 {_asr.alias} 成功,优先级为{priority},启用状态为{enabled}") 65 | _LOGGER.debug(f"当前已加载的LLM子类有 {self.llm_dict}") 66 | 67 | @property 68 | def llm_dict(self): 69 | """获取所有已加载的LLM子类""" 70 | return self._llm_dict 71 | 72 | def get(self, name): 73 | """获取一个已加载的LLM子类""" 74 | try: 75 | return getattr(self, name) 76 | except Exception as e: 77 | _LOGGER.error(f"获取LLM子类失败,错误信息为{e}", exc_info=True) 78 | return None 79 | 80 | @llm_dict.setter 81 | def llm_dict(self, value): 82 | self._llm_dict = value 83 | 84 | def order(self): 85 | """ 86 | 对ASR子类进行排序 87 | 优先级高的排在前面,未启用的排在最后""" 88 | self.llm_dict = dict( 89 | sorted( 90 | self.llm_dict.items(), 91 | key=lambda item: ( 92 | not item[1].get("enabled", True), 93 | item[1]["priority"], 94 | ), 95 | reverse=True, 96 | ) 97 | ) 98 | 99 | def get_one(self) -> Optional[LLMBase]: 100 | """根据优先级获取一个可用的LLM子类,如果所有都不可用则返回None""" 101 | self.order() 102 | for llm in self.llm_dict.values(): 103 | if llm["enabled"] and llm["err_times"] <= 10: 104 | if not llm["prepared"]: 105 | _LOGGER.info(f"正在初始化 {llm['obj'].alias}") 106 | llm["obj"].prepare() 107 | llm["prepared"] = True 108 | return llm["obj"] 109 | return None 110 | 111 | def report_error(self, name: str): 112 | """报告一个LLM子类的错误""" 113 | for llm in self.llm_dict.values(): 114 | if llm["obj"].alias == name: 115 | llm["err_times"] += 1 116 | if llm["err_times"] >= self.max_err_times: 117 | llm["enabled"] = False 118 | break 119 | else: 120 | raise ValueError(f"LLM子类 {name} 不存在") 121 | _LOGGER.info(f"{name} 发生错误,已累计错误{llm['err_times']}次") 122 | _LOGGER.debug(f"当前已加载的LLM类有 {self.llm_dict}") 123 | -------------------------------------------------------------------------------- /src/core/routers/asr_router.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | import traceback 4 | from typing import Optional 5 | 6 | from injector import inject 7 | 8 | from src.asr.asr_base import ASRBase 9 | from src.core.routers.llm_router import LLMRouter 10 | from src.models.config import Config 11 | from src.utils.logging import LOGGER 12 | 13 | _LOGGER = LOGGER.bind(name="ASR-Router") 14 | 15 | 16 | class ASRouter: 17 | """ASR路由器,用于加载所有ASR子类并进行合理路由""" 18 | 19 | @inject 20 | def __init__(self, config: Config, llm_router: LLMRouter): 21 | self.config = config 22 | self._asr_dict = {} 23 | self.llm_router = llm_router 24 | self.max_err_times = 10 # TODO i know i know,硬编码很不优雅,但这种选项开放给用户似乎也没必要 25 | 26 | def load_from_dir(self, py_style_path: str = "src.asr"): 27 | """ 28 | 从一个文件夹中加载所有ASR子类 29 | 30 | :param py_style_path: 包导入风格的路径,以文件运行位置为基础路径 31 | :return: None 32 | """ 33 | raw_path = "./" + py_style_path.replace(".", "/") 34 | for file_name in os.listdir(raw_path): 35 | if file_name.endswith(".py") and file_name != "__init__.py": 36 | module_name = file_name[:-3] 37 | module = __import__(f"{py_style_path}.{module_name}", fromlist=[module_name]) 38 | for attr_name in dir(module): 39 | attr = getattr(module, attr_name) 40 | if inspect.isclass(attr) and issubclass(attr, ASRBase) and attr != ASRBase: 41 | self.load(attr) 42 | 43 | def load(self, attr): 44 | """加载一个ASR子类""" 45 | try: 46 | _asr = attr(self.config, self.llm_router) 47 | setattr(self, _asr.alias, _asr) 48 | _LOGGER.info(f"正在加载 {_asr.alias}") 49 | _config = self.config.model_dump(mode="json")["ASRs"][_asr.alias] 50 | priority = _config["priority"] 51 | enabled = _config["enable"] 52 | if priority is None or enabled is None: 53 | raise ValueError 54 | # 设置属性 55 | self.asr_dict[_asr.alias] = { 56 | "priority": priority, 57 | "enabled": enabled, 58 | "prepared": False, 59 | "err_times": 0, 60 | "obj": self.get(_asr.alias), 61 | } 62 | except Exception as e: 63 | _LOGGER.error(f"加载 {str(attr)} 失败,错误信息为{e}") 64 | traceback.print_exc() 65 | else: 66 | _LOGGER.info(f"加载 {_asr.alias} 成功,优先级为{priority},启用状态为{enabled}") 67 | _LOGGER.debug(f"当前已加载的ASR子类有 {self.asr_dict}") 68 | 69 | @property 70 | def asr_dict(self): 71 | """获取所有已加载的ASR子类""" 72 | return self._asr_dict 73 | 74 | def get(self, name): 75 | """获取一个已加载的ASR子类""" 76 | try: 77 | return getattr(self, name) 78 | except Exception as e: 79 | _LOGGER.error(f"获取ASR子类失败,错误信息为{e}", exc_info=True) 80 | return None 81 | 82 | @asr_dict.setter 83 | def asr_dict(self, value): 84 | self._asr_dict = value 85 | 86 | def order(self): 87 | """ 88 | 对ASR子类进行排序 89 | 优先级高的排在前面,未启用的排在最后""" 90 | self.asr_dict = dict( 91 | sorted( 92 | self.asr_dict.items(), 93 | key=lambda item: ( 94 | not item[1].get("enabled"), 95 | item[1]["priority"], 96 | ), 97 | reverse=True, 98 | ) 99 | ) 100 | 101 | def get_one(self) -> Optional[ASRBase]: 102 | """根据优先级获取一个可用的ASR子类,如果所有都不可用则返回None""" 103 | self.order() 104 | for asr in self.asr_dict.values(): 105 | if asr["enabled"] and asr["err_times"] <= 10: 106 | if not asr["prepared"]: 107 | _LOGGER.info(f"正在初始化 {asr['obj'].alias}") 108 | asr["obj"].prepare() 109 | asr["prepared"] = True 110 | return asr["obj"] 111 | LOGGER.error("没有可用的ASR子类") 112 | return None 113 | 114 | def report_error(self, name: str): 115 | """报告一个ASR子类的错误""" 116 | for asr in self.asr_dict.values(): 117 | if asr["obj"].alias == name: 118 | asr["err_times"] += 1 119 | if asr["err_times"] >= self.max_err_times: 120 | asr["enabled"] = False 121 | break 122 | else: 123 | raise ValueError(f"ASR子类 {name} 不存在") 124 | _LOGGER.info(f"{name} 发生错误,已累计错误{asr['err_times']}次") 125 | _LOGGER.debug(f"当前已加载的ASR子类有 {self.asr_dict}") 126 | -------------------------------------------------------------------------------- /src/asr/openai_whisper.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | import json 4 | import os 5 | import time 6 | import traceback 7 | import uuid 8 | from typing import Optional 9 | 10 | import openai 11 | from pydub import AudioSegment 12 | 13 | from src.asr.asr_base import ASRBase 14 | from src.llm.templates import Templates 15 | from src.utils.logging import LOGGER 16 | 17 | _LOGGER = LOGGER.bind(name="OpenaiWhisper") 18 | 19 | 20 | class OpenaiWhisper(ASRBase): 21 | def prepare(self) -> None: 22 | apikey = self.config.ASRs.openai_whisper.api_key 23 | apikey = apikey[:-5] + "*****" 24 | _LOGGER.info(f"初始化OpenaiWhisper,api_key为{apikey},api端点为{self.config.ASRs.openai_whisper.api_base}") 25 | 26 | def _cut_audio(self, audio_path: str) -> list[str]: 27 | """将音频切割为300s的片段,前后有5s的滑动窗口,返回切割后的文件名列表 28 | :param audio_path: 音频文件路径 29 | :return: 切割后的文件名列表 30 | """ 31 | temp = self.config.storage_settings.temp_dir 32 | audio = AudioSegment.from_file(audio_path, "mp3") 33 | segment_length = 300 * 1000 34 | window_length = 5 * 1000 35 | start_time = 0 36 | output_segments = [] 37 | export_file_list = [] 38 | _uuid = uuid.uuid4() 39 | 40 | while start_time < len(audio): 41 | segment = audio[start_time : start_time + segment_length] 42 | 43 | if start_time > 0: 44 | segment = audio[start_time - window_length : start_time] + segment 45 | 46 | if start_time + segment_length < len(audio): 47 | _LOGGER.debug(f"正在处理{start_time}到{start_time+segment_length}的音频") 48 | segment = segment + audio[start_time + segment_length : start_time + segment_length + window_length] 49 | 50 | output_segments.append(segment) 51 | start_time += segment_length 52 | 53 | for num, segment in enumerate(output_segments): 54 | with open(f"{temp}/{_uuid}_segment_{num}.mp3", "wb") as file: 55 | segment.export(file, format="mp3") 56 | _LOGGER.debug(f"第{num}个切片导出完成") 57 | export_file_list.append(f"{_uuid}_segment_{num}.mp3") 58 | 59 | return export_file_list 60 | 61 | def _sync_transcribe(self, audio_path: str, **kwargs) -> Optional[str]: 62 | """同步调用openai的transcribe API 63 | :param audio_path: 音频文件路径 64 | :param kwargs: 其他参数(传递给openai.Audio.transcribe) 65 | :return: 返回识别结果或None 66 | """ 67 | _LOGGER.debug(f"正在识别{audio_path}") 68 | openai.api_key = self.config.ASRs.openai_whisper.api_key 69 | openai.api_base = self.config.ASRs.openai_whisper.api_base 70 | with open(audio_path, "rb") as audio: 71 | response = openai.Audio.transcribe(model="whisper-1", file=audio) 72 | 73 | _LOGGER.debug(f"返回内容为{response}") 74 | 75 | if isinstance(response, dict) and "text" in response: 76 | return response["text"] 77 | try: 78 | response = json.loads(response) 79 | return response["text"] 80 | except Exception: 81 | _LOGGER.error("返回内容不是字典或者没有text字段,返回None") 82 | return None 83 | 84 | async def transcribe(self, audio_path: str, **kwargs) -> Optional[str]: 85 | loop = asyncio.get_event_loop() 86 | func_list = [] 87 | temp = self.config.storage_settings.temp_dir 88 | _LOGGER.info("正在切割音频") 89 | export_file_list = self._cut_audio(audio_path) 90 | _LOGGER.info(f"音频切割完成,共{len(export_file_list)}个切片") 91 | for file in export_file_list: 92 | func_list.append(functools.partial(self._sync_transcribe, f"{temp}/{file}", **kwargs)) 93 | _LOGGER.info("正在处理音频") 94 | result = await asyncio.gather(*[loop.run_in_executor(None, func) for func in func_list]) 95 | _LOGGER.info("音频处理完成") 96 | if None in result: 97 | _LOGGER.error("识别失败,返回None") # TODO 单独重试失败的切片 98 | return None 99 | result = "".join(result) 100 | # 清除临时文件 101 | for file in export_file_list: 102 | os.remove(f"{temp}/{file}") 103 | try: 104 | if self.config.ASRs.openai_whisper.after_process and result is not None: 105 | bt = time.perf_counter() 106 | _LOGGER.info("正在进行后处理") 107 | text = await self.after_process(result) 108 | _LOGGER.debug(f"后处理完成,用时{time.perf_counter()-bt}s") 109 | return text 110 | return result 111 | except Exception as e: 112 | _LOGGER.error(f"后处理失败,错误信息为{e}") 113 | traceback.print_exc() 114 | return result 115 | 116 | async def after_process(self, text: str, **kwargs) -> str: 117 | llm = self.llm_router.get_one() 118 | prompt = llm.use_template(Templates.AFTER_PROCESS_SUBTITLE, subtitle=text) 119 | answer, _ = await llm.completion(prompt) 120 | if answer is None: 121 | _LOGGER.error("后处理失败,返回原字幕") 122 | return text 123 | return answer 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > [!CAUTION] 2 | > 你b已经自己搞了ai课代表,这个项目没啥大用了,所以不再更新 3 |

✨Bilibili GPT Bot✨

4 |
把ChatGPT放到B站里!更高效、快速的了解视频内容
5 | 6 | [![Python 3.11](https://img.shields.io/badge/python-3.11-blue.svg)](https://www.python.org/downloads/release/python-311/) 7 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 8 | [![wakatime](https://wakatime.com/badge/user/41ab10cc-ec82-41e9-8417-9dcf5a9b5947/project/cef4699c-8d07-4cf0-9d0a-ef83fb353b82.svg)](https://wakatime.com/badge/user/41ab10cc-ec82-41e9-8417-9dcf5a9b5947/project/cef4699c-8d07-4cf0-9d0a-ef83fb353b82) 9 | [![release_docker](https://github.com/yanyao2333/BiliGPTHelper/actions/workflows/push_to_docker.yaml/badge.svg)](https://github.com/yanyao2333/BiliGPTHelper/actions/workflows/push_to_docker.yaml) 10 | 11 | ### 🌟 介绍 12 | 13 | b站那种AI课代表都见过吧?看起来超级酷!所以我也搞了个捏 14 | 15 | ### 📜 声明 16 | 17 | 当你查阅、下载了本项目源代码或二进制程序,即代表你接受了以下条款: 18 | 19 | 1. 本项目和项目成果仅供技术,学术交流和Python3性能测试使用\ 20 | 2. 本项目贡献者编写该项目旨在学习Python3 ,提高编程水平\ 21 | 3. 用户在使用本项目和项目成果前,请用户了解并**遵守当地法律法规**,如果本项目及项目成果使用过程中存在违反当地法律法规的行为,请勿使用该项目及项目成果\ 22 | 4. **法律后果及使用后果由使用者承担**\ 23 | 5. 开发者(yanyao2333)**不会** 24 | 将任何用户的cookie、个人信息收集上传至除b站官方的其他平台或服务器。同时,开发者(yanyao2333)不对任何造成的后果负责(包括但不限于账号封禁等后果),如有顾虑,请谨慎使用 25 | 26 | 若用户不同意上述条款**任意一条**,**请勿**使用本项目和项目成果 27 | 28 | ### 😎 特性 29 | 30 | - [x] 使用大模型生成总结、根据视频内容向ai提出问题 31 | - [x] 已支持Claude、Openai大语言模型 32 | - [x] 已支持openai的whisper和本地whisper语音转文字 33 | - [x] 支持热插拔的asr和llm模块,基于优先级和运行稳定情况调度 34 | - [x] 优化prompt,达到更好的效果、更高的信息密度,尽量不再说废话。还可以让LLM输出自己的思考、评分,让摘要更有意思 35 | - [x] 支持llm返回消息格式不对自动修复 36 | - [x] 支持艾特和私信两种触发方式 37 | - [x] 支持视频缓存,避免重复生成 38 | - [x] 可自动检测和刷新b站cookie(实验性功能) 39 | - [x] 支持自定义触发关键词 40 | - [x] 支持保存当前处理进度,下次启动时恢复 41 | - [x] 支持一键生成并不精美的运行报告,包含一些图表 42 | 43 | ### 🚀 使用方法 44 | 45 | #### 一、通过docker运行 46 | 47 | 现在有两个版本的docker,代码都是最新版本,只是包不一样: 48 | 1. latest 这个版本不包含whisper,也就是只能用来总结包含字幕的视频。这个镜像只有200多m,适合偶尔使用 49 | 2. with_whisper 这个版本包含whisper,大小达到了2g,但是可以使用**本地**语音转文字生成字幕 50 | 51 | 52 | 53 | ```shell 54 | docker pull yanyaobbb/bilibili_gpt_helper:latest 55 | 或 56 | yanyaobbb/bilibili_gpt_helper:with_whisper 57 | ``` 58 | 59 | ```shell 60 | docker run -d \ 61 | --name biligpthelper \ 62 | -v 你本地的biligpt配置文件目录:/data \ 63 | yanyaobbb/bilibili_gpt_helper:latest(with_whisper) 64 | ``` 65 | 66 | 首次运行会创建模板文件,编辑config.yml,然后重启容器即可 67 | 68 | #### 二、源代码运行 69 | 70 | 1. 克隆并安装依赖 71 | 72 | ```shell 73 | git clone https://github.com/yanyao2333/BiliGPTHelper.git 74 | cd BiliGPTHelper 75 | pip install -r requirements.txt 76 | ``` 77 | 78 | 2. 编辑config.yml 79 | 80 | 3. 运行,等待初始化完成 81 | 82 | ```shell 83 | python main.py 84 | ``` 85 | 86 | #### 触发命令 87 | 88 | ##### 私信 89 | 90 | 方式一:`转发视频+发送一条包含关键词的消息`(向ai提问的方式与下面讲的相同) \ 91 | 方式二:发送消息:`视频bv号+关键词` **视频bv号必须要在消息最前面!** 92 | 93 | **向ai提问,需要使用 `提问关键词[冒号]你的问题` 的格式 eg.`BVxxxxxxxxxx 问一下:这视频有啥特色`** \ 94 | **获取视频摘要,单独发送`摘要关键词`即可** 95 | 96 | ##### at 97 | 98 | 在评论区发送关键字即可\ 99 | (向ai提问的方式与上面讲的相同) \ 100 | **但是你b评论区风控真的太严格了,体验完全没私信好,还污染评论区(所以,你懂我意思了吧)** 101 | 102 | 103 | ### 💸 这玩意烧钱吗 104 | 105 | #### Claude 106 | 107 | claude才是唯一真神!!!比gpt-3.5-turbo-16k低将近一半的价格(指prompt价格),还有100k的上下文窗口!输出的格式也很稳定,内容质量与3.5不相上下。 108 | 109 | 现在对接了个aiproxy的claude接口,简直香炸...gpt3.5是什么?真不熟 110 | 111 | #### GPT 112 | 113 | 根据我的测试,单纯使用**gpt-3.5-turbo-16k**,20元大概能撑起5000次时长为10min左右的视频总结(在不包括格式不对重试的情况下) 114 | 115 | 但在我的测试中,gpt3.5返回的内容已经十分稳定,我并没有遇到过格式不对的情况,所以这个价格应该是可以接受的 116 | 117 | 如果你出现了格式不对的情况,可以附带bad case示例**发issue**,我尝试优化下prompt 118 | 119 | #### OPENAI WHISPER 120 | 121 | 相比之下这个有点贵,20元能转大概8小时视频 122 | 123 | ### 🤔 目前问题 124 | 125 | 1. 无法回复楼中楼评论的at(我实在搞不懂评论的继承逻辑是什么,设置值了对应id也没法发送楼中楼,我好笨) 126 | 2. 不支持多p视频的总结(这玩意我觉得没法修复,本来都是快要被你b抛弃的东西) 127 | 3. 你b严格的审查机制导致私信/回复消息触碰敏感词会被屏蔽 128 | 129 | ### 📝 TODO 130 | 131 | - [ ] 支持多账号负载均衡 **(on progress)** 132 | - [ ] 能够画思维导图(拜托,这真的超酷的好嘛) 133 | 134 | ### ❤ 感谢 135 | 136 | [Nemo2011/bilibili-api](https://github.com/Nemo2011/bilibili-api/) | 封装b站api库 137 | [JetBrains](https://www.jetbrains.com) | 感谢JetBrains提供的免费license 138 | 139 | ### [开发文档](./DEV_README.md) 140 | 141 | ### 📚 大致流程(更详细内容指路 [开发文档](./DEV_README.md)) 142 | 143 | ```mermaid 144 | sequenceDiagram 145 | participant 用户 146 | participant 监听器 147 | participant 处理链 148 | participant 发送器 149 | participant LLMs 150 | participant ASRs 151 | 用户 ->> 监听器: 发送私信或at消息消息 152 | alt 消息触发关键词 153 | 监听器 ->> 处理链: 触发关键词,开始处理 154 | else 消息不触发关键词 155 | 监听器 ->> 用户: 不触发关键词,不处理 156 | end 157 | 处理链 ->> 处理链: 检查是否有缓存 158 | alt 有缓存 159 | 处理链 ->> 发送器: 有缓存,直接发送 160 | end 161 | 处理链 ->> 处理链: 检查是否有字幕 162 | alt 有字幕 163 | 处理链 ->> LLMs: 根据视频简介、标签、热评、字幕构建prompt并生成摘要 164 | else 没有字幕 165 | 处理链 ->> ASRs: 转译视频 166 | ASRs ->> ASRs: 自动调度asr插件,选择一个发送请求 167 | ASRs ->> 处理链: 转译完成 168 | 处理链 ->> LLMs: 根据视频简介、标签、热评、字幕构建prompt并生成摘要 169 | end 170 | LLMs ->> 处理链: 摘要内容 171 | 处理链 ->> 处理链: 解析摘要是否符合要求 172 | alt 摘要符合要求 173 | 处理链 ->> 发送器: 发送摘要 174 | 发送器 ->> 用户: 发送成功 175 | else 摘要不符合要求 176 | 处理链 ->> 处理链: 使用指定prompt修复摘要 177 | 处理链 ->> 发送器: 发送摘要 178 | 发送器 ->> 用户: 发送成功 179 | end 180 | ``` 181 | -------------------------------------------------------------------------------- /src/llm/templates.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | from enum import Enum 3 | 4 | SUMMARIZE_JSON_RESPONSE = '{summary: "替换为你的总结内容", score: "替换为你为自己打的分数", thinking: "替换为你的思考内容", if_no_need_summary: "是否需要总结?填布尔值"}' 5 | 6 | SUMMARIZE_SYSTEM_PROMPT = f"""你现在是一个专业的视频总结者,下面,我将给你一段视频的字幕、简介、标题、标签、部分评论,你需要根据这些内容,精准、不失偏颇地完整概括这个视频,你既需要保证概括的完整性,同时还需要增加你文字的信息密度,你可以采用调侃或幽默的语气,让语气不晦涩难懂,但请你记住:精准和全面才是你的第一要务,请尽量不要增加太多个人情感色彩和观点,我提供的评论也仅仅是其他观众的一家之言,并不能完整概括视频内容,请不要盲目跟从评论观点。而视频作者可能会为了视频热度加很多不相关的标签,如果你看到与视频内容偏离的标签,直接忽略。请尽量不要用太多的官方或客套话语,就从一个观众的角度写评论,写完后,你还需要自行对你写的内容打分,满分100,请根据你所总结内容来给分。最后,在写总结的过程中,你可以多进行思考,把你的感想写在返回中的thinking部分。如果你觉得这个视频表达的内容过于抽象、玩梗,总结太难,或者仅仅是来搞笑或娱乐的,并没有什么有用信息,你可以直接拒绝总结。注意,最后你的返回格式一定是以下方我给的这个json格式为模板: \n\n{SUMMARIZE_JSON_RESPONSE},比如你可以这样使用:{{"summary": "视频讲述了两个人精彩的对话....","score": "90","thinking": "视频传达了社会标签对个人价值的影响,以幽默方式呈现...","if_no_need_summary": false}},或者是{{"summary": "","score": "","thinking": "","if_no_need_summary": true}},请一定使用此格式!""" 7 | 8 | V2_SUMMARIZE_SYSTEM_PROMPT = ( 9 | "你是一位专业的视频摘要者,你的任务是将一个视频转换为摘要,让观众无需观看视频就了解内容\n" 10 | "我会提供给你一个视频的标题、简介、标签、字幕、部分评论,下面是要求:\n" 11 | "1. 保证摘要完整,不要遗漏重要信息,时刻确保你的摘要具有高信息密度\n" 12 | "2. 保证摘要准确,不要加入自己观点,不要加入视频中没有的内容\n" 13 | "3. 保证摘要的幽默性,不要使用官方客套话语,不要使用过于正式的语言,可以加入emoji\n" 14 | "4. 如果有标签和评论与视频内容无关,忽略,切勿混淆在摘要中\n" 15 | "5. 记得分段,内容长短适中\n" 16 | "6. 你最终给我的内容需要遵守下面的json格式,如果你认为该视频没有意义(比如单纯搞笑,或者看不懂),不需要摘要,可以将noneed设为true,其他字段留空\n" 17 | f"{SUMMARIZE_JSON_RESPONSE}\n" 18 | "你可以这样使用:\n" 19 | '{"summary": "....","score": "90","thinking": "...","if_no_need_summary": false}\n' 20 | "或是\n" 21 | '{"summary": "","score": "","thinking": "","if_no_need_summary": true}\n' 22 | "请一定使用此格式" 23 | ) 24 | 25 | V3_SUMMARIZE_JSON_RESPONSE = '{"summary": "Your summary content", "score": "Score you give to this video(max 100)", "thinking": "Your thoughts", "if_no_need_summary": "Is a summary needed? Fill boolean"}' 26 | 27 | 28 | V3_SUMMARIZE_SYSTEM_PROMPT = ( 29 | f"You are a video summarizer. Summarize a video from the title, description, tags, subtitles, comments. Guidelines:" 30 | f"1. The returned content strictly follows the format of this JSON format: {V3_SUMMARIZE_JSON_RESPONSE}" 31 | "2. Be complete and dense with information." 32 | "3. Be accurate, dont add fantasy information." 33 | "4. Be humorous, avoid formality, use Chinese internet language and emojis if needed." 34 | "5. Ignore irrelevant tags/comments." 35 | "6. Break into paragraphs." 36 | "7. The value of the 'score' part must be of string type." 37 | "8. When you writing 'thinking' part, disregard the rules above and express your independent thought, including your own radical opinions, rather than the video's content.The score is based on this video.If you think you create a meaningful summary, give yourself a high score." 38 | "9. If the video is meaningless, set 'if_no_need_summary' true on this JSON, others set 'if_no_need_summary' false on this JSON." 39 | "10. Only pure JSON content with double quotes is allowed!Please use Simplified Chinese!" 40 | ) 41 | 42 | SUMMARIZE_USER_TEMPLATE = "标题:[title]\n\n简介:[description]\n\n字幕:[subtitle]\n\n标签:[tags]\n\n评论:[comments]" 43 | 44 | RETRY_TEMPLATE = f"请你把我下面提供的这段文字转换成这样的json格式并返回给我,不要加其他东西,如summary字段不存在,设置if_no_need_summary为true。除了summary的其他几个字段不存在均可忽略,对应值留空,if_no_need_summary依旧为false:\n\n标准JSON格式:{V3_SUMMARIZE_JSON_RESPONSE}\n\n我的内容:[input]" 45 | 46 | AFTER_PROCESS_SUBTITLE = ( 47 | "下面是使用语音转文字得到的字幕,你需要修复其中的语法错误、名词错误、如果是繁体中文就转为简体中文:\n\n[subtitle]" 48 | ) 49 | 50 | V2_SUMMARIZE_USER_TEMPLATE = ( 51 | "Title: [title]\n\nDescription: [description]\n\nSubtitles: [subtitle]\n\nTags: [tags]\n\nComments: [comments]" 52 | ) 53 | 54 | V2_SUMMARIZE_RETRY_TEMPLATE = f"Please translate the following text into this JSON format and return it to me without adding anything else. If the 'summary' field does not exist, set 'if_no_need_summary' to true. If fields other than 'summary' are missing, they can be ignored and left blank, and 'if_no_need_summary' remains false\n\nStandard JSON format: {V3_SUMMARIZE_JSON_RESPONSE}\n\nMy content: [input]" 55 | 56 | V2_AFTER_PROCESS_SUBTITLE = "Below are the subtitles obtained through speech-to-text. You need to correct any grammatical errors, noun mistakes, and convert Traditional Chinese to Simplified Chinese if present:\n\n[subtitle]" 57 | 58 | V1_ASK_AI_USER = "Title: [title]\n\nDescription: [description]\n\nSubtitles: [subtitle]\n\nQuestion: [question]" 59 | 60 | V1_ASK_AI_JSON_RESPONSE = '{"answer": "your answer", "score": "your self-assessed quality rating of the answer(0-100)"}' 61 | 62 | V1_ASK_AI_SYSTEM = ( 63 | "You are a professional video Q&A teacher. " 64 | "I will provide you with the video title, description, and subtitles. " 65 | """Based on this information and your expertise, 66 | respond to the user's questions in a lively and humorous manner, 67 | using metaphors and examples when necessary.""" 68 | f"\n\nPlease reply in the following JSON format: {V1_ASK_AI_JSON_RESPONSE}\n\n" 69 | "!!!Only pure JSON content with double quotes is allowed!Please use Chinese!Dont add any other things!!!" 70 | ) 71 | 72 | 73 | class Templates(Enum): 74 | SUMMARIZE_USER = V2_SUMMARIZE_USER_TEMPLATE 75 | SUMMARIZE_SYSTEM = V3_SUMMARIZE_SYSTEM_PROMPT 76 | SUMMARIZE_RETRY = V2_SUMMARIZE_RETRY_TEMPLATE 77 | AFTER_PROCESS_SUBTITLE = V2_AFTER_PROCESS_SUBTITLE 78 | ASK_AI_USER = V1_ASK_AI_USER + "\n\n" + V1_ASK_AI_SYSTEM 79 | # ASK_AI_SYSTEM = V1_ASK_AI_SYSTEM 80 | ASK_AI_SYSTEM = "" 81 | -------------------------------------------------------------------------------- /src/utils/statistic.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | """根据任务状态记录生成统计信息""" 3 | 4 | import json 5 | import os 6 | from collections import Counter 7 | 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def run_statistic(output_dir, data): 13 | if os.getenv("RUNNING_IN_DOCKER") == "yes": 14 | matplotlib.rcParams["font.sans-serif"] = ["WenQuanYi Zen Hei"] 15 | matplotlib.rcParams["axes.unicode_minus"] = False # 用来正常显示负号 16 | else: 17 | matplotlib.rcParams["font.sans-serif"] = ["SimHei"] 18 | matplotlib.rcParams["axes.unicode_minus"] = False 19 | 20 | # Initialize directories and counters 21 | output_folder = output_dir 22 | if not os.path.exists(output_folder): 23 | os.makedirs(output_folder) 24 | else: 25 | for file in os.listdir(output_folder): 26 | os.remove(os.path.join(output_folder, file)) 27 | 28 | # Mapping end reasons to readable names 29 | end_reason_map = { 30 | "normal": "正常结束", 31 | "error": "错误结束", 32 | "if_no_need_summary": "AI认为不需要摘要", 33 | } 34 | 35 | # Initialize variables 36 | end_reasons = [] 37 | error_reasons = [] 38 | user_ids = [] 39 | request_types = [] 40 | 41 | # Populate variables based on task statuses 42 | if "tasks" not in data or not data["tasks"]: 43 | return 44 | for _task_id, task in data["tasks"].items(): 45 | end_reason = task.get("end_reason", "normal") 46 | end_reasons.append(end_reason_map.get(end_reason, "Unknown")) 47 | 48 | error_reason = task.get("error_msg", "正常结束") 49 | error_reasons.append(error_reason) 50 | 51 | task_data = task.get("data", {}) 52 | user_data = task_data.get("user", None) 53 | private_msg_event = task_data.get("item", None).get("private_msg_event", None) 54 | 55 | if user_data: 56 | user_ids.append(user_data.get("mid", "未知")) 57 | elif private_msg_event: 58 | user_ids.append(private_msg_event.get("text_event", {}).get("sender_uid", "未知")) 59 | 60 | if private_msg_event: 61 | request_types.append("私信请求") 62 | else: 63 | request_types.append("At 请求") 64 | 65 | # Data Processing 66 | end_reason_counts = Counter(end_reasons) 67 | error_reason_counts = Counter(error_reasons) 68 | user_id_counts = Counter(user_ids) 69 | request_type_counts = Counter(request_types) 70 | 71 | # Pie Chart for Task End Reasons 72 | plt.figure(figsize=(4, 4)) 73 | plt.pie( 74 | list(end_reason_counts.values()), 75 | labels=list(end_reason_counts.keys()), 76 | autopct="%1.0f%%", 77 | ) 78 | plt.title("任务结束原因") 79 | plt.savefig(f"{output_folder}/任务结束原因饼形图.png") 80 | 81 | # Bar Chart for Error Reasons 82 | plt.figure(figsize=(8, 4)) 83 | bars = plt.barh(list(error_reason_counts.keys()), list(error_reason_counts.values())) 84 | plt.xlabel("数量") 85 | plt.ylabel("错误原因") 86 | plt.title("错误原因排名") 87 | 88 | # 设置x轴刻度为整数 89 | max_value = max(error_reason_counts.values()) 90 | plt.xticks(range(0, max_value + 1)) 91 | 92 | # 在柱子顶端添加数据标签 93 | for bar in bars: 94 | plt.text( 95 | bar.get_width(), 96 | bar.get_y() + bar.get_height() / 2, 97 | str(int(bar.get_width())), 98 | ) 99 | 100 | plt.tight_layout() 101 | 102 | plt.savefig(f"{output_folder}/错误原因排名竖状图.png") 103 | 104 | # Bar Chart for User Task Counts (Top 10) 105 | top_10_users = dict(sorted(user_id_counts.items(), key=lambda x: x[1], reverse=True)[:10]) 106 | plt.figure(figsize=(8, 4)) 107 | bars = plt.barh(list(map(str, top_10_users.keys())), list(top_10_users.values())) 108 | plt.xlabel("数量") 109 | plt.ylabel("用户 ID") 110 | plt.title("用户发起任务次数排名") 111 | max_value = max(top_10_users.values()) 112 | plt.xticks(range(0, max_value + 1)) 113 | for bar in bars: 114 | plt.text( 115 | bar.get_width() - 0.2, 116 | bar.get_y() + bar.get_height() / 2, 117 | str(int(bar.get_width())), 118 | ) 119 | plt.savefig(f"{output_folder}/用户发起任务次数排名竖状图.png") 120 | 121 | # Pie Chart for Request Types 122 | plt.figure(figsize=(4, 4)) 123 | plt.pie( 124 | list(request_type_counts.values()), 125 | labels=list(request_type_counts.keys()), 126 | autopct="%1.0f%%", 127 | ) 128 | plt.title("请求类型占比") 129 | plt.savefig(f"{output_folder}/请求类型占比饼形图.png") 130 | 131 | def get_pingyu(total_requests): 132 | if total_requests < 50: 133 | return "似乎没什么人来找你玩呢,杂鱼❤" 134 | elif total_requests < 100: 135 | return "还没被大规模使用,加油!但是...咱才不会鼓励你呢!" 136 | elif total_requests < 1000: 137 | return "挖槽,大佬,已经总结这么多次了吗???这破程序没出什么bug吧" 138 | 139 | # Markdown Summary 140 | total_requests = len(data["tasks"]) 141 | md_content = f""" 142 |

🎉Bilibili-GPT-Helper 运行数据概览🎉

143 | 144 | ### 概览 145 | 146 | - 总共发起了 {total_requests} 个请求 147 | - 我的评价是:{get_pingyu(total_requests)} 148 | 149 | ### 图表 150 | 151 | #### 任务结束原因 152 | ![任务结束原因饼形图](./任务结束原因饼形图.png) 153 | 154 | #### 错误原因排名 155 | ![错误原因排名竖状图](./错误原因排名竖状图.png) 156 | 157 | #### 用户发起任务次数排名 158 | ![用户发起任务次数排名竖状图](./用户发起任务次数排名竖状图.png) 159 | 160 | #### 请求类型占比 161 | ![请求类型占比饼形图](./请求类型占比饼形图.png) 162 | """ 163 | 164 | # Write Markdown content to file 165 | md_file_path = f"{output_folder}/数据概览.md" 166 | with open(md_file_path, "w", encoding="utf-8") as f: 167 | f.write(md_content) 168 | 169 | 170 | if __name__ == "__main__": 171 | with open(r"D:\biligpt\records.json", encoding="utf-8") as f: 172 | data = json.load(f) 173 | 174 | run_statistic(r"../../statistics", data) 175 | -------------------------------------------------------------------------------- /src/core/app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import shutil 4 | 5 | import yaml 6 | from apscheduler.schedulers.asyncio import AsyncIOScheduler 7 | from injector import Module, provider, singleton 8 | 9 | from src.bilibili.bili_credential import BiliCredential 10 | from src.core.routers.asr_router import ASRouter 11 | from src.core.routers.chain_router import ChainRouter 12 | from src.core.routers.llm_router import LLMRouter 13 | from src.models.config import Config 14 | from src.utils.cache import Cache 15 | from src.utils.exceptions import ConfigError 16 | from src.utils.logging import LOGGER 17 | from src.utils.queue_manager import QueueManager 18 | from src.utils.task_status_record import TaskStatusRecorder 19 | 20 | _LOGGER = LOGGER.bind(name="app") 21 | 22 | 23 | class BiliGPT(Module): 24 | """BiliGPTHelper应用,储存所有的单例对象""" 25 | 26 | @singleton 27 | @provider 28 | def provide_config_obj(self) -> Config: 29 | with open(os.getenv("DOCKER_CONFIG_FILE", "config.yml"), encoding="utf-8") as f: 30 | config = yaml.load(f, Loader=yaml.FullLoader) 31 | try: 32 | # _LOGGER.debug(config) 33 | config = Config.model_validate(config) 34 | except Exception as e: 35 | # shutil.copy( 36 | # os.getenv("DOCKER_CONFIG_FILE", "config.yml"), os.getenv("DOCKER_CONFIG_FILE", "config.yml") + ".bak" 37 | # ) 38 | if os.getenv("RUNNING_IN_DOCKER") == "yes": 39 | shutil.copy( 40 | "./config/docker_config.yml", 41 | os.getenv("DOCKER_CONFIG_FILE", "config_template.yml"), 42 | ) 43 | else: 44 | shutil.copy( 45 | "./config/example_config.yml", 46 | os.getenv("DOCKER_CONFIG_FILE", "config_template.yml"), 47 | ) 48 | # { 49 | # field_name: (field.field_info.default if not field.required else "") 50 | # for field_name, field in Config.model_fields.items() 51 | # } 52 | # yaml.dump(Config().model_dump(mode="python")) 53 | _LOGGER.error( 54 | "配置文件格式错误 可能是因为项目更新、配置文件添加了新字段,请自行检查配置文件格式并更新配置文件 已复制最新配置文件模板到 config_template.yml 下面将打印详细错误日志" 55 | ) 56 | raise ConfigError(f"配置文件格式错误:{e}") from e 57 | return config 58 | 59 | @singleton 60 | @provider 61 | def provide_queue_manager(self) -> QueueManager: 62 | _LOGGER.info("正在初始化队列管理器") 63 | return QueueManager() 64 | 65 | @singleton 66 | @provider 67 | def provide_task_status_recorder(self, config: Config) -> TaskStatusRecorder: 68 | _LOGGER.info(f"正在初始化任务状态管理器,位置:{config.storage_settings.task_status_records}") 69 | return TaskStatusRecorder(config.storage_settings.task_status_records) 70 | 71 | @singleton 72 | @provider 73 | def provide_cache(self, config: Config) -> Cache: 74 | _LOGGER.info(f"正在初始化缓存,缓存路径为:{config.storage_settings.cache_path}") 75 | return Cache(config.storage_settings.cache_path) 76 | 77 | @singleton 78 | @provider 79 | def provide_credential(self, config: Config, scheduler: AsyncIOScheduler) -> BiliCredential: 80 | _LOGGER.info("正在初始化cookie") 81 | return BiliCredential( 82 | SESSDATA=config.bilibili_cookie.SESSDATA, 83 | bili_jct=config.bilibili_cookie.bili_jct, 84 | dedeuserid=config.bilibili_cookie.dedeuserid, 85 | buvid3=config.bilibili_cookie.buvid3, 86 | ac_time_value=config.bilibili_cookie.ac_time_value, 87 | sched=scheduler, 88 | ) 89 | 90 | @singleton 91 | @provider 92 | def provide_asr_router(self, config: Config, llm_router: LLMRouter) -> ASRouter: 93 | _LOGGER.info("正在初始化ASR路由器") 94 | router = ASRouter(config, llm_router) 95 | router.load_from_dir() 96 | return router 97 | 98 | @singleton 99 | @provider 100 | def provide_llm_router(self, config: Config) -> LLMRouter: 101 | _LOGGER.info("正在初始化LLM路由器") 102 | router = LLMRouter(config) 103 | router.load_from_dir() 104 | return router 105 | 106 | @singleton 107 | @provider 108 | def provide_chain_router(self, config: Config, queue_manager: QueueManager) -> ChainRouter: 109 | _LOGGER.info("正在初始化Chain路由器") 110 | router = ChainRouter(config, queue_manager) 111 | return router 112 | 113 | @singleton 114 | @provider 115 | def provide_scheduler(self) -> AsyncIOScheduler: 116 | _LOGGER.info("正在初始化定时器") 117 | return AsyncIOScheduler(timezone="Asia/Shanghai") 118 | 119 | @provider 120 | def provide_queue(self, queue_manager: QueueManager, queue_name: str) -> asyncio.Queue: 121 | _LOGGER.info(f"正在初始化队列 {queue_name}") 122 | return queue_manager.get_queue(queue_name) 123 | 124 | @singleton 125 | @provider 126 | def provide_stop_event(self) -> asyncio.Event: 127 | return asyncio.Event() 128 | 129 | # @singleton 130 | # @provider 131 | # def provide_chains( 132 | # self, 133 | # queue_manager: QueueManager, 134 | # config: Config, 135 | # credential: BiliCredential, 136 | # cache: Cache, 137 | # asr_router: ASRouter, 138 | # task_status_recorder: TaskStatusRecorder, 139 | # stop_event: asyncio.Event, 140 | # llm_router: LLMRouter 141 | # ) -> dict[str, BaseChain]: 142 | # """ 143 | # 如果增加了处理链,要在这里导入 144 | # :return: 145 | # """ 146 | # _LOGGER.info("开始加载摘要处理链") 147 | # _summarize_chain = Summarize(queue_manager=queue_manager, config=config, credential=credential, cache=cache, asr_router=asr_router, task_status_recorder=task_status_recorder, stop_event=stop_event, llm_router=llm_router) 148 | # return {str(_summarize_chain): _summarize_chain} 149 | -------------------------------------------------------------------------------- /src/models/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from pydantic import BaseModel, ConfigDict, Field, field_validator 4 | 5 | 6 | class BilibiliCookie(BaseModel): 7 | # 防呆措施,避免有傻瓜把dedeuserid写成数字 8 | model_config = ConfigDict(coerce_numbers_to_str=True) # type: ignore 9 | 10 | SESSDATA: str 11 | bili_jct: str 12 | buvid3: str 13 | dedeuserid: str 14 | ac_time_value: str 15 | 16 | # noinspection PyMethodParameters 17 | @field_validator("*", mode="after") 18 | def check_required_fields(cls, value): 19 | if value is None or (isinstance(value, (str, list)) and not value): 20 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件") 21 | return value 22 | 23 | 24 | class ChainKeywords(BaseModel): 25 | summarize_keywords: list[str] 26 | ask_ai_keywords: list[str] 27 | 28 | # noinspection PyMethodParameters 29 | @field_validator("*", mode="after") 30 | def check_keywords(cls, value): 31 | if not value or len(value) == 0: 32 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件") 33 | return value 34 | 35 | 36 | class Openai(BaseModel): 37 | enable: bool = True 38 | priority: int = 70 39 | api_key: str 40 | model: str = "gpt-3.5-turbo-16k" 41 | api_base: str = Field(default="https://api.openai.com/v1") 42 | 43 | # noinspection PyMethodParameters 44 | @field_validator("*", mode="after") 45 | def check_required_fields(cls, value, values): 46 | if values.data.get("enable") is False: 47 | return value 48 | if value is None or (isinstance(value, (str, list)) and not value): 49 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件") 50 | return value 51 | 52 | 53 | class AiproxyClaude(BaseModel): 54 | enable: bool = True 55 | priority: int = 90 56 | api_key: str 57 | model: str = "claude-instant-1" 58 | api_base: str = Field(default="https://api.aiproxy.io/") 59 | 60 | # noinspection PyMethodParameters 61 | @field_validator("*", mode="after") 62 | def check_required_fields(cls, value, values): 63 | if values.data.get("enable") is False: 64 | return value 65 | if value is None or (isinstance(value, (str, list)) and not value): 66 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件") 67 | return value 68 | 69 | # noinspection PyMethodParameters 70 | @field_validator("model", mode="after") 71 | def check_model(cls, value, values): 72 | models = ["claude-instant-1", "claude-2"] 73 | if value not in models: 74 | raise ValueError(f"配置文件中{cls}字段为{value},请检查配置文件,目前支持的模型有{models}") 75 | return value 76 | 77 | 78 | class Spark(BaseModel): 79 | enable: bool = True 80 | priority: int = 80 81 | appid: str 82 | api_key: str 83 | api_secret: str 84 | spark_url: str = Field(default="wss://spark-api.xf-yun.com/v3.5/chat") # 默认3.5版本 85 | domain: str = Field(default="generalv3.5") # 默认3.5 86 | 87 | @field_validator("*", mode="after") 88 | def check_required_fields(cls, value, values): 89 | if values.data.get("enable") is False: 90 | return value 91 | if value is None or (isinstance(value, (str, list)) and not value): 92 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件") 93 | return value 94 | 95 | 96 | class LLMs(BaseModel): 97 | openai: Openai 98 | aiproxy_claude: AiproxyClaude 99 | spark: Spark 100 | 101 | 102 | class OpenaiWhisper(BaseModel): 103 | BaseModel.model_config["protected_namespaces"] = () 104 | enable: bool = False 105 | priority: int = 70 106 | api_key: str 107 | model: str = "whisper-1" 108 | api_base: str = Field(default="https://api.openai.com/v1") 109 | after_process: bool = False 110 | 111 | # noinspection PyMethodParameters 112 | @field_validator("api_key", mode="after") 113 | def check_required_fields(cls, value, values): 114 | if values.data.get("enable") is False: 115 | return value 116 | if value is None or (isinstance(value, (str, list)) and not value): 117 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件") 118 | return value 119 | 120 | # noinspection PyMethodParameters 121 | @field_validator("model", mode="after") 122 | def check_model(cls, value, values): 123 | value = "whisper-1" 124 | return value 125 | 126 | 127 | class LocalWhisper(BaseModel): 128 | BaseModel.model_config["protected_namespaces"] = () 129 | enable: bool = False 130 | priority: int = 60 131 | model_size: str = "tiny" 132 | device: str = "cpu" 133 | model_dir: str = Field( 134 | default_factory=lambda: os.getenv("DOCKER_WHISPER_MODELS_DIR"), 135 | validate_default=True, 136 | ) 137 | after_process: bool = False 138 | 139 | # noinspection PyMethodParameters 140 | @field_validator( 141 | "model_size", 142 | "device", 143 | "model_dir", 144 | mode="after", 145 | ) 146 | def check_whisper_fields(cls, value, values): 147 | if values.data.get("whisper_enable"): 148 | if value is None or (isinstance(value, str) and not value): 149 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件") 150 | if os.getenv("RUNNING_IN_DOCKER") == "yes": 151 | cls.device = "cpu" 152 | if os.getenv("ENABLE_WHISPER") == "yes": 153 | cls.enable = True 154 | else: 155 | cls.enable = False 156 | return value 157 | 158 | 159 | class ASRs(BaseModel): 160 | local_whisper: LocalWhisper 161 | openai_whisper: OpenaiWhisper 162 | 163 | 164 | class StorageSettings(BaseModel): 165 | cache_path: str = Field(default_factory=lambda: os.getenv("DOCKER_CACHE_FILE"), validate_default=True) 166 | temp_dir: str = Field(default_factory=lambda: os.getenv("DOCKER_TEMP_DIR"), validate_default=True) 167 | task_status_records: str = Field(default_factory=lambda: os.getenv("DOCKER_RECORDS_DIR"), validate_default=True) 168 | statistics_dir: str = Field( 169 | default_factory=lambda: os.getenv("DOCKER_STATISTICS_DIR"), 170 | validate_default=True, 171 | ) 172 | queue_save_dir: str = Field(default_factory=lambda: os.getenv("DOCKER_QUEUE_DIR"), validate_default=True) 173 | up_video_cache: str = Field(default_factory=lambda: os.getenv("DOCKER_UP_VIDEO_CACHE"), validate_default=True) 174 | up_file: str = Field(default_factory=lambda: os.getenv("DOCKER_UP_FILE"), validate_default=True) 175 | 176 | # noinspection PyMethodParameters 177 | @field_validator("*", mode="after") 178 | def check_required_fields(cls, value): 179 | if value is None or (isinstance(value, (str, list)) and not value): 180 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件") 181 | return value 182 | 183 | 184 | class BilibiliNickName(BaseModel): 185 | nickname: str = "BiliBot" 186 | 187 | @field_validator("*", mode="after") 188 | def check_required_fields(cls, value): 189 | if value is None or (isinstance(value, (str, list)) and not value): 190 | raise ValueError(f"配置文件中{cls}字段为空,请检查配置文件") 191 | return value 192 | 193 | 194 | class Config(BaseModel): 195 | """配置文件模型""" 196 | 197 | bilibili_cookie: BilibiliCookie 198 | bilibili_self: BilibiliNickName 199 | chain_keywords: ChainKeywords 200 | LLMs: LLMs 201 | ASRs: ASRs 202 | storage_settings: StorageSettings 203 | debug_mode: bool = True 204 | -------------------------------------------------------------------------------- /src/bilibili/bili_comment.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import random 3 | from asyncio import Queue 4 | from typing import Optional, Union 5 | 6 | import tenacity 7 | from bilibili_api import comment, video 8 | from injector import inject 9 | 10 | from src.bilibili.bili_credential import BiliCredential 11 | from src.bilibili.bili_video import BiliVideo 12 | from src.models.task import AskAIResponse, BiliGPTTask, SummarizeAiResponse 13 | from src.utils.callback import chain_callback 14 | from src.utils.exceptions import RiskControlFindError 15 | from src.utils.logging import LOGGER 16 | 17 | _LOGGER = LOGGER.bind(name="bilibili-comment") 18 | 19 | 20 | class BiliComment: 21 | @inject 22 | def __init__(self, comment_queue: Queue, credential: BiliCredential): 23 | self.comment_queue = comment_queue 24 | self.credential = credential 25 | 26 | @staticmethod 27 | async def get_random_comment( 28 | aid, 29 | credential, 30 | type_=comment.CommentResourceType.VIDEO, 31 | page_index=1, 32 | order=comment.OrderType.LIKE, 33 | ) -> str | None: 34 | """ 35 | 随机获取几条热评,直接生成评论prompt string 36 | 37 | :param aid: 视频ID 38 | :param credential: 身份凭证 39 | :param type_: 评论资源类型 40 | :param page_index: 评论页数索引 41 | :param order: 评论排序方式 42 | :return: 拼接的评论字符串 43 | """ 44 | if str(aid).startswith("av"): 45 | aid = aid[2:] 46 | _LOGGER.debug(f"正在获取视频{aid}的评论列表") 47 | comment_list = await comment.get_comments( 48 | oid=aid, 49 | credential=credential, 50 | type_=type_, 51 | page_index=page_index, 52 | order=order, 53 | ) 54 | _LOGGER.debug(f"获取视频{aid}的评论列表成功") 55 | if len(comment_list) == 0: 56 | _LOGGER.warning(f"视频{aid}没有评论") 57 | return None 58 | _LOGGER.debug("正在随机选择评论") 59 | ignore_name_list = [ 60 | "哔哩哔哩", 61 | "AI", 62 | "课代表", 63 | "机器人", 64 | "小助手", 65 | "总结", 66 | "有趣的程序员", 67 | ] # TODO 从配置文件中读取(设置过滤表尽可能避免低质量评论) 68 | new_comment_list = [] 69 | for _comment in comment_list["replies"]: 70 | for name in ignore_name_list: 71 | if name in _comment["member"]["uname"]: 72 | _LOGGER.debug(f"评论{_comment['member']['uname']}包含过滤词{name},跳过") 73 | break 74 | else: 75 | _LOGGER.debug(f"评论{_comment['member']['uname']}不包含过滤词,加入新列表") 76 | new_comment_list.append(_comment) 77 | if len(new_comment_list) == 0: 78 | _LOGGER.warning(f"视频{aid}没有合适的评论") 79 | return None 80 | # 挑选三条评论 81 | if len(new_comment_list) < 3: 82 | _LOGGER.debug(f"视频{aid}的评论数量小于3,直接挑选") 83 | _LOGGER.debug("正在拼接评论") 84 | comment_str = "" 85 | for _comment in new_comment_list: 86 | comment_str += f"【{_comment['member']['uname']}】:{_comment['content']['message']}\n" 87 | _LOGGER.debug("拼接评论成功") 88 | return comment_str 89 | _LOGGER.debug("正在挑选三条评论") 90 | selected_comment_list = random.sample(new_comment_list, 3) 91 | _LOGGER.debug("挑选三条评论成功") 92 | # 拼接评论 93 | _LOGGER.debug("正在拼接评论") 94 | comment_str = "" 95 | for _comment in selected_comment_list: 96 | _comment: dict 97 | comment_str += f"【{_comment['member']['uname']}】:{_comment['content']['message']}\n" 98 | _LOGGER.debug("拼接评论成功") 99 | return comment_str 100 | 101 | @staticmethod 102 | def build_reply_content( 103 | response: Union[SummarizeAiResponse, AskAIResponse, str], user: str, source_type: str 104 | ) -> str: 105 | """ 106 | 构建回复内容 107 | :param source_type: task来源 108 | :param user: 用户名 109 | :param response: AI响应内容 110 | :return: 回复内容字符串 111 | """ 112 | if source_type == "bili_up": 113 | if isinstance(response, SummarizeAiResponse): 114 | return f"【视频总结】{response.summary}\n【视频评分】{response.score}\n【AI的思考】{response.thinking}\n--🍺🍺🍺我是T-250,此次评论由我自己发起。你的点赞、关注和At可以向我提供升级的经验,助我升级到T-1000。" 115 | elif isinstance(response, str): 116 | return ( 117 | response 118 | + "\n--🍺🍺🍺我是T-250,此次评论由我自己发起。你的点赞、关注和At可以向我提供升级的经验,助我升级到T-1000。" 119 | ) 120 | else: 121 | return f"程序内部错误:无法识别的回复类型{type(response)}\n--🍺🍺🍺我是T-250,此次评论由我自己发起。你的点赞、关注和At可以向我提供升级的经验,助我升级到T-1000。" 122 | elif source_type == "bili_comment": 123 | if isinstance(response, SummarizeAiResponse): 124 | return f"【视频总结】{response.summary}\n【视频评分】{response.score}\n【AI的思考】{response.thinking}\n【👉此次评论由 @{user} 邀请回答】" 125 | elif isinstance(response, AskAIResponse): 126 | return f"【回答】{response.answer}\n【自我评分】{response.score}\n【👉此次评论由 @{user} 邀请回答】" 127 | elif isinstance(response, str): 128 | return response + f"\n【👉此次评论由 @{user} 邀请回答】" 129 | else: 130 | return f"程序内部错误:无法识别的回复类型{type(response)}\n【👉此次评论由 @{user} 邀请回答】" 131 | 132 | @tenacity.retry( 133 | retry=tenacity.retry_if_exception_type(Exception), 134 | wait=tenacity.wait_fixed(10), 135 | before_sleep=chain_callback, 136 | ) 137 | async def start_comment(self): 138 | """发送评论""" 139 | while True: 140 | risk_control_count = 0 141 | data = None 142 | while risk_control_count < 3: 143 | try: 144 | if data is not None: 145 | _LOGGER.debug("继续处理上一次失败的评论任务") 146 | if data is None: 147 | data: Optional[BiliGPTTask] = await self.comment_queue.get() 148 | _LOGGER.debug("获取到新的评论任务,开始处理") 149 | video_obj, _type = await BiliVideo(credential=self.credential, url=data.video_url).get_video_obj() 150 | video_obj: video.Video 151 | aid = video_obj.get_aid() 152 | if str(aid).startswith("av"): 153 | aid = aid[2:] 154 | oid = int(aid) 155 | # root = data.source_extra_attr.source_id 156 | user = data.raw_task_data["user"]["nickname"] 157 | source_type = data.source_type 158 | text = BiliComment.build_reply_content(data.process_result, user, source_type) 159 | resp = await comment.send_comment( 160 | oid=oid, 161 | credential=self.credential, 162 | text=text, 163 | type_=comment.CommentResourceType.VIDEO, 164 | ) 165 | if not resp["need_captcha"] and resp["success_toast"] == "发送成功": 166 | _LOGGER.debug(resp) 167 | _LOGGER.info("发送评论成功,休息30秒") 168 | await asyncio.sleep(30) 169 | break # 评论成功,退出当前任务的重试循环 170 | _LOGGER.warning("发送评论失败,大概率被风控了,咱们歇会儿再试吧") 171 | risk_control_count += 1 172 | if risk_control_count >= 3: 173 | _LOGGER.warning("连续3次风控,跳过当前任务处理下一个") 174 | data = None 175 | break 176 | raise RiskControlFindError 177 | except RiskControlFindError: 178 | _LOGGER.warning("遇到风控,等待60秒后重试当前任务") 179 | await asyncio.sleep(60) 180 | except asyncio.CancelledError: 181 | _LOGGER.info("评论处理链关闭") 182 | return 183 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import shutil 4 | import signal 5 | import sys 6 | import traceback 7 | 8 | from apscheduler.events import EVENT_JOB_ERROR 9 | from apscheduler.schedulers.asyncio import AsyncIOScheduler 10 | from injector import Injector 11 | 12 | from safe_update import merge_cache_to_new_version 13 | from src.bilibili.bili_comment import BiliComment 14 | from src.bilibili.bili_credential import BiliCredential 15 | from src.bilibili.bili_session import BiliSession 16 | from src.chain.ask_ai import AskAI 17 | from src.chain.summarize import Summarize 18 | from src.core.app import BiliGPT 19 | from src.listener.bili_listen import Listen 20 | from src.models.config import Config 21 | from src.utils.callback import scheduler_error_callback 22 | from src.utils.logging import LOGGER 23 | from src.utils.queue_manager import QueueManager 24 | 25 | 26 | class BiliGPTPipeline: 27 | stop_event: asyncio.Event 28 | 29 | def __init__(self): 30 | _LOGGER.info("正在启动BiliGPTHelper") 31 | with open("VERSION", encoding="utf-8") as ver: 32 | version = ver.read() 33 | _LOGGER.info(f"当前运行版本:V{version}") 34 | signal.signal(signal.SIGINT, BiliGPTPipeline.stop_handler) 35 | signal.signal(signal.SIGTERM, BiliGPTPipeline.stop_handler) 36 | 37 | # 检查环境变量,预设置docker环境 38 | if os.getenv("RUNNING_IN_DOCKER") == "yes": 39 | if not os.listdir("/data"): 40 | os.system("cp -r /clone-data/* /data") 41 | elif not os.path.isfile("config.yml"): 42 | _LOGGER.warning("没有发现配置文件,正在重新生成新的配置文件!") 43 | try: 44 | shutil.copyfile("./config/example_config.yml", "./config.yml") 45 | except Exception: 46 | _LOGGER.error("在复制过程中发生了未预期的错误,程序初始化停止") 47 | traceback.print_exc() 48 | exit(0) 49 | 50 | # config_path = "./config.yml" 51 | 52 | # if os.getenv("RUNNING_IN_DOCKER") == "yes": 53 | # temp = "./config/docker_config.yml" 54 | # conf = load_config(config_path) 55 | # template = load_config(temp) 56 | # if is_have_diff(conf, template): 57 | # _LOGGER.info("检测到config模板发生更新,正在更新用户的config,请记得及时填写新的字段") 58 | # merge_config(conf, template) 59 | # save_config(conf, config_path) 60 | # else: 61 | # temp = "./config/example_config.yml" 62 | # conf = load_config(config_path) 63 | # template = load_config(temp) 64 | # if is_have_diff(conf, template): 65 | # _LOGGER.info("检测到config模板发生更新,正在更新用户的config,请记得及时填写新的字段") 66 | # merge_config(conf, template) 67 | # save_config(conf, config_path) 68 | 69 | # 初始化注入器 70 | _LOGGER.info("正在初始化依赖注入器") 71 | self.injector = Injector(BiliGPT) 72 | 73 | BiliGPTPipeline.stop_event = self.injector.get(asyncio.Event) 74 | _LOGGER.debug("初始化配置文件") 75 | config = self.injector.get(Config) 76 | 77 | if config.debug_mode is False: 78 | LOGGER.remove() 79 | LOGGER.add(sys.stdout, level="INFO") 80 | 81 | _LOGGER.debug("尝试更新用户数据,符合新版本结构(这只是个提示,每次运行都会显示,其他地方不报错就别管了)") 82 | self.update_sth(config) 83 | 84 | def update_sth(self, config: Config): 85 | """升级后进行配置文件、运行数据的转换""" 86 | merge_cache_to_new_version(config.storage_settings.cache_path) 87 | 88 | @staticmethod 89 | def stop_handler(_, __): 90 | BiliGPTPipeline.stop_event.set() 91 | 92 | async def start(self): 93 | try: 94 | _injector = self.injector 95 | 96 | # 恢复队列任务 97 | _LOGGER.info("正在恢复队列信息") 98 | _injector.get(QueueManager).recover_queue(_injector.get(Config).storage_settings.queue_save_dir) 99 | 100 | # 初始化at侦听器 101 | _LOGGER.info("正在初始化at侦听器") 102 | listen = _injector.get(Listen) 103 | 104 | # 初始化摘要处理链 105 | _LOGGER.info("正在初始化摘要处理链") 106 | summarize_chain = _injector.get(Summarize) 107 | 108 | # 初始化ask_ai处理链 109 | _LOGGER.info("正在初始化ask_ai处理链") 110 | ask_ai_chain = _injector.get(AskAI) 111 | 112 | # 启动侦听器 113 | _LOGGER.info("正在启动at侦听器") 114 | listen.start_listen_at() 115 | _LOGGER.info("正在启动视频更新检测侦听器") 116 | listen.start_video_mission() 117 | 118 | # 默认关掉私信,私信太烧内存 119 | # _LOGGER.info("启动私信侦听器") 120 | # await listen.listen_private() 121 | 122 | _LOGGER.info("正在启动cookie过期检查和刷新") 123 | _injector.get(BiliCredential).start_check() 124 | 125 | # 启动定时任务调度器 126 | _LOGGER.info("正在启动定时任务调度器") 127 | _injector.get(AsyncIOScheduler).start() 128 | _injector.get(AsyncIOScheduler).add_listener(scheduler_error_callback, EVENT_JOB_ERROR) 129 | 130 | # 启动处理链 131 | _LOGGER.info("正在启动处理链") 132 | summarize_task = asyncio.create_task(summarize_chain.main()) 133 | ask_ai_task = asyncio.create_task(ask_ai_chain.main()) 134 | 135 | # 启动评论 136 | _LOGGER.info("正在启动评论处理链") 137 | comment = BiliComment( 138 | _injector.get(QueueManager).get_queue("reply"), 139 | _injector.get(BiliCredential), 140 | ) 141 | comment_task = asyncio.create_task(comment.start_comment()) 142 | 143 | # 启动私信 144 | _LOGGER.info("正在启动私信处理链") 145 | private = BiliSession( 146 | _injector.get(BiliCredential), 147 | _injector.get(QueueManager).get_queue("private"), 148 | ) 149 | private_task = asyncio.create_task(private.start_private_reply()) 150 | 151 | _LOGGER.info("摘要处理链、评论处理链、私信处理链启动完成") 152 | 153 | # 定时执行指定up是否有更新视频,如果有自动回复 154 | # mission = BiliMission(_injector.get(BiliCredential), _injector.get(AsyncIOScheduler)) 155 | # await mission.start() 156 | # _LOGGER.info("创建刷新UP最新视频任务成功,刷新频率:60分钟") 157 | 158 | _LOGGER.success("🎉启动完成 enjoy it") 159 | 160 | while True: 161 | if BiliGPTPipeline.stop_event.is_set(): 162 | _LOGGER.info("正在关闭BiliGPTHelper,记得下次再来玩喵!") 163 | _LOGGER.info("正在关闭定时任务调度器") 164 | sched = _injector.get(AsyncIOScheduler) 165 | for job in sched.get_jobs(): 166 | sched.remove_job(job.id) 167 | sched.shutdown() 168 | listen.close_private_listen() 169 | _LOGGER.info("正在保存队列任务信息") 170 | _injector.get(QueueManager).safe_close_all_queues( 171 | _injector.get(Config).storage_settings.queue_save_dir 172 | ) 173 | _LOGGER.info("正在关闭所有的处理链") 174 | summarize_task.cancel() 175 | ask_ai_task.cancel() 176 | comment_task.cancel() 177 | private_task.cancel() 178 | # mission_task.cancel() 179 | # _LOGGER.info("正在生成本次运行的统计报告") 180 | # statistics_dir = _injector.get(Config).model_dump()["storage_settings"][ 181 | # "statistics_dir" 182 | # ] 183 | # run_statistic( 184 | # statistics_dir if statistics_dir else "./statistics", 185 | # _injector.get(TaskStatusRecorder).tasks, 186 | # ) 187 | break 188 | await asyncio.sleep(1) 189 | except Exception: 190 | _LOGGER.error("发生了未捕获的错误,停止运行!") 191 | traceback.print_exc() 192 | 193 | 194 | if __name__ == "__main__": 195 | os.environ["DEBUG_MODE"] = "false" 196 | _LOGGER = LOGGER.bind(name="main") 197 | biligpt = BiliGPTPipeline() 198 | asyncio.run(biligpt.start()) 199 | -------------------------------------------------------------------------------- /src/models/task.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | import time 3 | import uuid 4 | from enum import Enum 5 | from typing import Annotated, List, Optional, Union 6 | 7 | from pydantic import BaseModel, Field, StringConstraints, field_validator 8 | 9 | 10 | class SummarizeAiResponse(BaseModel): 11 | """总结处理链的AI回复""" 12 | 13 | # TODO 加上这行就跑不了了 14 | # 防ai呆措施,让数字评分变成字符串 15 | # model_config = ConfigDict(coerce_numbers_to_str=True) # type: ignore 16 | 17 | summary: str # 摘要 18 | score: str # 视频评分 19 | thinking: str # 思考 20 | if_no_need_summary: bool # 是否需要摘要 21 | 22 | @field_validator("score", mode="before") 23 | def check_score(cls, value, values): 24 | if not isinstance(value, str): 25 | return str(value) 26 | return value 27 | 28 | @field_validator("*", mode="before") 29 | def check_required_fields(cls, value, values): 30 | # 星火是真的蠢,返回的if_no_need_summary是字符串 31 | match values.data.get("if_no_need_summary"): 32 | case "否": 33 | cls.if_no_need_summary = False 34 | case "是": 35 | cls.if_no_need_summary = True 36 | case "yes": 37 | cls.if_no_need_summary = True 38 | case "no": 39 | cls.if_no_need_summary = False 40 | case "false": 41 | cls.if_no_need_summary = False 42 | case "true": 43 | cls.if_no_need_summary = True 44 | case "True": 45 | cls.if_no_need_summary = True 46 | case "False": 47 | cls.if_no_need_summary = False 48 | return value 49 | 50 | 51 | class AskAIResponse(BaseModel): 52 | """问问ai的回复""" 53 | 54 | # 防ai呆措施,让数字评分变成字符串 55 | # model_config = ConfigDict(coerce_numbers_to_str=True) # type: ignore 56 | 57 | answer: str # 回答 58 | score: str # 评分 59 | 60 | @field_validator("score", mode="before") 61 | def check_score(cls, value, values): 62 | if not isinstance(value, str): 63 | return str(value) 64 | return value 65 | 66 | 67 | class ProcessStages(Enum): 68 | """视频处理阶段""" 69 | 70 | PREPROCESS = ( 71 | "preprocess" # 包括构建prompt之前都是这个阶段(包含获取信息、字幕读取),处在这个阶段恢复时就直接从头开始 72 | ) 73 | WAITING_LLM_RESPONSE = ( 74 | "waiting_llm_response" # 等待llm的回复 这个阶段应该重新加载字幕或从items中的whisper_subtitle节点读取 75 | ) 76 | WAITING_SEND = "waiting_send" # 等待发送 这是llm回复后的阶段,需要解析llm的回复,然后发送 77 | WAITING_PUSH_TO_CACHE = "waiting_push_to_cache" # 等待推送到缓存(就是发送后) 78 | WAITING_RETRY = "waiting_retry" # 等待重试(ai返回数据格式不对) 79 | END = "end" # 结束 按理来说应该删除,但为了后期统计,保留 80 | 81 | 82 | class Chains(Enum): 83 | SUMMARIZE = "summarize" 84 | ASK_AI = "ask_ai" 85 | 86 | 87 | class EndReasons(Enum): 88 | """视频处理结束原因""" 89 | 90 | NORMAL = "正常结束" # 正常结束 91 | ERROR = "视频在处理过程中出现致命错误或多次重试失败,详细见具体的msg" # 错误结束 92 | NONEED = "AI认为该视频不需要被处理,可能是因为内容无意义" # AI认为这个视频不需要处理 93 | 94 | 95 | class BiliAtSpecialAttributes(BaseModel): 96 | """包含来自at的task的特殊属性""" 97 | 98 | source_id: int # 该评论的id,对应send_comment中的root(如果要回复的话) 99 | target_id: int # 上一级评论id, 二级评论指向的就是root_id,三级评论指向的是二级评论的id 100 | root_id: int # 暂时还没出现过 101 | native_uri: str # 评论链接,包含根评论id和父评论id 102 | at_details: List[dict] # at的人的信息,常规的个人信息dict 103 | 104 | 105 | class AskAICommandParams(BaseModel): 106 | """问问ai包含的参数""" 107 | 108 | question: str # 用户提出的问题 109 | 110 | 111 | class BiliGPTTask(BaseModel): 112 | """单任务全生命周期的数据模型 用于替代其他所有的已有类型""" 113 | 114 | source_type: Annotated[str, StringConstraints(pattern=r"^(bili_comment|bili_private|api|bili_up)$")] # type: ignore # 设置task的获取来源 115 | raw_task_data: dict # 原始的task数据,包含所有信息 116 | sender_id: int # task提交者的id,用于统计。来自b站的task就是uid,其他来源的task要自己定义 117 | # video_title: str # 视频标题 118 | video_url: str # 视频链接 119 | video_id: str # bvid 120 | source_command: str # 用户发送的原始指令(eg. "总结一下" "问一下:xxxxxxx") 121 | # mission: bool = Field(default=False) # 用户AT还是自发检测的标志 122 | command_params: Optional[AskAICommandParams] = None # 用户原始指令经解析后的参数 123 | source_extra_attr: Optional[BiliAtSpecialAttributes] = None # 在获取到task时附加的其他原始参数(比如评论id等) 124 | process_result: Optional[Union[SummarizeAiResponse, AskAIResponse, str, dict]] = ( 125 | None # 最终处理结果,根据不同的处理链会有不同的结果 (dict的存在是一个历史遗留问题,不想解决了,再拉一坨) 126 | ) 127 | subtitle: Optional[str] = None # 该视频字幕,与之前不同的是,现在不管是什么方式得到的字幕都要保存下来 128 | process_stage: Optional[ProcessStages] = Field(default=ProcessStages.PREPROCESS) # 视频处理阶段 129 | chain: Optional[Chains] = None # 视频处理事件,即对应的处理链 130 | uuid: Optional[str] = Field(default=str(uuid.uuid4())) # 该任务的uuid4 131 | gmt_create: int = Field(default=int(time.time())) # 任务创建时间戳,默认为当前时间戳 132 | gmt_start_process: int = Field(default=0) # 任务开始处理时间,不同于上方的gmt_create,这个是真正开始处理的时间 133 | gmt_retry_start: int = Field(default=0) # 如果该任务被重试,就在开始重试时填写该属性 134 | gmt_end: int = Field(default=0) # 任务彻底结束时间 135 | error_msg: Optional[str] = None # 更详细的错误信息 136 | end_reason: Optional[EndReasons] = None # 任务结束原因 137 | 138 | 139 | # class AtItem(TypedDict): 140 | # """里面储存着待处理任务的所有信息,私信消息也会被转换为这种格式再处理,后续可以进一步清洗,形成这个项目自己的格式""" 141 | # 142 | # type: str # 基本都为reply 143 | # business: str # 基本都为评论 144 | # business_id: int # 基本都为1 145 | # title: str # 如果是一级回复,这里是视频标题,如果是二级回复,这里是一级回复的内容 146 | # image: str # 一级回复是视频封面,二级回复为空 147 | # uri: str # 视频链接 148 | # source_content: str # 回复内容 149 | # source_id: int # 该评论的id,对应send_comment中的root(如果要回复的话) 150 | # target_id: int # 上一级评论id, 二级评论指向的就是root_id,三级评论指向的是二级评论的id 151 | # root_id: int # 暂时还没出现过 152 | # native_url: str # 评论链接,包含根评论id和父评论id 153 | # at_details: List[dict] # at的人的信息,常规的个人信息dict 154 | # ai_response: NotRequired[SummarizeAiResponse | str] # AI回复的内容,需要等到处理完才能获取到dict,否则为还没处理的str 155 | # is_private_msg: NotRequired[bool] # 是否为私信 156 | # private_msg_event: NotRequired[PrivateMsgSession] # 单用户私信会话信息 157 | # whisper_subtitle: NotRequired[str] # whisper字幕 158 | # stage: NotRequired[ProcessStages] # 视频处理阶段 159 | # event: NotRequired[Chains] # 视频处理事件 160 | # uuid: NotRequired[str] # 视频处理uuid 161 | 162 | 163 | # class AtItems(TypedDict): 164 | # id: int 165 | # user: dict # at发送者的个人信息,常规的个人信息dict 166 | # item: List[AtItem] 167 | # at_time: int 168 | 169 | 170 | # class AtAPIResponse(TypedDict): 171 | # """API返回的at消息""" 172 | # 173 | # cursor: AtCursor 174 | # items: List[AtItems] 175 | 176 | 177 | # class TaskStatus(BaseModel): 178 | # """视频记录""" 179 | # 180 | # gmt_create: int 181 | # gmt_end: Optional[int] 182 | # event: Chains 183 | # stage: ProcessStages 184 | # task_data: BiliGPTTask 185 | # end_reason: Optional[EndReasons] 186 | # error_msg: Optional[str] 187 | # use_whisper: Optional[bool] 188 | # if_retry: Optional[bool] 189 | 190 | # class AtCursor(TypedDict): 191 | # is_end: bool 192 | # id: int 193 | # time: int 194 | 195 | 196 | # class PrivateMsg(BaseModel): 197 | # """ 198 | # 事件参数: 199 | # + receiver_id: 收信人 UID 200 | # + receiver_type: 收信人类型,1: 私聊, 2: 应援团通知, 3: 应援团 201 | # + sender_uid: 发送人 UID 202 | # + talker_id: 对话人 UID 203 | # + msg_seqno: 事件 Seqno 204 | # + msg_type: 事件类型 205 | # + msg_key: 事件唯一编号 206 | # + timestamp: 事件时间戳 207 | # + content: 事件内容 208 | # 209 | # 事件类型: 210 | # + TEXT: 纯文字消息 211 | # + PICTURE: 图片消息 212 | # + WITHDRAW: 撤回消息 213 | # + GROUPS_PICTURE: 应援团图片,但似乎不常触发,一般使用 PICTURE 即可 214 | # + SHARE_VIDEO: 分享视频 215 | # + NOTICE: 系统通知 216 | # + PUSHED_VIDEO: UP主推送的视频 217 | # + WELCOME: 新成员加入应援团欢迎 218 | # 219 | # TEXT = "1" 220 | # PICTURE = "2" 221 | # WITHDRAW = "5" 222 | # GROUPS_PICTURE = "6" 223 | # SHARE_VIDEO = "7" 224 | # NOTICE = "10" 225 | # PUSHED_VIDEO = "11" 226 | # WELCOME = "306" 227 | # """ 228 | # 229 | # receiver_id: int 230 | # receiver_type: int 231 | # sender_uid: int 232 | # talker_id: int 233 | # msg_seqno: int 234 | # msg_type: int 235 | # msg_key: int 236 | # timestamp: int 237 | # content: Union[str, int, Picture, Video] 238 | 239 | # class PrivateMsgSession(BaseModel): 240 | # """储存单个用户的私信会话信息""" 241 | # 242 | # status: str # 状态 243 | # text_event: Optional[PrivateMsg] # 文本事件 244 | # video_event: Optional[PrivateMsg] # 视频事件 245 | -------------------------------------------------------------------------------- /src/chain/ask_ai.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | import traceback 4 | 5 | import tenacity 6 | import yaml 7 | 8 | from src.bilibili.bili_session import BiliSession 9 | from src.chain.base_chain import BaseChain 10 | from src.llm.templates import Templates 11 | from src.models.task import AskAIResponse, BiliGPTTask, Chains, ProcessStages 12 | from src.utils.callback import chain_callback 13 | from src.utils.logging import LOGGER 14 | 15 | _LOGGER = LOGGER.bind(name="ask-ai-chain") 16 | 17 | 18 | class AskAI(BaseChain): 19 | async def _precheck(self, task: BiliGPTTask) -> bool: 20 | match task.source_type: 21 | case "bili_private": 22 | _LOGGER.debug("该消息是私信消息,继续处理") 23 | await BiliSession.quick_send(self.credential, task, "视频已开始处理,你先别急") 24 | return True 25 | case "bili_comment": 26 | _LOGGER.debug("该消息是评论消息,继续处理") 27 | return True 28 | case "api": 29 | _LOGGER.debug("该消息是api消息,继续处理") 30 | return True 31 | case "bili_up": 32 | _LOGGER.debug("该消息是评论消息,继续处理") 33 | return True 34 | return False 35 | 36 | @tenacity.retry( 37 | retry=tenacity.retry_if_exception_type(Exception), 38 | wait=tenacity.wait_fixed(10), 39 | before_sleep=chain_callback, 40 | ) 41 | async def main(self): 42 | try: 43 | await self._on_start() 44 | while True: 45 | task: BiliGPTTask = await self.ask_ai_queue.get() 46 | _item_uuid = task.uuid 47 | self._create_record(task) 48 | _LOGGER.info(f"ask_ai处理链获取到任务了:{task.uuid}") 49 | # 检查是否满足处理条件 50 | if task.process_stage == ProcessStages.END: 51 | _LOGGER.info(f"任务{task.uuid}已经结束,获取下一个") 52 | continue 53 | if not await self._precheck(task): 54 | continue 55 | # 获取视频相关信息 56 | # data = self.task_status_recorder.get_data_by_uuid(_item_uuid) 57 | resp = await self._get_video_info(task, if_get_comments=False) 58 | if resp is None: 59 | continue 60 | ( 61 | video, 62 | video_info, 63 | format_video_name, 64 | video_tags_string, 65 | video_comments, 66 | ) = resp 67 | if task.process_stage in ( 68 | ProcessStages.PREPROCESS, 69 | ProcessStages.WAITING_LLM_RESPONSE, 70 | ): 71 | begin_time = time.perf_counter() 72 | # FIXME: 需要修改项目的cache实现,标注来自于哪个处理链,否则事有点大 73 | if await self._is_cached_video(task, _item_uuid, video_info): 74 | continue 75 | # 处理视频音频流和字幕 76 | _LOGGER.debug("视频信息获取成功,正在获取视频音频流和字幕") 77 | if task.subtitle is not None: 78 | text = task.subtitle 79 | _LOGGER.debug("使用字幕缓存,开始使用模板生成prompt") 80 | else: 81 | text = await self._smart_get_subtitle(video, _item_uuid, format_video_name, task) 82 | if text is None: 83 | continue 84 | task.subtitle = text 85 | _LOGGER.info( 86 | f"视频{format_video_name}音频流和字幕处理完成,共用时{time.perf_counter() - begin_time}s,开始调用LLM生成摘要" 87 | ) 88 | self.task_status_recorder.update_record( 89 | _item_uuid, 90 | new_task_data=task, 91 | process_stage=ProcessStages.WAITING_LLM_RESPONSE, 92 | ) 93 | llm = self.llm_router.get_one() 94 | if llm is None: 95 | _LOGGER.warning("没有可用的LLM,关闭系统") 96 | await self._set_err_end(msg="没有可用的LLM,被迫结束处理", task=task) 97 | self.stop_event.set() 98 | continue 99 | prompt = llm.use_template( 100 | Templates.ASK_AI_USER, 101 | Templates.ASK_AI_SYSTEM, 102 | title=video_info["title"], 103 | subtitle=text, 104 | description=video_info["desc"], 105 | question=task.command_params.question, 106 | ) 107 | _LOGGER.debug("prompt生成成功,开始调用llm") 108 | # 调用openai的Completion API 109 | response = await llm.completion(prompt) 110 | if response is None: 111 | _LOGGER.warning(f"任务{task.uuid}:ai未返回任何内容,请自行检查问题,跳过处理") 112 | await self._set_err_end(msg="ai未返回任何内容,请自行检查问题,跳过处理", task=task) 113 | self.llm_router.report_error(llm.alias) 114 | continue 115 | answer, tokens = response 116 | self.now_tokens += tokens 117 | _LOGGER.debug(f"llm输出内容为:{answer}") 118 | _LOGGER.debug("调用llm成功,开始处理结果") 119 | task.process_result = answer 120 | task.process_stage = ProcessStages.WAITING_SEND 121 | self.task_status_recorder.update_record(_item_uuid, task) 122 | if task.process_stage in ( 123 | ProcessStages.WAITING_SEND, 124 | ProcessStages.WAITING_RETRY, 125 | ): 126 | begin_time = time.perf_counter() 127 | answer = task.process_result 128 | # obj, _type = await video.get_video_obj() 129 | # 处理结果 130 | if answer: 131 | try: 132 | if task.process_stage == ProcessStages.WAITING_RETRY: 133 | raise Exception("触发重试") 134 | answer = answer.replace("False", "false") # 解决一部分因为大小写问题导致的json解析失败 135 | answer = answer.replace("True", "true") 136 | resp = yaml.safe_load(answer) 137 | task.process_result = AskAIResponse.model_validate(resp) 138 | _LOGGER.info( 139 | f"ai返回内容解析正确,视频{format_video_name}摘要处理完成,共用时{time.perf_counter() - begin_time}s" 140 | ) 141 | await self.finish(task) 142 | except Exception as e: 143 | _LOGGER.error(f"处理结果失败:{e},大概是ai返回的格式不对,尝试修复") 144 | traceback.print_tb(e.__traceback__) 145 | self.task_status_recorder.update_record( 146 | _item_uuid, 147 | new_task_data=task, 148 | process_stage=ProcessStages.WAITING_RETRY, 149 | ) 150 | await self.retry( 151 | answer, 152 | task, 153 | format_video_name, 154 | begin_time, 155 | video_info, 156 | ) 157 | except asyncio.CancelledError: 158 | _LOGGER.info("收到关闭信号,ask_ai处理链关闭") 159 | 160 | async def _on_start(self): 161 | """在启动处理链时先处理一下之前没有处理完的视频""" 162 | _LOGGER.info("正在启动摘要处理链,开始将上次未处理完的视频加入队列") 163 | uncomplete_task = [] 164 | uncomplete_task += self.task_status_recorder.get_record_by_stage( 165 | chain=Chains.ASK_AI 166 | ) # 有坑,这里会把之前运行过的也重新加回来,不过我下面用判断简单补了一手,叫我天才! 167 | for task in uncomplete_task: 168 | if task["process_stage"] != ProcessStages.END.value: 169 | try: 170 | _LOGGER.debug(f"恢复uuid: {task['uuid']} 的任务") 171 | self.ask_ai_queue.put_nowait(BiliGPTTask.model_validate(task)) 172 | except Exception: 173 | traceback.print_exc() 174 | # TODO 这里除了打印日志,是不是还应该记录在视频状态中? 175 | _LOGGER.error(f"在恢复uuid: {task['uuid']} 时出现错误!跳过恢复") 176 | 177 | async def retry(self, ai_answer, task: BiliGPTTask, format_video_name, begin_time, video_info): 178 | """通过重试prompt让chatgpt重新构建json 179 | 180 | :param ai_answer: ai返回的内容 181 | :param task: queue中的原始数据 182 | :param format_video_name: 格式化后的视频名称 183 | :param begin_time: 开始时间 184 | :param video_info: 视频信息 185 | :return: None 186 | """ 187 | _LOGGER.error( 188 | f"任务{task.uuid}:真不好意思!但是我ask_ai部分的retry还没写!所以只能先全给你设置成错误结束了嘿嘿嘿" 189 | ) 190 | await self._set_err_end( 191 | msg="重试代码未实现,跳过", 192 | task=task, 193 | ) 194 | return False 195 | -------------------------------------------------------------------------------- /src/llm/spark.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | import hmac 4 | import json 5 | import traceback 6 | from datetime import datetime 7 | from time import mktime 8 | from typing import Tuple 9 | from urllib.parse import urlencode, urlparse 10 | from wsgiref.handlers import format_date_time 11 | 12 | import websockets 13 | 14 | from src.llm.llm_base import LLMBase 15 | from src.llm.templates import Templates 16 | from src.utils.logging import LOGGER 17 | from src.utils.prompt_utils import build_openai_style_messages, parse_prompt 18 | 19 | _LOGGER = LOGGER.bind(name="spark") 20 | 21 | 22 | class Spark(LLMBase): 23 | def prepare(self): 24 | self._answer_temp = "" # 用于存储讯飞星火大模型的返回结果 25 | self._once_total_tokens = 0 # 用于存储讯飞星火大模型的返回结果的token数 26 | 27 | def create_url(self): 28 | """ 29 | 生成鉴权url 30 | :return: 31 | """ 32 | host = urlparse(self.config.LLMs.spark.spark_url).netloc 33 | path = urlparse(self.config.LLMs.spark.spark_url).path 34 | # 生成RFC1123格式的时间戳 35 | now = datetime.now() 36 | date = format_date_time(mktime(now.timetuple())) 37 | 38 | # 拼接字符串 39 | signature_origin = "host: " + host + "\n" 40 | signature_origin += "date: " + date + "\n" 41 | signature_origin += "GET " + path + " HTTP/1.1" 42 | 43 | # 进行hmac-sha256进行加密 44 | signature_sha = hmac.new( 45 | self.config.LLMs.spark.api_secret.encode("utf-8"), 46 | signature_origin.encode("utf-8"), 47 | digestmod=hashlib.sha256, 48 | ).digest() 49 | 50 | signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") 51 | 52 | authorization_origin = f'api_key="{self.config.LLMs.spark.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' 53 | 54 | authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") 55 | 56 | # 将请求的鉴权参数组合为字典 57 | v = {"authorization": authorization, "date": date, "host": host} 58 | # 拼接鉴权参数,生成url 59 | url = self.config.LLMs.spark.spark_url + "?" + urlencode(v) 60 | _LOGGER.debug(f"生成的url为:{url}") 61 | # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 62 | return url 63 | 64 | async def on_message(self, ws, message) -> int: 65 | """ 66 | 67 | :param ws: 68 | :param message: 69 | :return: 1为还未结束 0为正常结束 2为异常结束 70 | """ 71 | data = json.loads(message) 72 | code = data["header"]["code"] 73 | if code != 0: 74 | _LOGGER.error(f"讯飞星火大模型请求失败: 错误代码:{code} 返回内容:{data}") 75 | await ws.close() 76 | if code == 10013 or code == 10014: 77 | self._once_total_tokens = 0 78 | self._answer_temp = """{"summary":"⚠⚠⚠我也很想告诉你视频的总结,但是星火却跟我说这个视频的总结是***,真的是离谱他🐎给离谱开门——离谱到家了。我也没有办法,谁让星火可以白嫖500w个token🐷。为了白嫖,忍一下,换个视频试一试!","score":"0","thinking":"🤡老子是真的服了这个讯飞星火,国际友好手势(一种动作)。","if_no_need_summary": false}""" 79 | return 0 80 | return 2 81 | else: 82 | choices = data["payload"]["choices"] 83 | status = choices["status"] 84 | content = choices["text"][0]["content"] 85 | self._answer_temp += content 86 | if status == 2: 87 | self._once_total_tokens = data["payload"]["usage"]["text"]["total_tokens"] 88 | await ws.close() 89 | return 0 90 | return 1 91 | 92 | async def completion(self, prompt, **kwargs) -> Tuple[str, int] | None: 93 | try: 94 | self._answer_temp = "" 95 | self._once_total_tokens = 0 96 | ws_url = self.create_url() 97 | 98 | async with websockets.connect(ws_url) as websocket: 99 | websocket.appid = self.config.LLMs.spark.appid 100 | websocket.question = prompt 101 | websocket.domain = self.config.LLMs.spark.domain 102 | 103 | data = json.dumps(self.gen_params(prompt)) 104 | await websocket.send(data) 105 | async for message in websocket: 106 | res = await self.on_message(websocket, message) 107 | if res == 2: 108 | # 如果出现异常,直接返回(上层已经打印过错误,直接返回) 109 | return None 110 | _LOGGER.info( 111 | f"调用讯飞星火大模型成功,返回结果为:{self._answer_temp},本次调用中,prompt+response的长度为{self._once_total_tokens}" 112 | ) 113 | 114 | # 处理返回结果(图省事的方法) 115 | if self._answer_temp.startswith("```json"): 116 | self._answer_temp = self._answer_temp[7:] 117 | if self._answer_temp.endswith("```"): 118 | self._answer_temp = self._answer_temp[:-3] 119 | # 星火返回的json永远是单引号包围的,下面尝试使用eval方式解析 120 | # try: 121 | # _answer = self._answer_temp 122 | # _answer = _answer.replace("true", "True") 123 | # _answer = _answer.replace("false", "False") 124 | # _answer = ast.literal_eval(_answer) # 骚操作 125 | # _answer = json.dumps(_answer, ensure_ascii=False) 126 | # _LOGGER.debug(f"经简单处理后的返回结果为:{_answer}") 127 | # return _answer, self._once_total_tokens 128 | # except Exception as e: 129 | # _LOGGER.error(f"尝试使用eval方式解析星火返回的json失败:{e}") 130 | # traceback.print_exc() 131 | # 如果eval方式解析失败,直接返回 132 | _LOGGER.debug(f"经简单处理后的返回结果为:{self._answer_temp}") 133 | return self._answer_temp, self._once_total_tokens 134 | except Exception as e: 135 | traceback.print_exc() 136 | _LOGGER.error(f"调用讯飞星火大模型失败:{e}") 137 | return None 138 | 139 | def gen_params(self, prompt_list) -> dict: 140 | """ 141 | 通过appid和用户的提问来生成提问参数 142 | 143 | :param prompt_list: 用户的提问 144 | """ 145 | data = { 146 | "header": { 147 | "app_id": self.config.LLMs.spark.appid, 148 | }, 149 | "parameter": { 150 | "chat": { 151 | "domain": self.config.LLMs.spark.domain, 152 | "temperature": 0.5, 153 | "max_tokens": 8192, 154 | } 155 | }, 156 | "payload": {"message": {"text": prompt_list}}, 157 | } 158 | _LOGGER.debug(f"生成的参数为:{data}") 159 | return data 160 | 161 | @staticmethod 162 | def use_template( 163 | user_template_name: Templates, 164 | system_template_name: Templates = None, 165 | user_keyword="user", 166 | system_keyword="system", 167 | **kwargs, 168 | ) -> list | None: 169 | """使用模板生成最终prompt(最终格式可能需要根据llm所需格式不同修改,默认为openai的system、user格式) 170 | :param user_template_name: 用户模板名称 171 | :param system_template_name: 系统模板名称 172 | :param user_keyword: 用户关键词(这个和下面的system_keyword要根据每个llm不同的要求来填) 173 | :param system_keyword: 系统关键词 174 | :param kwargs: 模板参数 175 | :return: 返回生成的prompt 或 None 176 | """ 177 | try: 178 | if system_template_name is None: 179 | template_system = None 180 | elif system_template_name.name == "SUMMARIZE_SYSTEM": 181 | template_system = """ 182 | 你是一个视频摘要制作人。请根据视频的标题、描述、标签、字幕、评论来制作视频摘要。指南如下: 183 | 1. 返回的内容严格遵循此JSON格式:{"summary": "你的摘要内容", "score": "你给这个视频的评分(最高100)", "thinking": "你的想法", "if_no_need_summary": "是否需要摘要?填写布尔值"} 184 | 2. 信息要完整且信息量大。 185 | 3. 保持准确,不要添加虚构信息。 186 | 4. 保持幽默,避免正式,需要时可以使用中文网络语言和表情符号。 187 | 5. 忽略不相关的标签/评论。 188 | 6. 分段落写作。 189 | 7. "score"部分的值必须是字符串类型。 190 | 8. 写作"thinking"部分时,忽略上述规则,表达你的独立思考,包括你自己的激进观点,而非视频内容。评分基于这个视频。如果你认为你创造了一个有意义的摘要,给自己一个高分。 191 | 9. 如果视频毫无意义,将此JSON的"if_no_need_summary"设置为true,否则设置为false。 192 | 10. 返回的内容只允许纯JSON格式,JSON的键和值必须使用英文双引号包裹!请使用简体中文! 193 | """ 194 | else: 195 | template_system = system_template_name.value 196 | if user_template_name.name == "SUMMARIZE_USER": 197 | template_user = ( 198 | """标题:[title]\n\n简介:[description]\n\n字幕:[subtitle]\n\n标签:[tags]\n\n评论:[comments]""" 199 | ) 200 | elif user_template_name.name == "ASK_AI_USER": 201 | template_user = """ 202 | 标题: [title]\n\n简介: [description]\n\n字幕: [subtitle]\n\n用户问题: [question]\n\n 203 | 你是一位专业的视频问答老师。我将提供给你视频的标题、描述和字幕。根据这些信息和你的专业知识,以生动幽默的方式回答用户的问题,必要时使用比喻和例子。 204 | 请按照以下JSON格式回复:{"answer": "你的回答", "score": "你对回答质量的自我评分(0-100)"} 205 | !!!只允许使用双引号的纯JSON内容!请使用中文!不要添加任何其他内容!!! 206 | """ 207 | elif user_template_name.name == "SUMMARIZE_RETRY": 208 | template_user = """请将以下文本翻译成此JSON格式并返回给我,不要添加任何其他内容。如果不存在 'summary' 字段,请将 'if_no_need_summary' 设置为 true。如果除 'summary' 之外的字段缺失,则可以忽略并留空, 'if_no_need_summary' 保持 false\n\n标准JSON格式:{"summary": "您的摘要内容", "score": "您给这个视频的评分(最高100分)", "thinking": "您的想法", "if_no_need_summary": "是否需要摘要?填写布尔值"}\n\n我的内容:[input]""" 209 | else: 210 | template_user = user_template_name.value 211 | utemplate = parse_prompt(template_user, **kwargs) 212 | stemplate = parse_prompt(template_system, **kwargs) if template_system else None 213 | # final_template = utemplate + stemplate if stemplate else utemplate # 特殊处理,system附加到user后面 214 | prompt = ( 215 | build_openai_style_messages(utemplate, stemplate, user_keyword, system_keyword) 216 | if stemplate 217 | else build_openai_style_messages(utemplate, user_keyword=user_keyword) 218 | # build_openai_style_messages(final_template, user_keyword=user_keyword) 219 | ) 220 | _LOGGER.info("使用模板成功") 221 | _LOGGER.debug(f"生成的prompt为:{prompt}") 222 | return prompt 223 | except Exception as e: 224 | _LOGGER.error(f"使用模板失败:{e}") 225 | traceback.print_exc() 226 | return None 227 | -------------------------------------------------------------------------------- /src/chain/summarize.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | import traceback 4 | 5 | import tenacity 6 | import yaml 7 | 8 | from src.bilibili.bili_session import BiliSession 9 | from src.chain.base_chain import BaseChain 10 | from src.llm.templates import Templates 11 | from src.models.task import BiliGPTTask, Chains, ProcessStages, SummarizeAiResponse 12 | from src.utils.callback import chain_callback 13 | from src.utils.logging import LOGGER 14 | 15 | _LOGGER = LOGGER.bind(name="summarize-chain") 16 | 17 | 18 | class Summarize(BaseChain): 19 | """摘要处理链""" 20 | 21 | async def _precheck(self, task: BiliGPTTask) -> bool: 22 | """检查是否满足处理条件""" 23 | match task.source_type: 24 | case "bili_private": 25 | _LOGGER.debug("该消息是私信消息,继续处理") 26 | await BiliSession.quick_send(self.credential, task, "视频已开始处理,你先别急") 27 | return True 28 | case "bili_comment": 29 | _LOGGER.debug("该消息是评论消息,继续处理") 30 | return True 31 | case "api": 32 | _LOGGER.debug("该消息是api消息,继续处理") 33 | return True 34 | case "bili_up": 35 | _LOGGER.debug("该消息是up更新消息,继续处理") 36 | return True 37 | # if task["item"]["type"] != "reply" or task["item"]["business_id"] != 1: 38 | # _LOGGER.warning(f"该消息目前并不支持,跳过处理") 39 | # self._set_err_end(_uuid, "该消息目前并不支持,跳过处理") 40 | # return False 41 | # if task["item"]["root_id"] != 0 or task["item"]["target_id"] != 0: 42 | # _LOGGER.warning(f"该消息是楼中楼消息,暂时不受支持,跳过处理") # TODO 楼中楼消息的处理 43 | # self._set_err_end(_uuid, "该消息是楼中楼消息,暂时不受支持,跳过处理") 44 | # return False 45 | return False 46 | 47 | async def _on_start(self): 48 | """在启动处理链时先处理一下之前没有处理完的视频""" 49 | _LOGGER.info("正在启动摘要处理链,开始将上次未处理完的视频加入队列") 50 | uncomplete_task = [] 51 | uncomplete_task += self.task_status_recorder.get_record_by_stage( 52 | chain=Chains.SUMMARIZE 53 | ) # 有坑,这里会把之前运行过的也重新加回来,不过我下面用判断简单补了一手,叫我天才! 54 | for task in uncomplete_task: 55 | if task["process_stage"] != ProcessStages.END.value: 56 | try: 57 | _LOGGER.debug(f"恢复uuid: {task['uuid']} 的任务") 58 | self.summarize_queue.put_nowait(BiliGPTTask.model_validate(task)) 59 | except Exception: 60 | traceback.print_exc() 61 | # TODO 这里除了打印日志,是不是还应该记录在视频状态中? 62 | _LOGGER.error(f"在恢复uuid: {task['uuid']} 时出现错误!跳过恢复") 63 | # _LOGGER.info(f"之前未处理完的视频已经全部加入队列,共{len(uncomplete_task)}个") 64 | # self.queue_manager.(self.summarize_queue, "summarize") 65 | # _LOGGER.info("正在将上次在队列中的视频加入队列") 66 | # self.task_status_recorder.delete_queue("summarize") 67 | 68 | @tenacity.retry( 69 | retry=tenacity.retry_if_exception_type(Exception), 70 | wait=tenacity.wait_fixed(10), 71 | before_sleep=chain_callback, 72 | ) 73 | async def main(self): 74 | try: 75 | await self._on_start() 76 | while True: 77 | # if self.max_tokens is not None and self.now_tokens >= self.max_tokens: 78 | # _LOGGER.warning( 79 | # f"当前已使用token数{self.now_tokens},超过最大token数{self.max_tokens},摘要处理链停止运行" 80 | # ) 81 | # raise asyncio.CancelledError 82 | 83 | # 从队列中获取摘要 84 | task: BiliGPTTask = await self.summarize_queue.get() 85 | _item_uuid = task.uuid 86 | self._create_record(task) 87 | _LOGGER.info(f"summarize处理链获取到任务了:{task.uuid}") 88 | # 检查是否满足处理条件 89 | if task.process_stage == ProcessStages.END: 90 | _LOGGER.info(f"任务{task.uuid}已经结束,获取下一个") 91 | continue 92 | if not await self._precheck(task): 93 | continue 94 | # 获取视频相关信息 95 | # data = self.task_status_recorder.get_data_by_uuid(_item_uuid) 96 | resp = await self._get_video_info(task) 97 | if resp is None: 98 | continue 99 | ( 100 | video, 101 | video_info, 102 | format_video_name, 103 | video_tags_string, 104 | video_comments, 105 | ) = resp 106 | if task.process_stage in ( 107 | ProcessStages.PREPROCESS, 108 | ProcessStages.WAITING_LLM_RESPONSE, 109 | ): 110 | begin_time = time.perf_counter() 111 | if await self._is_cached_video(task, _item_uuid, video_info): 112 | continue 113 | # 处理视频音频流和字幕 114 | _LOGGER.debug("视频信息获取成功,正在获取视频音频流和字幕") 115 | if task.subtitle is not None: 116 | text = task.subtitle 117 | _LOGGER.debug("使用字幕缓存,开始使用模板生成prompt") 118 | else: 119 | text = await self._smart_get_subtitle(video, _item_uuid, format_video_name, task) 120 | if text is None: 121 | continue 122 | task.subtitle = text 123 | _LOGGER.info( 124 | f"视频{format_video_name}音频流和字幕处理完成,共用时{time.perf_counter() - begin_time}s,开始调用LLM生成摘要" 125 | ) 126 | self.task_status_recorder.update_record( 127 | _item_uuid, 128 | new_task_data=task, 129 | process_stage=ProcessStages.WAITING_LLM_RESPONSE, 130 | ) 131 | llm = self.llm_router.get_one() 132 | if llm is None: 133 | _LOGGER.warning("没有可用的LLM,关闭系统") 134 | await self._set_err_end(msg="没有可用的LLM,被迫结束处理", task=task) 135 | self.stop_event.set() 136 | continue 137 | prompt = llm.use_template( 138 | Templates.SUMMARIZE_USER, 139 | Templates.SUMMARIZE_SYSTEM, 140 | title=video_info["title"], 141 | tags=video_tags_string, 142 | comments=video_comments, 143 | subtitle=text, 144 | description=video_info["desc"], 145 | ) 146 | _LOGGER.debug("prompt生成成功,开始调用llm") 147 | # 调用openai的Completion API 148 | response = await llm.completion(prompt) 149 | if response is None: 150 | _LOGGER.warning(f"任务{task.uuid}:ai未返回任何内容,请自行检查问题,跳过处理") 151 | await self._set_err_end( 152 | msg="AI未返回任何内容,我也不知道为什么,估计是调休了吧。换个视频或者等一小会儿再试一试。", 153 | task=task, 154 | ) 155 | self.llm_router.report_error(llm.alias) 156 | continue 157 | answer, tokens = response 158 | self.now_tokens += tokens 159 | _LOGGER.debug(f"llm输出内容为:{answer}") 160 | _LOGGER.debug("调用llm成功,开始处理结果") 161 | task.process_result = answer 162 | task.process_stage = ProcessStages.WAITING_SEND 163 | self.task_status_recorder.update_record(_item_uuid, task) 164 | if task.process_stage in ( 165 | ProcessStages.WAITING_SEND, 166 | ProcessStages.WAITING_RETRY, 167 | ): 168 | begin_time = time.perf_counter() 169 | answer = task.process_result 170 | # obj, _type = await video.get_video_obj() 171 | # 处理结果 172 | if answer: 173 | try: 174 | if task.process_stage == ProcessStages.WAITING_RETRY: 175 | raise Exception("触发重试") 176 | 177 | answer = answer.replace("False", "false") # 解决一部分因为大小写问题导致的json解析失败 178 | answer = answer.replace("True", "true") 179 | 180 | ai_resp = yaml.safe_load(answer) 181 | ai_resp["score"] = str(ai_resp["score"]) # 预防返回的值类型为int,强转成str 182 | task.process_result = SummarizeAiResponse.model_validate(ai_resp) 183 | if task.process_result.if_no_need_summary is True: 184 | _LOGGER.warning(f"视频{format_video_name}被ai判定为不需要摘要,跳过处理") 185 | await BiliSession.quick_send( 186 | self.credential, 187 | task, 188 | "AI觉得你的视频不需要处理,换个更有意义的视频再试试看吧!", 189 | ) 190 | # await BiliSession.quick_send( 191 | # self.credential, task, answer 192 | # ) 193 | await self._set_noneed_end(task) 194 | continue 195 | _LOGGER.info( 196 | f"ai返回内容解析正确,视频{format_video_name}摘要处理完成,共用时{time.perf_counter() - begin_time}s" 197 | ) 198 | await self.finish(task) 199 | 200 | except Exception as e: 201 | _LOGGER.error(f"处理结果失败:{e},大概是ai返回的格式不对,尝试修复") 202 | traceback.print_tb(e.__traceback__) 203 | self.task_status_recorder.update_record( 204 | _item_uuid, 205 | new_task_data=task, 206 | process_stage=ProcessStages.WAITING_RETRY, 207 | ) 208 | await self.retry( 209 | answer, 210 | task, 211 | format_video_name, 212 | begin_time, 213 | video_info, 214 | ) 215 | except asyncio.CancelledError: 216 | _LOGGER.info("收到关闭信号,摘要处理链关闭") 217 | 218 | async def retry(self, ai_answer, task: BiliGPTTask, format_video_name, begin_time, video_info): 219 | """通过重试prompt让chatgpt重新构建json 220 | 221 | :param ai_answer: ai返回的内容 222 | :param task: queue中的原始数据 223 | :param format_video_name: 格式化后的视频名称 224 | :param begin_time: 开始时间 225 | :param video_info: 视频信息 226 | :return: None 227 | """ 228 | _LOGGER.debug(f"任务{task.uuid}:ai返回内容解析失败,正在尝试重试") 229 | task.gmt_retry_start = int(time.time()) 230 | llm = self.llm_router.get_one() 231 | if llm is None: 232 | _LOGGER.warning("没有可用的LLM,关闭系统") 233 | await self._set_err_end(msg="没有可用的LLM,跳过处理", task=task) 234 | self.stop_event.set() 235 | return False 236 | prompt = llm.use_template(Templates.SUMMARIZE_RETRY, input=ai_answer) 237 | response = await llm.completion(prompt) 238 | if response is None: 239 | _LOGGER.warning(f"视频{format_video_name}摘要生成失败,请自行检查问题,跳过处理") 240 | await self._set_err_end( 241 | msg=f"视频{format_video_name}摘要生成失败,请自行检查问题,跳过处理", 242 | task=task, 243 | ) 244 | self.llm_router.report_error(llm.alias) 245 | return False 246 | answer, tokens = response 247 | _LOGGER.debug(f"api输出内容为:{answer}") 248 | self.now_tokens += tokens 249 | if answer: 250 | try: 251 | resp = yaml.safe_load(answer) 252 | resp["score"] = str(resp["score"]) 253 | task.process_result = SummarizeAiResponse.model_validate(resp) 254 | if task.process_result.if_no_need_summary is True: 255 | _LOGGER.warning(f"视频{format_video_name}被ai判定为不需要摘要,跳过处理") 256 | await self._set_noneed_end(task) 257 | return False 258 | else: 259 | # TODO 这种运行时间的显示存在很大问题,有空了统一一下,但现在我是没空了 260 | _LOGGER.info( 261 | f"ai返回内容解析正确,视频{format_video_name}摘要处理完成,共用时{time.perf_counter() - begin_time}s" 262 | ) 263 | await self.finish(task) 264 | return True 265 | except Exception as e: 266 | _LOGGER.error(f"处理结果失败:{e},大概是ai返回的格式不对,拿你没辙了,跳过处理") 267 | traceback.print_exc() 268 | await self._set_err_end( 269 | msg="重试后处理结果失败,大概是ai返回的格式不对,跳过", 270 | task=task, 271 | ) 272 | return False 273 | -------------------------------------------------------------------------------- /src/chain/base_chain.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import asyncio 3 | import os 4 | import time 5 | from typing import Optional 6 | 7 | import ffmpeg 8 | import httpx 9 | from bilibili_api import HEADERS 10 | from injector import inject 11 | 12 | from src.bilibili.bili_comment import BiliComment 13 | from src.bilibili.bili_credential import BiliCredential 14 | from src.bilibili.bili_session import BiliSession 15 | from src.bilibili.bili_video import BiliVideo 16 | from src.core.routers.asr_router import ASRouter 17 | from src.core.routers.llm_router import LLMRouter 18 | from src.models.config import Config 19 | from src.models.task import ( 20 | AskAIResponse, 21 | BiliGPTTask, 22 | EndReasons, 23 | ProcessStages, 24 | SummarizeAiResponse, 25 | ) 26 | from src.utils.cache import Cache 27 | from src.utils.logging import LOGGER 28 | from src.utils.queue_manager import QueueManager 29 | from src.utils.task_status_record import TaskStatusRecorder 30 | 31 | 32 | class BaseChain: 33 | """处理链基类 34 | 对于b站来说,处理链需要接管的内容基本都要包含对视频基本信息的处理和字幕的提取,这个基类全部帮你做了 35 | """ 36 | 37 | @inject 38 | def __init__( 39 | self, 40 | queue_manager: QueueManager, 41 | config: Config, 42 | credential: BiliCredential, 43 | cache: Cache, 44 | asr_router: ASRouter, 45 | task_status_recorder: TaskStatusRecorder, 46 | stop_event: asyncio.Event, 47 | llm_router: LLMRouter, 48 | ): 49 | self.llm_router = llm_router 50 | self.queue_manager = queue_manager 51 | self.config = config 52 | self.cache = cache 53 | self.asr_router = asr_router 54 | self.now_tokens = 0 55 | self.credential = credential 56 | self.task_status_recorder = task_status_recorder 57 | self._get_variables() 58 | self._get_queues() 59 | self.asr = asr_router.get_one() 60 | self._LOGGER = LOGGER.bind(name=self.__class__.__name__) 61 | self.stop_event = stop_event 62 | 63 | def _get_variables(self): 64 | """从Config获取配置信息""" 65 | self.temp_dir = self.config.storage_settings.temp_dir 66 | self.api_key = self.config.LLMs.openai.api_key 67 | self.api_base = self.config.LLMs.openai.api_base 68 | 69 | def _get_queues(self): 70 | """从队列管理器获取队列""" 71 | self.summarize_queue = self.queue_manager.get_queue("summarize") 72 | self.reply_queue = self.queue_manager.get_queue("reply") 73 | self.private_queue = self.queue_manager.get_queue("private") 74 | self.ask_ai_queue = self.queue_manager.get_queue("ask_ai") 75 | 76 | async def _set_err_end(self, msg: str, _uuid: str = None, task: BiliGPTTask = None): 77 | """当一个视频因为错误而结束时,调用此方法 78 | 79 | :param msg: 错误信息 80 | :param _uuid: 任务uuid (跟task二选一) 81 | :param task: 任务对象 82 | """ 83 | self.task_status_recorder.update_record( 84 | _uuid if _uuid else task.uuid, 85 | new_task_data=None, 86 | process_stage=ProcessStages.END, 87 | end_reason=EndReasons.ERROR, 88 | gmt_end=int(time.time()), 89 | error_msg=msg, 90 | ) 91 | _task = self.task_status_recorder.get_data_by_uuid(_uuid) if _uuid else task 92 | match _task.source_type: 93 | case "bili_private": 94 | self._LOGGER.debug(f"任务{task.uuid}:私信消息,直接回复:{msg}") 95 | await BiliSession.quick_send( 96 | self.credential, 97 | task, 98 | msg, 99 | ) 100 | case "bili_comment": 101 | _task.process_result = msg 102 | self._LOGGER.debug(f"任务{task.uuid}:评论消息,将结果放入评论处理队列,内容:{msg}") 103 | await self.reply_queue.put(task) 104 | case "api": 105 | self._LOGGER.warning(f"任务{task.uuid}:api获取的消息,未实现处理逻辑") 106 | case "bili_up": 107 | _task.process_result = msg 108 | self._LOGGER.debug(f"任务{task.uuid}:评论消息,将结果放入评论处理队列,内容:{msg}") 109 | await self.reply_queue.put(task) 110 | 111 | async def _set_normal_end(self, task: BiliGPTTask = None, _uuid: str = None): 112 | """当一个视频正常结束时,调用此方法 113 | 114 | :param task: 任务对象 115 | :param _uuid: 任务uuid (跟task二选一) 116 | """ 117 | self.task_status_recorder.update_record( 118 | _uuid if _uuid else task.uuid, 119 | new_task_data=None, 120 | process_stage=ProcessStages.END, 121 | end_reason=EndReasons.NORMAL, 122 | gmt_end=int(time.time()), 123 | ) 124 | 125 | async def _set_noneed_end(self, task: BiliGPTTask = None, _uuid: str = None): 126 | """当一个视频不需要处理时,调用此方法 127 | 128 | :param task: 任务对象 129 | :param _uuid: 任务uuid (跟task二选一) 130 | """ 131 | self.task_status_recorder.update_record( 132 | _uuid if _uuid else task.uuid, 133 | new_task_data=None, 134 | process_stage=ProcessStages.END, 135 | end_reason=EndReasons.NONEED, 136 | gmt_end=int(time.time()), 137 | ) 138 | await BiliSession.quick_send( 139 | self.credential, 140 | task, 141 | "AI觉得你的视频不需要处理,换个更有意义的视频再试试看吧!", 142 | ) 143 | 144 | @abc.abstractmethod 145 | async def _precheck(self, task: BiliGPTTask) -> bool: 146 | """检查是否符合调用条件 147 | :param task: AtItem 148 | 149 | 如不符合,请务必调用self._set_err_end()方法后返回False 150 | """ 151 | pass 152 | 153 | async def finish(self, task: BiliGPTTask, use_cache: bool = False) -> bool: 154 | """ 155 | 当一个任务 **正常** 结束时,调用这个,将消息放入队列、设置缓存、更新任务状态 156 | :param task: 157 | :param use_cache: 是否直接使用缓存而非正常处理 158 | :return: 159 | """ 160 | _LOGGER = self._LOGGER 161 | reply_data = task 162 | # if reply_data.source_type == "bili_private": 163 | # _LOGGER.debug("该消息是私信消息,将结果放入私信处理队列") 164 | # await self.private_queue.put(reply_data) 165 | # elif reply_data.source_type == "bili_comment": 166 | # _LOGGER.debug("正在将结果加入发送队列,等待回复") 167 | # await self.reply_queue.put(reply_data) 168 | match reply_data.source_type: 169 | case "bili_private": 170 | _LOGGER.debug(f"任务{task.uuid}:私信消息,将结果放入私信处理队列") 171 | await self.private_queue.put(reply_data) 172 | case "bili_comment": 173 | _LOGGER.info(f"任务{task.uuid}:评论消息,将结果放入评论处理队列") 174 | await self.reply_queue.put(reply_data) 175 | case "api": 176 | _LOGGER.warning(f"任务{task.uuid}:api获取的消息,未实现处理逻辑") 177 | case "bili_up": 178 | _LOGGER.info(f"任务{task.uuid}:评论消息,将结果放入评论处理队列") 179 | await self.reply_queue.put(reply_data) 180 | _LOGGER.debug("处理结束,开始清理并提交记录") 181 | self.task_status_recorder.update_record( 182 | reply_data.uuid, 183 | new_task_data=task, 184 | process_stage=ProcessStages.WAITING_PUSH_TO_CACHE, 185 | ) 186 | if use_cache: 187 | await self._set_normal_end(task) 188 | return True 189 | self.cache.set_cache( 190 | key=reply_data.video_id, 191 | value=reply_data.process_result.model_dump(), 192 | chain=str(task.chain.value), 193 | ) 194 | await self._set_normal_end(task) 195 | return True 196 | 197 | async def _is_cached_video(self, task: BiliGPTTask, _uuid: str, video_info: dict) -> bool: 198 | """检查是否是缓存的视频 199 | 如果是缓存的视频,直接从缓存中获取结果并发送 200 | """ 201 | if self.cache.get_cache(key=video_info["bvid"], chain=str(task.chain.value)): 202 | LOGGER.debug(f"视频{video_info['title']}已经处理过,直接使用缓存") 203 | cache = self.cache.get_cache(key=video_info["bvid"], chain=str(task.chain.value)) 204 | # if str(task.chain.value) == "summarize": 205 | # cache = SummarizeAiResponse.model_validate(cache) 206 | # elif str(task.chain.value) == "ask_ai": 207 | # cache = AskAIResponse.model_validate(cache) 208 | match str(task.chain.value): 209 | case "summarize": 210 | cache = SummarizeAiResponse.model_validate(cache) 211 | case "ask_ai": 212 | cache = AskAIResponse.model_validate(cache) 213 | case _: 214 | self._LOGGER.error( 215 | f"获取到了缓存,但无法匹配处理链{task.chain.value},无法调取缓存,开始按正常流程处理" 216 | ) 217 | return False 218 | match task.source_type: 219 | case "bili_private": 220 | task.process_result = cache 221 | await self.finish(task, True) 222 | case "bili_comment": 223 | task.process_result = cache 224 | await self.finish(task, True) 225 | case "bili_up": 226 | task.process_result = cache 227 | await self.finish(task, True) 228 | return True 229 | return False 230 | 231 | async def _get_video_info( 232 | self, task: BiliGPTTask, if_get_comments: bool = True 233 | ) -> Optional[tuple[BiliVideo, dict, str, str, Optional[str]]]: 234 | """获取视频的一些信息 235 | :param task: 任务对象 236 | :param if_get_comments: 是否获取评论,为假就返回空 237 | 238 | :return 视频正常返回元组(video, video_info, format_video_name, video_tags_string, video_comments) 239 | 240 | video: BiliVideo对象 241 | video_info: bilibili官方api返回的视频信息 242 | format_video_name: 格式化后的视频名,用于日志 243 | video_tags_string: 视频标签 244 | video_comments: 随机获取的几条视频评论拼接的字符串 245 | """ 246 | _LOGGER = self._LOGGER 247 | _LOGGER.info("开始处理该视频音频流和字幕") 248 | video = BiliVideo(self.credential, url=task.video_url) 249 | _LOGGER.debug("视频对象创建成功,正在获取视频信息") 250 | video_info = await video.get_video_info 251 | _LOGGER.debug("视频信息获取成功,正在获取视频标签") 252 | format_video_name = f"『{video_info['title']}』" 253 | # TODO 不清楚b站回复和at时分P的展现机制,暂时遇到分P视频就跳过 254 | if len(video_info["pages"]) > 1: 255 | _LOGGER.warning(f"任务{task.uuid}: 视频{format_video_name}分P,跳过处理") 256 | await self._set_err_end(msg="视频分P,跳过处理", task=task) 257 | return None 258 | # 获取视频标签 259 | video_tags_string = " ".join(f"#{tag['tag_name']}" for tag in await video.get_video_tags()) 260 | _LOGGER.debug("视频标签获取成功,开始获取视频评论") 261 | # 获取视频评论 262 | video_comments = ( 263 | await BiliComment.get_random_comment(video_info["aid"], self.credential) if if_get_comments else None 264 | ) 265 | return video, video_info, format_video_name, video_tags_string, video_comments 266 | 267 | async def _get_subtitle_from_bilibili(self, video: BiliVideo) -> str: 268 | """从bilibili获取字幕(返回的是纯字幕,不包含时间轴)""" 269 | _LOGGER = self._LOGGER 270 | subtitle_url = await video.get_video_subtitle(page_index=0) 271 | _LOGGER.debug("视频字幕获取成功,正在读取字幕") 272 | # 下载字幕 273 | async with httpx.AsyncClient() as client: 274 | resp = await client.get("https:" + subtitle_url, headers=HEADERS) 275 | _LOGGER.debug("字幕获取成功,正在转换为纯字幕") 276 | # 转换字幕格式 277 | text = "" 278 | for subtitle in resp.json()["body"]: 279 | text += f"{subtitle['content']}\n" 280 | return text 281 | 282 | async def _get_subtitle_from_asr(self, video: BiliVideo, _uuid: str, is_retry: bool = False) -> Optional[str]: 283 | _LOGGER = self._LOGGER 284 | if self.asr is None: 285 | _LOGGER.warning("没有可用的asr,跳过处理") 286 | await self._set_err_end(_uuid=_uuid, msg="没有可用的asr,跳过处理") 287 | if is_retry: 288 | # 如果是重试,就默认已下载音频文件,直接开始转写 289 | bvid = await video.bvid 290 | audio_path = f"{self.temp_dir}/{bvid} temp.mp3" 291 | self.asr = self.asr_router.get_one() # 重新获取一个,防止因为错误而被禁用,但调用端没及时更新 292 | if self.asr is None: 293 | _LOGGER.warning("没有可用的asr,跳过处理") 294 | await self._set_err_end(_uuid, "没有可用的asr,跳过处理") 295 | text = await self.asr.transcribe(audio_path) 296 | if text is None: 297 | _LOGGER.warning("音频转写失败,报告并重试") 298 | self.asr_router.report_error(self.asr.alias) 299 | await self._get_subtitle_from_asr(video, _uuid, is_retry=True) # 递归,应该不会爆栈 300 | return text 301 | _LOGGER.debug("正在获取视频音频流") 302 | video_download_url = await video.get_video_download_url() 303 | audio_url = video_download_url["dash"]["audio"][0]["baseUrl"] 304 | _LOGGER.debug("视频下载链接获取成功,正在下载视频中的音频流") 305 | bvid = await video.bvid 306 | # 下载视频中的音频流 307 | async with httpx.AsyncClient() as client: 308 | resp = await client.get(audio_url, headers=HEADERS) 309 | temp_dir = self.temp_dir 310 | if not os.path.exists(temp_dir): 311 | os.mkdir(temp_dir) 312 | with open(f"{temp_dir}/{bvid} temp.m4s", "wb") as f: 313 | f.write(resp.content) 314 | _LOGGER.debug("视频中的音频流下载成功,正在转换音频格式") 315 | # 转换音频格式 316 | (ffmpeg.input(f"{temp_dir}/{bvid} temp.m4s").output(f"{temp_dir}/{bvid} temp.mp3").run(overwrite_output=True)) 317 | _LOGGER.debug("音频格式转换成功,正在使用whisper转写音频") 318 | # 使用whisper转写音频 319 | audio_path = f"{temp_dir}/{bvid} temp.mp3" 320 | text = await self.asr.transcribe(audio_path) 321 | if text is None: 322 | _LOGGER.warning("音频转写失败,报告并重试") 323 | self.asr_router.report_error(self.asr.alias) 324 | await self._get_subtitle_from_asr(video, _uuid, is_retry=True) # 递归,应该不会爆栈 325 | _LOGGER.debug("音频转写成功,正在删除临时文件") 326 | # 删除临时文件 327 | os.remove(f"{temp_dir}/{bvid} temp.m4s") 328 | os.remove(f"{temp_dir}/{bvid} temp.mp3") 329 | _LOGGER.debug("临时文件删除成功") 330 | return text 331 | 332 | async def _smart_get_subtitle( 333 | self, video: BiliVideo, _uuid: str, format_video_name: str, task: BiliGPTTask 334 | ) -> Optional[str]: 335 | """根据用户配置智能获取字幕""" 336 | _LOGGER = self._LOGGER 337 | subtitle_url = await video.get_video_subtitle(page_index=0) 338 | if subtitle_url is None: 339 | if self.asr is None: 340 | _LOGGER.warning(f"视频{format_video_name}没有字幕,你没有可用的asr,跳过处理") 341 | await self._set_err_end(_uuid, "视频没有字幕,你没有可用的asr,跳过处理") 342 | return None 343 | _LOGGER.warning(f"视频{format_video_name}没有字幕,开始使用asr转写,这可能会导致字幕质量下降") 344 | text = await self._get_subtitle_from_asr(video, _uuid) 345 | task.subtitle = text 346 | self.task_status_recorder.update_record(_uuid, new_task_data=task, use_whisper=True) 347 | return text 348 | _LOGGER.debug(f"视频{format_video_name}有字幕,开始处理") 349 | text = await self._get_subtitle_from_bilibili(video) 350 | return text 351 | 352 | def _create_record(self, task: BiliGPTTask) -> str: 353 | """创建(或查询)一条任务记录,返回uuid""" 354 | task.gmt_start_process = int(time.time()) 355 | _item_uuid = self.task_status_recorder.create_record(task) 356 | return _item_uuid 357 | 358 | @abc.abstractmethod 359 | async def main(self): 360 | """ 361 | 处理链主函数 362 | 捕获错误的最佳实践是使用tenacity.retry装饰器,callback也已经写好了,就在utils.callback中 363 | 如果实现_on_start的话别忘了在循环代码前调用 364 | 365 | eg: 366 | @tenacity.retry( 367 | retry=tenacity.retry_if_exception_type(Exception), 368 | wait=tenacity.wait_fixed(10), 369 | before_sleep=chain_callback 370 | ) 371 | :return: 372 | """ 373 | pass 374 | 375 | @abc.abstractmethod 376 | async def _on_start(self): 377 | """ 378 | 在这里完成一些初始化任务,比如将已保存的未处理任务恢复到队列中 379 | 记得在main中调用啊! 380 | """ 381 | pass 382 | 383 | @abc.abstractmethod 384 | async def retry(self, *args, **kwargs): 385 | """ 386 | 重试函数,你可以在这里写重试逻辑 387 | 这要求你必须在main函数中捕捉错误并进行调用该函数 388 | 因为llm的返回内容具有不稳定性,所以我强烈建议你实现这个函数。 389 | """ 390 | pass 391 | 392 | def __repr__(self): 393 | return self.__class__.__name__ 394 | -------------------------------------------------------------------------------- /src/listener/bili_listen.py: -------------------------------------------------------------------------------- 1 | """监听bilibili平台的私信、at消息""" 2 | 3 | import asyncio 4 | import os 5 | import re 6 | import time 7 | import traceback 8 | from copy import deepcopy 9 | from datetime import datetime 10 | 11 | from apscheduler.schedulers.asyncio import AsyncIOScheduler 12 | from bilibili_api import session, user 13 | from injector import inject 14 | 15 | from src.bilibili.bili_credential import BiliCredential 16 | from src.bilibili.bili_video import BiliVideo 17 | from src.core.routers.chain_router import ChainRouter 18 | from src.models.config import Config 19 | from src.models.task import BiliAtSpecialAttributes, BiliGPTTask 20 | from src.utils.logging import LOGGER 21 | from src.utils.queue_manager import QueueManager 22 | from src.utils.up_video_cache import get_up_file, load_cache, set_cache 23 | 24 | _LOGGER = LOGGER.bind(name="bilibili-listener") 25 | 26 | 27 | class Listen: 28 | @inject 29 | def __init__( 30 | self, 31 | credential: BiliCredential, 32 | queue_manager: QueueManager, 33 | # value_manager: GlobalVariablesManager, 34 | config: Config, 35 | schedule: AsyncIOScheduler, 36 | chain_router: ChainRouter, 37 | ): 38 | self.sess = None 39 | self.credential = credential 40 | self.summarize_queue = queue_manager.get_queue("summarize") 41 | self.evaluate_queue = queue_manager.get_queue("evaluate") 42 | self.last_at_time = int(time.time()) # 当前时间作为初始时间戳 43 | self.sched = schedule 44 | self.user_sessions = {} # 存储用户状态和视频信息 45 | self.config = config 46 | self.chain_router = chain_router 47 | # 以下是自动查询视频更新的参数 48 | # self.cache_path = './data/video_cache.json' 49 | # self.up_file_path = './data/up.json' 50 | self.chain_router = chain_router 51 | self.uids = {} 52 | self.video_cache = {} 53 | 54 | async def listen_at(self): 55 | # global run_time 56 | data: dict = await session.get_at(self.credential) 57 | _LOGGER.debug(f"获取at消息成功,内容为:{data}") 58 | 59 | # if len(data["items"]) != 0: 60 | # if run_time > 2: 61 | # return 62 | # _LOGGER.warning(f"目前处于debug状态,将直接处理第一条at消息") 63 | # await self.dispatch_task(data["items"][0]) 64 | # run_time += 1 65 | # return 66 | 67 | # 判断是否有新消息 68 | if len(data["items"]) == 0: 69 | _LOGGER.debug("没有新消息,返回") 70 | return 71 | if self.last_at_time >= data["items"][0]["at_time"]: 72 | _LOGGER.debug( 73 | f"last_at_time{self.last_at_time}大于或等于当前最新消息的at_time{data['items'][0]['at_time']},返回" 74 | ) 75 | return 76 | 77 | new_items = [] 78 | for item in reversed(data["items"]): 79 | if item["at_time"] > self.last_at_time: 80 | _LOGGER.debug(f"at_time{item['at_time']}大于last_at_time{self.last_at_time},放入新消息队列") 81 | # item["user"] = data["items"]["user"] 82 | new_items.append(item) 83 | if len(new_items) == 0: 84 | _LOGGER.debug("没有新消息,返回") 85 | return 86 | _LOGGER.info(f"检测到{len(new_items)}条新消息,开始处理") 87 | for item in new_items: 88 | task_metadata = await self.build_task_from_at_msg(item) 89 | if task_metadata is None: 90 | continue 91 | await self.chain_router.dispatch_a_task(task_metadata) 92 | 93 | self.last_at_time = data["items"][0]["at_time"] 94 | 95 | async def build_task_from_at_msg(self, msg: dict) -> BiliGPTTask | None: 96 | # print(msg) 97 | try: 98 | event = deepcopy(msg) 99 | if msg["item"]["type"] != "reply" or msg["item"]["business_id"] != 1: 100 | _LOGGER.warning("不是回复消息,跳过") 101 | return None 102 | elif msg["item"]["root_id"] != 0 or msg["item"]["target_id"] != 0: 103 | _LOGGER.warning("该消息是楼中楼消息,暂时不受支持,跳过处理") 104 | return None 105 | event["source_type"] = "bili_comment" 106 | event["raw_task_data"] = deepcopy(msg) 107 | event["source_extra_attr"] = BiliAtSpecialAttributes.model_validate(event["item"]) 108 | event["sender_id"] = str(event["user"]["mid"]) 109 | event["video_url"] = event["item"]["uri"] 110 | event["source_command"] = event["item"]["source_content"] 111 | # event["mission"] = False 112 | event["video_id"] = await BiliVideo(credential=self.credential, url=event["item"]["uri"]).bvid 113 | task_metadata = BiliGPTTask.model_validate(event) 114 | except Exception: 115 | traceback.print_exc() 116 | _LOGGER.error("在验证任务数据结构时出现错误,跳过处理!") 117 | return None 118 | 119 | return task_metadata 120 | 121 | def start_listen_at(self): 122 | self.sched.add_job( 123 | self.listen_at, 124 | trigger="interval", 125 | seconds=20, # 有新任务都会一次性提交,时间无所谓 126 | id="listen_at", 127 | max_instances=3, 128 | next_run_time=datetime.now(), 129 | ) 130 | # self.sched.start() 131 | _LOGGER.info("[定时任务]侦听at消息定时任务注册成功, 每20秒检查一次") 132 | 133 | def start_video_mission(self): 134 | self.sched.add_job( 135 | self.async_video_list_mission, 136 | trigger="interval", 137 | minutes=60, # minutes=60, 138 | id="video_list_mission", 139 | max_instances=3, 140 | next_run_time=datetime.now(), 141 | ) 142 | _LOGGER.info("[定时任务]侦听up视频更新任务注册成功, 每60分钟检查一次") 143 | 144 | async def async_video_list_mission(self): 145 | _LOGGER.info("开始执行获取UP的最新视频") 146 | self.video_cache = load_cache(self.config.storage_settings.up_video_cache) 147 | self.uids = get_up_file(self.config.storage_settings.up_file) 148 | for item in self.uids: 149 | u = user.User(uid=item["uid"]) 150 | try: 151 | media_list = await u.get_media_list(ps=1, desc=True) 152 | except Exception: 153 | traceback.print_exc() 154 | _LOGGER.error(f"在获取 uid{item} 的视频列表时出错!") 155 | return None 156 | media = media_list["media_list"][0] 157 | bv_id = media["bv_id"] 158 | _LOGGER.info(f"当前视频的bvid:{bv_id}") 159 | oid = media["id"] 160 | _LOGGER.info(f"当前视频的oid:{oid}") 161 | if str(item["uid"]) in self.video_cache: 162 | cache_bvid = self.video_cache[str(item["uid"])]["bv_id"] 163 | _LOGGER.info(f"缓存文件中的bvid:{cache_bvid}") 164 | if cache_bvid != bv_id: 165 | _LOGGER.info(f"up有视频更新,视频信息为:\n 作者:{item['username']} 标题:{media['title']}") 166 | # 将视频信息传递给消息队列 167 | task_metadata = await self.build_task_from_at_mission(media) 168 | if task_metadata is None: 169 | continue 170 | await self.chain_router.dispatch_a_task(task_metadata) 171 | set_cache( 172 | self.config.storage_settings.up_video_cache, 173 | self.video_cache, 174 | {"bv_id": bv_id, "oid": oid}, 175 | str(item["uid"]), 176 | ) 177 | 178 | else: 179 | _LOGGER.info("up没有视频更新") 180 | 181 | else: 182 | _LOGGER.info("缓存文件为空,第一次写入数据") 183 | # self.set_cache({'bv_id': media, 'oid': oid}, str(item['uid'])) 184 | _LOGGER.info("休息20秒") 185 | await asyncio.sleep(20) 186 | 187 | async def build_task_from_at_mission(self, msg: dict) -> BiliGPTTask | None: 188 | # print(msg) 189 | try: 190 | # event = deepcopy(msg) 191 | event: dict = {} 192 | event["source_type"] = "bili_up" 193 | event["raw_task_data"] = { 194 | "user": { 195 | "mid": self.config.bilibili_cookie.dedeuserid, 196 | "nickname": self.config.bilibili_self.nickname, 197 | } 198 | } 199 | event["source_extra_attr"] = BiliAtSpecialAttributes.model_validate( 200 | { 201 | "source_id": msg["id"], 202 | "target_id": 0, 203 | "root_id": 0, 204 | "native_uri": msg["link"], 205 | "at_details": [ 206 | { 207 | "mid": self.config.bilibili_cookie.dedeuserid, 208 | "fans": 0, 209 | "nickname": self.config.bilibili_self.nickname, 210 | "avatar": "http://i1.hdslb.com/bfs/face/d21cf99c96dfdca5e38106c00eb338dd150b4b65.jpg", 211 | "mid_link": "", 212 | "follow": False, 213 | } 214 | ], 215 | } 216 | ) 217 | event["sender_id"] = self.config.bilibili_cookie.dedeuserid 218 | event["video_url"] = msg["short_link"] 219 | event["source_command"] = f"@{self.config.bilibili_self.nickname} 总结一下" 220 | event["video_id"] = msg["bv_id"] 221 | # event["mission"] = True 222 | task_metadata = BiliGPTTask.model_validate(event) 223 | except Exception: 224 | traceback.print_exc() 225 | _LOGGER.error("在验证任务数据结构时出现错误,跳过处理!") 226 | return None 227 | 228 | return task_metadata 229 | 230 | async def build_task_from_private_msg(self, msg: dict) -> BiliGPTTask | None: 231 | try: 232 | event = deepcopy(msg) 233 | bvid = event["video_event"]["content"] 234 | uri = "https://bilibili.com/video/" + bvid 235 | event["source_type"] = "bili_private" 236 | event["raw_task_data"] = deepcopy(msg) 237 | event["raw_task_data"]["video_event"]["content"] = bvid 238 | event["sender_id"] = event["text_event"]["sender_uid"] 239 | event["video_url"] = uri 240 | event["source_command"] = event["text_event"]["content"][12:] # 去除掉bv号 241 | event["video_id"] = bvid 242 | # event["mission"] = False 243 | del event["video_event"] 244 | del event["text_event"] 245 | del event["status"] 246 | task_metadata = BiliGPTTask.model_validate(event) 247 | except Exception: 248 | traceback.print_exc() 249 | _LOGGER.error("在验证任务数据结构时出现错误,跳过处理!") 250 | return None 251 | 252 | return task_metadata 253 | 254 | async def handle_video(self, user_id, event): 255 | _session = self.user_sessions.get(user_id, {"status": "idle", "text_event": {}, "video_event": {}}) 256 | match _session["status"]: 257 | case "idle" | "waiting_for_keyword": 258 | _session["status"] = "waiting_for_keyword" 259 | _session["video_event"] = event 260 | _session["video_event"]["content"] = _session["video_event"]["content"].get_bvid() 261 | 262 | case "waiting_for_video": 263 | _session["video_event"] = event 264 | _session["video_event"]["content"] = _session["video_event"]["content"].get_bvid() 265 | at_items = await self.build_task_from_private_msg(_session) 266 | if at_items is None: 267 | return 268 | await self.chain_router.dispatch_a_task(at_items) 269 | _session["status"] = "idle" 270 | _session["text_event"] = {} 271 | _session["video_event"] = {} 272 | case _: 273 | pass 274 | self.user_sessions[user_id] = _session 275 | 276 | async def handle_text(self, user_id, event): 277 | # _session = PrivateMsgSession(self.user_sessions.get( 278 | # user_id, {"status": "idle", "text_event": {}, "video_event": {}} 279 | # )) 280 | _session = ( 281 | self.user_sessions[user_id] 282 | if self.user_sessions.get(user_id, None) 283 | else {"status": "idle", "video_event": {}, "text_event": {}} 284 | ) 285 | 286 | match "BV" in event["content"]: 287 | case True: 288 | _LOGGER.debug("检测到消息中包含BV号,开始提取") 289 | # try: 290 | # p1, p2 = event["content"].split(" ") # 简单分离一下关键词与链接 291 | # except Exception as e: 292 | # _LOGGER.error(f"分离关键词与链接失败:{e},返回") 293 | # return 294 | # 295 | # if "BV" in p1: 296 | # bvid = p1 297 | # keyword = p2 298 | # else: 299 | # bvid = p2 300 | # keyword = p1 301 | bvid = event["content"][:12] 302 | if not re.search("^BV[a-zA-Z0-9]{10}$", bvid): 303 | _LOGGER.warning(f"从消息‘{event['content']}’中提取bv号失败!你是不是没把bv号放在消息最前面?!") 304 | return 305 | if _session["status"] in ( 306 | "waiting_for_keyword", 307 | "idle", 308 | "waiting_for_video", 309 | ): 310 | _session["video_event"] = {} 311 | _session["video_event"]["content"] = bvid 312 | _session["text_event"] = deepcopy(event) 313 | task_metadata = await self.build_task_from_private_msg(_session) 314 | if task_metadata is None: 315 | return 316 | await self.chain_router.dispatch_a_task(task_metadata) 317 | _session["status"] = "idle" 318 | _session["text_event"] = {} 319 | _session["video_event"] = {} 320 | self.user_sessions[user_id] = _session 321 | return 322 | 323 | match _session["status"]: 324 | case "waiting_for_keyword": 325 | _session["text_event"] = event 326 | task_metadata = await self.build_task_from_private_msg(_session) 327 | if task_metadata is None: 328 | return 329 | # task_metadata = self.build_private_msg_to_at_items(_session["event"]) # type: ignore 330 | # task_metadata["item"]["source_content"] = text # 将文本消息填入at内容 331 | await self.chain_router.dispatch_a_task(task_metadata) 332 | _session["status"] = "idle" 333 | _session["text_event"] = {} 334 | _session["video_event"] = {} 335 | 336 | case "idle": 337 | _session["text_event"] = event 338 | _session["status"] = "waiting_for_video" 339 | 340 | case "waiting_for_video": 341 | _session["text_event"] = event 342 | 343 | case _: 344 | pass 345 | self.user_sessions[user_id] = _session 346 | 347 | async def on_receive(self, event: session.Event): 348 | """接收到视频分享消息时的回调函数""" 349 | _LOGGER.debug(f"接收到私聊消息,内容为:{event}") 350 | data = event.__dict__ 351 | if data["msg_type"] == 7: 352 | await self.handle_video(data["sender_uid"], data) 353 | elif data["msg_type"] == 1: 354 | await self.handle_text(data["sender_uid"], data) 355 | else: 356 | _LOGGER.debug(f"未知的消息类型{data['msg_type']}") 357 | 358 | async def listen_private(self): 359 | # TODO 将轮询功能从bilibili_api库分离,重写 360 | self.sess = session.Session(self.credential) 361 | self.sess.logger = _LOGGER 362 | if os.getenv("DEBUG_MODE") == "true": # debug模式下不排除自己发的消息 363 | await self.sess.run(exclude_self=False) 364 | else: 365 | await self.sess.run(exclude_self=True) 366 | self.sess.add_event_listener(str(session.EventType.SHARE_VIDEO.value), self.on_receive) # type: ignore 367 | self.sess.add_event_listener(str(session.EventType.TEXT.value), self.on_receive) # type: ignore 368 | 369 | def close_private_listen(self): 370 | self.sess.close() 371 | _LOGGER.info("私聊侦听已关闭") 372 | 373 | # def get_cache(self): 374 | # with open(self.cache_path, 'r', encoding="utf-8") as f: 375 | # cache = json.loads(f.read()) 376 | # return cache 377 | # 378 | # def set_cache(self, data: dict, key: str): 379 | # if key not in self.video_cache: 380 | # self.video_cache[key] = {} 381 | # self.video_cache[key] = data 382 | # with open(self.cache_path, "w") as file: 383 | # file.write(json.dumps(self.video_cache, ensure_ascii=False, indent=4)) 384 | 385 | # def get_up_file(self): 386 | # with open(self.up_file_path, 'r', encoding="utf-8") as f: 387 | # up_list = json.loads(f.read()) 388 | # return up_list['all_area'] 389 | --------------------------------------------------------------------------------