├── .gitignore ├── README.md ├── backend ├── app.db ├── app │ ├── __init__.py │ ├── api │ │ └── v1 │ │ │ ├── __init__.py │ │ │ ├── assets.py │ │ │ ├── common.py │ │ │ ├── settings.py │ │ │ ├── tasks.py │ │ │ ├── terminal.py │ │ │ └── upload.py │ ├── config.py │ ├── database.py │ ├── main.py │ ├── middleware │ │ └── error_handler.py │ ├── models │ │ ├── asset.py │ │ ├── constants.py │ │ ├── setting.py │ │ ├── task.py │ │ ├── training.py │ │ └── upload_file.py │ ├── schemas │ │ ├── asset.py │ │ ├── setting.py │ │ ├── task.py │ │ └── training.py │ ├── services │ │ ├── asset_service.py │ │ ├── common_service.py │ │ ├── config_service.py │ │ ├── local_asset_service.py │ │ ├── task_service.py │ │ ├── task_services │ │ │ ├── __init__.py │ │ │ ├── base_task_service.py │ │ │ ├── marking_service.py │ │ │ ├── result_service.py │ │ │ ├── scheduler_service.py │ │ │ ├── task_image_service.py │ │ │ └── training_service.py │ │ ├── terminal_service.py │ │ └── upload_service.py │ └── utils │ │ ├── common.py │ │ ├── file_handler.py │ │ ├── json_encoder.py │ │ ├── logger.py │ │ ├── mark_handler.py │ │ ├── response.py │ │ ├── ssh.py │ │ ├── train_handler.py │ │ └── validators.py ├── data │ └── workflow │ │ ├── mark_workflow_api.json │ │ └── mark_workflow_api_new.json ├── requirements.txt ├── run.py ├── start_services.sh ├── task_scheduler │ ├── __init__.py │ ├── comfyui_api.py │ ├── comfyui_error_parser.py │ ├── comfyui_precheck.py │ ├── comfyui_ws_api.py │ └── task_scheduler.py └── tests │ ├── README.md │ └── test_task_status.py ├── docs ├── Snipaste_2025-06-20_19-18-44.png ├── Snipaste_2025-06-20_19-18-55.png ├── Snipaste_2025-06-20_19-19-01.png ├── Snipaste_2025-06-20_19-19-50.png └── Snipaste_2025-06-20_19-20-00.png └── fronted-ui ├── .eslintrc.js ├── .gitignore ├── README.md ├── babel.config.js ├── index.html ├── jsconfig.json ├── package-lock.json ├── package.json ├── public ├── docs │ ├── faq.md │ └── guide.md ├── favicon.ico └── index.html ├── src ├── App.vue ├── api │ ├── asset.js │ ├── common.js │ ├── settings.js │ ├── tasks.js │ ├── terminal.js │ └── upload.js ├── assets │ ├── logo.png │ └── styles │ │ └── global.css ├── components │ ├── assets │ │ ├── AssetCard.vue │ │ └── AssetForm.vue │ ├── common │ │ ├── Checkbox.vue │ │ ├── ContextMenu.vue │ │ ├── FileUploader.vue │ │ ├── HighlightEditableDiv.vue │ │ ├── KeyValueConfig.vue │ │ ├── LoadingSpinner.vue │ │ ├── LoraParamAssetView.vue │ │ ├── LoraParamSettingsView.vue │ │ ├── LoraTrainingParams.vue │ │ ├── Message.vue │ │ ├── Modal.vue │ │ ├── PageTabs.vue │ │ ├── PromptTooltip.vue │ │ ├── SwitchButton.vue │ │ ├── TaskConfigCard.vue │ │ ├── TooltipText.vue │ │ └── TrainingLossChart.vue │ ├── layout │ │ └── AppLayout.vue │ ├── tasks │ │ ├── ImageGrid.vue │ │ ├── ImageUploader.vue │ │ ├── ImageViewer.vue │ │ ├── LogViewer.vue │ │ ├── TaskCard.vue │ │ ├── TaskConfigCard.vue │ │ ├── TaskForm.vue │ │ ├── TaskList.vue │ │ ├── TaskStatus.vue │ │ ├── TrainingDetails.vue │ │ ├── TrainingHistoryDetails.vue │ │ └── TrainingHistoryDropdown.vue │ └── terminal │ │ ├── FileManager.vue │ │ ├── RemotePanel.vue │ │ └── Terminal.vue ├── composables │ └── useLoraParams.js ├── main.js ├── router │ └── index.js ├── utils │ ├── datetime.js │ ├── eventBus.js │ ├── message.js │ ├── object.js │ ├── paramUtils.js │ ├── request.js │ ├── taskStatus.js │ ├── textFormatters.js │ └── translationCache.js └── views │ ├── Assets.vue │ ├── Guide.vue │ ├── NotFound.vue │ ├── Settings.vue │ ├── TaskDetail.vue │ ├── Tasks.vue │ └── Training.vue ├── vite.config.js └── vue.config.js /.gitignore: -------------------------------------------------------------------------------- 1 | # Node 2 | node_modules/ 3 | dist/ 4 | .env 5 | .env.* 6 | 7 | # IDE 8 | .idea/ 9 | .vscode/ 10 | *.suo 11 | *.ntvs* 12 | *.njsproj 13 | *.sln 14 | *.sw? 15 | .history 16 | 17 | # OS 18 | .DS_Store 19 | Thumbs.db 20 | 21 | # Logs 22 | logs 23 | *.log 24 | npm-debug.log* 25 | yarn-debug.log* 26 | yarn-error.log* 27 | pnpm-debug.log* 28 | 29 | # Vue 30 | .nuxt 31 | .nitro 32 | .cache 33 | .output 34 | .env 35 | *.local 36 | 37 | # Testing 38 | /coverage 39 | /test/unit/coverage/ 40 | 41 | # Build 42 | /dist/ 43 | build/ 44 | data/ 45 | 46 | # Python 47 | __pycache__/ 48 | *.py[cod] 49 | *$py.class 50 | *.so 51 | .Python 52 | env/ 53 | build/ 54 | develop-eggs/ 55 | dist/ 56 | downloads/ 57 | eggs/ 58 | .eggs/ 59 | lib/ 60 | lib64/ 61 | parts/ 62 | sdist/ 63 | var/ 64 | *.egg-info/ 65 | .installed.cfg 66 | *.egg 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RICK LORA TRAINER (RLT) 2 | 3 | ![RLT Logo](docs/Snipaste_2025-06-20_19-18-44.png) 4 | 5 | ## 项目简介 6 | 7 | RICK LORA TRAINER (RLT) 是一个基于ComfyUI和秋叶训练器的自动化LORA训练平台,旨在简化AI绘图LORA模型的训练流程。该项目通过直观的用户界面和自动化工作流程,使得即使是零基础的用户也能轻松训练出高质量的LORA模型。 8 | 9 | ## 核心特性 10 | 11 | ### 1. 零参数配置训练 12 | - 预设优化的训练参数 13 | - 自动处理训练流程 14 | - 适合新手快速上手 15 | 16 | ### 2. 资产管理系统 17 | - 支持多训练节点管理 18 | - 训练素材自动下载与处理 19 | - 模型资产统一管理 20 | 21 | ### 3. 自动化工作流程 22 | - 素材自动下载(支持百度网盘) 23 | - 智能数据标注处理 24 | - 一键式LORA模型训练 25 | - 训练完成后自动上传成品模型 26 | 27 | ### 4. 任务管理 28 | ![任务管理](docs/Snipaste_2025-06-20_19-19-50.png) 29 | - 可视化任务状态追踪 30 | - 实时训练进度展示 31 | - 训练日志实时查看 32 | - 失败任务自动重试 33 | - 历史训练数据查询 34 | 35 | ### 5. 灵活配置 36 | ![系统配置](docs/Snipaste_2025-06-20_19-18-55.png) 37 | - 可定制的训练参数 38 | - 节点资源分配策略 39 | - 调度间隔设置 40 | - 并发限制控制 41 | 42 | ## 系统架构 43 | 44 | ### 后端架构 45 | - Flask + SQLAlchemy 46 | - WebSocket实时通信 47 | - 多线程任务调度 48 | - ComfyUI API集成 49 | 50 | ### 前端架构 51 | - Vue 3框架 52 | - 响应式界面设计 53 | - 实时数据更新 54 | 55 | ## 工作流程 56 | ![工作流程](docs/Snipaste_2025-06-20_19-19-01.png) 57 | 58 | ``` 59 | 素材收集 -> 自动标注 -> LORA训练 -> 模型生成 -> 效果预览 60 | ``` 61 | 62 | ## 快速开始 63 | 64 | ### 环境要求 65 | - Python 3.8+ 66 | - Node.js 14+ 67 | - ComfyUI环境 68 | - 秋叶训练器(可选) 69 | 70 | ### 安装步骤 71 | 72 | 1. 克隆项目 73 | ```bash 74 | git clone 75 | cd lora-automatic-traning 76 | ``` 77 | 78 | 2. 安装后端依赖 79 | ```bash 80 | cd backend 81 | pip install -r requirements.txt 82 | ``` 83 | 84 | 3. 安装前端依赖 85 | ```bash 86 | cd fronted-ui 87 | npm install 88 | ``` 89 | 90 | 4. 启动服务 91 | ```bash 92 | # 启动后端 93 | cd backend 94 | python run.py 95 | 96 | # 启动前端 97 | cd fronted-ui 98 | npm run serve 99 | ``` 100 | 101 | ## 使用指南 102 | ![使用界面](docs/Snipaste_2025-06-20_19-20-00.png) 103 | 104 | 1. **创建训练任务** 105 | - 上传训练素材或提供下载链接 106 | - 选择训练类型和参数(或使用默认配置) 107 | - 提交任务 108 | 109 | 2. **监控训练进度** 110 | - 在任务列表查看所有任务状态 111 | - 点击任务查看详细训练日志和进度 112 | 113 | 3. **查看训练结果** 114 | - 训练完成后自动生成预览图 115 | - 下载训练好的LORA模型 116 | 117 | ## 核心模块 118 | 119 | - **TaskScheduler**: 任务调度与管理 120 | - **ComfyUIAPI**: 与ComfyUI交互的API封装 121 | - **AssetManager**: 资产管理系统 122 | - **TrainingService**: 训练服务 123 | - **ConfigService**: 配置管理 124 | 125 | ## 常见问题 126 | 127 | 1. **如何调整训练参数?** 128 | - 在设置页面可以修改默认训练参数 129 | - 也可以在创建任务时为单个任务指定参数 130 | 131 | 2. **支持哪些训练模型?** 132 | - 目前主要支持Flux-LORA模型训练 133 | - 计划后续添加更多模型支持 134 | 135 | 3. **如何处理训练失败?** 136 | - 系统会自动重试失败的任务 137 | - 可以查看详细错误日志进行排查 138 | 139 | ## 开发计划 140 | 141 | - [ ] 添加更多模型训练支持 142 | - [ ] 优化训练参数自动推荐 143 | - [ ] 增强批量处理能力 144 | - [ ] 添加模型评分系统 145 | 146 | ## 贡献指南 147 | 148 | 欢迎提交Pull Request或Issue来帮助改进项目。 149 | 150 | ## 许可证 151 | 152 | MIT License 153 | -------------------------------------------------------------------------------- /backend/app.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartRick/RLT/8a5e2e8fddff01d36d9f6f98750d7376b971f1e4/backend/app.db -------------------------------------------------------------------------------- /backend/app/__init__.py: -------------------------------------------------------------------------------- 1 | # 空文件,标记这是一个 Python 包 -------------------------------------------------------------------------------- /backend/app/api/v1/__init__.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint 2 | from .tasks import tasks_bp 3 | from .assets import assets_bp 4 | from .settings import settings_bp 5 | from .terminal import terminal_bp 6 | from .common import common_bp 7 | from .upload import upload_bp 8 | 9 | api_v1 = Blueprint('api_v1', __name__, url_prefix='/api/v1') 10 | 11 | # 注册路由 12 | api_v1.register_blueprint(tasks_bp, url_prefix='/tasks') 13 | api_v1.register_blueprint(assets_bp, url_prefix='/assets') 14 | api_v1.register_blueprint(settings_bp, url_prefix='/settings') 15 | api_v1.register_blueprint(terminal_bp, url_prefix='/terminal') 16 | api_v1.register_blueprint(upload_bp, url_prefix='/upload') 17 | api_v1.register_blueprint(common_bp, url_prefix='/common') -------------------------------------------------------------------------------- /backend/app/api/v1/assets.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, request 2 | from ...services.asset_service import AssetService 3 | from ...services.config_service import ConfigService 4 | from ...utils.logger import setup_logger 5 | from ...utils.validators import validate_asset_create 6 | from ...schemas.asset import SshVerifyRequest,AssetCreate, AssetUpdate 7 | from ...utils.response import success_json, error_json, exception_handler, response_template 8 | from ...models.asset import Asset as AssetModel 9 | from ...utils.common import copy_attributes 10 | 11 | logger = setup_logger('assets_api') 12 | assets_bp = Blueprint('assets', __name__) 13 | 14 | @assets_bp.route('', methods=['GET']) 15 | @exception_handler 16 | def list_assets(): 17 | """获取资产列表""" 18 | assets = AssetService.list_assets() 19 | return success_json([asset.dict() for asset in assets]) 20 | 21 | @assets_bp.route('', methods=['POST']) 22 | @exception_handler 23 | def create_asset(): 24 | """创建新资产""" 25 | asset_data = AssetCreate(**request.json) 26 | asset = AssetService.create_asset(asset_data) 27 | if not asset: 28 | return error_json(2001, "创建资产失败") 29 | return response_template("created", data=asset.dict()) 30 | 31 | @assets_bp.route('/', methods=['PUT']) 32 | @exception_handler 33 | def update_asset(asset_id): 34 | """更新资产""" 35 | logger.debug(f"更新资产: {asset_id}, 原始请求数据: {request.json}") 36 | 37 | # 检查是否为本地资产 38 | from ...services.local_asset_service import LocalAssetService 39 | is_local = LocalAssetService.is_local_asset(asset_id) 40 | logger.info(f"资产 {asset_id} 是否为本地资产: {is_local}") 41 | 42 | # 如果是本地资产,确保is_local字段为True 43 | if is_local and 'is_local' not in request.json: 44 | request_data = dict(request.json) 45 | request_data['is_local'] = True 46 | logger.info(f"为本地资产 {asset_id} 添加is_local=True标志") 47 | else: 48 | request_data = request.json 49 | 50 | logger.debug(f"处理后的请求数据: {request_data}") 51 | update_data = AssetUpdate(**request_data) 52 | logger.debug(f"验证后的更新数据: {update_data}") 53 | 54 | asset = AssetService.update_asset(asset_id, update_data) 55 | if not asset: 56 | logger.warning(f"资产不存在或更新失败: {asset_id}") 57 | return response_template("not_found", code=2002, msg="资产不存在或更新失败") 58 | 59 | return response_template("updated", data=asset.dict()) 60 | 61 | @assets_bp.route('/', methods=['DELETE']) 62 | @exception_handler 63 | def delete_asset(asset_id): 64 | """删除资产""" 65 | if AssetService.delete_asset(asset_id): 66 | return response_template("deleted") 67 | return error_json(2003, "删除资产失败") 68 | 69 | @assets_bp.route('//verify', methods=['POST']) 70 | @exception_handler 71 | def verify_capabilities(asset_id): 72 | """验证资产能力""" 73 | logger.info(f"开始验证资产 {asset_id} 的能力") 74 | results = AssetService.verify_capabilities(asset_id) 75 | logger.info(f"验证资产 {asset_id} 能力成功: {results}") 76 | return success_json(results) 77 | 78 | @assets_bp.route('/verify-ssh', methods=['POST']) 79 | @exception_handler 80 | def verify_ssh_connection(): 81 | """验证SSH连接""" 82 | # 验证请求数据 83 | asset_data = SshVerifyRequest(**request.json) 84 | 85 | # 创建临时资产对象用于验证 86 | asset = AssetModel() 87 | 88 | # 使用copy_attributes工具函数拷贝属性 89 | copy_attributes(asset_data, asset) 90 | 91 | # 执行SSH连接验证 92 | from ...services.terminal_service import TerminalService 93 | success, message = TerminalService.verify_asset_ssh_connection(asset) 94 | 95 | if success: 96 | return success_json(None, message) 97 | else: 98 | return error_json(4001, message) 99 | 100 | @assets_bp.route('//configs/lora', methods=['GET']) 101 | @exception_handler 102 | def get_asset_lora_config(asset_id): 103 | """获取资产的Lora训练配置""" 104 | config = ConfigService.get_asset_lora_config(asset_id) 105 | if config is None: 106 | return response_template("not_found", code=1004, msg="资产不存在") 107 | return success_json(config) 108 | 109 | @assets_bp.route('//configs/ai-engine', methods=['GET']) 110 | @exception_handler 111 | def get_asset_ai_engine_config(asset_id): 112 | """获取资产的AI引擎配置""" 113 | config = ConfigService.get_asset_ai_engine_config(asset_id) 114 | if config is None: 115 | return response_template("not_found", code=1004, msg="资产不存在") 116 | return success_json(config) 117 | 118 | @assets_bp.route('//headers/lora', methods=['GET']) 119 | @exception_handler 120 | def get_asset_lora_headers(asset_id): 121 | """获取资产的Lora训练请求头""" 122 | headers = ConfigService.get_asset_lora_headers(asset_id) 123 | if headers is None: 124 | return response_template("not_found", code=1004, msg="资产不存在") 125 | return success_json(headers) 126 | 127 | @assets_bp.route('//headers/ai-engine', methods=['GET']) 128 | @exception_handler 129 | def get_asset_ai_engine_headers(asset_id): 130 | """获取资产的AI引擎请求头""" 131 | headers = ConfigService.get_asset_ai_engine_headers(asset_id) 132 | if headers is None: 133 | return response_template("not_found", code=1004, msg="资产不存在") 134 | return success_json(headers) 135 | 136 | @assets_bp.route('//toggle', methods=['POST']) 137 | @exception_handler 138 | def toggle_asset(asset_id): 139 | """开启或关闭资产""" 140 | data = request.json 141 | if 'enabled' not in data: 142 | return error_json(2004, "请提供enabled参数") 143 | 144 | enabled = bool(data['enabled']) 145 | asset = AssetService.toggle_asset_status(asset_id, enabled) 146 | 147 | if not asset: 148 | return response_template("not_found", code=2002, msg="资产不存在") 149 | 150 | return success_json(asset.dict(), f"资产已{'启用' if enabled else '禁用'}") -------------------------------------------------------------------------------- /backend/app/api/v1/common.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, request, jsonify 2 | from typing import List, Dict, Any 3 | from ...services.common_service import CommonService 4 | from ...services.config_service import ConfigService 5 | from ...utils.response import error, success, success_json, error_json, exception_handler, response_template 6 | 7 | common_bp = Blueprint('common_bp', __name__) 8 | 9 | @common_bp.route('/translate', methods=['POST']) 10 | def translate_text(): 11 | """ 12 | 调用百度翻译API翻译文本 13 | """ 14 | try: 15 | # 检查翻译功能是否开启 16 | if not ConfigService.is_translate_enabled(): 17 | return error(msg="翻译功能未开启") 18 | 19 | data = request.get_json() 20 | if not data or 'text' not in data: 21 | return error(msg="缺少必要参数: text") 22 | 23 | text = data.get('text') 24 | to_lang = data.get('to_lang') 25 | from_lang = data.get('from_lang') 26 | 27 | result = CommonService.translate_text( 28 | text=text, 29 | to_lang=to_lang, 30 | from_lang=from_lang 31 | ) 32 | 33 | if not result['success']: 34 | return error(msg=result.get('error', '翻译失败')) 35 | 36 | return success(data=result) 37 | except Exception as e: 38 | return error(msg=f"翻译失败: {str(e)}") 39 | 40 | @common_bp.route('/batch-translate', methods=['POST']) 41 | def batch_translate(): 42 | """ 43 | 批量翻译多个文本 44 | """ 45 | try: 46 | # 检查翻译功能是否开启 47 | if not ConfigService.is_translate_enabled(): 48 | return error(msg="翻译功能未开启") 49 | 50 | data = request.get_json() 51 | if not data or 'texts' not in data: 52 | return error(msg="缺少必要参数: texts") 53 | 54 | texts = data.get('texts') 55 | to_lang = data.get('to_lang') 56 | from_lang = data.get('from_lang') 57 | 58 | if not isinstance(texts, list): 59 | return error(msg="texts参数必须是文本列表") 60 | 61 | result = CommonService.batch_translate( 62 | texts=texts, 63 | to_lang=to_lang, 64 | from_lang=from_lang 65 | ) 66 | 67 | return success(data=result) 68 | except Exception as e: 69 | return error(msg=f"批量翻译失败: {str(e)}") 70 | 71 | @common_bp.route('/translate', methods=['GET']) 72 | def translate_text_get(): 73 | """ 74 | 调用百度翻译API翻译文本(GET方式) 75 | """ 76 | try: 77 | # 检查翻译功能是否开启 78 | if not ConfigService.is_translate_enabled(): 79 | return error(msg="翻译功能未开启") 80 | 81 | text = request.args.get('text') 82 | to_lang = request.args.get('to_lang') 83 | from_lang = request.args.get('from_lang') 84 | 85 | if not text: 86 | return error(msg="缺少必要参数: text") 87 | 88 | result = CommonService.translate_text( 89 | text=text, 90 | to_lang=to_lang, 91 | from_lang=from_lang 92 | ) 93 | 94 | if not result['success']: 95 | return error(msg=result.get('error', '翻译失败')) 96 | 97 | return success(data=result) 98 | except Exception as e: 99 | return error(msg=f"翻译失败: {str(e)}") -------------------------------------------------------------------------------- /backend/app/api/v1/settings.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, request 2 | from ...database import get_db 3 | from ...services.config_service import ConfigService 4 | from ...utils.logger import setup_logger 5 | from ...utils.response import success_json, error_json, exception_handler, response_template 6 | from ...models.constants import COMMON_TRAINING_PARAMS, COMMON_MARK_PARAMS, FLUX_LORA_PARAMS 7 | 8 | logger = setup_logger('settings_api') 9 | settings_bp = Blueprint('settings', __name__) 10 | 11 | @settings_bp.route('', methods=['GET']) 12 | @exception_handler 13 | def get_settings(): 14 | """获取所有设置""" 15 | settings = ConfigService.get_config() 16 | return success_json(settings) 17 | 18 | @settings_bp.route('', methods=['PUT']) 19 | @exception_handler 20 | def update_settings(): 21 | """更新设置""" 22 | data = request.get_json() 23 | if not data: 24 | return response_template("bad_request", msg="无效的设置数据") 25 | 26 | result = ConfigService.update_config(data) 27 | if result: 28 | return success_json(None, "设置更新成功") 29 | return error_json(msg="设置更新失败") 30 | 31 | @settings_bp.route('/', methods=['GET']) 32 | @exception_handler 33 | def get_setting_value(key): 34 | """获取指定设置值""" 35 | value = ConfigService.get_value(key) 36 | if value is not None: 37 | return success_json({key: value}) 38 | return response_template("not_found", msg=f"设置项 {key} 不存在") 39 | 40 | @settings_bp.route('/common-training-params', methods=['GET']) 41 | @exception_handler 42 | def get_common_training_params(): 43 | """获取常用训练参数列表""" 44 | return success_json(COMMON_TRAINING_PARAMS) 45 | 46 | @settings_bp.route('/common-mark-params', methods=['GET']) 47 | @exception_handler 48 | def get_common_mark_params(): 49 | """获取常用标记参数列表""" 50 | return success_json(COMMON_MARK_PARAMS) 51 | 52 | @settings_bp.route('/flux-lora-params', methods=['GET']) 53 | @exception_handler 54 | def get_flux_lora_params(): 55 | """获取Flux-Lora特有参数列表""" 56 | return success_json(FLUX_LORA_PARAMS) 57 | 58 | @settings_bp.route('/tasks//mark-config', methods=['GET']) 59 | @exception_handler 60 | def get_task_mark_config(task_id): 61 | """获取任务的打标配置""" 62 | mark_config = ConfigService.get_task_mark_config(task_id) 63 | if mark_config is None: 64 | return response_template("not_found", code=1004, msg="任务不存在或无法获取打标配置") 65 | return success_json(mark_config) 66 | 67 | @settings_bp.route('/tasks//training-config', methods=['GET']) 68 | @exception_handler 69 | def get_task_training_config(task_id): 70 | """获取任务的训练配置""" 71 | training_config = ConfigService.get_task_training_config(task_id) 72 | if training_config is None: 73 | return response_template("not_found", code=1004, msg="任务不存在或无法获取训练配置") 74 | return success_json(training_config) 75 | 76 | @settings_bp.route('/assets//training-config', methods=['GET']) 77 | @exception_handler 78 | def get_asset_lora_config(asset_id): 79 | """获取资产的Lora训练配置""" 80 | config = ConfigService.get_asset_lora_config(asset_id) 81 | if config is None: 82 | return response_template("not_found", code=1004, msg="资产不存在或获取配置失败") 83 | return success_json(config) 84 | 85 | @settings_bp.route('/assets//ai-engine-config', methods=['GET']) 86 | @exception_handler 87 | def get_asset_ai_engine_config(asset_id): 88 | """获取资产的AI引擎配置""" 89 | config = ConfigService.get_asset_ai_engine_config(asset_id) 90 | if config is None: 91 | return response_template("not_found", code=1004, msg="资产不存在或获取配置失败") 92 | return success_json(config) -------------------------------------------------------------------------------- /backend/app/api/v1/upload.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, request, jsonify, send_from_directory 2 | from ...services.upload_service import UploadService 3 | from ...config import config 4 | import os 5 | 6 | upload_bp = Blueprint('upload', __name__) 7 | 8 | @upload_bp.route('/files', methods=['POST']) 9 | def upload_file(): 10 | """上传文件接口""" 11 | # 检查是否有文件 12 | if 'file' not in request.files: 13 | return jsonify({'error': '没有上传文件'}), 400 14 | 15 | file = request.files['file'] 16 | if file.filename == '': 17 | return jsonify({'error': '没有选择文件'}), 400 18 | 19 | # 获取文件描述 20 | description = request.form.get('description', '') 21 | 22 | # 保存文件 23 | result = UploadService.save_file(file, description) 24 | if result: 25 | return jsonify({'message': '文件上传成功', 'file': result}), 201 26 | else: 27 | return jsonify({'error': '文件上传失败'}), 500 28 | 29 | @upload_bp.route('/files', methods=['GET']) 30 | def get_files(): 31 | """获取所有文件列表""" 32 | files = UploadService.get_all_files() 33 | return jsonify({'files': files}) 34 | 35 | @upload_bp.route('/files/', methods=['GET']) 36 | def get_file(file_id): 37 | """获取单个文件信息""" 38 | file = UploadService.get_file_by_id(file_id) 39 | if file: 40 | return jsonify({'file': file}) 41 | return jsonify({'error': '文件不存在'}), 404 42 | 43 | @upload_bp.route('/files/', methods=['DELETE']) 44 | def delete_file(file_id): 45 | """删除文件""" 46 | result = UploadService.delete_file(file_id) 47 | if result: 48 | return jsonify({'message': '文件删除成功'}) 49 | return jsonify({'error': '文件删除失败'}), 404 50 | 51 | @upload_bp.route('/files//download', methods=['GET']) 52 | def download_file(file_id): 53 | """下载文件""" 54 | file = UploadService.get_file_by_id(file_id) 55 | if not file: 56 | return jsonify({'error': '文件不存在'}), 404 57 | 58 | # 获取文件路径 59 | file_dir = os.path.dirname(os.path.join(config.PROJECT_ROOT, file['storage_path'])) 60 | filename = os.path.basename(file['storage_path']) 61 | 62 | # 设置下载文件名为原始文件名 63 | download_name = file['filename'] 64 | 65 | return send_from_directory( 66 | file_dir, 67 | filename, 68 | as_attachment=True, 69 | download_name=download_name 70 | ) 71 | 72 | def init_app(app): 73 | """注册蓝图""" 74 | app.register_blueprint(upload_bp, url_prefix=f'{config.API_V1_PREFIX}/upload') -------------------------------------------------------------------------------- /backend/app/database.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine 2 | from sqlalchemy.ext.declarative import declarative_base 3 | from sqlalchemy.orm import sessionmaker 4 | from contextlib import contextmanager 5 | from .config import config 6 | from .utils.logger import setup_logger 7 | 8 | # 设置日志记录器 9 | logger = setup_logger('database') 10 | 11 | # 创建数据库引擎 12 | # 如果是SQLite,添加check_same_thread=False参数 13 | if config.DATABASE_URL.startswith('sqlite'): 14 | engine = create_engine( 15 | config.DATABASE_URL, 16 | connect_args={"check_same_thread": False} # 允许多线程访问SQLite连接 17 | ) 18 | logger.info("SQLite数据库已配置为多线程模式") 19 | else: 20 | engine = create_engine(config.DATABASE_URL) 21 | 22 | # 创建会话工厂 23 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 24 | 25 | # 创建基类 26 | Base = declarative_base() 27 | 28 | @contextmanager 29 | def get_db(): 30 | """获取数据库会话""" 31 | db = SessionLocal() 32 | try: 33 | yield db 34 | finally: 35 | db.close() 36 | 37 | def init_db(): 38 | """初始化数据库""" 39 | # 导入所有模型以确保它们被注册 40 | from .models import task # noqa 41 | from .models import training # noqa 42 | from .models import asset # noqa 43 | 44 | # 创建所有表 45 | Base.metadata.create_all(bind=engine) 46 | # 初始化本地资产 47 | try: 48 | from .services.local_asset_service import LocalAssetService 49 | logger.info("正在初始化本地资产...") 50 | local_asset = LocalAssetService.init_local_asset() 51 | if local_asset: 52 | logger.info(f"本地资产初始化成功: ID={local_asset.id}, 名称={local_asset.name}") 53 | else: 54 | logger.warning("本地资产初始化失败") 55 | 56 | # 初始化系统设置 57 | from .services.config_service import ConfigService 58 | ConfigService.init_settings() 59 | except Exception as e: 60 | logger.error(f"初始化本地资产时出错: {str(e)}", exc_info=True) -------------------------------------------------------------------------------- /backend/app/main.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, send_from_directory 2 | from flask_cors import CORS 3 | from .config import config 4 | from .api.v1 import api_v1 5 | from .utils.logger import setup_logger 6 | from .middleware.error_handler import ErrorHandler 7 | from .database import init_db 8 | from .api.v1.terminal import sock # 确保导入 sock 9 | from .services.config_service import ConfigService # 添加这行 10 | from .services.task_services.scheduler_service import SchedulerService 11 | from .utils.json_encoder import CustomJSONEncoder 12 | from .utils.ssh import close_ssh_connection_pool 13 | import os 14 | import atexit 15 | 16 | logger = setup_logger('main') 17 | 18 | def create_app(): 19 | """创建 Flask 应用""" 20 | logger.info("开始创建Flask应用...") 21 | app = Flask(__name__) 22 | 23 | # 基础配置 24 | app.config['SECRET_KEY'] = config.SECRET_KEY 25 | app.config['MAX_CONTENT_LENGTH'] = config.MAX_CONTENT_LENGTH 26 | 27 | # 注册自定义 JSON 编码器 28 | app.json_encoder = CustomJSONEncoder 29 | 30 | # 初始化数据库 31 | init_db() 32 | 33 | # 启用CORS 34 | CORS(app) 35 | 36 | # 注册错误处理 37 | ErrorHandler.init_app(app) 38 | 39 | # 注册 API 路由 40 | app.register_blueprint(api_v1) 41 | 42 | # 注册请求前处理器 43 | @app.before_request 44 | def before_request(): 45 | # TODO: 添加认证和请求日志 46 | pass 47 | 48 | # 注册请求后处理器 49 | @app.after_request 50 | def after_request(response): 51 | # TODO: 添加响应日志 52 | return response 53 | 54 | # 注册 WebSocket 扩展 55 | sock.init_app(app) 56 | 57 | # 注册静态文件路由,用于访问data目录下的任意文件 58 | @app.route('/data/') 59 | def serve_data_files(filepath): 60 | """处理data目录下的文件访问,包括上传的图片和打标后的文本等""" 61 | # 构建完整的文件路径 62 | full_path = os.path.join(config.DATA_DIR, filepath) 63 | directory, filename = os.path.split(full_path) 64 | 65 | # 检查目录是否存在 66 | if not os.path.exists(directory): 67 | logger.error(f"Directory not found: {directory}") 68 | return "Directory not found", 404 69 | 70 | # 检查文件是否存在 71 | if not os.path.exists(full_path): 72 | logger.error(f"File not found: {full_path}") 73 | return "File not found", 404 74 | 75 | return send_from_directory(directory, filename) 76 | 77 | # 初始化任务服务和调度器 78 | SchedulerService.init_scheduler() 79 | logger.info("任务服务已启动") 80 | 81 | # 注册应用关闭处理函数 82 | atexit.register(close_ssh_connection_pool) 83 | logger.info("注册了SSH连接池关闭函数") 84 | 85 | return app 86 | 87 | app = create_app() -------------------------------------------------------------------------------- /backend/app/middleware/error_handler.py: -------------------------------------------------------------------------------- 1 | from flask import jsonify 2 | from werkzeug.exceptions import HTTPException 3 | from ..utils.logger import setup_logger 4 | 5 | logger = setup_logger('error_handler') 6 | 7 | class ErrorHandler: 8 | @staticmethod 9 | def init_app(app): 10 | @app.errorhandler(Exception) 11 | def handle_exception(e): 12 | """处理所有异常""" 13 | if isinstance(e, HTTPException): 14 | response = { 15 | 'error': e.name, 16 | 'message': e.description, 17 | 'status_code': e.code 18 | } 19 | return jsonify(response), e.code 20 | 21 | # 处理其他异常 22 | logger.exception('Unhandled Exception') 23 | response = { 24 | 'error': 'Internal Server Error', 25 | 'message': str(e), 26 | 'status_code': 500 27 | } 28 | return jsonify(response), 500 29 | 30 | class ValidationError(Exception): 31 | """验证错误""" 32 | def __init__(self, message): 33 | self.message = message 34 | super().__init__(self.message) 35 | 36 | class NotFoundError(Exception): 37 | """资源不存在错误""" 38 | def __init__(self, message="Resource not found"): 39 | self.message = message 40 | super().__init__(self.message) 41 | 42 | class ServiceError(Exception): 43 | """服务错误""" 44 | def __init__(self, message): 45 | self.message = message 46 | super().__init__(self.message) -------------------------------------------------------------------------------- /backend/app/models/asset.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from sqlalchemy import Column, Integer, String, Boolean, DateTime, JSON 3 | from ..database import Base 4 | 5 | class Asset(Base): 6 | __tablename__ = 'assets' 7 | 8 | id = Column(Integer, primary_key=True) 9 | name = Column(String(50), nullable=False) 10 | ip = Column(String(255), nullable=False, comment='IP地址或域名') 11 | ssh_port = Column(Integer, default=22) 12 | ssh_username = Column(String(50), nullable=False) 13 | ssh_password = Column(String(255)) 14 | ssh_key_path = Column(String(255)) 15 | ssh_auth_type = Column(String(20), default='KEY') 16 | status = Column(String(20), default='PENDING') 17 | is_local = Column(Boolean, default=False, comment='是否为本地系统资产') 18 | port_access_mode = Column(String(20), default='DIRECT', comment='端口访问模式: DIRECT直连模式, DOMAIN域名模式') 19 | enabled = Column(Boolean, default=True, comment='资产是否启用') 20 | 21 | # 存储为JSON字段,包含高级配置参数 22 | lora_training = Column(JSON, default={ 23 | 'enabled': False, 24 | 'port': None, 25 | 'params': {}, # 高级参数配置,可覆盖全局配置 26 | 'verified': False, 27 | 'headers': { # 请求头配置 28 | 'Content-Type': 'application/json', 29 | 'Authorization': '' 30 | }, 31 | 'use_global_config': True, # 是否使用全局配置 32 | }) 33 | 34 | ai_engine = Column(JSON, default={ 35 | 'enabled': False, 36 | 'port': None, 37 | 'timeout': 300, 38 | 'headers': { # 请求头配置 39 | 'Content-Type': 'application/json', 40 | 'Authorization': '' 41 | }, 42 | 'max_retries': 3, 43 | 'retry_interval': 5, 44 | 'use_global_config': True, # 是否使用全局配置 45 | 'verified': False 46 | }) 47 | 48 | created_at = Column(DateTime, default=datetime.utcnow) 49 | updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) 50 | 51 | # 添加任务计数字段 52 | marking_tasks_count = Column(Integer, default=0, comment='当前标记任务数') 53 | training_tasks_count = Column(Integer, default=0, comment='当前训练任务数') 54 | max_concurrent_tasks = Column(Integer, default=10, comment='最大并发任务数(标记任务最大10个,训练任务最大1个)') 55 | 56 | def to_dict(self): 57 | return { 58 | 'id': self.id, 59 | 'name': self.name, 60 | 'ip': self.ip, 61 | 'ssh_port': self.ssh_port, 62 | 'ssh_username': self.ssh_username, 63 | 'ssh_auth_type': self.ssh_auth_type, 64 | 'status': self.status, 65 | 'port_access_mode': self.port_access_mode, 66 | 'lora_training': self.lora_training, 67 | 'ai_engine': self.ai_engine, 68 | 'created_at': self.created_at.isoformat() if self.created_at else None, 69 | 'updated_at': self.updated_at.isoformat() if self.updated_at else None, 70 | 'marking_tasks_count': self.marking_tasks_count, 71 | 'training_tasks_count': self.training_tasks_count, 72 | 'max_concurrent_tasks': self.max_concurrent_tasks, 73 | 'is_local': self.is_local, 74 | 'enabled': self.enabled 75 | } -------------------------------------------------------------------------------- /backend/app/models/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | 常量定义 3 | """ 4 | 5 | # 常用训练参数,用户可以在任务级别配置这些参数 6 | COMMON_TRAINING_PARAMS = { 7 | 'max_train_epochs': '最大训练轮次', 8 | 'train_batch_size': '批量大小', 9 | 'network_dim': '网络维度 (Dim)', 10 | 'network_alpha': '网络Alpha值', 11 | 'learning_rate': '基础学习率', 12 | 'unet_lr': 'Unet学习率', 13 | 'text_encoder_lr': '文本编码器学习率', 14 | 'resolution': '分辨率', 15 | 'lr_scheduler': '学习率调度器', 16 | 'lr_warmup_steps': '预热步数', 17 | 'lr_scheduler_num_cycles': '学习率循环次数', 18 | 'save_every_n_epochs': '每N轮保存一次', 19 | 'sample_every_n_epochs': '每N轮采样一次', 20 | 'clip_skip': 'CLIP跳过层数', 21 | 'seed': '随机种子', 22 | 'mixed_precision': '混合精度', 23 | 'optimizer_type': '优化器类型', 24 | 'repeat_num': '单张图片重复次数', 25 | } 26 | 27 | # 常用标记参数,用户可以在任务级别配置这些参数 28 | COMMON_MARK_PARAMS = { 29 | 'resolution': '图像分辨率', 30 | 'ratio': '图像比例', 31 | 'max_tokens': '最大标记令牌数', 32 | 'min_confidence': '最小置信度', 33 | 'trigger_words': '触发词', 34 | 'auto_crop': '自动裁剪' 35 | } 36 | 37 | # Flux-Lora训练特有参数 38 | FLUX_LORA_PARAMS = { 39 | 'model_train_type': '训练模型类型', 40 | 'pretrained_model_name_or_path': '预训练模型路径', 41 | 'ae': 'AutoEncoder模型路径', 42 | 'clip_l': 'CLIP-L模型路径', 43 | 't5xxl': 'T5XXL模型路径', 44 | 'timestep_sampling': '时间步采样方法', 45 | 'sigmoid_scale': 'Sigmoid缩放系数', 46 | 'model_prediction_type': '模型预测类型', 47 | 'discrete_flow_shift': '离散流偏移', 48 | 'loss_type': '损失函数类型', 49 | 'guidance_scale': '引导缩放系数', 50 | 'prior_loss_weight': '先验损失权重', 51 | 'enable_bucket': '启用分桶', 52 | 'min_bucket_reso': '最小桶分辨率', 53 | 'max_bucket_reso': '最大桶分辨率', 54 | 'bucket_reso_steps': '桶分辨率步长', 55 | 'bucket_no_upscale': '桶不上采样', 56 | 'network_module': '网络模块', 57 | 'network_train_unet_only': '仅训练UNet', 58 | 'network_train_text_encoder_only': '仅训练文本编码器', 59 | 'fp8_base': '使用FP8基础模型', 60 | 'sdpa': '使用SDPA', 61 | 'lowram': '低内存模式', 62 | 'cache_latents': '缓存潜变量', 63 | 'cache_latents_to_disk': '潜变量缓存到磁盘', 64 | 'cache_text_encoder_outputs': '缓存文本编码器输出', 65 | 'cache_text_encoder_outputs_to_disk': '文本编码器输出缓存到磁盘', 66 | 'persistent_data_loader_workers': '持久化数据加载器工作线程', 67 | 'gradient_checkpointing': '梯度检查点', 68 | 'gradient_accumulation_steps': '梯度累积步数', 69 | } 70 | 71 | # 域名访问模式配置 72 | DOMAIN_ACCESS_CONFIG = { 73 | # 域名访问模式 74 | 'PORT_ACCESS_MODES': { 75 | 'DIRECT': '直连模式', # 直接使用IP:PORT访问 76 | 'DOMAIN': '域名模式', # 使用域名方式访问 77 | }, 78 | 79 | # SSH域名后缀,例如:mq1xrkw51rq0vj8w.ssh.x-gpu.com 80 | 'SSH_DOMAIN_SUFFIX': '.ssh.x-gpu.com', 81 | 82 | # 服务容器域名格式,例如:mq1xrkw51rq0vj8w-80.container.x-gpu.com 83 | # 格式为:{hostname}-{port}.container.x-gpu.com 84 | 'CONTAINER_DOMAIN_FORMAT': '{hostname}-{port}.container.x-gpu.com', 85 | 86 | # 容器域名后缀 87 | 'CONTAINER_DOMAIN_SUFFIX': '.container.x-gpu.com', 88 | 89 | # 默认协议前缀,可以是http或https 90 | 'DEFAULT_PROTOCOL': 'https://' 91 | } -------------------------------------------------------------------------------- /backend/app/models/setting.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, Integer, JSON, Text 2 | from ..database import Base 3 | from ..config import config 4 | import json 5 | 6 | class Setting(Base): 7 | __tablename__ = 'settings' 8 | 9 | key = Column(String(50), primary_key=True) 10 | value = Column(Text, nullable=False) # 使用Text类型以支持更长的配置值 11 | type = Column(String(20), nullable=False) # string, integer, json 12 | description = Column(String(200)) 13 | 14 | # 添加默认配置 15 | @staticmethod 16 | def get_defaults(): 17 | return { 18 | 'mark_workflow_api': { 19 | 'type': 'string', 20 | 'value': 'data/workflow/mark_workflow_api.json', # 从文件加载的默认工作流 21 | 'description': 'ComfyUI标记工作流配置' 22 | }, 23 | 'mark_poll_interval': { 24 | 'type': 'integer', 25 | 'value': '5', 26 | 'description': '标记任务轮询间隔(秒)' 27 | }, 28 | 'train_poll_interval': { 29 | 'type': 'integer', 30 | 'value': '15', 31 | 'description': '训练任务轮询间隔(秒)' 32 | }, 33 | 'scheduling_minute': { 34 | 'type': 'integer', 35 | 'value': '5', 36 | 'description': '调度间隔(分钟)' 37 | }, 38 | 'mark_pan_dir': { 39 | 'type': 'string', 40 | 'value': config.SYSTEM_CONFIG['mark_pan_dir'], 41 | 'description': '标记中间目录' 42 | }, 43 | 'lora_pan_upload_dir': { 44 | 'type': 'string', 45 | 'value': config.SYSTEM_CONFIG['lora_pan_upload_dir'], 46 | 'description': 'Lora上传中间目录' 47 | }, 48 | 'mark_config': { 49 | 'type': 'json', 50 | 'value': json.dumps(config.MARK_CONFIG), 51 | 'description': '打标全局配置' 52 | }, 53 | 'lora_training_config': { 54 | 'type': 'json', 55 | 'value': json.dumps(config.LORA_TRAINING_CONFIG), 56 | 'description': 'Lora训练全局配置' 57 | }, 58 | 'ai_engine_config': { 59 | 'type': 'json', 60 | 'value': json.dumps(config.AI_ENGINE_CONFIG), 61 | 'description': 'AI引擎全局配置' 62 | }, 63 | 'lora_training_headers': { 64 | 'type': 'json', 65 | 'value': json.dumps(config.HEADERS_CONFIG['lora_training']), 66 | 'description': 'Lora训练引擎请求头' 67 | }, 68 | 'ai_engine_headers': { 69 | 'type': 'json', 70 | 'value': json.dumps(config.HEADERS_CONFIG['ai_engine']), 71 | 'description': 'AI引擎请求头' 72 | }, 73 | 'baidu_translate_config': { 74 | 'type': 'json', 75 | 'value': json.dumps({ 76 | 'enabled': False, 77 | 'app_id': '20250327002316619', 78 | 'secret_key': '67qaSQg_WdfWqQFvx7ml', 79 | 'api_url': 'https://fanyi-api.baidu.com/api/trans/vip/translate', 80 | 'default_from': 'auto', 81 | 'default_to': 'zh' 82 | }), 83 | 'description': '百度翻译API配置' 84 | } 85 | } 86 | 87 | def parse_value(self): 88 | """根据类型解析值""" 89 | if self.type == 'integer': 90 | return int(self.value) 91 | elif self.type == 'json': 92 | try: 93 | return json.loads(self.value) 94 | except: 95 | return {} 96 | else: 97 | return self.value -------------------------------------------------------------------------------- /backend/app/models/training.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, JSON, DateTime, ForeignKey, Boolean 2 | from sqlalchemy.orm import relationship 3 | from ..database import Base 4 | from ..config import config 5 | from datetime import datetime 6 | 7 | class TrainingMaterial(Base): 8 | __tablename__ = "training_materials" 9 | 10 | id = Column(Integer, primary_key=True, index=True) 11 | folder_name = Column(String, unique=True, index=True) 12 | source_path = Column(String) 13 | status = Column(String) # PENDING, UPLOADED, PROCESSING, COMPLETED, FAILED 14 | extra_info = Column(JSON) # 改用 extra_info 替代 metadata 15 | created_at = Column(DateTime, default=datetime.utcnow) 16 | updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) 17 | 18 | # 关联关系 19 | tasks = relationship("TrainingTask", back_populates="material") 20 | 21 | class TrainingTask(Base): 22 | __tablename__ = "training_tasks" 23 | 24 | id = Column(Integer, primary_key=True, index=True) 25 | material_id = Column(Integer, ForeignKey("training_materials.id")) 26 | model_name = Column(String) 27 | status = Column(String) # PENDING, TRAINING, COMPLETED, FAILED 28 | progress = Column(Integer, default=0) 29 | error_message = Column(String, nullable=True) 30 | config = Column(JSON) # 训练配置参数 31 | asset_id = Column(Integer, ForeignKey("assets.id"), nullable=True) # 关联的资产ID 32 | use_global_config = Column(Boolean, default=True) # 是否使用全局配置 33 | output_path = Column(String, nullable=True) # 模型输出路径 34 | sample_images = Column(JSON, default=[]) # 采样图片路径列表 35 | created_at = Column(DateTime, default=datetime.utcnow) 36 | updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) 37 | started_at = Column(DateTime, nullable=True) 38 | completed_at = Column(DateTime, nullable=True) 39 | 40 | # 关联关系 41 | material = relationship("TrainingMaterial", back_populates="tasks") 42 | asset = relationship("Asset", backref="training_tasks") 43 | 44 | def to_dict(self): 45 | return { 46 | 'id': self.id, 47 | 'material_id': self.material_id, 48 | 'model_name': self.model_name, 49 | 'status': self.status, 50 | 'progress': self.progress, 51 | 'error_message': self.error_message, 52 | 'config': self.config, 53 | 'asset_id': self.asset_id, 54 | 'use_global_config': self.use_global_config, 55 | 'output_path': self.output_path, 56 | 'sample_images': self.sample_images, 57 | 'created_at': self.created_at.isoformat() if self.created_at else None, 58 | 'updated_at': self.updated_at.isoformat() if self.updated_at else None, 59 | 'started_at': self.started_at.isoformat() if self.started_at else None, 60 | 'completed_at': self.completed_at.isoformat() if self.completed_at else None 61 | } 62 | 63 | def get_training_config(self): 64 | """获取训练配置,合并全局配置和任务特定配置""" 65 | if self.use_global_config: 66 | # 合并全局配置和任务特定配置 67 | training_config = config.LORA_TRAINING_CONFIG.copy() 68 | if self.config and isinstance(self.config, dict): 69 | training_config.update(self.config) 70 | 71 | # 如果有关联资产,还需要考虑资产特定配置 72 | if self.asset: 73 | asset_config = self.asset.get_lora_config() 74 | # 资产配置优先级低于任务特定配置 75 | for key, value in asset_config.items(): 76 | if key not in self.config: 77 | training_config[key] = value 78 | 79 | return training_config 80 | else: 81 | # 仅使用任务特定配置 82 | return self.config or {} -------------------------------------------------------------------------------- /backend/app/models/upload_file.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from sqlalchemy import Column, Integer, String, DateTime, Text, Float 3 | from ..database import Base 4 | 5 | class UploadFile(Base): 6 | """上传文件记录模型""" 7 | __tablename__ = 'upload_files' 8 | 9 | id = Column(Integer, primary_key=True, autoincrement=True) 10 | filename = Column(String(255), nullable=False, comment='文件原始名称') 11 | storage_path = Column(String(500), nullable=False, comment='存储路径') 12 | file_type = Column(String(50), comment='文件类型') 13 | file_size = Column(Float, comment='文件大小(KB)') 14 | mime_type = Column(String(100), comment='MIME类型') 15 | md5 = Column(String(32), comment='文件MD5') 16 | description = Column(Text, comment='文件描述') 17 | created_at = Column(DateTime, default=datetime.now, comment='上传时间') 18 | 19 | def to_dict(self): 20 | """转换为字典""" 21 | return { 22 | 'id': self.id, 23 | 'filename': self.filename, 24 | 'storage_path': self.storage_path, 25 | 'file_type': self.file_type, 26 | 'file_size': self.file_size, 27 | 'mime_type': self.mime_type, 28 | 'md5': self.md5, 29 | 'description': self.description, 30 | 'created_at': self.created_at.isoformat() if self.created_at else None 31 | } -------------------------------------------------------------------------------- /backend/app/schemas/setting.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any, Dict 2 | from pydantic import BaseModel 3 | 4 | class SettingUpdate(BaseModel): 5 | source_dir: str 6 | lora_output_path: str 7 | scheduling_minute: int 8 | mark_pan_dir: str 9 | lora_pan_upload_dir: str 10 | 11 | class SettingResponse(BaseModel): 12 | source_dir: str 13 | lora_output_path: str 14 | scheduling_minute: int 15 | mark_pan_dir: str 16 | lora_pan_upload_dir: str 17 | 18 | class Config: 19 | orm_mode = True -------------------------------------------------------------------------------- /backend/app/schemas/task.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Any, List, Union 2 | from pydantic import BaseModel, Field 3 | from datetime import datetime 4 | 5 | class TaskBase(BaseModel): 6 | name: str = Field(..., min_length=2, max_length=100) 7 | description: Optional[str] = None 8 | marking_asset_id: Optional[int] = None 9 | training_asset_id: Optional[int] = None 10 | 11 | # 配置字段 12 | mark_config: Optional[Dict[str, Any]] = None 13 | use_global_mark_config: Optional[bool] = True 14 | 15 | training_config: Optional[Dict[str, Any]] = None 16 | use_global_training_config: Optional[bool] = True 17 | 18 | class TaskCreate(TaskBase): 19 | pass 20 | 21 | class TaskUpdate(BaseModel): 22 | name: Optional[str] = None 23 | description: Optional[str] = None 24 | status: Optional[str] = None 25 | progress: Optional[int] = None 26 | marking_asset_id: Optional[int] = None 27 | training_asset_id: Optional[int] = None 28 | 29 | # 配置字段 30 | mark_config: Optional[Dict[str, Any]] = None 31 | use_global_mark_config: Optional[bool] = None 32 | 33 | training_config: Optional[Dict[str, Any]] = None 34 | use_global_training_config: Optional[bool] = None 35 | 36 | class TaskStatus(BaseModel): 37 | status: str 38 | message: Optional[str] = None 39 | 40 | class TaskResponse(TaskBase): 41 | id: int 42 | status: str 43 | progress: int = 0 44 | created_at: datetime 45 | updated_at: datetime 46 | started_at: Optional[datetime] = None 47 | completed_at: Optional[datetime] = None 48 | 49 | class Config: 50 | orm_mode = True 51 | 52 | class TaskDetail(TaskResponse): 53 | images: Optional[List[Dict[str, Any]]] = [] 54 | status_history: Optional[Dict[str, Any]] = {} 55 | marking_asset: Optional[Dict[str, Any]] = None 56 | training_asset: Optional[Dict[str, Any]] = None 57 | 58 | class TaskImage(BaseModel): 59 | id: int 60 | filename: str 61 | preview_url: Optional[str] = None 62 | size: Optional[int] = None 63 | created_at: datetime 64 | 65 | class Config: 66 | orm_mode = True 67 | 68 | class TaskLog(BaseModel): 69 | id: int 70 | message: str 71 | time: str 72 | status: str -------------------------------------------------------------------------------- /backend/app/schemas/training.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Optional, Dict, Any 3 | from datetime import datetime 4 | 5 | # 训练素材相关模型 6 | class TrainingMaterialBase(BaseModel): 7 | folder_name: str 8 | source_path: str 9 | metadata: Optional[Dict[str, Any]] = None 10 | 11 | class TrainingMaterialCreate(TrainingMaterialBase): 12 | pass 13 | 14 | class TrainingMaterialUpdate(BaseModel): 15 | folder_name: Optional[str] = None 16 | source_path: Optional[str] = None 17 | status: Optional[str] = None 18 | metadata: Optional[Dict[str, Any]] = None 19 | 20 | class TrainingMaterial(TrainingMaterialBase): 21 | id: int 22 | status: str 23 | created_at: datetime 24 | updated_at: datetime 25 | 26 | class Config: 27 | from_attributes = True 28 | 29 | # 训练任务相关模型 30 | class TrainingTaskBase(BaseModel): 31 | material_id: int 32 | node_id: int 33 | 34 | class TrainingTaskCreate(TrainingTaskBase): 35 | pass 36 | 37 | class TrainingTaskUpdate(BaseModel): 38 | status: Optional[str] = None 39 | lora_path: Optional[str] = None 40 | error: Optional[str] = None 41 | 42 | class TrainingTask(TrainingTaskBase): 43 | id: int 44 | status: str 45 | lora_path: Optional[str] = None 46 | error: Optional[str] = None 47 | created_at: datetime 48 | updated_at: datetime 49 | 50 | class Config: 51 | from_attributes = True -------------------------------------------------------------------------------- /backend/app/services/common_service.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List, Optional 2 | import requests 3 | import random 4 | import hashlib 5 | import json 6 | import time 7 | from ..utils.logger import setup_logger 8 | from ..services.config_service import ConfigService 9 | 10 | logger = setup_logger('common_service') 11 | 12 | class CommonService: 13 | """ 14 | 系统通用接口服务 15 | 提供各种常用功能接口 16 | """ 17 | 18 | @staticmethod 19 | def translate_text(text: str, to_lang: str = None, from_lang: str = None) -> Dict[str, Any]: 20 | """ 21 | 调用百度翻译API翻译文本 22 | 23 | Args: 24 | text: 要翻译的文本 25 | to_lang: 目标语言,默认为系统配置的默认目标语言 26 | from_lang: 源语言,默认为auto(自动检测) 27 | 28 | Returns: 29 | 包含翻译结果的字典 30 | { 31 | 'success': True/False, 32 | 'result': '翻译结果', 33 | 'from': '源语言', 34 | 'to': '目标语言', 35 | 'error': '错误信息(如果失败)' 36 | } 37 | """ 38 | try: 39 | # 获取百度翻译API配置 40 | config = ConfigService.get_value('baidu_translate_config', {}) 41 | 42 | if not config or not isinstance(config, dict): 43 | return { 44 | 'success': False, 45 | 'error': '系统未配置百度翻译API' 46 | } 47 | 48 | app_id = config.get('app_id') 49 | secret_key = config.get('secret_key') 50 | api_url = config.get('api_url', 'https://fanyi-api.baidu.com/api/trans/vip/translate') 51 | 52 | if not app_id or not secret_key: 53 | return { 54 | 'success': False, 55 | 'error': '百度翻译API配置不完整' 56 | } 57 | 58 | # 设置默认值 59 | if not from_lang: 60 | from_lang = config.get('default_from', 'auto') 61 | if not to_lang: 62 | to_lang = config.get('default_to', 'zh') 63 | 64 | # 处理过长的文本 65 | if len(text) > 2000: 66 | text = text[:2000] 67 | logger.warning('翻译文本过长,已截断至2000字符') 68 | 69 | # 准备请求参数 70 | salt = str(random.randint(32768, 65536)) 71 | sign = app_id + text + salt + secret_key 72 | sign = hashlib.md5(sign.encode()).hexdigest() 73 | 74 | params = { 75 | 'q': text, 76 | 'from': from_lang, 77 | 'to': to_lang, 78 | 'appid': app_id, 79 | 'salt': salt, 80 | 'sign': sign 81 | } 82 | 83 | # 发送请求 84 | response = requests.get(api_url, params=params, timeout=10) 85 | result = response.json() 86 | 87 | if 'error_code' in result: 88 | return { 89 | 'success': False, 90 | 'error': f"百度翻译API错误: {result.get('error_code')} - {result.get('error_msg', '未知错误')}" 91 | } 92 | 93 | # 处理翻译结果 94 | translated_text = "" 95 | src_lang = from_lang 96 | 97 | if 'trans_result' in result: 98 | translated_text = ' '.join([item['dst'] for item in result['trans_result']]) 99 | src_lang = result.get('from', from_lang) 100 | 101 | return { 102 | 'success': True, 103 | 'result': translated_text, 104 | 'from': src_lang, 105 | 'to': to_lang 106 | } 107 | 108 | except Exception as e: 109 | logger.error(f"翻译失败: {str(e)}") 110 | return { 111 | 'success': False, 112 | 'error': f"翻译失败: {str(e)}" 113 | } 114 | 115 | @staticmethod 116 | def batch_translate(texts: List[str], to_lang: str = None, from_lang: str = None) -> Dict[str, Any]: 117 | """ 118 | 批量翻译多个文本 119 | 120 | Args: 121 | texts: 要翻译的文本列表 122 | to_lang: 目标语言,默认为系统配置的默认目标语言 123 | from_lang: 源语言,默认为auto(自动检测) 124 | 125 | Returns: 126 | 包含所有翻译结果的字典 127 | """ 128 | results = [] 129 | success_count = 0 130 | 131 | for text in texts: 132 | # 为避免API调用过于频繁,添加短暂延迟 133 | time.sleep(0.2) 134 | result = CommonService.translate_text(text, to_lang, from_lang) 135 | results.append(result) 136 | if result['success']: 137 | success_count += 1 138 | 139 | return { 140 | 'success': success_count == len(texts), 141 | 'results': results, 142 | 'total': len(texts), 143 | 'success_count': success_count, 144 | 'failed_count': len(texts) - success_count 145 | } -------------------------------------------------------------------------------- /backend/app/services/local_asset_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import socket 4 | from typing import Optional 5 | from ..models.asset import Asset 6 | from ..database import get_db 7 | from ..utils.logger import setup_logger 8 | from ..utils import common 9 | logger = setup_logger('local_asset_service') 10 | 11 | class LocalAssetService: 12 | """本地资产服务,用于管理本地资产""" 13 | 14 | LOCAL_ASSET_NAME = "本地系统" 15 | 16 | @staticmethod 17 | def init_local_asset(): 18 | """初始化本地资产""" 19 | try: 20 | with get_db() as db: 21 | # 检查是否已存在本地资产 22 | existing = db.query(Asset).filter(Asset.name == LocalAssetService.LOCAL_ASSET_NAME).first() 23 | if existing: 24 | logger.info(f"本地资产已存在: {existing.id}") 25 | return existing 26 | 27 | # 获取系统信息 28 | # ip = LocalAssetService.get_local_ip() 29 | system_info = common.get_system_info() 30 | 31 | # 创建新的本地资产 32 | local_asset = Asset( 33 | name=LocalAssetService.LOCAL_ASSET_NAME, 34 | ip='127.0.0.1', 35 | ssh_port=22, # 默认值,Windows不会使用 36 | ssh_username=system_info["username"], 37 | ssh_auth_type="KEY", # 默认值,不重要 38 | status="CONNECTED", # 本地资产默认为已连接状态 39 | is_local=True, # 标记为本地资产 40 | enabled=True # 启用本地资产 41 | ) 42 | 43 | # 设置本地资产的能力 44 | local_asset.lora_training = { 45 | 'enabled': True, 46 | 'port': 28000, 47 | 'config_path': '', 48 | 'params': {}, # 高级参数配置,可覆盖全局配置 49 | 'verified': False, 50 | 'headers': { # 请求头配置 51 | 'Content-Type': 'application/json', 52 | 'Authorization': '' 53 | }, 54 | 'use_global_config': True, # 是否使用全局配置 55 | } 56 | 57 | local_asset.ai_engine = { 58 | 'enabled': True, 59 | 'port': 8188, 60 | 'api_url': '', 61 | 'timeout': 300, 62 | 'headers': { # 请求头配置 63 | 'Content-Type': 'application/json', 64 | 'Authorization': '' 65 | }, 66 | 'max_retries': 3, 67 | 'retry_interval': 5, 68 | 'use_global_config': True, # 是否使用全局配置 69 | 'verified': False 70 | } 71 | # 保存到数据库 72 | db.add(local_asset) 73 | db.commit() 74 | db.refresh(local_asset) 75 | 76 | logger.info(f"本地资产创建成功: ID={local_asset.id}, 系统类型={'Windows' if system_info['is_windows'] else 'Linux/Unix'}") 77 | return local_asset 78 | 79 | except Exception as e: 80 | logger.error(f"初始化本地资产失败: {str(e)}", exc_info=True) 81 | return None 82 | 83 | @staticmethod 84 | def is_local_asset(asset_id: int) -> bool: 85 | """判断是否为本地资产""" 86 | try: 87 | with get_db() as db: 88 | asset = db.query(Asset).filter(Asset.id == asset_id).first() 89 | if not asset: 90 | return False 91 | return asset.name == LocalAssetService.LOCAL_ASSET_NAME 92 | except Exception as e: 93 | logger.error(f"检查本地资产失败: {str(e)}") 94 | return False -------------------------------------------------------------------------------- /backend/app/services/task_service.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | from datetime import datetime 3 | from sqlalchemy.orm import Session 4 | from ..models.task import Task, TaskStatus 5 | from ..database import get_db 6 | from ..utils.logger import setup_logger 7 | 8 | # 导入拆分后的服务 9 | from .task_services.base_task_service import BaseTaskService 10 | from .task_services.task_image_service import TaskImageService 11 | from .task_services.marking_service import MarkingService 12 | from .task_services.training_service import TrainingService 13 | from .task_services.result_service import ResultService 14 | from .task_services.scheduler_service import SchedulerService 15 | 16 | logger = setup_logger('task_service') 17 | 18 | class TaskService: 19 | """任务服务,作为统一入口调用各个子服务""" 20 | 21 | # 基本任务管理(委托给BaseTaskService) 22 | list_tasks = BaseTaskService.list_tasks 23 | get_task_by_id = BaseTaskService.get_task_by_id 24 | create_task = BaseTaskService.create_task 25 | update_task = BaseTaskService.update_task 26 | delete_task = BaseTaskService.delete_task 27 | get_task_log = BaseTaskService.get_task_log 28 | get_task_status = BaseTaskService.get_task_status 29 | get_stats = BaseTaskService.get_stats 30 | stop_task = BaseTaskService.stop_task 31 | restart_task = BaseTaskService.restart_task 32 | cancel_task = BaseTaskService.cancel_task 33 | get_task_config = BaseTaskService.get_task_config 34 | update_task_config = BaseTaskService.update_task_config 35 | 36 | # 任务图片管理(委托给TaskImageService) 37 | upload_images = TaskImageService.upload_images 38 | delete_image = TaskImageService.delete_image 39 | batch_delete_images = TaskImageService.batch_delete_images 40 | 41 | # 打标任务管理(委托给MarkingService) 42 | start_marking = MarkingService.start_marking 43 | batch_start_marking = MarkingService.batch_start_marking 44 | get_available_marking_assets = MarkingService.get_available_marking_assets 45 | 46 | # 训练任务管理(委托给TrainingService) 47 | start_training = TrainingService.start_training 48 | get_available_training_assets = TrainingService.get_available_training_assets 49 | 50 | # 结果管理(委托给ResultService) 51 | get_marked_texts = ResultService.get_marked_texts 52 | update_marked_text = ResultService.update_marked_text 53 | batch_update_marked_texts = ResultService.batch_update_marked_texts 54 | get_training_results = ResultService.get_training_results 55 | get_training_loss_data = ResultService.get_training_loss_data 56 | get_execution_history = ResultService.get_execution_history 57 | get_execution_history_by_id = ResultService.get_execution_history_by_id 58 | delete_execution_history = ResultService.delete_execution_history 59 | export_marked_files = ResultService.export_marked_files 60 | import_marked_files = ResultService.import_marked_files 61 | 62 | # 调度器管理(委托给SchedulerService) 63 | init_scheduler = SchedulerService.init_scheduler 64 | start_scheduler = SchedulerService.start_scheduler 65 | stop_scheduler = SchedulerService.stop_scheduler 66 | run_scheduler_once = SchedulerService.run_scheduler_once -------------------------------------------------------------------------------- /backend/app/services/task_services/__init__.py: -------------------------------------------------------------------------------- 1 | # task_services模块 2 | # 包含任务服务的各个子模块 3 | 4 | from .base_task_service import BaseTaskService 5 | from .task_image_service import TaskImageService 6 | from .marking_service import MarkingService 7 | from .training_service import TrainingService 8 | from .result_service import ResultService 9 | from .scheduler_service import SchedulerService 10 | 11 | __all__ = [ 12 | 'BaseTaskService', 13 | 'TaskImageService', 14 | 'MarkingService', 15 | 'TrainingService', 16 | 'ResultService', 17 | 'SchedulerService' 18 | ] -------------------------------------------------------------------------------- /backend/app/services/task_services/task_image_service.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | from sqlalchemy.orm import Session 3 | from ...models.task import Task, TaskImage, TaskStatus 4 | from ...database import get_db 5 | from ...utils.logger import setup_logger 6 | import os 7 | from werkzeug.utils import secure_filename 8 | from ...config import config 9 | 10 | logger = setup_logger('task_image_service') 11 | 12 | class TaskImageService: 13 | @staticmethod 14 | def upload_images(db: Session, task_id: int, files: List) -> Optional[Dict]: 15 | """上传任务图片""" 16 | try: 17 | task = db.query(Task).filter(Task.id == task_id).first() 18 | if not task or task.status != TaskStatus.NEW: 19 | return None 20 | 21 | results = [] 22 | # 创建任务专属的上传目录 23 | task_upload_dir = os.path.join(config.UPLOAD_DIR, str(task_id)) 24 | os.makedirs(task_upload_dir, exist_ok=True) 25 | 26 | for file in files: 27 | filename = secure_filename(file.filename) 28 | file_path = os.path.join(task_upload_dir, filename) 29 | 30 | # 保存文件 31 | file.save(file_path) 32 | 33 | # 创建图片记录 34 | image = TaskImage( 35 | task_id=task_id, 36 | filename=filename, 37 | file_path=file_path, 38 | preview_url=f'/data/uploads/{task_id}/{filename}', 39 | size=os.path.getsize(file_path) 40 | ) 41 | db.add(image) 42 | results.append(image.to_dict()) 43 | 44 | db.commit() 45 | return results 46 | except Exception as e: 47 | logger.error(f"上传图片失败: {e}") 48 | db.rollback() 49 | return None 50 | 51 | @staticmethod 52 | def delete_image(db: Session, task_id: int, image_id: int) -> bool: 53 | """删除任务图片""" 54 | try: 55 | image = db.query(TaskImage).filter( 56 | TaskImage.id == image_id, 57 | TaskImage.task_id == task_id 58 | ).first() 59 | 60 | if not image: 61 | logger.warning(f"未找到图片: task_id={task_id}, image_id={image_id}") 62 | return False 63 | 64 | # 1. 删除原始图片文件 65 | if image.file_path and os.path.exists(image.file_path): 66 | os.remove(image.file_path) 67 | logger.info(f"已删除图片文件: {image.file_path}") 68 | 69 | # 2. 删除对应的打标文本文件(如果存在) 70 | name_without_ext = os.path.splitext(image.filename)[0] 71 | marked_dir = os.path.join(config.MARKED_DIR, str(task_id)) 72 | text_file_path = os.path.join(marked_dir, f"{name_without_ext}.txt") 73 | 74 | if os.path.exists(text_file_path): 75 | os.remove(text_file_path) 76 | logger.info(f"已删除打标文本文件: {text_file_path}") 77 | 78 | # 3. 从数据库中删除图片记录 79 | db.delete(image) 80 | 81 | # 4. 记录操作日志 82 | task = db.query(Task).filter(Task.id == task_id).first() 83 | if task: 84 | task.add_log(f"删除了图片: {image.filename}", db=db) 85 | else: 86 | db.commit() 87 | 88 | return True 89 | except Exception as e: 90 | logger.error(f"删除图片失败: {e}") 91 | db.rollback() 92 | return False 93 | 94 | @staticmethod 95 | def batch_delete_images(db: Session, task_id: int, image_ids: List[int]) -> Dict: 96 | """批量删除任务图片 97 | 98 | Args: 99 | db: 数据库会话 100 | task_id: 任务ID 101 | image_ids: 要删除的图片ID列表 102 | 103 | Returns: 104 | 包含操作结果的字典,包括成功删除的图片和失败的图片 105 | """ 106 | # 检查任务是否存在 107 | task = db.query(Task).filter(Task.id == task_id).first() 108 | if not task: 109 | return { 110 | "success": False, 111 | "message": f"任务 {task_id} 不存在", 112 | "deleted": [], 113 | "failed": [{"id": image_id, "reason": "任务不存在"} for image_id in image_ids] 114 | } 115 | 116 | # 检查任务状态,只有NEW状态的任务可以删除图片 117 | if task.status != TaskStatus.NEW: 118 | return { 119 | "success": False, 120 | "message": f"任务状态为 {task.status},不允许删除图片", 121 | "deleted": [], 122 | "failed": [{"id": image_id, "reason": f"任务状态为 {task.status},不允许删除图片"} for image_id in image_ids] 123 | } 124 | 125 | results = { 126 | "success": True, 127 | "message": "批量删除图片完成", 128 | "deleted": [], 129 | "failed": [] 130 | } 131 | 132 | for image_id in image_ids: 133 | try: 134 | # 查询图片 135 | image = db.query(TaskImage).filter( 136 | TaskImage.id == image_id, 137 | TaskImage.task_id == task_id 138 | ).first() 139 | 140 | if not image: 141 | results["failed"].append({ 142 | "id": image_id, 143 | "reason": f"图片不存在或不属于任务 {task_id}" 144 | }) 145 | continue 146 | 147 | image_info = image.to_dict() 148 | 149 | # 1. 删除原始图片文件 150 | if image.file_path and os.path.exists(image.file_path): 151 | os.remove(image.file_path) 152 | 153 | # 2. 删除对应的打标文本文件(如果存在) 154 | name_without_ext = os.path.splitext(image.filename)[0] 155 | marked_dir = os.path.join(config.MARKED_DIR, str(task_id)) 156 | text_file_path = os.path.join(marked_dir, f"{name_without_ext}.txt") 157 | 158 | if os.path.exists(text_file_path): 159 | os.remove(text_file_path) 160 | 161 | # 3. 从数据库中删除图片记录 162 | db.delete(image) 163 | 164 | # 添加到成功列表 165 | results["deleted"].append(image_info) 166 | 167 | except Exception as e: 168 | logger.error(f"删除图片 {image_id} 失败: {str(e)}") 169 | results["failed"].append({ 170 | "id": image_id, 171 | "reason": str(e) 172 | }) 173 | 174 | # 记录操作日志 175 | if results["deleted"]: 176 | deleted_filenames = [img.get("filename", f"ID:{img.get('id')}") for img in results["deleted"]] 177 | task.add_log(f"批量删除了 {len(results['deleted'])} 个图片: {', '.join(deleted_filenames)}", db=db) 178 | 179 | # 如果全部失败,则整体标记为失败 180 | if not results["deleted"] and results["failed"]: 181 | results["success"] = False 182 | results["message"] = "所有图片删除都失败了" 183 | 184 | return results -------------------------------------------------------------------------------- /backend/app/services/upload_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import uuid 4 | from typing import Optional, Dict, Any 5 | from werkzeug.utils import secure_filename 6 | from ..database import get_db 7 | from ..models.upload_file import UploadFile 8 | from ..config import config 9 | from ..utils.logger import setup_logger 10 | from ..utils.file_handler import calculate_md5 11 | 12 | logger = setup_logger('upload_service') 13 | 14 | class UploadService: 15 | @staticmethod 16 | def allowed_file(filename: str) -> bool: 17 | """检查文件类型是否允许上传""" 18 | return '.' in filename and \ 19 | filename.rsplit('.', 1)[1].lower() in config.ALLOWED_EXTENSIONS 20 | 21 | @staticmethod 22 | def save_file(file, description: str = None) -> Optional[Dict[str, Any]]: 23 | """ 24 | 保存上传文件并记录到数据库 25 | 26 | Args: 27 | file: 上传的文件对象 28 | description: 文件描述 29 | 30 | Returns: 31 | 保存成功返回文件信息字典,失败返回None 32 | """ 33 | try: 34 | if not file: 35 | logger.error("没有接收到文件") 36 | return None 37 | 38 | # 安全处理文件名 39 | original_filename = file.filename 40 | # 检查文件类型 41 | if not UploadService.allowed_file(original_filename): 42 | logger.error(f"不允许的文件类型: {original_filename}") 43 | return None 44 | 45 | # 生成唯一文件名 46 | file_extension = original_filename.rsplit('.', 1)[1].lower() 47 | timestamp = str(int(uuid.uuid1().time_low)) # 生成短时间戳 48 | filename_without_extension = original_filename.rsplit('.', 1)[0] 49 | unique_filename = f"{filename_without_extension}_{timestamp}.{file_extension}" 50 | 51 | # 构建存储路径 52 | relative_path = os.path.join('uploads', unique_filename) 53 | file_path = os.path.join(config.UPLOAD_DIR, unique_filename) 54 | 55 | # 保存文件 56 | file.save(file_path) 57 | file_size = os.path.getsize(file_path) / 1024 # 转换为KB 58 | 59 | # 计算MD5 60 | md5 = calculate_md5(file_path) 61 | 62 | # 记录到数据库 63 | with get_db() as db: 64 | upload_file = UploadFile( 65 | filename=original_filename, 66 | storage_path=relative_path, 67 | file_type=file_extension, 68 | file_size=file_size, 69 | mime_type=file.content_type if hasattr(file, 'content_type') else None, 70 | md5=md5, 71 | description=description 72 | ) 73 | db.add(upload_file) 74 | db.commit() 75 | db.refresh(upload_file) 76 | 77 | logger.info(f"文件上传成功: {original_filename}, ID: {upload_file.id}") 78 | return upload_file.to_dict() 79 | 80 | except Exception as e: 81 | logger.error(f"文件上传失败: {str(e)}") 82 | return None 83 | 84 | @staticmethod 85 | def get_file_by_id(file_id: int) -> Optional[Dict[str, Any]]: 86 | """根据ID获取文件信息""" 87 | try: 88 | with get_db() as db: 89 | file = db.query(UploadFile).filter(UploadFile.id == file_id).first() 90 | if file: 91 | return file.to_dict() 92 | return None 93 | except Exception as e: 94 | logger.error(f"获取文件信息失败, ID: {file_id}, 错误: {str(e)}") 95 | return None 96 | 97 | @staticmethod 98 | def get_all_files() -> list: 99 | """获取所有上传的文件""" 100 | try: 101 | with get_db() as db: 102 | files = db.query(UploadFile).order_by(UploadFile.created_at.desc()).all() 103 | return [file.to_dict() for file in files] 104 | except Exception as e: 105 | logger.error(f"获取所有文件列表失败: {str(e)}") 106 | return [] 107 | 108 | @staticmethod 109 | def delete_file(file_id: int) -> bool: 110 | """删除文件""" 111 | try: 112 | with get_db() as db: 113 | file = db.query(UploadFile).filter(UploadFile.id == file_id).first() 114 | if not file: 115 | logger.error(f"文件不存在, ID: {file_id}") 116 | return False 117 | 118 | # 删除物理文件 119 | file_path = os.path.join(config.PROJECT_ROOT, file.storage_path) 120 | if os.path.exists(file_path): 121 | os.remove(file_path) 122 | 123 | # 删除数据库记录 124 | db.delete(file) 125 | db.commit() 126 | 127 | logger.info(f"文件删除成功, ID: {file_id}") 128 | return True 129 | except Exception as e: 130 | logger.error(f"删除文件失败, ID: {file_id}, 错误: {str(e)}") 131 | return False -------------------------------------------------------------------------------- /backend/app/utils/common.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import platform 3 | import os 4 | from ..utils.logger import setup_logger 5 | 6 | logger = setup_logger('common') 7 | 8 | def get_local_ip(): 9 | """获取本地IP地址""" 10 | try: 11 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 12 | # 使用谷歌DNS服务器地址,不需要真的连接 13 | s.connect(("8.8.8.8", 80)) 14 | ip = s.getsockname()[0] 15 | s.close() 16 | return ip 17 | except Exception as e: 18 | logger.error(f"获取本地IP地址失败: {str(e)}") 19 | return "127.0.0.1" 20 | 21 | def get_system_info(): 22 | """获取系统信息""" 23 | system = platform.system() 24 | if system == "Windows": 25 | username = os.environ.get("USERNAME", "Administrator") 26 | return { 27 | "is_windows": True, 28 | "username": username 29 | } 30 | else: # Linux/Darwin 31 | username = os.environ.get("USER", "root") 32 | return { 33 | "is_windows": False, 34 | "username": username 35 | } 36 | 37 | def copy_attributes(source, target, attributes=None, ignore=None): 38 | """ 39 | 将源对象的属性拷贝到目标对象 40 | 41 | Args: 42 | source: 源对象 43 | target: 目标对象 44 | attributes: 要拷贝的属性列表,如果为None则拷贝所有属性 45 | ignore: 要忽略的属性列表 46 | 47 | Returns: 48 | 拷贝后的目标对象 49 | """ 50 | if ignore is None: 51 | ignore = [] 52 | 53 | # 如果没有指定属性列表,则获取源对象的所有属性 54 | if attributes is None: 55 | if hasattr(source, '__dict__'): 56 | attributes = source.__dict__.keys() 57 | elif hasattr(source, '__slots__'): 58 | attributes = source.__slots__ 59 | elif isinstance(source, dict): 60 | attributes = source.keys() 61 | else: 62 | attributes = [attr for attr in dir(source) if not attr.startswith('_')] 63 | 64 | # 拷贝属性 65 | for attr in attributes: 66 | # 跳过被忽略的属性 67 | if attr in ignore: 68 | continue 69 | 70 | # 获取源对象的属性值 71 | if isinstance(source, dict): 72 | if attr in source: 73 | value = source[attr] 74 | else: 75 | continue 76 | else: 77 | if hasattr(source, attr): 78 | value = getattr(source, attr) 79 | else: 80 | continue 81 | 82 | # 设置目标对象的属性值 83 | if isinstance(target, dict): 84 | target[attr] = value 85 | else: 86 | setattr(target, attr, value) 87 | 88 | return target 89 | 90 | def generate_domain_url(hostname: str, port: int) -> tuple[str, int]: 91 | """ 92 | 根据主机名和端口生成域名URL 93 | 94 | Args: 95 | hostname: 主机名(SSH域名) 96 | port: 服务端口 97 | 98 | Returns: 99 | tuple[str, int]: 生成的域名URL和端口,如果无法解析则返回None 100 | """ 101 | from ..models.constants import DOMAIN_ACCESS_CONFIG 102 | 103 | try: 104 | # 检查是否是SSH域名格式 105 | if DOMAIN_ACCESS_CONFIG['SSH_DOMAIN_SUFFIX'] in hostname: 106 | # 提取主机名部分(移除SSH域名后缀) 107 | base_hostname = hostname.replace(DOMAIN_ACCESS_CONFIG['SSH_DOMAIN_SUFFIX'], '') 108 | 109 | # 根据域名格式模板生成容器访问域名 110 | container_domain = DOMAIN_ACCESS_CONFIG['CONTAINER_DOMAIN_FORMAT'].format( 111 | hostname=base_hostname, 112 | port=port 113 | ) 114 | 115 | # 添加协议前缀 116 | full_url = f"{DOMAIN_ACCESS_CONFIG['DEFAULT_PROTOCOL']}{container_domain}" 117 | 118 | # 根据协议确定返回端口 119 | return_port = 443 if DOMAIN_ACCESS_CONFIG['DEFAULT_PROTOCOL'].startswith('https') else 80 120 | 121 | return full_url, return_port 122 | 123 | return None 124 | except Exception as e: 125 | import logging 126 | logging.error(f"生成域名URL失败: {str(e)}") 127 | return None -------------------------------------------------------------------------------- /backend/app/utils/file_handler.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import hashlib 4 | from typing import Dict, List, Any 5 | from ..utils.logger import setup_logger 6 | 7 | logger = setup_logger('file_handler') 8 | 9 | def load_json(file_path: str, default: Any = None) -> Any: 10 | """加载JSON文件""" 11 | try: 12 | if os.path.exists(file_path): 13 | with open(file_path, 'r', encoding='utf-8') as f: 14 | return json.load(f) 15 | return default 16 | except Exception as e: 17 | logger.error(f"加载文件失败 {file_path}: {e}") 18 | return default 19 | 20 | def save_json(file_path: str, data: Any) -> bool: 21 | """保存JSON文件""" 22 | try: 23 | with open(file_path, 'w', encoding='utf-8') as f: 24 | json.dump(data, f, indent=2, ensure_ascii=False) 25 | return True 26 | except Exception as e: 27 | logger.error(f"保存文件失败 {file_path}: {e}") 28 | return False 29 | 30 | def calculate_md5(file_path: str) -> str: 31 | """ 32 | 计算文件MD5哈希值 33 | 34 | Args: 35 | file_path: 文件路径 36 | 37 | Returns: 38 | 文件的MD5哈希值(十六进制字符串) 39 | """ 40 | try: 41 | md5_hash = hashlib.md5() 42 | with open(file_path, "rb") as f: 43 | # 分块读取文件以处理大文件 44 | for chunk in iter(lambda: f.read(4096), b""): 45 | md5_hash.update(chunk) 46 | return md5_hash.hexdigest() 47 | except Exception as e: 48 | logger.error(f"计算文件MD5失败 {file_path}: {e}") 49 | return "" 50 | 51 | def generate_unique_folder_path(base_dir: str, task_id: int, path_type: str) -> str: 52 | """ 53 | 生成唯一的文件夹路径,格式为 base_dir/task_id_序号 54 | 55 | Args: 56 | base_dir: 基础目录 57 | task_id: 任务ID 58 | path_type: 路径类型,mark或train 59 | 60 | Returns: 61 | 唯一的文件夹路径 62 | """ 63 | # 确保基础目录存在 64 | os.makedirs(base_dir, exist_ok=True) 65 | 66 | # 查找已存在的当前任务的文件夹 67 | existing_folders = [] 68 | prefix = f"{task_id}_" 69 | 70 | if os.path.exists(base_dir): 71 | for item in os.listdir(base_dir): 72 | item_path = os.path.join(base_dir, item) 73 | if os.path.isdir(item_path) and item.startswith(prefix): 74 | try: 75 | # 提取序号部分 76 | sequence = int(item.split('_')[1]) 77 | existing_folders.append(sequence) 78 | except (IndexError, ValueError): 79 | # 如果格式不符,忽略 80 | pass 81 | 82 | # 确定新的序号 83 | sequence_num = 1 84 | if existing_folders: 85 | sequence_num = max(existing_folders) + 1 86 | 87 | # 生成新的文件夹路径 88 | folder_name = f"{task_id}_{sequence_num}" 89 | if path_type == 'mark': 90 | folder_name += "_mark" 91 | elif path_type == 'train': 92 | folder_name += "_train" 93 | 94 | new_folder_path = os.path.join(base_dir, folder_name) 95 | 96 | # 创建目录 97 | os.makedirs(new_folder_path, exist_ok=True) 98 | 99 | return new_folder_path -------------------------------------------------------------------------------- /backend/app/utils/json_encoder.py: -------------------------------------------------------------------------------- 1 | import json 2 | import enum 3 | from datetime import datetime, date 4 | from decimal import Decimal 5 | 6 | class CustomJSONEncoder(json.JSONEncoder): 7 | """自定义 JSON 编码器,处理特殊类型的序列化""" 8 | 9 | def default(self, obj): 10 | # 处理枚举类型 11 | if isinstance(obj, enum.Enum): 12 | return obj.value 13 | 14 | # 处理日期时间类型 15 | if isinstance(obj, (datetime, date)): 16 | return obj.isoformat() 17 | 18 | # 处理 Decimal 类型 19 | if isinstance(obj, Decimal): 20 | return float(obj) 21 | 22 | # 尝试将对象转换为字典(如果有 to_dict 方法) 23 | if hasattr(obj, 'to_dict') and callable(getattr(obj, 'to_dict')): 24 | return obj.to_dict() 25 | 26 | # 默认行为 27 | return super().default(obj) -------------------------------------------------------------------------------- /backend/app/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from logging.handlers import TimedRotatingFileHandler 4 | from datetime import datetime 5 | from ..config import config 6 | 7 | def setup_logger(name: str) -> logging.Logger: 8 | """设置日志记录器""" 9 | logger = logging.getLogger(name) 10 | 11 | # 检查是否已经有处理器,如果有则不再添加 12 | if logger.handlers: 13 | return logger 14 | 15 | logger.setLevel(logging.INFO) 16 | 17 | # 使用TimedRotatingFileHandler按周分割日志文件 18 | log_file = os.path.join(config.LOGS_DIR, 'application.log') 19 | # 创建按周轮换的日志处理器 20 | file_handler = TimedRotatingFileHandler( 21 | filename=log_file, 22 | when='W0', # 每周一轮换 23 | interval=1, # 每1周 24 | backupCount=12 # 保留12周的日志 25 | ) 26 | file_handler.setLevel(logging.INFO) 27 | # 设置后缀名格式为 application.log.YYYY-MM-DD 28 | file_handler.suffix = "%Y-%m-%d" 29 | 30 | # 控制台处理器 31 | console_handler = logging.StreamHandler() 32 | console_handler.setLevel(logging.INFO) 33 | 34 | # 格式化器 - 添加模块名称以便于区分不同模块的日志 35 | formatter = logging.Formatter( 36 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 37 | ) 38 | file_handler.setFormatter(formatter) 39 | console_handler.setFormatter(formatter) 40 | 41 | logger.addHandler(file_handler) 42 | logger.addHandler(console_handler) 43 | 44 | return logger -------------------------------------------------------------------------------- /backend/app/utils/response.py: -------------------------------------------------------------------------------- 1 | """ 2 | 响应格式工具模块 3 | 统一API响应格式为 {code, data, msg} 结构 4 | """ 5 | from typing import Any, Dict, Optional, Union, List, Tuple, TypeVar, Callable 6 | from flask import jsonify, Response, current_app 7 | import traceback 8 | import logging 9 | 10 | # 创建logger 11 | logger = logging.getLogger(__name__) 12 | 13 | # 类型定义 14 | T = TypeVar('T') 15 | DataType = Optional[Union[Dict, List, Any]] 16 | ResponseType = Tuple[Response, int] 17 | 18 | def api_response(code: int = 0, 19 | data: DataType = None, 20 | msg: str = "成功") -> Dict: 21 | """ 22 | 生成统一的API响应格式 23 | 24 | 参数: 25 | code: 状态码,0表示成功,非0表示各种错误 26 | data: 响应数据 27 | msg: 状态消息 28 | 29 | 返回: 30 | 符合统一格式的响应字典 31 | """ 32 | response = { 33 | "code": code, 34 | "data": data if data is not None else {}, 35 | "msg": msg 36 | } 37 | return response 38 | 39 | 40 | def success(data: DataType = None, msg: str = "成功") -> Dict: 41 | """成功响应字典""" 42 | return api_response(0, data, msg) 43 | 44 | 45 | def error(code: int = 500, msg: str = "操作失败", data: DataType = None) -> Dict: 46 | """错误响应字典""" 47 | return api_response(code, data, msg) 48 | 49 | 50 | def json_response(code: int = 0, 51 | data: DataType = None, 52 | msg: str = "成功", 53 | status: int = 200) -> ResponseType: 54 | """ 55 | 生成统一的API JSON响应对象 56 | 57 | 参数: 58 | code: 业务码,0表示成功,非0表示各种错误 59 | data: 响应数据 60 | msg: 状态消息 61 | status: HTTP状态码 62 | 63 | 返回: 64 | Flask响应对象 65 | """ 66 | response_data = api_response(code, data, msg) 67 | return jsonify(response_data), status 68 | 69 | 70 | def success_json(data: DataType = None, 71 | msg: str = "成功", 72 | status: int = 200) -> ResponseType: 73 | """ 74 | 生成成功的JSON响应对象 75 | 76 | 参数: 77 | data: 响应数据 78 | msg: 状态消息 79 | status: HTTP状态码 (默认200) 80 | 81 | 返回: 82 | Flask响应对象 83 | """ 84 | return json_response(0, data, msg, status) 85 | 86 | 87 | def error_json(code: int = 500, 88 | msg: str = "操作失败", 89 | data: DataType = None, 90 | status: int = None) -> ResponseType: 91 | """ 92 | 生成错误的JSON响应对象 93 | 94 | 参数: 95 | code: 业务错误码 96 | msg: 错误消息 97 | data: 额外数据 98 | status: HTTP状态码 (默认根据code自动判断) 99 | 100 | 返回: 101 | Flask响应对象 102 | """ 103 | # 如果未指定消息,尝试从错误码定义中获取 104 | if msg == "操作失败" and code in ERROR_CODES: 105 | msg = ERROR_CODES[code] 106 | 107 | # 如果未指定HTTP状态码,根据业务码自动选择 108 | if status is None: 109 | if code >= 1000: # 业务错误码 110 | status = 400 # 默认业务错误返回400 111 | else: # HTTP错误码 112 | status = code if 400 <= code < 600 else 500 113 | 114 | # 记录错误日志(5xx错误) 115 | if status >= 500: 116 | logger.error(f"服务器错误: code={code}, msg={msg}") 117 | 118 | return json_response(code, data, msg, status) 119 | 120 | 121 | def exception_handler(func: Callable[..., T]) -> Callable[..., Union[T, ResponseType]]: 122 | """ 123 | 异常处理装饰器,捕获函数执行中的异常并返回标准错误响应 124 | 125 | 使用示例: 126 | 127 | @exception_handler 128 | def my_api_function(): 129 | # 可能抛出异常的代码 130 | return success_json(...) 131 | """ 132 | def wrapper(*args, **kwargs): 133 | try: 134 | return func(*args, **kwargs) 135 | except ValueError as e: 136 | # 参数验证错误 137 | logger.warning(f"参数验证错误: {str(e)}") 138 | return error_json(400, str(e)) 139 | except Exception as e: 140 | # 服务器内部错误 141 | error_id = generate_error_id() 142 | logger.error(f"未处理异常 [{error_id}]: {str(e)}", exc_info=True) 143 | 144 | if current_app.debug: 145 | # 开发环境返回详细错误 146 | return error_json(500, f"服务器错误: {str(e)}", { 147 | "error_id": error_id, 148 | "traceback": traceback.format_exc() 149 | }) 150 | else: 151 | # 生产环境返回简化错误 152 | return error_json(500, f"服务器错误,请联系管理员。错误ID: {error_id}") 153 | 154 | # 保留原函数名和文档 155 | wrapper.__name__ = func.__name__ 156 | wrapper.__doc__ = func.__doc__ 157 | return wrapper 158 | 159 | 160 | def generate_error_id() -> str: 161 | """生成唯一的错误ID,用于日志跟踪""" 162 | import uuid 163 | import time 164 | return f"{int(time.time())}-{str(uuid.uuid4())[:8]}" 165 | 166 | 167 | def response_template(template_name: str, **kwargs) -> ResponseType: 168 | """ 169 | 使用预定义模板生成响应 170 | 171 | 参数: 172 | template_name: 模板名称 173 | **kwargs: 模板参数 174 | 175 | 返回: 176 | Flask响应对象 177 | """ 178 | templates = { 179 | "created": lambda **kw: success_json( 180 | kw.get("data"), 181 | kw.get("msg", "创建成功"), 182 | 201 183 | ), 184 | "updated": lambda **kw: success_json( 185 | kw.get("data"), 186 | kw.get("msg", "更新成功") 187 | ), 188 | "deleted": lambda **kw: success_json( 189 | None, 190 | kw.get("msg", "删除成功") 191 | ), 192 | "not_found": lambda **kw: error_json( 193 | kw.get("code", 404), 194 | kw.get("msg", "资源不存在"), 195 | kw.get("data") 196 | ), 197 | "bad_request": lambda **kw: error_json( 198 | kw.get("code", 400), 199 | kw.get("msg", "请求参数错误"), 200 | kw.get("data") 201 | ), 202 | "unauthorized": lambda **kw: error_json( 203 | kw.get("code", 401), 204 | kw.get("msg", "未授权访问"), 205 | kw.get("data"), 206 | 401 207 | ), 208 | "forbidden": lambda **kw: error_json( 209 | kw.get("code", 403), 210 | kw.get("msg", "禁止访问"), 211 | kw.get("data"), 212 | 403 213 | ), 214 | } 215 | 216 | if template_name not in templates: 217 | raise ValueError(f"未知的响应模板: {template_name}") 218 | 219 | return templates[template_name](**kwargs) 220 | 221 | 222 | def get_error_message(code: int) -> str: 223 | """获取错误码对应的默认错误消息""" 224 | return ERROR_CODES.get(code, "未知错误") 225 | 226 | 227 | # 常见错误码定义 228 | ERROR_CODES = { 229 | # 通用错误 230 | 400: "请求参数错误", 231 | 401: "未授权访问", 232 | 403: "禁止访问", 233 | 404: "资源不存在", 234 | 405: "方法不允许", 235 | 500: "服务器内部错误", 236 | 237 | # 业务错误码 238 | # 1xxx: 任务相关错误 239 | 1001: "任务不存在", 240 | 1002: "任务状态不允许此操作", 241 | 1003: "创建任务失败", 242 | 1004: "更新任务失败", 243 | 1005: "删除任务失败", 244 | 1006: "获取任务失败", 245 | 246 | # 2xxx: 资产相关错误 247 | 2001: "创建资产失败", 248 | 2002: "资产不存在或更新失败", 249 | 2003: "删除资产失败", 250 | 2004: "标记资产不可用", 251 | 2005: "训练资产不可用", 252 | 2006: "标记请求失败", 253 | 2007: "标记过程异常", 254 | 2008: "验证资产能力失败", 255 | 256 | # 3xxx: 文件操作错误 257 | 3001: "文件上传失败", 258 | 3002: "文件删除失败", 259 | 260 | # 4xxx: 系统设置相关错误 261 | 4001: "更新配置失败", 262 | 4002: "获取配置失败", 263 | 264 | # 5xxx: 训练相关错误 265 | 5001: "启动训练失败", 266 | 5002: "停止训练失败", 267 | 5003: "获取训练状态失败" 268 | } -------------------------------------------------------------------------------- /backend/app/utils/validators.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List 2 | from ..middleware.error_handler import ValidationError 3 | from ..config import config 4 | 5 | def validate_task_create(data: dict) -> bool: 6 | """验证任务创建数据""" 7 | required_fields = ['name'] # 修改为只需要 name 字段 8 | 9 | # 检查必填字段 10 | for field in required_fields: 11 | if field not in data or not data[field]: 12 | return False 13 | 14 | return True 15 | 16 | def validate_asset_create(data: Dict[str, Any]) -> None: 17 | """验证创建资产的数据""" 18 | required_fields = ['name', 'folder_name'] 19 | for field in required_fields: 20 | if field not in data: 21 | raise ValidationError(f"Missing required field: {field}") 22 | 23 | def validate_file_upload(files: List[Any]) -> None: 24 | """验证文件上传""" 25 | if not files: 26 | raise ValidationError("没有上传文件") 27 | 28 | for file in files: 29 | # 检查文件扩展名 30 | ext = file.filename.rsplit('.', 1)[1].lower() if '.' in file.filename else '' 31 | if ext not in config.ALLOWED_EXTENSIONS: 32 | raise ValidationError(f"不支持的文件类型。支持的类型: {', '.join(config.ALLOWED_EXTENSIONS)}") 33 | 34 | # 检查文件大小 35 | if file.content_length > config.MAX_CONTENT_LENGTH: 36 | max_size_mb = config.MAX_CONTENT_LENGTH / (1024 * 1024) 37 | raise ValidationError(f"文件过大。最大允许: {max_size_mb}MB") -------------------------------------------------------------------------------- /backend/data/workflow/mark_workflow_api_new.json: -------------------------------------------------------------------------------- 1 | { 2 | "35": { 3 | "inputs": { 4 | "aspect_ratio": "1:1", 5 | "proportional_width": 512, 6 | "proportional_height": 512, 7 | "method": "lanczos", 8 | "scale_to_side": "longest", 9 | "scale_to_length": 1024, 10 | "round_to_multiple": "8", 11 | "image": [ 12 | "217", 13 | 0 14 | ], 15 | "mask": [ 16 | "152", 17 | 0 18 | ] 19 | }, 20 | "class_type": "LayerUtility: ImageAutoCrop V3", 21 | "_meta": { 22 | "title": "LayerUtility: ImageAutoCrop V3" 23 | } 24 | }, 25 | "150": { 26 | "inputs": { 27 | "bbox_threshold": 0.5000000000000001, 28 | "bbox_dilation": 0, 29 | "crop_factor": 3, 30 | "drop_size": 10, 31 | "sub_threshold": 0.5000000000000001, 32 | "sub_dilation": 0, 33 | "sub_bbox_expansion": 0, 34 | "sam_mask_hint_threshold": 0.7000000000000002, 35 | "post_dilation": 0, 36 | "bbox_detector": [ 37 | "151", 38 | 0 39 | ], 40 | "image": [ 41 | "217", 42 | 0 43 | ] 44 | }, 45 | "class_type": "ImpactSimpleDetectorSEGS", 46 | "_meta": { 47 | "title": "简易Seg检测" 48 | } 49 | }, 50 | "151": { 51 | "inputs": { 52 | "model_name": "bbox/face_yolov8m.pt" 53 | }, 54 | "class_type": "UltralyticsDetectorProvider", 55 | "_meta": { 56 | "title": "检测加载器" 57 | } 58 | }, 59 | "152": { 60 | "inputs": { 61 | "segs": [ 62 | "150", 63 | 0 64 | ] 65 | }, 66 | "class_type": "ImpactSEGSToMaskBatch", 67 | "_meta": { 68 | "title": "Seg到遮罩组" 69 | } 70 | }, 71 | "155": { 72 | "inputs": { 73 | "string": "G:\\project\\project\\python\\lora-automatic-traning\\backend\\data\\marked\\2_1_mark" 74 | }, 75 | "class_type": "Simple String", 76 | "_meta": { 77 | "title": "简易字符串" 78 | } 79 | }, 80 | "168": { 81 | "inputs": { 82 | "source": [ 83 | "226", 84 | 0 85 | ], 86 | "to_replace": ". This is ", 87 | "replace_with": "" 88 | }, 89 | "class_type": "JWStringReplace", 90 | "_meta": { 91 | "title": "String Replace" 92 | } 93 | }, 94 | "170": { 95 | "inputs": { 96 | "source": [ 97 | "168", 98 | 0 99 | ], 100 | "to_replace": ". The", 101 | "replace_with": "" 102 | }, 103 | "class_type": "JWStringReplace", 104 | "_meta": { 105 | "title": "String Replace" 106 | } 107 | }, 108 | "171": { 109 | "inputs": { 110 | "source": [ 111 | "170", 112 | 0 113 | ], 114 | "to_replace": ". ", 115 | "replace_with": "" 116 | }, 117 | "class_type": "JWStringReplace", 118 | "_meta": { 119 | "title": "String Replace" 120 | } 121 | }, 122 | "200": { 123 | "inputs": { 124 | "destination": [ 125 | "155", 126 | 0 127 | ], 128 | "save_mode": "Overwrite", 129 | "file_mode": "Windows", 130 | "file_format": "na", 131 | "seed": 126166310172827, 132 | "TextFileNames": [ 133 | "219", 134 | 0 135 | ], 136 | "TextFileContents": [ 137 | "216", 138 | 0 139 | ] 140 | }, 141 | "class_type": "DataSet_TextFilesSave", 142 | "_meta": { 143 | "title": "DataSet_TextFilesSave" 144 | } 145 | }, 146 | "202": { 147 | "inputs": { 148 | "names": [ 149 | "219", 150 | 0 151 | ], 152 | "destination": [ 153 | "155", 154 | 0 155 | ], 156 | "image_format": "png", 157 | "image_quality": 100, 158 | "seed": 654270775057616, 159 | "images": [ 160 | "209", 161 | 0 162 | ] 163 | }, 164 | "class_type": "DataSet_SaveImagePro", 165 | "_meta": { 166 | "title": "DataSet_SaveImagePro" 167 | } 168 | }, 169 | "208": { 170 | "inputs": { 171 | "string": "G:\\project\\project\\python\\lora-automatic-traning\\backend\\data\\uploads\\2" 172 | }, 173 | "class_type": "Simple String", 174 | "_meta": { 175 | "title": "简易字符串" 176 | } 177 | }, 178 | "209": { 179 | "inputs": { 180 | "boolean": true, 181 | "on_true": [ 182 | "35", 183 | 0 184 | ], 185 | "on_false": [ 186 | "217", 187 | 0 188 | ] 189 | }, 190 | "class_type": "easy ifElse", 191 | "_meta": { 192 | "title": "是否判断" 193 | } 194 | }, 195 | "210": { 196 | "inputs": { 197 | "string": "Marking test" 198 | }, 199 | "class_type": "Simple String", 200 | "_meta": { 201 | "title": "简易字符串" 202 | } 203 | }, 204 | "216": { 205 | "inputs": { 206 | "delimiter": ", ", 207 | "clean_whitespace": "true", 208 | "text_a": [ 209 | "210", 210 | 0 211 | ], 212 | "text_b": [ 213 | "171", 214 | 0 215 | ] 216 | }, 217 | "class_type": "Text Concatenate", 218 | "_meta": { 219 | "title": "文本连锁" 220 | } 221 | }, 222 | "217": { 223 | "inputs": { 224 | "directory": [ 225 | "208", 226 | 0 227 | ], 228 | "image_load_cap": 0, 229 | "start_index": 0, 230 | "load_always": true 231 | }, 232 | "class_type": "LoadImageListFromDir //Inspire", 233 | "_meta": { 234 | "title": "加载图像列表(路径)" 235 | } 236 | }, 237 | "218": { 238 | "inputs": { 239 | "source": [ 240 | "223", 241 | 0 242 | ], 243 | "split_by": "/", 244 | "from_right": "true" 245 | }, 246 | "class_type": "JWStringSplit", 247 | "_meta": { 248 | "title": "String Split" 249 | } 250 | }, 251 | "219": { 252 | "inputs": { 253 | "source": [ 254 | "218", 255 | 1 256 | ], 257 | "split_by": ".", 258 | "from_right": "true" 259 | }, 260 | "class_type": "JWStringSplit", 261 | "_meta": { 262 | "title": "String Split" 263 | } 264 | }, 265 | "223": { 266 | "inputs": { 267 | "source": [ 268 | "217", 269 | 2 270 | ], 271 | "to_replace": "\\", 272 | "replace_with": "/" 273 | }, 274 | "class_type": "JWStringReplace", 275 | "_meta": { 276 | "title": "String Replace" 277 | } 278 | }, 279 | "226": { 280 | "inputs": { 281 | "caption_type": "Descriptive", 282 | "caption_length": "long", 283 | "low_vram": true, 284 | "joy_two_pipeline": [ 285 | "227", 286 | 0 287 | ], 288 | "image": [ 289 | "217", 290 | 0 291 | ] 292 | }, 293 | "class_type": "Joy_caption_two", 294 | "_meta": { 295 | "title": "Joy Caption Two" 296 | } 297 | }, 298 | "227": { 299 | "inputs": { 300 | "model": "John6666/Llama-3.1-8B-Lexi-Uncensored-V2-nf4" 301 | }, 302 | "class_type": "Joy_caption_two_load", 303 | "_meta": { 304 | "title": "Joy Caption Two Load" 305 | } 306 | } 307 | } -------------------------------------------------------------------------------- /backend/requirements.txt: -------------------------------------------------------------------------------- 1 | flask==2.0.1 2 | flask-cors==3.0.10 3 | flask_sock==0.7.0 4 | paramiko==2.8.1 5 | python-dotenv==0.19.0 6 | sqlalchemy==1.4.23 7 | werkzeug==2.0.1 8 | gunicorn==20.1.0 9 | cryptography==3.4.7 10 | requests 11 | pydantic==1.10.8 -------------------------------------------------------------------------------- /backend/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # 添加项目根目录到 Python 路径 5 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(ROOT_DIR) 7 | 8 | from app.main import app 9 | from app.config import config 10 | 11 | if __name__ == '__main__': 12 | app.run( 13 | host=config.HOST, 14 | port=config.PORT, 15 | debug=config.DEBUG 16 | ) -------------------------------------------------------------------------------- /backend/start_services.sh: -------------------------------------------------------------------------------- 1 | # 创建日志目录 2 | mkdir -p /var/log/custom_scripts 3 | 4 | # 添加日志记录函数 5 | log() { 6 | echo "$(date '+%Y-%m-%d %H:%M:%S') - $1" >> /var/log/custom_scripts/startup.log 7 | } 8 | 9 | log "开始执行服务启动脚本" 10 | 11 | # 启动nginx 12 | log "正在启动nginx服务" 13 | nginx 14 | if [ $? -eq 0 ]; then 15 | log "nginx启动成功" 16 | else 17 | log "nginx启动失败,错误码: $?" 18 | fi 19 | 20 | # 在screen会话中启动ComfyUI (comfyui环境) 21 | log "正在启动ComfyUI服务" 22 | screen -d -m -D -L -Logfile /tmp/comfyui.log -S comfyui \ 23 | bash -l -c 'source $(conda info --base)/etc/profile.d/conda.sh; \ 24 | conda activate comfyui; \ 25 | cd /root/comfy/ComfyUI && python main.py' 26 | if [ $? -eq 0 ]; then 27 | log "ComfyUI启动成功" 28 | else 29 | log "ComfyUI启动失败,错误码: $?" 30 | fi 31 | 32 | # 在screen会话中启动另一个Python脚本 (base环境) 33 | log "正在启动base环境的Python脚本" 34 | screen -d -m -D -L -Logfile /tmp/base_app.log -S base_app \ 35 | bash -l -c 'source $(conda info --base)/etc/profile.d/conda.sh; \ 36 | conda activate base; \ 37 | cd /home/rlt && python run.py' 38 | if [ $? -eq 0 ]; then 39 | log "base环境Python脚本启动成功" 40 | else 41 | log "base环境Python脚本启动失败,错误码: $?" 42 | fi 43 | 44 | log "所有服务启动完成" 45 | 46 | echo "所有服务已启动" 47 | echo "使用 'screen -r comfyui' 查看ComfyUI日志" 48 | echo "使用 'screen -r base_app' 查看base应用日志" 49 | echo "启动日志位于 /var/log/custom_scripts/startup.log" -------------------------------------------------------------------------------- /backend/task_scheduler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartRick/RLT/8a5e2e8fddff01d36d9f6f98750d7376b971f1e4/backend/task_scheduler/__init__.py -------------------------------------------------------------------------------- /backend/tests/README.md: -------------------------------------------------------------------------------- 1 | # 任务状态接口测试 2 | 3 | 本目录包含用于测试任务管理API的测试用例。 4 | 5 | ## 运行测试 6 | 7 | 在项目根目录下运行: 8 | 9 | ```bash 10 | python -m unittest discover tests 11 | ``` 12 | 13 | 或者运行单个测试文件: 14 | 15 | ```bash 16 | python -m unittest tests/test_task_status.py 17 | ``` 18 | 19 | ## 任务状态接口 20 | 21 | ### 获取任务状态 22 | 23 | **URL**: `/api/v1/tasks//status` 24 | **方法**: `GET` 25 | **描述**: 获取指定任务的状态信息 26 | 27 | **响应示例**: 28 | 29 | ```json 30 | { 31 | "id": 1, 32 | "name": "测试任务", 33 | "status": "MARKING", 34 | "progress": 45, 35 | "error_message": null, 36 | "started_at": "2025-03-16T12:30:45", 37 | "updated_at": "2025-03-16T12:35:22", 38 | "completed_at": null, 39 | "recent_logs": [ 40 | { 41 | "time": "2025-03-16T12:30:45", 42 | "message": "任务已开始标记" 43 | }, 44 | { 45 | "time": "2025-03-16T12:35:22", 46 | "message": "标记进度: 45%" 47 | } 48 | ], 49 | "marking_asset_id": 2, 50 | "training_asset_id": null 51 | } 52 | ``` 53 | 54 | **可能的状态**: 55 | - `NEW`: 新建 56 | - `SUBMITTED`: 已提交 57 | - `MARKING`: 标记中 58 | - `MARKED`: 已标记 59 | - `TRAINING`: 训练中 60 | - `COMPLETED`: 已完成 61 | - `ERROR`: 错误 62 | 63 | **错误响应**: 64 | 65 | ```json 66 | { 67 | "error": "任务不存在", 68 | "message": "未找到ID为 999 的任务" 69 | } 70 | ``` 71 | 72 | ## 前端使用示例 73 | 74 | ```javascript 75 | // 获取任务状态 76 | async function getTaskStatus(taskId) { 77 | try { 78 | const response = await fetch(`/api/v1/tasks/${taskId}/status`); 79 | if (!response.ok) { 80 | const error = await response.json(); 81 | throw new Error(error.message || '获取任务状态失败'); 82 | } 83 | return await response.json(); 84 | } catch (error) { 85 | console.error('获取任务状态错误:', error); 86 | throw error; 87 | } 88 | } 89 | 90 | // 定期轮询任务状态 91 | function pollTaskStatus(taskId, callback, interval = 3000) { 92 | const timer = setInterval(async () => { 93 | try { 94 | const status = await getTaskStatus(taskId); 95 | callback(status); 96 | 97 | // 如果任务已完成或出错,停止轮询 98 | if (['COMPLETED', 'ERROR', 'MARKED'].includes(status.status)) { 99 | clearInterval(timer); 100 | } 101 | } catch (error) { 102 | console.error('轮询任务状态失败:', error); 103 | clearInterval(timer); 104 | } 105 | }, interval); 106 | 107 | return timer; // 返回计时器,以便外部可以停止轮询 108 | } 109 | ``` -------------------------------------------------------------------------------- /backend/tests/test_task_status.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import json 3 | from app import create_app 4 | from app.database import get_db 5 | 6 | class TaskStatusTestCase(unittest.TestCase): 7 | """测试任务状态接口""" 8 | 9 | def setUp(self): 10 | """测试前准备""" 11 | self.app = create_app() 12 | self.client = self.app.test_client() 13 | self.app_context = self.app.app_context() 14 | self.app_context.push() 15 | 16 | def tearDown(self): 17 | """测试后清理""" 18 | self.app_context.pop() 19 | 20 | def test_get_task_status(self): 21 | """测试获取任务状态""" 22 | # 1. 创建一个测试任务 23 | task_data = { 24 | 'name': '测试任务', 25 | 'description': '这是一个测试任务' 26 | } 27 | response = self.client.post( 28 | '/api/v1/tasks', 29 | data=json.dumps(task_data), 30 | content_type='application/json' 31 | ) 32 | # 验证HTTP状态码始终为200,因为我们将业务状态码放在code字段中 33 | self.assertEqual(response.status_code, 200) 34 | result = json.loads(response.data) 35 | self.assertEqual(result['code'], 0) # 验证成功响应码 36 | task = result['data'] 37 | task_id = task['id'] 38 | 39 | # 2. 测试获取任务状态 40 | response = self.client.get(f'/api/v1/tasks/{task_id}/status') 41 | self.assertEqual(response.status_code, 200) 42 | 43 | # 验证返回的字段 44 | result = json.loads(response.data) 45 | self.assertEqual(result['code'], 0) # 验证成功响应码 46 | status_data = result['data'] 47 | self.assertEqual(status_data['id'], task_id) 48 | self.assertEqual(status_data['name'], '测试任务') 49 | self.assertEqual(status_data['status'], 'NEW') # 新创建的任务状态应为NEW 50 | self.assertIn('progress', status_data) 51 | self.assertIn('recent_logs', status_data) 52 | 53 | # 3. 测试不存在的任务 54 | response = self.client.get('/api/v1/tasks/9999/status') 55 | # 业务错误也使用200作为HTTP状态码,但code字段表示业务错误 56 | result = json.loads(response.data) 57 | self.assertNotEqual(result['code'], 0) # 验证错误响应码 58 | self.assertEqual(result['code'], 1001) # 任务不存在错误码 59 | 60 | if __name__ == '__main__': 61 | unittest.main() -------------------------------------------------------------------------------- /docs/Snipaste_2025-06-20_19-18-44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartRick/RLT/8a5e2e8fddff01d36d9f6f98750d7376b971f1e4/docs/Snipaste_2025-06-20_19-18-44.png -------------------------------------------------------------------------------- /docs/Snipaste_2025-06-20_19-18-55.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartRick/RLT/8a5e2e8fddff01d36d9f6f98750d7376b971f1e4/docs/Snipaste_2025-06-20_19-18-55.png -------------------------------------------------------------------------------- /docs/Snipaste_2025-06-20_19-19-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartRick/RLT/8a5e2e8fddff01d36d9f6f98750d7376b971f1e4/docs/Snipaste_2025-06-20_19-19-01.png -------------------------------------------------------------------------------- /docs/Snipaste_2025-06-20_19-19-50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartRick/RLT/8a5e2e8fddff01d36d9f6f98750d7376b971f1e4/docs/Snipaste_2025-06-20_19-19-50.png -------------------------------------------------------------------------------- /docs/Snipaste_2025-06-20_19-20-00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartRick/RLT/8a5e2e8fddff01d36d9f6f98750d7376b971f1e4/docs/Snipaste_2025-06-20_19-20-00.png -------------------------------------------------------------------------------- /fronted-ui/.eslintrc.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | root: true, 3 | env: { 4 | node: true, 5 | 'vue/setup-compiler-macros': true 6 | }, 7 | extends: [ 8 | 'plugin:vue/vue3-essential', 9 | 'eslint:recommended' 10 | ], 11 | parserOptions: { 12 | ecmaVersion: 2020 13 | }, 14 | rules: { 15 | 'no-console': process.env.NODE_ENV === 'production' ? 'warn' : 'off', 16 | 'no-debugger': process.env.NODE_ENV === 'production' ? 'warn' : 'off', 17 | 'vue/multi-word-component-names': 'off', 18 | 'no-undef': ['error', { 19 | 'typeof': true 20 | }], 21 | 'no-unused-vars': ['warn', { 22 | 'vars': 'all', 23 | 'args': 'after-used', 24 | 'ignoreRestSiblings': true, 25 | 'varsIgnorePattern': '^(watch|ref|computed|reactive|onMounted|onBeforeUnmount)$' 26 | }] 27 | } 28 | } -------------------------------------------------------------------------------- /fronted-ui/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | node_modules 3 | /dist 4 | 5 | 6 | # local env files 7 | .env.local 8 | .env.*.local 9 | 10 | # Log files 11 | npm-debug.log* 12 | yarn-debug.log* 13 | yarn-error.log* 14 | pnpm-debug.log* 15 | 16 | # Editor directories and files 17 | .idea 18 | .vscode 19 | *.suo 20 | *.ntvs* 21 | *.njsproj 22 | *.sln 23 | *.sw? 24 | -------------------------------------------------------------------------------- /fronted-ui/README.md: -------------------------------------------------------------------------------- 1 | # LoRA 自动训练系统 - 前端界面 2 | # RICK lora训练器 3 | 4 | 这是一个用于管理LoRA(Low-Rank Adaptation)模型训练的前端Web应用,提供了直观的用户界面来简化AI模型训练流程。该系统采用现代化的Mac风格UI设计,为用户提供流畅的操作体验。 5 | 6 | ## 📚 主要功能模块 7 | 8 | ### 1. 资产管理模块 9 | - **资产创建与配置**:添加、编辑和删除训练资源,支持本地和远程资产 10 | - **资产连接状态监控**:实时监控资产的连接状态(已连接/待连接/连接错误) 11 | - **资产能力验证**:验证资产是否具备LoRA训练和AI引擎能力 12 | - **资产终端控制**:通过集成的XTerm终端界面直接操作远程资产 13 | - **资产筛选与搜索**:按状态、能力和关键词快速筛选资产 14 | 15 | ### 2. 任务管理模块 16 | - **任务创建与配置**:创建训练任务并配置训练参数 17 | - **任务状态监控**:实时监控任务的执行状态和进度 18 | - **任务操作控制**:提交标记、开始训练、终止任务、重启任务等操作 19 | - **训练图片管理**:上传、预览和删除训练图片 20 | - **标记文本管理**:查看和编辑图片的标记文本 21 | 22 | ### 3. 训练监控模块 23 | - **实时日志查看**:通过终端界面实时查看训练日志 24 | - **训练进度跟踪**:监控训练步骤完成情况和时间估计 25 | - **Loss值曲线图**:使用ECharts实现训练过程中loss值变化的可视化展示 26 | - **训练结果预览**:查看训练完成后的模型效果 27 | - **模型下载**:训练完成后下载生成的模型文件 28 | 29 | ### 4. 系统设置模块 30 | - **全局配置管理**:设置系统级别的默认参数 31 | - 标记轮询间隔 32 | - 任务调度间隔 33 | - 文件目录配置 34 | - 工作流配置 35 | - **标记配置**:设置图像标记的默认参数 36 | - 自动裁剪设置 37 | - 自动标签设置 38 | - 置信度阈值 39 | - 最大标签数量 40 | - 默认裁剪比例 41 | - **AI引擎配置**:配置AI推理引擎参数 42 | - **翻译配置**:配置翻译服务相关参数 43 | - **Lora训练配置**:配置Lora训练的特定参数 44 | 45 | ## 🖥️ UI设计风格 46 | 47 | 项目采用现代化的Mac风格UI设计,具有以下特点: 48 | 49 | 1. **简洁优雅**:界面元素简洁清晰,使用柔和的配色方案 50 | 2. **一致性**:统一的设计语言,包括按钮、表单、卡片等组件 51 | 3. **响应式布局**:适配不同屏幕尺寸,提供良好的移动端体验 52 | 4. **动画过渡**:流畅的过渡动画,增强用户体验 53 | 5. **可访问性**:考虑不同用户需求,提供良好的键盘导航和屏幕阅读支持 54 | 55 | UI组件包括: 56 | - Mac风格的卡片和按钮 57 | - 自定义Modal对话框 58 | - 集成的XTerm终端 59 | - 文本提示工具 60 | - 上下文菜单 61 | - 文件上传组件 62 | - 加载动画 63 | - 开关按钮 64 | - 高亮可编辑文本区域 65 | 66 | ## 🔧 配置优先级系统 67 | 68 | 系统实现了三级配置优先级机制,确保训练参数的灵活性和可控性: 69 | 70 | 1. **任务配置**:针对单个训练任务的特定配置,具有最高优先级 71 | 2. **资产配置**:针对特定数据集或模型的配置,当任务配置未指定时生效 72 | 3. **全局配置**:系统默认配置,当上述两级配置未指定时生效 73 | 74 | 配置优先级:任务配置 > 资产配置 > 全局配置 75 | 76 | ## 🛠️ 技术栈 77 | 78 | - Vue 3 - 前端框架 79 | - Vue Router 4 - 路由管理 80 | - Axios - HTTP请求 81 | - XTerm 5.3 - 终端模拟 82 | - XTerm插件 - fit、web-links、webgl 83 | - ECharts 5.6 - 数据可视化 84 | - Heroicons - 图标库 85 | - Date-fns - 日期时间处理 86 | 87 | ## 📁 项目结构 88 | 89 | ``` 90 | src/ 91 | ├── api/ # API接口封装 92 | │ ├── asset.js # 资产管理API 93 | │ ├── tasks.js # 任务管理API 94 | │ ├── upload.js # 上传文件API 95 | │ ├── terminal.js # 终端操作API 96 | │ ├── settings.js # 系统设置API 97 | │ └── common.js # 通用API方法 98 | ├── assets/ # 静态资源 99 | ├── components/ # 可复用组件 100 | │ ├── assets/ # 资产相关组件 101 | │ ├── common/ # 通用UI组件 102 | │ ├── layout/ # 布局组件 103 | │ ├── tasks/ # 任务相关组件 104 | │ └── terminal/ # 终端相关组件 105 | ├── composables/ # 组合式API函数 106 | ├── router/ # 路由配置 107 | ├── utils/ # 工具函数 108 | ├── views/ # 页面视图 109 | │ ├── Assets.vue # 资产管理页面 110 | │ ├── Tasks.vue # 任务列表页面 111 | │ ├── TaskDetail.vue # 任务详情页面 112 | │ └── Settings.vue # 系统设置页面 113 | ├── App.vue # 根组件 114 | └── main.js # 入口文件 115 | ``` 116 | 117 | ## 🚀 快速开始 118 | 119 | ### 环境要求 120 | 121 | - Node.js 14.0+ 122 | - npm 6.0+ 123 | 124 | ### 安装依赖 125 | 126 | ```bash 127 | npm install 128 | ``` 129 | 130 | ### 开发环境运行 131 | 132 | ```bash 133 | npm run serve 134 | ``` 135 | 136 | ### 生产环境构建 137 | 138 | ```bash 139 | npm run build 140 | ``` 141 | 142 | ## 🔄 与后端集成 143 | 144 | 本项目需要配合后端API使用,确保后端服务已正确配置并运行。系统通过Axios实现与后端的通信,主要包括以下几类API: 145 | 146 | 1. **资产管理API**:创建、更新、删除资产,验证资产能力 147 | 2. **任务管理API**:创建、管理训练任务,监控任务状态 148 | 3. **终端API**:实现远程命令执行和日志监控 149 | 4. **上传API**:处理训练图片和模型文件的上传下载 150 | 5. **设置API**:管理系统全局设置和配置项 151 | 152 | ## 🚧 开发计划 153 | 154 | 当前正在开发的功能: 155 | 156 | 1. **训练结果分析**:完善loss值曲线图表展示,训练结果预览和模型下载功能 157 | 2. **资产管理优化**:改进资产连接状态监控机制和终端控制功能 158 | 3. **任务调度重构**:优化任务队列管理架构,提高系统稳定性和性能 159 | 4. **UI/UX改进**:持续优化用户界面和交互体验 160 | 161 | ## 📄 许可证 162 | 163 | [MIT](LICENSE) 164 | 165 | ## 🤝 贡献指南 166 | 167 | 欢迎提交Issues和Pull Requests来改进项目! 168 | 169 | -------------------------------------------------------------------------------- /fronted-ui/babel.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | presets: [ 3 | '@vue/cli-plugin-babel/preset' 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /fronted-ui/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | AI Studio 8 | 9 | 10 |
11 | 12 | 13 | -------------------------------------------------------------------------------- /fronted-ui/jsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es5", 4 | "module": "esnext", 5 | "baseUrl": "./", 6 | "moduleResolution": "node", 7 | "paths": { 8 | "@/*": [ 9 | "src/*" 10 | ] 11 | }, 12 | "lib": [ 13 | "esnext", 14 | "dom", 15 | "dom.iterable", 16 | "scripthost" 17 | ] 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /fronted-ui/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "fronted-ui", 3 | "version": "0.1.0", 4 | "private": true, 5 | "scripts": { 6 | "serve": "vue-cli-service serve", 7 | "build": "vue-cli-service build", 8 | "lint": "vue-cli-service lint" 9 | }, 10 | "dependencies": { 11 | "@heroicons/vue": "^2.1.1", 12 | "axios": "^1.7.9", 13 | "core-js": "^3.8.3", 14 | "date-fns": "^4.1.0", 15 | "echarts": "^5.6.0", 16 | "marked": "^15.0.12", 17 | "mitt": "^3.0.1", 18 | "vite-plugin-raw": "^1.0.3", 19 | "vue": "^3.2.13", 20 | "vue-router": "^4.5.0", 21 | "xterm": "^5.3.0", 22 | "xterm-addon-fit": "^0.8.0", 23 | "xterm-addon-web-links": "^0.9.0", 24 | "xterm-addon-webgl": "^0.16.0" 25 | }, 26 | "devDependencies": { 27 | "@babel/core": "^7.12.16", 28 | "@babel/eslint-parser": "^7.12.16", 29 | "@vue/cli-plugin-babel": "~5.0.0", 30 | "@vue/cli-plugin-eslint": "~5.0.0", 31 | "@vue/cli-service": "~5.0.0", 32 | "eslint": "^7.32.0", 33 | "eslint-plugin-vue": "^8.0.3" 34 | }, 35 | "eslintConfig": { 36 | "root": true, 37 | "env": { 38 | "node": true 39 | }, 40 | "extends": [ 41 | "plugin:vue/vue3-essential", 42 | "eslint:recommended" 43 | ], 44 | "parserOptions": { 45 | "parser": "@babel/eslint-parser" 46 | }, 47 | "rules": {} 48 | }, 49 | "browserslist": [ 50 | "> 1%", 51 | "last 2 versions", 52 | "not dead", 53 | "not ie 11" 54 | ] 55 | } 56 | -------------------------------------------------------------------------------- /fronted-ui/public/docs/faq.md: -------------------------------------------------------------------------------- 1 | # Lora训练系统常见问题 2 | 3 | ## 基础问题 4 | 5 | ### 什么是Lora训练? 6 | Lora(Low-Rank Adaptation)是一种高效的模型微调技术,它通过在预训练模型的权重矩阵中插入低秩矩阵来实现参数高效的微调。这种方法可以用很少的参数量捕获特定领域或风格的特征,非常适合个性化AI模型。 7 | 8 | ### 需要多少训练图片? 9 | 一般建议: 10 | - 最少:15-20张高质量图片 11 | - 理想:30-50张图片 12 | - 风格训练:20-30张具有一致风格的图片 13 | - 角色训练:30-50张同一角色的不同角度、姿势的图片 14 | 15 | ### 训练需要多长时间? 16 | 训练时间取决于多个因素: 17 | - 图片数量和分辨率 18 | - 训练轮次 19 | - 硬件配置(特别是GPU性能) 20 | 21 | 一般来说,使用中等配置的GPU,训练30张图片10个轮次大约需要30分钟到2小时。 22 | 23 | ## 技术问题 24 | 25 | ### 如何解决"CUDA out of memory"错误? 26 | 当GPU内存不足时会出现此错误,解决方法: 27 | 1. 减小批量大小(train_batch_size) 28 | 2. 降低图片分辨率(resolution) 29 | 3. 减小网络维度(network_dim) 30 | 4. 使用混合精度训练(mixed_precision设置为fp16或bf16) 31 | 32 | ### 训练过程中断怎么办? 33 | 如果训练意外中断,您可以: 34 | 1. 在任务详情页面查看日志,确认中断原因 35 | 2. 修复问题后,点击"继续训练"按钮从最后一个检查点恢复训练 36 | 3. 如果无法恢复,可以创建新任务重新开始训练 37 | 38 | ### 如何提高训练质量? 39 | 提高训练质量的建议: 40 | 1. 使用高质量、风格一致的训练图片 41 | 2. 合理设置学习率(learning_rate) 42 | 3. 适当增加训练轮次(max_train_epochs) 43 | 4. 使用更大的网络维度(network_dim) 44 | 5. 确保触发词(trigger_words)准确描述了您想要捕捉的特征 45 | 46 | ## 参数问题 47 | 48 | ### 学习率应该设置多少? 49 | 学习率是训练中最重要的超参数之一: 50 | - 一般范围:0.0001-0.001 51 | - 默认推荐值:0.0001 52 | - 如果训练不稳定或结果不理想,可以尝试减小学习率 53 | - 如果训练太慢,可以适当增大学习率 54 | 55 | ### 网络维度(network_dim)和Alpha值如何设置? 56 | - 网络维度(network_dim):控制模型容量,通常设置为8-128之间的值 57 | - 较小的值(8-32):学习简单特征,防止过拟合 58 | - 较大的值(64-128):学习复杂特征,需要更多训练数据 59 | 60 | - Alpha值(network_alpha):控制初始化规模,通常设置为network_dim的一半或相等 61 | 62 | ### 什么是混合精度训练? 63 | 混合精度训练是一种使用较低精度(如fp16或bf16)进行部分计算的技术,可以: 64 | - 减少GPU内存使用 65 | - 加快训练速度 66 | - 在大多数情况下不会显著影响训练质量 67 | 68 | 推荐设置: 69 | - 对于较新的GPU:使用"bf16" 70 | - 对于较旧的GPU:使用"fp16" 71 | - 如果遇到训练不稳定:使用"no"(关闭混合精度) 72 | 73 | ## 使用问题 74 | 75 | ### 如何在其他软件中使用训练好的模型? 76 | 训练完成后,您可以: 77 | 1. 在任务详情页面下载模型文件(.safetensors或.pt格式) 78 | 2. 将模型文件放入相应软件的Lora模型目录 79 | 3. 在生成图片时使用触发词并调整Lora强度(通常为0.6-1.0) 80 | 81 | 支持的软件包括:Stable Diffusion WebUI、ComfyUI、AUTOMATIC1111等。 82 | 83 | ### 如何调整已训练模型的强度? 84 | 在使用模型时,您可以通过以下方式调整强度: 85 | 1. 在提示词中使用语法:``,如`` 86 | 2. 在WebUI等界面中使用滑块调整强度 87 | 3. 通常0.6-0.8是一个好的起点,可以根据需要上下调整 88 | 89 | ### 如何组合多个Lora模型? 90 | 您可以在提示词中组合多个Lora模型: 91 | ``` 92 | 一张风景图 93 | ``` 94 | 注意控制总体强度,避免特征过于混杂。组合时建议降低每个模型的单独强度。 95 | 96 | ## 故障排除 97 | 98 | ### 为什么我的训练结果看起来模糊或不理想? 99 | 可能的原因和解决方法: 100 | 1. 训练图片质量不高 → 使用更高质量的图片 101 | 2. 训练轮次不足 → 增加训练轮次 102 | 3. 学习率设置不当 → 调整学习率 103 | 4. 网络维度太小 → 增加网络维度 104 | 5. 触发词不准确 → 优化触发词 105 | 106 | ### 为什么训练过程中损失值异常? 107 | 损失值异常的常见情况: 108 | - 损失值持续增大:学习率可能太高,尝试降低学习率 109 | - 损失值不下降:可能需要更多训练轮次或调整学习率 110 | - 损失值波动剧烈:尝试更换优化器或调整学习率调度器 111 | 112 | ### 系统无法识别我的GPU怎么办? 113 | 如果系统无法识别您的GPU: 114 | 1. 确保已安装正确的GPU驱动程序 115 | 2. 检查CUDA和cuDNN是否正确安装 116 | 3. 重启系统和应用程序 117 | 4. 在系统设置中检查GPU配置 -------------------------------------------------------------------------------- /fronted-ui/public/docs/guide.md: -------------------------------------------------------------------------------- 1 | # Lora训练系统使用指南 2 | 3 | ## 简介 4 | 5 | 欢迎使用Lora训练系统!本系统是一个专为AI模型训练设计的工具,可以帮助您轻松创建和管理Lora模型训练任务。本指南将帮助您了解系统的基本功能和使用方法。 6 | 7 | ## 快速开始 8 | 9 | ### 1. 创建训练资产 10 | 11 | 首先,您需要在"资产"页面创建一个训练资产: 12 | 13 | 1. 点击"资产"页面中的"创建资产"按钮 14 | 3. 填写资产名称和描述 15 | 5. 点击"创建"按钮保存资产 16 | 17 | ### 2. 配置训练参数 18 | 19 | 在创建资产后,您可以配置训练参数: 20 | 21 | 1. 在资产详情页面,点击"编辑参数"按钮 22 | 2. 选择模型训练类型(Flux-Lora、SD1.5-Lora或SDXL-Lora) 23 | 3. 根据您的需求调整训练参数 24 | 4. 点击"保存"按钮 25 | 26 | ### 3. 创建训练任务 27 | 28 | 配置好参数后,您可以创建训练任务: 29 | 30 | 1. 在"任务"页面,点击"创建任务"按钮 31 | 2. 选择您刚才创建的资产 32 | 3. 确认训练参数 33 | 4. 点击"开始训练"按钮 34 | 35 | ### 4. 监控训练进度 36 | 37 | 您可以在"任务"页面监控训练进度: 38 | 39 | 1. 查看任务状态、进度和预计完成时间 40 | 2. 查看训练日志和中间结果 41 | 3. 如需要,可以暂停或取消训练任务 42 | 43 | ## 高级功能 44 | 45 | ### 自动打标功能 46 | 47 | 系统提供自动打标功能,可以自动识别图片内容并生成标签: 48 | 49 | 50 | ### 批量处理 51 | 52 | 您可以批量处理多个图片: 53 | 54 | 1. 在资产页面,选择多个图片 55 | 2. 点击"批量操作"按钮 56 | 3. 选择要执行的操作(如批量打标、批量裁剪等) 57 | 58 | ### 模型导出 59 | 60 | 训练完成后,您可以导出模型用于推理: 61 | 62 | 1. 在任务详情页面,点击"下载模型"按钮 63 | 64 | ## 参数说明 65 | 66 | ### 基本参数 67 | 68 | - **模型训练类型**:选择要使用的基础模型类型 69 | - Flux-Lora:适用于高质量艺术创作 70 | - SD1.5-Lora:通用型,适合多种场景 71 | - SDXL-Lora:高分辨率,细节更丰富 72 | 73 | - **最大训练轮次**:模型训练的总轮数,通常5-10轮即可 74 | 75 | - **图片重复次数**:每张图片在一个训练轮次中重复使用的次数 76 | 77 | ### 高级参数 78 | 79 | - **学习率**:控制模型学习速度,一般建议0.0001-0.001 80 | 81 | - **网络维度**:LoRA网络的维度,影响模型容量和学习能力 82 | - 较小的值(如4-16):学习简单特征,训练速度快 83 | - 较大的值(如64-128):学习复杂特征,需要更多训练数据 84 | 85 | - **批量大小**:每次更新模型参数时处理的样本数量,受GPU内存限制 86 | 87 | ## 故障排除 88 | 89 | ### 常见错误 90 | 91 | - **CUDA内存不足**:减小批量大小或降低分辨率 92 | - **训练不收敛**:调整学习率或增加训练轮次 93 | - **模型过拟合**:增加正则化或减少训练轮次 94 | 95 | ### 性能优化 96 | 97 | - 使用较小的网络维度可以加快训练速度 98 | - 混合精度训练可以减少内存使用 99 | - 适当调整批量大小可以提高训练效率 100 | 101 | ## 系统要求 102 | 103 | - **GPU**:建议使用NVIDIA GPU,至少8GB显存 104 | - **存储空间**:至少20GB可用空间 105 | - **操作系统**:Windows 10/11、Ubuntu 20.04或更高版本 106 | 107 | ## 更新日志 108 | 109 | ### v1.2.0 (2025-06-20) 110 | - 添加SDXL-Lora、sd1.5-Lora支持 111 | - 优化自动打标算法 112 | - 改进用户界面 113 | 114 | ### v1.1.0 (2025-05-15) 115 | - 添加批量处理功能 116 | - 改进训练稳定性 117 | - 修复多个已知问题 118 | 119 | ### v1.0.0 (2025-04-01) 120 | - 初始版本发布 -------------------------------------------------------------------------------- /fronted-ui/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartRick/RLT/8a5e2e8fddff01d36d9f6f98750d7376b971f1e4/fronted-ui/public/favicon.ico -------------------------------------------------------------------------------- /fronted-ui/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | <%= htmlWebpackPlugin.options.title %> 9 | 10 | 11 | 14 |
15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /fronted-ui/src/App.vue: -------------------------------------------------------------------------------- 1 | 8 | 9 | 35 | 36 | 75 | -------------------------------------------------------------------------------- /fronted-ui/src/api/asset.js: -------------------------------------------------------------------------------- 1 | import request from '@/utils/request' 2 | 3 | const BASE_URL = '/assets' 4 | 5 | export const assetApi = { 6 | /** 7 | * 获取资产列表 8 | * @returns {Promise} 资产列表 9 | */ 10 | async getAssets() { 11 | return request.get(BASE_URL) 12 | }, 13 | 14 | /** 15 | * 创建资产 16 | * @param {Object} data - 资产数据 17 | * @returns {Promise} 创建的资产 18 | */ 19 | async createAsset(data) { 20 | return request.post(BASE_URL, data) 21 | }, 22 | 23 | /** 24 | * 更新资产 25 | * @param {number|string} id - 资产ID 26 | * @param {Object} data - 更新数据 27 | * @returns {Promise} 更新后的资产 28 | */ 29 | async updateAsset(id, data) { 30 | return request.put(`${BASE_URL}/${id}`, data) 31 | }, 32 | 33 | /** 34 | * 删除资产 35 | * @param {number|string} id - 资产ID 36 | * @returns {Promise} 37 | */ 38 | async deleteAsset(id) { 39 | return request.delete(`${BASE_URL}/${id}`) 40 | }, 41 | 42 | /** 43 | * 验证资产能力 44 | * @param {number|string} id - 资产ID 45 | * @returns {Promise} 验证结果 46 | */ 47 | async verifyCapabilities(id) { 48 | return request.post(`${BASE_URL}/${id}/verify`) 49 | }, 50 | 51 | /** 52 | * 验证SSH连接 53 | * @param {Object} data - SSH连接参数 54 | * @returns {Promise} 验证结果 55 | */ 56 | async verifySshConnection(data) { 57 | return request.post(`${BASE_URL}/verify-ssh`, data) 58 | }, 59 | 60 | /** 61 | * 切换资产启用状态 62 | * @param {number} assetId 资产ID 63 | * @param {boolean} enabled 是否启用 64 | * @returns {Promise} 65 | */ 66 | async toggleAssetStatus(assetId, enabled) { 67 | return request({ 68 | url: `${BASE_URL}/${assetId}/toggle`, 69 | method: 'post', 70 | data: { enabled } 71 | }) 72 | } 73 | } -------------------------------------------------------------------------------- /fronted-ui/src/api/common.js: -------------------------------------------------------------------------------- 1 | import request from '@/utils/request' 2 | 3 | const BASE_URL = '/common' 4 | 5 | export const commonApi = { 6 | /** 7 | * 翻译文本 8 | * @param {string} text - 需要翻译的文本 9 | * @param {string} [to_lang='zh'] - 目标语言 10 | * @param {string} [from_lang='auto'] - 源语言 11 | * @returns {Promise} 翻译结果 12 | */ 13 | async translateText(text, to_lang = 'zh', from_lang = 'auto') { 14 | return request.post(`${BASE_URL}/translate`, { 15 | text, 16 | to_lang, 17 | from_lang 18 | }) 19 | }, 20 | 21 | /** 22 | * 批量翻译文本 23 | * @param {Array} texts - 需要翻译的文本数组 24 | * @param {string} [to_lang='zh'] - 目标语言 25 | * @param {string} [from_lang='auto'] - 源语言 26 | * @returns {Promise} 翻译结果 27 | */ 28 | async batchTranslate(texts, to_lang = 'zh', from_lang = 'auto') { 29 | return request.post(`${BASE_URL}/batch-translate`, { 30 | texts, 31 | to_lang, 32 | from_lang 33 | }) 34 | }, 35 | 36 | /** 37 | * GET方式翻译文本(适用于简单翻译) 38 | * @param {string} text - 需要翻译的文本 39 | * @param {string} [to_lang='zh'] - 目标语言 40 | * @param {string} [from_lang='auto'] - 源语言 41 | * @returns {Promise} 翻译结果 42 | */ 43 | async translateTextGet(text, to_lang = 'zh', from_lang = 'auto') { 44 | return request.get(`${BASE_URL}/translate`, { 45 | params: { 46 | text, 47 | to_lang, 48 | from_lang 49 | } 50 | }) 51 | } 52 | } -------------------------------------------------------------------------------- /fronted-ui/src/api/settings.js: -------------------------------------------------------------------------------- 1 | import request from '@/utils/request' 2 | const BASE_URL = '/settings' 3 | 4 | export const settingsApi = { 5 | /** 6 | * 获取设置 7 | * @returns {Promise} 系统设置 8 | */ 9 | async getSettings() { 10 | return request.get(BASE_URL) 11 | }, 12 | 13 | /** 14 | * 更新设置 15 | * @param {Object} data - 设置数据 16 | * @returns {Promise} 更新后的设置 17 | */ 18 | async updateSettings(data) { 19 | return request.put(BASE_URL, data) 20 | }, 21 | 22 | /** 23 | * 获取任务打标配置 24 | * @param {number} taskId - 任务ID 25 | * @returns {Promise} 打标配置 26 | */ 27 | async getTaskMarkConfig(taskId) { 28 | return request.get(`${BASE_URL}/tasks/${taskId}/mark-config`) 29 | }, 30 | 31 | /** 32 | * 获取任务训练配置 33 | * @param {number} taskId - 任务ID 34 | * @returns {Promise} 训练配置 35 | */ 36 | async getTaskTrainingConfig(taskId) { 37 | return request.get(`${BASE_URL}/tasks/${taskId}/training-config`) 38 | } 39 | } -------------------------------------------------------------------------------- /fronted-ui/src/api/tasks.js: -------------------------------------------------------------------------------- 1 | import request from '@/utils/request' 2 | 3 | const BASE_URL = '/tasks' 4 | 5 | export const tasksApi = { 6 | // 获取任务列表 7 | async getTasks(params) { 8 | return request.get(BASE_URL, { params }) 9 | }, 10 | 11 | // 创建任务 12 | async createTask(data) { 13 | return request.post(BASE_URL, data) 14 | }, 15 | 16 | // 获取任务详情 17 | async getTaskById(id) { 18 | return request.get(`${BASE_URL}/${id}`) 19 | }, 20 | 21 | // 更新任务 22 | async updateTask(id, data) { 23 | return request.put(`${BASE_URL}/${id}`, data) 24 | }, 25 | 26 | // 删除任务 27 | async deleteTask(id) { 28 | return request.delete(`${BASE_URL}/${id}`) 29 | }, 30 | 31 | // 上传图片 32 | async uploadImages(taskId, formData) { 33 | return request.post(`${BASE_URL}/${taskId}/images`, formData, { 34 | headers: { 35 | 'Content-Type': 'multipart/form-data' 36 | } 37 | }) 38 | }, 39 | 40 | // 删除图片 41 | async deleteImage(taskId, imageId) { 42 | return request.delete(`${BASE_URL}/${taskId}/images/${imageId}`) 43 | }, 44 | 45 | // 开始标记 46 | async startMarking(taskId) { 47 | return request.post(`${BASE_URL}/${taskId}/mark`) 48 | }, 49 | 50 | // 开始训练 51 | async startTraining(taskId) { 52 | return request.post(`${BASE_URL}/${taskId}/train`) 53 | }, 54 | 55 | /** 56 | * 重启任务 57 | * @param {number|string} taskId - 任务ID 58 | * @returns {Promise} 任务信息 59 | */ 60 | async restartTask(taskId) { 61 | return request.post(`${BASE_URL}/${taskId}/restart`) 62 | }, 63 | 64 | /** 65 | * 取消任务 66 | * @param {number|string} taskId - 任务ID 67 | * @returns {Promise} 任务信息 68 | */ 69 | async cancelTask(taskId) { 70 | return request.post(`${BASE_URL}/${taskId}/cancel`) 71 | }, 72 | 73 | /** 74 | * 获取任务状态 75 | * @param {number|string} taskId - 任务ID 76 | * @returns {Promise} 任务状态信息 77 | */ 78 | async getTaskStatus(taskId) { 79 | return request.get(`${BASE_URL}/${taskId}/status`) 80 | }, 81 | 82 | /** 83 | * 获取打标文本 84 | * @param {number|string} taskId - 任务ID 85 | * @returns {Promise} 打标文本信息,key为图片名称,value为打标文本 86 | */ 87 | async getMarkedTexts(taskId) { 88 | return request.get(`${BASE_URL}/${taskId}/marked_texts`) 89 | }, 90 | 91 | /** 92 | * 更新打标文本 93 | * @param {number|string} taskId - 任务ID 94 | * @param {string} filename - 图片文件名 95 | * @param {string} content - 打标文本内容 96 | * @returns {Promise} 更新后的打标文本信息 97 | */ 98 | async updateMarkedText(taskId, filename, content) { 99 | return request.put(`${BASE_URL}/${taskId}/marked_texts`, { 100 | filename, 101 | content 102 | }) 103 | }, 104 | 105 | /** 106 | * 停止任务 107 | * @param {number} taskId 任务ID 108 | * @returns {Promise} 任务对象 109 | */ 110 | async stopTask(taskId) { 111 | return request.post(`${BASE_URL}/${taskId}/stop`) 112 | }, 113 | 114 | /** 115 | * 批量删除任务图片 116 | * @param {number|string} taskId 任务ID 117 | * @param {Array} imageIds 图片ID数组 118 | * @returns {Promise} 119 | */ 120 | async batchDeleteImages(taskId, imageIds) { 121 | return request.delete(`${BASE_URL}/${taskId}/images/batch`, { 122 | data: { image_ids: imageIds } 123 | }) 124 | }, 125 | 126 | /** 127 | * 批量更新打标文本 128 | * @param {number|string} taskId 任务ID 129 | * @param {Object} markedTexts 文件名到文本内容的映射 130 | * @returns {Promise} 131 | */ 132 | async batchUpdateMarkedTexts(taskId, markedTexts) { 133 | return request.put(`${BASE_URL}/${taskId}/marked_texts/batch`, markedTexts) 134 | }, 135 | 136 | /** 137 | * 获取训练结果 138 | * @param {number|string} taskId 任务ID 139 | * @returns {Promise} 训练结果信息,包含模型列表等 140 | */ 141 | async getTrainingResults(taskId) { 142 | return request.get(`${BASE_URL}/${taskId}/training-results`) 143 | }, 144 | 145 | /** 146 | * 获取训练loss曲线数据 147 | * @param {number|string} taskId 任务ID 148 | * @returns {Promise} 训练loss数据,包含数据点和训练进度 149 | */ 150 | async getTrainingLoss(taskId) { 151 | return request.get(`${BASE_URL}/${taskId}/training-loss`) 152 | }, 153 | 154 | /** 155 | * 批量提交任务进行标记 156 | * @param {Array} taskIds 任务ID数组 157 | * @returns {Promise} 成功提交的任务ID列表 158 | */ 159 | async batchStartMarking(taskIds) { 160 | return request.post(`${BASE_URL}/batch/mark`, { task_ids: taskIds }) 161 | }, 162 | 163 | // 获取任务训练历史 164 | async getTaskTrainingHistory(taskId) { 165 | const response = await request.get(`${BASE_URL}/${taskId}/execution-history`) 166 | return response 167 | }, 168 | 169 | // 获取特定历史记录详情 170 | async getTrainingHistoryDetails(historyId) { 171 | const response = await request.get(`${BASE_URL}/execution-history/${historyId}`) 172 | return response 173 | }, 174 | 175 | /** 176 | * 获取任务配置 177 | * @param {number|string} taskId 任务ID 178 | * @returns {Promise} 任务配置,包含打标和训练配置 179 | */ 180 | async getTaskConfig(taskId) { 181 | return request.get(`${BASE_URL}/${taskId}/config`) 182 | }, 183 | 184 | /** 185 | * 更新任务配置 186 | * @param {number|string} taskId 任务ID 187 | * @param {Object} configData 配置数据 188 | * @returns {Promise} 更新后的配置信息 189 | */ 190 | async updateTaskConfig(taskId, configData) { 191 | return request.put(`${BASE_URL}/${taskId}/config`, configData) 192 | } 193 | } -------------------------------------------------------------------------------- /fronted-ui/src/api/terminal.js: -------------------------------------------------------------------------------- 1 | import request from '@/utils/request' 2 | 3 | const BASE_URL = '/terminal' 4 | 5 | export const terminalApi = { 6 | /** 7 | * 列出远程目录文件 8 | * @param {number|string} assetId - 资产ID 9 | * @param {string} path - 远程路径 10 | * @returns {Promise} 文件列表 11 | */ 12 | async listFiles(assetId, path = '/') { 13 | return request.get(`${BASE_URL}/files/list/${assetId}`, { 14 | params: { path } 15 | }) 16 | }, 17 | 18 | /** 19 | * 浏览远程目录(带排序和过滤功能) 20 | * @param {number|string} assetId - 资产ID 21 | * @param {string} path - 远程路径 22 | * @param {string} sortBy - 排序字段 23 | * @param {string} sortOrder - 排序顺序 24 | * @returns {Promise} 文件和目录信息 25 | */ 26 | async browseDirectory(assetId, path = '/', sortBy = 'name', sortOrder = 'asc') { 27 | return request.get(`${BASE_URL}/files/browse/${assetId}`, { 28 | params: { 29 | path, 30 | sort_by: sortBy, 31 | sort_order: sortOrder 32 | } 33 | }) 34 | }, 35 | 36 | /** 37 | * 上传文件到远程服务器 38 | * @param {number|string} assetId - 资产ID 39 | * @param {File} file - 文件对象 40 | * @param {string} remotePath - 远程路径 41 | * @returns {Promise} 上传结果 42 | */ 43 | async uploadFile(assetId, file, remotePath = '/') { 44 | const formData = new FormData() 45 | formData.append('file', file) 46 | formData.append('remote_path', remotePath) 47 | 48 | const response = await request.post(`${BASE_URL}/files/stream-upload/${assetId}`, formData, { 49 | headers: { 50 | 'Content-Type': 'multipart/form-data' 51 | } 52 | }) 53 | console.log("上传返回值response",response) 54 | return response 55 | }, 56 | 57 | /** 58 | * 从远程服务器下载文件 59 | * @param {number|string} assetId - 资产ID 60 | * @param {string} remotePath - 远程文件路径 61 | * @returns {Promise} 文件Blob对象 62 | */ 63 | async downloadFile(assetId, remotePath) { 64 | return fetch(`/api/v1/${BASE_URL}/files/stream-download/${assetId}`, { 65 | method: 'POST', 66 | headers: { 67 | 'Content-Type': 'application/json' 68 | }, 69 | body: JSON.stringify({ remote_path: remotePath }) 70 | }).then(response => { 71 | if (!response.ok) { 72 | throw new Error(`下载失败: ${response.statusText}`) 73 | } 74 | return response.blob() 75 | }) 76 | }, 77 | 78 | /** 79 | * 获取WebSocket连接URL 80 | * @param {number|string} assetId - 资产ID 81 | * @returns {string} WebSocket URL 82 | */ 83 | getWebSocketUrl(assetId) { 84 | const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:' 85 | return `${protocol}//${window.location.host}/api/v1/${BASE_URL}/${assetId}` 86 | }, 87 | 88 | /** 89 | * 删除远程服务器上的文件或目录 90 | * @param {number|string} assetId - 资产ID 91 | * @param {string} remotePath - 远程文件或目录路径 92 | * @returns {Promise} 删除结果 93 | */ 94 | async deleteRemoteFile(assetId, remotePath) { 95 | return request.post(`${BASE_URL}/files/delete/${assetId}`, { 96 | remote_path: remotePath 97 | }) 98 | } 99 | } -------------------------------------------------------------------------------- /fronted-ui/src/api/upload.js: -------------------------------------------------------------------------------- 1 | import request from '@/utils/request' 2 | 3 | const BASE_URL = '/upload' 4 | 5 | export const uploadApi = { 6 | /** 7 | * 上传文件 8 | * @param {File} file - 要上传的文件对象 9 | * @param {string} description - 文件描述 10 | * @returns {Promise} 上传结果 11 | */ 12 | async uploadFile(file, description = '') { 13 | const formData = new FormData() 14 | formData.append('file', file) 15 | formData.append('description', description) 16 | 17 | return request.post(`${BASE_URL}/files`, formData, { 18 | headers: { 19 | 'Content-Type': 'multipart/form-data' 20 | } 21 | }) 22 | }, 23 | 24 | /** 25 | * 获取所有文件列表 26 | * @returns {Promise} 文件列表 27 | */ 28 | async getFiles() { 29 | return request.get(`${BASE_URL}/files`) 30 | }, 31 | 32 | /** 33 | * 获取单个文件信息 34 | * @param {number} fileId - 文件ID 35 | * @returns {Promise} 文件信息 36 | */ 37 | async getFile(fileId) { 38 | return request.get(`${BASE_URL}/files/${fileId}`) 39 | }, 40 | 41 | /** 42 | * 删除文件 43 | * @param {number} fileId - 文件ID 44 | * @returns {Promise} 删除结果 45 | */ 46 | async deleteFile(fileId) { 47 | return request.delete(`${BASE_URL}/files/${fileId}`) 48 | }, 49 | 50 | /** 51 | * 获取文件下载链接 52 | * @param {number} fileId - 文件ID 53 | * @returns {string} 下载链接 54 | */ 55 | getDownloadUrl(fileId) { 56 | return `${request.defaults.baseURL}${BASE_URL}/files/${fileId}/download` 57 | } 58 | } -------------------------------------------------------------------------------- /fronted-ui/src/assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartRick/RLT/8a5e2e8fddff01d36d9f6f98750d7376b971f1e4/fronted-ui/src/assets/logo.png -------------------------------------------------------------------------------- /fronted-ui/src/assets/styles/global.css: -------------------------------------------------------------------------------- 1 | :root { 2 | /* 颜色系统 */ 3 | --primary-color: #0A84FF; 4 | --success-color: #32D74B; 5 | --warning-color: #FF9F0A; 6 | --danger-color: #FF453A; 7 | --info-color: #64D2FF; 8 | 9 | /* 文本颜色 */ 10 | --text-primary: #000000; 11 | --text-secondary: #6B6B6B; 12 | --text-tertiary: #8E8E93; 13 | 14 | /* 背景颜色 */ 15 | --background-primary: #F5F5F7; 16 | --background-secondary: #FFFFFF; 17 | --background-tertiary: #F2F2F7; 18 | 19 | /* 边框颜色 */ 20 | --border-color: #D2D2D7; 21 | --border-color-light: #E5E5EA; 22 | 23 | /* 阴影 */ 24 | --shadow-sm: 0 1px 2px rgba(0, 0, 0, 0.05); 25 | --shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); 26 | --shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05); 27 | 28 | /* 圆角 */ 29 | --radius-sm: 6px; 30 | --radius-md: 8px; 31 | --radius-lg: 12px; 32 | 33 | /* 动画 */ 34 | --transition-speed: 0.2s; 35 | 36 | /* 间距 */ 37 | --spacing-1: 4px; 38 | --spacing-2: 8px; 39 | --spacing-3: 12px; 40 | --spacing-4: 16px; 41 | --spacing-5: 20px; 42 | --spacing-6: 24px; 43 | } 44 | 45 | /* 基础样式重置 */ 46 | * { 47 | margin: 0; 48 | padding: 0; 49 | box-sizing: border-box; 50 | -webkit-font-smoothing: antialiased; 51 | -moz-osx-font-smoothing: grayscale; 52 | } 53 | 54 | body { 55 | font-family: -apple-system, BlinkMacSystemFont, 'SF Pro Text', 'Helvetica Neue', sans-serif; 56 | background-color: var(--background-primary); 57 | color: var(--text-primary); 58 | line-height: 1.5; 59 | } 60 | 61 | /* 通用卡片样式 */ 62 | .mac-card { 63 | background: var(--background-secondary); 64 | border-radius: var(--radius-md); 65 | box-shadow: var(--shadow-sm); 66 | border: 1px solid var(--border-color-light); 67 | backdrop-filter: blur(20px); 68 | transition: transform var(--transition-speed), 69 | box-shadow var(--transition-speed); 70 | padding: 20px; 71 | transition: all 0.3s ease; 72 | } 73 | 74 | /* .mac-card:hover { */ 75 | /* transform: translateY(-2px); */ 76 | /* box-shadow: var(--shadow-md); */ 77 | /* } */ 78 | 79 | /* 按钮样式 */ 80 | .mac-btn { 81 | padding: var(--spacing-2) var(--spacing-4); 82 | border-radius: var(--radius-sm); 83 | font-size: 14px; 84 | font-weight: 500; 85 | border: 1px solid var(--border-color); 86 | background: var(--background-secondary); 87 | color: var(--text-primary); 88 | cursor: pointer; 89 | display: inline-flex; 90 | align-items: center; 91 | justify-content: center; 92 | gap: var(--spacing-2); 93 | transition: all var(--transition-speed); 94 | } 95 | 96 | .mac-btn:hover { 97 | background: var(--background-tertiary); 98 | border-color: var(--border-color); 99 | } 100 | 101 | .mac-btn.primary { 102 | background: var(--primary-color); 103 | border-color: var(--primary-color); 104 | color: white; 105 | } 106 | 107 | .mac-btn.primary:hover { 108 | background: #0969DA; 109 | border-color: #0969DA; 110 | } 111 | 112 | /* 输入框样式 */ 113 | .mac-input { 114 | padding: var(--spacing-2) var(--spacing-3); 115 | border-radius: var(--radius-sm); 116 | border: 1px solid var(--border-color); 117 | font-size: 14px; 118 | background: var(--background-secondary); 119 | transition: all var(--transition-speed); 120 | width: 100%; 121 | } 122 | 123 | /* textarea样式 */ 124 | textarea, .mac-textarea { 125 | font-family: "微软雅黑", "Heiti SC", "黑体", -apple-system, BlinkMacSystemFont, 'SF Pro Text', 'Helvetica Neue', sans-serif; 126 | padding: var(--spacing-3); 127 | line-height: 1.6; 128 | } 129 | 130 | .mac-input:focus { 131 | outline: none; 132 | border-color: var(--primary-color); 133 | box-shadow: 0 0 0 3px color-mix(in srgb, var(--primary-color) 20%, transparent); 134 | } 135 | 136 | /* 下拉框样式 */ 137 | .mac-select { 138 | padding: var(--spacing-2) var(--spacing-3); 139 | border-radius: var(--radius-sm); 140 | border: 1px solid var(--border-color); 141 | font-size: 14px; 142 | background: var(--background-secondary); 143 | cursor: pointer; 144 | transition: all var(--transition-speed); 145 | } 146 | 147 | .mac-select:focus { 148 | outline: none; 149 | border-color: var(--primary-color); 150 | box-shadow: 0 0 0 3px color-mix(in srgb, var(--primary-color) 20%, transparent); 151 | } 152 | 153 | /* 表单样式 */ 154 | .form-section { 155 | background: var(--background-tertiary); 156 | border-radius: var(--radius-md); 157 | padding: var(--spacing-4); 158 | margin-bottom: var(--spacing-4); 159 | } 160 | 161 | .form-item { 162 | margin-bottom: var(--spacing-4); 163 | } 164 | 165 | .form-item label { 166 | display: block; 167 | margin-bottom: var(--spacing-2); 168 | color: var(--text-secondary); 169 | font-size: 14px; 170 | } 171 | 172 | .form-row { 173 | display: grid; 174 | grid-template-columns: 1fr 1fr; 175 | gap: var(--spacing-4); 176 | } 177 | 178 | /* 动画 */ 179 | @keyframes slideUp { 180 | from { 181 | transform: translateY(20px); 182 | opacity: 0; 183 | } 184 | to { 185 | transform: translateY(0); 186 | opacity: 1; 187 | } 188 | } 189 | 190 | .modal-enter-active { 191 | animation: slideUp var(--transition-speed) ease-out; 192 | } 193 | 194 | .modal-leave-active { 195 | animation: slideUp var(--transition-speed) ease-in reverse; 196 | } -------------------------------------------------------------------------------- /fronted-ui/src/components/common/Checkbox.vue: -------------------------------------------------------------------------------- 1 | 28 | 29 | 54 | 55 | -------------------------------------------------------------------------------- /fronted-ui/src/components/common/ContextMenu.vue: -------------------------------------------------------------------------------- 1 | 35 | 36 | 151 | 152 | -------------------------------------------------------------------------------- /fronted-ui/src/components/common/KeyValueConfig.vue: -------------------------------------------------------------------------------- 1 | 43 | 44 | 116 | 117 | -------------------------------------------------------------------------------- /fronted-ui/src/components/common/LoadingSpinner.vue: -------------------------------------------------------------------------------- 1 | 14 | 15 | 29 | 30 | -------------------------------------------------------------------------------- /fronted-ui/src/components/common/LoraTrainingParams.vue: -------------------------------------------------------------------------------- 1 | 21 | 22 | 81 | 82 | -------------------------------------------------------------------------------- /fronted-ui/src/components/common/Message.vue: -------------------------------------------------------------------------------- 1 | 14 | 15 | 113 | 114 | -------------------------------------------------------------------------------- /fronted-ui/src/components/common/PageTabs.vue: -------------------------------------------------------------------------------- 1 | 14 | 15 | 40 | 41 | -------------------------------------------------------------------------------- /fronted-ui/src/components/common/SwitchButton.vue: -------------------------------------------------------------------------------- 1 | 11 | 12 | 29 | 30 | -------------------------------------------------------------------------------- /fronted-ui/src/components/common/TaskConfigCard.vue: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fronted-ui/src/components/common/TooltipText.vue: -------------------------------------------------------------------------------- 1 | 17 | 18 | 154 | 155 | -------------------------------------------------------------------------------- /fronted-ui/src/components/layout/AppLayout.vue: -------------------------------------------------------------------------------- 1 | 39 | 40 | 99 | 100 | -------------------------------------------------------------------------------- /fronted-ui/src/components/tasks/TaskForm.vue: -------------------------------------------------------------------------------- 1 |