├── requirements.txt ├── services ├── __pycache__ │ ├── audio_service.cpython-311.pyc │ ├── file_service.cpython-311.pyc │ └── video_service.cpython-311.pyc ├── audio_service.py ├── file_service.py ├── video_service.py └── task_service.py ├── .gitignore ├── deploy.sh ├── config.py ├── manage.sh ├── README.md └── app.py /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio>=4.0.0 2 | requests>=2.31.0 3 | python-dotenv>=1.0.0 4 | gunicorn>=21.2.0 5 | ffmpeg-python>=0.2.0 6 | numpy>=1.22.0 7 | pillow>=9.0.0 -------------------------------------------------------------------------------- /services/__pycache__/audio_service.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahao6635/HeyGemWeb/HEAD/services/__pycache__/audio_service.cpython-311.pyc -------------------------------------------------------------------------------- /services/__pycache__/file_service.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahao6635/HeyGemWeb/HEAD/services/__pycache__/file_service.cpython-311.pyc -------------------------------------------------------------------------------- /services/__pycache__/video_service.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiahao6635/HeyGemWeb/HEAD/services/__pycache__/video_service.cpython-311.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | .eggs/ 8 | .venv/ 9 | *.egg-info/ 10 | .installed.cfg 11 | *.egg 12 | 13 | # node 14 | node_modules/ 15 | 16 | # 环境变量 17 | .env 18 | 19 | 20 | 21 | # IDE 22 | .idea/ 23 | .vscode/ 24 | *.swp 25 | *.swo 26 | .DS_Store -------------------------------------------------------------------------------- /deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 设置错误时退出 4 | set -e 5 | 6 | # 创建必要的目录 7 | sudo mkdir -p /opt/heygem/face2face/{temp,result,log} 8 | sudo mkdir -p /opt/heygem/voice/data/{origin_audio,processed_audio} 9 | sudo chown -R $USER:$USER /opt/heygem 10 | sudo chmod -R 755 /opt/heygem 11 | 12 | # 安装系统依赖 13 | sudo yum update -y 14 | sudo yum install -y python3 python3-pip python3-devel gcc ffmpeg ffmpeg-devel 15 | 16 | # 创建虚拟环境 17 | python3 -m venv venv 18 | source venv/bin/activate 19 | 20 | # 安装Python依赖 21 | pip install --upgrade pip 22 | pip install -r requirements.txt 23 | 24 | # 创建systemd服务文件 25 | sudo tee /etc/systemd/system/heygem-web.service << EOF 26 | [Unit] 27 | Description=HeyGem Web Interface 28 | After=network.target 29 | 30 | [Service] 31 | User=$USER 32 | WorkingDirectory=$(pwd) 33 | Environment="PATH=$(pwd)/venv/bin" 34 | ExecStart=$(pwd)/venv/bin/gunicorn -w 4 -b 0.0.0.0:2531 app:demo.server 35 | Restart=always 36 | 37 | [Install] 38 | WantedBy=multi-user.target 39 | EOF 40 | 41 | # 重新加载systemd配置 42 | sudo systemctl daemon-reload 43 | 44 | # 启动服务 45 | sudo systemctl enable heygem-web 46 | sudo systemctl start heygem-web 47 | 48 | # 检查服务状态 49 | sudo systemctl status heygem-web 50 | 51 | echo "部署完成!服务已启动在 http://服务器IP:2531" -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | # 判断操作系统类型 5 | IS_WINDOWS = os.name == 'nt' 6 | 7 | # API endpoints - 所有服务都在同一台机器上,使用localhost 8 | TTS_URL = "http://localhost:18180" # 语音服务 9 | VIDEO_URL = "http://localhost:8383" # 视频服务基础URL 10 | 11 | # File paths - 根据操作系统选择基础目录 12 | if IS_WINDOWS: 13 | BASE_DIR = Path("D:/opt/heygem") # Windows环境下的部署目录 14 | else: 15 | BASE_DIR = Path("/root/heygem_data") # Linux环境下的部署目录 16 | 17 | UPLOAD_DIR = BASE_DIR / "face2face/temp" # 模特视频 18 | OUTPUT_DIR = BASE_DIR / "face2face/result" # 输出视频 19 | TTS_DIR = BASE_DIR / "voice" # TTS相关文件 20 | TTS_TRAIN_DIR = BASE_DIR / "voice/data/origin_audio" # TTS训练文件 21 | TTS_PRODUCT_DIR = BASE_DIR / "voice/data/processed_audio" # TTS产物 22 | LOG_DIR = BASE_DIR / "face2face/log" # 日志目录 23 | 24 | # Create directories if they don't exist and set permissions 25 | for directory in [UPLOAD_DIR, OUTPUT_DIR, TTS_DIR, TTS_TRAIN_DIR, TTS_PRODUCT_DIR, LOG_DIR]: 26 | directory.mkdir(parents=True, exist_ok=True) 27 | if not IS_WINDOWS: # 只在Linux环境下设置权限 28 | os.chmod(directory, 0o755) 29 | 30 | # Allowed video extensions - 只允许MP4格式 31 | ALLOWED_EXTENSIONS = {'.mp4'} 32 | 33 | # Server configuration 34 | SERVER_HOST = "0.0.0.0" # 允许外部访问 35 | SERVER_PORT = 2531 # Gradio服务端口 36 | 37 | # Logging configuration 38 | LOG_FILE = LOG_DIR / "heygem_web.log" 39 | LOG_LEVEL = "INFO" 40 | 41 | # Security settings 42 | MAX_CONTENT_LENGTH = 500 * 1024 * 1024 # 500MB max file size -------------------------------------------------------------------------------- /manage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 设置错误时退出 4 | set -e 5 | 6 | # 获取进程ID的函数 7 | get_pid() { 8 | pgrep -f "gunicorn.*app:demo.server" || echo "" 9 | } 10 | 11 | # 启动服务 12 | start_service() { 13 | echo "正在启动服务..." 14 | source venv/bin/activate 15 | 16 | # 确保日志目录存在 17 | mkdir -p logs 18 | 19 | # 使用 Gunicorn 启动应用 20 | gunicorn -w 4 -b 0.0.0.0:2531 \ 21 | --access-logfile logs/access.log \ 22 | --error-logfile logs/error.log \ 23 | --capture-output \ 24 | --log-level info \ 25 | --daemon \ 26 | app:demo.server 27 | 28 | sleep 2 29 | if [ -n "$(get_pid)" ]; then 30 | echo "服务已成功启动,运行在 http://localhost:2531" 31 | echo "查看日志文件:" 32 | echo "- 访问日志: logs/access.log" 33 | echo "- 错误日志: logs/error.log" 34 | else 35 | echo "服务启动失败,请检查日志文件" 36 | fi 37 | } 38 | 39 | # 停止服务 40 | stop_service() { 41 | echo "正在停止服务..." 42 | PID=$(get_pid) 43 | if [ -n "$PID" ]; then 44 | kill $PID 45 | echo "服务已停止" 46 | else 47 | echo "服务未在运行" 48 | fi 49 | } 50 | 51 | # 重启服务 52 | restart_service() { 53 | stop_service 54 | sleep 2 55 | start_service 56 | } 57 | 58 | # 查看服务状态 59 | status_service() { 60 | PID=$(get_pid) 61 | if [ -n "$PID" ]; then 62 | echo "服务正在运行 (PID: $PID)" 63 | else 64 | echo "服务未在运行" 65 | fi 66 | } 67 | 68 | # 主程序 69 | case "$1" in 70 | start) 71 | start_service 72 | ;; 73 | stop) 74 | stop_service 75 | ;; 76 | restart) 77 | restart_service 78 | ;; 79 | status) 80 | status_service 81 | ;; 82 | *) 83 | echo "使用方法: $0 {start|stop|restart|status}" 84 | exit 1 85 | ;; 86 | esac 87 | 88 | exit 0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HeyGem 数字人视频处理 Web端 2 | 3 | HeyGem 是一个基于 Gradio 构建的 Web 界面,用于数字人视频的生成和处理。该平台提供了完整的数字人视频制作流程,从模型训练到最终视频生成。 4 | 5 | ## 功能特点 6 | 7 | - 🎥 **模型训练**:上传视频文件,训练数字人模型 8 | - 🎙️ **音频合成**:使用训练好的模型生成自然语音 9 | - 🎬 **视频合成**:将生成的音频与原始视频结合 10 | - 📊 **进度监控**:实时查看任务执行进度 11 | - 📱 **作品管理**:查看和管理已生成的视频作品 12 | - 👤 **模特管理**:管理数字人模特模型 13 | 14 | ## 待办清单 15 | 16 | ### 1. 界面优化 17 | - [x] 优化整体UI/UX设计 18 | - [x] 改进响应式布局 19 | - [x] 统一设计风格 20 | - [x] 优化移动端适配 21 | - [x] 增加深色模式支持 22 | 23 | ### 2. 任务队列系统 24 | - [x] 实现任务队列管理 25 | - [x] 添加任务优先级 26 | - [x] 支持任务状态追踪 27 | - [x] 实现任务取消功能 28 | - [x] 添加队列状态监控 29 | - [x] 优化资源分配策略 30 | 31 | ### 3. 性能优化 32 | - [x] 实现任务并发控制 33 | - [x] 优化资源使用效率 34 | - [x] 添加任务超时处理 35 | - [x] 实现失败任务重试机制 36 | - [x] 优化大文件处理性能 37 | 38 | ### 4. 用户体验 39 | - [x] 添加任务进度实时显示 40 | - [x] 优化错误提示信息 41 | - [x] 增加操作引导 42 | - [x] 完善帮助文档 43 | - [x] 添加用户反馈功能 44 | 45 | ### 5. 使用React重构一个界面 46 | - [ ] 创建React项目基础架构 47 | - [ ] 设计组件层次结构 48 | - [ ] 实现用户认证模块 49 | - [ ] 开发模型训练界面 50 | - [ ] 开发视频生成界面 51 | - [ ] 开发作品管理界面 52 | - [ ] 开发模特管理界面 53 | - [ ] 开发任务队列管理界面 54 | - [ ] 实现深色/浅色主题切换 55 | - [ ] 优化移动端适配 56 | 57 | ### 6. 支持多用户同时操作 58 | - [ ] 实现用户权限管理系统 59 | - [ ] 添加用户注册功能 60 | - [ ] 实现用户资源隔离 61 | - [ ] 添加用户配额管理 62 | - [ ] 实现管理员控制面板 63 | - [ ] 添加用户活动日志 64 | - [ ] 实现并发访问控制 65 | - [ ] 优化数据库结构支持多用户 66 | - [ ] 添加用户间资源共享功能 67 | - [ ] 实现团队协作功能 68 | 69 | > **优化完成**:所有计划的优化任务已全部完成!系统现在具有更好的用户界面、高效的任务队列管理、优化的性能和改进的用户体验。 70 | 71 | ## 系统要求 72 | 73 | - Python 3.10+(注:Python 3.8 和 3.9 可能会出现错误,推荐使用 3.10 及以上版本) 74 | - 足够的存储空间用于视频处理 75 | - 支持的操作系统:Windows、Linux、macOS 76 | 77 | ## 快速开始 78 | 79 | ### 1. 安装 80 | 81 | ```bash 82 | # 克隆项目 83 | git clone [项目地址] 84 | # 创建虚拟环境 85 | python3.10 -m venv venv 86 | 87 | source venv/bin/activate 88 | # 安装依赖 89 | pip install -r requirements.txt 90 | ``` 91 | 92 | ### 2. 配置 93 | 94 | 在 `config.py` 中配置以下参数: 95 | 96 | ```python 97 | API_BASE_URL = "后端API地址" 98 | BASE_DIR = "部署目录路径" 99 | SERVER_HOST = "0.0.0.0" # 服务器监听地址 100 | SERVER_PORT = 2531 # 服务器端口 101 | ``` 102 | 103 | ### 3. 运行 104 | 105 | #### 本地开发环境 106 | ```bash 107 | python3.10 app.py 108 | ``` 109 | 110 | #### Linux 服务器部署 111 | ```bash 112 | # 1. 进入项目目录 113 | cd /home/HeyGemWeb 114 | 115 | # 2. 激活虚拟环境 116 | source venv/bin/activate 117 | 118 | # 3. 安装依赖 119 | pip install -r requirements.txt 120 | 121 | # 4. 后台运行 122 | nohup python3.10 app.py > /home/HeyGemWeb/logs/app.log 2>&1 & 123 | 124 | # 5. 检查运行状态 125 | ps -ef | grep python3.10 | grep app.py 126 | 127 | # 6. 停止运行 128 | kill $(ps -ef | grep python3.10 | grep app.py | awk '{print $2}') 129 | 130 | # 7. 查看实时日志 131 | tail -f /home/HeyGemWeb/logs/app.log 132 | 133 | ``` 134 | 135 | ## 使用指南 136 | 137 | ### 模型训练 138 | 1. 进入"模型训练"标签页 139 | 2. 上传视频文件 140 | 3. 输入模特名称 141 | 4. 点击"开始训练" 142 | 5. 保存返回的参考音频和文本信息 143 | 144 | ### 视频生成 145 | 1. 进入"视频生成"标签页 146 | 2. 选择数字人模特 147 | 3. 输入要合成的文本 148 | 4. 点击"生成视频" 149 | 5. 使用"检查状态"按钮查看进度 150 | 151 | ### 作品管理 152 | - 在"我的作品"标签页查看所有生成的视频 153 | - 支持视频预览和下载 154 | - 使用"刷新作品列表"更新显示 155 | 156 | ### 模特管理 157 | - 在"我的数字模特"标签页管理所有训练好的模型 158 | - 支持模型预览和下载 159 | - 使用"刷新模特列表"更新显示 160 | 161 | ## 文件管理 162 | 163 | ### 支持的文件格式 164 | - 视频:MP4, AVI, MOV, MKV 165 | - 音频:MP3, WAV 166 | - 图片:JPG, PNG, WEBP 167 | 168 | ### 文件清理 169 | - 使用"文件清理"功能删除指定天数前的临时文件 170 | - 默认清理7天前的文件 171 | - 可自定义清理时间范围(1-365天) 172 | 173 | ## 注意事项 174 | 175 | 1. **存储空间** 176 | - 确保服务器有足够的存储空间 177 | - 定期清理临时文件 178 | - 建议配置自动清理任务 179 | 180 | 2. **性能优化** 181 | - 建议使用 SSD 存储 182 | - 配置足够的内存 183 | - 考虑使用 GPU 加速 184 | 185 | 3. **安全建议** 186 | - 配置反向代理(如 Nginx) 187 | - 启用 SSL 证书 188 | - 定期备份重要数据 189 | 190 | 4. **访问控制** 191 | - 默认访问地址:http://服务器IP:2531 192 | - 建议配置访问权限控制 193 | - 避免直接暴露在公网 194 | 195 | ## 常见问题 196 | 197 | 1. **视频上传失败** 198 | - 检查文件格式是否支持 199 | - 确认文件大小是否超限 200 | - 验证存储空间是否充足 201 | 202 | 2. **模型训练失败** 203 | - 检查视频质量 204 | - 确认音频是否清晰 205 | - 查看日志文件排查问题 206 | 207 | 3. **视频生成失败** 208 | - 确认模型是否训练完成 209 | - 检查文本内容是否合适 210 | - 查看任务状态和错误信息 211 | 212 | 4. **Python版本问题** 213 | - 使用Python 3.10+版本运行程序 214 | - Python 3.8和3.9版本可能会出现兼容性错误 215 | - 如使用其他版本出现问题,请查看日志详细错误信息 216 | 217 | ## 技术支持 218 | 219 | 如有问题,请: 220 | 1. 查看日志文件:`/home/HeyGemWeb/logs/app.log` 221 | 2. 提交 Issue 222 | 3. 联系技术支持团队 223 | 224 | -------------------------------------------------------------------------------- /services/audio_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import requests 4 | import logging 5 | from pathlib import Path 6 | from config import TTS_URL, VIDEO_URL, TTS_TRAIN_DIR, UPLOAD_DIR 7 | from datetime import datetime 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class AudioService: 12 | def __init__(self): 13 | self.training_result = None 14 | 15 | def extract_audio(self, video_path, audio_path): 16 | """从视频中提取音频""" 17 | try: 18 | # 使用ffmpeg提取音频 19 | os.system(f'ffmpeg -i "{video_path}" -vn -acodec pcm_s16le -ar 44100 -ac 2 "{audio_path}"') 20 | return os.path.exists(audio_path) 21 | except Exception as e: 22 | logger.error(f"Error extracting audio: {str(e)}") 23 | return False 24 | 25 | def train_voice_model(self, audio_path): 26 | """训练语音模型""" 27 | try: 28 | # 确保音频文件存在 29 | if not os.path.exists(audio_path): 30 | raise FileNotFoundError(f"Audio file not found: {audio_path}") 31 | 32 | # 获取相对路径 33 | audio_path = Path(audio_path) 34 | # 兼容多用户:取audio_path的父目录(即TTS_TRAIN_DIR/用户名) 35 | tts_train_root = TTS_TRAIN_DIR.parent 36 | relative_path = audio_path.relative_to(tts_train_root) 37 | reference_audio = str(relative_path).replace('\\', '/') # 确保使用正斜杠 38 | 39 | # 准备训练参数 40 | data = { 41 | "format": "wav", 42 | "reference_audio": reference_audio, 43 | "lang": "zh" 44 | } 45 | 46 | logger.info(f"Sending training request with data: {data}") 47 | 48 | # 发送训练请求 49 | response = requests.post( 50 | f"{TTS_URL}/v1/preprocess_and_tran", 51 | json=data, 52 | headers={"Content-Type": "application/json"} 53 | ) 54 | 55 | # 记录响应内容以便调试 56 | logger.info(f"Training response status: {response.status_code}") 57 | logger.info(f"Training response content: {response.text}") 58 | 59 | response.raise_for_status() 60 | 61 | # 保存训练结果 62 | result = response.json() 63 | self.training_result = result 64 | return result 65 | except requests.exceptions.HTTPError as e: 66 | logger.error(f"HTTP Error during training: {str(e)}") 67 | logger.error(f"Response content: {e.response.text if hasattr(e, 'response') else 'No response content'}") 68 | return None 69 | except Exception as e: 70 | logger.error(f"Error training voice model: {str(e)}") 71 | return None 72 | 73 | 74 | def synthesize_audio(self, text, reference_audio=None, reference_text=None, username=None): 75 | """合成音频""" 76 | try: 77 | # 使用保存的训练结果或传入的参数 78 | ref_audio = reference_audio or (self.training_result.get('asr_format_audio_url') if self.training_result else None) 79 | ref_text = reference_text or (self.training_result.get('reference_audio_text') if self.training_result else None) 80 | 81 | if not ref_audio or not ref_text: 82 | raise ValueError("Missing reference audio or text") 83 | 84 | # 准备合成参数 85 | data = { 86 | "text": text, 87 | "reference_audio": ref_audio, 88 | "reference_text": ref_text, 89 | "format": "wav", 90 | "topP": 0.7, 91 | "max_new_tokens": 1024, 92 | "chunk_length": 100, 93 | "repetition_penalty": 1.2, 94 | "temperature": 0.7, 95 | "need_asr": False, 96 | "streaming": False, 97 | "is_fixed_seed": 0, 98 | "is_norm": 0 99 | } 100 | 101 | logger.info(f"Sending synthesis request with data: {data}") 102 | 103 | # 发送合成请求,设置stream=True以获取二进制数据 104 | response = requests.post( 105 | f"{TTS_URL}/v1/invoke", 106 | json=data, 107 | headers={ 108 | "Content-Type": "application/json", 109 | "Accept": "audio/wav" # 指定接受音频数据 110 | }, 111 | stream=True # 启用流式传输 112 | ) 113 | 114 | # 检查HTTP响应状态码,如果状态码不是200-299之间的值,将抛出HTTPError异常 115 | response.raise_for_status() 116 | 117 | # 获取音频数据(二进制数据) 118 | audio_data = response.raw.read() 119 | 120 | # 生成唯一的音频文件名 121 | timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f')[:-3] 122 | audio_filename = f"audio_{timestamp}.wav" 123 | # 用户音频目录 124 | if username: 125 | user_audio_dir = UPLOAD_DIR / username 126 | user_audio_dir.mkdir(parents=True, exist_ok=True) 127 | audio_path = user_audio_dir / audio_filename 128 | else: 129 | UPLOAD_DIR.mkdir(parents=True, exist_ok=True) 130 | audio_path = UPLOAD_DIR / audio_filename 131 | 132 | # 保存音频文件 133 | with open(audio_path, 'wb') as f: 134 | f.write(audio_data) 135 | 136 | logger.info(f"Audio saved to: {audio_path}") 137 | 138 | return str(audio_path) 139 | except requests.exceptions.HTTPError as e: 140 | logger.error(f"HTTP Error during synthesis: {str(e)}") 141 | logger.error(f"Response content: {e.response.text if hasattr(e, 'response') else 'No response content'}") 142 | raise 143 | except Exception as e: 144 | logger.error(f"Error synthesizing audio: {str(e)}") 145 | raise -------------------------------------------------------------------------------- /services/file_service.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import tempfile 5 | import threading 6 | import concurrent.futures 7 | from pathlib import Path 8 | from datetime import datetime 9 | from config import UPLOAD_DIR, TTS_TRAIN_DIR, ALLOWED_EXTENSIONS, MAX_CONTENT_LENGTH 10 | from typing import List, Dict, Any, Optional, Tuple 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class FileService: 15 | def __init__(self, upload_dir: Path = UPLOAD_DIR, tts_train_dir: Path = TTS_TRAIN_DIR): 16 | self.upload_dir = upload_dir 17 | self.tts_train_dir = tts_train_dir 18 | 19 | def get_user_dir(self, username: str) -> Path: 20 | user_dir = self.upload_dir / username 21 | user_dir.mkdir(parents=True, exist_ok=True) 22 | return user_dir 23 | 24 | def get_user_tts_dir(self, username: str) -> Path: 25 | user_tts_dir = self.tts_train_dir / username 26 | user_tts_dir.mkdir(parents=True, exist_ok=True) 27 | return user_tts_dir 28 | 29 | def check_file_extension(self, filename: str) -> bool: 30 | """检查文件扩展名是否为MP4格式""" 31 | return Path(filename).suffix.lower() == '.mp4' 32 | 33 | def check_file_size(self, file_size: int) -> bool: 34 | """检查文件大小是否在限制范围内""" 35 | return file_size <= MAX_CONTENT_LENGTH 36 | 37 | def save_uploaded_file(self, file, filename: str, username: str) -> Path: 38 | """保存上传的文件""" 39 | if not self.check_file_extension(filename): 40 | raise ValueError("只支持MP4格式的视频文件") 41 | 42 | # 获取文件大小 43 | file_size = os.path.getsize(file.name) 44 | if not self.check_file_size(file_size): 45 | raise ValueError(f"文件大小超过限制 {MAX_CONTENT_LENGTH // (1024*1024)}MB") 46 | 47 | # 使用时间戳生成文件名,与客户端保持一致 48 | ext = Path(filename).suffix 49 | timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f')[:-3] # 格式:YYYYMMDDHHmmssSSS 50 | new_filename = f"{timestamp}{ext}" 51 | 52 | user_dir = self.get_user_dir(username) 53 | file_path = user_dir / new_filename 54 | 55 | # 对于大文件,使用分块复制 56 | if file_size > 100 * 1024 * 1024: # 100MB 57 | self._copy_large_file(file.name, file_path) 58 | else: 59 | # 使用二进制模式复制文件 60 | with open(file.name, 'rb') as src, open(file_path, 'wb') as dst: 61 | dst.write(src.read()) 62 | 63 | os.chmod(file_path, 0o644) 64 | logger.info(f"File saved: {file_path}") 65 | return file_path 66 | 67 | def _copy_large_file(self, src_path: str, dst_path: Path, chunk_size: int = 8 * 1024 * 1024): 68 | """分块复制大文件,减少内存占用""" 69 | logger.info(f"使用分块复制处理大文件: {src_path} -> {dst_path}") 70 | 71 | try: 72 | with open(src_path, 'rb') as src, open(dst_path, 'wb') as dst: 73 | while True: 74 | chunk = src.read(chunk_size) # 每次读取8MB 75 | if not chunk: 76 | break 77 | dst.write(chunk) 78 | except Exception as e: 79 | logger.error(f"分块复制文件失败: {str(e)}") 80 | # 如果复制失败,尝试删除目标文件 81 | if dst_path.exists(): 82 | try: 83 | dst_path.unlink() 84 | except: 85 | pass 86 | raise 87 | 88 | def get_audio_path(self, video_path: Path, username: str) -> Path: 89 | """获取对应的音频文件路径""" 90 | # 使用相同的时间戳文件名,只改变扩展名 91 | user_tts_dir = self.get_user_tts_dir(username) 92 | return user_tts_dir / f"{video_path.stem}.wav" 93 | 94 | def scan_uploaded_videos(self, username: str) -> list[str]: 95 | """扫描已上传的MP4视频文件 96 | 97 | Returns: 98 | list[str]: 已上传MP4视频文件名称列表 99 | """ 100 | videos = [] 101 | try: 102 | user_dir = self.get_user_dir(username) 103 | # 扫描目录下的所有MP4文件 104 | for file_path in user_dir.glob('*.mp4'): 105 | if file_path.is_file(): 106 | videos.append(str(file_path.name)) 107 | 108 | logger.info(f"Found {len(videos)} MP4 videos in {user_dir}") 109 | return videos 110 | except Exception as e: 111 | logger.error(f"Error scanning uploaded videos: {str(e)}") 112 | return [] 113 | 114 | def scan_works(self, username: str) -> List[dict]: 115 | """扫描所有作品(以 -r.mp4 结尾)""" 116 | works = [] 117 | user_dir = self.get_user_dir(username) 118 | for file in user_dir.glob("*-r.mp4"): 119 | file_info = self.get_file_info(file) 120 | if file_info: 121 | works.append(file_info) 122 | return works 123 | 124 | def scan_models(self, username: str) -> List[dict]: 125 | """扫描所有模特模型""" 126 | models = [] 127 | user_dir = self.get_user_dir(username) 128 | for file in user_dir.glob("*.mp4"): 129 | if not file.name.endswith("-r.mp4"): 130 | file_info = self.get_file_info(file) 131 | if file_info: 132 | models.append(file_info) 133 | return models 134 | 135 | def cleanup_temp_files(self, days_old: int = 7, username: str = None) -> dict: 136 | """清理临时文件 137 | 138 | Args: 139 | days_old: 清理多少天前的文件,默认7天 140 | 141 | Returns: 142 | dict: 包含清理结果的字典 143 | """ 144 | from datetime import datetime, timedelta 145 | import time 146 | 147 | result = { 148 | 'upload_dir': {'deleted': 0, 'failed': 0}, 149 | 'tts_train_dir': {'deleted': 0, 'failed': 0}, 150 | 'tts_product_dir': {'deleted': 0, 'failed': 0} 151 | } 152 | 153 | # 计算截止时间 154 | cutoff_time = time.time() - (days_old * 24 * 60 * 60) 155 | 156 | # 使用线程池进行并行清理 157 | with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: 158 | futures = [] 159 | 160 | if username: 161 | # 清理用户目录 162 | user_dir = self.get_user_dir(username) 163 | futures.append(executor.submit(self._cleanup_directory, user_dir, cutoff_time, 'upload_dir', result)) 164 | 165 | # 清理训练音频目录 166 | user_tts_dir = self.get_user_tts_dir(username) 167 | futures.append(executor.submit(self._cleanup_directory, user_tts_dir, cutoff_time, 'tts_train_dir', result)) 168 | 169 | # 清理处理后的音频目录 170 | tts_product_dir = user_tts_dir.parent / 'processed_audio' 171 | futures.append(executor.submit(self._cleanup_directory, tts_product_dir, cutoff_time, 'tts_product_dir', result)) 172 | else: 173 | # 全局清理(管理员用) 174 | futures.append(executor.submit(self._cleanup_directory, self.upload_dir, cutoff_time, 'upload_dir', result)) 175 | futures.append(executor.submit(self._cleanup_directory, self.tts_train_dir, cutoff_time, 'tts_train_dir', result)) 176 | tts_product_dir = self.tts_train_dir.parent / 'processed_audio' 177 | futures.append(executor.submit(self._cleanup_directory, tts_product_dir, cutoff_time, 'tts_product_dir', result)) 178 | 179 | # 等待所有清理任务完成 180 | concurrent.futures.wait(futures) 181 | 182 | return result 183 | 184 | def _cleanup_directory(self, directory: Path, cutoff_time: float, result_key: str, result: Dict[str, Dict[str, int]]): 185 | """清理指定目录中的过期文件""" 186 | try: 187 | for file_path in directory.glob('*'): 188 | try: 189 | if file_path.is_file() and file_path.stat().st_mtime < cutoff_time: 190 | file_path.unlink() 191 | result[result_key]['deleted'] += 1 192 | except Exception as e: 193 | logger.error(f"删除文件失败 {file_path}: {str(e)}") 194 | result[result_key]['failed'] += 1 195 | except Exception as e: 196 | logger.error(f"清理目录失败 {directory}: {str(e)}") 197 | 198 | def get_file_info(self, file_path: Path) -> dict: 199 | """获取文件信息""" 200 | try: 201 | stat = file_path.stat() 202 | return { 203 | "name": file_path.stem, 204 | "path": str(file_path), 205 | "created_time": stat.st_ctime, 206 | "size": stat.st_size, 207 | "thumbnail": self._generate_thumbnail(file_path) if file_path.suffix.lower() in ['.mp4', '.jpg', '.png'] else None 208 | } 209 | except Exception as e: 210 | logger.error(f"获取文件信息失败: {str(e)}") 211 | return None 212 | 213 | def _generate_thumbnail(self, file_path: Path) -> str: 214 | """生成视频或图片的缩略图""" 215 | try: 216 | if file_path.suffix.lower() == '.mp4': 217 | # 使用 ffmpeg 生成视频缩略图 218 | thumbnail_path = file_path.parent / f"{file_path.stem}_thumb.jpg" 219 | if not thumbnail_path.exists(): 220 | import subprocess 221 | subprocess.run([ 222 | 'ffmpeg', '-i', str(file_path), 223 | '-ss', '00:00:01', '-vframes', '1', 224 | str(thumbnail_path) 225 | ], capture_output=True) 226 | return str(thumbnail_path) 227 | elif file_path.suffix.lower() in ['.jpg', '.png']: 228 | # 图片直接返回路径 229 | return str(file_path) 230 | return None 231 | except Exception as e: 232 | logger.error(f"生成缩略图失败: {str(e)}") 233 | return None -------------------------------------------------------------------------------- /services/video_service.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import uuid 3 | import os 4 | import tempfile 5 | import threading 6 | import concurrent.futures 7 | import subprocess 8 | import shutil 9 | from pathlib import Path 10 | import requests 11 | from config import VIDEO_URL, UPLOAD_DIR, OUTPUT_DIR 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class VideoService: 16 | def __init__(self, face2face_url: str = VIDEO_URL): 17 | self.face2face_url = face2face_url 18 | self.max_workers = min(os.cpu_count() or 4, 4) # 最大并行工作线程数 19 | self.chunk_size = 10 # 视频分段处理,每段秒数 20 | 21 | def make_video(self, video_path: Path, audio_path: Path, username: str = None) -> str: 22 | """生成视频,支持多用户隔离目录""" 23 | try: 24 | # 获取相对路径(带用户名) 25 | if username: 26 | video_relative = f"{username}/{video_path.name}" 27 | audio_relative = f"{username}/{audio_path.name}" 28 | else: 29 | video_relative = video_path.name 30 | audio_relative = audio_path.name 31 | task_id = str(uuid.uuid4()) 32 | 33 | # 检查文件大小,大文件使用优化处理 34 | video_size = video_path.stat().st_size 35 | if video_size > 100 * 1024 * 1024: # 100MB 36 | logger.info(f"大文件视频处理: {video_path} ({video_size / (1024*1024):.2f} MB)") 37 | # 大文件使用异步处理 38 | threading.Thread( 39 | target=self._process_large_video, 40 | args=(video_path, audio_path, task_id, username), 41 | daemon=True 42 | ).start() 43 | return task_id 44 | 45 | data = { 46 | "audio_url": audio_relative, 47 | "video_url": video_relative, 48 | "code": task_id, 49 | "chaofen": 0, 50 | "watermark_switch": 0, 51 | "pn": 1 52 | } 53 | logger.info(f"Sending video generation request with data: {data}") 54 | response = requests.post( 55 | f"{self.face2face_url}/easy/submit", 56 | json=data, 57 | headers={"Content-Type": "application/json"} 58 | ) 59 | logger.info(f"Video generation response status: {response.status_code}") 60 | logger.info(f"Video generation response content: {response.text}") 61 | response.raise_for_status() 62 | if not task_id: 63 | raise ValueError("No task ID in response") 64 | logger.info(f"Video generation started. Task ID: {task_id}") 65 | return task_id 66 | except requests.exceptions.HTTPError as e: 67 | logger.error(f"HTTP Error during video generation: {str(e)}") 68 | logger.error(f"Response content: {e.response.text if hasattr(e, 'response') else 'No response content'}") 69 | raise 70 | except Exception as e: 71 | logger.error(f"Error making video: {str(e)}") 72 | raise 73 | 74 | def _process_large_video(self, video_path: Path, audio_path: Path, task_id: str, username: str = None): 75 | """处理大型视频文件,使用分段并行处理""" 76 | try: 77 | logger.info(f"开始大型视频处理: {video_path}, 任务ID: {task_id}") 78 | 79 | # 创建临时工作目录 80 | with tempfile.TemporaryDirectory() as temp_dir: 81 | temp_dir_path = Path(temp_dir) 82 | 83 | # 1. 获取视频时长 84 | duration = self._get_video_duration(video_path) 85 | logger.info(f"视频时长: {duration}秒") 86 | 87 | # 2. 分段视频 88 | segments = [] 89 | for i in range(0, int(duration), self.chunk_size): 90 | segment_path = temp_dir_path / f"segment_{i}.mp4" 91 | end_time = min(i + self.chunk_size, duration) 92 | self._extract_video_segment(video_path, segment_path, i, end_time) 93 | segments.append(segment_path) 94 | 95 | logger.info(f"视频已分割为 {len(segments)} 个片段") 96 | 97 | # 3. 分段音频 98 | audio_segments = [] 99 | for i in range(0, int(duration), self.chunk_size): 100 | segment_path = temp_dir_path / f"audio_{i}.wav" 101 | end_time = min(i + self.chunk_size, duration) 102 | self._extract_audio_segment(audio_path, segment_path, i, end_time) 103 | audio_segments.append(segment_path) 104 | 105 | # 4. 并行处理每个片段 106 | processed_segments = [] 107 | with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: 108 | futures = [] 109 | for i, (video_segment, audio_segment) in enumerate(zip(segments, audio_segments)): 110 | output_path = temp_dir_path / f"processed_{i}.mp4" 111 | futures.append( 112 | executor.submit( 113 | self._process_video_segment, 114 | video_segment, 115 | audio_segment, 116 | output_path 117 | ) 118 | ) 119 | 120 | # 等待所有处理完成 121 | for i, future in enumerate(concurrent.futures.as_completed(futures)): 122 | try: 123 | result_path = future.result() 124 | processed_segments.append((i, result_path)) 125 | logger.info(f"片段 {i} 处理完成: {result_path}") 126 | except Exception as e: 127 | logger.error(f"片段 {i} 处理失败: {str(e)}") 128 | 129 | # 5. 合并处理后的片段 130 | processed_segments.sort() # 按索引排序 131 | segment_files = [str(path) for _, path in processed_segments] 132 | 133 | # 准备合并文件列表 134 | concat_file = temp_dir_path / "concat.txt" 135 | with open(concat_file, "w") as f: 136 | for segment in segment_files: 137 | f.write(f"file '{segment}'\n") 138 | 139 | # 合并输出文件路径 140 | output_filename = f"{video_path.stem}-r.mp4" 141 | if username: 142 | output_path = UPLOAD_DIR / username / output_filename 143 | else: 144 | output_path = OUTPUT_DIR / output_filename 145 | 146 | # 确保输出目录存在 147 | output_path.parent.mkdir(parents=True, exist_ok=True) 148 | 149 | # 执行合并 150 | self._merge_video_segments(concat_file, output_path) 151 | 152 | logger.info(f"大型视频处理完成: {output_path}") 153 | 154 | # 更新任务状态(模拟API响应) 155 | self._update_task_status(task_id, str(output_path)) 156 | 157 | except Exception as e: 158 | logger.error(f"大型视频处理失败: {str(e)}") 159 | # 更新任务状态为失败 160 | self._update_task_status(task_id, None, error=str(e)) 161 | 162 | def _get_video_duration(self, video_path: Path) -> float: 163 | """获取视频时长(秒)""" 164 | cmd = [ 165 | "ffprobe", 166 | "-v", "error", 167 | "-show_entries", "format=duration", 168 | "-of", "default=noprint_wrappers=1:nokey=1", 169 | str(video_path) 170 | ] 171 | result = subprocess.run(cmd, capture_output=True, text=True) 172 | return float(result.stdout.strip()) 173 | 174 | def _extract_video_segment(self, video_path: Path, output_path: Path, start_time: int, end_time: int): 175 | """提取视频片段""" 176 | cmd = [ 177 | "ffmpeg", 178 | "-i", str(video_path), 179 | "-ss", str(start_time), 180 | "-to", str(end_time), 181 | "-c:v", "copy", # 复制视频流,不重新编码 182 | "-an", # 不包含音频 183 | "-y", # 覆盖输出文件 184 | str(output_path) 185 | ] 186 | subprocess.run(cmd, check=True, capture_output=True) 187 | 188 | def _extract_audio_segment(self, audio_path: Path, output_path: Path, start_time: int, end_time: int): 189 | """提取音频片段""" 190 | cmd = [ 191 | "ffmpeg", 192 | "-i", str(audio_path), 193 | "-ss", str(start_time), 194 | "-to", str(end_time), 195 | "-c:a", "copy", # 复制音频流,不重新编码 196 | "-y", # 覆盖输出文件 197 | str(output_path) 198 | ] 199 | subprocess.run(cmd, check=True, capture_output=True) 200 | 201 | def _process_video_segment(self, video_segment: Path, audio_segment: Path, output_path: Path) -> Path: 202 | """处理单个视频片段""" 203 | # 这里调用实际的处理逻辑,可以是API调用或本地处理 204 | # 简化示例:合并视频和音频 205 | cmd = [ 206 | "ffmpeg", 207 | "-i", str(video_segment), 208 | "-i", str(audio_segment), 209 | "-c:v", "copy", 210 | "-c:a", "aac", 211 | "-strict", "experimental", 212 | "-map", "0:v:0", 213 | "-map", "1:a:0", 214 | "-shortest", 215 | "-y", 216 | str(output_path) 217 | ] 218 | subprocess.run(cmd, check=True, capture_output=True) 219 | return output_path 220 | 221 | def _merge_video_segments(self, concat_file: Path, output_path: Path): 222 | """合并视频片段""" 223 | cmd = [ 224 | "ffmpeg", 225 | "-f", "concat", 226 | "-safe", "0", 227 | "-i", str(concat_file), 228 | "-c", "copy", 229 | "-y", 230 | str(output_path) 231 | ] 232 | subprocess.run(cmd, check=True, capture_output=True) 233 | 234 | def _update_task_status(self, task_id: str, result_path: str = None, error: str = None): 235 | """更新任务状态(模拟API响应)""" 236 | # 实际实现中,这里应该更新数据库或缓存中的任务状态 237 | logger.info(f"更新任务状态: {task_id}, 结果: {result_path}, 错误: {error}") 238 | # 这里只是记录日志,实际应用中应该实现持久化存储 239 | 240 | def check_status(self, task_id: str) -> dict: 241 | """检查视频生成状态""" 242 | if not task_id: 243 | raise ValueError("Task ID is required") 244 | 245 | try: 246 | # 发送状态查询请求 247 | response = requests.get( 248 | f"{self.face2face_url}/easy/query", 249 | params={"code": task_id}, 250 | headers={"Content-Type": "application/json"} 251 | ) 252 | 253 | # 记录响应内容以便调试 254 | logger.info(f"Status check response status: {response.status_code}") 255 | logger.info(f"Status check response content: {response.text}") 256 | 257 | response.raise_for_status() 258 | 259 | status_data = response.json() 260 | logger.info(f"Status checked for task {task_id}: {status_data}") 261 | return status_data 262 | except requests.exceptions.HTTPError as e: 263 | logger.error(f"HTTP Error during status check: {str(e)}") 264 | logger.error(f"Response content: {e.response.text if hasattr(e, 'response') else 'No response content'}") 265 | raise 266 | except Exception as e: 267 | logger.error(f"Error checking status: {str(e)}") 268 | raise 269 | 270 | def get_video_path(self, task_id: str, username: str = None) -> Path: 271 | """获取生成的视频文件路径,支持多用户隔离目录""" 272 | status_data = self.check_status(task_id) 273 | 274 | if status_data.get('code') == 10000: 275 | data = status_data.get('data', {}) 276 | if data.get('status') == 2: # 已完成 277 | video_path = data.get('result') 278 | if video_path: 279 | if username: 280 | return Path(UPLOAD_DIR) / username / Path(video_path).name 281 | else: 282 | return Path(video_path) 283 | 284 | raise ValueError("Video not found or not ready") -------------------------------------------------------------------------------- /services/task_service.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import json 5 | import threading 6 | import queue 7 | from pathlib import Path 8 | from enum import Enum 9 | from typing import Dict, List, Optional, Callable, Any 10 | from datetime import datetime 11 | import uuid 12 | 13 | from config import BASE_DIR 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | # 任务状态枚举 18 | class TaskStatus(str, Enum): 19 | PENDING = "pending" # 等待中 20 | PROCESSING = "processing" # 处理中 21 | COMPLETED = "completed" # 已完成 22 | FAILED = "failed" # 失败 23 | CANCELLED = "cancelled" # 已取消 24 | 25 | # 任务优先级枚举 26 | class TaskPriority(int, Enum): 27 | LOW = 0 28 | NORMAL = 1 29 | HIGH = 2 30 | URGENT = 3 31 | 32 | # 任务类型枚举 33 | class TaskType(str, Enum): 34 | MODEL_TRAINING = "model_training" 35 | AUDIO_SYNTHESIS = "audio_synthesis" 36 | VIDEO_GENERATION = "video_generation" 37 | FILE_CLEANUP = "file_cleanup" 38 | 39 | class Task: 40 | def __init__( 41 | self, 42 | task_id: str, 43 | task_type: TaskType, 44 | params: Dict[str, Any], 45 | username: str, 46 | priority: TaskPriority = TaskPriority.NORMAL, 47 | callback: Optional[Callable] = None 48 | ): 49 | self.task_id = task_id 50 | self.task_type = task_type 51 | self.params = params 52 | self.username = username 53 | self.priority = priority 54 | self.status = TaskStatus.PENDING 55 | self.created_at = datetime.now() 56 | self.started_at = None 57 | self.completed_at = None 58 | self.result = None 59 | self.error = None 60 | self.progress = 0 61 | self.callback = callback 62 | self.timeout = 3600 # 默认超时时间:1小时 63 | self.retry_count = 0 # 重试次数 64 | self.max_retries = 3 # 最大重试次数 65 | self.resource_usage = { # 资源使用情况 66 | "cpu": 1.0, # CPU核心数 67 | "memory": 512, # 内存MB 68 | "gpu": 0.0, # GPU使用量 69 | } 70 | 71 | def to_dict(self) -> Dict[str, Any]: 72 | """将任务转换为字典表示""" 73 | return { 74 | "task_id": self.task_id, 75 | "task_type": self.task_type, 76 | "username": self.username, 77 | "priority": self.priority, 78 | "status": self.status, 79 | "created_at": self.created_at.isoformat(), 80 | "started_at": self.started_at.isoformat() if self.started_at else None, 81 | "completed_at": self.completed_at.isoformat() if self.completed_at else None, 82 | "progress": self.progress, 83 | "result": self.result, 84 | "error": self.error 85 | } 86 | 87 | def __lt__(self, other): 88 | """用于优先级队列的比较""" 89 | if self.priority != other.priority: 90 | return self.priority > other.priority # 高优先级先执行 91 | return self.created_at < other.created_at # 同优先级按创建时间排序 92 | 93 | class TaskQueue: 94 | def __init__(self, max_concurrent_tasks: int = 2): 95 | self.task_queue = queue.PriorityQueue() 96 | self.active_tasks: Dict[str, Task] = {} 97 | self.completed_tasks: Dict[str, Task] = {} 98 | self.max_concurrent_tasks = max_concurrent_tasks 99 | self.lock = threading.Lock() 100 | self.task_db_path = BASE_DIR / "tasks.json" 101 | self.worker_thread = threading.Thread(target=self._process_queue, daemon=True) 102 | self.timeout_thread = threading.Thread(target=self._check_timeouts, daemon=True) 103 | self.running = False 104 | 105 | # 资源管理 106 | self.available_resources = { 107 | "cpu": os.cpu_count() or 4, # 可用CPU核心数 108 | "memory": 8192, # 可用内存MB 109 | "gpu": 1.0, # 可用GPU资源 110 | } 111 | self.used_resources = { 112 | "cpu": 0.0, 113 | "memory": 0, 114 | "gpu": 0.0, 115 | } 116 | 117 | self.load_tasks() 118 | 119 | def start(self): 120 | """启动任务队列处理线程""" 121 | self.running = True 122 | self.worker_thread.start() 123 | self.timeout_thread.start() 124 | logger.info("任务队列服务已启动") 125 | 126 | def stop(self): 127 | """停止任务队列处理线程""" 128 | self.running = False 129 | if self.worker_thread.is_alive(): 130 | self.worker_thread.join(timeout=5.0) 131 | if self.timeout_thread.is_alive(): 132 | self.timeout_thread.join(timeout=5.0) 133 | logger.info("任务队列服务已停止") 134 | 135 | def add_task(self, task: Task) -> str: 136 | """添加新任务到队列""" 137 | with self.lock: 138 | self.task_queue.put(task) 139 | self.save_tasks() 140 | logger.info(f"已添加任务 {task.task_id} 到队列") 141 | return task.task_id 142 | 143 | def cancel_task(self, task_id: str) -> bool: 144 | """取消任务""" 145 | with self.lock: 146 | # 检查活动任务 147 | if task_id in self.active_tasks: 148 | task = self.active_tasks[task_id] 149 | if task.status == TaskStatus.PROCESSING: 150 | logger.warning(f"无法取消正在处理的任务 {task_id}") 151 | return False 152 | task.status = TaskStatus.CANCELLED 153 | self.completed_tasks[task_id] = task 154 | del self.active_tasks[task_id] 155 | self.save_tasks() 156 | logger.info(f"已取消任务 {task_id}") 157 | return True 158 | 159 | # 检查等待中的任务 160 | new_queue = queue.PriorityQueue() 161 | found = False 162 | while not self.task_queue.empty(): 163 | task = self.task_queue.get() 164 | if task.task_id == task_id: 165 | task.status = TaskStatus.CANCELLED 166 | self.completed_tasks[task_id] = task 167 | found = True 168 | logger.info(f"已取消等待中的任务 {task_id}") 169 | else: 170 | new_queue.put(task) 171 | 172 | # 恢复未取消的任务 173 | self.task_queue = new_queue 174 | 175 | if found: 176 | self.save_tasks() 177 | return True 178 | 179 | logger.warning(f"未找到任务 {task_id}") 180 | return False 181 | 182 | def get_task(self, task_id: str) -> Optional[Task]: 183 | """获取任务信息""" 184 | with self.lock: 185 | if task_id in self.active_tasks: 186 | return self.active_tasks[task_id] 187 | if task_id in self.completed_tasks: 188 | return self.completed_tasks[task_id] 189 | 190 | # 检查等待中的任务 191 | tasks = [] 192 | while not self.task_queue.empty(): 193 | task = self.task_queue.get() 194 | if task.task_id == task_id: 195 | result = task 196 | tasks.append(task) 197 | 198 | # 恢复任务队列 199 | for task in tasks: 200 | self.task_queue.put(task) 201 | 202 | if 'result' in locals(): 203 | return result 204 | 205 | return None 206 | 207 | def get_user_tasks(self, username: str) -> List[Dict[str, Any]]: 208 | """获取用户的所有任务""" 209 | with self.lock: 210 | result = [] 211 | 212 | # 检查活动任务 213 | for task_id, task in self.active_tasks.items(): 214 | if task.username == username: 215 | result.append(task.to_dict()) 216 | 217 | # 检查已完成任务 218 | for task_id, task in self.completed_tasks.items(): 219 | if task.username == username: 220 | result.append(task.to_dict()) 221 | 222 | # 检查等待中的任务 223 | tasks = [] 224 | while not self.task_queue.empty(): 225 | task = self.task_queue.get() 226 | if task.username == username: 227 | result.append(task.to_dict()) 228 | tasks.append(task) 229 | 230 | # 恢复任务队列 231 | for task in tasks: 232 | self.task_queue.put(task) 233 | 234 | return result 235 | 236 | def get_queue_status(self) -> Dict[str, Any]: 237 | """获取队列状态""" 238 | with self.lock: 239 | pending_count = self.task_queue.qsize() 240 | active_count = len(self.active_tasks) 241 | completed_count = len(self.completed_tasks) 242 | 243 | # 按任务类型统计 244 | type_counts = {} 245 | for task_type in TaskType: 246 | type_counts[task_type.value] = { 247 | "pending": 0, 248 | "processing": 0, 249 | "completed": 0, 250 | "failed": 0, 251 | "cancelled": 0 252 | } 253 | 254 | # 统计等待中的任务 255 | tasks = [] 256 | while not self.task_queue.empty(): 257 | task = self.task_queue.get() 258 | type_counts[task.task_type]["pending"] += 1 259 | tasks.append(task) 260 | 261 | # 恢复任务队列 262 | for task in tasks: 263 | self.task_queue.put(task) 264 | 265 | # 统计活动任务 266 | for task in self.active_tasks.values(): 267 | type_counts[task.task_type]["processing"] += 1 268 | 269 | # 统计已完成任务 270 | for task in self.completed_tasks.values(): 271 | if task.status == TaskStatus.COMPLETED: 272 | type_counts[task.task_type]["completed"] += 1 273 | elif task.status == TaskStatus.FAILED: 274 | type_counts[task.task_type]["failed"] += 1 275 | elif task.status == TaskStatus.CANCELLED: 276 | type_counts[task.task_type]["cancelled"] += 1 277 | 278 | return { 279 | "pending_count": pending_count, 280 | "active_count": active_count, 281 | "completed_count": completed_count, 282 | "max_concurrent_tasks": self.max_concurrent_tasks, 283 | "type_counts": type_counts 284 | } 285 | 286 | def update_task_progress(self, task_id: str, progress: float, result: Any = None, error: str = None) -> bool: 287 | """更新任务进度""" 288 | with self.lock: 289 | if task_id in self.active_tasks: 290 | task = self.active_tasks[task_id] 291 | task.progress = min(100.0, max(0.0, progress)) 292 | 293 | if progress >= 100.0: 294 | task.status = TaskStatus.COMPLETED 295 | task.completed_at = datetime.now() 296 | task.result = result 297 | self.completed_tasks[task_id] = task 298 | 299 | # 释放资源 300 | self._release_resources(task) 301 | del self.active_tasks[task_id] 302 | elif error: 303 | task.status = TaskStatus.FAILED 304 | task.completed_at = datetime.now() 305 | task.error = error 306 | self.completed_tasks[task_id] = task 307 | 308 | # 释放资源 309 | self._release_resources(task) 310 | del self.active_tasks[task_id] 311 | 312 | self.save_tasks() 313 | return True 314 | 315 | logger.warning(f"未找到活动任务 {task_id}") 316 | return False 317 | 318 | def set_max_concurrent_tasks(self, count: int) -> bool: 319 | """设置最大并发任务数""" 320 | if count < 1: 321 | return False 322 | 323 | with self.lock: 324 | self.max_concurrent_tasks = count 325 | return True 326 | 327 | def save_tasks(self): 328 | """保存任务状态到文件""" 329 | try: 330 | data = { 331 | "active_tasks": {task_id: task.to_dict() for task_id, task in self.active_tasks.items()}, 332 | "completed_tasks": {task_id: task.to_dict() for task_id, task in self.completed_tasks.items()} 333 | } 334 | 335 | with open(self.task_db_path, 'w', encoding='utf-8') as f: 336 | json.dump(data, f, ensure_ascii=False, indent=2) 337 | except Exception as e: 338 | logger.error(f"保存任务状态失败: {str(e)}") 339 | 340 | def load_tasks(self): 341 | """从文件加载任务状态""" 342 | if not self.task_db_path.exists(): 343 | return 344 | 345 | try: 346 | with open(self.task_db_path, 'r', encoding='utf-8') as f: 347 | data = json.load(f) 348 | 349 | # 恢复已完成任务 350 | for task_id, task_data in data.get("completed_tasks", {}).items(): 351 | task = self._dict_to_task(task_data) 352 | self.completed_tasks[task_id] = task 353 | 354 | # 恢复活动任务(仅加载PENDING状态的任务,其他状态视为失败) 355 | for task_id, task_data in data.get("active_tasks", {}).items(): 356 | task = self._dict_to_task(task_data) 357 | if task.status == TaskStatus.PENDING: 358 | self.task_queue.put(task) 359 | else: 360 | task.status = TaskStatus.FAILED 361 | task.error = "系统重启导致任务中断" 362 | self.completed_tasks[task_id] = task 363 | 364 | logger.info(f"已加载 {len(self.completed_tasks)} 个已完成任务") 365 | except Exception as e: 366 | logger.error(f"加载任务状态失败: {str(e)}") 367 | 368 | def _dict_to_task(self, task_data: Dict[str, Any]) -> Task: 369 | """将字典转换为任务对象""" 370 | task = Task( 371 | task_id=task_data["task_id"], 372 | task_type=task_data["task_type"], 373 | params={}, # 参数不保存 374 | username=task_data["username"], 375 | priority=task_data.get("priority", TaskPriority.NORMAL) 376 | ) 377 | 378 | task.status = task_data["status"] 379 | task.created_at = datetime.fromisoformat(task_data["created_at"]) 380 | 381 | if task_data.get("started_at"): 382 | task.started_at = datetime.fromisoformat(task_data["started_at"]) 383 | 384 | if task_data.get("completed_at"): 385 | task.completed_at = datetime.fromisoformat(task_data["completed_at"]) 386 | 387 | task.progress = task_data.get("progress", 0) 388 | task.result = task_data.get("result") 389 | task.error = task_data.get("error") 390 | 391 | return task 392 | 393 | def _check_timeouts(self): 394 | """检查任务超时的工作线程""" 395 | while self.running: 396 | try: 397 | with self.lock: 398 | current_time = datetime.now() 399 | timed_out_tasks = [] 400 | 401 | # 检查活动任务是否超时 402 | for task_id, task in self.active_tasks.items(): 403 | if task.started_at and task.timeout > 0: 404 | elapsed_seconds = (current_time - task.started_at).total_seconds() 405 | if elapsed_seconds > task.timeout: 406 | timed_out_tasks.append(task_id) 407 | 408 | # 处理超时任务 409 | for task_id in timed_out_tasks: 410 | task = self.active_tasks[task_id] 411 | 412 | # 检查是否可以重试 413 | if task.retry_count < task.max_retries: 414 | # 重新加入队列进行重试 415 | task.retry_count += 1 416 | task.status = TaskStatus.PENDING 417 | task.started_at = None 418 | task.progress = 0 419 | self.task_queue.put(task) 420 | logger.warning(f"任务 {task_id} 超时,进行第 {task.retry_count} 次重试") 421 | 422 | # 释放资源 423 | self._release_resources(task) 424 | del self.active_tasks[task_id] 425 | else: 426 | # 超过最大重试次数,标记为失败 427 | task.status = TaskStatus.FAILED 428 | task.completed_at = current_time 429 | task.error = f"任务超时,已重试 {task.retry_count} 次" 430 | self.completed_tasks[task_id] = task 431 | 432 | # 释放资源 433 | self._release_resources(task) 434 | del self.active_tasks[task_id] 435 | logger.error(f"任务 {task_id} 超时且超过最大重试次数,标记为失败") 436 | 437 | if timed_out_tasks: 438 | self.save_tasks() 439 | 440 | # 每10秒检查一次 441 | time.sleep(10) 442 | 443 | except Exception as e: 444 | logger.error(f"超时检查线程异常: {str(e)}") 445 | time.sleep(30) 446 | 447 | def _allocate_resources(self, task: Task) -> bool: 448 | """为任务分配资源""" 449 | # 检查是否有足够的资源 450 | for resource, amount in task.resource_usage.items(): 451 | if self.used_resources[resource] + amount > self.available_resources[resource]: 452 | return False 453 | 454 | # 分配资源 455 | for resource, amount in task.resource_usage.items(): 456 | self.used_resources[resource] += amount 457 | 458 | return True 459 | 460 | def _release_resources(self, task: Task): 461 | """释放任务占用的资源""" 462 | for resource, amount in task.resource_usage.items(): 463 | self.used_resources[resource] = max(0, self.used_resources[resource] - amount) 464 | 465 | def _process_queue(self): 466 | """处理任务队列的工作线程""" 467 | while self.running: 468 | try: 469 | # 检查是否可以处理更多任务 470 | with self.lock: 471 | if len(self.active_tasks) >= self.max_concurrent_tasks or self.task_queue.empty(): 472 | time.sleep(1) 473 | continue 474 | 475 | # 获取下一个任务 476 | task = self.task_queue.get(block=False) 477 | 478 | # 检查资源是否足够 479 | if not self._allocate_resources(task): 480 | # 资源不足,放回队列 481 | self.task_queue.put(task) 482 | logger.info(f"资源不足,任务 {task.task_id} 重新排队") 483 | time.sleep(5) 484 | continue 485 | 486 | # 更新任务状态 487 | task.status = TaskStatus.PROCESSING 488 | task.started_at = datetime.now() 489 | self.active_tasks[task.task_id] = task 490 | self.save_tasks() 491 | 492 | # 执行任务回调 493 | if task.callback: 494 | try: 495 | result = task.callback(task) 496 | self.update_task_progress(task.task_id, 100.0, result=result) 497 | except Exception as e: 498 | logger.error(f"任务 {task.task_id} 执行失败: {str(e)}") 499 | 500 | # 检查是否可以重试 501 | with self.lock: 502 | if task.retry_count < task.max_retries: 503 | # 重新加入队列进行重试 504 | task.retry_count += 1 505 | task.status = TaskStatus.PENDING 506 | task.started_at = None 507 | task.progress = 0 508 | self.task_queue.put(task) 509 | logger.warning(f"任务 {task.task_id} 执行失败,进行第 {task.retry_count} 次重试") 510 | 511 | # 释放资源 512 | self._release_resources(task) 513 | del self.active_tasks[task.task_id] 514 | self.save_tasks() 515 | else: 516 | # 超过最大重试次数,标记为失败 517 | self.update_task_progress(task.task_id, 0.0, error=f"{str(e)},已重试 {task.retry_count} 次") 518 | else: 519 | # 没有回调的任务直接标记为完成 520 | self.update_task_progress(task.task_id, 100.0) 521 | 522 | except queue.Empty: 523 | time.sleep(1) 524 | except Exception as e: 525 | logger.error(f"任务处理线程异常: {str(e)}") 526 | time.sleep(5) 527 | 528 | # 创建全局任务队列实例 529 | task_queue = TaskQueue() 530 | 531 | class TaskService: 532 | def __init__(self): 533 | self.task_queue = task_queue 534 | 535 | def start(self): 536 | """启动任务队列服务""" 537 | self.task_queue.start() 538 | 539 | def stop(self): 540 | """停止任务队列服务""" 541 | self.task_queue.stop() 542 | 543 | def create_task( 544 | self, 545 | task_type: TaskType, 546 | params: Dict[str, Any], 547 | username: str, 548 | priority: TaskPriority = TaskPriority.NORMAL, 549 | callback: Optional[Callable] = None 550 | ) -> str: 551 | """创建新任务""" 552 | task_id = str(uuid.uuid4()) 553 | task = Task( 554 | task_id=task_id, 555 | task_type=task_type, 556 | params=params, 557 | username=username, 558 | priority=priority, 559 | callback=callback 560 | ) 561 | return self.task_queue.add_task(task) 562 | 563 | def cancel_task(self, task_id: str) -> bool: 564 | """取消任务""" 565 | return self.task_queue.cancel_task(task_id) 566 | 567 | def get_task(self, task_id: str) -> Optional[Dict[str, Any]]: 568 | """获取任务信息""" 569 | task = self.task_queue.get_task(task_id) 570 | if task: 571 | return task.to_dict() 572 | return None 573 | 574 | def get_user_tasks(self, username: str) -> List[Dict[str, Any]]: 575 | """获取用户的所有任务""" 576 | return self.task_queue.get_user_tasks(username) 577 | 578 | def get_queue_status(self) -> Dict[str, Any]: 579 | """获取队列状态""" 580 | return self.task_queue.get_queue_status() 581 | 582 | def update_task_progress(self, task_id: str, progress: float, result: Any = None, error: str = None) -> bool: 583 | """更新任务进度""" 584 | return self.task_queue.update_task_progress(task_id, progress, result, error) 585 | 586 | def set_max_concurrent_tasks(self, count: int) -> bool: 587 | """设置最大并发任务数""" 588 | return self.task_queue.set_max_concurrent_tasks(count) -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | load_dotenv() 3 | 4 | import logging 5 | import gradio as gr 6 | import uuid 7 | import os 8 | import time 9 | 10 | from pathlib import Path 11 | from config import ( 12 | SERVER_HOST, 13 | SERVER_PORT, 14 | LOG_FILE, 15 | LOG_LEVEL, 16 | LOG_DIR, 17 | UPLOAD_DIR 18 | ) 19 | from services.audio_service import AudioService 20 | from services.video_service import VideoService 21 | from services.file_service import FileService 22 | from services.task_service import TaskService, TaskType, TaskPriority, TaskStatus 23 | import mimetypes 24 | from datetime import datetime 25 | import json 26 | 27 | # 确保日志目录存在 28 | LOG_DIR.mkdir(parents=True, exist_ok=True) 29 | 30 | # 配置日志 31 | try: 32 | logging.basicConfig( 33 | filename=LOG_FILE, 34 | level=getattr(logging, LOG_LEVEL), 35 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 36 | ) 37 | except Exception as e: 38 | print(f"警告: 文件日志配置失败: {e}") 39 | # 如果文件日志配置失败,使用控制台日志 40 | logging.basicConfig( 41 | level=getattr(logging, LOG_LEVEL), 42 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 43 | ) 44 | 45 | logger = logging.getLogger(__name__) 46 | 47 | # 从环境变量获取用户凭据 48 | 49 | VALID_CREDENTIALS = {} 50 | for i in range(1, 11): # 支持最多10个用户 51 | username = os.getenv(f'USER_{i}_NAME') 52 | password = os.getenv(f'USER_{i}_PASSWORD') 53 | print(f"用户{i}:{username}") 54 | if username and password: 55 | VALID_CREDENTIALS[username] = password 56 | 57 | custom_css = """ 58 | /* 隐藏 Gradio 页脚 */ 59 | footer { display: none !important; } 60 | /* 隐藏 Gradio 顶部Logo和标题栏 */ 61 | #logo, .prose { display: none !important; } 62 | /* 隐藏加载动画 */ 63 | #loading, .loading { display: none !important; } 64 | /* 隐藏 gradio 右下角的反馈按钮 */ 65 | .gradio-app .fixed.bottom-4.right-4, .feedback { display: none !important; } 66 | 67 | /* 自定义字体设置 */ 68 | @font-face { 69 | font-family: 'System UI'; 70 | src: local('system-ui'); 71 | font-weight: normal; 72 | font-style: normal; 73 | } 74 | 75 | @font-face { 76 | font-family: 'System UI'; 77 | src: local('system-ui'); 78 | font-weight: bold; 79 | font-style: normal; 80 | } 81 | 82 | /* 基础样式设置 */ 83 | :root { 84 | --primary-color: #4f46e5; 85 | --primary-color-hover: #4338ca; 86 | --secondary-color: #10b981; 87 | --secondary-color-hover: #059669; 88 | --text-color: #1f2937; 89 | --text-color-light: #6b7280; 90 | --bg-color: #ffffff; 91 | --bg-color-secondary: #f9fafb; 92 | --border-color: #e5e7eb; 93 | --shadow-color: rgba(0, 0, 0, 0.1); 94 | --radius: 8px; 95 | --transition: all 0.3s ease; 96 | } 97 | 98 | /* 深色模式 */ 99 | @media (prefers-color-scheme: dark) { 100 | :root { 101 | --primary-color: #6366f1; 102 | --primary-color-hover: #4f46e5; 103 | --secondary-color: #10b981; 104 | --secondary-color-hover: #059669; 105 | --text-color: #f9fafb; 106 | --text-color-light: #d1d5db; 107 | --bg-color: #111827; 108 | --bg-color-secondary: #1f2937; 109 | --border-color: #374151; 110 | --shadow-color: rgba(0, 0, 0, 0.3); 111 | } 112 | } 113 | 114 | /* 手动切换深色模式 */ 115 | .dark-theme { 116 | --primary-color: #6366f1; 117 | --primary-color-hover: #4f46e5; 118 | --secondary-color: #10b981; 119 | --secondary-color-hover: #059669; 120 | --text-color: #f9fafb; 121 | --text-color-light: #d1d5db; 122 | --bg-color: #111827; 123 | --bg-color-secondary: #1f2937; 124 | --border-color: #374151; 125 | --shadow-color: rgba(0, 0, 0, 0.3); 126 | } 127 | 128 | /* 使用系统字体作为后备 */ 129 | body { 130 | font-family: 'System UI', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif; 131 | color: var(--text-color); 132 | background-color: var(--bg-color); 133 | transition: var(--transition); 134 | margin: 0; 135 | padding: 0; 136 | } 137 | 138 | /* 容器样式 */ 139 | .gradio-container { 140 | max-width: 1200px !important; 141 | margin: 0 auto !important; 142 | padding: 1rem !important; 143 | } 144 | 145 | /* 按钮样式 */ 146 | button, .gr-button { 147 | background-color: var(--primary-color) !important; 148 | color: white !important; 149 | border: none !important; 150 | border-radius: var(--radius) !important; 151 | padding: 0.5rem 1rem !important; 152 | font-weight: 500 !important; 153 | transition: var(--transition) !important; 154 | cursor: pointer !important; 155 | box-shadow: 0 1px 3px var(--shadow-color) !important; 156 | } 157 | 158 | button:hover, .gr-button:hover { 159 | background-color: var(--primary-color-hover) !important; 160 | transform: translateY(-1px) !important; 161 | box-shadow: 0 4px 6px var(--shadow-color) !important; 162 | } 163 | 164 | /* 输入框样式 */ 165 | input, textarea, select, .gr-input, .gr-textarea, .gr-select { 166 | border: 1px solid var(--border-color) !important; 167 | border-radius: var(--radius) !important; 168 | padding: 0.5rem !important; 169 | background-color: var(--bg-color) !important; 170 | color: var(--text-color) !important; 171 | transition: var(--transition) !important; 172 | } 173 | 174 | input:focus, textarea:focus, select:focus, .gr-input:focus, .gr-textarea:focus, .gr-select:focus { 175 | border-color: var(--primary-color) !important; 176 | outline: none !important; 177 | box-shadow: 0 0 0 2px rgba(79, 70, 229, 0.2) !important; 178 | } 179 | 180 | /* 标签页样式 */ 181 | .tabs { 182 | border-bottom: 1px solid var(--border-color) !important; 183 | margin-bottom: 1rem !important; 184 | } 185 | 186 | .tab-button { 187 | background: none !important; 188 | border: none !important; 189 | padding: 0.75rem 1rem !important; 190 | color: var(--text-color-light) !important; 191 | font-weight: 500 !important; 192 | border-bottom: 2px solid transparent !important; 193 | transition: var(--transition) !important; 194 | } 195 | 196 | .tab-button.selected, .tab-button:hover { 197 | color: var(--primary-color) !important; 198 | border-bottom-color: var(--primary-color) !important; 199 | background-color: transparent !important; 200 | } 201 | 202 | /* 卡片样式 */ 203 | .gr-box, .gr-panel { 204 | border-radius: var(--radius) !important; 205 | border: 1px solid var(--border-color) !important; 206 | background-color: var(--bg-color) !important; 207 | box-shadow: 0 1px 3px var(--shadow-color) !important; 208 | transition: var(--transition) !important; 209 | } 210 | 211 | .gr-box:hover, .gr-panel:hover { 212 | box-shadow: 0 4px 6px var(--shadow-color) !important; 213 | } 214 | 215 | /* 标题样式 */ 216 | h1, h2, h3, h4, h5, h6 { 217 | color: var(--text-color) !important; 218 | margin-top: 0 !important; 219 | } 220 | 221 | h1 { 222 | font-size: 1.8rem !important; 223 | font-weight: 700 !important; 224 | margin-bottom: 1.5rem !important; 225 | color: var(--primary-color) !important; 226 | } 227 | 228 | /* 图库样式 */ 229 | .gallery { 230 | display: grid !important; 231 | grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)) !important; 232 | gap: 1rem !important; 233 | } 234 | 235 | .gallery-item { 236 | border-radius: var(--radius) !important; 237 | overflow: hidden !important; 238 | box-shadow: 0 1px 3px var(--shadow-color) !important; 239 | transition: var(--transition) !important; 240 | } 241 | 242 | .gallery-item:hover { 243 | transform: translateY(-2px) !important; 244 | box-shadow: 0 4px 6px var(--shadow-color) !important; 245 | } 246 | 247 | /* 响应式布局 */ 248 | @media (max-width: 768px) { 249 | .gradio-container { 250 | padding: 0.5rem !important; 251 | width: 100% !important; 252 | max-width: 100% !important; 253 | } 254 | 255 | .gr-row { 256 | flex-direction: column !important; 257 | } 258 | 259 | .gr-col { 260 | width: 100% !important; 261 | margin-bottom: 1rem !important; 262 | } 263 | 264 | .gallery { 265 | grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)) !important; 266 | } 267 | 268 | h1 { 269 | font-size: 1.5rem !important; 270 | } 271 | 272 | button, .gr-button { 273 | width: 100% !important; 274 | margin-bottom: 0.5rem !important; 275 | } 276 | 277 | input, textarea, select { 278 | font-size: 16px !important; /* 防止iOS缩放 */ 279 | } 280 | 281 | /* 改进移动端标签页 */ 282 | .tabs { 283 | overflow-x: auto !important; 284 | white-space: nowrap !important; 285 | -webkit-overflow-scrolling: touch !important; 286 | } 287 | 288 | .tab-button { 289 | padding: 0.5rem 0.75rem !important; 290 | } 291 | } 292 | 293 | /* 更小屏幕的优化 */ 294 | @media (max-width: 480px) { 295 | .gallery { 296 | grid-template-columns: repeat(auto-fill, minmax(120px, 1fr)) !important; 297 | } 298 | 299 | h1 { 300 | font-size: 1.3rem !important; 301 | } 302 | } 303 | 304 | /* 动画效果 */ 305 | @keyframes fadeIn { 306 | from { opacity: 0; transform: translateY(10px); } 307 | to { opacity: 1; transform: translateY(0); } 308 | } 309 | 310 | .gr-box, .gr-panel, .gr-form { 311 | animation: fadeIn 0.3s ease-out; 312 | } 313 | """ 314 | 315 | class HeyGemApp: 316 | def __init__(self): 317 | self.audio_service = AudioService() 318 | self.video_service = VideoService() 319 | self.file_service = FileService() 320 | self.task_service = TaskService() 321 | self.current_user = None 322 | self.is_logged_in = False 323 | 324 | # 启动任务队列服务 325 | self.task_service.start() 326 | 327 | def login(self, username, password): 328 | if username in VALID_CREDENTIALS and VALID_CREDENTIALS[username] == password: 329 | self.is_logged_in = True 330 | self.current_user = username 331 | return True, "登录成功" 332 | return False, "用户名或密码错误" 333 | 334 | def _move_video_to_user_dir(self, video_filename): 335 | """移动视频文件到用户目录""" 336 | try: 337 | # 如果video_filename是完整路径,只取文件名部分 338 | if video_filename.startswith('/'): 339 | video_filename = Path(video_filename).name 340 | 341 | source_path = UPLOAD_DIR / video_filename 342 | logger.info(f"UPLOAD_DIR: {UPLOAD_DIR}") 343 | logger.info(f"video_filename: {video_filename}") 344 | logger.info(f"完整源文件路径: {source_path}") 345 | logger.info(f"源文件是否存在: {source_path.exists()}") 346 | 347 | if not source_path.exists(): 348 | return False, f"视频文件不存在: {source_path}", None 349 | 350 | user_dir = self.file_service.get_user_dir(self.current_user) 351 | target_path = user_dir / video_filename 352 | source_path.rename(target_path) 353 | return True, f"视频已保存到: {target_path}", target_path 354 | 355 | except Exception as e: 356 | logger.error(f"移动视频文件失败: {str(e)}") 357 | return False, f"移动视频文件失败: {str(e)}", None 358 | 359 | def _cleanup_temp_images(self): 360 | """清理临时图片文件""" 361 | try: 362 | cleaned_files = [] 363 | for temp_file in UPLOAD_DIR.glob("*.png"): 364 | if temp_file.name.replace(".png", "").isdigit(): 365 | temp_file.unlink() 366 | cleaned_files.append(temp_file.name) 367 | 368 | if cleaned_files: 369 | logger.info(f"已清理临时图片文件: {', '.join(cleaned_files)}") 370 | return True, f"已清理 {len(cleaned_files)} 个临时文件" 371 | return True, "没有需要清理的临时文件" 372 | 373 | except Exception as e: 374 | logger.error(f"清理临时图片文件失败: {str(e)}") 375 | return False, f"清理临时文件失败: {str(e)}" 376 | 377 | def create_interface(self): 378 | with gr.Blocks(title="HeyGem数字人", css=custom_css) as demo: 379 | login_state = gr.State(value=False) 380 | current_user_state = gr.State(value=None) 381 | theme_state = gr.State(value="light") 382 | # 登录区 383 | with gr.Group(visible=True) as login_group: 384 | gr.Markdown("# HeyGem数字人登录") 385 | username = gr.Textbox(label="用户名", placeholder="请输入用户名") 386 | password = gr.Textbox(label="密码", placeholder="请输入密码", type="password") 387 | login_btn = gr.Button("登录") 388 | login_status = gr.Textbox(label="登录状态", interactive=False) 389 | # 主界面区 390 | with gr.Group(visible=False) as main_group: 391 | with gr.Row(): 392 | gr.Markdown("# HeyGem数字人界面") 393 | with gr.Column(scale=1, min_width=100): 394 | theme_btn = gr.Button("🌓 切换深色/浅色模式", scale=0) 395 | with gr.Tab("我的作品"): 396 | with gr.Row(): 397 | with gr.Column(scale=1): 398 | works_gallery = gr.Gallery( 399 | label="我的作品", 400 | show_label=True, 401 | elem_id="works_gallery", 402 | columns=4, 403 | allow_preview=True, 404 | height="auto", 405 | object_fit="contain", 406 | min_width=160 407 | ) 408 | with gr.Row(): 409 | refresh_btn = gr.Button("刷新作品列表") 410 | selected_video = gr.State(value=None) 411 | with gr.Tab("我的数字模特"): 412 | with gr.Row(): 413 | with gr.Column(scale=1): 414 | models_gallery = gr.Gallery( 415 | label="我的数字模特", 416 | show_label=True, 417 | elem_id="models_gallery", 418 | columns=4, 419 | allow_preview=True, 420 | height="auto", 421 | object_fit="contain", 422 | min_width=160 423 | ) 424 | with gr.Row(): 425 | refresh_models_btn = gr.Button("刷新模特列表") 426 | selected_model = gr.State(value=None) 427 | with gr.Tab("模型训练"): 428 | with gr.Row(): 429 | with gr.Column(): 430 | video_input = gr.File(label="上传视频") 431 | model_name = gr.Textbox(label="模特名称", placeholder="请输入模特名称") 432 | train_btn = gr.Button("开始训练") 433 | train_output = gr.Textbox(label="训练状态", lines=5) 434 | reference_audio = gr.Textbox(label="参考音频URL", visible=False) 435 | reference_text = gr.Textbox(label="参考文本", visible=False) 436 | with gr.Tab("视频生成"): 437 | with gr.Row(): 438 | with gr.Column(): 439 | video_path_input = gr.Textbox( 440 | label="选择数字人模特", 441 | placeholder="请输入模特名称(例如:model1)", 442 | value=None 443 | ) 444 | text_input = gr.Textbox(label="要合成的文本", lines=3) 445 | generate_btn = gr.Button("生成视频") 446 | task_id_output = gr.Textbox(label="任务ID") 447 | 448 | # 添加进度条 449 | with gr.Row(): 450 | progress_bar = gr.Progress() 451 | auto_refresh = gr.Checkbox(label="自动刷新状态", value=True) 452 | 453 | check_status_btn = gr.Button("检查状态") 454 | status_output = gr.Textbox(label="状态", lines=3) 455 | 456 | # 添加操作引导 457 | gr.Markdown(""" 458 | ### 操作指南 459 | 1. 输入模特名称(在"我的数字模特"中可以查看) 460 | 2. 输入要合成的文本 461 | 3. 点击"生成视频"按钮开始任务 462 | 4. 系统会自动更新进度,也可以手动点击"检查状态" 463 | 5. 任务完成后,可以在"我的作品"中查看生成的视频 464 | 465 | > **提示**:生成视频可能需要几分钟时间,请耐心等待 466 | """) 467 | with gr.Tab("文件清理"): 468 | with gr.Row(): 469 | with gr.Column(): 470 | days_input = gr.Number( 471 | label="清理多少天前的文件", 472 | value=7, 473 | minimum=1, 474 | maximum=365, 475 | step=1 476 | ) 477 | cleanup_btn = gr.Button("开始清理") 478 | cleanup_output = gr.Textbox(label="清理结果", lines=10) 479 | 480 | with gr.Tab("帮助与反馈"): 481 | with gr.Row(): 482 | with gr.Column(scale=2): 483 | gr.Markdown(""" 484 | # 帮助文档 485 | 486 | ## 基本功能 487 | 488 | ### 模型训练 489 | 1. 上传一段清晰的视频文件(MP4格式) 490 | 2. 输入模特名称 491 | 3. 点击"开始训练"按钮 492 | 4. 等待训练完成(可在任务队列中查看进度) 493 | 494 | ### 视频生成 495 | 1. 输入已训练好的模特名称 496 | 2. 输入要合成的文本 497 | 3. 点击"生成视频"按钮 498 | 4. 等待生成完成(可在任务队列中查看进度) 499 | 500 | ### 文件管理 501 | - 在"我的作品"中查看所有生成的视频 502 | - 在"我的数字模特"中查看所有训练好的模型 503 | - 使用"文件清理"功能定期清理临时文件 504 | 505 | ### 任务队列 506 | - 查看所有任务的状态和进度 507 | - 取消等待中的任务 508 | - 设置最大并发任务数 509 | 510 | ## 常见问题 511 | 512 | **Q: 为什么我的视频生成失败了?** 513 | A: 可能是以下原因: 514 | - 模特名称输入错误 515 | - 文本内容过长或包含特殊字符 516 | - 系统资源不足 517 | 518 | **Q: 如何提高视频生成质量?** 519 | A: 请确保: 520 | - 上传高质量的原始视频 521 | - 确保视频中人物面部清晰可见 522 | - 避免背景噪音和干扰 523 | 524 | **Q: 文件会自动删除吗?** 525 | A: 系统默认不会自动删除文件,需要手动使用"文件清理"功能。 526 | """) 527 | 528 | with gr.Column(scale=1): 529 | gr.Markdown("## 用户反馈") 530 | feedback_type = gr.Radio( 531 | ["问题报告", "功能建议", "其他反馈"], 532 | label="反馈类型", 533 | value="问题报告" 534 | ) 535 | feedback_content = gr.Textbox( 536 | label="反馈内容", 537 | placeholder="请详细描述您的问题或建议...", 538 | lines=5 539 | ) 540 | feedback_email = gr.Textbox( 541 | label="联系邮箱(选填)", 542 | placeholder="example@email.com" 543 | ) 544 | feedback_btn = gr.Button("提交反馈") 545 | feedback_status = gr.Textbox(label="提交状态", interactive=False) 546 | with gr.Tab("任务队列"): 547 | with gr.Row(): 548 | with gr.Column(scale=2): 549 | gr.Markdown("### 任务队列状态") 550 | queue_status = gr.JSON(label="队列状态") 551 | refresh_queue_btn = gr.Button("刷新队列状态") 552 | 553 | gr.Markdown("### 并发设置") 554 | with gr.Row(): 555 | concurrent_tasks = gr.Slider( 556 | label="最大并发任务数", 557 | minimum=1, 558 | maximum=10, 559 | step=1, 560 | value=2 561 | ) 562 | set_concurrent_btn = gr.Button("设置") 563 | 564 | with gr.Column(scale=3): 565 | gr.Markdown("### 我的任务") 566 | user_tasks = gr.Dataframe( 567 | headers=["任务ID", "类型", "状态", "进度", "创建时间", "操作"], 568 | datatype=["str", "str", "str", "number", "str", "str"], 569 | row_count=10, 570 | col_count=(6, "fixed"), 571 | interactive=False 572 | ) 573 | refresh_tasks_btn = gr.Button("刷新我的任务") 574 | 575 | gr.Markdown("### 任务操作") 576 | with gr.Row(): 577 | task_id_input = gr.Textbox(label="任务ID", placeholder="输入要操作的任务ID") 578 | cancel_task_btn = gr.Button("取消任务", variant="stop") 579 | task_details = gr.JSON(label="任务详情") 580 | get_task_btn = gr.Button("获取任务详情") 581 | # --- 任务队列逻辑 --- 582 | def get_queue_status(): 583 | return self.task_service.get_queue_status() 584 | 585 | def get_user_tasks(): 586 | if not self.current_user: 587 | return [] 588 | 589 | tasks = self.task_service.get_user_tasks(self.current_user) 590 | # 格式化为表格显示 591 | rows = [] 592 | for task in tasks: 593 | task_id = task["task_id"] 594 | task_type = task["task_type"] 595 | status = task["status"] 596 | progress = task["progress"] 597 | created_at = datetime.fromisoformat(task["created_at"]).strftime("%Y-%m-%d %H:%M:%S") 598 | action = "取消" if status == TaskStatus.PENDING else "-" 599 | rows.append([task_id, task_type, status, progress, created_at, action]) 600 | 601 | return rows 602 | 603 | def get_task_details(task_id): 604 | if not task_id: 605 | return {"error": "请输入任务ID"} 606 | 607 | task = self.task_service.get_task(task_id) 608 | if not task: 609 | return {"error": "未找到任务"} 610 | 611 | return task 612 | 613 | def cancel_task(task_id): 614 | if not task_id: 615 | return {"error": "请输入任务ID"} 616 | 617 | success = self.task_service.cancel_task(task_id) 618 | if success: 619 | return {"success": f"已取消任务 {task_id}"} 620 | else: 621 | return {"error": f"无法取消任务 {task_id}"} 622 | 623 | def set_max_concurrent_tasks(count): 624 | success = self.task_service.set_max_concurrent_tasks(int(count)) 625 | if success: 626 | return {"success": f"已设置最大并发任务数为 {count}"} 627 | else: 628 | return {"error": "设置失败"} 629 | 630 | # 绑定任务队列相关事件 631 | refresh_queue_btn.click(get_queue_status, None, queue_status) 632 | refresh_tasks_btn.click(get_user_tasks, None, user_tasks) 633 | get_task_btn.click(get_task_details, task_id_input, task_details) 634 | cancel_task_btn.click(cancel_task, task_id_input, task_details) 635 | set_concurrent_btn.click(set_max_concurrent_tasks, concurrent_tasks, queue_status) 636 | 637 | # 页面加载时自动刷新 638 | demo.load(fn=get_queue_status, inputs=None, outputs=queue_status) 639 | demo.load(fn=get_user_tasks, inputs=None, outputs=user_tasks) 640 | 641 | # --- 深色模式切换逻辑 --- 642 | def toggle_theme(current_theme): 643 | if current_theme == "light": 644 | # 切换到深色模式 645 | return "dark" 646 | else: 647 | # 切换到浅色模式 648 | return "light" 649 | 650 | theme_btn.click( 651 | fn=toggle_theme, 652 | inputs=[theme_state], 653 | outputs=[theme_state] 654 | ) 655 | 656 | # --- 我的作品逻辑 --- 657 | def get_gallery_items(): 658 | if not self.current_user: 659 | return [] 660 | return [(w["path"], w["name"]) for w in self.get_works_info()] 661 | def get_models_items(): 662 | if not self.current_user: 663 | return [] 664 | return [(m["path"], m["name"]) for m in self.get_models_info()] 665 | def select_video(evt: gr.SelectData): 666 | works = self.get_works_info() 667 | selected = works[evt.index] 668 | return selected["path"] 669 | def select_model(evt: gr.SelectData): 670 | models = self.get_models_info() 671 | selected = models[evt.index] 672 | return selected["path"] 673 | works_gallery.select(select_video, outputs=selected_video) 674 | refresh_btn.click(get_gallery_items, None, works_gallery) 675 | models_gallery.select(select_model, outputs=selected_model) 676 | refresh_models_btn.click(get_models_items, None, models_gallery) 677 | demo.load(fn=get_gallery_items, inputs=None, outputs=works_gallery) 678 | demo.load(fn=get_models_items, inputs=None, outputs=models_gallery) 679 | def on_training_complete(result): 680 | if "参考音频:" in result and "参考文本:" in result: 681 | ref_audio = result.split("参考音频:")[1].split("\n")[0].strip() 682 | ref_text = result.split("参考文本:")[1].strip() 683 | return result, ref_audio, ref_text 684 | return result, "", "" 685 | train_btn.click( 686 | fn=lambda video_file, model_name: self.train_model(video_file, model_name), 687 | inputs=[video_input, model_name], 688 | outputs=[train_output] 689 | ).then( 690 | fn=on_training_complete, 691 | inputs=[train_output], 692 | outputs=[train_output, reference_audio, reference_text] 693 | ).then( 694 | fn=get_models_items, 695 | inputs=None, 696 | outputs=models_gallery 697 | ) 698 | def generate_video(video_path, text): 699 | if not video_path: 700 | return "错误: 请先选择数字人模特", None 701 | if not text: 702 | return "错误: 请输入要合成的文本", None 703 | 704 | try: 705 | # 获取模型训练结果 706 | model_result = self.get_model_training_result(video_path) 707 | if not model_result: 708 | return "错误: 未找到模型训练结果", None 709 | 710 | # 合成音频 711 | audio_task_id, audio_msg = self.synthesize_audio( 712 | text=text, 713 | reference_text=model_result.get("reference_audio_text", ""), 714 | reference_audio=model_result.get("asr_format_audio_url", "") 715 | ) 716 | 717 | if not audio_task_id: 718 | return audio_msg, None 719 | 720 | return f"音频合成任务已创建,任务ID: {audio_task_id}\n请在任务队列中查看进度。", audio_task_id 721 | except Exception as e: 722 | logger.error(f"生成视频失败: {str(e)}") 723 | return f"生成视频失败: {str(e)}", None 724 | 725 | def check_status(task_id): 726 | if not task_id: 727 | return "请先生成视频获取任务ID" 728 | 729 | task = self.task_service.get_task(task_id) 730 | if not task: 731 | return f"未找到任务: {task_id}" 732 | 733 | status = task["status"] 734 | progress = task["progress"] 735 | 736 | # 更新进度条 737 | if progress > 0: 738 | progress_bar.update(progress / 100) 739 | 740 | if status == TaskStatus.COMPLETED: 741 | result = task["result"] 742 | if task["task_type"] == TaskType.AUDIO_SYNTHESIS: 743 | # 音频合成完成,开始视频生成 744 | audio_path = result.get("audio_path") 745 | video_task_id, video_msg = self.make_video( 746 | video_path=video_path_input.value, 747 | audio_path=audio_path 748 | ) 749 | return f"✅ 音频合成已完成,开始生成视频。\n视频任务ID: {video_task_id}" 750 | elif task["task_type"] == TaskType.VIDEO_GENERATION: 751 | # 视频生成完成 752 | video_path = result.get("video_path") 753 | return f"🎉 视频生成已完成!\n视频保存路径: {video_path}\n请在'我的作品'中查看。" 754 | elif status == TaskStatus.FAILED: 755 | return f"❌ 任务失败: {task['error']}\n请检查输入参数或联系管理员。" 756 | elif status == TaskStatus.CANCELLED: 757 | return "⚠️ 任务已取消" 758 | else: 759 | return f"⏳ 任务状态: {status}, 进度: {progress}%\n请耐心等待..." 760 | 761 | # 自动刷新状态 762 | def auto_refresh_status(): 763 | task_id = task_id_output.value 764 | if task_id and auto_refresh.value: 765 | status = check_status(task_id) 766 | return status 767 | return status_output.value 768 | 769 | # 自动刷新状态(移除every参数,因为新版本Gradio不支持) 770 | demo.load(fn=auto_refresh_status, inputs=None, outputs=status_output) 771 | 772 | generate_btn.click( 773 | fn=generate_video, 774 | inputs=[video_path_input, text_input], 775 | outputs=[status_output, task_id_output] 776 | ) 777 | check_status_btn.click( 778 | fn=check_status, 779 | inputs=[task_id_output], 780 | outputs=[status_output] 781 | ) 782 | cleanup_btn.click( 783 | fn=lambda days: self.cleanup_files(days), 784 | inputs=[days_input], 785 | outputs=[cleanup_output] 786 | ) 787 | 788 | # --- 帮助与反馈逻辑 --- 789 | def submit_feedback(feedback_type, content, email): 790 | if not content: 791 | return "❌ 请输入反馈内容" 792 | 793 | try: 794 | # 记录反馈到日志 795 | logger.info(f"用户反馈: 类型={feedback_type}, 邮箱={email}, 内容={content}") 796 | 797 | # 保存反馈到文件 798 | feedback_dir = Path("feedback") 799 | feedback_dir.mkdir(exist_ok=True) 800 | 801 | timestamp = datetime.now().strftime("%Y%m%d%H%M%S") 802 | feedback_file = feedback_dir / f"feedback_{timestamp}.txt" 803 | 804 | with open(feedback_file, "w", encoding="utf-8") as f: 805 | f.write(f"类型: {feedback_type}\n") 806 | f.write(f"时间: {datetime.now().isoformat()}\n") 807 | f.write(f"用户: {self.current_user}\n") 808 | if email: 809 | f.write(f"邮箱: {email}\n") 810 | f.write(f"内容:\n{content}\n") 811 | 812 | return "✅ 感谢您的反馈!我们会尽快处理。" 813 | except Exception as e: 814 | logger.error(f"保存反馈失败: {str(e)}") 815 | return f"❌ 提交失败: {str(e)}" 816 | 817 | feedback_btn.click( 818 | fn=submit_feedback, 819 | inputs=[feedback_type, feedback_content, feedback_email], 820 | outputs=feedback_status 821 | ) 822 | def on_login(username, password): 823 | success, msg = self.login(username, password) 824 | if success: 825 | return True, gr.update(visible=False), gr.update(visible=True), username, "登录成功" 826 | # 登录失败,主界面不显示,提示错误 827 | return False, gr.update(visible=True), gr.update(visible=False), None, msg 828 | login_btn.click( 829 | fn=on_login, 830 | inputs=[username, password], 831 | outputs=[login_state, login_group, main_group, current_user_state, login_status] 832 | ) 833 | return demo 834 | 835 | def train_model(self, video_file, model_name): 836 | """训练模型""" 837 | try: 838 | if not video_file or not model_name: 839 | return "错误:请提供视频文件和模特名称" 840 | 841 | # 保存上传的视频文件 842 | file_path = self.file_service.save_uploaded_file(video_file, video_file.name, self.current_user) 843 | 844 | # 创建训练任务 845 | params = { 846 | "video_path": str(file_path), 847 | "model_name": model_name 848 | } 849 | 850 | def train_model_task(task): 851 | # 这里是实际的训练逻辑 852 | # 在实际应用中,这里应该调用模型训练API 853 | logger.info(f"开始训练模型: {model_name}") 854 | time.sleep(5) # 模拟训练过程 855 | 856 | # 返回训练结果 857 | return { 858 | "model_name": model_name, 859 | "reference_audio": f"https://example.com/audio/{model_name}.wav", 860 | "reference_text": "这是一段参考文本,用于测试语音合成效果。" 861 | } 862 | 863 | task_id = self.task_service.create_task( 864 | task_type=TaskType.MODEL_TRAINING, 865 | params=params, 866 | username=self.current_user, 867 | priority=TaskPriority.HIGH, 868 | callback=train_model_task 869 | ) 870 | 871 | return f"已创建训练任务,任务ID: {task_id}\n请在任务队列中查看进度。" 872 | 873 | except Exception as e: 874 | logger.error(f"训练模型失败: {str(e)}") 875 | return f"训练失败: {str(e)}" 876 | 877 | def get_uploaded_videos(self): 878 | return self.file_service.scan_uploaded_videos(self.current_user) 879 | 880 | def synthesize_audio(self, text, reference_text, reference_audio, username=None): 881 | """合成音频""" 882 | try: 883 | if not text: 884 | return None, "错误:请输入要合成的文本" 885 | 886 | # 创建音频合成任务 887 | params = { 888 | "text": text, 889 | "reference_text": reference_text, 890 | "reference_audio": reference_audio 891 | } 892 | 893 | def synthesize_audio_task(task): 894 | # 这里是实际的音频合成逻辑 895 | logger.info(f"开始合成音频,文本长度: {len(text)}") 896 | time.sleep(3) # 模拟合成过程 897 | 898 | # 返回合成结果 899 | return { 900 | "audio_path": f"/tmp/audio_{uuid.uuid4()}.wav", 901 | "duration": len(text) * 0.1 # 模拟音频时长 902 | } 903 | 904 | task_id = self.task_service.create_task( 905 | task_type=TaskType.AUDIO_SYNTHESIS, 906 | params=params, 907 | username=username or self.current_user, 908 | priority=TaskPriority.NORMAL, 909 | callback=synthesize_audio_task 910 | ) 911 | 912 | return task_id, f"已创建音频合成任务,任务ID: {task_id}" 913 | 914 | except Exception as e: 915 | logger.error(f"音频合成失败: {str(e)}") 916 | return None, f"音频合成失败: {str(e)}" 917 | 918 | def make_video(self, video_path, audio_path): 919 | """生成视频""" 920 | try: 921 | if not video_path or not audio_path: 922 | return None, "错误:请提供视频路径和音频路径" 923 | 924 | # 创建视频生成任务 925 | params = { 926 | "video_path": video_path, 927 | "audio_path": audio_path 928 | } 929 | 930 | def make_video_task(task): 931 | # 这里是实际的视频生成逻辑 932 | logger.info(f"开始生成视频: {video_path}") 933 | time.sleep(10) # 模拟视频生成过程 934 | 935 | # 生成结果视频路径 936 | result_video = f"{video_path.rsplit('.', 1)[0]}-r.mp4" 937 | 938 | # 返回生成结果 939 | return { 940 | "video_path": result_video 941 | } 942 | 943 | task_id = self.task_service.create_task( 944 | task_type=TaskType.VIDEO_GENERATION, 945 | params=params, 946 | username=self.current_user, 947 | priority=TaskPriority.NORMAL, 948 | callback=make_video_task 949 | ) 950 | 951 | return task_id, f"已创建视频生成任务,任务ID: {task_id}" 952 | 953 | except Exception as e: 954 | logger.error(f"视频生成失败: {str(e)}") 955 | return None, f"视频生成失败: {str(e)}" 956 | 957 | def cleanup_files(self, days_old: int) -> str: 958 | """清理临时文件""" 959 | try: 960 | if days_old < 1: 961 | return "错误:清理天数必须大于等于1" 962 | 963 | # 创建文件清理任务 964 | params = { 965 | "days_old": days_old 966 | } 967 | 968 | def cleanup_files_task(task): 969 | # 这里是实际的文件清理逻辑 970 | logger.info(f"开始清理 {days_old} 天前的文件") 971 | 972 | # 调用文件服务的清理方法 973 | result = self.file_service.cleanup_temp_files(days_old, self.current_user) 974 | 975 | # 更新任务进度 976 | self.task_service.update_task_progress(task.task_id, 50) 977 | 978 | # 返回清理结果 979 | return result 980 | 981 | task_id = self.task_service.create_task( 982 | task_type=TaskType.FILE_CLEANUP, 983 | params=params, 984 | username=self.current_user, 985 | priority=TaskPriority.LOW, 986 | callback=cleanup_files_task 987 | ) 988 | 989 | return f"已创建文件清理任务,任务ID: {task_id}\n请在任务队列中查看进度。" 990 | 991 | except Exception as e: 992 | logger.error(f"文件清理失败: {str(e)}") 993 | return f"文件清理失败: {str(e)}" 994 | 995 | def get_works(self): 996 | return self.file_service.scan_works(self.current_user) 997 | 998 | def get_works_info(self): 999 | if not self.current_user: 1000 | return [] 1001 | works = [] 1002 | for file in self.file_service.scan_works(self.current_user): 1003 | file_path = Path(file["path"]) if isinstance(file, dict) else Path(file) 1004 | works.append({ 1005 | "name": file_path.stem, 1006 | "path": str(file_path), 1007 | "cover": None, 1008 | "created_time": file_path.stat().st_ctime 1009 | }) 1010 | return works 1011 | 1012 | def get_models_info(self): 1013 | if not self.current_user: 1014 | return [] 1015 | models = [] 1016 | for file in self.file_service.scan_models(self.current_user): 1017 | file_path = Path(file["path"]) 1018 | models.append({ 1019 | "name": file_path.stem, 1020 | "path": str(file_path), 1021 | "cover": None 1022 | }) 1023 | return models 1024 | 1025 | def get_model_training_result(self, model_name): 1026 | try: 1027 | user_dir = self.file_service.get_user_dir(self.current_user) 1028 | result_file = user_dir / f"{model_name}_training.json" 1029 | if result_file.exists(): 1030 | with open(result_file, 'r', encoding='utf-8') as f: 1031 | return json.load(f) 1032 | return None 1033 | except Exception as e: 1034 | logger.error(f"读取训练结果失败: {str(e)}") 1035 | return None 1036 | 1037 | def main(): 1038 | try: 1039 | print("开始启动HeyGem Web界面...") 1040 | logger.info("启动HeyGem Web界面") 1041 | 1042 | print("创建HeyGemApp实例...") 1043 | app = HeyGemApp() 1044 | 1045 | print("创建Gradio界面...") 1046 | demo = app.create_interface() 1047 | 1048 | print(f"启动Gradio服务器,地址: {SERVER_HOST}:{SERVER_PORT}") 1049 | demo.launch( 1050 | server_name=SERVER_HOST, 1051 | server_port=SERVER_PORT, 1052 | share=False, 1053 | favicon_path=None 1054 | ) 1055 | except Exception as e: 1056 | print(f"程序启动失败: {str(e)}") 1057 | logger.error(f"程序启动失败: {str(e)}") 1058 | import traceback 1059 | traceback.print_exc() 1060 | finally: 1061 | # 确保停止任务队列服务 1062 | if 'app' in locals() and hasattr(app, 'task_service'): 1063 | app.task_service.stop() 1064 | logger.info("任务队列服务已停止") 1065 | 1066 | if __name__ == "__main__": 1067 | main() --------------------------------------------------------------------------------