├── .gitignore ├── LICENSE ├── README.md ├── app ├── api │ ├── __init__.py │ ├── digital_human_avatars.py │ ├── digital_human_voices.py │ ├── file_deal.py │ ├── font.py │ ├── short_videos.py │ ├── tasks.py │ └── video │ │ ├── crt_video.py │ │ └── h5_crt_video.py ├── database.py ├── main.py ├── models │ ├── __init__.py │ ├── digital_human_avatar.py │ ├── digital_human_voice.py │ ├── font.py │ ├── publishing_plan_detail.py │ ├── short_video.py │ ├── short_video_detail.py │ └── task.py ├── schemas │ ├── __init__.py │ ├── digital_human_avatar.py │ ├── digital_human_voice.py │ ├── font.py │ ├── response.py │ ├── short_video.py │ ├── short_video_detail.py │ └── task.py ├── services │ ├── ffmpeg_service.py │ ├── fishspeech_service.py │ ├── task_service.py │ ├── transcription_service.py │ ├── ultralight_service.py │ ├── upload_service.py │ └── wav2lip_service.py └── utils │ ├── gpu_utils.py │ ├── logger_utils.py │ ├── media_utils.py │ ├── response_utils.py │ └── user_utils.py ├── backend_schema.sql ├── doc ├── pro_01.jpg ├── pro_02.jpg ├── pro_03.jpg ├── pro_04.jpg ├── pro_05.jpg ├── pro_06.jpg ├── pro_07.jpg ├── pro_08.jpg ├── pro_09.jpg ├── pro_10.jpg ├── wx_01.jpg └── wx_02.jpg ├── external_modules ├── fish-speech │ ├── .dockerignore │ ├── .gitignore │ ├── .pre-commit-config.yaml │ ├── .project-root │ ├── .readthedocs.yaml │ ├── API_FLAGS.txt │ ├── README.zh.md │ ├── docker-compose.dev.yml │ ├── dockerfile │ ├── dockerfile.dev │ ├── docs │ │ ├── CNAME │ │ ├── assets │ │ │ └── figs │ │ │ │ ├── VS_1.jpg │ │ │ │ ├── VS_1_pt-BR.png │ │ │ │ ├── diagram.png │ │ │ │ └── diagrama.png │ │ ├── en │ │ │ ├── finetune.md │ │ │ ├── index.md │ │ │ ├── inference.md │ │ │ └── samples.md │ │ ├── ja │ │ │ ├── finetune.md │ │ │ ├── index.md │ │ │ ├── inference.md │ │ │ └── samples.md │ │ ├── pt │ │ │ ├── finetune.md │ │ │ ├── index.md │ │ │ ├── inference.md │ │ │ └── samples.md │ │ ├── requirements.txt │ │ ├── stylesheets │ │ │ └── extra.css │ │ └── zh │ │ │ ├── finetune.md │ │ │ ├── index.md │ │ │ ├── inference.md │ │ │ └── samples.md │ ├── entrypoint.sh │ ├── fish_speech │ │ ├── callbacks │ │ │ ├── __init__.py │ │ │ └── grad_norm.py │ │ ├── configs │ │ │ ├── base.yaml │ │ │ ├── firefly_gan_vq.yaml │ │ │ ├── lora │ │ │ │ └── r_8_alpha_16.yaml │ │ │ └── text2semantic_finetune.yaml │ │ ├── conversation.py │ │ ├── datasets │ │ │ ├── concat_repeat.py │ │ │ ├── protos │ │ │ │ ├── text-data.proto │ │ │ │ ├── text_data_pb2.py │ │ │ │ └── text_data_stream.py │ │ │ ├── semantic.py │ │ │ └── vqgan.py │ │ ├── i18n │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── locale │ │ │ │ ├── en_US.json │ │ │ │ ├── es_ES.json │ │ │ │ ├── ja_JP.json │ │ │ │ ├── pt_BR.json │ │ │ │ └── zh_CN.json │ │ │ └── scan.py │ │ ├── models │ │ │ ├── text2semantic │ │ │ │ ├── __init__.py │ │ │ │ ├── lit_module.py │ │ │ │ ├── llama.py │ │ │ │ └── lora.py │ │ │ └── vqgan │ │ │ │ ├── __init__.py │ │ │ │ ├── modules │ │ │ │ ├── firefly.py │ │ │ │ └── fsq.py │ │ │ │ └── utils.py │ │ ├── scheduler.py │ │ ├── text │ │ │ ├── __init__.py │ │ │ ├── chn_text_norm │ │ │ │ ├── .gitignore │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── basic_class.py │ │ │ │ ├── basic_constant.py │ │ │ │ ├── basic_util.py │ │ │ │ ├── cardinal.py │ │ │ │ ├── date.py │ │ │ │ ├── digit.py │ │ │ │ ├── fraction.py │ │ │ │ ├── money.py │ │ │ │ ├── percentage.py │ │ │ │ ├── telephone.py │ │ │ │ └── text.py │ │ │ ├── clean.py │ │ │ └── spliter.py │ │ ├── train.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── braceexpand.py │ │ │ ├── context.py │ │ │ ├── file.py │ │ │ ├── instantiators.py │ │ │ ├── logger.py │ │ │ ├── logging_utils.py │ │ │ ├── rich_utils.py │ │ │ ├── spectrogram.py │ │ │ └── utils.py │ │ └── webui │ │ │ ├── css │ │ │ └── style.css │ │ │ ├── html │ │ │ └── footer.html │ │ │ ├── js │ │ │ └── animate.js │ │ │ ├── launch_utils.py │ │ │ └── manage.py │ ├── inference.ipynb │ ├── install_env.bat │ ├── mkdocs.yml │ ├── pyproject.toml │ ├── pyrightconfig.json │ ├── run_cmd.bat │ ├── start.bat │ └── tools │ │ ├── api.py │ │ ├── commons.py │ │ ├── download_models.py │ │ ├── extract_model.py │ │ ├── file.py │ │ ├── llama │ │ ├── build_dataset.py │ │ ├── eval_in_context.py │ │ ├── generate.py │ │ ├── merge_lora.py │ │ ├── quantize.py │ │ └── rebuild_tokenizer.py │ │ ├── msgpack_api.py │ │ ├── post_api.py │ │ ├── sensevoice │ │ ├── README.md │ │ ├── __init__.py │ │ ├── auto_model.py │ │ ├── fun_asr.py │ │ └── vad_utils.py │ │ ├── smart_pad.py │ │ ├── vqgan │ │ ├── create_train_split.py │ │ ├── extract_vq.py │ │ └── inference.py │ │ ├── webui.py │ │ └── whisper_asr.py ├── ultralight │ ├── .gitignore │ ├── README.md │ ├── data_utils │ │ ├── FeaturePipeline.py │ │ ├── base_module.py │ │ ├── checkpoint_epoch_335.pth.tar │ │ ├── conf │ │ │ ├── READ.md │ │ │ ├── decode_engine.yaml │ │ │ ├── decode_engine_V4.yaml │ │ │ └── wenetspeech_unified_conformer │ │ │ │ ├── global_cmvn │ │ │ │ ├── train.yaml │ │ │ │ └── words.txt │ │ ├── decode_engine_V4.yaml │ │ ├── detect_face.py │ │ ├── get_landmark.py │ │ ├── hubert.py │ │ ├── mean_face.txt │ │ ├── pfld_mobileone.py │ │ ├── process.py │ │ ├── scrfd_2.5g_kps.onnx │ │ ├── wenet │ │ │ ├── bin │ │ │ │ ├── alignment.py │ │ │ │ ├── average_model.py │ │ │ │ ├── export_jit.py │ │ │ │ ├── recognize.py │ │ │ │ ├── recognize_deprecated.py │ │ │ │ ├── recognize_wav.py │ │ │ │ ├── recognize_wav_streaming.py │ │ │ │ ├── train.py │ │ │ │ └── train_deprecated.py │ │ │ ├── transformer │ │ │ │ ├── asr_model.py │ │ │ │ ├── asr_model_streaming.py │ │ │ │ ├── attention.py │ │ │ │ ├── cmvn.py │ │ │ │ ├── convolution.py │ │ │ │ ├── ctc.py │ │ │ │ ├── decoder.py │ │ │ │ ├── decoder_layer.py │ │ │ │ ├── decoder_streaming.py │ │ │ │ ├── embedding.py │ │ │ │ ├── encoder.py │ │ │ │ ├── encoder_layer.py │ │ │ │ ├── encoder_streaming.py │ │ │ │ ├── label_smoothing_loss.py │ │ │ │ ├── positionwise_feed_forward.py │ │ │ │ ├── subsampling.py │ │ │ │ └── swish.py │ │ │ └── utils │ │ │ │ ├── checkpoint.py │ │ │ │ ├── cmvn.py │ │ │ │ ├── common.py │ │ │ │ ├── config.py │ │ │ │ ├── ctc_util.py │ │ │ │ ├── executor.py │ │ │ │ ├── file_utils.py │ │ │ │ ├── mask.py │ │ │ │ └── scheduler.py │ │ └── wenet_infer.py │ ├── datasetsss.py │ ├── inference.py │ ├── pth2onnx.py │ ├── requirements.txt │ ├── syncnet.py │ ├── train.py │ └── unet.py └── wav2lip-onnx-256 │ ├── .gitignore │ ├── README.md │ ├── audio.py │ ├── checkpoints │ └── checkpoints.txt │ ├── convert2onnx_256 │ ├── conv.py │ ├── export.py │ └── wav2lip_256.py │ ├── hparams.py │ ├── inference_onnxModel.py │ ├── insightface_func │ ├── __init__.py │ ├── face_detect_crop_single.py │ └── utils │ │ └── face_align_ffhqandnewarc.py │ ├── requirements.txt │ ├── setup.txt │ └── temp │ └── temp.txt ├── requirements.txt └── resources └── audios ├── default_audio.wav └── default_audio_hu.npy /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Virtual Environment 7 | venv/ 8 | env/ 9 | .env 10 | 11 | # IDE 12 | .vscode/ 13 | .idea/ 14 | *.swp 15 | *.swo 16 | 17 | # Logs 18 | *.log 19 | 20 | # Database 21 | *.sqlite3 22 | 23 | # OS generated files 24 | .DS_Store 25 | .DS_Store? 26 | ._* 27 | .Spotlight-V100 28 | .Trashes 29 | ehthumbs.db 30 | Thumbs.db 31 | 32 | # FastAPI specific 33 | .pytest_cache/ 34 | 35 | # Dependency directories 36 | node_modules/ 37 | 38 | # Distribution / packaging 39 | .Python 40 | build/ 41 | develop-eggs/ 42 | dist/ 43 | downloads/ 44 | eggs/ 45 | .eggs/ 46 | lib/ 47 | lib64/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | wheels/ 52 | pip-wheel-metadata/ 53 | share/python-wheels/ 54 | *.egg-info/ 55 | .installed.cfg 56 | *.egg 57 | 58 | # PyInstaller 59 | *.manifest 60 | *.spec 61 | 62 | # Installer logs 63 | pip-log.txt 64 | pip-delete-this-directory.txt 65 | 66 | # Unit test / coverage reports 67 | htmlcov/ 68 | .tox/ 69 | .nox/ 70 | .coverage 71 | .coverage.* 72 | .cache 73 | nosetests.xml 74 | coverage.xml 75 | *.cover 76 | *.py,cover 77 | .hypothesis/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # Environments 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .dmypy.json 106 | dmypy.json 107 | 108 | # Pyre type checker 109 | .pyre/ 110 | 111 | # 大型模型文件 112 | *.model 113 | *.pkl 114 | *.h5 115 | *.pt 116 | *.pth 117 | 118 | # 排除.github目录 119 | .github/ 120 | /data 121 | -------------------------------------------------------------------------------- /app/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/app/api/__init__.py -------------------------------------------------------------------------------- /app/api/digital_human_voices.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends, HTTPException, Query 2 | from sqlalchemy.orm import Session 3 | import logging 4 | from ..database import get_db 5 | from ..models.digital_human_voice import DigitalHumanVoice 6 | from ..schemas.digital_human_voice import DigitalHumanVoice as DigitalHumanVoiceSchema, DigitalHumanVoiceCreate, DigitalHumanVoiceUpdate 7 | from ..utils.response_utils import success_response, error_response 8 | from ..utils import media_utils 9 | from ..schemas.response import ApiResponse, PaginatedResponse 10 | from ..utils.user_utils import get_user_id 11 | 12 | router = APIRouter() 13 | 14 | # 设置日志记录 15 | logging.basicConfig(level=logging.INFO) 16 | logger = logging.getLogger(__name__) 17 | 18 | @router.get("/", response_model=ApiResponse[PaginatedResponse[DigitalHumanVoiceSchema]]) 19 | def list_digital_human_voices( 20 | page: int = Query(1, description="当前页码"), 21 | page_size: int = Query(10, description="每页记录数"), 22 | status: int = Query(None, description="状态 0-AI克隆训练中,1-可用,2-失败"), 23 | name: str = Query(None, description="声音名称"), 24 | db: Session = Depends(get_db) 25 | ): 26 | """ 27 | 获取数字人声音列表 28 | 29 | 这个端点返回一个数字人声音列表,支持分页 30 | - skip: 跳过的记录数,用于分页 31 | - limit: 返回的最大记录数,用于分页 32 | - name: 按名称搜索 33 | """ 34 | query = db.query(DigitalHumanVoice).filter(DigitalHumanVoice.is_deleted == False, DigitalHumanVoice.user_id == get_user_id()) 35 | 36 | if status is not None: 37 | query = query.filter(DigitalHumanVoice.status == status) 38 | if name is not None: 39 | query = query.filter(DigitalHumanVoice.name.like(f"%{name}%")) 40 | 41 | 42 | total = query.count() 43 | skip = (page - 1) * page_size 44 | voices = query.order_by(DigitalHumanVoice.created_at.desc()).offset(skip).limit(page_size).all() 45 | 46 | # 转换路径为URL 47 | for voice in voices: 48 | voice.sample_audio_url = media_utils.convert_path_to_url(voice.sample_audio_url) 49 | return success_response(data=PaginatedResponse(items=voices, total=total)) 50 | -------------------------------------------------------------------------------- /app/api/file_deal.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Query 2 | from fastapi.responses import FileResponse 3 | import os 4 | 5 | from app.utils import media_utils 6 | 7 | router = APIRouter() 8 | @router.get("/download") 9 | def download_file(url: str = Query(..., description="文件的HTTP地址")): 10 | # 使用 convert_url_to_path 方法将 URL 转换为系统内部路径 11 | file_path = media_utils.convert_url_to_path(url) 12 | print(f"文件路径: {file_path}") 13 | # 检查文件是否存在 14 | if not os.path.exists(file_path): 15 | raise HTTPException(status_code=404, detail="File not found") 16 | 17 | # 返回文件响应 18 | return FileResponse(file_path, media_type='application/octet-stream', filename=os.path.basename(file_path)) 19 | 20 | 21 | @router.get("/get_preSign_url") 22 | def get_presigned_url(filename: str = Query(..., description="文件名称")): 23 | from app.services.upload_service import generate_presigned_url 24 | from app.utils.response_utils import success_response 25 | 26 | # 调用生成预签名URL的方法 27 | result = generate_presigned_url(filename) 28 | 29 | if not result: 30 | raise HTTPException(status_code=500, detail="生成预签名URL失败") 31 | 32 | return success_response( 33 | data=result, 34 | message="获取预签名URL成功" 35 | ) 36 | -------------------------------------------------------------------------------- /app/api/font.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends, HTTPException 2 | from sqlalchemy.orm import Session 3 | from ..database import get_db 4 | from ..models.font import Font 5 | from ..schemas.font import Font as FontSchema, FontCreate, FontUpdate 6 | from ..utils.response_utils import success_response, error_response 7 | from ..schemas.response import ApiResponse, PaginatedResponse 8 | 9 | router = APIRouter() 10 | 11 | @router.get("/", response_model=ApiResponse[PaginatedResponse[FontSchema]]) 12 | def list_fonts(page: int = 1, page_size: int = 10, db: Session = Depends(get_db)): 13 | """获取字体列表,支持分页""" 14 | query = db.query(Font) 15 | total = query.count() 16 | skip = (page - 1) * page_size 17 | fonts = query.offset(skip).limit(page_size).all() 18 | return success_response(data=PaginatedResponse(items=fonts, total=total)) -------------------------------------------------------------------------------- /app/api/short_videos.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from fastapi import APIRouter, Depends, Query 3 | from sqlalchemy.orm import Session 4 | from typing import List 5 | from datetime import datetime, date 6 | 7 | from ..database import get_db 8 | from ..models.short_video import ShortVideo 9 | from ..schemas.short_video import ShortVideo as ShortVideoSchema, ShortVideoCreate, ShortVideoUpdate 10 | from ..services.ffmpeg_service import FFmpegService 11 | from ..utils import media_utils 12 | from ..utils.response_utils import success_response, error_response 13 | from ..schemas.response import ApiResponse, PaginatedResponse 14 | from ..utils.user_utils import get_user_id 15 | 16 | router = APIRouter() 17 | 18 | @router.get("/", response_model=ApiResponse[PaginatedResponse[ShortVideoSchema]]) 19 | def list_short_videos( 20 | page: int = Query(1, description="当前页码"), 21 | page_size: int = Query(10, description="每页记录数"), 22 | name: str = Query(None, description="视频名称"), 23 | start_time: datetime = Query(None, description="生成时间开始"), 24 | end_time: datetime = Query(None, description="生成时间结束"), 25 | video_type: str = Query(None, description="视频类型"), 26 | status: int = Query(None, description="状态 0-生成中,1-已生成,2-生成失败需要重试"), 27 | db: Session = Depends(get_db) 28 | ): 29 | """获取短视频列表""" 30 | query = db.query(ShortVideo).filter(ShortVideo.is_deleted == False, ShortVideo.user_id == get_user_id()) 31 | 32 | if name: 33 | query = query.filter(ShortVideo.title.ilike(f"%{name}%")) 34 | if start_time: 35 | query = query.filter(ShortVideo.created_at >= start_time) 36 | if end_time: 37 | query = query.filter(ShortVideo.created_at <= end_time) 38 | if video_type: 39 | query = query.filter(ShortVideo.type == video_type) 40 | if status is not None: 41 | query = query.filter(ShortVideo.status == status) 42 | 43 | total = query.count() 44 | skip = (page - 1) * page_size 45 | short_videos = query.order_by(ShortVideo.created_at.desc()).offset(skip).limit(page_size).all() 46 | 47 | # 转换路径为URL 48 | for short_video in short_videos: 49 | short_video.video_cover = media_utils.convert_path_to_url(short_video.video_cover) 50 | short_video.video_url = media_utils.convert_path_to_url(short_video.video_url) 51 | 52 | return success_response(data=PaginatedResponse(items=short_videos, total=total)) 53 | 54 | 55 | @router.delete("/{short_video_id}", response_model=ApiResponse[ShortVideoSchema]) 56 | def delete_short_video(short_video_id: int, db: Session = Depends(get_db)): 57 | """ 58 | 软删除短视频 59 | 60 | 这个端点用于软删除指定 ID 的短视频(将 is_deleted 设置为 True) 61 | """ 62 | db_short_video = db.query(ShortVideo).filter( 63 | ShortVideo.id == short_video_id, 64 | ShortVideo.is_deleted == False, 65 | ShortVideo.user_id == get_user_id() 66 | ).first() 67 | if not db_short_video: 68 | return error_response(code=404, message="短视频不存在") 69 | 70 | db_short_video.is_deleted = True 71 | db.commit() 72 | return success_response(data=db_short_video, message="成功删除短视频") 73 | -------------------------------------------------------------------------------- /app/api/tasks.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends, Query 2 | from sqlalchemy.orm import Session 3 | from typing import Optional 4 | from datetime import datetime 5 | from ..database import get_db 6 | from ..models.task import Task 7 | from ..schemas.task import Task as TaskSchema, TaskCreate, TaskUpdate 8 | from ..utils.response_utils import success_response, error_response 9 | from ..schemas.response import ApiResponse, PaginatedResponse 10 | 11 | router = APIRouter() 12 | 13 | @router.get("/", response_model=ApiResponse[PaginatedResponse[TaskSchema]]) 14 | def list_tasks( 15 | page: int = Query(1, description="当前页码"), 16 | page_size: int = Query(10, description="每页记录数"), 17 | name: Optional[str] = Query(None, description="任务名称"), 18 | status: Optional[int] = Query(None, description="任务状态"), 19 | start_time: Optional[datetime] = Query(None, description="计划执行时间开始"), 20 | end_time: Optional[datetime] = Query(None, description="计划执行时间结束"), 21 | task_type: Optional[str] = Query(None, description="任务类型"), 22 | db: Session = Depends(get_db) 23 | ): 24 | """获取任务列表""" 25 | query = db.query(Task).filter(Task.is_deleted == False) 26 | 27 | if name: 28 | query = query.filter(Task.name.ilike(f"%{name}%")) 29 | if status is not None: 30 | query = query.filter(Task.status == status) 31 | if start_time: 32 | query = query.filter(Task.scheduled_time >= start_time) 33 | if end_time: 34 | query = query.filter(Task.scheduled_time <= end_time) 35 | if task_type: 36 | query = query.filter(Task.type == task_type) 37 | 38 | total = query.count() 39 | skip = (page - 1) * page_size 40 | tasks = query.offset(skip).limit(page_size).all() 41 | 42 | return success_response(data=PaginatedResponse(items=tasks, total=total)) 43 | 44 | 45 | @router.get("/{task_id}", response_model=ApiResponse[TaskSchema]) 46 | def read_task(task_id: int, db: Session = Depends(get_db)): 47 | """获取特定任务的详细信息""" 48 | task = db.query(Task).filter(Task.id == task_id, Task.is_deleted == False).first() 49 | if not task: 50 | return error_response(code=404, message="任务不存在") 51 | return success_response(data=task) 52 | 53 | 54 | @router.delete("/{task_id}", response_model=ApiResponse[TaskSchema]) 55 | def delete_task(task_id: int, db: Session = Depends(get_db)): 56 | """软删除任务""" 57 | db_task = db.query(Task).filter(Task.id == task_id, Task.is_deleted == False).first() 58 | if not db_task: 59 | return error_response(code=404, message="任务不存在") 60 | 61 | db_task.is_deleted = True 62 | db.commit() 63 | return success_response(data=db_task, message="成功删除任务") 64 | 65 | -------------------------------------------------------------------------------- /app/database.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine 2 | from sqlalchemy.ext.declarative import declarative_base 3 | from sqlalchemy.orm import sessionmaker 4 | from sqlalchemy.pool import QueuePool 5 | import os 6 | from pathlib import Path 7 | from dotenv import load_dotenv 8 | 9 | # 加载 .env 文件中的环境变量 10 | load_dotenv() 11 | 12 | # 获取项目根目录 13 | BASE_DIR = Path(os.getenv("PROJECT_ROOT")) 14 | 15 | # 确保 data 目录存在 16 | (BASE_DIR / 'data').mkdir(exist_ok=True) 17 | 18 | # 创建 SQLite 数据库文件的路径 19 | DATABASE_URL = f"sqlite:///{BASE_DIR / 'data' / 'app.db'}" 20 | 21 | # 创建 SQLAlchemy 引擎 22 | engine = create_engine( 23 | DATABASE_URL, 24 | connect_args={"check_same_thread": False, "timeout": 10}, 25 | poolclass=QueuePool, 26 | pool_size=50, 27 | max_overflow=100 28 | ) 29 | 30 | # 创建会话工厂 31 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 32 | 33 | # 创建基本模型类,所有的 ORM 模型都将继承这个类 34 | Base = declarative_base() 35 | 36 | def get_db(): 37 | """ 38 | 创建一个数据库会话的生成器函数 39 | 40 | 每次调用时创建一个新的数据库会话,并在使用完毕后关闭它 41 | 这个函数将被用作 FastAPI 的依赖项 42 | """ 43 | db = SessionLocal() 44 | try: 45 | yield db 46 | finally: 47 | db.close() 48 | -------------------------------------------------------------------------------- /app/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/app/models/__init__.py -------------------------------------------------------------------------------- /app/models/digital_human_avatar.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, DateTime, Boolean 2 | from sqlalchemy.sql import func 3 | from ..database import Base 4 | 5 | class DigitalHumanAvatar(Base): 6 | """数字人形象模型 7 | 8 | 定义数字人形象在数据库中的结构 9 | 10 | 属性: 11 | id (int): 主键,自动递增 12 | name (str): 形象名称 13 | type (int): 形象类型,0表示公共,1表示个人 14 | created_at (datetime): 创建时间,默认为当前时间 15 | finished_at (datetime): 完成时间 16 | status (int): 状态,0表示AI克隆训练中,1表示克隆完成,2表示克隆失败 17 | status_msg (str): 训练中的进度和预计耗时 18 | is_deleted (bool): 删除状态,True表示已删除,False表示未删除 19 | description (str): 形象描述 20 | audio_path (str): 音频文件路径 21 | audio_prompt_npy_path (str): 音频提示npy文件路径 22 | welcome_audio_path (str): 欢迎语音频路径 23 | welcome_video_path (str): 欢迎视频路径 24 | human_id (str): 数字人形象唯一标识符 25 | user_id (str): 用户ID 26 | no_green_video_path (str): 去除绿幕后的视频路径 27 | no_green_cover_image_path (str): 去除绿幕后的封面图片路径 28 | """ 29 | 30 | __tablename__ = "digital_human_avatars" 31 | 32 | id = Column(Integer, primary_key=True, index=True) 33 | name = Column(String(100), index=True, nullable=False) 34 | 35 | type = Column(Integer, default=1, nullable=False) # 0表示公共,1表示个人 36 | created_at = Column(DateTime(timezone=True), nullable=False) 37 | finished_at = Column(DateTime(timezone=True), nullable=True) 38 | status = Column(Integer, default=0, nullable=False) # 0:AI克隆训练中,1:克隆完成,2:克隆失败 39 | status_msg = Column(String(20), default="", nullable=False) #训练中的进度和预计耗时 40 | is_deleted = Column(Boolean, default=False, nullable=False) 41 | description = Column(String(500)) 42 | 43 | video_path = Column(String(255), nullable=False) 44 | audio_path = Column(String(255), nullable=True) 45 | audio_prompt_npy_path = Column(String(255), nullable=True) 46 | 47 | welcome_audio_path = Column(String(255), nullable=True) 48 | welcome_video_path = Column(String(255), nullable=True) 49 | human_id = Column(String(255), nullable=True) 50 | user_id = Column(String(100), nullable=True) 51 | 52 | no_green_video_path = Column(String(255), nullable=True) 53 | no_green_cover_image_path = Column(String(255), nullable=True) 54 | no_green_cover_image_width = Column(Integer, nullable=True) 55 | no_green_cover_image_height = Column(Integer, nullable=True) 56 | 57 | TYPE_MAPPING = {0: "公共", 1: "个人"} 58 | STATUS_MAPPING = {0: "AI克隆训练中", 1: "克隆完成", 2: "克隆失败"} 59 | 60 | @property 61 | def type_name(self): 62 | """返回形象类型的文字描述""" 63 | return self.TYPE_MAPPING.get(self.type, "未知") 64 | 65 | @property 66 | def status_name(self): 67 | """返回状态的文字描述""" 68 | return self.STATUS_MAPPING.get(self.status, "未知") 69 | 70 | def __repr__(self): 71 | """返回数字人形象对象的字符串表示""" 72 | return f"" 73 | -------------------------------------------------------------------------------- /app/models/digital_human_voice.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, DateTime, Boolean 2 | from sqlalchemy.sql import func 3 | from ..database import Base 4 | 5 | class DigitalHumanVoice(Base): 6 | """数字人声音模型 7 | 8 | 这个模型定义了数字人声音在数据库中的结构 9 | 10 | 属性: 11 | id (int): 主键,自动递增 12 | user_id (str): 用户ID 13 | name (str): 声音名称 14 | file_path (str) : 音频路径 15 | type (int): 类型, 0 代表公共, 1 代表个人 16 | status (int): 状态, 0-AI克隆训练中, 1-克隆完成, 2-克隆失败 17 | is_deleted (bool): 删除状态, False 表示未删除, True 表示已删除 18 | created_at (datetime): 创建时间 19 | finished_at (datetime): 完成时间 20 | voice_id (str): 声音唯一标识符 21 | sample_audio_url (str): 示例音频地址 22 | """ 23 | 24 | __tablename__ = "digital_human_voices" 25 | 26 | id = Column(Integer, primary_key=True, index=True) 27 | user_id = Column(String(100), nullable=False) 28 | name = Column(String(100), nullable=False) 29 | file_path = Column(String(255), nullable=False) 30 | npy_path = Column(String(255), nullable=True) 31 | npy_prompt_text = Column(String(255), nullable=True) 32 | type = Column(Integer, default=1, nullable=False) 33 | status = Column(Integer, default=0, nullable=False) 34 | status_msg = Column(String(20), default="", nullable=False) #训练中的进度和预计耗时 35 | is_deleted = Column(Boolean, default=False, nullable=False) 36 | created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) 37 | finished_at = Column(DateTime(timezone=True), nullable=True) 38 | voice_id = Column(String(255), nullable=True) 39 | sample_audio_url = Column(String(255), nullable=True) 40 | 41 | TYPE_MAPPING = {0: "公共", 1: "个人"} 42 | STATUS_MAPPING = {0: "AI克隆训练中", 1: "克隆完成", 2: "克隆失败"} 43 | 44 | @property 45 | def type_name(self): 46 | """返回类型的文字描述""" 47 | return self.TYPE_MAPPING.get(self.type, "未知") 48 | 49 | @property 50 | def status_name(self): 51 | """返回状态的文字描述""" 52 | return self.STATUS_MAPPING.get(self.status, "未知") 53 | 54 | def __repr__(self): 55 | """返回数字人声音对象的字符串表示""" 56 | return f"" 57 | -------------------------------------------------------------------------------- /app/models/font.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String 2 | from ..database import Base 3 | 4 | class Font(Base): 5 | """字体模型""" 6 | __tablename__ = "fonts" 7 | 8 | id = Column(Integer, primary_key=True, autoincrement=True) 9 | name = Column(String, nullable=False, unique=True) 10 | nickname = Column(String, nullable=True) 11 | font_path = Column(String, nullable=False) -------------------------------------------------------------------------------- /app/models/publishing_plan_detail.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, DateTime, Boolean 2 | from ..database import Base 3 | 4 | class PublishingPlanDetail(Base): 5 | """发布计划详情模型""" 6 | __tablename__ = "publishing_plan_details" 7 | 8 | # 主键ID 9 | id = Column(Integer, primary_key=True, index=True) 10 | # 关联的发布计划ID 11 | publishing_plan_id = Column(Integer, index=True) 12 | # 视频ID 13 | video_id = Column(String, index=True) 14 | # 视频标题 15 | video_title = Column(String) 16 | # 封面URL 17 | cover_url = Column(String) 18 | # 视频URL 19 | video_url = Column(String) 20 | # 计划发布时间 21 | publish_time = Column(DateTime) 22 | # 是否立即发布 23 | is_publish_immediately = Column(Boolean, default=False, nullable=False) 24 | # 发布状态(1: 发布中, 2: 发布失败, 3: 发布成功, 4: 发布取消) 25 | publish_status = Column(Integer, nullable=False) 26 | # 任务ID(可为空) 27 | task_id = Column(String, nullable=True) 28 | # 渠道名称(可能的值:douyin, wechat, xiaohongshu, kuaishou,以逗号分隔的字符串) 29 | channel_names = Column(String, nullable=False) 30 | # 账号列表(以逗号分隔的字符串,非必填) 31 | account_list = Column(String, nullable=True) 32 | # 分组(非必填) 33 | group_name = Column(String, nullable=True) 34 | 35 | # 发布状态常量 36 | STATUS_PUBLISHING = 1 37 | STATUS_FAILED = 2 38 | STATUS_SUCCESS = 3 39 | STATUS_CANCELLED = 4 40 | -------------------------------------------------------------------------------- /app/models/short_video.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey 2 | from sqlalchemy.sql import func 3 | from ..database import Base 4 | 5 | class ShortVideo(Base): 6 | """短视频模型 7 | 8 | 这个模型定义了短视频在数据库中的结构 9 | 10 | 属性: 11 | id (int): 主键,自动递增 12 | title (str): 短视频标题,最大长度20字符 13 | status (int): 状态, 0表示生成中, 1表示已生成, 2表示生成失败需要重试 14 | video_url (str): 视频文件的URL或本地存储路径,最大长度255字符 15 | video_cover (str): 视频封面文件的URL或本地存储路径,最大长度255字符 16 | type (int): 视频类型, 0表示创作, 1表示混剪 17 | created_at (datetime): 视频创建时间 18 | finished_at (datetime): 视频生成完成时间 19 | is_deleted (bool): 删除标志, False表示未删除, True表示已删除 20 | short_videos_detail_id (int): 视频详情id,外键 21 | user_id (str): 用户ID,最大长度100字符 22 | """ 23 | 24 | __tablename__ = "short_videos" 25 | 26 | id = Column(Integer, primary_key=True, index=True, autoincrement=True) 27 | title = Column(String(100), nullable=False, comment="短视频的标题") 28 | status = Column(Integer, nullable=False, default=0, comment="短视频状态:0表示生成中,1表示已生成,2表示生成失败需要重试") 29 | status_msg = Column(String(20), nullable=True, default="", comment="短视频状态信息") 30 | video_url = Column(String(255), nullable=True, comment="视频文件的URL或本地存储路径") 31 | video_cover = Column(String(255), nullable=True, comment="视频封面文件的URL或本地存储路径") 32 | type = Column(Integer, nullable=False, default=0, comment="视频类型:0表示创作,1表示混剪") 33 | created_at = Column(DateTime, nullable=False, server_default=func.now(), comment="视频创建时间") 34 | finished_at = Column(DateTime, comment="视频生成完成时间") 35 | is_deleted = Column(Boolean, nullable=False, default=False, comment="删除标志:False表示未删除,True表示已删除") 36 | short_videos_detail_id = Column(Integer, nullable=False, comment="视频详情id") 37 | user_id = Column(String(100), nullable=True, comment="用户ID") 38 | 39 | STATUS_MAPPING = {0: "生成中", 1: "已生成", 2: "生成失败,需要重试"} 40 | TYPE_MAPPING = {0: "创作", 1: "混剪"} 41 | 42 | @property 43 | def status_name(self): 44 | """返回状态的文字描述""" 45 | return self.STATUS_MAPPING.get(self.status, "未知") 46 | 47 | @property 48 | def type_name(self): 49 | """返回类型的文字描述""" 50 | return self.TYPE_MAPPING.get(self.type, "未知") 51 | 52 | def __repr__(self): 53 | """返回短视频对象的字符串表示""" 54 | return f"" 55 | -------------------------------------------------------------------------------- /app/models/short_video_detail.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import JSON, Column, Integer, Float, String 2 | from ..database import Base 3 | 4 | class ShortVideoDetail(Base): 5 | """短视频详情模型类""" 6 | 7 | __tablename__ = "short_video_details" 8 | 9 | id = Column(Integer, primary_key=True, index=True) 10 | user_id = Column(Integer, comment="用户ID") 11 | video_title = Column(String(100), nullable=False, comment="视频标题") 12 | script_content = Column(String(255), comment="文案内容") 13 | video_duration = Column(Integer, comment="生成的视频时长(秒)") 14 | generation_count = Column(Integer, default=1, nullable=False, comment="生成数量") 15 | 16 | # 视频设置 17 | video_layout = Column(Integer, default=2, nullable=False, comment="视频布局(1-横屏,2-竖屏)") 18 | video_frame_rate = Column(Integer, default=25, nullable=False, comment="视频帧率(25,30,50,60)") 19 | resolution = Column(Integer, default=3, nullable=False, comment="分辨率(1-480p,2-720p,3-1080p,4-2k,5-4k)") 20 | export_format = Column(Integer, default=1, nullable=False, comment="导出格式(1-mp4,2-mov)") 21 | 22 | # 数字人设置 23 | digital_human_avatars_type = Column(Integer, default=1, nullable=False, comment="数字人形象类型(0远程,1本地)") 24 | digital_human_avatars_download_url = Column(String(255), comment="远程:模型压缩包下载地址") # video_path 25 | digital_human_avatars_id = Column(Integer, comment="人物id") 26 | digital_human_avatars_position = Column(String(20), default='0,0', comment="人物位置") 27 | digital_human_avatars_scale = Column(Float, default=1, nullable=False, comment="人物缩放比例") 28 | digital_human_avatars_human_id = Column(String(20), default='', comment="远程:human_id") # human_id 29 | digital_human_avatars_no_green_cover_image_width = Column(Integer, comment="远程:远程数字人宽") # no_green_cover_image_width 30 | digital_human_avatars_no_green_cover_image_height = Column(Integer, comment="远程:远程数字人宽") # no_green_cover_image_height 31 | 32 | # 配音设置 33 | voice_material_type = Column(Integer, default=1, nullable=False, comment="配音素材类型(1本地,0远程)") 34 | voice_switch = Column(Integer, default=0, comment="配音真人录制(0-关闭,1-开启)") 35 | voice_speed = Column(Float, default=1, comment="配音语速") 36 | voice_volume = Column(Float, default=1, comment="配音音量") 37 | voice_id = Column(Integer, comment="配音声音id") 38 | voice_path = Column(String(255), comment="声音文件路径") 39 | voice_download_url = Column(String(255), comment="远程:声音素材模型压缩包下载地址") # file_path 40 | voice_preview_url = Column(String(255), comment="声音素材预览地址") 41 | voice_resource_id = Column(String(50), comment="声音素材资源ID") 42 | voice_npy_prompt_text = Column(String(500), comment="远程:npy提示文本") # npy_prompt_text 43 | voice_voice_id = Column(String(500), comment="远端:voice_id") # voice_id 44 | 45 | 46 | # 字幕设置 47 | subtitle_switch = Column(Integer, default=0, nullable=False, comment="字幕开关(0-关闭,1-开启)") 48 | font_id = Column(Integer, comment="字体id") 49 | font_size = Column(Integer, default=16, nullable=False, comment="字体大小") 50 | font_color = Column(String(20), default='#000000', nullable=False, comment="字体颜色") 51 | font_position = Column(String(20), default='0,0', nullable=False, comment="字幕位置") 52 | font_path = Column(String(255), comment="字体文件路径") 53 | font_name = Column(String(50), comment="字体名称") 54 | -------------------------------------------------------------------------------- /app/models/task.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, DateTime 2 | from sqlalchemy.sql import func 3 | from ..database import Base 4 | 5 | class Task(Base): 6 | """任务模型 7 | 8 | 这个模型定义了简化后的任务在数据库中的结构 9 | 10 | 属性: 11 | id (int): 主键,自动递增 12 | name (str): 任务名称 13 | start_time (datetime): 开始时间 14 | end_time (datetime): 结束时间 15 | result (str): 执行结果 16 | status (int): 任务状态 (0: 执行中, 1: 执行成功, 2: 执行失败, 3: 取消执行) 17 | """ 18 | 19 | __tablename__ = "tasks" 20 | 21 | id = Column(String, primary_key=True, index=True) 22 | name = Column(String, nullable=False) 23 | start_time = Column(DateTime, nullable=False) 24 | end_time = Column(DateTime) 25 | result = Column(String) 26 | status = Column(Integer, nullable=False, default=0) 27 | 28 | # 状态映射 29 | STATUS_MAPPING = { 30 | 0: "执行中", 31 | 1: "执行成功", 32 | 2: "执行失败", 33 | 3: "取消执行" 34 | } 35 | 36 | def __repr__(self): 37 | status_str = self.STATUS_MAPPING.get(self.status, "未知状态") 38 | return f"" 39 | -------------------------------------------------------------------------------- /app/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/app/schemas/__init__.py -------------------------------------------------------------------------------- /app/schemas/digital_human_avatar.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field,validator 2 | from datetime import datetime 3 | from typing import Optional 4 | 5 | class DigitalHumanAvatarBase(BaseModel): 6 | """数字人形象基础模式 7 | 8 | 这个模式定义了创建和更新数字人形象时共用的字段 9 | """ 10 | name: str = Field(..., description="形象名称") 11 | 12 | type: Optional[int] = Field(default=1, description="形象类型, 0 代表公共, 1 代表个人") 13 | status: Optional[int] = Field(default=0, description="状态, 0 表示 AI 克隆训练中, 1 表示克隆完成, 2 表示克隆失败") 14 | description: Optional[str] = Field(None, description="形象描述") 15 | audio_path: Optional[str] = Field(None, description="音频文件路径") 16 | audio_prompt_npy_path: Optional[str] = Field(None, description="语音prompt_npy路径") 17 | no_green_video_path: Optional[str] = Field(None, description="去除绿幕后的视频路径") 18 | no_green_cover_image_path: Optional[str] = Field(None, description="去除绿幕后的封面图片路径") 19 | no_green_cover_image_width: Optional[int] = Field(None, description="去除绿幕后的封面图片宽度") 20 | no_green_cover_image_height: Optional[int] = Field(None, description="去除绿幕后的封面图片高度") 21 | welcome_audio_path: Optional[str] = Field(None, description="欢迎语音频路径") 22 | welcome_video_path: Optional[str] = Field(None, description="欢迎视频路径") 23 | human_id: Optional[str] = Field(None, description="数字人形象唯一标识符") 24 | video_path: str = Field(..., description="形象视频文件路径") 25 | 26 | class DigitalHumanAvatarCreate(DigitalHumanAvatarBase): 27 | """创建数字人形象的请求模式""" 28 | pass 29 | 30 | class DigitalHumanAvatar(DigitalHumanAvatarBase): 31 | """数字人形象的响应模 32 | 33 | 这个模式包含了从数据库返回的所有字段 34 | """ 35 | id: int 36 | created_at: datetime 37 | is_deleted: bool = False 38 | no_green_cover_image_width: Optional[int] 39 | no_green_cover_image_height: Optional[int] 40 | no_green_cover_image_path: Optional[str] 41 | 42 | 43 | class Config: 44 | orm_mode = True 45 | json_encoders = { 46 | datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S") 47 | } 48 | #orm_mode = True 49 | 50 | class DigitalHumanAvatarUpdate(BaseModel): 51 | """更新数字人形象的请求模式 52 | 53 | 所有字段都是可选的,允许部分更新 54 | """ 55 | name: Optional[str] = None 56 | video_path: Optional[str] = None 57 | type: Optional[int] = None 58 | status: Optional[int] = None 59 | description: Optional[str] = None 60 | audio_path: Optional[str] = None 61 | audio_prompt_npy_path: Optional[str] = None 62 | no_green_video_path: Optional[str] = None 63 | no_green_cover_image_path: Optional[str] = None 64 | welcome_audio_path: Optional[str] = None 65 | welcome_video_path: Optional[str] = None 66 | human_id: Optional[str] = None 67 | -------------------------------------------------------------------------------- /app/schemas/digital_human_voice.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field, validator 2 | from datetime import datetime 3 | from typing import Optional 4 | 5 | class DigitalHumanVoiceBase(BaseModel): 6 | """数字人声音基础模式 7 | 8 | 这个模式定义了创建和更新数字人声音时共用的字段 9 | """ 10 | user_id: Optional[str] = Field(None, description="用户ID") 11 | name: str = Field(..., description="声音名称") 12 | file_path: str = Field(..., description="音频路径") 13 | type: Optional[int] = Field(None, description="类型, 0 代表公共, 1 代表个人") 14 | sample_audio_url: Optional[str] = Field(None, description="示例音频地址") 15 | 16 | class DigitalHumanVoiceCreate(DigitalHumanVoiceBase): 17 | """创建数字人声音的请求模式""" 18 | pass 19 | 20 | class DigitalHumanVoice(DigitalHumanVoiceBase): 21 | """数字人声音的响应模式 22 | 23 | 这个模式包含了从数据库返回的所有字段 24 | """ 25 | id: int 26 | status: int 27 | status_msg: str 28 | is_deleted: bool 29 | created_at: datetime 30 | finished_at: Optional[datetime] 31 | voice_id: Optional[str] 32 | type_name: str 33 | status_name: str 34 | 35 | class Config: 36 | orm_mode = True 37 | json_encoders = { 38 | datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S") 39 | } 40 | 41 | class DigitalHumanVoiceUpdate(BaseModel): 42 | """更新数字人声音的请求模式 43 | 44 | 所有字段都是可选的,允许部分更新 45 | """ 46 | name: Optional[str] = None 47 | file_path: Optional[str] = None 48 | type: Optional[int] = None 49 | status: Optional[int] = None 50 | sample_audio_url: Optional[str] = None 51 | -------------------------------------------------------------------------------- /app/schemas/font.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional 3 | 4 | class FontBase(BaseModel): 5 | """字体基础模式""" 6 | name: str = Field(..., description="字体名称") 7 | nickname: Optional[str] = Field(None, description="字体昵称") 8 | font_path: str = Field(..., description="字体文件路径") 9 | 10 | class FontCreate(FontBase): 11 | """创建字体的请求模式""" 12 | pass 13 | 14 | class FontUpdate(BaseModel): 15 | """更新字体的请求模式""" 16 | name: Optional[str] = None 17 | nickname: Optional[str] = None 18 | font_path: Optional[str] = None 19 | 20 | class Font(FontBase): 21 | """字体的响应模式""" 22 | id: int 23 | 24 | class Config: 25 | orm_mode = True -------------------------------------------------------------------------------- /app/schemas/response.py: -------------------------------------------------------------------------------- 1 | from typing import List,Generic, TypeVar, Optional 2 | from pydantic import BaseModel 3 | 4 | T = TypeVar('T') 5 | 6 | class ApiResponse(BaseModel, Generic[T]): 7 | code: int 8 | data: Optional[T] = None 9 | message: Optional[str] = None 10 | 11 | 12 | class PaginatedResponse(BaseModel, Generic[T]): 13 | items: List[T] 14 | total: int 15 | -------------------------------------------------------------------------------- /app/schemas/short_video_detail.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, List, Dict 3 | 4 | 5 | class ShortVideoDetailBase(BaseModel): 6 | """短视频详情基础模式""" 7 | user_id: Optional[int] = Field(None, description="用户ID") 8 | video_title: Optional[str] = Field(None, description="视频标题") 9 | script_content: Optional[str] = Field(None, description="文案内容") 10 | video_layout: Optional[int] = Field(2, description="视频布局(1-横屏,2-竖屏)") 11 | resolution: Optional[int] = Field(3, description="分辨率(1-480p,2-720p,3-1080p,4-2k,5-4k)") 12 | video_frame_rate: Optional[int] = Field(25, description="视频帧率(1-25fps,2-30fps,3-50fps,4-60fps)") 13 | video_duration: Optional[int] = Field(None, description="生成的视频时长(秒)") 14 | export_format: Optional[int] = Field(1, description="导出格式(1-mp4,2-mov)") 15 | generation_count: Optional[int] = Field(1, description="生成数量") 16 | 17 | digital_human_avatars_type: Optional[int] = Field(1, description="数字人形象类型(0远程,1本地)") 18 | digital_human_avatars_download_url: Optional[str] = Field(default=None, description="远程:模型压缩包下载地址") 19 | digital_human_avatars_id: Optional[int] = Field(None, description="人物id") 20 | digital_human_avatars_position: Optional[str] = Field("0,0", description="人物位置") 21 | digital_human_avatars_scale: Optional[float] = Field(1, description="人物缩放比例") 22 | digital_human_avatars_human_id: Optional[str] = Field(default=None, description="远程:human_id") # human_id 23 | digital_human_avatars_no_green_cover_image_width: Optional[int] = Field(1920, description="远程数字人宽") 24 | digital_human_avatars_no_green_cover_image_height: Optional[int] = Field(1080, description="远程:远程数字人高") 25 | 26 | voice_speed: Optional[float] = Field(1.0, description="配音语速") 27 | voice_volume: Optional[float] = Field(1.0, description="配音音量") 28 | voice_id: Optional[int] = Field(None, description="配音声音id") 29 | voice_path: Optional[str] = Field(None, description="声音文件路径") 30 | voice_material_type: Optional[int] = Field(1, description="声音素材库类型(1-本地,0-远程)") 31 | voice_download_url: Optional[str] = Field(None, description="远程:声音素材模型压缩包下载地址") 32 | voice_preview_url: Optional[str] = Field(None, description="声音素材预览地址") 33 | voice_resource_id: Optional[str] = Field(None, description="声音素材资源ID") 34 | voice_npy_prompt_text: Optional[str] = Field(None, description="远程:npy提示文本") 35 | voice_voice_id: Optional[str] = Field(None, description="远端:voice_id") 36 | 37 | subtitle_switch: Optional[int] = Field(0, description="字幕开关(0-关闭,1-开启)") 38 | font_id: Optional[int] = Field(0, description="字体id") 39 | font_size: Optional[int] = Field(16, description="字体大小") 40 | font_color: Optional[str] = Field("#ffffffff", description="字体颜色") 41 | font_position: Optional[str] = Field("0,0", description="字幕位置") 42 | font_path: Optional[str] = Field(None, description="字体文件路径") 43 | font_name: Optional[str] = Field(None, description="字体名称") 44 | 45 | 46 | 47 | 48 | class ShortVideoDetailCreate(ShortVideoDetailBase): 49 | """创建短视频详情的请求模式""" 50 | pass 51 | 52 | class ShortVideoDetail(ShortVideoDetailBase): 53 | """短视频详情的响应模式""" 54 | id: int 55 | 56 | class Config: 57 | orm_mode = True 58 | -------------------------------------------------------------------------------- /app/schemas/task.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/app/schemas/task.py -------------------------------------------------------------------------------- /app/services/ffmpeg_service.py: -------------------------------------------------------------------------------- 1 | import shlex 2 | import ffmpeg 3 | from pathlib import Path 4 | import os 5 | from dotenv import load_dotenv 6 | import uuid 7 | from datetime import datetime 8 | import traceback 9 | import logging 10 | import srt 11 | import platform 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class FFmpegService: 16 | def __init__(self): 17 | load_dotenv() 18 | self.project_root = Path(os.getenv("PROJECT_ROOT")) 19 | self.video_dir = self.project_root / 'data' / 'video' 20 | self.image_dir = self.project_root / 'data' / 'image' 21 | self.music_dir = self.project_root / 'data' / 'music' 22 | self.audio_dir = self.project_root / 'data' / 'audio' 23 | self.subtitle_dir = self.project_root / 'data' / 'subtitle' # 新增字幕目录 24 | self.publish_dir = self.project_root / 'data' / 'publish' 25 | 26 | # 创建所有必要的目录 27 | for directory in [self.video_dir, self.image_dir, self.music_dir, self.audio_dir, self.subtitle_dir, self.publish_dir]: 28 | directory.mkdir(parents=True, exist_ok=True) 29 | 30 | self.default_font_path = self.project_root / 'resources' / 'fonts' / 'lipin.ttf' # 默认使用黑体 31 | 32 | # 确保字体文件存在 33 | if not self.default_font_path.exists(): 34 | logger.warning(f"默认字体文件不存在: {self.default_font_path}") 35 | 36 | def convert_video_format(self, input_path: str, output_path: str): 37 | """将视频格式从MOV转换为MP4""" 38 | # 确保输出目录存在 39 | output_dir = Path(output_path).parent 40 | output_dir.mkdir(parents=True, exist_ok=True) 41 | input_stream = ffmpeg.input(input_path) 42 | 43 | output_path = Path(output_path).as_posix() 44 | output_stream = ffmpeg.output(input_stream, output_path, vcodec='libx264', acodec='aac', strict='experimental') 45 | ffmpeg.run(output_stream, overwrite_output=True) 46 | 47 | -------------------------------------------------------------------------------- /app/services/fishspeech_service.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | import datetime 4 | import uuid 5 | import os 6 | from dotenv import load_dotenv 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | logger.setLevel(logging.DEBUG) 10 | 11 | class FishSpeechService: 12 | """FishSpeech服务类,用于语音克隆和生成。""" 13 | 14 | def __init__(self): 15 | """初始化FishSpeechService""" 16 | load_dotenv() 17 | project_root = Path(os.getenv("PROJECT_ROOT")) 18 | 19 | self.base_path = project_root / 'external_modules' / 'fish-speech' 20 | self.checkpoint_path = self.base_path / 'checkpoints' / 'fish-speech-1.4' 21 | self.vqgan_path = self.checkpoint_path / 'firefly-gan-vq-fsq-8x1024-21hz-generator.pth' 22 | self.conda_env = os.getenv("FISH_SPEECH_CONDA_ENV") 23 | 24 | def run_command(self, command): 25 | """在指定的Conda环境中运行命令。""" 26 | full_command = f"conda run -n {self.conda_env} {command}" 27 | logger.debug(f"fishspeech: 执行命令: {command}") 28 | subprocess.run(full_command, shell=True, check=True, cwd=str(self.base_path)) 29 | 30 | def clone_voice(self, audio_path, audio_prompt_wav_path): 31 | """ 32 | 克隆声音。 33 | 34 | 参数: 35 | audio_path (str): 参考音频文件的路径 36 | output_path (str): 克隆npy的路径 37 | 38 | 返回: 39 | str: 克隆声音的唯一标识符 40 | """ 41 | command = f"python tools/vqgan/inference.py -i {audio_path} --checkpoint-path {self.vqgan_path} -o {audio_prompt_wav_path}" 42 | logger.debug(f"fishspeech: 执行命令: {command}") 43 | self.run_command(command) 44 | return audio_prompt_wav_path 45 | 46 | def generate_speech(self, text, prompt_text, prompt_npy_path, output_npy_path, output_wav_path): 47 | """ 48 | 生成语音。 49 | 50 | 参数: 51 | text (str): 要转换为语音的文本 52 | prompt_npy_path (str): 语意prompt_npy的路径 53 | output_npy_path (str): 输出的npy文件路径 54 | output_wav_path (str): 输出的wav文件路径 55 | 56 | 返回: 57 | tuple: 生成的npy文件路径和wav文件路径 58 | """ 59 | # 生成语音特征 60 | self.run_command(f"python tools/llama/generate.py --text \"{text}\" --prompt-text \"{prompt_text}\" --prompt-tokens {prompt_npy_path} " 61 | f"--checkpoint-path {self.checkpoint_path} --num-samples 1 ") 62 | 63 | # 将特征转换为音频 64 | self.run_command(f"python tools/vqgan/inference.py -i codes_0.npy --checkpoint-path {self.vqgan_path} -o {output_wav_path}") 65 | return output_npy_path, output_wav_path 66 | 67 | def process_audio(self, avatar_path, wav_output_path): 68 | """处理音频文件""" 69 | try: 70 | subprocess.run([ 71 | 'ffmpeg', 72 | '-i', avatar_path, 73 | '-acodec', 'pcm_s16le', 74 | '-ar', '16000', 75 | wav_output_path 76 | ], check=True, capture_output=True, text=True) 77 | except subprocess.CalledProcessError as e: 78 | raise 79 | return wav_output_path 80 | 81 | if __name__ == "__main__": 82 | # 使用示例 83 | fish_speech = FishSpeechService() 84 | 85 | unique_id = fish_speech.clone_voice("/path/to/reference_audio.wav", "unique_voice_id") 86 | output_audio = fish_speech.generate_speech("这是一测试文本", unique_id) 87 | print(f"生成的音频文件路径: {output_audio}") 88 | -------------------------------------------------------------------------------- /app/services/upload_service.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | import boto3 3 | import os 4 | from botocore.config import Config 5 | 6 | # 加载 .env 文件 7 | load_dotenv() 8 | 9 | # 从环境变量中获取配置 10 | endpoint = os.getenv('OSS_ENDPOINT') 11 | access_key_id = os.getenv('OSS_ACCESS_KEY_ID') 12 | secret_access_key = os.getenv('OSS_SECRET_ACCESS_KEY') 13 | bucket_name = os.getenv('OSS_BUCKET_NAME') 14 | 15 | # 打印以验证加载的值(仅用于调试,生产环境中请移除) 16 | print(f"Endpoint: {endpoint}") 17 | print(f"Access Key ID: {access_key_id}") 18 | print(f"Secret Access Key: {secret_access_key}") 19 | print(f"bucket_name: {bucket_name}") 20 | 21 | # 创建S3客户端 22 | s3 = boto3.client( 23 | 's3', 24 | aws_access_key_id=access_key_id, 25 | aws_secret_access_key=secret_access_key, 26 | endpoint_url=endpoint, 27 | config=Config(s3={"addressing_style": "virtual"}, 28 | signature_version='v4') 29 | ) 30 | 31 | def generate_presigned_url(object_name, expiration=3600): 32 | """ 33 | 生成用于上传文件的预签名URL。 34 | 35 | :param bucket_name: 存储桶名称 36 | :param object_name: 对象名称(即文件名) 37 | :param expiration: URL的有效期(秒),默认为3600秒 38 | :return: 预签名的URL 39 | """ 40 | try: 41 | bucket_name = os.getenv('OSS_BUCKET_NAME') 42 | presigned_url = s3.generate_presigned_url( 43 | 'put_object', 44 | Params={'Bucket': bucket_name, 'Key': object_name}, 45 | ExpiresIn=expiration 46 | ) 47 | return { 48 | "oriFileName": object_name, 49 | "filePath": f"{bucket_name}/{object_name}", 50 | "preSignUrl": presigned_url 51 | } 52 | 53 | except Exception as e: 54 | print(f"生成预签名URL时出错: {e}") 55 | return None -------------------------------------------------------------------------------- /app/services/wav2lip_service.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | import os 4 | from dotenv import load_dotenv 5 | from datetime import datetime 6 | import uuid 7 | 8 | class Wav2LipService: 9 | """Wav2Lip服务类,用于生成唇形同步视频。""" 10 | 11 | def __init__(self): 12 | """初始化Wav2LipService""" 13 | load_dotenv() 14 | project_root = Path(os.getenv("PROJECT_ROOT")) 15 | 16 | self.base_path = project_root / 'external_modules' / 'wav2lip-onnx-256' 17 | self.checkpoint_path = self.base_path / 'checkpoints' / 'wav2lip_256.onnx' 18 | self.conda_env = os.getenv("WAV2LIP_CONDA_ENV") 19 | self.video_dir = project_root / 'data' / 'video' 20 | self.video_dir.mkdir(parents=True, exist_ok=True) 21 | 22 | def run_command(self, command): 23 | """在指定的Conda环境中运行命令。""" 24 | full_command = f"conda run -n {self.conda_env} {command}" 25 | subprocess.run(full_command, shell=True, check=True, cwd=str(self.base_path)) 26 | 27 | def generate_video(self, face_path, audio_path, output_video_path): 28 | """ 29 | 生成唇形同步视频。 30 | 31 | 参数: 32 | face_path (str): 人脸图像或视频的路径 33 | audio_path (str): 音频文件的路径 34 | output_video_path (str): 生成的视频文件的路径 35 | 36 | 返回: 37 | Path: 生成的视频文件的路径 38 | """ 39 | # 生成唯一的文件名和日期目录 40 | current_date = datetime.now().strftime("%Y-%m-%d") 41 | output_dir = self.video_dir / current_date 42 | output_dir.mkdir(parents=True, exist_ok=True) 43 | 44 | 45 | command = (f"python inference_onnxModel.py " 46 | f"--checkpoint_path {self.checkpoint_path} " 47 | f"--face {face_path} " 48 | f"--audio {audio_path} " 49 | f"--outfile {output_video_path}") 50 | 51 | self.run_command(command) 52 | return output_video_path 53 | 54 | if __name__ == "__main__": 55 | # 使用示例 56 | wav2lip = Wav2LipService() 57 | 58 | face_path = "/path/to/face_image.jpg" 59 | audio_path = "/path/to/audio.wav" 60 | unique_id = str(uuid.uuid4()) # 生成唯一标识符 61 | 62 | result_video = wav2lip.generate_video(face_path, audio_path, unique_id) 63 | print(f"生成的视频文件路径: {result_video}") 64 | -------------------------------------------------------------------------------- /app/utils/gpu_utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def check_gpu_available(): 5 | """检查是否有可用的 NVIDIA GPU""" 6 | # try: 7 | # result = subprocess.run(['nvidia-smi'], capture_output=True, text=True) 8 | # return result.returncode == 0 9 | # except FileNotFoundError: 10 | return False -------------------------------------------------------------------------------- /app/utils/logger_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from datetime import datetime 4 | 5 | def setup_logger(): 6 | # 创建日志目录 7 | log_dir = "data/logs" 8 | if not os.path.exists(log_dir): 9 | os.makedirs(log_dir) 10 | 11 | # 获取当前日期作为文件名 12 | current_date = datetime.now().strftime("%Y-%m-%d") 13 | log_file = os.path.join(log_dir, f"app_{current_date}.log") 14 | 15 | # 设置日志格式 16 | formatter = logging.Formatter( 17 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s', 18 | datefmt='%Y-%m-%d %H:%M:%S' 19 | ) 20 | 21 | # 创建 FileHandler,使用追加模式 22 | file_handler = logging.FileHandler( 23 | filename=log_file, 24 | encoding='utf-8', 25 | mode='a' # 追加模式 26 | ) 27 | file_handler.setFormatter(formatter) 28 | 29 | # 创建 StreamHandler 用于控制台输出 30 | stream_handler = logging.StreamHandler() 31 | stream_handler.setFormatter(formatter) 32 | 33 | # 配置根日志记录器 34 | root_logger = logging.getLogger() 35 | root_logger.setLevel(logging.DEBUG) 36 | 37 | # 清除可能存在的旧处理器 38 | root_logger.handlers.clear() 39 | 40 | # 添加处理器 41 | root_logger.addHandler(stream_handler) 42 | root_logger.addHandler(file_handler) 43 | -------------------------------------------------------------------------------- /app/utils/response_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any 2 | from ..schemas.response import ApiResponse 3 | 4 | def success_response(data: Any = None, message: Optional[str] = None) -> ApiResponse: 5 | return ApiResponse(code=200, data=data, message=message) 6 | 7 | def error_response(code: int = 400, message: str = "操作失败", data: Any = None) -> ApiResponse: 8 | return ApiResponse(code=code, message=message, data=data) 9 | 10 | -------------------------------------------------------------------------------- /app/utils/user_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def get_user_id(): 3 | return 'admin' -------------------------------------------------------------------------------- /doc/pro_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/pro_01.jpg -------------------------------------------------------------------------------- /doc/pro_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/pro_02.jpg -------------------------------------------------------------------------------- /doc/pro_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/pro_03.jpg -------------------------------------------------------------------------------- /doc/pro_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/pro_04.jpg -------------------------------------------------------------------------------- /doc/pro_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/pro_05.jpg -------------------------------------------------------------------------------- /doc/pro_06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/pro_06.jpg -------------------------------------------------------------------------------- /doc/pro_07.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/pro_07.jpg -------------------------------------------------------------------------------- /doc/pro_08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/pro_08.jpg -------------------------------------------------------------------------------- /doc/pro_09.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/pro_09.jpg -------------------------------------------------------------------------------- /doc/pro_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/pro_10.jpg -------------------------------------------------------------------------------- /doc/wx_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/wx_01.jpg -------------------------------------------------------------------------------- /doc/wx_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/doc/wx_02.jpg -------------------------------------------------------------------------------- /external_modules/fish-speech/.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .github 3 | results 4 | data 5 | *.filelist 6 | /data_server/target 7 | checkpoints 8 | -------------------------------------------------------------------------------- /external_modules/fish-speech/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .pgx.* 3 | .pdm-python 4 | /fish_speech.egg-info 5 | __pycache__ 6 | /results 7 | /data 8 | /*.test.sh 9 | *.filelist 10 | filelists 11 | /fish_speech/text/cmudict_cache.pickle 12 | /checkpoints 13 | /.vscode 14 | /data_server/target 15 | /*.npy 16 | /*.wav 17 | /*.mp3 18 | /*.lab 19 | /results 20 | /data 21 | /.idea 22 | ffmpeg.exe 23 | ffprobe.exe 24 | asr-label* 25 | /.cache 26 | /fishenv 27 | /.locale 28 | /demo-audios 29 | /references 30 | /example 31 | /faster_whisper 32 | 33 | # 大型模型文件 34 | *.model 35 | *.pkl 36 | *.h5 37 | *.onnx 38 | *.pt 39 | *.pth 40 | 41 | # 排除.github目录 42 | .github/ 43 | 44 | -------------------------------------------------------------------------------- /external_modules/fish-speech/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_schedule: monthly 3 | 4 | repos: 5 | - repo: https://github.com/pycqa/isort 6 | rev: 5.13.2 7 | hooks: 8 | - id: isort 9 | args: [--profile=black] 10 | 11 | - repo: https://github.com/psf/black 12 | rev: 24.10.0 13 | hooks: 14 | - id: black 15 | 16 | - repo: https://github.com/pre-commit/pre-commit-hooks 17 | rev: v5.0.0 18 | hooks: 19 | - id: end-of-file-fixer 20 | - id: check-yaml 21 | - id: check-json 22 | - id: mixed-line-ending 23 | args: ['--fix=lf'] 24 | - id: check-added-large-files 25 | args: ['--maxkb=5000'] 26 | -------------------------------------------------------------------------------- /external_modules/fish-speech/.project-root: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/fish-speech/.project-root -------------------------------------------------------------------------------- /external_modules/fish-speech/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for MkDocs projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the version of Python and other tools you might need 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.12" 12 | 13 | mkdocs: 14 | configuration: mkdocs.yml 15 | 16 | # Optionally declare the Python requirements required to build your docs 17 | python: 18 | install: 19 | - requirements: docs/requirements.txt 20 | -------------------------------------------------------------------------------- /external_modules/fish-speech/API_FLAGS.txt: -------------------------------------------------------------------------------- 1 | # --infer 2 | # --api 3 | --listen 0.0.0.0:8080 \ 4 | --llama-checkpoint-path "checkpoints/fish-speech-1.4" \ 5 | --decoder-checkpoint-path "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth" \ 6 | --decoder-config-name firefly_gan_vq 7 | -------------------------------------------------------------------------------- /external_modules/fish-speech/README.zh.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | --- 5 | ## 特性 6 | 7 | 1. **零样本 & 小样本 TTS**:输入 10 到 30 秒的声音样本即可生成高质量的 TTS 输出。**详见 [语音克隆最佳实践指南](https://docs.fish.audio/text-to-speech/voice-clone-best-practices)。** 8 | 2. **多语言 & 跨语言支持**:只需复制并粘贴多语言文本到输入框中,无需担心语言问题。目前支持英语、日语、韩语、中文、法语、德语、阿拉伯语和西班牙语。 9 | 3. **无音素依赖**:模型具备强大的泛化能力,不依赖音素进行 TTS,能够处理任何文字表示的语言。 10 | 4. **高准确率**:在 5 分钟的英文文本上,达到了约 2% 的 CER(字符错误率)和 WER(词错误率)。 11 | 5. **快速**:通过 fish-tech 加速,在 Nvidia RTX 4060 笔记本上的实时因子约为 1:5,在 Nvidia RTX 4090 上约为 1:15。 12 | 6. **WebUI 推理**:提供易于使用的基于 Gradio 的网页用户界面,兼容 Chrome、Firefox、Edge 等浏览器。 13 | 7. **GUI 推理**:提供 PyQt6 图形界面,与 API 服务器无缝协作。支持 Linux、Windows 和 macOS。[查看 GUI](https://github.com/AnyaCoder/fish-speech-gui)。 14 | 8. **易于部署**:轻松设置推理服务器,原生支持 Linux、Windows 和 macOS,最大程度减少速度损失。 15 | 16 | 17 | ## 免责声明 18 | 19 | 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规. 20 | 21 | 22 | ## 在线 DEMO 23 | 24 | [Fish Audio](https://fish.audio) 25 | 26 | ## 快速开始本地推理 27 | 28 | [inference.ipynb](/inference.ipynb) 29 | 30 | ## 视频 31 | 32 | #### 1.4 介绍: https://www.bilibili.com/video/BV1pu46eVEk7 33 | 34 | #### 1.2 介绍: https://www.bilibili.com/video/BV1wz421B71D 35 | 36 | #### 1.1 介绍: https://www.bilibili.com/video/BV1zJ4m1K7cj 37 | 38 | ## 文档 39 | 40 | - [English](https://speech.fish.audio/) 41 | - [中文](https://speech.fish.audio/zh/) 42 | - [日本語](https://speech.fish.audio/ja/) 43 | - [Portuguese (Brazil)](https://speech.fish.audio/pt/) 44 | 45 | ## 例子 (2024/10/02 V1.4) 46 | 47 | - [English](https://speech.fish.audio/samples/) 48 | - [中文](https://speech.fish.audio/zh/samples/) 49 | - [日本語](https://speech.fish.audio/ja/samples/) 50 | - [Portuguese (Brazil)](https://speech.fish.audio/pt/samples/) 51 | 52 | ## 鸣谢 53 | 54 | - [VITS2 (daniilrobnikov)](https://github.com/daniilrobnikov/vits2) 55 | - [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2) 56 | - [GPT VITS](https://github.com/innnky/gpt-vits) 57 | - [MQTTS](https://github.com/b04901014/MQTTS) 58 | - [GPT Fast](https://github.com/pytorch-labs/gpt-fast) 59 | - [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) 60 | 61 | 62 | python tools/vqgan/inference.py -i /Users/libn/Desktop/123.wav --checkpoint-path /Users/libn/dev/project/me/auto_lead_backend/external_modules/fish-speech/checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth -o /Users/libn/Desktop/20241016162915_3e3cc3.npy -------------------------------------------------------------------------------- /external_modules/fish-speech/docker-compose.dev.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | fish-speech: 5 | build: 6 | context: . 7 | dockerfile: dockerfile.dev 8 | container_name: fish-speech 9 | volumes: 10 | - ./:/exp 11 | deploy: 12 | resources: 13 | reservations: 14 | devices: 15 | - driver: nvidia 16 | count: all 17 | capabilities: [gpu] 18 | command: tail -f /dev/null 19 | -------------------------------------------------------------------------------- /external_modules/fish-speech/dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.12-slim-bookworm AS stage-1 2 | ARG TARGETARCH 3 | 4 | ARG HUGGINGFACE_MODEL=fish-speech-1.4 5 | ARG HF_ENDPOINT=https://huggingface.co 6 | 7 | WORKDIR /opt/fish-speech 8 | 9 | RUN set -ex \ 10 | && pip install huggingface_hub \ 11 | && HF_ENDPOINT=${HF_ENDPOINT} huggingface-cli download --resume-download fishaudio/${HUGGINGFACE_MODEL} --local-dir checkpoints/${HUGGINGFACE_MODEL} 12 | 13 | FROM python:3.12-slim-bookworm 14 | ARG TARGETARCH 15 | 16 | ARG DEPENDENCIES=" \ 17 | ca-certificates \ 18 | libsox-dev \ 19 | ffmpeg" 20 | 21 | RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ 22 | --mount=type=cache,target=/var/lib/apt,sharing=locked \ 23 | set -ex \ 24 | && rm -f /etc/apt/apt.conf.d/docker-clean \ 25 | && echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache \ 26 | && apt-get update \ 27 | && apt-get -y install --no-install-recommends ${DEPENDENCIES} \ 28 | && echo "no" | dpkg-reconfigure dash 29 | 30 | WORKDIR /opt/fish-speech 31 | 32 | COPY . . 33 | 34 | RUN --mount=type=cache,target=/root/.cache,sharing=locked \ 35 | set -ex \ 36 | && pip install -e .[stable] 37 | 38 | COPY --from=stage-1 /opt/fish-speech/checkpoints /opt/fish-speech/checkpoints 39 | 40 | ENV GRADIO_SERVER_NAME="0.0.0.0" 41 | 42 | EXPOSE 7860 43 | 44 | CMD ["./entrypoint.sh"] 45 | -------------------------------------------------------------------------------- /external_modules/fish-speech/dockerfile.dev: -------------------------------------------------------------------------------- 1 | ARG VERSION=dev 2 | ARG BASE_IMAGE=ghcr.io/fishaudio/fish-speech:${VERSION} 3 | 4 | FROM ${BASE_IMAGE} 5 | 6 | ARG TOOLS=" \ 7 | git \ 8 | curl \ 9 | build-essential \ 10 | ffmpeg \ 11 | libsm6 \ 12 | libxext6 \ 13 | libjpeg-dev \ 14 | zlib1g-dev \ 15 | aria2 \ 16 | zsh \ 17 | openssh-server \ 18 | sudo \ 19 | protobuf-compiler \ 20 | cmake" 21 | 22 | RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ 23 | --mount=type=cache,target=/var/lib/apt,sharing=locked \ 24 | set -ex \ 25 | && apt-get update \ 26 | && apt-get -y install --no-install-recommends ${TOOLS} 27 | 28 | # Install oh-my-zsh so your terminal looks nice 29 | RUN sh -c "$(curl https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" "" --unattended 30 | 31 | # Set zsh as default shell 32 | RUN chsh -s /usr/bin/zsh 33 | ENV SHELL=/usr/bin/zsh 34 | -------------------------------------------------------------------------------- /external_modules/fish-speech/docs/CNAME: -------------------------------------------------------------------------------- 1 | speech.fish.audio 2 | -------------------------------------------------------------------------------- /external_modules/fish-speech/docs/assets/figs/VS_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/fish-speech/docs/assets/figs/VS_1.jpg -------------------------------------------------------------------------------- /external_modules/fish-speech/docs/assets/figs/VS_1_pt-BR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/fish-speech/docs/assets/figs/VS_1_pt-BR.png -------------------------------------------------------------------------------- /external_modules/fish-speech/docs/assets/figs/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/fish-speech/docs/assets/figs/diagram.png -------------------------------------------------------------------------------- /external_modules/fish-speech/docs/assets/figs/diagrama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/fish-speech/docs/assets/figs/diagrama.png -------------------------------------------------------------------------------- /external_modules/fish-speech/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs-material 2 | mkdocs-static-i18n[material] 3 | mkdocs[i18n] 4 | -------------------------------------------------------------------------------- /external_modules/fish-speech/docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | .md-grid { 2 | max-width: 1440px; 3 | } 4 | -------------------------------------------------------------------------------- /external_modules/fish-speech/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_ENABLED=${CUDA_ENABLED:-true} 4 | DEVICE="" 5 | 6 | if [ "${CUDA_ENABLED}" != "true" ]; then 7 | DEVICE="--device cpu" 8 | fi 9 | 10 | exec python tools/webui.py ${DEVICE} 11 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .grad_norm import GradNormMonitor 2 | 3 | __all__ = ["GradNormMonitor"] 4 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/configs/base.yaml: -------------------------------------------------------------------------------- 1 | # Base configuration for training a model 2 | paths: 3 | run_dir: results/${project} 4 | ckpt_dir: ${paths.run_dir}/checkpoints 5 | 6 | hydra: 7 | run: 8 | dir: ${paths.run_dir} 9 | 10 | # Lightning Trainer 11 | trainer: 12 | _target_: lightning.pytorch.trainer.Trainer 13 | 14 | default_root_dir: ${paths.run_dir} 15 | accelerator: gpu 16 | num_nodes: 1 17 | devices: auto 18 | strategy: 19 | _target_: lightning.pytorch.strategies.DDPStrategy 20 | process_group_backend: nccl # This should be override when training on windows 21 | 22 | precision: bf16-mixed 23 | 24 | # disable validation by epoch end 25 | check_val_every_n_epoch: null 26 | val_check_interval: 5000 27 | max_steps: 100_000 28 | 29 | # Use torch.backends.cudnn.benchmark to speed up training 30 | benchmark: true 31 | 32 | # Callbacks 33 | callbacks: 34 | model_checkpoint: 35 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 36 | dirpath: ${paths.ckpt_dir} 37 | filename: "step_{step:09d}" 38 | save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt 39 | save_top_k: 5 # save 5 latest checkpoints 40 | monitor: step # use step to monitor checkpoints 41 | mode: max # save the latest checkpoint with the highest global_step 42 | every_n_epochs: null # don't save checkpoints by epoch end 43 | every_n_train_steps: 5000 # save checkpoints every 5000 steps 44 | auto_insert_metric_name: false 45 | 46 | model_summary: 47 | _target_: lightning.pytorch.callbacks.ModelSummary 48 | max_depth: 2 # the maximum depth of layer nesting that the summary will include 49 | 50 | learning_rate_monitor: 51 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 52 | logging_interval: step 53 | log_momentum: false 54 | 55 | grad_norm_monitor: 56 | _target_: fish_speech.callbacks.GradNormMonitor 57 | norm_type: 2 58 | logging_interval: step 59 | 60 | # Logger 61 | logger: 62 | tensorboard: 63 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 64 | save_dir: "${paths.run_dir}/tensorboard/" 65 | name: null 66 | log_graph: false 67 | default_hp_metric: true 68 | prefix: "" 69 | 70 | # wandb: 71 | # _target_: lightning.pytorch.loggers.wandb.WandbLogger 72 | # # name: "" # name of the run (normally generated by wandb) 73 | # save_dir: "${paths.run_dir}" 74 | # offline: False 75 | # id: null # pass correct id to resume experiment! 76 | # anonymous: null # enable anonymous logging 77 | # project: "fish-speech" 78 | # log_model: False # upload lightning ckpts 79 | # prefix: "" # a string to put at the beginning of metric keys 80 | # # entity: "" # set to name of your wandb team 81 | # group: "" 82 | # tags: ["vq", "hq", "finetune"] 83 | # job_type: "" 84 | 85 | # Loop 86 | train: true 87 | test: false 88 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/configs/firefly_gan_vq.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture 2 | spec_transform: 3 | _target_: fish_speech.utils.spectrogram.LogMelSpectrogram 4 | sample_rate: 44100 5 | n_mels: 160 6 | n_fft: 2048 7 | hop_length: 512 8 | win_length: 2048 9 | backbone: 10 | _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder 11 | input_channels: 160 12 | depths: [3, 3, 9, 3] 13 | dims: [128, 256, 384, 512] 14 | drop_path_rate: 0.2 15 | kernel_size: 7 16 | head: 17 | _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator 18 | hop_length: 512 19 | upsample_rates: [8, 8, 2, 2, 2] # aka. strides 20 | upsample_kernel_sizes: [16, 16, 4, 4, 4] 21 | resblock_kernel_sizes: [3, 7, 11] 22 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 23 | num_mels: 512 24 | upsample_initial_channel: 512 25 | pre_conv_kernel_size: 13 26 | post_conv_kernel_size: 13 27 | quantizer: 28 | _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize 29 | input_dim: 512 30 | n_groups: 8 31 | n_codebooks: 1 32 | levels: [8, 5, 5, 5] 33 | downsample_factor: [2, 2] 34 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/configs/lora/r_8_alpha_16.yaml: -------------------------------------------------------------------------------- 1 | _target_: fish_speech.models.text2semantic.lora.LoraConfig 2 | r: 8 3 | lora_alpha: 16 4 | lora_dropout: 0.01 5 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/configs/text2semantic_finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | - _self_ 4 | 5 | project: text2semantic_finetune_dual_ar 6 | max_length: 4096 7 | pretrained_ckpt_path: checkpoints/fish-speech-1.4 8 | 9 | # Lightning Trainer 10 | trainer: 11 | accumulate_grad_batches: 1 12 | gradient_clip_val: 1.0 13 | gradient_clip_algorithm: "norm" 14 | max_steps: 1000 15 | precision: bf16-true 16 | limit_val_batches: 10 17 | val_check_interval: 100 18 | 19 | # Dataset Configuration 20 | tokenizer: 21 | _target_: transformers.AutoTokenizer.from_pretrained 22 | pretrained_model_name_or_path: ${pretrained_ckpt_path} 23 | 24 | # Dataset Configuration 25 | train_dataset: 26 | _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset 27 | proto_files: 28 | - data/protos 29 | tokenizer: ${tokenizer} 30 | causal: true 31 | max_length: ${max_length} 32 | use_speaker: false 33 | interactive_prob: 0.7 34 | 35 | val_dataset: 36 | _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset 37 | proto_files: 38 | - data/protos 39 | tokenizer: ${tokenizer} 40 | causal: true 41 | max_length: ${max_length} 42 | use_speaker: false 43 | interactive_prob: 0.7 44 | 45 | data: 46 | _target_: fish_speech.datasets.semantic.SemanticDataModule 47 | train_dataset: ${train_dataset} 48 | val_dataset: ${val_dataset} 49 | num_workers: 4 50 | batch_size: 8 51 | tokenizer: ${tokenizer} 52 | max_length: ${max_length} 53 | 54 | # Model Configuration 55 | model: 56 | _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic 57 | model: 58 | _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained 59 | path: ${pretrained_ckpt_path} 60 | load_weights: true 61 | max_length: ${max_length} 62 | lora_config: null 63 | 64 | optimizer: 65 | _target_: torch.optim.AdamW 66 | _partial_: true 67 | lr: 1e-4 68 | weight_decay: 0 69 | betas: [0.9, 0.95] 70 | eps: 1e-5 71 | 72 | lr_scheduler: 73 | _target_: torch.optim.lr_scheduler.LambdaLR 74 | _partial_: true 75 | lr_lambda: 76 | _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda 77 | _partial_: true 78 | num_warmup_steps: 10 79 | 80 | # Callbacks 81 | callbacks: 82 | model_checkpoint: 83 | every_n_train_steps: ${trainer.val_check_interval} 84 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/conversation.py: -------------------------------------------------------------------------------- 1 | SEMANTIC_TOKEN = "<|semantic|>" 2 | CODEBOOK_PAD_TOKEN_ID = 0 3 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/datasets/concat_repeat.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import random 3 | from typing import Iterable 4 | 5 | from torch.utils.data import Dataset, IterableDataset 6 | 7 | 8 | class ConcatRepeatDataset(Dataset): 9 | datasets: list[Dataset] 10 | cumulative_sizes: list[int] 11 | repeats: list[int] 12 | 13 | @staticmethod 14 | def cumsum(sequence, repeats): 15 | r, s = [], 0 16 | for dataset, repeat in zip(sequence, repeats): 17 | l = len(dataset) * repeat 18 | r.append(l + s) 19 | s += l 20 | return r 21 | 22 | def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): 23 | super().__init__() 24 | 25 | self.datasets = list(datasets) 26 | self.repeats = repeats 27 | 28 | assert len(self.datasets) > 0, "datasets should not be an empty iterable" 29 | assert len(self.datasets) == len( 30 | repeats 31 | ), "datasets and repeats should have the same length" 32 | 33 | for d in self.datasets: 34 | assert not isinstance( 35 | d, IterableDataset 36 | ), "ConcatRepeatDataset does not support IterableDataset" 37 | 38 | self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) 39 | 40 | def __len__(self): 41 | return self.cumulative_sizes[-1] 42 | 43 | def __getitem__(self, idx): 44 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 45 | 46 | if dataset_idx == 0: 47 | sample_idx = idx 48 | else: 49 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 50 | 51 | dataset = self.datasets[dataset_idx] 52 | 53 | return dataset[sample_idx % len(dataset)] 54 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/datasets/protos/text-data.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package text_data; 4 | 5 | message Semantics { 6 | repeated uint32 values = 1; 7 | } 8 | 9 | message Sentence { 10 | repeated string texts = 1; 11 | repeated Semantics semantics = 3; 12 | } 13 | 14 | message TextData { 15 | string source = 1; 16 | string name = 2; 17 | repeated Sentence sentences = 4; 18 | } 19 | 20 | message SampledData { 21 | string source = 1; 22 | string name = 2; 23 | repeated Sentence samples = 3; 24 | } 25 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/datasets/protos/text_data_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: text-data.proto 4 | # Protobuf Python Version: 4.25.1 5 | """Generated protocol buffer code.""" 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import symbol_database as _symbol_database 9 | from google.protobuf.internal import builder as _builder 10 | 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 17 | b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3' 18 | ) 19 | 20 | _globals = globals() 21 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 22 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals) 23 | if _descriptor._USE_C_DESCRIPTORS == False: 24 | DESCRIPTOR._options = None 25 | _globals["_SEMANTICS"]._serialized_start = 30 26 | _globals["_SEMANTICS"]._serialized_end = 57 27 | _globals["_SENTENCE"]._serialized_start = 59 28 | _globals["_SENTENCE"]._serialized_end = 125 29 | _globals["_TEXTDATA"]._serialized_start = 127 30 | _globals["_TEXTDATA"]._serialized_end = 207 31 | _globals["_SAMPLEDDATA"]._serialized_start = 209 32 | _globals["_SAMPLEDDATA"]._serialized_end = 290 33 | # @@protoc_insertion_point(module_scope) 34 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/datasets/protos/text_data_stream.py: -------------------------------------------------------------------------------- 1 | import struct 2 | 3 | from .text_data_pb2 import TextData 4 | 5 | 6 | def read_pb_stream(f): 7 | while True: 8 | buf = f.read(4) 9 | if len(buf) == 0: 10 | break 11 | size = struct.unpack("I", buf)[0] 12 | buf = f.read(size) 13 | text_data = TextData() 14 | text_data.ParseFromString(buf) 15 | yield text_data 16 | 17 | 18 | def write_pb_stream(f, text_data): 19 | buf = text_data.SerializeToString() 20 | f.write(struct.pack("I", len(buf))) 21 | f.write(buf) 22 | 23 | 24 | def pack_pb_stream(text_data): 25 | buf = text_data.SerializeToString() 26 | return struct.pack("I", len(buf)) + buf 27 | 28 | 29 | def split_pb_stream(f): 30 | while True: 31 | head = f.read(4) 32 | if len(head) == 0: 33 | break 34 | size = struct.unpack("I", head)[0] 35 | buf = f.read(size) 36 | yield head + buf 37 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/i18n/README.md: -------------------------------------------------------------------------------- 1 | ## i18n Folder Attribution 2 | 3 | The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below: 4 | 5 | ### fish_speech/i18n/core.py 6 | 7 | **Related code from RVC:** 8 | [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py) 9 | 10 | **Initial commit:** 11 | add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35) 12 | 13 | **Initial author:** 14 | [@L4Ph](https://github.com/L4Ph) 15 | 16 | ### fish_speech/i18n/scan.py 17 | 18 | **Related code from RVC:** 19 | [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py) 20 | 21 | **Initial commit:** 22 | File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058) 23 | 24 | **Initial author:** 25 | [@towzeur](https://github.com/towzeur) 26 | 27 | We appreciate the contributions of the RVC project and its authors. 28 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/i18n/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import i18n 2 | 3 | __all__ = ["i18n"] 4 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/i18n/core.py: -------------------------------------------------------------------------------- 1 | import json 2 | import locale 3 | from pathlib import Path 4 | 5 | I18N_FILE_PATH = Path(__file__).parent / "locale" 6 | DEFAULT_LANGUAGE = "en_US" 7 | 8 | 9 | def load_language_list(language): 10 | with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f: 11 | language_list = json.load(f) 12 | 13 | return language_list 14 | 15 | 16 | class I18nAuto: 17 | def __init__(self): 18 | i18n_file = Path(".locale") 19 | 20 | if i18n_file.exists(): 21 | with open(i18n_file, "r", encoding="utf-8") as f: 22 | language = f.read().strip() 23 | else: 24 | # getlocale can't identify the system's language ((None, None)) 25 | language = locale.getdefaultlocale()[0] 26 | 27 | if (I18N_FILE_PATH / f"{language}.json").exists() is False: 28 | language = DEFAULT_LANGUAGE 29 | 30 | self.language = language 31 | self.language_map = load_language_list(language) 32 | 33 | def __call__(self, key): 34 | return self.language_map.get(key, key) 35 | 36 | def __repr__(self): 37 | return "Use Language: " + self.language 38 | 39 | 40 | i18n = I18nAuto() 41 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/models/text2semantic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/fish-speech/fish_speech/models/text2semantic/__init__.py -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/models/text2semantic/lora.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import loralib as lora 4 | 5 | 6 | @dataclass 7 | class LoraConfig: 8 | r: int 9 | lora_alpha: float 10 | lora_dropout: float = 0.0 11 | 12 | 13 | def setup_lora(model, lora_config): 14 | # Replace the embedding layer with a LoRA layer 15 | model.embeddings = lora.Embedding( 16 | num_embeddings=model.embeddings.num_embeddings, 17 | embedding_dim=model.embeddings.embedding_dim, 18 | padding_idx=model.embeddings.padding_idx, 19 | r=lora_config.r, 20 | lora_alpha=lora_config.lora_alpha, 21 | ) 22 | 23 | model.codebook_embeddings = lora.Embedding( 24 | num_embeddings=model.codebook_embeddings.num_embeddings, 25 | embedding_dim=model.codebook_embeddings.embedding_dim, 26 | padding_idx=model.codebook_embeddings.padding_idx, 27 | r=lora_config.r, 28 | lora_alpha=lora_config.lora_alpha, 29 | ) 30 | 31 | # Replace output layer with a LoRA layer 32 | linears = [(model, "output")] 33 | 34 | # Replace all linear layers with LoRA layers 35 | for layer in model.layers: 36 | linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) 37 | linears.extend( 38 | [ 39 | (layer.feed_forward, "w1"), 40 | (layer.feed_forward, "w2"), 41 | (layer.feed_forward, "w3"), 42 | ] 43 | ) 44 | 45 | if hasattr(model, "fast_layers"): 46 | model.fast_embeddings = lora.Embedding( 47 | num_embeddings=model.fast_embeddings.num_embeddings, 48 | embedding_dim=model.fast_embeddings.embedding_dim, 49 | padding_idx=model.fast_embeddings.padding_idx, 50 | r=lora_config.r, 51 | lora_alpha=lora_config.lora_alpha, 52 | ) 53 | 54 | # Dual-AR model 55 | linears.append((model, "fast_output")) 56 | 57 | for layer in model.fast_layers: 58 | linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) 59 | linears.extend( 60 | [ 61 | (layer.feed_forward, "w1"), 62 | (layer.feed_forward, "w2"), 63 | (layer.feed_forward, "w3"), 64 | ] 65 | ) 66 | 67 | for module, layer in linears: 68 | updated_linear = lora.Linear( 69 | in_features=getattr(module, layer).in_features, 70 | out_features=getattr(module, layer).out_features, 71 | bias=getattr(module, layer).bias, 72 | r=lora_config.r, 73 | lora_alpha=lora_config.lora_alpha, 74 | lora_dropout=lora_config.lora_dropout, 75 | ) 76 | setattr(module, layer, updated_linear) 77 | 78 | # Mark only the LoRA layers as trainable 79 | lora.mark_only_lora_as_trainable(model, bias="none") 80 | 81 | 82 | def get_merged_state_dict(model): 83 | # This line will merge the state dict of the model and the LoRA parameters 84 | model.eval() 85 | 86 | # Then we need to remove the LoRA parameters from the state dict 87 | state_dict = model.state_dict() 88 | for name in list(state_dict.keys()): 89 | if "lora" in name: 90 | state_dict.pop(name) 91 | 92 | return state_dict 93 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/models/vqgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/fish-speech/fish_speech/models/vqgan/__init__.py -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/models/vqgan/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import torch 3 | from matplotlib import pyplot as plt 4 | 5 | matplotlib.use("Agg") 6 | 7 | 8 | def convert_pad_shape(pad_shape): 9 | l = pad_shape[::-1] 10 | pad_shape = [item for sublist in l for item in sublist] 11 | return pad_shape 12 | 13 | 14 | def sequence_mask(length, max_length=None): 15 | if max_length is None: 16 | max_length = length.max() 17 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 18 | return x.unsqueeze(0) < length.unsqueeze(1) 19 | 20 | 21 | def init_weights(m, mean=0.0, std=0.01): 22 | classname = m.__class__.__name__ 23 | if classname.find("Conv") != -1: 24 | m.weight.data.normal_(mean, std) 25 | 26 | 27 | def get_padding(kernel_size, dilation=1): 28 | return int((kernel_size * dilation - dilation) / 2) 29 | 30 | 31 | def plot_mel(data, titles=None): 32 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 33 | 34 | if titles is None: 35 | titles = [None for i in range(len(data))] 36 | 37 | plt.tight_layout() 38 | 39 | for i in range(len(data)): 40 | mel = data[i] 41 | 42 | if isinstance(mel, torch.Tensor): 43 | mel = mel.float().detach().cpu().numpy() 44 | 45 | axes[i][0].imshow(mel, origin="lower") 46 | axes[i][0].set_aspect(2.5, adjustable="box") 47 | axes[i][0].set_ylim(0, mel.shape[0]) 48 | axes[i][0].set_title(titles[i], fontsize="medium") 49 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 50 | axes[i][0].set_anchor("W") 51 | 52 | return fig 53 | 54 | 55 | def slice_segments(x, ids_str, segment_size=4): 56 | ret = torch.zeros_like(x[:, :, :segment_size]) 57 | for i in range(x.size(0)): 58 | idx_str = ids_str[i] 59 | idx_end = idx_str + segment_size 60 | ret[i] = x[i, :, idx_str:idx_end] 61 | 62 | return ret 63 | 64 | 65 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 66 | b, d, t = x.size() 67 | if x_lengths is None: 68 | x_lengths = t 69 | ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) 70 | ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) 71 | ret = slice_segments(x, ids_str, segment_size) 72 | return ret, ids_str 73 | 74 | 75 | @torch.jit.script 76 | def fused_add_tanh_sigmoid_multiply(in_act, n_channels): 77 | n_channels_int = n_channels[0] 78 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 79 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 80 | acts = t_act * s_act 81 | 82 | return acts 83 | 84 | 85 | def avg_with_mask(x, mask): 86 | assert mask.dtype == torch.float, "Mask should be float" 87 | 88 | if mask.ndim == 2: 89 | mask = mask.unsqueeze(1) 90 | 91 | if mask.shape[1] == 1: 92 | mask = mask.expand_as(x) 93 | 94 | return (x * mask).sum() / mask.sum() 95 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def get_cosine_schedule_with_warmup_lr_lambda( 5 | current_step: int, 6 | *, 7 | num_warmup_steps: int | float, 8 | num_training_steps: int, 9 | num_cycles: float = 0.5, 10 | final_lr_ratio: float = 0.0, 11 | ): 12 | if 0 < num_warmup_steps < 1: # float mode 13 | num_warmup_steps = int(num_warmup_steps * num_training_steps) 14 | 15 | if current_step < num_warmup_steps: 16 | return float(current_step) / float(max(1, num_warmup_steps)) 17 | 18 | progress = float(current_step - num_warmup_steps) / float( 19 | max(1, num_training_steps - num_warmup_steps) 20 | ) 21 | 22 | return max( 23 | final_lr_ratio, 24 | 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), 25 | ) 26 | 27 | 28 | def get_constant_schedule_with_warmup_lr_lambda( 29 | current_step: int, 30 | *, 31 | num_warmup_steps: int | float, 32 | num_training_steps: int | None = None, 33 | ): 34 | if 0 < num_warmup_steps < 1: # float mode 35 | num_warmup_steps = int(num_warmup_steps * num_training_steps) 36 | 37 | if current_step < num_warmup_steps: 38 | return float(current_step) / float(max(1, num_warmup_steps)) 39 | 40 | return 1.0 41 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/__init__.py: -------------------------------------------------------------------------------- 1 | from .clean import clean_text 2 | from .spliter import split_text 3 | 4 | __all__ = ["clean_text", "split_text"] 5 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/chn_text_norm/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # JetBrains PyCharm 107 | .idea 108 | 109 | # Customize 110 | references 111 | url.txt 112 | 113 | # Git 114 | .git 115 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/chn_text_norm/README.md: -------------------------------------------------------------------------------- 1 | # This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works. 2 | 3 | # Chn Text Norm 4 | 5 | this is a repository for chinese text normalization (no longer maintained). 6 | 7 | ## Quick Start ## 8 | 9 | ### Git Clone Repo ### 10 | 11 | git clone this repo to the root directory of your project which need to use it. 12 | 13 | cd /path/to/proj 14 | git clone https://github.com/Joee1995/chn-text-norm.git 15 | 16 | after that, your doc tree should be: 17 | ``` 18 | proj # root of your project 19 | |--- chn_text_norm # this chn-text-norm tool 20 | |--- text.py 21 | |--- ... 22 | |--- text_normalize.py # your text normalization code 23 | |--- ... 24 | ``` 25 | 26 | ### How to Use ? ### 27 | 28 | # text_normalize.py 29 | from chn_text_norm.text import * 30 | 31 | raw_text = 'your raw text' 32 | text = Text(raw_text=raw_text).normalize() 33 | 34 | ### How to add quantums ### 35 | 36 | 打开test.py,然后你就知道怎么做了。 37 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/chn_text_norm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/fish-speech/fish_speech/text/chn_text_norm/__init__.py -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/chn_text_norm/basic_constant.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """基本常量 3 | 中文数字/数位/符号字符常量 4 | """ 5 | 6 | __author__ = "Zhiyang Zhou " 7 | __data__ = "2019-05-02" 8 | 9 | CHINESE_DIGIS = "零一二三四五六七八九" 10 | BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" 11 | BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" 12 | SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" 13 | SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" 14 | LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" 15 | LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" 16 | SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" 17 | SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" 18 | 19 | ZERO_ALT = "〇" 20 | ONE_ALT = "幺" 21 | TWO_ALTS = ["两", "兩"] 22 | 23 | POSITIVE = ["正", "正"] 24 | NEGATIVE = ["负", "負"] 25 | POINT = ["点", "點"] 26 | # PLUS = [u'加', u'加'] 27 | # SIL = [u'杠', u'槓'] 28 | 29 | # 中文数字系统类型 30 | NUMBERING_TYPES = ["low", "mid", "high"] 31 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/chn_text_norm/cardinal.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """CARDINAL类 (包含小数DECIMAL类) 3 | 纯数 <=> 中文字符串 方法 4 | 中文字符串 <=> 纯数 方法 5 | """ 6 | 7 | __author__ = "Zhiyang Zhou " 8 | __data__ = "2019-05-03" 9 | 10 | from fish_speech.text.chn_text_norm.basic_util import * 11 | 12 | 13 | class Cardinal: 14 | """ 15 | CARDINAL类 16 | """ 17 | 18 | def __init__(self, cardinal=None, chntext=None): 19 | self.cardinal = cardinal 20 | self.chntext = chntext 21 | 22 | def chntext2cardinal(self): 23 | return chn2num(self.chntext) 24 | 25 | def cardinal2chntext(self): 26 | return num2chn(self.cardinal) 27 | 28 | 29 | if __name__ == "__main__": 30 | 31 | # 测试程序 32 | print(Cardinal(cardinal="21357.230").cardinal2chntext()) 33 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/chn_text_norm/date.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """DATE类 3 | 日期 <=> 中文字符串 方法 4 | 中文字符串 <=> 日期 方法 5 | """ 6 | 7 | __author__ = "Zhiyang Zhou " 8 | __data__ = "2019-05-07" 9 | 10 | from fish_speech.text.chn_text_norm.cardinal import Cardinal 11 | from fish_speech.text.chn_text_norm.digit import Digit 12 | 13 | 14 | class Date: 15 | """ 16 | DATE类 17 | """ 18 | 19 | def __init__(self, date=None, chntext=None): 20 | self.date = date 21 | self.chntext = chntext 22 | 23 | # def chntext2date(self): 24 | # chntext = self.chntext 25 | # try: 26 | # year, other = chntext.strip().split('年', maxsplit=1) 27 | # year = Digit(chntext=year).digit2chntext() + '年' 28 | # except ValueError: 29 | # other = chntext 30 | # year = '' 31 | # if other: 32 | # try: 33 | # month, day = other.strip().split('月', maxsplit=1) 34 | # month = Cardinal(chntext=month).chntext2cardinal() + '月' 35 | # except ValueError: 36 | # day = chntext 37 | # month = '' 38 | # if day: 39 | # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] 40 | # else: 41 | # month = '' 42 | # day = '' 43 | # date = year + month + day 44 | # self.date = date 45 | # return self.date 46 | 47 | def date2chntext(self): 48 | date = self.date 49 | try: 50 | year, other = date.strip().split("年", maxsplit=1) 51 | year = Digit(digit=year).digit2chntext() + "年" 52 | except ValueError: 53 | other = date 54 | year = "" 55 | if other: 56 | try: 57 | month, day = other.strip().split("月", maxsplit=1) 58 | month = Cardinal(cardinal=month).cardinal2chntext() + "月" 59 | except ValueError: 60 | day = date 61 | month = "" 62 | if day: 63 | day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] 64 | else: 65 | month = "" 66 | day = "" 67 | chntext = year + month + day 68 | self.chntext = chntext 69 | return self.chntext 70 | 71 | 72 | if __name__ == "__main__": 73 | 74 | # 测试 75 | print(Date(date="09年3月16日").date2chntext()) 76 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/chn_text_norm/digit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """DIGIT类 3 | 数字串 <=> 中文字符串 方法 4 | 中文字符串 <=> 数字串 方法 5 | """ 6 | 7 | __author__ = "Zhiyang Zhou " 8 | __data__ = "2019-05-03" 9 | 10 | from fish_speech.text.chn_text_norm.basic_util import * 11 | 12 | 13 | class Digit: 14 | """ 15 | DIGIT类 16 | """ 17 | 18 | def __init__(self, digit=None, chntext=None): 19 | self.digit = digit 20 | self.chntext = chntext 21 | 22 | # def chntext2digit(self): 23 | # return chn2num(self.chntext) 24 | 25 | def digit2chntext(self): 26 | return num2chn(self.digit, alt_two=False, use_units=False) 27 | 28 | 29 | if __name__ == "__main__": 30 | 31 | # 测试程序 32 | print(Digit(digit="2016").digit2chntext()) 33 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/chn_text_norm/fraction.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """FRACTION类 3 | 分数 <=> 中文字符串 方法 4 | 中文字符串 <=> 分数 方法 5 | """ 6 | 7 | __author__ = "Zhiyang Zhou " 8 | __data__ = "2019-05-03" 9 | 10 | from fish_speech.text.chn_text_norm.basic_util import * 11 | 12 | 13 | class Fraction: 14 | """ 15 | FRACTION类 16 | """ 17 | 18 | def __init__(self, fraction=None, chntext=None): 19 | self.fraction = fraction 20 | self.chntext = chntext 21 | 22 | def chntext2fraction(self): 23 | denominator, numerator = self.chntext.split("分之") 24 | return chn2num(numerator) + "/" + chn2num(denominator) 25 | 26 | def fraction2chntext(self): 27 | numerator, denominator = self.fraction.split("/") 28 | return num2chn(denominator) + "分之" + num2chn(numerator) 29 | 30 | 31 | if __name__ == "__main__": 32 | 33 | # 测试程序 34 | print(Fraction(fraction="2135/7230").fraction2chntext()) 35 | print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction()) 36 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/chn_text_norm/money.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """MONEY类 3 | 金钱 <=> 中文字符串 方法 4 | 中文字符串 <=> 金钱 方法 5 | """ 6 | import re 7 | 8 | __author__ = "Zhiyang Zhou " 9 | __data__ = "2019-05-08" 10 | 11 | from fish_speech.text.chn_text_norm.cardinal import Cardinal 12 | 13 | 14 | class Money: 15 | """ 16 | MONEY类 17 | """ 18 | 19 | def __init__(self, money=None, chntext=None): 20 | self.money = money 21 | self.chntext = chntext 22 | 23 | # def chntext2money(self): 24 | # return self.money 25 | 26 | def money2chntext(self): 27 | money = self.money 28 | pattern = re.compile(r"(\d+(\.\d+)?)") 29 | matchers = pattern.findall(money) 30 | if matchers: 31 | for matcher in matchers: 32 | money = money.replace( 33 | matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext() 34 | ) 35 | self.chntext = money 36 | return self.chntext 37 | 38 | 39 | if __name__ == "__main__": 40 | 41 | # 测试 42 | print(Money(money="21.5万元").money2chntext()) 43 | print(Money(money="230块5毛").money2chntext()) 44 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/chn_text_norm/percentage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """PERCENTAGE类 3 | 百分数 <=> 中文字符串 方法 4 | 中文字符串 <=> 百分数 方法 5 | """ 6 | 7 | __author__ = "Zhiyang Zhou " 8 | __data__ = "2019-05-06" 9 | 10 | from fish_speech.text.chn_text_norm.basic_util import * 11 | 12 | 13 | class Percentage: 14 | """ 15 | PERCENTAGE类 16 | """ 17 | 18 | def __init__(self, percentage=None, chntext=None): 19 | self.percentage = percentage 20 | self.chntext = chntext 21 | 22 | def chntext2percentage(self): 23 | return chn2num(self.chntext.strip().strip("百分之")) + "%" 24 | 25 | def percentage2chntext(self): 26 | return "百分之" + num2chn(self.percentage.strip().strip("%")) 27 | 28 | 29 | if __name__ == "__main__": 30 | 31 | # 测试程序 32 | print(Percentage(chntext="百分之五十六点零三").chntext2percentage()) 33 | print(Percentage(percentage="65.3%").percentage2chntext()) 34 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/chn_text_norm/telephone.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """TELEPHONE类 3 | 电话号码 <=> 中文字符串 方法 4 | 中文字符串 <=> 电话号码 方法 5 | """ 6 | 7 | __author__ = "Zhiyang Zhou " 8 | __data__ = "2019-05-03" 9 | 10 | from fish_speech.text.chn_text_norm.basic_util import * 11 | 12 | 13 | class TelePhone: 14 | """ 15 | TELEPHONE类 16 | """ 17 | 18 | def __init__(self, telephone=None, raw_chntext=None, chntext=None): 19 | self.telephone = telephone 20 | self.raw_chntext = raw_chntext 21 | self.chntext = chntext 22 | 23 | # def chntext2telephone(self): 24 | # sil_parts = self.raw_chntext.split('') 25 | # self.telephone = '-'.join([ 26 | # str(chn2num(p)) for p in sil_parts 27 | # ]) 28 | # return self.telephone 29 | 30 | def telephone2chntext(self, fixed=False): 31 | 32 | if fixed: 33 | sil_parts = self.telephone.split("-") 34 | self.raw_chntext = "".join( 35 | [num2chn(part, alt_two=False, use_units=False) for part in sil_parts] 36 | ) 37 | self.chntext = self.raw_chntext.replace("", "") 38 | else: 39 | sp_parts = self.telephone.strip("+").split() 40 | self.raw_chntext = "".join( 41 | [num2chn(part, alt_two=False, use_units=False) for part in sp_parts] 42 | ) 43 | self.chntext = self.raw_chntext.replace("", "") 44 | return self.chntext 45 | 46 | 47 | if __name__ == "__main__": 48 | 49 | # 测试程序 50 | print(TelePhone(telephone="0595-23980880").telephone2chntext()) 51 | # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone()) 52 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/text/clean.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | SYMBOLS_MAPPING = { 4 | "\n": "", 5 | "…": ".", 6 | "“": "'", 7 | "”": "'", 8 | "‘": "'", 9 | "’": "'", 10 | "【": "", 11 | "】": "", 12 | "[": "", 13 | "]": "", 14 | "(": "", 15 | ")": "", 16 | "(": "", 17 | ")": "", 18 | "・": "", 19 | "·": "", 20 | "「": "'", 21 | "」": "'", 22 | "《": "'", 23 | "》": "'", 24 | "—": "", 25 | "~": "", 26 | "~": "", 27 | ":": ",", 28 | ";": ",", 29 | ";": ",", 30 | ":": ",", 31 | } 32 | 33 | REPLACE_SYMBOL_REGEX = re.compile( 34 | "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys()) 35 | ) 36 | 37 | 38 | EMOJI_REGEX = re.compile( 39 | "[" 40 | "\U0001F600-\U0001F64F" # emoticons 41 | "\U0001F300-\U0001F5FF" # symbols & pictographs 42 | "\U0001F680-\U0001F6FF" # transport & map symbols 43 | "\U0001F1E0-\U0001F1FF" # flags (iOS) 44 | "]+", 45 | flags=re.UNICODE, 46 | ) 47 | 48 | 49 | def clean_text(text): 50 | # Clean the text 51 | text = text.strip() 52 | 53 | # Replace all chinese symbols with their english counterparts 54 | text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) 55 | 56 | # Remove emojis 57 | text = EMOJI_REGEX.sub(r"", text) 58 | 59 | # Remove continuous periods (...) and commas (,,,) 60 | text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text) 61 | 62 | return text 63 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .braceexpand import braceexpand 2 | from .context import autocast_exclude_mps 3 | from .file import get_latest_checkpoint 4 | from .instantiators import instantiate_callbacks, instantiate_loggers 5 | from .logger import RankedLogger 6 | from .logging_utils import log_hyperparameters 7 | from .rich_utils import enforce_tags, print_config_tree 8 | from .utils import extras, get_metric_value, task_wrapper 9 | 10 | __all__ = [ 11 | "enforce_tags", 12 | "extras", 13 | "get_metric_value", 14 | "RankedLogger", 15 | "instantiate_callbacks", 16 | "instantiate_loggers", 17 | "log_hyperparameters", 18 | "print_config_tree", 19 | "task_wrapper", 20 | "braceexpand", 21 | "get_latest_checkpoint", 22 | "autocast_exclude_mps", 23 | ] 24 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/utils/context.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | 3 | import torch 4 | 5 | 6 | def autocast_exclude_mps( 7 | device_type: str, dtype: torch.dtype 8 | ) -> nullcontext | torch.autocast: 9 | return ( 10 | nullcontext() 11 | if torch.backends.mps.is_available() 12 | else torch.autocast(device_type, dtype) 13 | ) 14 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/utils/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | def get_latest_checkpoint(path: Path | str) -> Path | None: 6 | # Find the latest checkpoint 7 | ckpt_dir = Path(path) 8 | 9 | if ckpt_dir.exists() is False: 10 | return None 11 | 12 | ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) 13 | if len(ckpts) == 0: 14 | return None 15 | 16 | return ckpts[-1] 17 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | from pytorch_lightning import Callback 6 | from pytorch_lightning.loggers import Logger 7 | 8 | from .logger import RankedLogger 9 | 10 | log = RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config.""" 15 | 16 | callbacks: List[Callback] = [] 17 | 18 | if not callbacks_cfg: 19 | log.warning("No callback configs found! Skipping..") 20 | return callbacks 21 | 22 | if not isinstance(callbacks_cfg, DictConfig): 23 | raise TypeError("Callbacks config must be a DictConfig!") 24 | 25 | for _, cb_conf in callbacks_cfg.items(): 26 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 27 | log.info(f"Instantiating callback <{cb_conf._target_}>") 28 | callbacks.append(hydra.utils.instantiate(cb_conf)) 29 | 30 | return callbacks 31 | 32 | 33 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 34 | """Instantiates loggers from config.""" 35 | 36 | logger: List[Logger] = [] 37 | 38 | if not logger_cfg: 39 | log.warning("No logger configs found! Skipping...") 40 | return logger 41 | 42 | if not isinstance(logger_cfg, DictConfig): 43 | raise TypeError("Logger config must be a DictConfig!") 44 | 45 | for _, lg_conf in logger_cfg.items(): 46 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 47 | log.info(f"Instantiating logger <{lg_conf._target_}>") 48 | logger.append(hydra.utils.instantiate(lg_conf)) 49 | 50 | return logger 51 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Mapping, Optional 3 | 4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only 5 | 6 | 7 | class RankedLogger(logging.LoggerAdapter): 8 | """A multi-GPU-friendly python command line logger.""" 9 | 10 | def __init__( 11 | self, 12 | name: str = __name__, 13 | rank_zero_only: bool = True, 14 | extra: Optional[Mapping[str, object]] = None, 15 | ) -> None: 16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes 17 | with their rank prefixed in the log message. 18 | 19 | :param name: The name of the logger. Default is ``__name__``. 20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 22 | """ 23 | logger = logging.getLogger(name) 24 | super().__init__(logger=logger, extra=extra) 25 | self.rank_zero_only = rank_zero_only 26 | 27 | def log( 28 | self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs 29 | ) -> None: 30 | """Delegate a log call to the underlying logger, after prefixing its message with the rank 31 | of the process it's being logged from. If `'rank'` is provided, then the log will only 32 | occur on that rank/process. 33 | 34 | :param level: The level to log at. Look at `logging.__init__.py` for more information. 35 | :param msg: The message to log. 36 | :param rank: The rank to log at. 37 | :param args: Additional args to pass to the underlying logging function. 38 | :param kwargs: Any additional keyword args to pass to the underlying logging function. 39 | """ 40 | if self.isEnabledFor(level): 41 | msg, kwargs = self.process(msg, kwargs) 42 | current_rank = getattr(rank_zero_only, "rank", None) 43 | if current_rank is None: 44 | raise RuntimeError( 45 | "The `rank_zero_only.rank` needs to be set before use" 46 | ) 47 | msg = rank_prefixed_message(msg, current_rank) 48 | if self.rank_zero_only: 49 | if current_rank == 0: 50 | self.logger.log(level, msg, *args, **kwargs) 51 | else: 52 | if rank is None: 53 | self.logger.log(level, msg, *args, **kwargs) 54 | elif current_rank == rank: 55 | self.logger.log(level, msg, *args, **kwargs) 56 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.utilities import rank_zero_only 2 | 3 | from fish_speech.utils import logger as log 4 | 5 | 6 | @rank_zero_only 7 | def log_hyperparameters(object_dict: dict) -> None: 8 | """Controls which config parts are saved by lightning loggers. 9 | 10 | Additionally saves: 11 | - Number of model parameters 12 | """ 13 | 14 | hparams = {} 15 | 16 | cfg = object_dict["cfg"] 17 | model = object_dict["model"] 18 | trainer = object_dict["trainer"] 19 | 20 | if not trainer.logger: 21 | log.warning("Logger not found! Skipping hyperparameter logging...") 22 | return 23 | 24 | hparams["model"] = cfg["model"] 25 | 26 | # save number of model parameters 27 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 28 | hparams["model/params/trainable"] = sum( 29 | p.numel() for p in model.parameters() if p.requires_grad 30 | ) 31 | hparams["model/params/non_trainable"] = sum( 32 | p.numel() for p in model.parameters() if not p.requires_grad 33 | ) 34 | 35 | hparams["data"] = cfg["data"] 36 | hparams["trainer"] = cfg["trainer"] 37 | 38 | hparams["callbacks"] = cfg.get("callbacks") 39 | hparams["extras"] = cfg.get("extras") 40 | 41 | hparams["task_name"] = cfg.get("task_name") 42 | hparams["tags"] = cfg.get("tags") 43 | hparams["ckpt_path"] = cfg.get("ckpt_path") 44 | hparams["seed"] = cfg.get("seed") 45 | 46 | # send hparams to all loggers 47 | for logger in trainer.loggers: 48 | logger.log_hyperparams(hparams) 49 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from fish_speech.utils import logger as log 13 | 14 | 15 | @rank_zero_only 16 | def print_config_tree( 17 | cfg: DictConfig, 18 | print_order: Sequence[str] = ( 19 | "data", 20 | "model", 21 | "callbacks", 22 | "logger", 23 | "trainer", 24 | "paths", 25 | "extras", 26 | ), 27 | resolve: bool = False, 28 | save_to_file: bool = False, 29 | ) -> None: 30 | """Prints content of DictConfig using Rich library and its tree structure. 31 | 32 | Args: 33 | cfg (DictConfig): Configuration composed by Hydra. 34 | print_order (Sequence[str], optional): Determines in what order config components are printed. 35 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 36 | save_to_file (bool, optional): Whether to export config to the hydra output folder. 37 | """ # noqa: E501 38 | 39 | style = "dim" 40 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 41 | 42 | queue = [] 43 | 44 | # add fields from `print_order` to queue 45 | for field in print_order: 46 | ( 47 | queue.append(field) 48 | if field in cfg 49 | else log.warning( 50 | f"Field '{field}' not found in config. " 51 | + f"Skipping '{field}' config printing..." 52 | ) 53 | ) 54 | 55 | # add all the other fields to queue (not specified in `print_order`) 56 | for field in cfg: 57 | if field not in queue: 58 | queue.append(field) 59 | 60 | # generate config tree from queue 61 | for field in queue: 62 | branch = tree.add(field, style=style, guide_style=style) 63 | 64 | config_group = cfg[field] 65 | if isinstance(config_group, DictConfig): 66 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 67 | else: 68 | branch_content = str(config_group) 69 | 70 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 71 | 72 | # print config tree 73 | rich.print(tree) 74 | 75 | # save config tree to file 76 | if save_to_file: 77 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 78 | rich.print(tree, file=file) 79 | 80 | 81 | @rank_zero_only 82 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 83 | """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 84 | 85 | if not cfg.get("tags"): 86 | if "id" in HydraConfig().cfg.hydra.job: 87 | raise ValueError("Specify tags before launching a multirun!") 88 | 89 | log.warning("No tags provided in config. Prompting user to input tags...") 90 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 91 | tags = [t.strip() for t in tags.split(",") if t != ""] 92 | 93 | with open_dict(cfg): 94 | cfg.tags = tags 95 | 96 | log.info(f"Tags: {cfg.tags}") 97 | 98 | if save_to_file: 99 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 100 | rich.print(cfg.tags, file=file) 101 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/webui/css/style.css: -------------------------------------------------------------------------------- 1 | :root { 2 | --my-200: #80eeee; 3 | --my-50: #ecfdf5; 4 | --water-width: 300px; 5 | --water-heigh: 300px; 6 | } 7 | 8 | 9 | /* general styled components */ 10 | .tools { 11 | align-items: center; 12 | justify-content: center; 13 | } 14 | 15 | .gradio-button { 16 | max-width: 2.2em; 17 | min-width: 2.2em !important; 18 | height: 2.4em; 19 | align-self: end; 20 | line-height: 1em; 21 | border-radius: 0.5em; 22 | 23 | } 24 | 25 | .gradio-button.secondary-down, .gradio-button.secondary-down:hover{ 26 | box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset; 27 | } 28 | 29 | /* replace original footer with ours */ 30 | a{ 31 | font-weight: bold; 32 | cursor: pointer; 33 | color: #030C14 !important; 34 | } 35 | 36 | footer { 37 | display: none !important; 38 | } 39 | 40 | #footer{ 41 | text-align: center; 42 | } 43 | 44 | #footer div{ 45 | display: inline-block; 46 | } 47 | 48 | #footer .versions{ 49 | font-size: 85%; 50 | opacity: 0.85; 51 | } 52 | 53 | /*@keyframes moveBackground {*/ 54 | /* 0% {*/ 55 | /* background-position: 0 0;*/ 56 | /* }*/ 57 | /* 100% {*/ 58 | /* background-position: -100px 100px;*/ 59 | /* }*/ 60 | /*}*/ 61 | @keyframes moveJellyBackground { 62 | 0% { 63 | background-position: 0% 50%; 64 | } 65 | 50% { 66 | background-position: 100% 50%; 67 | } 68 | 100% { 69 | background-position: 0% 50%; 70 | } 71 | } 72 | 73 | .gradio-container { 74 | position: absolute; 75 | z-index: 10; 76 | } 77 | 78 | 79 | .quan { 80 | position: absolute; 81 | bottom: 0; 82 | width: var(--water-width); 83 | height: var(--water-heigh); 84 | border-radius: 0; 85 | /*border: 3px solid rgb(246, 247, 248);*/ 86 | /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/ 87 | z-index: 0; 88 | 89 | } 90 | 91 | .quan:last-child { 92 | margin-right: 0; 93 | } 94 | 95 | .shui { 96 | position: absolute; 97 | top: 0; 98 | left: 0; 99 | width: 100%; 100 | height: 100%; 101 | background-color: rgb(23, 106, 201); 102 | border-radius: 0; 103 | overflow: hidden; 104 | z-index: 0; 105 | } 106 | 107 | .shui::after { 108 | 109 | content: ''; 110 | position: absolute; 111 | top: 20%; 112 | left: 50%; 113 | width: 150%; 114 | height: 150%; 115 | border-radius: 40%; 116 | background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%); 117 | animation: shi 5s linear infinite; 118 | } 119 | 120 | @keyframes shi { 121 | 0% { 122 | transform: translate(-50%, -65%) rotate(0deg); 123 | } 124 | 100% { 125 | transform: translate(-50%, -65%) rotate(360deg); 126 | } 127 | } 128 | 129 | .shui::before { 130 | content: ''; 131 | position: absolute; 132 | top: 20%; 133 | left: 50%; 134 | width: 150%; 135 | height: 150%; 136 | border-radius: 42%; 137 | background-color: rgb(240, 228, 228, 0.2); 138 | animation: xu 7s linear infinite; 139 | } 140 | 141 | @keyframes xu { 142 | 0% { 143 | transform: translate(-50%, -60%) rotate(0deg); 144 | } 145 | 100% { 146 | transform: translate(-50%, -60%) rotate(360deg); 147 | } 148 | } 149 | 150 | fieldset.data_src div.wrap label { 151 | background: #f8bffee0 !important; 152 | } 153 | 154 | .scrollable-component { 155 | max-height: 100px; 156 | overflow-y: auto; 157 | } 158 | 159 | #file_accordion { 160 | max-height: 220px !important; 161 | } 162 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/webui/html/footer.html: -------------------------------------------------------------------------------- 1 |
2 | API 3 |  •  4 | Github 5 |  •  6 | Gradio 7 |
8 |
9 |
10 | {versions} 11 |
12 | -------------------------------------------------------------------------------- /external_modules/fish-speech/fish_speech/webui/js/animate.js: -------------------------------------------------------------------------------- 1 | 2 | function createGradioAnimation() { 3 | const params = new URLSearchParams(window.location.search); 4 | if (!params.has('__theme')) { 5 | params.set('__theme', 'light'); 6 | window.location.search = params.toString(); 7 | } 8 | 9 | var gradioApp = document.querySelector('gradio-app'); 10 | if (gradioApp) { 11 | 12 | document.documentElement.style.setProperty('--my-200', '#80eeee'); 13 | document.documentElement.style.setProperty('--my-50', '#ecfdf5'); 14 | 15 | // gradioApp.style.position = 'relative'; 16 | // gradioApp.style.backgroundSize = '200% 200%'; 17 | // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite'; 18 | // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)'; 19 | // gradioApp.style.display = 'flex'; 20 | // gradioApp.style.justifyContent = 'flex-start'; 21 | // gradioApp.style.flexWrap = 'nowrap'; 22 | // gradioApp.style.overflowX = 'auto'; 23 | 24 | // for (let i = 0; i < 6; i++) { 25 | // var quan = document.createElement('div'); 26 | // quan.className = 'quan'; 27 | // gradioApp.insertBefore(quan, gradioApp.firstChild); 28 | // quan.id = 'quan' + i.toString(); 29 | // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')'; 30 | // var quanContainer = document.querySelector('.quan'); 31 | // if (quanContainer) { 32 | // var shui = document.createElement('div'); 33 | // shui.className = 'shui'; 34 | // quanContainer.insertBefore(shui, quanContainer.firstChild) 35 | // } 36 | // } 37 | } 38 | 39 | var container = document.createElement('div'); 40 | container.id = 'gradio-animation'; 41 | container.style.fontSize = '2em'; 42 | container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace'; 43 | container.style.fontWeight = 'bold'; 44 | container.style.textAlign = 'center'; 45 | container.style.marginBottom = '20px'; 46 | 47 | var text = 'Welcome to Fish-Speech!'; 48 | for (var i = 0; i < text.length; i++) { 49 | (function(i){ 50 | setTimeout(function(){ 51 | var letter = document.createElement('span'); 52 | letter.style.opacity = '0'; 53 | letter.style.transition = 'opacity 0.5s'; 54 | letter.innerText = text[i]; 55 | 56 | container.appendChild(letter); 57 | 58 | setTimeout(function() { 59 | letter.style.opacity = '1'; 60 | }, 50); 61 | }, i * 200); 62 | })(i); 63 | } 64 | 65 | var gradioContainer = document.querySelector('.gradio-container'); 66 | gradioContainer.insertBefore(container, gradioContainer.firstChild); 67 | 68 | return 'Animation created'; 69 | } 70 | -------------------------------------------------------------------------------- /external_modules/fish-speech/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Fish Speech 2 | site_description: Targeting SOTA TTS solutions. 3 | site_url: https://speech.fish.audio 4 | 5 | # Repository 6 | repo_name: fishaudio/fish-speech 7 | repo_url: https://github.com/fishaudio/fish-speech 8 | edit_uri: blob/main/docs 9 | 10 | # Copyright 11 | copyright: Copyright © 2023-2024 by Fish Audio 12 | 13 | theme: 14 | name: material 15 | language: en 16 | features: 17 | - content.action.edit 18 | - content.action.view 19 | - navigation.tracking 20 | - navigation.footer 21 | # - navigation.tabs 22 | - search 23 | - search.suggest 24 | - search.highlight 25 | - search.share 26 | - content.code.copy 27 | icon: 28 | logo: fontawesome/solid/fish 29 | 30 | palette: 31 | # Palette toggle for automatic mode 32 | - media: "(prefers-color-scheme)" 33 | toggle: 34 | icon: material/brightness-auto 35 | name: Switch to light mode 36 | 37 | # Palette toggle for light mode 38 | - media: "(prefers-color-scheme: light)" 39 | scheme: default 40 | toggle: 41 | icon: material/brightness-7 42 | name: Switch to dark mode 43 | primary: black 44 | font: 45 | code: Roboto Mono 46 | 47 | # Palette toggle for dark mode 48 | - media: "(prefers-color-scheme: dark)" 49 | scheme: slate 50 | toggle: 51 | icon: material/brightness-4 52 | name: Switch to light mode 53 | primary: black 54 | font: 55 | code: Roboto Mono 56 | 57 | # Plugins 58 | plugins: 59 | - search: 60 | separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])' 61 | lang: 62 | - en 63 | - zh 64 | - ja 65 | - pt 66 | - i18n: 67 | docs_structure: folder 68 | languages: 69 | - locale: en 70 | name: English 71 | default: true 72 | build: true 73 | - locale: zh 74 | name: 简体中文 75 | build: true 76 | - locale: ja 77 | name: 日本語 78 | build: true 79 | - locale: pt 80 | name: Português (Brasil) 81 | build: true 82 | 83 | markdown_extensions: 84 | - pymdownx.highlight: 85 | anchor_linenums: true 86 | line_spans: __span 87 | pygments_lang_class: true 88 | - pymdownx.inlinehilite 89 | - pymdownx.snippets 90 | - pymdownx.superfences 91 | - admonition 92 | - pymdownx.details 93 | - pymdownx.superfences 94 | - attr_list 95 | - md_in_html 96 | - pymdownx.superfences 97 | 98 | extra_css: 99 | - stylesheets/extra.css 100 | 101 | extra: 102 | social: 103 | - icon: fontawesome/brands/discord 104 | link: https://discord.gg/Es5qTB9BcN 105 | - icon: fontawesome/brands/docker 106 | link: https://hub.docker.com/r/fishaudio/fish-speech 107 | - icon: fontawesome/brands/qq 108 | link: http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=jCKlUP7QgSm9kh95UlBoYv6s1I-Apl1M&authKey=xI5ttVAp3do68IpEYEalwXSYZFdfxZSkah%2BctF5FIMyN2NqAa003vFtLqJyAVRfF&noverify=0&group_code=593946093 109 | homepage: https://speech.fish.audio 110 | -------------------------------------------------------------------------------- /external_modules/fish-speech/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fish-speech" 3 | version = "0.1.0" 4 | authors = [ 5 | {name = "Lengyue", email = "lengyue@lengyue.me"}, 6 | ] 7 | description = "Fish Speech" 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | keywords = ["TTS", "Speech"] 11 | license = {text = "CC BY-NC-SA 4.0"} 12 | classifiers = [ 13 | "Programming Language :: Python :: 3", 14 | ] 15 | dependencies = [ 16 | "numpy<=1.26.4", 17 | "transformers>=4.35.2", 18 | "datasets==2.18.0", 19 | "lightning>=2.1.0", 20 | "hydra-core>=1.3.2", 21 | "tensorboard>=2.14.1", 22 | "natsort>=8.4.0", 23 | "einops>=0.7.0", 24 | "librosa>=0.10.1", 25 | "rich>=13.5.3", 26 | "gradio>=4.0.0", 27 | "wandb>=0.15.11", 28 | "grpcio>=1.58.0", 29 | "kui>=1.6.0", 30 | "uvicorn>=0.30.0", 31 | "loguru>=0.6.0", 32 | "loralib>=0.1.2", 33 | "natsort>=8.4.0", 34 | "pyrootutils>=1.0.4", 35 | "vector_quantize_pytorch>=1.14.24", 36 | "resampy>=0.4.3", 37 | "einx[torch]==0.2.2", 38 | "zstandard>=0.22.0", 39 | "pydub", 40 | "faster_whisper", 41 | "modelscope==1.17.1", 42 | "funasr==1.1.5", 43 | "opencc-python-reimplemented==0.1.7", 44 | "silero-vad", 45 | "ormsgpack", 46 | ] 47 | 48 | [project.optional-dependencies] 49 | stable = [ 50 | "torch>=2.3.1", 51 | "torchaudio", 52 | ] 53 | 54 | [build-system] 55 | requires = ["setuptools", "setuptools-scm"] 56 | build-backend = "setuptools.build_meta" 57 | 58 | [tool.setuptools] 59 | packages = ["fish_speech", "tools"] 60 | -------------------------------------------------------------------------------- /external_modules/fish-speech/pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | "data", 4 | "filelists" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /external_modules/fish-speech/run_cmd.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | chcp 65001 3 | 4 | set no_proxy="127.0.0.1, 0.0.0.0, localhost" 5 | setlocal enabledelayedexpansion 6 | 7 | cd /D "%~dp0" 8 | 9 | set PATH="%PATH%";%SystemRoot%\system32 10 | 11 | 12 | echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && ( 13 | echo. 14 | echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && ( 15 | goto end 16 | ) 17 | ) 18 | 19 | 20 | set TMP=%CD%\fishenv 21 | set TEMP=%CD%\fishenv 22 | 23 | 24 | (call conda deactivate && call conda deactivate && call conda deactivate) 2>nul 25 | 26 | 27 | set CONDA_ROOT_PREFIX=%cd%\fishenv\conda 28 | set INSTALL_ENV_DIR=%cd%\fishenv\env 29 | 30 | 31 | set PYTHONNOUSERSITE=1 32 | set PYTHONPATH=%~dp0 33 | set PYTHONHOME= 34 | 35 | 36 | call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" 37 | 38 | if errorlevel 1 ( 39 | echo. 40 | echo Environment activation failed. 41 | goto end 42 | ) else ( 43 | echo. 44 | echo Environment activation succeeded. 45 | ) 46 | 47 | cmd /k "%*" 48 | 49 | :end 50 | pause 51 | -------------------------------------------------------------------------------- /external_modules/fish-speech/start.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | chcp 65001 3 | 4 | set USE_MIRROR=true 5 | set PYTHONPATH=%~dp0 6 | set PYTHON_CMD=python 7 | if exist "fishenv" ( 8 | set PYTHON_CMD=%cd%\fishenv\env\python 9 | ) 10 | 11 | set API_FLAG_PATH=%~dp0API_FLAGS.txt 12 | set KMP_DUPLICATE_LIB_OK=TRUE 13 | 14 | setlocal enabledelayedexpansion 15 | 16 | set "HF_ENDPOINT=https://huggingface.co" 17 | set "no_proxy=" 18 | if "%USE_MIRROR%" == "true" ( 19 | set "HF_ENDPOINT=https://hf-mirror.com" 20 | set "no_proxy=localhost, 127.0.0.1, 0.0.0.0" 21 | ) 22 | echo "HF_ENDPOINT: !HF_ENDPOINT!" 23 | echo "NO_PROXY: !no_proxy!" 24 | 25 | echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~\u4E00-\u9FFF ] " >nul && ( 26 | echo. 27 | echo There are special characters in the current path, please make the path of fish-speech free of special characters before running. && ( 28 | goto end 29 | ) 30 | ) 31 | 32 | %PYTHON_CMD% .\tools\download_models.py 33 | 34 | set "API_FLAGS=" 35 | set "flags=" 36 | 37 | if exist "%API_FLAG_PATH%" ( 38 | for /f "usebackq tokens=*" %%a in ("%API_FLAG_PATH%") do ( 39 | set "line=%%a" 40 | if not "!line:~0,1!"=="#" ( 41 | set "line=!line: =!" 42 | set "line=!line:\=!" 43 | set "line=!line:= !" 44 | if not "!line!"=="" ( 45 | set "API_FLAGS=!API_FLAGS!!line! " 46 | ) 47 | ) 48 | ) 49 | ) 50 | 51 | 52 | if not "!API_FLAGS!"=="" set "API_FLAGS=!API_FLAGS:~0,-1!" 53 | 54 | set "flags=" 55 | 56 | echo !API_FLAGS! | findstr /C:"--api" >nul 2>&1 57 | if !errorlevel! equ 0 ( 58 | echo. 59 | echo Start HTTP API... 60 | set "mode=api" 61 | goto process_flags 62 | ) 63 | 64 | echo !API_FLAGS! | findstr /C:"--infer" >nul 2>&1 65 | if !errorlevel! equ 0 ( 66 | echo. 67 | echo Start WebUI Inference... 68 | set "mode=infer" 69 | goto process_flags 70 | ) 71 | 72 | 73 | :process_flags 74 | for %%p in (!API_FLAGS!) do ( 75 | if not "%%p"=="--!mode!" ( 76 | set "flags=!flags! %%p" 77 | ) 78 | ) 79 | 80 | if not "!flags!"=="" set "flags=!flags:~1!" 81 | 82 | echo Debug: flags = !flags! 83 | 84 | if "!mode!"=="api" ( 85 | %PYTHON_CMD% -m tools.api !flags! 86 | ) else if "!mode!"=="infer" ( 87 | %PYTHON_CMD% -m tools.webui !flags! 88 | ) 89 | 90 | echo. 91 | echo Next launch the page... 92 | %PYTHON_CMD% fish_speech\webui\manage.py 93 | 94 | 95 | :end 96 | endlocal 97 | pause 98 | -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/commons.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Literal, Optional 2 | 3 | from pydantic import BaseModel, Field, conint 4 | 5 | 6 | class ServeReferenceAudio(BaseModel): 7 | audio: bytes 8 | text: str 9 | 10 | 11 | class ServeTTSRequest(BaseModel): 12 | text: str 13 | chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 14 | # Audio format 15 | format: Literal["wav", "pcm", "mp3"] = "wav" 16 | mp3_bitrate: Literal[64, 128, 192] = 128 17 | # References audios for in-context learning 18 | references: list[ServeReferenceAudio] = [] 19 | # Reference id 20 | # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ 21 | # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 22 | reference_id: str | None = None 23 | # Normalize text for en & zh, this increase stability for numbers 24 | normalize: bool = True 25 | mp3_bitrate: Optional[int] = 64 26 | opus_bitrate: Optional[int] = -1000 27 | # Balance mode will reduce latency to 300ms, but may decrease stability 28 | latency: Literal["normal", "balanced"] = "normal" 29 | # not usually used below 30 | streaming: bool = False 31 | emotion: Optional[str] = None 32 | max_new_tokens: int = 1024 33 | top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 34 | repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 35 | temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 36 | -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/download_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from huggingface_hub import hf_hub_download 4 | 5 | 6 | # Download 7 | def check_and_download_files(repo_id, file_list, local_dir): 8 | os.makedirs(local_dir, exist_ok=True) 9 | for file in file_list: 10 | file_path = os.path.join(local_dir, file) 11 | if not os.path.exists(file_path): 12 | print(f"{file} 不存在,从 Hugging Face 仓库下载...") 13 | hf_hub_download( 14 | repo_id=repo_id, 15 | filename=file, 16 | resume_download=True, 17 | local_dir=local_dir, 18 | local_dir_use_symlinks=False, 19 | ) 20 | else: 21 | print(f"{file} 已存在,跳过下载。") 22 | 23 | 24 | # 1st 25 | repo_id_1 = "fishaudio/fish-speech-1.4" 26 | local_dir_1 = "./checkpoints/fish-speech-1.4" 27 | files_1 = [ 28 | "model.pth", 29 | "README.md", 30 | "special_tokens_map.json", 31 | "tokenizer_config.json", 32 | "tokenizer.json", 33 | "config.json", 34 | "firefly-gan-vq-fsq-8x1024-21hz-generator.pth", 35 | ] 36 | 37 | # 3rd 38 | repo_id_3 = "fishaudio/fish-speech-1" 39 | local_dir_3 = "./" 40 | files_3 = [ 41 | "ffmpeg.exe", 42 | "ffprobe.exe", 43 | ] 44 | 45 | # 4th 46 | repo_id_4 = "SpicyqSama007/fish-speech-packed" 47 | local_dir_4 = "./" 48 | files_4 = [ 49 | "asr-label-win-x64.exe", 50 | ] 51 | 52 | check_and_download_files(repo_id_1, files_1, local_dir_1) 53 | 54 | check_and_download_files(repo_id_3, files_3, local_dir_3) 55 | check_and_download_files(repo_id_4, files_4, local_dir_4) 56 | -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/extract_model.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | from loguru import logger 4 | 5 | 6 | @click.command() 7 | @click.argument("model_path") 8 | @click.argument("output_path") 9 | def main(model_path, output_path): 10 | if model_path == output_path: 11 | logger.error("Model path and output path are the same") 12 | return 13 | 14 | logger.info(f"Loading model from {model_path}") 15 | state_dict = torch.load(model_path, map_location="cpu")["state_dict"] 16 | torch.save(state_dict, output_path) 17 | logger.info(f"Model saved to {output_path}") 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/file.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | from loguru import logger 6 | from natsort import natsorted 7 | 8 | AUDIO_EXTENSIONS = { 9 | ".mp3", 10 | ".wav", 11 | ".flac", 12 | ".ogg", 13 | ".m4a", 14 | ".wma", 15 | ".aac", 16 | ".aiff", 17 | ".aif", 18 | ".aifc", 19 | } 20 | 21 | VIDEO_EXTENSIONS = { 22 | ".mp4", 23 | ".avi", 24 | } 25 | 26 | 27 | def audio_to_bytes(file_path): 28 | if not file_path or not Path(file_path).exists(): 29 | return None 30 | with open(file_path, "rb") as wav_file: 31 | wav = wav_file.read() 32 | return wav 33 | 34 | 35 | def read_ref_text(ref_text): 36 | path = Path(ref_text) 37 | if path.exists() and path.is_file(): 38 | with path.open("r", encoding="utf-8") as file: 39 | return file.read() 40 | return ref_text 41 | 42 | 43 | def list_files( 44 | path: Union[Path, str], 45 | extensions: set[str] = None, 46 | recursive: bool = False, 47 | sort: bool = True, 48 | ) -> list[Path]: 49 | """List files in a directory. 50 | 51 | Args: 52 | path (Path): Path to the directory. 53 | extensions (set, optional): Extensions to filter. Defaults to None. 54 | recursive (bool, optional): Whether to search recursively. Defaults to False. 55 | sort (bool, optional): Whether to sort the files. Defaults to True. 56 | 57 | Returns: 58 | list: List of files. 59 | """ 60 | 61 | if isinstance(path, str): 62 | path = Path(path) 63 | 64 | if not path.exists(): 65 | raise FileNotFoundError(f"Directory {path} does not exist.") 66 | 67 | files = [file for ext in extensions for file in path.rglob(f"*{ext}")] 68 | 69 | if sort: 70 | files = natsorted(files) 71 | 72 | return files 73 | 74 | 75 | def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]: 76 | """ 77 | Load a Bert-VITS2 style filelist. 78 | """ 79 | 80 | files = set() 81 | results = [] 82 | count_duplicated, count_not_found = 0, 0 83 | 84 | LANGUAGE_TO_LANGUAGES = { 85 | "zh": ["zh", "en"], 86 | "jp": ["jp", "en"], 87 | "en": ["en"], 88 | } 89 | 90 | with open(path, "r", encoding="utf-8") as f: 91 | for line in f.readlines(): 92 | splits = line.strip().split("|", maxsplit=3) 93 | if len(splits) != 4: 94 | logger.warning(f"Invalid line: {line}") 95 | continue 96 | 97 | filename, speaker, language, text = splits 98 | file = Path(filename) 99 | language = language.strip().lower() 100 | 101 | if language == "ja": 102 | language = "jp" 103 | 104 | assert language in ["zh", "jp", "en"], f"Invalid language {language}" 105 | languages = LANGUAGE_TO_LANGUAGES[language] 106 | 107 | if file in files: 108 | logger.warning(f"Duplicated file: {file}") 109 | count_duplicated += 1 110 | continue 111 | 112 | if not file.exists(): 113 | logger.warning(f"File not found: {file}") 114 | count_not_found += 1 115 | continue 116 | 117 | results.append((file, speaker, languages, text)) 118 | 119 | if count_duplicated > 0: 120 | logger.warning(f"Total duplicated files: {count_duplicated}") 121 | 122 | if count_not_found > 0: 123 | logger.warning(f"Total files not found: {count_not_found}") 124 | 125 | return results 126 | -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/llama/merge_lora.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from copy import deepcopy 3 | from pathlib import Path 4 | 5 | import click 6 | import hydra 7 | import torch 8 | from hydra import compose, initialize 9 | from hydra.utils import instantiate 10 | from loguru import logger 11 | 12 | from fish_speech.models.text2semantic.llama import BaseTransformer 13 | from fish_speech.models.text2semantic.lora import get_merged_state_dict 14 | 15 | 16 | @click.command() 17 | @click.option("--lora-config", type=str, default="r_8_alpha_16") 18 | @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4") 19 | @click.option("--lora-weight", type=str, required=True) 20 | @click.option("--output", type=str, required=True) 21 | def merge(lora_config, base_weight, lora_weight, output): 22 | output = Path(output) 23 | logger.info( 24 | f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}" 25 | ) 26 | 27 | with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"): 28 | cfg = compose(config_name=lora_config) 29 | 30 | lora_config = instantiate(cfg) 31 | logger.info(f"Loaded lora model with config {lora_config}") 32 | 33 | llama_model = BaseTransformer.from_pretrained( 34 | path=base_weight, 35 | load_weights=True, 36 | lora_config=lora_config, 37 | ) 38 | logger.info(f"Loaded llama model") 39 | 40 | llama_state_dict = llama_model.state_dict() 41 | llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k} 42 | llama_state_dict_copy = deepcopy(llama_state_dict) 43 | lora_state_dict = torch.load(lora_weight, map_location="cpu") 44 | 45 | if "state_dict" in llama_state_dict: 46 | llama_state_dict = llama_state_dict["state_dict"] 47 | 48 | if "state_dict" in lora_state_dict: 49 | lora_state_dict = lora_state_dict["state_dict"] 50 | 51 | # remove prefix model. 52 | if any(k.startswith("model.") for k in llama_state_dict.keys()): 53 | llama_state_dict = { 54 | k.replace("model.", ""): v 55 | for k, v in llama_state_dict.items() 56 | if k.startswith("model.") 57 | } 58 | if any(k.startswith("model.") for k in lora_state_dict.keys()): 59 | lora_state_dict = { 60 | k.replace("model.", ""): v 61 | for k, v in lora_state_dict.items() 62 | if k.startswith("model.") 63 | } 64 | 65 | logger.info(f"Found {len(llama_state_dict)} keys in llama model") 66 | logger.info(f"Found {len(lora_state_dict)} keys in lora model") 67 | 68 | merged_state_dict = llama_state_dict | lora_state_dict 69 | llama_model.load_state_dict(merged_state_dict, strict=True) 70 | logger.info(f"Merged model loaded") 71 | 72 | # Trigger eval mode to merge lora 73 | llama_model.eval() 74 | llama_model.save_pretrained(output, drop_lora=True) 75 | logger.info(f"Saved merged model to {output}, validating") 76 | 77 | new_state_dict = torch.load(output / "model.pth", map_location="cpu") 78 | original_keys = set(llama_state_dict_copy.keys()) 79 | merged_keys = set(new_state_dict.keys()) 80 | 81 | assert original_keys == merged_keys, "Keys should be same" 82 | 83 | for key in original_keys: 84 | diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item() 85 | if diff_l1 != 0: 86 | break 87 | else: 88 | logger.error("Merged model is same as the original model") 89 | exit(1) 90 | 91 | logger.info("Merged model is different from the original model, check passed") 92 | 93 | 94 | if __name__ == "__main__": 95 | merge() 96 | -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/llama/rebuild_tokenizer.py: -------------------------------------------------------------------------------- 1 | from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers 2 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 3 | 4 | # Initialize a tokenizer 5 | tokenizer = Tokenizer(models.BPE()) 6 | 7 | # Customize pre-tokenization and decoding 8 | tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) 9 | tokenizer.decoder = decoders.ByteLevel() 10 | tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) 11 | 12 | # Don't train the tokenizer 13 | trainer = trainers.BpeTrainer( 14 | vocab_size=0, 15 | min_frequency=2, 16 | initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), 17 | special_tokens=[ 18 | "<|begin_of_sequence|>", 19 | "<|end_of_sequence|>", 20 | "<|im_start|>", 21 | "<|im_sep|>", # system, user, assistant, etc. 22 | "<|im_end|>", 23 | "<|semantic|>", # audio features 24 | "<|pad|>", 25 | ], 26 | ) 27 | 28 | # <|im_start|>user<|im_sep|>...<|im_end|> 29 | # <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|> 30 | tokenizer.train_from_iterator([], trainer=trainer) 31 | 32 | print(len(tokenizer.get_vocab())) 33 | x = tokenizer.encode( 34 | "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>" 35 | ).ids 36 | print(x, len(x)) 37 | print(tokenizer.decode(x, skip_special_tokens=True)) 38 | 39 | 40 | tokenizer = PreTrainedTokenizerFast( 41 | tokenizer_object=tokenizer, 42 | pad_token="<|pad|>", 43 | bos_token="<|begin_of_sequence|>", 44 | eos_token="<|end_of_sequence|>", 45 | ) 46 | 47 | # Try tokenizing a new sequence 48 | sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>" 49 | encoded = tokenizer(sequence).input_ids 50 | 51 | print("Test encoding....") 52 | print(f"\tSentence: {sequence}") 53 | print(f"\tEncoded: {encoded}") 54 | print(f"\tDecoded: {tokenizer.batch_decode(encoded)}") 55 | print(f"\tDecoded: {tokenizer.decode(encoded)}") 56 | 57 | tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True) 58 | -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/msgpack_api.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import ormsgpack 3 | 4 | from tools.commons import ServeReferenceAudio, ServeTTSRequest 5 | 6 | 7 | def audio_request(): 8 | # priority: ref_id > references 9 | request = ServeTTSRequest( 10 | text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.", 11 | # reference_id="114514", 12 | references=[ 13 | ServeReferenceAudio( 14 | audio=open("lengyue.wav", "rb").read(), 15 | text=open("lengyue.lab", "r", encoding="utf-8").read(), 16 | ) 17 | ], 18 | streaming=True, 19 | ) 20 | 21 | with ( 22 | httpx.Client() as client, 23 | open("hello.wav", "wb") as f, 24 | ): 25 | with client.stream( 26 | "POST", 27 | "http://127.0.0.1:8080/v1/tts", 28 | content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), 29 | headers={ 30 | "authorization": "Bearer YOUR_API_KEY", 31 | "content-type": "application/msgpack", 32 | }, 33 | timeout=None, 34 | ) as response: 35 | for chunk in response.iter_bytes(): 36 | f.write(chunk) 37 | 38 | 39 | def asr_request(): 40 | 41 | # Read the audio file 42 | with open( 43 | r"D:\PythonProject\fish-speech\.cache\test_audios\prompts\2648200402409733590.wav", 44 | "rb", 45 | ) as audio_file: 46 | audio_data = audio_file.read() 47 | 48 | # Prepare the request data 49 | request_data = { 50 | "audio": audio_data, 51 | "language": "en", # Optional: specify the language 52 | "ignore_timestamps": False, # Optional: set to True to ignore precise timestamps 53 | } 54 | 55 | # Send the request 56 | with httpx.Client() as client: 57 | response = client.post( 58 | "https://api.fish.audio/v1/asr", 59 | headers={ 60 | "Authorization": "Bearer 8eda4aeed2bc4aec9489b3efad003799", 61 | "Content-Type": "application/msgpack", 62 | }, 63 | content=ormsgpack.packb(request_data), 64 | ) 65 | 66 | # Parse the response 67 | result = response.json() 68 | 69 | print(f"Transcribed text: {result['text']}") 70 | print(f"Audio duration: {result['duration']} seconds") 71 | 72 | for segment in result["segments"]: 73 | print(f"Segment: {segment['text']}") 74 | print(f"Start time: {segment['start']}, End time: {segment['end']}") 75 | 76 | 77 | if __name__ == "__main__": 78 | asr_request() 79 | -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/sensevoice/README.md: -------------------------------------------------------------------------------- 1 | # FunASR Command Line Interface 2 | 3 | This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files. 4 | 5 | ## Requirements 6 | 7 | - Python >= 3.10 8 | - PyTorch <= 2.3.1 9 | - ffmpeg, pydub, audio-separator[gpu]. 10 | 11 | ## Installation 12 | 13 | Install the required packages: 14 | 15 | ```bash 16 | pip install -e .[stable] 17 | ``` 18 | 19 | Make sure you have `ffmpeg` installed and available in your `PATH`. 20 | 21 | ## Usage 22 | 23 | ### Basic Usage 24 | 25 | To run the tool with default settings: 26 | 27 | ```bash 28 | python tools/sensevoice/fun_asr.py --audio-dir --save-dir 29 | ``` 30 | 31 | ## Options 32 | 33 | | Option | Description | 34 | | :-----------------------: | :---------------------------------------------------------------------------: | 35 | | --audio-dir | Directory containing audio or video files. | 36 | | --save-dir | Directory to save processed audio files. | 37 | | --device | Device to use for processing. Options: cuda (default) or cpu. | 38 | | --language | Language of the transcription. Default is auto. | 39 | | --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. | 40 | | --punc | Enable punctuation prediction. | 41 | | --denoise | Enable noise reduction (vocal separation). | 42 | 43 | ## Example 44 | 45 | To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled: 46 | 47 | ```bash 48 | python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise 49 | ``` 50 | 51 | ## Additional Notes 52 | 53 | - The tool supports `both audio and video files`. Videos will be converted to audio automatically. 54 | - If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks. 55 | - The script will automatically create necessary directories in the `--save-dir`. 56 | 57 | ## Troubleshooting 58 | 59 | If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency. 60 | -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/sensevoice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/fish-speech/tools/sensevoice/__init__.py -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/sensevoice/vad_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | 4 | 5 | def slice_padding_fbank(speech, speech_lengths, vad_segments): 6 | speech_list = [] 7 | speech_lengths_list = [] 8 | for i, segment in enumerate(vad_segments): 9 | 10 | bed_idx = int(segment[0][0] * 16) 11 | end_idx = min(int(segment[0][1] * 16), speech_lengths[0]) 12 | speech_i = speech[0, bed_idx:end_idx] 13 | speech_lengths_i = end_idx - bed_idx 14 | speech_list.append(speech_i) 15 | speech_lengths_list.append(speech_lengths_i) 16 | feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0) 17 | speech_lengths_pad = torch.Tensor(speech_lengths_list).int() 18 | return feats_pad, speech_lengths_pad 19 | 20 | 21 | def slice_padding_audio_samples(speech, speech_lengths, vad_segments): 22 | speech_list = [] 23 | speech_lengths_list = [] 24 | intervals = [] 25 | for i, segment in enumerate(vad_segments): 26 | bed_idx = int(segment[0][0] * 16) 27 | end_idx = min(int(segment[0][1] * 16), speech_lengths) 28 | speech_i = speech[bed_idx:end_idx] 29 | speech_lengths_i = end_idx - bed_idx 30 | speech_list.append(speech_i) 31 | speech_lengths_list.append(speech_lengths_i) 32 | intervals.append([bed_idx // 16, end_idx // 16]) 33 | 34 | return speech_list, speech_lengths_list, intervals 35 | 36 | 37 | def merge_vad(vad_result, max_length=15000, min_length=0): 38 | new_result = [] 39 | if len(vad_result) <= 1: 40 | return vad_result 41 | time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result] 42 | time_step = sorted(list(set(time_step))) 43 | if len(time_step) == 0: 44 | return [] 45 | bg = 0 46 | for i in range(len(time_step) - 1): 47 | time = time_step[i] 48 | if time_step[i + 1] - bg < max_length: 49 | continue 50 | if time - bg > min_length: 51 | new_result.append([bg, time]) 52 | # if time - bg < max_length * 1.5: 53 | # new_result.append([bg, time]) 54 | # else: 55 | # split_num = int(time - bg) // max_length + 1 56 | # spl_l = int(time - bg) // split_num 57 | # for j in range(split_num): 58 | # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l]) 59 | bg = time 60 | new_result.append([bg, time_step[-1]]) 61 | return new_result 62 | -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/smart_pad.py: -------------------------------------------------------------------------------- 1 | import random 2 | from multiprocessing import Pool 3 | from pathlib import Path 4 | 5 | import click 6 | import librosa 7 | import torch.nn.functional as F 8 | import torchaudio 9 | from tqdm import tqdm 10 | 11 | from tools.file import AUDIO_EXTENSIONS, list_files 12 | 13 | threshold = 10 ** (-50 / 20.0) 14 | 15 | 16 | def process(file): 17 | waveform, sample_rate = torchaudio.load(str(file), backend="sox") 18 | if waveform.size(0) > 1: 19 | waveform = waveform.mean(dim=0, keepdim=True) 20 | 21 | loudness = librosa.feature.rms( 22 | y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True 23 | )[0] 24 | 25 | for i in range(len(loudness) - 1, 0, -1): 26 | if loudness[i] > threshold: 27 | break 28 | 29 | end_silent_time = (len(loudness) - i) * 512 / sample_rate 30 | 31 | if end_silent_time <= 0.3: 32 | random_time = random.uniform(0.3, 0.7) - end_silent_time 33 | waveform = F.pad( 34 | waveform, (0, int(random_time * sample_rate)), mode="constant", value=0 35 | ) 36 | 37 | for i in range(len(loudness)): 38 | if loudness[i] > threshold: 39 | break 40 | 41 | start_silent_time = i * 512 / sample_rate 42 | 43 | if start_silent_time > 0.02: 44 | waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :] 45 | 46 | torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate) 47 | 48 | 49 | @click.command() 50 | @click.argument("source", type=Path) 51 | @click.option("--num-workers", type=int, default=12) 52 | def main(source, num_workers): 53 | files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True)) 54 | 55 | with Pool(num_workers) as p: 56 | list(tqdm(p.imap_unordered(process, files), total=len(files))) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /external_modules/fish-speech/tools/vqgan/create_train_split.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pathlib import Path 3 | from random import Random 4 | 5 | import click 6 | from loguru import logger 7 | from pydub import AudioSegment 8 | from tqdm import tqdm 9 | 10 | from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist 11 | 12 | 13 | @click.command() 14 | @click.argument("root", type=click.Path(exists=True, path_type=Path)) 15 | @click.option("--val-ratio", type=float, default=None) 16 | @click.option("--val-count", type=int, default=None) 17 | @click.option("--filelist", default=None, type=Path) 18 | @click.option("--min-duration", default=None, type=float) 19 | @click.option("--max-duration", default=None, type=float) 20 | def main(root, val_ratio, val_count, filelist, min_duration, max_duration): 21 | if filelist: 22 | files = [i[0] for i in load_filelist(filelist)] 23 | else: 24 | files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) 25 | 26 | if min_duration is None and max_duration is None: 27 | filtered_files = list(map(str, [file.relative_to(root) for file in files])) 28 | else: 29 | filtered_files = [] 30 | for file in tqdm(files): 31 | try: 32 | audio = AudioSegment.from_file(str(file)) 33 | duration = len(audio) / 1000.0 34 | 35 | if min_duration is not None and duration < min_duration: 36 | logger.info( 37 | f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}" 38 | ) 39 | continue 40 | 41 | if max_duration is not None and duration > max_duration: 42 | logger.info( 43 | f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}" 44 | ) 45 | continue 46 | 47 | filtered_files.append(str(file.relative_to(root))) 48 | except Exception as e: 49 | logger.info(f"Error processing {file}: {e}") 50 | 51 | logger.info( 52 | f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering" 53 | ) 54 | 55 | Random(42).shuffle(filtered_files) 56 | 57 | if val_count is None and val_ratio is None: 58 | logger.info("Validation ratio and count not specified, using min(20%, 100)") 59 | val_size = min(100, math.ceil(len(filtered_files) * 0.2)) 60 | elif val_count is not None and val_ratio is not None: 61 | logger.error("Cannot specify both val_count and val_ratio") 62 | return 63 | elif val_count is not None: 64 | if val_count < 1 or val_count > len(filtered_files): 65 | logger.error("val_count must be between 1 and number of files") 66 | return 67 | val_size = val_count 68 | else: 69 | val_size = math.ceil(len(filtered_files) * val_ratio) 70 | 71 | logger.info(f"Using {val_size} files for validation") 72 | 73 | with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f: 74 | f.write("\n".join(filtered_files[val_size:])) 75 | 76 | with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f: 77 | f.write("\n".join(filtered_files[:val_size])) 78 | 79 | logger.info("Done") 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /external_modules/ultralight/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | dataset 3 | *.wav 4 | *.npy 5 | __pycache__ 6 | *.mp4 7 | checkpoint 8 | syncnet_checkpoint 9 | *.jpg 10 | data_utils/encoder.onnx 11 | __MACOSX 12 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/FeaturePipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import yaml 3 | import torch 4 | import torchaudio.compliance.kaldi as kaldi 5 | 6 | class Feature_Pipeline(): 7 | 8 | def __init__(self, engine_config): 9 | #with open(model_config_path, 'r') as fin: 10 | # self.configs = yaml.load(fin, Loader=yaml.FullLoader) 11 | self.configs = engine_config 12 | self.num_mel_bins = self.configs['data_conf']['fbank_conf']['num_mel_bins'] # 80 13 | self.frame_length = self.configs['data_conf']['fbank_conf']['frame_length'] # 25 14 | # self.frame_shift = 1 15 | self.frame_shift = self.configs['data_conf']['fbank_conf']['frame_shift'] # 10 16 | self.dither = self.configs['data_conf']['fbank_conf']['dither'] # 0.0 17 | self.sample_rate = self.configs['engine_sample_rate_hertz'] # 16000 18 | #self.feature_queue_ = torch.tensor([[0]]) 19 | self.first_wav_ = b'' 20 | self.remained_wav_ = b'' 21 | self._waveform = b'' #torch.tensor([[0]]) 22 | self.exist_endpoint = False 23 | 24 | def AcceptWaveform(self, audio): # audio: b'' 25 | first, self.remained_wav_, ExitEndpoint = self.vad.endpoint_detect(audio) 26 | self._waveform += first 27 | #if ExitEndpoint: 28 | #self._waveform = self._waveform + self.first_wav_ 29 | #else: 30 | # self._waveform 31 | #self.remained_wav_ = second_wav 32 | ''' 33 | feat, feat_length = self._extact_feature(waveform) 34 | if self.feature_queue_.shape[1]==1: 35 | self.feature_queue_ = feat 36 | else: 37 | self.feature_queue_ = torch.cat((self.feature_queue_, feat), 1) 38 | ''' 39 | self.exist_endpoint = ExitEndpoint 40 | #self.mutex.release() 41 | #print('待计算音频长度:', len(self._waveform)) 42 | return ExitEndpoint #, self.feature_queue_.shape[1] 43 | 44 | def _extract_feature(self, waveform_int16): 45 | #print(waveform_int16.shape) 46 | #assert max(waveform_int16.shape) > 512 47 | feat = kaldi.fbank(waveform_int16, 48 | num_mel_bins=self.num_mel_bins, 49 | frame_length=self.frame_length, 50 | frame_shift=self.frame_shift, 51 | dither=self.dither, 52 | energy_floor=0.0, 53 | sample_frequency=self.sample_rate 54 | ) 55 | feat = feat.unsqueeze(0) #.to(device) 56 | feat_length = torch.IntTensor([feat.size()[1]]) 57 | return feat, feat_length 58 | 59 | 60 | def Reset(self): 61 | self.remained_wav_ = b'' 62 | self._waveform = b'' 63 | #self.feature_queue_ = torch.tensor([[0]]) 64 | 65 | def get_waveform_len(self): 66 | return len(self._waveform) 67 | 68 | def ReadFeats(self): 69 | if len(self._waveform) < 512: 70 | return None 71 | waveform = np.frombuffer(self._waveform, dtype=np.int16) 72 | waveform = torch.from_numpy(waveform).float().unsqueeze(0) 73 | feat, feat_length = self._extract_feature(waveform) 74 | #self._waveform = b'' 75 | if self.exist_endpoint: 76 | self._waveform = self.remained_wav_ 77 | else: self._waveform = b'' 78 | return feat #, self.exist_endpoint 79 | 80 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/checkpoint_epoch_335.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/ultralight/data_utils/checkpoint_epoch_335.pth.tar -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/conf/READ.md: -------------------------------------------------------------------------------- 1 | V1:multicn数据集,unified conformer 2 | 3 | V2:wenetspeech ,conformer 自己训练的 4 | 5 | V3:wenetspeech ,conformer 开源 6 | 7 | V4:wenetspeech, unified conformer 自己训练的 8 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/conf/decode_engine.yaml: -------------------------------------------------------------------------------- 1 | cmvn_file: conf/unified_conformer/global_cmvn 2 | #dataset_conf: 3 | data_conf: 4 | batch_conf: 5 | batch_size: 1 6 | batch_type: static 7 | fbank_conf: 8 | dither: 0.0 9 | frame_length: 25 10 | frame_shift: 10 11 | num_mel_bins: 80 12 | filter_conf: 13 | max_length: 102400 #40960 14 | min_length: 0 15 | token_max_length: 102400 #200 16 | token_min_length: 0 #1 17 | resample_conf: 18 | resample_rate: 16000 19 | shuffle: False 20 | shuffle_conf: 21 | shuffle_size: 1500 22 | sort: False 23 | sort_conf: 24 | sort_size: 500 25 | spec_aug: False 26 | spec_aug_conf: 27 | max_f: 10 28 | max_t: 50 29 | num_f_mask: 2 30 | num_t_mask: 2 31 | speed_perturb: False 32 | decoder: transformer 33 | decoder_conf: 34 | attention_heads: 4 35 | dropout_rate: 0.1 36 | linear_units: 2048 37 | num_blocks: 6 38 | positional_dropout_rate: 0.1 39 | self_attention_dropout_rate: 0.0 40 | src_attention_dropout_rate: 0.0 41 | encoder: conformer 42 | encoder_conf: 43 | activation_type: swish 44 | attention_dropout_rate: 0.0 45 | attention_heads: 4 46 | causal: true 47 | cnn_module_kernel: 15 48 | cnn_module_norm: layer_norm 49 | dropout_rate: 0.1 50 | input_layer: conv2d 51 | linear_units: 2048 52 | normalize_before: true 53 | num_blocks: 12 54 | output_size: 256 55 | pos_enc_layer_type: rel_pos 56 | positional_dropout_rate: 0.1 57 | selfattention_layer_type: rel_selfattn 58 | use_cnn_module: true 59 | use_dynamic_chunk: true 60 | use_dynamic_left_chunk: false 61 | input_dim: 80 62 | is_json_cmvn: true 63 | log_interval: 100 64 | model_conf: 65 | ctc_weight: 0.5 66 | length_normalized_loss: false 67 | lsm_weight: 0.1 68 | #output_dim: 4233 69 | output_dim: 11008 70 | raw_wav: true 71 | 72 | reverse_weight: 0 73 | 74 | engine_sample_rate_hertz: 16000 75 | engine_max_decoders: 20 76 | engine_max_inactivity_secs: 3 77 | engine_vad_aggressiveness: 2 78 | 79 | #model_path: conf/unified_conformer_tel/19.pt 80 | model_path: conf/unified_conformer/final.pt 81 | dict_path: conf/unified_conformer/lang_char.txt 82 | 83 | beam_size: 10 84 | mode: ctc_greedy_search 85 | decoding_chunk_size: 16 #11 86 | num_decoding_left_chunks: -1 87 | #override_config: 88 | #penalty: 89 | save_wave: True 90 | audio_save_path: /data/kzx/datasets/server_collected_wav 91 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/conf/decode_engine_V4.yaml: -------------------------------------------------------------------------------- 1 | accum_grad: 16 2 | cmvn_file: conf/wenetspeech_unified_conformer/global_cmvn 3 | #dataset_conf: 4 | data_conf: 5 | batch_conf: 6 | batch_size: 10 7 | batch_type: static 8 | fbank_conf: 9 | dither: 0 #1.0 10 | frame_length: 25 11 | frame_shift: 10 12 | num_mel_bins: 80 13 | filter_conf: 14 | max_length: 1200 15 | min_length: 10 16 | token_max_length: 100 17 | token_min_length: 1 18 | resample_conf: 19 | resample_rate: 16000 20 | shuffle: true 21 | shuffle_conf: 22 | shuffle_size: 1500 23 | sort: true 24 | sort_conf: 25 | sort_size: 1000 26 | spec_aug: true 27 | spec_aug_conf: 28 | max_f: 30 29 | max_t: 50 30 | num_f_mask: 2 31 | num_t_mask: 2 32 | speed_perturb: false 33 | decoder: transformer 34 | decoder_conf: 35 | attention_heads: 8 36 | dropout_rate: 0.1 37 | linear_units: 2048 38 | num_blocks: 6 39 | positional_dropout_rate: 0.1 40 | self_attention_dropout_rate: 0.0 41 | src_attention_dropout_rate: 0.0 42 | encoder: conformer 43 | encoder_conf: 44 | activation_type: swish 45 | attention_dropout_rate: 0.0 46 | attention_heads: 8 47 | causal: true 48 | cnn_module_kernel: 15 49 | cnn_module_norm: layer_norm 50 | dropout_rate: 0.1 51 | input_layer: conv2d 52 | linear_units: 2048 53 | normalize_before: true 54 | num_blocks: 12 55 | output_size: 512 56 | pos_enc_layer_type: rel_pos 57 | positional_dropout_rate: 0.1 58 | selfattention_layer_type: rel_selfattn 59 | use_cnn_module: true 60 | use_dynamic_chunk: true 61 | use_dynamic_left_chunk: false 62 | grad_clip: 5 63 | input_dim: 80 64 | is_json_cmvn: true 65 | log_interval: 100 66 | max_epoch: 80 67 | model_conf: 68 | ctc_weight: 0.3 69 | length_normalized_loss: false 70 | lsm_weight: 0.1 71 | optim: adam 72 | optim_conf: 73 | lr: 0.001 74 | output_dim: 5538 75 | scheduler: warmuplr 76 | scheduler_conf: 77 | warmup_steps: 5000 78 | 79 | reverse_weight: 0 80 | 81 | engine_sample_rate_hertz: 16000 82 | engine_max_decoders: 2 83 | engine_max_inactivity_secs: 3 84 | engine_vad_aggressiveness: 1 85 | 86 | model_path: conf/wenetspeech_unified_conformer/avg10.pt 87 | dict_path: conf/wenetspeech_unified_conformer/words.txt 88 | 89 | beam_size: 10 90 | mode: ctc_greedy_search 91 | decoding_chunk_size: 16 92 | num_decoding_left_chunks: 1 #-1 93 | save_wave: True 94 | audio_save_path: /data/kzx/datasets/server_collected_wav 95 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/conf/wenetspeech_unified_conformer/global_cmvn: -------------------------------------------------------------------------------- 1 | {"mean_stat": [94093639680.0, 99601817600.0, 105945178112.0, 109973512192.0, 113765081088.0, 116411883520.0, 118275350528.0, 119089381376.0, 120471224320.0, 120857878528.0, 122253713408.0, 122959314944.0, 124104908800.0, 124943925248.0, 125086113792.0, 125526827008.0, 125499613184.0, 125373145088.0, 126521786368.0, 125454884864.0, 124538224640.0, 126333992960.0, 125070983168.0, 126151393280.0, 125602439168.0, 126395072512.0, 125509885952.0, 126410891264.0, 125762912256.0, 125595451392.0, 126002593792.0, 125733101568.0, 125873864704.0, 126414536704.0, 126228545536.0, 126687256576.0, 127307743232.0, 127302467584.0, 127395962880.0, 127164047360.0, 127362924544.0, 127042977792.0, 127896125440.0, 127576309760.0, 127937052672.0, 128511803392.0, 129259085824.0, 129270480896.0, 129519296512.0, 130338086912.0, 129834475520.0, 129682448384.0, 129816797184.0, 130217508864.0, 130445729792.0, 130589614080.0, 131225165824.0, 131323748352.0, 130621440000.0, 130219769856.0, 129976082432.0, 129137131520.0, 128767655936.0, 128216735744.0, 127710003200.0, 126975082496.0, 126253596672.0, 126021091328.0, 125921501184.0, 125329408000.0, 124847611904.0, 124577431552.0, 123957854208.0, 123273756672.0, 122255400960.0, 121058467840.0, 119279566848.0, 115862519808.0, 110746214400.0, 103604412416.0], "var_stat": [1208517656576.0, 1353929457664.0, 1517070712832.0, 1619039092736.0, 1723396784128.0, 1800768585728.0, 1858094366720.0, 1886653906944.0, 1926771113984.0, 1935567486976.0, 1975271030784.0, 1996454100992.0, 2030957756416.0, 2056672378880.0, 2060925927424.0, 2073670451200.0, 2071954718720.0, 2066790088704.0, 2101332672512.0, 2067731447808.0, 2040147476480.0, 2093775847424.0, 2054495010816.0, 2086421921792.0, 2068710555648.0, 2093016547328.0, 2064243752960.0, 2093339901952.0, 2073059655680.0, 2066486788096.0, 2079567642624.0, 2069787049984.0, 2074533953536.0, 2091364253696.0, 2084653105152.0, 2099241680896.0, 2119383777280.0, 2118917423104.0, 2121075523584.0, 2115089596416.0, 2120015020032.0, 2110308483072.0, 2137911721984.0, 2127944482816.0, 2139145502720.0, 2158215168000.0, 2184207400960.0, 2184671133696.0, 2193423466496.0, 2220860637184.0, 2204302049280.0, 2199531814912.0, 2204159180800.0, 2217718054912.0, 2226764906496.0, 2232647942144.0, 2253489700864.0, 2256605020160.0, 2233431228416.0, 2219367202816.0, 2211371286528.0, 2183647330304.0, 2171116978176.0, 2153920987136.0, 2137535152128.0, 2114695462912.0, 2092549144576.0, 2085652529152.0, 2083979395072.0, 2066295947264.0, 2052445962240.0, 2045596139520.0, 2027642421248.0, 2008141135872.0, 1977202114560.0, 1940280967168.0, 1887087362048.0, 1787704639488.0, 1642100162560.0, 1448020017152.0], "frame_num": 8038121679} -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/conf/wenetspeech_unified_conformer/train.yaml: -------------------------------------------------------------------------------- 1 | accum_grad: 16 2 | cmvn_file: exp/unified_conformer/global_cmvn 3 | dataset_conf: 4 | batch_conf: 5 | batch_size: 10 6 | batch_type: static 7 | fbank_conf: 8 | dither: 1.0 9 | frame_length: 25 10 | frame_shift: 10 11 | num_mel_bins: 80 12 | filter_conf: 13 | max_length: 1200 14 | min_length: 10 15 | token_max_length: 100 16 | token_min_length: 1 17 | resample_conf: 18 | resample_rate: 16000 19 | shuffle: true 20 | shuffle_conf: 21 | shuffle_size: 1500 22 | sort: true 23 | sort_conf: 24 | sort_size: 1000 25 | spec_aug: true 26 | spec_aug_conf: 27 | max_f: 30 28 | max_t: 50 29 | num_f_mask: 2 30 | num_t_mask: 2 31 | speed_perturb: false 32 | decoder: transformer 33 | decoder_conf: 34 | attention_heads: 8 35 | dropout_rate: 0.1 36 | linear_units: 2048 37 | num_blocks: 6 38 | positional_dropout_rate: 0.1 39 | self_attention_dropout_rate: 0.0 40 | src_attention_dropout_rate: 0.0 41 | encoder: conformer 42 | encoder_conf: 43 | activation_type: swish 44 | attention_dropout_rate: 0.0 45 | attention_heads: 8 46 | causal: true 47 | cnn_module_kernel: 15 48 | cnn_module_norm: layer_norm 49 | dropout_rate: 0.1 50 | input_layer: conv2d 51 | linear_units: 2048 52 | normalize_before: true 53 | num_blocks: 12 54 | output_size: 512 55 | pos_enc_layer_type: rel_pos 56 | positional_dropout_rate: 0.1 57 | selfattention_layer_type: rel_selfattn 58 | use_cnn_module: true 59 | use_dynamic_chunk: true 60 | use_dynamic_left_chunk: false 61 | grad_clip: 5 62 | input_dim: 80 63 | is_json_cmvn: true 64 | log_interval: 100 65 | max_epoch: 80 66 | model_conf: 67 | ctc_weight: 0.3 68 | length_normalized_loss: false 69 | lsm_weight: 0.1 70 | optim: adam 71 | optim_conf: 72 | lr: 0.001 73 | output_dim: 5538 74 | scheduler: warmuplr 75 | scheduler_conf: 76 | warmup_steps: 5000 77 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/decode_engine_V4.yaml: -------------------------------------------------------------------------------- 1 | accum_grad: 16 2 | cmvn_file: conf/wenetspeech_unified_conformer/global_cmvn 3 | #dataset_conf: 4 | data_conf: 5 | batch_conf: 6 | batch_size: 10 7 | batch_type: static 8 | fbank_conf: 9 | dither: 0 #1.0 10 | frame_length: 25 11 | frame_shift: 10 12 | num_mel_bins: 80 13 | filter_conf: 14 | max_length: 1200 15 | min_length: 10 16 | token_max_length: 100 17 | token_min_length: 1 18 | resample_conf: 19 | resample_rate: 16000 20 | shuffle: true 21 | shuffle_conf: 22 | shuffle_size: 1500 23 | sort: true 24 | sort_conf: 25 | sort_size: 1000 26 | spec_aug: true 27 | spec_aug_conf: 28 | max_f: 30 29 | max_t: 50 30 | num_f_mask: 2 31 | num_t_mask: 2 32 | speed_perturb: false 33 | decoder: transformer 34 | decoder_conf: 35 | attention_heads: 8 36 | dropout_rate: 0.1 37 | linear_units: 2048 38 | num_blocks: 6 39 | positional_dropout_rate: 0.1 40 | self_attention_dropout_rate: 0.0 41 | src_attention_dropout_rate: 0.0 42 | encoder: conformer 43 | encoder_conf: 44 | activation_type: swish 45 | attention_dropout_rate: 0.0 46 | attention_heads: 8 47 | causal: true 48 | cnn_module_kernel: 15 49 | cnn_module_norm: layer_norm 50 | dropout_rate: 0.1 51 | input_layer: conv2d 52 | linear_units: 2048 53 | normalize_before: true 54 | num_blocks: 12 55 | output_size: 512 56 | pos_enc_layer_type: rel_pos 57 | positional_dropout_rate: 0.1 58 | selfattention_layer_type: rel_selfattn 59 | use_cnn_module: true 60 | use_dynamic_chunk: true 61 | use_dynamic_left_chunk: false 62 | grad_clip: 5 63 | input_dim: 80 64 | is_json_cmvn: true 65 | log_interval: 100 66 | max_epoch: 80 67 | model_conf: 68 | ctc_weight: 0.3 69 | length_normalized_loss: false 70 | lsm_weight: 0.1 71 | optim: adam 72 | optim_conf: 73 | lr: 0.001 74 | output_dim: 5538 75 | scheduler: warmuplr 76 | scheduler_conf: 77 | warmup_steps: 5000 78 | 79 | reverse_weight: 0 80 | 81 | engine_sample_rate_hertz: 16000 82 | engine_max_decoders: 2 83 | engine_max_inactivity_secs: 3 84 | engine_vad_aggressiveness: 1 85 | 86 | model_path: conf/wenetspeech_unified_conformer/avg10.pt 87 | dict_path: conf/wenetspeech_unified_conformer/words.txt 88 | 89 | beam_size: 10 90 | mode: ctc_greedy_search 91 | decoding_chunk_size: 16 92 | num_decoding_left_chunks: 1 #-1 93 | save_wave: True 94 | audio_save_path: /data/kzx/datasets/server_collected_wav 95 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/mean_face.txt: -------------------------------------------------------------------------------- 1 | 0.07823661 0.22561455 0.07775262 0.28360514 0.07767719 0.34125846 0.07962388 0.39897107 0.0852785 0.45675877 0.0948296 0.51397081 0.10821601 0.57026014 0.12654839 0.624922 0.15092454 0.67696214 0.18117697 0.72501614 0.21636663 0.76926954 0.25593645 0.80971635 0.29881339 0.84644004 0.34358275 0.88036131 0.39073567 0.9098104 0.44371907 0.92960952 0.50159897 0.93640387 0.55961423 0.92988122 0.61295041 0.91019805 0.66039048 0.88071415 0.70537286 0.84665546 0.74849443 0.80988414 0.78839783 0.7694265 0.82393673 0.72515046 0.85446429 0.67704921 0.87904849 0.62493643 0.89749134 0.57016489 0.91091031 0.51374694 0.92042848 0.45636828 0.92597962 0.39841539 0.9278124 0.34047619 0.92761708 0.28261676 0.9269087 0.22448013 0.18333619 0.14799449 0.23011323 0.10008151 0.29151757 0.09043833 0.35420108 0.09683785 0.41239033 0.11419665 0.40433308 0.15312312 0.34988543 0.14151342 0.29250138 0.13487604 0.23594365 0.13835329 0.81760045 0.14719518 0.77032645 0.09917382 0.70833021 0.08949374 0.64518572 0.09615324 0.58674485 0.11388598 0.5950631 0.15290701 0.64974596 0.14094804 0.70746059 0.13400588 0.76445238 0.13742621 0.49899366 0.44126486 0.49926835 0.36584839 0.49954545 0.29154656 0.4997629 0.21590981 0.44175713 0.23971188 0.41748147 0.40751618 0.38250881 0.46847228 0.41250806 0.50733037 0.44348765 0.49384309 0.49968962 0.52071968 0.55587692 0.49389505 0.58745985 0.50735921 0.61725701 0.46843178 0.58200092 0.40750888 0.5582543 0.23969454 0.24420926 0.24232664 0.27853367 0.22262796 0.32139181 0.21649367 0.36352965 0.22694906 0.39399409 0.2539187 0.35570468 0.2607508 0.31553149 0.26475557 0.27670887 0.25815946 0.32108413 0.2401999 0.32039199 0.24018672 0.28731028 0.24066018 0.35345833 0.23970116 0.75291827 0.24135213 0.71813078 0.22180503 0.67499031 0.21574654 0.63271 0.2262944 0.60232031 0.25348639 0.64072191 0.26024382 0.68106141 0.26409874 0.72013163 0.25732273 0.67941506 0.23945947 0.67611672 0.23945686 0.64289896 0.23914819 0.70935429 0.23975898 0.3542491 0.64762846 0.39788583 0.61425828 0.45907503 0.59267412 0.49816209 0.60023844 0.53725534 0.5924438 0.59894311 0.61360978 0.64317025 0.6471479 0.61072743 0.69408156 0.5629294 0.72856417 0.49846957 0.74046476 0.43417159 0.72868171 0.38664602 0.69431213 0.37477622 0.64829514 0.41275304 0.63913694 0.49825266 0.63758627 0.5843274 0.63877947 0.62280633 0.64783215 0.58603199 0.66379551 0.49850656 0.6787626 0.4114343 0.66389504 -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import numpy as np 5 | 6 | def extract_audio(path, out_path, sample_rate=16000): 7 | 8 | print(f'[INFO] ===== extract audio from {path} to {out_path} =====') 9 | cmd = f'ffmpeg -y -i {path} -f wav -ar {sample_rate} {out_path}' 10 | os.system(cmd) 11 | print(f'[INFO] ===== extracted audio =====') 12 | 13 | def extract_images(path, mode): 14 | 15 | base_name = os.path.basename(path) 16 | full_body_dir = os.path.join(os.path.dirname(path), "full_body_img") 17 | if not os.path.exists(full_body_dir): 18 | os.mkdir(full_body_dir) 19 | 20 | counter = 0 21 | cap = cv2.VideoCapture(path) 22 | fps = cap.get(cv2.CAP_PROP_FPS) 23 | if mode == "hubert" and fps != 25: 24 | raise ValueError("Using hubert,your video fps should be 25!!!") 25 | if mode == "wenet" and fps != 20: 26 | raise ValueError("Using wenet,your video fps should be 20!!!") 27 | 28 | print("extracting images...") 29 | while True: 30 | ret, frame = cap.read() 31 | if not ret: 32 | break 33 | cv2.imwrite(full_body_dir+"/"+str(counter)+'.jpg', frame) 34 | counter += 1 35 | 36 | def get_audio_feature(wav_path, mode): 37 | 38 | print("extracting audio feature...") 39 | 40 | if mode == "wenet": 41 | os.system("python wenet_infer.py "+wav_path) 42 | if mode == "hubert": 43 | os.system("python hubert.py --wav "+wav_path) 44 | 45 | def get_landmark(path, landmarks_dir): 46 | print("detecting landmarks...") 47 | base_name = os.path.basename(path) 48 | full_img_dir = os.path.join(os.path.dirname(path), "full_body_img") 49 | 50 | from get_landmark import Landmark 51 | landmark = Landmark() 52 | 53 | for img_name in os.listdir(full_img_dir): 54 | if not img_name.endswith(".jpg"): 55 | continue 56 | img_path = os.path.join(full_img_dir, img_name) 57 | lms_path = os.path.join(landmarks_dir, img_name.replace(".jpg", ".lms")) 58 | pre_landmark, x1, y1 = landmark.detect(img_path) 59 | with open(lms_path, "w") as f: 60 | for p in pre_landmark: 61 | x, y = p[0]+x1, p[1]+y1 62 | f.write(str(x)) 63 | f.write(" ") 64 | f.write(str(y)) 65 | f.write("\n") 66 | 67 | if __name__ == "__main__": 68 | 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument('path', type=str, help="path to video file") 71 | parser.add_argument('--asr', type=str, default='hubert', help="wenet or hubert") 72 | opt = parser.parse_args() 73 | asr_mode = opt.asr 74 | 75 | base_dir = os.path.dirname(opt.path) 76 | wav_path = os.path.join(base_dir, 'aud.wav') 77 | landmarks_dir = os.path.join(base_dir, 'landmarks') 78 | 79 | os.makedirs(landmarks_dir, exist_ok=True) 80 | 81 | extract_audio(opt.path, wav_path) 82 | extract_images(opt.path, asr_mode) 83 | get_landmark(opt.path, landmarks_dir) 84 | get_audio_feature(wav_path, asr_mode) 85 | 86 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/scrfd_2.5g_kps.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/ultralight/data_utils/scrfd_2.5g_kps.onnx -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/bin/average_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Mobvoi Inc. All Rights Reserved. 2 | # Author: di.wu@mobvoi.com (DI WU) 3 | import os 4 | import argparse 5 | import glob 6 | 7 | import yaml 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser(description='average model') 14 | parser.add_argument('--dst_model', required=True, help='averaged model') 15 | parser.add_argument('--src_path', 16 | required=True, 17 | help='src model path for average') 18 | parser.add_argument('--val_best', 19 | action="store_true", 20 | help='averaged model') 21 | parser.add_argument('--num', 22 | default=5, 23 | type=int, 24 | help='nums for averaged model') 25 | parser.add_argument('--min_epoch', 26 | default=0, 27 | type=int, 28 | help='min epoch used for averaging model') 29 | parser.add_argument('--max_epoch', 30 | default=65536, 31 | type=int, 32 | help='max epoch used for averaging model') 33 | 34 | args = parser.parse_args() 35 | print(args) 36 | return args 37 | 38 | 39 | def main(): 40 | args = get_args() 41 | checkpoints = [] 42 | val_scores = [] 43 | if args.val_best: 44 | yamls = glob.glob('{}/[!train]*.yaml'.format(args.src_path)) 45 | for y in yamls: 46 | with open(y, 'r') as f: 47 | dic_yaml = yaml.load(f, Loader=yaml.FullLoader) 48 | loss = dic_yaml['cv_loss'] 49 | epoch = dic_yaml['epoch'] 50 | if epoch >= args.min_epoch and epoch <= args.max_epoch: 51 | val_scores += [[epoch, loss]] 52 | val_scores = np.array(val_scores) 53 | sort_idx = np.argsort(val_scores[:, -1]) 54 | sorted_val_scores = val_scores[sort_idx][::1] 55 | print("best val scores = " + str(sorted_val_scores[:args.num, 1])) 56 | print("selected epochs = " + 57 | str(sorted_val_scores[:args.num, 0].astype(np.int64))) 58 | path_list = [ 59 | args.src_path + '/{}.pt'.format(int(epoch)) 60 | for epoch in sorted_val_scores[:args.num, 0] 61 | ] 62 | else: 63 | path_list = glob.glob('{}/[!avg][!final]*.pt'.format(args.src_path)) 64 | path_list = sorted(path_list, key=os.path.getmtime) 65 | path_list = path_list[-args.num:] 66 | print(path_list) 67 | avg = None 68 | num = args.num 69 | assert num == len(path_list) 70 | for path in path_list: 71 | print('Processing {}'.format(path)) 72 | states = torch.load(path, map_location=torch.device('cpu')) 73 | if avg is None: 74 | avg = states 75 | else: 76 | for k in avg.keys(): 77 | avg[k] += states[k] 78 | # average 79 | for k in avg.keys(): 80 | if avg[k] is not None: 81 | # pytorch 1.6 use true_divide instead of /= 82 | avg[k] = torch.true_divide(avg[k], num) 83 | print('Saving to {}'.format(args.dst_model)) 84 | torch.save(avg, args.dst_model) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/bin/export_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import os 19 | 20 | import torch 21 | import yaml 22 | 23 | from wenet.transformer.asr_model import init_asr_model 24 | from wenet.utils.checkpoint import load_checkpoint 25 | 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser(description='export your script model') 29 | parser.add_argument('--config', required=True, help='config file') 30 | parser.add_argument('--checkpoint', required=True, help='checkpoint model') 31 | parser.add_argument('--output_file', required=True, help='output file') 32 | parser.add_argument('--output_quant_file', 33 | default=None, 34 | help='output quantized model file') 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def main(): 40 | args = get_args() 41 | # No need gpu for model export 42 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 43 | 44 | with open(args.config, 'r') as fin: 45 | configs = yaml.load(fin, Loader=yaml.FullLoader) 46 | model = init_asr_model(configs) 47 | print(model) 48 | 49 | load_checkpoint(model, args.checkpoint) 50 | # Export jit torch script model 51 | 52 | script_model = torch.jit.script(model) 53 | script_model.save(args.output_file) 54 | print('Export model successfully, see {}'.format(args.output_file)) 55 | 56 | # Export quantized jit torch script model 57 | if args.output_quant_file: 58 | quantized_model = torch.quantization.quantize_dynamic( 59 | model, {torch.nn.Linear}, dtype=torch.qint8 60 | ) 61 | print(quantized_model) 62 | script_quant_model = torch.jit.script(quantized_model) 63 | script_quant_model.save(args.output_quant_file) 64 | print('Export quantized model successfully, ' 65 | 'see {}'.format(args.output_quant_file)) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/transformer/cmvn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | 19 | class GlobalCMVN(torch.nn.Module): 20 | def __init__(self, 21 | mean: torch.Tensor, 22 | istd: torch.Tensor, 23 | norm_var: bool = True): 24 | """ 25 | Args: 26 | mean (torch.Tensor): mean stats 27 | istd (torch.Tensor): inverse std, std which is 1.0 / std 28 | """ 29 | super().__init__() 30 | assert mean.shape == istd.shape 31 | self.norm_var = norm_var 32 | # The buffer can be accessed from this module using self.mean 33 | self.register_buffer("mean", mean) 34 | self.register_buffer("istd", istd) 35 | 36 | def forward(self, x: torch.Tensor): 37 | """ 38 | Args: 39 | x (torch.Tensor): (batch, max_len, feat_dim) 40 | 41 | Returns: 42 | (torch.Tensor): normalized feature 43 | """ 44 | x = x - self.mean 45 | if self.norm_var: 46 | x = x * self.istd 47 | return x 48 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/transformer/ctc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typeguard import check_argument_types 4 | 5 | 6 | class CTC(torch.nn.Module): 7 | """CTC module""" 8 | def __init__( 9 | self, 10 | odim: int, 11 | encoder_output_size: int, 12 | dropout_rate: float = 0.0, 13 | reduce: bool = True, 14 | ): 15 | """ Construct CTC module 16 | Args: 17 | odim: dimension of outputs 18 | encoder_output_size: number of encoder projection units 19 | dropout_rate: dropout rate (0.0 ~ 1.0) 20 | reduce: reduce the CTC loss into a scalar 21 | """ 22 | assert check_argument_types() 23 | super().__init__() 24 | eprojs = encoder_output_size 25 | self.dropout_rate = dropout_rate 26 | self.ctc_lo = torch.nn.Linear(eprojs, odim) 27 | 28 | reduction_type = "sum" if reduce else "none" 29 | self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) 30 | 31 | def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, 32 | ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> torch.Tensor: 33 | """Calculate CTC loss. 34 | 35 | Args: 36 | hs_pad: batch of padded hidden state sequences (B, Tmax, D) 37 | hlens: batch of lengths of hidden state sequences (B) 38 | ys_pad: batch of padded character id sequence tensor (B, Lmax) 39 | ys_lens: batch of lengths of character sequence (B) 40 | """ 41 | # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab) 42 | ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) 43 | # ys_hat: (B, L, D) -> (L, B, D) 44 | ys_hat = ys_hat.transpose(0, 1) 45 | ys_hat = ys_hat.log_softmax(2) 46 | loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens) 47 | # Batch-size average 48 | loss = loss / ys_hat.size(1) 49 | return loss 50 | 51 | def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor: 52 | """log_softmax of frame activations 53 | 54 | Args: 55 | Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 56 | Returns: 57 | torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) 58 | """ 59 | return F.log_softmax(self.ctc_lo(hs_pad), dim=2) 60 | 61 | def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor: 62 | """argmax of frame activations 63 | 64 | Args: 65 | torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 66 | Returns: 67 | torch.Tensor: argmax applied 2d tensor (B, Tmax) 68 | """ 69 | return torch.argmax(self.ctc_lo(hs_pad), dim=2) 70 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | """Label smoothing module.""" 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class LabelSmoothingLoss(nn.Module): 13 | """Label-smoothing loss. 14 | 15 | In a standard CE loss, the label's data distribution is: 16 | [0,1,2] -> 17 | [ 18 | [1.0, 0.0, 0.0], 19 | [0.0, 1.0, 0.0], 20 | [0.0, 0.0, 1.0], 21 | ] 22 | 23 | In the smoothing version CE Loss,some probabilities 24 | are taken from the true label prob (1.0) and are divided 25 | among other labels. 26 | 27 | e.g. 28 | smoothing=0.1 29 | [0,1,2] -> 30 | [ 31 | [0.9, 0.05, 0.05], 32 | [0.05, 0.9, 0.05], 33 | [0.05, 0.05, 0.9], 34 | ] 35 | 36 | Args: 37 | size (int): the number of class 38 | padding_idx (int): padding class id which will be ignored for loss 39 | smoothing (float): smoothing rate (0.0 means the conventional CE) 40 | normalize_length (bool): 41 | normalize loss by sequence length if True 42 | normalize loss by batch size if False 43 | """ 44 | def __init__(self, 45 | size: int, 46 | padding_idx: int, 47 | smoothing: float, 48 | normalize_length: bool = False): 49 | """Construct an LabelSmoothingLoss object.""" 50 | super(LabelSmoothingLoss, self).__init__() 51 | self.criterion = nn.KLDivLoss(reduction="none") 52 | self.padding_idx = padding_idx 53 | self.confidence = 1.0 - smoothing 54 | self.smoothing = smoothing 55 | self.size = size 56 | self.normalize_length = normalize_length 57 | 58 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 59 | """Compute loss between x and target. 60 | 61 | The model outputs and data labels tensors are flatten to 62 | (batch*seqlen, class) shape and a mask is applied to the 63 | padding part which should not be calculated for loss. 64 | 65 | Args: 66 | x (torch.Tensor): prediction (batch, seqlen, class) 67 | target (torch.Tensor): 68 | target signal masked with self.padding_id (batch, seqlen) 69 | Returns: 70 | loss (torch.Tensor) : The KL loss, scalar float value 71 | """ 72 | assert x.size(2) == self.size 73 | batch_size = x.size(0) 74 | x = x.view(-1, self.size) 75 | target = target.view(-1) 76 | # use zeros_like instead of torch.no_grad() for true_dist, 77 | # since no_grad() can not be exported by JIT 78 | true_dist = torch.zeros_like(x) 79 | true_dist.fill_(self.smoothing / (self.size - 1)) 80 | ignore = target == self.padding_idx # (B,) 81 | total = len(target) - ignore.sum().item() 82 | target = target.masked_fill(ignore, 0) # avoid -1 index 83 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 84 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 85 | denom = total if self.normalize_length else batch_size 86 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 87 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | """Positionwise feed forward layer definition.""" 7 | 8 | import torch 9 | 10 | 11 | class PositionwiseFeedForward(torch.nn.Module): 12 | """Positionwise feed forward layer. 13 | 14 | FeedForward are appied on each position of the sequence. 15 | The output dim is same with the input dim. 16 | 17 | Args: 18 | idim (int): Input dimenstion. 19 | hidden_units (int): The number of hidden units. 20 | dropout_rate (float): Dropout rate. 21 | activation (torch.nn.Module): Activation function 22 | """ 23 | def __init__(self, 24 | idim: int, 25 | hidden_units: int, 26 | dropout_rate: float, 27 | activation: torch.nn.Module = torch.nn.ReLU()): 28 | """Construct a PositionwiseFeedForward object.""" 29 | super(PositionwiseFeedForward, self).__init__() 30 | self.w_1 = torch.nn.Linear(idim, hidden_units) 31 | self.activation = activation 32 | self.dropout = torch.nn.Dropout(dropout_rate) 33 | self.w_2 = torch.nn.Linear(hidden_units, idim) 34 | 35 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 36 | """Forward function. 37 | 38 | Args: 39 | xs: input tensor (B, L, D) 40 | Returns: 41 | output tensor, (B, L, D) 42 | """ 43 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 44 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/transformer/swish.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe) 5 | # Northwestern Polytechnical University (Pengcheng Guo) 6 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 7 | """Swish() activation function for Conformer.""" 8 | 9 | import torch 10 | 11 | 12 | class Swish(torch.nn.Module): 13 | """Construct an Swish object.""" 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | """Return Swish activation function.""" 16 | return x * torch.sigmoid(x) 17 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 2 | # Author: binbinzhang@mobvoi.com (Binbin Zhang) 3 | 4 | import logging 5 | import os 6 | import re 7 | 8 | import yaml 9 | import torch 10 | 11 | 12 | def load_checkpoint(model: torch.nn.Module, path: str) -> dict: 13 | if torch.cuda.is_available(): 14 | logging.info('Checkpoint: loading from checkpoint %s for GPU' % path) 15 | checkpoint = torch.load(path) 16 | else: 17 | logging.info('Checkpoint: loading from checkpoint %s for CPU' % path) 18 | checkpoint = torch.load(path, map_location='cpu') 19 | model.load_state_dict(checkpoint) 20 | info_path = re.sub('.pt$', '.yaml', path) 21 | configs = {} 22 | if os.path.exists(info_path): 23 | with open(info_path, 'r') as fin: 24 | configs = yaml.load(fin, Loader=yaml.FullLoader) 25 | return configs 26 | 27 | 28 | def save_checkpoint(model: torch.nn.Module, path: str, infos=None): 29 | ''' 30 | Args: 31 | infos (dict or None): any info you want to save. 32 | ''' 33 | logging.info('Checkpoint: save to checkpoint %s' % path) 34 | if isinstance(model, torch.nn.DataParallel): 35 | state_dict = model.module.state_dict() 36 | elif isinstance(model, torch.nn.parallel.DistributedDataParallel): 37 | state_dict = model.module.state_dict() 38 | else: 39 | state_dict = model.state_dict() 40 | torch.save(state_dict, path) 41 | info_path = re.sub('.pt$', '.yaml', path) 42 | if infos is None: 43 | infos = {} 44 | with open(info_path, 'w') as fout: 45 | data = yaml.dump(infos) 46 | fout.write(data) 47 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/utils/cmvn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | import math 18 | 19 | import numpy as np 20 | 21 | 22 | def _load_json_cmvn(json_cmvn_file): 23 | """ Load the json format cmvn stats file and calculate cmvn 24 | 25 | Args: 26 | json_cmvn_file: cmvn stats file in json format 27 | 28 | Returns: 29 | a numpy array of [means, vars] 30 | """ 31 | with open(json_cmvn_file) as f: 32 | cmvn_stats = json.load(f) 33 | 34 | means = cmvn_stats['mean_stat'] 35 | variance = cmvn_stats['var_stat'] 36 | count = cmvn_stats['frame_num'] 37 | for i in range(len(means)): 38 | means[i] /= count 39 | variance[i] = variance[i] / count - means[i] * means[i] 40 | if variance[i] < 1.0e-20: 41 | variance[i] = 1.0e-20 42 | variance[i] = 1.0 / math.sqrt(variance[i]) 43 | cmvn = np.array([means, variance]) 44 | return cmvn 45 | 46 | 47 | def _load_kaldi_cmvn(kaldi_cmvn_file): 48 | """ Load the kaldi format cmvn stats file and calculate cmvn 49 | 50 | Args: 51 | kaldi_cmvn_file: kaldi text style global cmvn file, which 52 | is generated by: 53 | compute-cmvn-stats --binary=false scp:feats.scp global_cmvn 54 | 55 | Returns: 56 | a numpy array of [means, vars] 57 | """ 58 | means = [] 59 | variance = [] 60 | with open(kaldi_cmvn_file, 'r') as fid: 61 | # kaldi binary file start with '\0B' 62 | if fid.read(2) == '\0B': 63 | logging.error('kaldi cmvn binary file is not supported, please ' 64 | 'recompute it by: compute-cmvn-stats --binary=false ' 65 | ' scp:feats.scp global_cmvn') 66 | sys.exit(1) 67 | fid.seek(0) 68 | arr = fid.read().split() 69 | assert (arr[0] == '[') 70 | assert (arr[-2] == '0') 71 | assert (arr[-1] == ']') 72 | feat_dim = int((len(arr) - 2 - 2) / 2) 73 | for i in range(1, feat_dim + 1): 74 | means.append(float(arr[i])) 75 | count = float(arr[feat_dim + 1]) 76 | for i in range(feat_dim + 2, 2 * feat_dim + 2): 77 | variance.append(float(arr[i])) 78 | 79 | for i in range(len(means)): 80 | means[i] /= count 81 | variance[i] = variance[i] / count - means[i] * means[i] 82 | if variance[i] < 1.0e-20: 83 | variance[i] = 1.0e-20 84 | variance[i] = 1.0 / math.sqrt(variance[i]) 85 | cmvn = np.array([means, variance]) 86 | return cmvn 87 | 88 | 89 | def load_cmvn(cmvn_file, is_json): 90 | if is_json: 91 | cmvn = _load_json_cmvn(cmvn_file) 92 | else: 93 | cmvn = _load_kaldi_cmvn(cmvn_file) 94 | return cmvn[0], cmvn[1] 95 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/utils/config.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | def override_config(configs, override_list): 4 | new_configs = copy.deepcopy(configs) 5 | for item in override_list: 6 | arr = item.split() 7 | if len(arr) != 2: 8 | print(f"the overrive {item} format not correct, skip it") 9 | continue 10 | keys = arr[0].split('.') 11 | s_configs = new_configs 12 | for i, key in enumerate(keys): 13 | if key not in s_configs: 14 | print(f"the overrive {item} format not correct, skip it") 15 | if i == len(keys) - 1: 16 | param_type = type(s_configs[key]) 17 | s_configs[key] = param_type(arr[1]) 18 | print(f"override {arr[0]} with {arr[1]}") 19 | else: 20 | s_configs = s_configs[key] 21 | return new_configs 22 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/utils/ctc_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Mobvoi Inc. All Rights Reserved. 2 | # Author: binbinzhang@mobvoi.com (Di Wu) 3 | 4 | import numpy as np 5 | import torch 6 | 7 | def insert_blank(label, blank_id=0): 8 | """Insert blank token between every two label token.""" 9 | label = np.expand_dims(label, 1) 10 | blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id 11 | label = np.concatenate([blanks, label], axis=1) 12 | label = label.reshape(-1) 13 | label = np.append(label, label[0]) 14 | return label 15 | 16 | def forced_align(ctc_probs: torch.Tensor, 17 | y: torch.Tensor, 18 | blank_id=0) -> list: 19 | """ctc forced alignment. 20 | 21 | Args: 22 | torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D) 23 | torch.Tensor y: id sequence tensor 1d tensor (L) 24 | int blank_id: blank symbol index 25 | Returns: 26 | torch.Tensor: alignment result 27 | """ 28 | y_insert_blank = insert_blank(y, blank_id) 29 | 30 | log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank))) 31 | log_alpha = log_alpha - float('inf') # log of zero 32 | state_path = (torch.zeros( 33 | (ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1 34 | ) # state path 35 | 36 | # init start state 37 | log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] 38 | log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] 39 | 40 | for t in range(1, ctc_probs.size(0)): 41 | for s in range(len(y_insert_blank)): 42 | if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ 43 | s] == y_insert_blank[s - 2]: 44 | candidates = torch.tensor( 45 | [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]]) 46 | prev_state = [s, s - 1] 47 | else: 48 | candidates = torch.tensor([ 49 | log_alpha[t - 1, s], 50 | log_alpha[t - 1, s - 1], 51 | log_alpha[t - 1, s - 2], 52 | ]) 53 | prev_state = [s, s - 1, s - 2] 54 | log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]] 55 | state_path[t, s] = prev_state[torch.argmax(candidates)] 56 | 57 | state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16) 58 | 59 | candidates = torch.tensor([ 60 | log_alpha[-1, len(y_insert_blank) - 1], 61 | log_alpha[-1, len(y_insert_blank) - 2] 62 | ]) 63 | prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] 64 | state_seq[-1] = prev_state[torch.argmax(candidates)] 65 | for t in range(ctc_probs.size(0) - 2, -1, -1): 66 | state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] 67 | 68 | output_alignment = [] 69 | for t in range(0, ctc_probs.size(0)): 70 | output_alignment.append(y_insert_blank[state_seq[t, 0]]) 71 | 72 | return output_alignment 73 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def read_lists(list_file): 17 | lists = [] 18 | with open(list_file, 'r', encoding='utf8') as fin: 19 | for line in fin: 20 | lists.append(line.strip()) 21 | return lists 22 | 23 | 24 | def read_symbol_table(symbol_table_file): 25 | symbol_table = {} 26 | with open(symbol_table_file, 'r', encoding='utf8') as fin: 27 | for line in fin: 28 | arr = line.strip().split() 29 | assert len(arr) == 2 30 | symbol_table[arr[0]] = int(arr[1]) 31 | return symbol_table 32 | -------------------------------------------------------------------------------- /external_modules/ultralight/data_utils/wenet/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | from typeguard import check_argument_types 7 | 8 | 9 | class WarmupLR(_LRScheduler): 10 | """The WarmupLR scheduler 11 | 12 | This scheduler is almost same as NoamLR Scheduler except for following 13 | difference: 14 | 15 | NoamLR: 16 | lr = optimizer.lr * model_size ** -0.5 17 | * min(step ** -0.5, step * warmup_step ** -1.5) 18 | WarmupLR: 19 | lr = optimizer.lr * warmup_step ** 0.5 20 | * min(step ** -0.5, step * warmup_step ** -1.5) 21 | 22 | Note that the maximum lr equals to optimizer.lr in this scheduler. 23 | 24 | """ 25 | 26 | def __init__( 27 | self, 28 | optimizer: torch.optim.Optimizer, 29 | warmup_steps: Union[int, float] = 25000, 30 | last_epoch: int = -1, 31 | ): 32 | assert check_argument_types() 33 | self.warmup_steps = warmup_steps 34 | 35 | # __init__() must be invoked before setting field 36 | # because step() is also invoked in __init__() 37 | super().__init__(optimizer, last_epoch) 38 | 39 | def __repr__(self): 40 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" 41 | 42 | def get_lr(self): 43 | step_num = self.last_epoch + 1 44 | return [ 45 | lr 46 | * self.warmup_steps ** 0.5 47 | * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5) 48 | for lr in self.base_lrs 49 | ] 50 | 51 | def set_step(self, step: int): 52 | self.last_epoch = step 53 | -------------------------------------------------------------------------------- /external_modules/ultralight/pth2onnx.py: -------------------------------------------------------------------------------- 1 | from unet import Model 2 | import onnx 3 | import torch 4 | 5 | import onnxruntime 6 | import numpy as np 7 | import time 8 | onnx_path = "./dihuman.onnx" 9 | 10 | def check_onnx(torch_out, torch_in, audio): 11 | onnx_model = onnx.load(onnx_path) 12 | onnx.checker.check_model(onnx_model) 13 | import onnxruntime 14 | providers = ["CUDAExecutionProvider"] 15 | ort_session = onnxruntime.InferenceSession(onnx_path, providers=providers) 16 | print(ort_session.get_providers()) 17 | ort_inputs = {ort_session.get_inputs()[0].name: torch_in.cpu().numpy(), ort_session.get_inputs()[1].name: audio.cpu().numpy()} 18 | for i in range(1): 19 | t1 = time.time() 20 | ort_outs = ort_session.run(None, ort_inputs) 21 | t2 = time.time() 22 | print("onnx time cost::", t2 - t1) 23 | 24 | np.testing.assert_allclose(torch_out[0].cpu().numpy(), ort_outs[0][0], rtol=1e-03, atol=1e-05) 25 | print("Exported model has been tested with ONNXRuntime, and the result looks good!") 26 | 27 | 28 | net = Model(6).eval() 29 | net.load_state_dict(torch.load("20.pth")) 30 | img = torch.zeros([1, 6, 160, 160]) 31 | audio = torch.zeros([1, 128, 16, 32]) 32 | 33 | input_dict = {"input": img, "audio": audio} 34 | 35 | with torch.no_grad(): 36 | torch_out = net(img, audio) 37 | print(torch_out.shape) 38 | torch.onnx.export(net, (img, audio), onnx_path, input_names=['input', "audio"], 39 | output_names=['output'], 40 | # dynamic_axes=dynamic_axes, 41 | # example_outputs=torch_out, 42 | opset_version=11, 43 | export_params=True) 44 | check_onnx(torch_out, img, audio) 45 | -------------------------------------------------------------------------------- /external_modules/ultralight/requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | transformers 3 | numpy==1.23.5 4 | soundfile 5 | librosa 6 | onnxruntime -------------------------------------------------------------------------------- /external_modules/wav2lip-onnx-256/.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Virtual Environment 7 | venv/ 8 | env/ 9 | .env 10 | 11 | # IDE 12 | .vscode/ 13 | .idea/ 14 | *.swp 15 | *.swo 16 | 17 | # Logs 18 | *.log 19 | 20 | # Database 21 | *.sqlite3 22 | 23 | # OS generated files 24 | .DS_Store 25 | .DS_Store? 26 | ._* 27 | .Spotlight-V100 28 | .Trashes 29 | ehthumbs.db 30 | Thumbs.db 31 | 32 | # FastAPI specific 33 | .pytest_cache/ 34 | 35 | # Dependency directories 36 | node_modules/ 37 | 38 | # Distribution / packaging 39 | .Python 40 | build/ 41 | develop-eggs/ 42 | dist/ 43 | downloads/ 44 | eggs/ 45 | .eggs/ 46 | lib/ 47 | lib64/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | wheels/ 52 | pip-wheel-metadata/ 53 | share/python-wheels/ 54 | *.egg-info/ 55 | .installed.cfg 56 | *.egg 57 | 58 | # PyInstaller 59 | *.manifest 60 | *.spec 61 | 62 | # Installer logs 63 | pip-log.txt 64 | pip-delete-this-directory.txt 65 | 66 | # Unit test / coverage reports 67 | htmlcov/ 68 | .tox/ 69 | .nox/ 70 | .coverage 71 | .coverage.* 72 | .cache 73 | nosetests.xml 74 | coverage.xml 75 | *.cover 76 | *.py,cover 77 | .hypothesis/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # Environments 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .dmypy.json 106 | dmypy.json 107 | 108 | # Pyre type checker 109 | .pyre/ 110 | 111 | # 大型模型文件 112 | *.model 113 | *.pkl 114 | *.h5 115 | *.onnx 116 | *.pt 117 | *.pth 118 | 119 | # 排除.github目录 120 | .github/ 121 | -------------------------------------------------------------------------------- /external_modules/wav2lip-onnx-256/README.md: -------------------------------------------------------------------------------- 1 | # wav2lip-onnx-256x256 model 2 | Simple and fast wav2lip using new 256x256 resolution trained onnx-converted model for inference 3 | 4 | Minimum version. No additional functions like face enhancement, face alignment. Just same functions as the original repository 5 | 6 | Inference is quite fast running on CPU using the converted wav2lip onnx models and antelope face detection. Can be run on Nvidia GPU, tested on RTX3060 Update: tested on GTX1050 7 | 8 | Some result: 9 | 10 | wav2lip 96x96 - wav2lip_gan 96x96 - wav2lip 256x256 11 | 12 | https://github.com/user-attachments/assets/bdd186f6-6a79-4cbd-824f-74108392d390 13 | 14 | * Installation: Clone this repository and read Setup.txt 15 | 16 | * Download models from releases. 17 | 18 | * Don't forget to install ffmpeg and set path variable. 19 | 20 | * Face detection checkpoint already in insightface_func/models/antelope 21 | 22 | Original 256x256 pretrained checkpoint taken from: 23 | 24 | https://github.com/Kedreamix/Linly-Talker/blob/main/README.md 25 | 26 | 27 | -------------------------------------------------------------------------------- /external_modules/wav2lip-onnx-256/checkpoints/checkpoints.txt: -------------------------------------------------------------------------------- 1 | put onnx models here 2 | -------------------------------------------------------------------------------- /external_modules/wav2lip-onnx-256/convert2onnx_256/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Conv2d(nn.Module): 6 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.conv_block = nn.Sequential( 9 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 10 | nn.BatchNorm2d(cout) 11 | ) 12 | self.act = nn.ReLU() 13 | self.residual = residual 14 | 15 | def forward(self, x): 16 | out = self.conv_block(x) 17 | if self.residual: 18 | out += x 19 | return self.act(out) 20 | 21 | class nonorm_Conv2d(nn.Module): 22 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.conv_block = nn.Sequential( 25 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 26 | ) 27 | self.act = nn.LeakyReLU(0.01, inplace=True) 28 | 29 | def forward(self, x): 30 | out = self.conv_block(x) 31 | return self.act(out) 32 | 33 | class Conv2dTranspose(nn.Module): 34 | def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): 35 | super().__init__(*args, **kwargs) 36 | self.conv_block = nn.Sequential( 37 | nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), 38 | nn.BatchNorm2d(cout) 39 | ) 40 | self.act = nn.ReLU() 41 | 42 | def forward(self, x): 43 | out = self.conv_block(x) 44 | return self.act(out) 45 | -------------------------------------------------------------------------------- /external_modules/wav2lip-onnx-256/convert2onnx_256/export.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wav2lip_256 import Wav2Lip 3 | 4 | model = Wav2Lip() 5 | checkpoint = torch.load('wav2lipv2.pth', map_location='cpu', weights_only=True) 6 | s = checkpoint["state_dict"] 7 | new_s = {} 8 | for k, v in s.items(): 9 | new_s[k.replace('module.', '')] = v 10 | model.load_state_dict(new_s) 11 | model = model.to('cpu') 12 | model.eval() 13 | 14 | x = torch.randn(1,1,80,16) 15 | y = torch.randn(1,6,256,256) 16 | 17 | torch.onnx.export(model, (x,y), 'wav2lip_256.onnx', input_names=['mel_spectrogram', 'video_frames'], output_names=['predicted_frames'], opset_version=15) 18 | 19 | out = model(x,y) 20 | print(out.shape) -------------------------------------------------------------------------------- /external_modules/wav2lip-onnx-256/insightface_func/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/external_modules/wav2lip-onnx-256/insightface_func/__init__.py -------------------------------------------------------------------------------- /external_modules/wav2lip-onnx-256/insightface_func/face_detect_crop_single.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 16:46:04 7 | Description: 8 | ''' 9 | from __future__ import division 10 | import collections 11 | import numpy as np 12 | import glob 13 | import os 14 | import os.path as osp 15 | import cv2 16 | from insightface.model_zoo import model_zoo 17 | #from insightface_func.utils import face_align_ffhqandnewarc as face_align 18 | 19 | __all__ = ['Face_detect_crop', 'Face'] 20 | 21 | Face = collections.namedtuple('Face', [ 22 | 'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age', 23 | 'embedding_norm', 'normed_embedding', 24 | 'landmark' 25 | ]) 26 | 27 | Face.__new__.__defaults__ = (None, ) * len(Face._fields) 28 | 29 | 30 | class Face_detect_crop: 31 | def __init__(self, name, root='~/.insightface_func/models'): 32 | self.models = {} 33 | root = os.path.expanduser(root) 34 | onnx_files = glob.glob(osp.join(root, name, '*.onnx')) 35 | onnx_files = sorted(onnx_files) 36 | for onnx_file in onnx_files: 37 | if onnx_file.find('_selfgen_')>0: 38 | #print('ignore:', onnx_file) 39 | continue 40 | model = model_zoo.get_model(onnx_file) 41 | if model.taskname not in self.models: 42 | print('find model:', onnx_file, model.taskname) 43 | self.models[model.taskname] = model 44 | else: 45 | print('duplicated model task type, ignore:', onnx_file, model.taskname) 46 | del model 47 | assert 'detection' in self.models 48 | self.det_model = self.models['detection'] 49 | 50 | 51 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640), mode ='None'): 52 | self.det_thresh = det_thresh 53 | self.mode = mode 54 | assert det_size is not None 55 | print('set det-size:', det_size) 56 | self.det_size = det_size 57 | for taskname, model in self.models.items(): 58 | if taskname=='detection': 59 | model.prepare(ctx_id, input_size=det_size) 60 | else: 61 | model.prepare(ctx_id) 62 | 63 | 64 | 65 | def getBox(self, img, max_num=0): 66 | bboxes, kpss = self.det_model.detect(img, 67 | max_num=max_num, 68 | metric='default') 69 | if bboxes.shape[0] == 0: 70 | return None 71 | 72 | x1 = int(bboxes[0, 0:1]) 73 | y1 = int(bboxes[0, 1:2]) 74 | x2 = int(bboxes[0, 2:3]) 75 | y2 = int(bboxes[0, 3:4]) 76 | 77 | 78 | return (x1,y1,x2,y2) 79 | -------------------------------------------------------------------------------- /external_modules/wav2lip-onnx-256/requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | numpy 3 | tqdm 4 | librosa 5 | numba 6 | insightface==0.2.1 7 | onnxruntime -------------------------------------------------------------------------------- /external_modules/wav2lip-onnx-256/setup.txt: -------------------------------------------------------------------------------- 1 | conda create -n wav2lip_onnx python=3.7 2 | conda activate wav2lip_onnx 3 | cd c:\tutorial\wav2lip_onnx 4 | pip install -r requirements.txt 5 | 6 | for use with Nvidia GPU: 7 | conda install -c conda-forge cudatoolkit=11.2 cudnn=8.1.0 (version depending on your graphic card model) 8 | pip uninstall onnxruntime 9 | pip install onnxruntime-gpu 10 | 11 | maybe it's neccessary to also 12 | pip install opencv-python 13 | 14 | --------------------------- 15 | if you get some "onnx 1.9 providers" error: 16 | 17 | Edit this file: 18 | e.g. File "C:\Users\.conda\envs\ENVname\lib\site-packages\insightface\model_zoo\model_zoo.py" 19 | line 23, in get_model 20 | 21 | change: 22 | session = onnxruntime.InferenceSession(self.onnx_file, None) 23 | 24 | to: 25 | session = onnxruntime.InferenceSession(self.onnx_file, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] 26 | --------------------------- 27 | 28 | inference: 29 | python -W ignore inference_onnxModel.py --checkpoint_path "checkpoints\wav2lip_256.onnx" --face "D:\some.mp4" --audio "D:\some.wav" --outfile "D:\output.mp4" --nosmooth --pads 0 10 0 0 --fps 29.97 30 | 31 | python -W ignore inference_onnxModel.py --checkpoint_path checkpoints/wav2lip_256.onnx --face /Users/libn/Desktop/726_1727606032.mp4 --audio /Users/libn/Desktop/123.wav --outfile /Users/libn/Desktop/output.mp4 --nosmooth --pads 0 10 0 0 --fps 29.97 -------------------------------------------------------------------------------- /external_modules/wav2lip-onnx-256/temp/temp.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.110.0 2 | uvicorn==0.27.1 3 | sqlalchemy==2.0.28 4 | pydantic==2.6.3 5 | python-dotenv==1.0.1 6 | alembic==1.13.1 7 | torch 8 | ffmpeg-python 9 | # 最好是单独一个服务 10 | faster-whisper 11 | playwright 12 | apscheduler 13 | srt 14 | cachetools 15 | Pillow 16 | oss2 17 | boto3 -------------------------------------------------------------------------------- /resources/audios/default_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/resources/audios/default_audio.wav -------------------------------------------------------------------------------- /resources/audios/default_audio_hu.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libn-net/marketing_creator_pro_max_backend/71de2515fc91b42a10cbcef7d43f017cfda81e39/resources/audios/default_audio_hu.npy --------------------------------------------------------------------------------