├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── config.py
├── docs
└── README_zh.md
├── examples
├── dolly.wav
├── freddie.wav
├── hayami_saori.mp3
├── kesidi.wav
├── kuidou_cn.wav
├── kuidou_en.wav
├── linbo_cn.wav
├── linbo_en.wav
├── maggie.wav
├── manman3_cn.wav
├── manman4_cn.wav
├── maychelle_cn.wav
├── maychelle_en.wav
├── weilun.wav
└── xinbada_cn.wav
├── inference.sh
├── inference_parameters.json
├── inference_parameters.txt
├── install.sh
├── pretrained_models
└── .gitignore
├── quick_start.sh
├── requirements.txt
├── run_training.sh
├── server
├── .ipynb_checkpoints
│ └── webui-checkpoint.py
├── app.py
├── asr.py
├── example.json
├── modelhandler.py
├── train_webui.py
└── webui.py
├── src
├── AR
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── bucket_sampler.py
│ │ ├── data_module.py
│ │ └── dataset.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── t2s_lightning_module.py
│ │ ├── t2s_lightning_module_onnx.py
│ │ ├── t2s_model.py
│ │ ├── t2s_model_onnx.py
│ │ └── utils.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── activation.py
│ │ ├── activation_onnx.py
│ │ ├── embedding.py
│ │ ├── embedding_onnx.py
│ │ ├── lr_schedulers.py
│ │ ├── optim.py
│ │ ├── patched_mha_with_cache.py
│ │ ├── patched_mha_with_cache_onnx.py
│ │ ├── scaling.py
│ │ ├── transformer.py
│ │ └── transformer_onnx.py
│ ├── text_processing
│ │ ├── __init__.py
│ │ ├── phonemizer.py
│ │ └── symbols.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── initialize.py
│ │ └── io.py
├── configs
│ ├── s1.yaml
│ ├── s1big.yaml
│ ├── s1big2.yaml
│ ├── s1longer.yaml
│ ├── s1mq.yaml
│ ├── sovits.json
│ └── train.yaml
├── feature_extractor
│ ├── __init__.py
│ ├── cnhubert.py
│ └── whisper_enc.py
├── inference
│ ├── .ipynb_checkpoints
│ │ ├── infer_tool-checkpoint.py
│ │ └── inference-checkpoint.py
│ ├── __init__.py
│ ├── infer_tool.py
│ └── inference.py
├── module
│ ├── __init__.py
│ ├── attentions.py
│ ├── attentions_onnx.py
│ ├── commons.py
│ ├── core_vq.py
│ ├── data_utils.py
│ ├── losses.py
│ ├── mel_processing.py
│ ├── models.py
│ ├── models_onnx.py
│ ├── modules.py
│ ├── mrte_model.py
│ ├── quantize.py
│ └── transforms.py
├── preprocess
│ ├── __init__.py
│ ├── get_phonemes.py
│ ├── get_semantic.py
│ ├── get_ssl_features.py
│ └── process.py
├── text
│ ├── __init__.py
│ ├── chinese.py
│ ├── cleaner.py
│ ├── cmudict-fast.rep
│ ├── cmudict.rep
│ ├── engdict-hot.rep
│ ├── engdict_cache.pickle
│ ├── english.py
│ ├── japanese.py
│ ├── namedict_cache.pickle
│ ├── opencpop-strict.txt
│ ├── symbols.py
│ ├── tone_sandhi.py
│ └── zh_normalization
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── char_convert.py
│ │ ├── chronology.py
│ │ ├── constants.py
│ │ ├── num.py
│ │ ├── phonecode.py
│ │ ├── quantifier.py
│ │ └── text_normlization.py
├── train
│ ├── train_gpt.py
│ └── train_sovits.py
└── utils
│ ├── __init__.py
│ ├── config.py
│ ├── cut.py
│ ├── onnx_export.py
│ ├── process_ckpt.py
│ └── utils.py
└── tools
├── asr
├── config.py
├── fasterwhisper_asr.py
├── funasr_asr.py
└── models
│ └── .gitignore
├── cmd-denoise.py
├── denoise-model
└── .gitignore
├── my_utils.py
├── process_data.py
├── slice_audio.py
├── slicer2.py
├── subfix_webui.py
└── uvr5
├── lib
├── lib_v5
│ ├── dataset.py
│ ├── layers.py
│ ├── layers_123812KB.py
│ ├── layers_123821KB.py
│ ├── layers_33966KB.py
│ ├── layers_537227KB.py
│ ├── layers_537238KB.py
│ ├── layers_new.py
│ ├── model_param_init.py
│ ├── modelparams
│ │ ├── 1band_sr16000_hl512.json
│ │ ├── 1band_sr32000_hl512.json
│ │ ├── 1band_sr33075_hl384.json
│ │ ├── 1band_sr44100_hl1024.json
│ │ ├── 1band_sr44100_hl256.json
│ │ ├── 1band_sr44100_hl512.json
│ │ ├── 1band_sr44100_hl512_cut.json
│ │ ├── 2band_32000.json
│ │ ├── 2band_44100_lofi.json
│ │ ├── 2band_48000.json
│ │ ├── 3band_44100.json
│ │ ├── 3band_44100_mid.json
│ │ ├── 3band_44100_msb2.json
│ │ ├── 4band_44100.json
│ │ ├── 4band_44100_mid.json
│ │ ├── 4band_44100_msb.json
│ │ ├── 4band_44100_msb2.json
│ │ ├── 4band_44100_reverse.json
│ │ ├── 4band_44100_sw.json
│ │ ├── 4band_v2.json
│ │ ├── 4band_v2_sn.json
│ │ ├── 4band_v3.json
│ │ └── ensemble.json
│ ├── nets.py
│ ├── nets_123812KB.py
│ ├── nets_123821KB.py
│ ├── nets_33966KB.py
│ ├── nets_537227KB.py
│ ├── nets_537238KB.py
│ ├── nets_61968KB.py
│ ├── nets_new.py
│ └── spec_utils.py
├── name_params.json
└── utils.py
├── mdxnet.py
├── vr.py
└── webui.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | __pycache__
3 | *.pyc
4 | env
5 | runtime
6 | .idea
7 | output
8 | logs
9 | reference
10 | GPT_weights
11 | SoVITS_weights
12 | TEMP
13 | *.ipynb
14 | input_audio/*
15 | output_audio/*
16 | .history
17 | app.log*
18 | Dockerfile
19 | .gitlab-ci.yml
20 | deploy.yaml
21 | Makefile
22 | .ipynb_checkpoints
23 | onnx/
24 | data/
25 | !data/check_audio.py
26 | example.xlsx
27 |
28 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # 使用Miniconda镜像作为基础镜像
2 | FROM continuumio/miniconda3
3 |
4 | # 安装依赖项
5 | RUN apt-get update && apt-get install -y --no-install-recommends \
6 | git \
7 | sudo \
8 | build-essential \
9 | wget \
10 | curl \
11 | cmake \
12 | vim \
13 | tmux \
14 | ffmpeg \
15 | libglu1-mesa \
16 | libxi-dev \
17 | libxmu-dev \
18 | libglu1-mesa-dev \
19 | freeglut3-dev && \
20 | apt-get clean && \
21 | rm -rf /var/lib/apt/lists/*
22 |
23 | # # 下载CUDA工具包并安装
24 | # RUN wget -q https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run \
25 | # && sh cuda_11.8.0_520.61.05_linux.run --silent --toolkit > /dev/null \
26 | # && rm cuda_11.8.0_520.61.05_linux.run
27 |
28 | # # 更新环境变量
29 | # ENV PATH=$PATH:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/nvidia/lib64
30 | # ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
31 | # ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64:$LD_LIBRARY_PATH
32 |
33 | # 创建并初始化conda环境
34 | RUN conda create -n aispeech python=3.10.13 && \
35 | echo "conda activate aispeech" >> ~/.bashrc
36 |
37 | SHELL ["conda", "run", "-n", "aispeech", "/bin/bash", "-c"]
38 |
39 | WORKDIR /app
40 |
41 | # 安装依赖
42 |
43 | COPY install.sh /app/install.sh
44 | COPY requirements.txt /app/requirements.txt
45 | RUN chmod +x /app/install.sh && bash /app/install.sh -y
46 |
47 | # 将当前目录内容复制到容器的/app下
48 | COPY . /app
49 |
50 | # 在Dockerfile中设置环境变量
51 | ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
52 |
53 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 RVC-Boss
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import sys,os
2 |
3 | import torch
4 |
5 | # 推理用的指定模型
6 | sovits_path = ""
7 | gpt_path = ""
8 | is_half_str = os.environ.get("is_half", "True")
9 | is_half = True if is_half_str.lower() == 'true' else False
10 | is_share_str = os.environ.get("is_share","False")
11 | is_share= True if is_share_str.lower() == 'true' else False
12 |
13 | cnhubert_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
14 | bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
15 | pretrained_sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
16 | pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
17 |
18 | exp_root = "logs"
19 | python_exec = sys.executable or "python"
20 | if torch.cuda.is_available():
21 | infer_device = "cuda"
22 | else:
23 | infer_device = "cpu"
24 |
25 | webui_port_main = 9874
26 | webui_port_uvr5 = 9873
27 | webui_port_infer_tts = 9872
28 | webui_port_subfix = 9871
29 |
30 | api_port = 9880
31 |
32 | if infer_device == "cuda":
33 | gpu_name = torch.cuda.get_device_name(0)
34 | if (
35 | ("16" in gpu_name and "V100" not in gpu_name.upper())
36 | or "P40" in gpu_name.upper()
37 | or "P10" in gpu_name.upper()
38 | or "1060" in gpu_name
39 | or "1070" in gpu_name
40 | or "1080" in gpu_name
41 | ):
42 | is_half=False
43 |
44 | if(infer_device=="cpu"):is_half=False
45 |
46 | class Config:
47 | def __init__(self):
48 | self.sovits_path = sovits_path
49 | self.gpt_path = gpt_path
50 | self.is_half = is_half
51 |
52 | self.cnhubert_path = cnhubert_path
53 | self.bert_path = bert_path
54 | self.pretrained_sovits_path = pretrained_sovits_path
55 | self.pretrained_gpt_path = pretrained_gpt_path
56 |
57 | self.exp_root = exp_root
58 | self.python_exec = python_exec
59 | self.infer_device = infer_device
60 |
61 | self.webui_port_main = webui_port_main
62 | self.webui_port_uvr5 = webui_port_uvr5
63 | self.webui_port_infer_tts = webui_port_infer_tts
64 | self.webui_port_subfix = webui_port_subfix
65 |
66 | self.api_port = api_port
67 |
--------------------------------------------------------------------------------
/examples/dolly.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/dolly.wav
--------------------------------------------------------------------------------
/examples/freddie.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/freddie.wav
--------------------------------------------------------------------------------
/examples/hayami_saori.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/hayami_saori.mp3
--------------------------------------------------------------------------------
/examples/kesidi.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/kesidi.wav
--------------------------------------------------------------------------------
/examples/kuidou_cn.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/kuidou_cn.wav
--------------------------------------------------------------------------------
/examples/kuidou_en.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/kuidou_en.wav
--------------------------------------------------------------------------------
/examples/linbo_cn.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/linbo_cn.wav
--------------------------------------------------------------------------------
/examples/linbo_en.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/linbo_en.wav
--------------------------------------------------------------------------------
/examples/maggie.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/maggie.wav
--------------------------------------------------------------------------------
/examples/manman3_cn.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/manman3_cn.wav
--------------------------------------------------------------------------------
/examples/manman4_cn.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/manman4_cn.wav
--------------------------------------------------------------------------------
/examples/maychelle_cn.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/maychelle_cn.wav
--------------------------------------------------------------------------------
/examples/maychelle_en.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/maychelle_en.wav
--------------------------------------------------------------------------------
/examples/weilun.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/weilun.wav
--------------------------------------------------------------------------------
/examples/xinbada_cn.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/examples/xinbada_cn.wav
--------------------------------------------------------------------------------
/inference.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python src/inference/inference.py \
4 | --sovits_weights pretrained_models/sovits_weights/kuidou_cn/kuidou_cn_e30_s540.pth \
5 | --gpt_weights pretrained_models/gpt_weights/kuidou_cn/kuidou_cn-e30.ckpt \
6 | --parameters_file inference_parameters.json \
7 | --output_folder output_audio/kuidou_cn \
8 |
--------------------------------------------------------------------------------
/inference_parameters.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "ref_wav_path": "kuidou_cn.wav",
4 | "prompt_text": "超级空投就在我附近,有人要来抢抢看么?",
5 | "prompt_language": "zh",
6 | "text": "那时候我才意识到我已经迷路了。我想等到红灯时停下自行车看导航,谁知一路都是绿灯,由此一路往前,我冲着众人鞠躬,这便是我迟到十年的原因。",
7 | "text_language": "zh",
8 | "how_to_cut": null
9 | },
10 | {
11 | "ref_wav_path": "kuidou_cn.wav",
12 | "prompt_text": "超级空投就在我附近,有人要来抢抢看么?",
13 | "prompt_language": "zh",
14 | "text": "总的来说对于游戏思维的要求非常高。而且因为规则很多很复杂,所以也不会出现重复游玩带来的枯燥感,非常适合有事没事就拿出来想一想。有时候突然的灵光乍现就可以想到解决问题的办法。",
15 | "text_language": "zh",
16 | "how_to_cut": null
17 | }
18 | ]
--------------------------------------------------------------------------------
/inference_parameters.txt:
--------------------------------------------------------------------------------
1 | 0001.wav|我从来不留情面。|zh|日月之力,引我前行|zh|None
2 | 0002.wav|罪人的哭喊,必将响彻天际。|zh|日月并肩,战场由我主宰|zh|None
3 | 0003.wav|在安静的时候,我眼前浮现出暗裔和人类的回忆。|zh|遵循荣光的指引|zh|None
4 | 0015.wav|百般求饶也不会让我停手。|zh|追随暗影的轨迹|zh|None
5 | 0005.wav|飞升者,太阳给了你力量黑暗给了你归宿。|zh|以平衡之名,出击!|zh|None
6 | 0006.wav|没有仁慈,没有宽恕,人类和怪物毫无区别,都会倒下。|zh|光影之间,真战始于此刻|zh|None
7 | 0007.wav|剑士不该质疑自己的方向。|zh|你的光芒耗尽了!|zh|None
8 | 0017.wav|同情也不过是一声呻吟。|zh|巨兽也无法阻挡光与影的审判!|zh|None
9 | 0018.wav|你希望我们团结,战斗会实现这个愿望吗?|zh|万物湮灭,不过吹灰之力。|zh|None
10 | 0010.wav|没有哪颗心能拴住我。|zh|光暗未定,此非终结。|zh|None
11 | 0011.wav|就连希望也会消亡。|zh|我听见星辰。。。在召唤我?|zh|None
12 | 0012.wav|你期盼已久的终点。|zh|日月轮回,永无止境,但充满希望。|zh|None
13 | 0013.wav|世界埋葬了我,我也必将回报。|zh|光暗交织,现世审判!|zh|None
14 | 0014.wav|享受短暂的喘息吧,艾欧尼亚的孩子,很快就会结束的。|zh|命运相伴,我们终将再会。|zh|None
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | conda install -c conda-forge gcc
3 | conda install -c conda-forge gxx
4 | conda install ffmpeg cmake
5 | conda install pytorch==2.1.0 torchvision==0.20.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia
6 | pip install -r requirements.txt
7 |
--------------------------------------------------------------------------------
/pretrained_models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/quick_start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # 获取传入的name参数
4 | NAME=$1
5 |
6 | # # Preprocess
7 | python src/preprocess/process.py \
8 | --data_dir data \
9 | --log_dir logs \
10 | --name $NAME \
11 |
12 | # train sovits model
13 | python src/train/train_sovits.py \
14 | -c src/configs/sovits.json \
15 | <<<<<<< HEAD
16 | -n $NAME \
17 | -t sovits \
18 | -e 30 \
19 | -lr 0.4 \
20 | -bs 32 \
21 | -nw 0 \
22 | --save_every_epoch 10 \
23 |
24 | # train gpt model
25 | python src/train/train_gpt.py \
26 | -c src/configs/s1longer.yaml \
27 | -n $NAME \
28 | -e 30 \
29 | -bs 32 \
30 | -nw 0 \
31 | --save_every_epoch 10
32 | =======
33 | -n varus \
34 | -t sovits \
35 | -e 25 \
36 | -lr 0.4 \
37 | -bs 16 \
38 | -nw 0 \
39 | --save_every_epoch 5 \
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.26.0
2 | scipy==1.12.0
3 | tensorboard
4 | librosa
5 | pydub
6 | numba
7 | pytorch-lightning
8 | gradio==4.41.0
9 | gradio_client==1.3.0
10 | fastapi==0.112.1
11 | ffmpeg-python
12 | onnxruntime
13 | tqdm
14 | funasr
15 | cn2an
16 | pypinyin
17 | pyopenjtalk
18 | g2p_en
19 | torchaudio
20 | sentencepiece
21 | transformers
22 | chardet
23 | PyYAML
24 | psutil
25 | jieba_fast
26 | jieba
27 | LangSegment>=0.2.0
28 | Faster_Whisper
29 | einops
30 | pydantic
31 | wordsegment
32 | openpyxl
--------------------------------------------------------------------------------
/run_training.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # 预先定义的name参数列表
4 | NAME_LIST=("maychelle_en" "maychelle_cn")
5 |
6 | # 循环处理每个name参数
7 | for NAME in "${NAME_LIST[@]}"; do
8 | # 执行quick_start.sh并传入name参数
9 | ./quick_start.sh $NAME
10 | done
--------------------------------------------------------------------------------
/server/asr.py:
--------------------------------------------------------------------------------
1 | from agent_kernel import get_llm_client, ASRRequest
2 | from pydub import AudioSegment
3 | import os
4 | import pandas as pd
5 | import re
6 | import asyncio
7 |
8 |
9 | async def get_asr_result(audio_path: str) -> str:
10 | # 加载音频文件
11 | audio = AudioSegment.from_wav(audio_path)
12 |
13 | # 检查音频是否为16位和单声道
14 | if audio.sample_width != 2 or audio.channels != 1:
15 | # 如果不是16位或单声道,则转换成16位和单声道并保存成一个临时的文件
16 | temp_audio_path = audio_path.replace(".wav", "_tmp.wav")
17 | audio = audio.set_sample_width(2).set_channels(1) # 2 bytes = 16 bits, 1 channel = mono
18 | audio.export(temp_audio_path, format="wav")
19 | else:
20 | temp_audio_path = audio_path
21 |
22 | # 调用agent kernel的asr接口,返回识别结果
23 | client = get_llm_client()
24 | r = ASRRequest.from_file(temp_audio_path)
25 | r.lang_code = "cmn-Hans-CN"
26 | res = await client.asr_async(r) # Async API
27 | transcript = res.response.alternatives[0].transcript # type: ignore
28 |
29 | # 删除临时文件
30 | if temp_audio_path != audio_path:
31 | os.remove(temp_audio_path)
32 |
33 | return transcript
34 |
35 |
36 | def calculate_edit_distance(hypothesis, reference, lang="zh"):
37 | """
38 | 计算 Translation Edit Rate (TER)
39 | """
40 | # 去除标点符号
41 | hypothesis = re.sub(r'[^\w\s]', '', hypothesis)
42 | reference = re.sub(r'[^\w\s]', '', reference)
43 |
44 | # 将字符串分割成单词列表
45 | if lang == "en":
46 | hyp_words = hypothesis.split(" ")
47 | ref_words = reference.split(" ")
48 | else:
49 | hyp_words = list(hypothesis)
50 | ref_words = list(reference)
51 |
52 | # 初始化编辑距离矩阵
53 | d = [[0] * (len(ref_words) + 1) for _ in range(len(hyp_words) + 1)]
54 |
55 | # 初始化边界条件
56 | for i in range(len(hyp_words) + 1):
57 | d[i][0] = i
58 | for j in range(len(ref_words) + 1):
59 | d[0][j] = j
60 |
61 | # 动态规划计算编辑距离
62 | for i in range(1, len(hyp_words) + 1):
63 | for j in range(1, len(ref_words) + 1):
64 | if hyp_words[i - 1] == ref_words[j - 1]:
65 | d[i][j] = d[i - 1][j - 1]
66 | else:
67 | d[i][j] = min(d[i - 1][j] + 1, # 删除
68 | d[i][j - 1] + 1, # 插入
69 | d[i - 1][j - 1] + 1) # 替换
70 |
71 | # 编辑距离
72 | edit_distance = d[len(hyp_words)][len(ref_words)]
73 | return edit_distance
74 |
75 |
76 | def match_result(asr_result, lines):
77 | ## 使用手动计算的TER来计算匹配度
78 | edit_distance = calculate_edit_distance(asr_result, lines)
79 |
80 | if edit_distance == 0:
81 | return 0
82 | elif edit_distance <= 2:
83 | return 1
84 | else:
85 | return 2
86 |
87 |
88 | async def process_audio(resource_name, line, audio_folder):
89 | audio_path = os.path.join(audio_folder, resource_name + '.wav') # 修正路径拼接
90 | if os.path.exists(audio_path):
91 | try:
92 | # 调用预先写好的get_asr_result函数
93 | asr_result = await get_asr_result(audio_path)
94 | except Exception as e:
95 | print(f"Error processing {resource_name}: {e}")
96 | asr_result = ""
97 | else:
98 | print(f"Audio file {resource_name} not found in {audio_folder}")
99 | asr_result = ""
100 |
101 | print(f"ASR Result for {resource_name}: {asr_result}")
102 | match_result_value = match_result(asr_result, line)
103 | return asr_result, match_result_value
104 |
105 |
106 | async def audio_line_check(excel_path, audio_folder):
107 | # 读取Excel文件
108 | df = pd.read_excel(excel_path, header=0)
109 |
110 | # 获取台词和资源命名
111 | lines = df.iloc[:, 0]
112 | resource_names = df.iloc[:, 1]
113 |
114 | # 使用asyncio.gather并行处理音频文件
115 | tasks = [process_audio(resource_name, line, audio_folder) for resource_name, line in zip(resource_names, lines)]
116 | results = await asyncio.gather(*tasks)
117 |
118 | # 分离ASR结果和匹配结果
119 | asr_results, match_results = zip(*results)
120 |
121 | ## 重新组装excel,前两列不变,加一列asr result命名为AI识别文本,和一列match result命名为匹配度
122 | df['AI识别文本'] = asr_results
123 | df['匹配度'] = match_results
124 |
125 | # 保存新的Excel文件
126 | new_excel_path = excel_path.replace("需求表.xlsx", "审查结果.xlsx")
127 | df.to_excel(new_excel_path, index=False)
128 | print(f"Results saved to {new_excel_path}")
129 |
--------------------------------------------------------------------------------
/server/modelhandler.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | import os
3 | import sys
4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5 | import json
6 |
7 |
8 | class ModelHandler:
9 | def __init__(self, models_folder, models_info_file='models_info.json'):
10 | self.models_folder = models_folder
11 | self.models_info_file = models_info_file
12 | self.models_info = {}
13 | self.load_models_info()
14 |
15 |
16 | def load_models_info(self):
17 | """加载模型信息JSON文件"""
18 | self.update_models_info()
19 | with open(os.path.join(self.models_folder, self.models_info_file), 'r') as file:
20 | self.models_info = json.load(file)
21 |
22 |
23 | def update_models_info(self):
24 | """更新模型信息,并写入JSON文件"""
25 | for character in os.listdir(self.models_folder):
26 | character_path = os.path.join(self.models_folder, character)
27 | if os.path.isdir(character_path):
28 | self.models_info[character] = {}
29 | for model_file in os.listdir(character_path):
30 | model_path = os.path.join(character_path, model_file)
31 | if os.path.isfile(model_path):
32 | self.models_info[character][model_file] = model_path
33 |
34 | with open(os.path.join(self.models_folder, self.models_info_file), 'w') as file:
35 | json.dump(self.models_info, file, indent=4)
36 |
37 |
--------------------------------------------------------------------------------
/src/AR/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/src/AR/__init__.py
--------------------------------------------------------------------------------
/src/AR/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/src/AR/data/__init__.py
--------------------------------------------------------------------------------
/src/AR/data/data_module.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
2 | from pytorch_lightning import LightningDataModule
3 | from AR.data.bucket_sampler import DistributedBucketSampler
4 | from AR.data.dataset import Text2SemanticDataset
5 | from torch.utils.data import DataLoader
6 |
7 |
8 | class Text2SemanticDataModule(LightningDataModule):
9 | def __init__(
10 | self,
11 | config,
12 | train_semantic_path,
13 | train_phoneme_path,
14 | dev_semantic_path=None,
15 | dev_phoneme_path=None,
16 | ):
17 | super().__init__()
18 | self.config = config
19 | self.train_semantic_path = train_semantic_path
20 | self.train_phoneme_path = train_phoneme_path
21 | self.dev_semantic_path = dev_semantic_path
22 | self.dev_phoneme_path = dev_phoneme_path
23 | self.num_workers = self.config["data"]["num_workers"]
24 |
25 | def prepare_data(self):
26 | pass
27 |
28 | def setup(self, stage=None, output_logs=False):
29 | self._train_dataset = Text2SemanticDataset(
30 | phoneme_path=self.train_phoneme_path,
31 | semantic_path=self.train_semantic_path,
32 | max_sec=self.config["data"]["max_sec"],
33 | pad_val=self.config["data"]["pad_val"],
34 | )
35 | self._dev_dataset = self._train_dataset
36 | # self._dev_dataset = Text2SemanticDataset(
37 | # phoneme_path=self.dev_phoneme_path,
38 | # semantic_path=self.dev_semantic_path,
39 | # max_sample=self.config['data']['max_eval_sample'],
40 | # max_sec=self.config['data']['max_sec'],
41 | # pad_val=self.config['data']['pad_val'])
42 |
43 | def train_dataloader(self):
44 | batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
45 | batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
46 | sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
47 | return DataLoader(
48 | self._train_dataset,
49 | batch_size=batch_size,
50 | sampler=sampler,
51 | collate_fn=self._train_dataset.collate,
52 | num_workers=self.num_workers,
53 | persistent_workers=True if self.num_workers>0 else False,
54 | prefetch_factor=16 if self.num_workers>0 else None,
55 | )
56 |
57 | def val_dataloader(self):
58 | return DataLoader(
59 | self._dev_dataset,
60 | batch_size=1,
61 | shuffle=False,
62 | collate_fn=self._train_dataset.collate,
63 | num_workers=max(self.num_workers, 12),
64 | persistent_workers=True,
65 | prefetch_factor=16,
66 | )
67 |
68 | # 这个会使用到嘛?
69 | def test_dataloader(self):
70 | return DataLoader(
71 | self._dev_dataset,
72 | batch_size=1,
73 | shuffle=False,
74 | collate_fn=self._train_dataset.collate,
75 | )
76 |
--------------------------------------------------------------------------------
/src/AR/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/src/AR/models/__init__.py
--------------------------------------------------------------------------------
/src/AR/models/t2s_lightning_module_onnx.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
2 | import os, sys
3 |
4 | now_dir = os.getcwd()
5 | sys.path.append(now_dir)
6 | from typing import Dict
7 |
8 | import torch
9 | from pytorch_lightning import LightningModule
10 | from AR.models.t2s_model_onnx import Text2SemanticDecoder
11 | from AR.modules.lr_schedulers import WarmupCosineLRSchedule
12 | from AR.modules.optim import ScaledAdam
13 |
14 |
15 | class Text2SemanticLightningModule(LightningModule):
16 | def __init__(self, config, output_dir, is_train=True):
17 | super().__init__()
18 | self.config = config
19 | self.top_k = 3
20 | self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
21 | pretrained_s1 = config.get("pretrained_s1")
22 | if pretrained_s1 and is_train:
23 | # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
24 | print(
25 | self.load_state_dict(
26 | torch.load(pretrained_s1, map_location="cpu")["weight"]
27 | )
28 | )
29 | if is_train:
30 | self.automatic_optimization = False
31 | self.save_hyperparameters()
32 | self.eval_dir = output_dir / "eval"
33 | self.eval_dir.mkdir(parents=True, exist_ok=True)
34 |
35 | def training_step(self, batch: Dict, batch_idx: int):
36 | opt = self.optimizers()
37 | scheduler = self.lr_schedulers()
38 | loss, acc = self.model.forward(
39 | batch["phoneme_ids"],
40 | batch["phoneme_ids_len"],
41 | batch["semantic_ids"],
42 | batch["semantic_ids_len"],
43 | batch["bert_feature"],
44 | )
45 | self.manual_backward(loss)
46 | if batch_idx > 0 and batch_idx % 4 == 0:
47 | opt.step()
48 | opt.zero_grad()
49 | scheduler.step()
50 |
51 | self.log(
52 | "total_loss",
53 | loss,
54 | on_step=True,
55 | on_epoch=True,
56 | prog_bar=True,
57 | sync_dist=True,
58 | )
59 | self.log(
60 | "lr",
61 | scheduler.get_last_lr()[0],
62 | on_epoch=True,
63 | prog_bar=True,
64 | sync_dist=True,
65 | )
66 | self.log(
67 | f"top_{self.top_k}_acc",
68 | acc,
69 | on_step=True,
70 | on_epoch=True,
71 | prog_bar=True,
72 | sync_dist=True,
73 | )
74 |
75 | def validation_step(self, batch: Dict, batch_idx: int):
76 | return
77 |
78 | def configure_optimizers(self):
79 | model_parameters = self.model.parameters()
80 | parameters_names = []
81 | parameters_names.append(
82 | [name_param_pair[0] for name_param_pair in self.model.named_parameters()]
83 | )
84 | lm_opt = ScaledAdam(
85 | model_parameters,
86 | lr=0.01,
87 | betas=(0.9, 0.95),
88 | clipping_scale=2.0,
89 | parameters_names=parameters_names,
90 | show_dominant_parameters=False,
91 | clipping_update_period=1000,
92 | )
93 |
94 | return {
95 | "optimizer": lm_opt,
96 | "lr_scheduler": {
97 | "scheduler": WarmupCosineLRSchedule(
98 | lm_opt,
99 | init_lr=self.config["optimizer"]["lr_init"],
100 | peak_lr=self.config["optimizer"]["lr"],
101 | end_lr=self.config["optimizer"]["lr_end"],
102 | warmup_steps=self.config["optimizer"]["warmup_steps"],
103 | total_steps=self.config["optimizer"]["decay_steps"],
104 | )
105 | },
106 | }
107 |
--------------------------------------------------------------------------------
/src/AR/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/src/AR/modules/__init__.py
--------------------------------------------------------------------------------
/src/AR/modules/embedding.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2 | import math
3 |
4 | import torch
5 | from torch import nn
6 |
7 |
8 | class TokenEmbedding(nn.Module):
9 | def __init__(
10 | self,
11 | embedding_dim: int,
12 | vocab_size: int,
13 | dropout: float = 0.0,
14 | ):
15 | super().__init__()
16 |
17 | self.vocab_size = vocab_size
18 | self.embedding_dim = embedding_dim
19 |
20 | self.dropout = torch.nn.Dropout(p=dropout)
21 | self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
22 |
23 | @property
24 | def weight(self) -> torch.Tensor:
25 | return self.word_embeddings.weight
26 |
27 | def embedding(self, index: int) -> torch.Tensor:
28 | return self.word_embeddings.weight[index : index + 1]
29 |
30 | def forward(self, x: torch.Tensor):
31 | x = self.word_embeddings(x)
32 | x = self.dropout(x)
33 | return x
34 |
35 |
36 | class SinePositionalEmbedding(nn.Module):
37 | def __init__(
38 | self,
39 | embedding_dim: int,
40 | dropout: float = 0.0,
41 | scale: bool = False,
42 | alpha: bool = False,
43 | ):
44 | super().__init__()
45 | self.embedding_dim = embedding_dim
46 | self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
47 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
48 | self.dropout = torch.nn.Dropout(p=dropout)
49 |
50 | self.reverse = False
51 | self.pe = None
52 | self.extend_pe(torch.tensor(0.0).expand(1, 4000))
53 |
54 | def extend_pe(self, x):
55 | """Reset the positional encodings."""
56 | if self.pe is not None:
57 | if self.pe.size(1) >= x.size(1):
58 | if self.pe.dtype != x.dtype or self.pe.device != x.device:
59 | self.pe = self.pe.to(dtype=x.dtype, device=x.device)
60 | return
61 | pe = torch.zeros(x.size(1), self.embedding_dim)
62 | if self.reverse:
63 | position = torch.arange(
64 | x.size(1) - 1, -1, -1.0, dtype=torch.float32
65 | ).unsqueeze(1)
66 | else:
67 | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
68 | div_term = torch.exp(
69 | torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
70 | * -(math.log(10000.0) / self.embedding_dim)
71 | )
72 | pe[:, 0::2] = torch.sin(position * div_term)
73 | pe[:, 1::2] = torch.cos(position * div_term)
74 | pe = pe.unsqueeze(0)
75 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
76 |
77 | def forward(self, x: torch.Tensor) -> torch.Tensor:
78 | self.extend_pe(x)
79 | output = x.unsqueeze(-1) if x.ndim == 2 else x
80 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
81 | return self.dropout(output)
82 |
--------------------------------------------------------------------------------
/src/AR/modules/embedding_onnx.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2 | import math
3 |
4 | import torch
5 | from torch import nn
6 |
7 |
8 | class TokenEmbedding(nn.Module):
9 | def __init__(
10 | self,
11 | embedding_dim: int,
12 | vocab_size: int,
13 | dropout: float = 0.0,
14 | ):
15 | super().__init__()
16 |
17 | self.vocab_size = vocab_size
18 | self.embedding_dim = embedding_dim
19 |
20 | self.dropout = torch.nn.Dropout(p=dropout)
21 | self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
22 |
23 | @property
24 | def weight(self) -> torch.Tensor:
25 | return self.word_embeddings.weight
26 |
27 | def embedding(self, index: int) -> torch.Tensor:
28 | return self.word_embeddings.weight[index : index + 1]
29 |
30 | def forward(self, x: torch.Tensor):
31 | x = self.word_embeddings(x)
32 | x = self.dropout(x)
33 | return x
34 |
35 |
36 | class SinePositionalEmbedding(nn.Module):
37 | def __init__(
38 | self,
39 | embedding_dim: int,
40 | dropout: float = 0.0,
41 | scale: bool = False,
42 | alpha: bool = False,
43 | ):
44 | super().__init__()
45 | self.embedding_dim = embedding_dim
46 | self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
47 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
48 | self.dropout = torch.nn.Dropout(p=dropout)
49 | self.reverse = False
50 | self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
51 |
52 | def extend_pe(self, x):
53 | position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
54 | scpe = (position * self.div_term).unsqueeze(0)
55 | pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
56 | pe = pe.contiguous().view(1, -1, self.embedding_dim)
57 | return pe
58 |
59 | def forward(self, x: torch.Tensor) -> torch.Tensor:
60 | pe = self.extend_pe(x)
61 | output = x.unsqueeze(-1) if x.ndim == 2 else x
62 | output = output * self.x_scale + self.alpha * pe
63 | return self.dropout(output)
64 |
--------------------------------------------------------------------------------
/src/AR/modules/lr_schedulers.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/lr_schedulers.py
2 | import math
3 |
4 | import torch
5 | from matplotlib import pyplot as plt
6 | from torch import nn
7 | from torch.optim import Adam
8 |
9 |
10 | class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
11 | """
12 | Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
13 | """
14 |
15 | def __init__(
16 | self,
17 | optimizer,
18 | init_lr,
19 | peak_lr,
20 | end_lr,
21 | warmup_steps=10000,
22 | total_steps=400000,
23 | current_step=0,
24 | ):
25 | self.init_lr = init_lr
26 | self.peak_lr = peak_lr
27 | self.end_lr = end_lr
28 | self.optimizer = optimizer
29 | self._warmup_rate = (peak_lr - init_lr) / warmup_steps
30 | self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
31 | self._current_step = current_step
32 | self.lr = init_lr
33 | self.warmup_steps = warmup_steps
34 | self.total_steps = total_steps
35 | self._last_lr = [self.lr]
36 |
37 | def set_lr(self, lr):
38 | self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
39 | for g in self.optimizer.param_groups:
40 | # g['lr'] = lr
41 | g["lr"] = self.end_lr ###锁定用线性
42 |
43 | def step(self):
44 | if self._current_step < self.warmup_steps:
45 | lr = self.init_lr + self._warmup_rate * self._current_step
46 |
47 | elif self._current_step > self.total_steps:
48 | lr = self.end_lr
49 |
50 | else:
51 | decay_ratio = (self._current_step - self.warmup_steps) / (
52 | self.total_steps - self.warmup_steps
53 | )
54 | if decay_ratio < 0.0 or decay_ratio > 1.0:
55 | raise RuntimeError(
56 | "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
57 | )
58 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
59 | lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
60 |
61 | self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
62 | self.set_lr(lr)
63 | self.lr = lr
64 | self._current_step += 1
65 | return self.lr
66 |
67 |
68 | if __name__ == "__main__":
69 | m = nn.Linear(10, 10)
70 | opt = Adam(m.parameters(), lr=1e-4)
71 | s = WarmupCosineLRSchedule(
72 | opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
73 | )
74 | lrs = []
75 | for i in range(25000):
76 | s.step()
77 | lrs.append(s.lr)
78 | print(s.lr)
79 |
80 | plt.plot(lrs)
81 | plt.plot(range(0, 25000), lrs)
82 | plt.show()
83 |
--------------------------------------------------------------------------------
/src/AR/modules/patched_mha_with_cache_onnx.py:
--------------------------------------------------------------------------------
1 | from torch.nn.functional import *
2 | from torch.nn.functional import (
3 | _mha_shape_check,
4 | _canonical_mask,
5 | _none_or_dtype,
6 | _in_projection_packed,
7 | )
8 |
9 | def multi_head_attention_forward_patched(
10 | query,
11 | key,
12 | value,
13 | embed_dim_to_check: int,
14 | num_heads: int,
15 | in_proj_weight,
16 | in_proj_bias: Optional[Tensor],
17 | bias_k: Optional[Tensor],
18 | bias_v: Optional[Tensor],
19 | add_zero_attn: bool,
20 | dropout_p: float,
21 | out_proj_weight: Tensor,
22 | out_proj_bias: Optional[Tensor],
23 | training: bool = True,
24 | key_padding_mask: Optional[Tensor] = None,
25 | need_weights: bool = True,
26 | attn_mask: Optional[Tensor] = None,
27 | use_separate_proj_weight: bool = False,
28 | q_proj_weight: Optional[Tensor] = None,
29 | k_proj_weight: Optional[Tensor] = None,
30 | v_proj_weight: Optional[Tensor] = None,
31 | static_k: Optional[Tensor] = None,
32 | static_v: Optional[Tensor] = None,
33 | average_attn_weights: bool = True,
34 | is_causal: bool = False,
35 | cache=None,
36 | ) -> Tuple[Tensor, Optional[Tensor]]:
37 |
38 | # set up shape vars
39 | _, _, embed_dim = query.shape
40 | attn_mask = _canonical_mask(
41 | mask=attn_mask,
42 | mask_name="attn_mask",
43 | other_type=None,
44 | other_name="",
45 | target_type=query.dtype,
46 | check_other=False,
47 | )
48 | head_dim = embed_dim // num_heads
49 |
50 | proj_qkv = linear(query, in_proj_weight, in_proj_bias)
51 | proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
52 | q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
53 |
54 | if cache["first_infer"] == 1:
55 | cache["k"][cache["stage"]] = k
56 | cache["v"][cache["stage"]] = v
57 | else:
58 | cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
59 | cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
60 | k = cache["k"][cache["stage"]]
61 | v = cache["v"][cache["stage"]]
62 | cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
63 |
64 | attn_mask = _canonical_mask(
65 | mask=attn_mask,
66 | mask_name="attn_mask",
67 | other_type=None,
68 | other_name="",
69 | target_type=q.dtype,
70 | check_other=False,
71 | )
72 | attn_mask = attn_mask.unsqueeze(0)
73 |
74 | q = q.view(-1, num_heads, head_dim).transpose(0, 1)
75 | k = k.view(-1, num_heads, head_dim).transpose(0, 1)
76 | v = v.view(-1, num_heads, head_dim).transpose(0, 1)
77 |
78 | dropout_p = 0.0
79 | attn_mask = attn_mask.unsqueeze(0)
80 | q = q.view(num_heads, -1, head_dim).unsqueeze(0)
81 | k = k.view(num_heads, -1, head_dim).unsqueeze(0)
82 | v = v.view(num_heads, -1, head_dim).unsqueeze(0)
83 | attn_output = scaled_dot_product_attention(
84 | q, k, v, attn_mask, dropout_p, is_causal
85 | )
86 | attn_output = (
87 | attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
88 | )
89 | attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
90 | attn_output = attn_output.view(-1, 1, attn_output.size(1))
91 |
92 | return attn_output
93 |
--------------------------------------------------------------------------------
/src/AR/text_processing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/src/AR/text_processing/__init__.py
--------------------------------------------------------------------------------
/src/AR/text_processing/phonemizer.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/phonemizer.py
2 | import itertools
3 | import re
4 | from typing import Dict
5 | from typing import List
6 |
7 | import regex
8 | from gruut import sentences
9 | from gruut.const import Sentence
10 | from gruut.const import Word
11 | from AR.text_processing.symbols import SYMBOL_TO_ID
12 |
13 |
14 | class GruutPhonemizer:
15 | def __init__(self, language: str):
16 | self._phonemizer = sentences
17 | self.lang = language
18 | self.symbol_to_id = SYMBOL_TO_ID
19 | self._special_cases_dict: Dict[str] = {
20 | r"\.\.\.": "... ",
21 | ";": "; ",
22 | ":": ": ",
23 | ",": ", ",
24 | r"\.": ". ",
25 | "!": "! ",
26 | r"\?": "? ",
27 | "—": "—",
28 | "…": "… ",
29 | "«": "«",
30 | "»": "»",
31 | }
32 | self._punctuation_regexp: str = (
33 | rf"([{''.join(self._special_cases_dict.keys())}])"
34 | )
35 |
36 | def _normalize_punctuation(self, text: str) -> str:
37 | text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
38 | text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
39 | text = regex.sub(r"\pZ+", r" ", text)
40 | return text.strip()
41 |
42 | def _convert_punctuation(self, word: Word) -> str:
43 | if not word.phonemes:
44 | return ""
45 | if word.phonemes[0] in ["‖", "|"]:
46 | return word.text.strip()
47 |
48 | phonemes = "".join(word.phonemes)
49 | # remove modifier characters ˈˌː with regex
50 | phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
51 | return phonemes.strip()
52 |
53 | def phonemize(self, text: str, espeak: bool = False) -> str:
54 | text_to_phonemize: str = self._normalize_punctuation(text)
55 | sents: List[Sentence] = [
56 | sent
57 | for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
58 | ]
59 | words: List[str] = [
60 | self._convert_punctuation(word) for word in itertools.chain(*sents)
61 | ]
62 | return " ".join(words)
63 |
64 | def transform(self, phonemes):
65 | # convert phonemes to ids
66 | # dictionary is in symbols.py
67 | return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
68 |
69 |
70 | if __name__ == "__main__":
71 | phonemizer = GruutPhonemizer("en-us")
72 | # text -> IPA
73 | phonemes = phonemizer.phonemize("Hello, wor-ld ?")
74 | print("phonemes:", phonemes)
75 | print("len(phonemes):", len(phonemes))
76 | phoneme_ids = phonemizer.transform(phonemes)
77 | print("phoneme_ids:", phoneme_ids)
78 | print("len(phoneme_ids):", len(phoneme_ids))
79 |
--------------------------------------------------------------------------------
/src/AR/text_processing/symbols.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py
2 | PAD = "_"
3 | PUNCTUATION = ';:,.!?¡¿—…"«»“” '
4 | LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
5 | IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
6 | SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
7 | SPACE_ID = SYMBOLS.index(" ")
8 | SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
9 | ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)}
10 |
--------------------------------------------------------------------------------
/src/AR/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def str2bool(str):
5 | return True if str.lower() == 'true' else False
6 |
7 |
8 | def get_newest_ckpt(string_list):
9 | # 定义一个正则表达式模式,用于匹配字符串中的数字
10 | pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
11 |
12 | # 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
13 | extracted_info = []
14 | for string in string_list:
15 | match = re.match(pattern, string)
16 | if match:
17 | epoch = int(match.group(1))
18 | step = int(match.group(2))
19 | extracted_info.append((epoch, step, string))
20 | # 按照 epoch 后面的数字和 step 后面的数字进行排序
21 | sorted_info = sorted(
22 | extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
23 | # 获取最新的 ckpt 文件名
24 | newest_ckpt = sorted_info[0][2]
25 | return newest_ckpt
26 |
27 |
28 | # 文本存在且不为空时 return True
29 | def check_txt_file(file_path):
30 | try:
31 | with open(file_path, 'r') as file:
32 | text = file.readline().strip()
33 | assert text.strip() != ''
34 | return text
35 | except Exception:
36 | return False
37 | return False
38 |
--------------------------------------------------------------------------------
/src/AR/utils/initialize.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Initialize modules for espnet2 neural networks."""
3 | import torch
4 | from typeguard import check_argument_types
5 |
6 |
7 | def initialize(model: torch.nn.Module, init: str):
8 | """Initialize weights of a neural network module.
9 |
10 | Parameters are initialized using the given method or distribution.
11 |
12 | Custom initialization routines can be implemented into submodules
13 | as function `espnet_initialization_fn` within the custom module.
14 |
15 | Args:
16 | model: Target.
17 | init: Method of initialization.
18 | """
19 | assert check_argument_types()
20 | print("init with", init)
21 |
22 | # weight init
23 | for p in model.parameters():
24 | if p.dim() > 1:
25 | if init == "xavier_uniform":
26 | torch.nn.init.xavier_uniform_(p.data)
27 | elif init == "xavier_normal":
28 | torch.nn.init.xavier_normal_(p.data)
29 | elif init == "kaiming_uniform":
30 | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
31 | elif init == "kaiming_normal":
32 | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
33 | else:
34 | raise ValueError("Unknown initialization: " + init)
35 | # bias init
36 | for name, p in model.named_parameters():
37 | if ".bias" in name and p.dim() == 1:
38 | p.data.zero_()
39 |
--------------------------------------------------------------------------------
/src/AR/utils/io.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import torch
4 | import yaml
5 |
6 |
7 | def load_yaml_config(path):
8 | with open(path) as f:
9 | config = yaml.full_load(f)
10 | return config
11 |
12 |
13 | def save_config_to_yaml(config, path):
14 | assert path.endswith(".yaml")
15 | with open(path, "w") as f:
16 | f.write(yaml.dump(config))
17 | f.close()
18 |
19 |
20 | def write_args(args, path):
21 | args_dict = dict(
22 | (name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
23 | )
24 | with open(path, "a") as args_file:
25 | args_file.write("==> torch version: {}\n".format(torch.__version__))
26 | args_file.write(
27 | "==> cudnn version: {}\n".format(torch.backends.cudnn.version())
28 | )
29 | args_file.write("==> Cmd:\n")
30 | args_file.write(str(sys.argv))
31 | args_file.write("\n==> args:\n")
32 | for k, v in sorted(args_dict.items()):
33 | args_file.write(" %s: %s\n" % (str(k), str(v)))
34 | args_file.close()
35 |
--------------------------------------------------------------------------------
/src/configs/s1.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | seed: 1234
3 | epochs: 300
4 | batch_size: 8
5 | gradient_accumulation: 4
6 | save_every_n_epoch: 1
7 | precision: 16
8 | gradient_clip: 1.0
9 | optimizer:
10 | lr: 0.01
11 | lr_init: 0.00001
12 | lr_end: 0.0001
13 | warmup_steps: 2000
14 | decay_steps: 40000
15 | data:
16 | max_eval_sample: 8
17 | max_sec: 54
18 | num_workers: 1
19 | pad_val: 1024 # same with EOS in model
20 | model:
21 | vocab_size: 1025
22 | phoneme_vocab_size: 512
23 | embedding_dim: 512
24 | hidden_dim: 512
25 | head: 16
26 | linear_units: 2048
27 | n_layer: 12
28 | dropout: 0
29 | EOS: 1024
30 | inference:
31 | top_k: 5
32 |
--------------------------------------------------------------------------------
/src/configs/s1big.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | seed: 1234
3 | epochs: 300
4 | batch_size: 8
5 | gradient_accumulation: 4
6 | save_every_n_epoch: 1
7 | precision: 16-mixed
8 | gradient_clip: 1.0
9 | optimizer:
10 | lr: 0.01
11 | lr_init: 0.00001
12 | lr_end: 0.0001
13 | warmup_steps: 2000
14 | decay_steps: 40000
15 | data:
16 | max_eval_sample: 8
17 | max_sec: 54
18 | num_workers: 1
19 | pad_val: 1024 # same with EOS in model
20 | model:
21 | vocab_size: 1025
22 | phoneme_vocab_size: 512
23 | embedding_dim: 1024
24 | hidden_dim: 1024
25 | head: 16
26 | linear_units: 2048
27 | n_layer: 16
28 | dropout: 0
29 | EOS: 1024
30 | inference:
31 | top_k: 5
32 |
--------------------------------------------------------------------------------
/src/configs/s1big2.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | seed: 1234
3 | epochs: 300
4 | batch_size: 12
5 | gradient_accumulation: 4
6 | save_every_n_epoch: 1
7 | precision: 16-mixed
8 | gradient_clip: 1.0
9 | optimizer:
10 | lr: 0.01
11 | lr_init: 0.00001
12 | lr_end: 0.0001
13 | warmup_steps: 2000
14 | decay_steps: 40000
15 | data:
16 | max_eval_sample: 8
17 | max_sec: 54
18 | num_workers: 1
19 | pad_val: 1024 # same with EOS in model
20 | model:
21 | vocab_size: 1025
22 | phoneme_vocab_size: 512
23 | embedding_dim: 1024
24 | hidden_dim: 1024
25 | head: 16
26 | linear_units: 2048
27 | n_layer: 6
28 | dropout: 0
29 | EOS: 1024
30 | inference:
31 | top_k: 5
32 |
--------------------------------------------------------------------------------
/src/configs/s1longer.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | seed: 1234
3 | epochs: 20
4 | batch_size: 8
5 | save_every_n_epoch: 1
6 | precision: 16-mixed
7 | gradient_clip: 1.0
8 | optimizer:
9 | lr: 0.01
10 | lr_init: 0.00001
11 | lr_end: 0.0001
12 | warmup_steps: 2000
13 | decay_steps: 40000
14 | data:
15 | max_eval_sample: 8
16 | max_sec: 54
17 | num_workers: 4
18 | pad_val: 1024 # same with EOS in model
19 | model:
20 | vocab_size: 1025
21 | phoneme_vocab_size: 512
22 | embedding_dim: 512
23 | hidden_dim: 512
24 | head: 16
25 | linear_units: 2048
26 | n_layer: 24
27 | dropout: 0
28 | EOS: 1024
29 | random_bert: 0
30 | inference:
31 | top_k: 5
32 |
--------------------------------------------------------------------------------
/src/configs/s1mq.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | seed: 1234
3 | epochs: 100
4 | batch_size: 6
5 | gradient_accumulation: 4
6 | save_every_n_epoch: 1
7 | precision: 32
8 | gradient_clip: 1.0
9 | optimizer:
10 | lr: 0.01
11 | lr_init: 0.00001
12 | lr_end: 0.0001
13 | warmup_steps: 2000
14 | decay_steps: 40000
15 | data:
16 | max_eval_sample: 8
17 | max_sec: 40
18 | num_workers: 1
19 | pad_val: 1024 # same with EOS in model
20 | model:
21 | saving_path: "ckpt/"
22 | resume_checkpoint: null
23 | vocoder_config_path: "quantizer/new_ckpt/config.json"
24 | vocoder_ckpt_path: "quantizer/new_ckpt/g_00600000"
25 | datadir: "/home/liweiche/GigaSpeech/wavs"
26 | metapath: "/home/liweiche/GigaSpeech/train2.json"
27 | val_metapath: "/home/liweiche/GigaSpeech/dev2.json"
28 | sampledir: "logs/"
29 | pretrained_path: null
30 | lr: 0.0001
31 | batch_size: 200.0
32 | train_bucket_size: 8192
33 | training_step: 800000
34 | optim_flat_percent: 0.0
35 | warmup_step: 50
36 | adam_beta1: 0.9
37 | adam_beta2: 0.98
38 | ffd_size: 3072
39 | hidden_size: 768
40 | enc_nlayers: 6
41 | dec_nlayers: 6
42 | nheads: 12
43 | ar_layer: 4
44 | ar_ffd_size: 1024
45 | ar_hidden_size: 256
46 | ar_nheads: 4
47 | aligner_softmax_temp: 1.0
48 | layer_norm_eps: 0.00001
49 | speaker_embed_dropout: 0.05
50 | label_smoothing: 0.0
51 | val_check_interval: 5000
52 | check_val_every_n_epoch: 1
53 | precision: "fp16"
54 | nworkers: 16
55 | distributed: true
56 | accelerator: "ddp"
57 | version: null
58 | accumulate_grad_batches: 1
59 | use_repetition_token: true
60 | use_repetition_gating: false
61 | repetition_penalty: 1.0
62 | sampling_temperature: 1.0
63 | top_k: -1
64 | min_top_k: 3
65 | top_p: 0.8
66 | sample_num: 4
67 | length_penalty_max_length: 15000
68 | length_penalty_max_prob: 0.95
69 | max_input_length: 2048
70 | max_output_length: 2000
71 | sample_rate: 16000
72 | n_codes: 1024
73 | n_cluster_groups: 1
74 | phone_context_window: 4
75 | phoneset_size: 1000
76 | inference:
77 | top_k: 5
78 |
--------------------------------------------------------------------------------
/src/configs/sovits.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 100,
4 | "eval_interval": 500,
5 | "seed": 1234,
6 | "epochs": 100,
7 | "learning_rate": 0.0001,
8 | "betas": [
9 | 0.8,
10 | 0.99
11 | ],
12 | "eps": 1e-09,
13 | "batch_size": 32,
14 | "fp16_run": true,
15 | "lr_decay": 0.999875,
16 | "segment_size": 20480,
17 | "init_lr_ratio": 1,
18 | "warmup_epochs": 0,
19 | "c_mel": 45,
20 | "c_kl": 1.0,
21 | "text_low_lr_rate": 0.4
22 | },
23 | "data": {
24 | "max_wav_value": 32768.0,
25 | "sampling_rate": 32000,
26 | "filter_length": 2048,
27 | "hop_length": 640,
28 | "win_length": 2048,
29 | "n_mel_channels": 128,
30 | "mel_fmin": 0.0,
31 | "mel_fmax": null,
32 | "add_blank": true,
33 | "n_speakers": 300,
34 | "cleaned_text": true
35 | },
36 | "model": {
37 | "inter_channels": 192,
38 | "hidden_channels": 192,
39 | "filter_channels": 768,
40 | "n_heads": 2,
41 | "n_layers": 6,
42 | "kernel_size": 3,
43 | "p_dropout": 0.1,
44 | "resblock": "1",
45 | "resblock_kernel_sizes": [
46 | 3,
47 | 7,
48 | 11
49 | ],
50 | "resblock_dilation_sizes": [
51 | [
52 | 1,
53 | 3,
54 | 5
55 | ],
56 | [
57 | 1,
58 | 3,
59 | 5
60 | ],
61 | [
62 | 1,
63 | 3,
64 | 5
65 | ]
66 | ],
67 | "upsample_rates": [
68 | 10,
69 | 8,
70 | 2,
71 | 2,
72 | 2
73 | ],
74 | "upsample_initial_channel": 512,
75 | "upsample_kernel_sizes": [
76 | 16,
77 | 16,
78 | 8,
79 | 2,
80 | 2
81 | ],
82 | "n_layers_q": 3,
83 | "use_spectral_norm": false,
84 | "gin_channels": 512,
85 | "semantic_frame_rate": "25hz",
86 | "freeze_quantizer": true
87 | },
88 | "content_module": "cnhubert"
89 | }
--------------------------------------------------------------------------------
/src/configs/train.yaml:
--------------------------------------------------------------------------------
1 | gpu:
2 | n_card: 1
3 | n_process_per_card: 2
4 | io:
5 | text_path: D:\RVC1006\GPT-SoVITS\GPT_SoVITS
6 | save_every_n_epoch: 1
7 | precision: 16-mixed
8 | gradient_clip: 1.0
9 | optimizer:
10 | lr: 0.01
11 | lr_init: 0.00001
12 | lr_end: 0.0001
13 | warmup_steps: 2000
14 | decay_steps: 40000
15 | data:
16 | max_eval_sample: 8
17 | max_sec: 54
18 | num_workers: 1
19 | pad_val: 1024 # same with EOS in model
20 | model:
21 | vocab_size: 1025
22 | phoneme_vocab_size: 512
23 | embedding_dim: 512
24 | hidden_dim: 512
25 | head: 16
26 | linear_units: 2048
27 | n_layer: 24
28 | dropout: 0
29 | EOS: 1024
30 | random_bert: 0
31 | inference:
32 | top_k: 5
33 |
--------------------------------------------------------------------------------
/src/feature_extractor/__init__.py:
--------------------------------------------------------------------------------
1 | from . import cnhubert, whisper_enc
2 |
3 | content_module_map = {
4 | 'cnhubert': cnhubert,
5 | 'whisper': whisper_enc
6 | }
--------------------------------------------------------------------------------
/src/feature_extractor/cnhubert.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import librosa
4 | import torch
5 | import torch.nn.functional as F
6 | import soundfile as sf
7 | import logging
8 |
9 | # logging.getLogger("numba").setLevel(logging.INFO)
10 |
11 | from transformers import (
12 | Wav2Vec2FeatureExtractor,
13 | HubertModel,
14 | )
15 |
16 | import utils
17 | import torch.nn as nn
18 |
19 |
20 | class CNHubert(nn.Module):
21 | def __init__(self, cnhubert_base_path='pretrained_models/chinese-hubert-base'):
22 | super().__init__()
23 | self.model = HubertModel.from_pretrained(cnhubert_base_path)
24 | self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
25 | cnhubert_base_path
26 | )
27 |
28 | def forward(self, x):
29 | input_values = self.feature_extractor(
30 | x, return_tensors="pt", sampling_rate=16000
31 | ).input_values.to(x.device)
32 | feats = self.model(input_values)["last_hidden_state"]
33 | return feats
34 |
35 |
36 | # class CNHubertLarge(nn.Module):
37 | # def __init__(self):
38 | # super().__init__()
39 | # self.model = HubertModel.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
40 | # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
41 | # def forward(self, x):
42 | # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
43 | # feats = self.model(input_values)["last_hidden_state"]
44 | # return feats
45 | #
46 | # class CVec(nn.Module):
47 | # def __init__(self):
48 | # super().__init__()
49 | # self.model = HubertModel.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
50 | # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
51 | # def forward(self, x):
52 | # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
53 | # feats = self.model(input_values)["last_hidden_state"]
54 | # return feats
55 | #
56 | # class cnw2v2base(nn.Module):
57 | # def __init__(self):
58 | # super().__init__()
59 | # self.model = Wav2Vec2Model.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
60 | # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
61 | # def forward(self, x):
62 | # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
63 | # feats = self.model(input_values)["last_hidden_state"]
64 | # return feats
65 |
66 |
67 | def get_model():
68 | model = CNHubert()
69 | model.eval()
70 | return model
71 |
72 |
73 | # def get_large_model():
74 | # model = CNHubertLarge()
75 | # model.eval()
76 | # return model
77 | #
78 | # def get_model_cvec():
79 | # model = CVec()
80 | # model.eval()
81 | # return model
82 | #
83 | # def get_model_cnw2v2base():
84 | # model = cnw2v2base()
85 | # model.eval()
86 | # return model
87 |
88 |
89 | def get_content(hmodel, wav_16k_tensor):
90 | with torch.no_grad():
91 | feats = hmodel(wav_16k_tensor)
92 | return feats.transpose(1, 2)
93 |
94 |
95 | if __name__ == "__main__":
96 | model = get_model()
97 | src_path = "/Users/Shared/原音频2.wav"
98 | wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
99 | model = model
100 | wav_16k_tensor = wav_16k_tensor
101 | feats = get_content(model, wav_16k_tensor)
102 | print(feats.shape)
103 |
--------------------------------------------------------------------------------
/src/feature_extractor/whisper_enc.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_model():
5 | import whisper
6 |
7 | model = whisper.load_model("small", device="cpu")
8 |
9 | return model.encoder
10 |
11 |
12 | def get_content(model=None, wav_16k_tensor=None):
13 | from whisper import log_mel_spectrogram, pad_or_trim
14 |
15 | dev = next(model.parameters()).device
16 | mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000]
17 | # if torch.cuda.is_available():
18 | # mel = mel.to(torch.float16)
19 | feature_len = mel.shape[-1] // 2
20 | assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
21 | with torch.no_grad():
22 | feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
23 | :1, :feature_len, :
24 | ].transpose(1, 2)
25 | return feature
26 |
--------------------------------------------------------------------------------
/src/inference/.ipynb_checkpoints/inference-checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import soundfile as sf
4 | import json
5 |
6 | from infer_tool import TTSInference
7 |
8 |
9 | def generate_audio(sovits_weights,
10 | gpt_weights,
11 | input_folder,
12 | output_folder,
13 | ref_wav_path,
14 | prompt_text,
15 | prompt_language,
16 | text,
17 | text_language,
18 | how_to_cut=None,
19 | save=True
20 | ):
21 | tts = TTSInference(sovits_weights=sovits_weights, gpt_weights=gpt_weights)
22 |
23 | infer_dict = {
24 | 'ref_wav_path': os.path.join(input_folder, ref_wav_path),
25 | 'prompt_text': prompt_text,
26 | 'prompt_language': prompt_language,
27 | 'text': text,
28 | 'text_language': text_language,
29 | 'how_to_cut': how_to_cut
30 | }
31 |
32 | audio_generator = tts.infer(**infer_dict)
33 |
34 | sr, audio = next(audio_generator)
35 |
36 | if save:
37 | ref_wav_name = ''.join(os.path.basename(ref_wav_path).split('.')[:-1])
38 | output_path = os.path.join(output_folder, f"{ref_wav_name}_{text[:6]}.wav")
39 | sf.write(output_path, audio, sr)
40 | print(f"Audio saved to {output_path}")
41 | else:
42 | return sr, audio
43 |
44 |
45 | def process_batch(sovits_weights, gpt_weights, input_folder, output_folder, parameters_file):
46 | _, file_extension = os.path.splitext(parameters_file)
47 |
48 | if file_extension.lower() == '.json':
49 | with open(parameters_file, 'r', encoding='utf-8') as file:
50 | parameters = json.load(file)
51 | for param in parameters:
52 | generate_audio(sovits_weights,
53 | gpt_weights,
54 | input_folder,
55 | output_folder,
56 | param['ref_wav_path'],
57 | param['prompt_text'],
58 | param['prompt_language'],
59 | param['text'],
60 | param['text_language'],
61 | param['how_to_cut']
62 | )
63 | elif file_extension.lower() == '.txt':
64 | with open(parameters_file, 'r', encoding='utf-8') as file:
65 | for line in file:
66 | ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut = line.strip().split('|')
67 | generate_audio(sovits_weights,
68 | gpt_weights,
69 | input_folder,
70 | output_folder,
71 | ref_wav_path,
72 | prompt_text,
73 | prompt_language,
74 | text,
75 | text_language,
76 | how_to_cut
77 | )
78 |
79 |
80 | if __name__ == "__main__":
81 | parser = argparse.ArgumentParser(description="Batch Run TTS Inference")
82 | parser.add_argument("--sovits_weights", required=True, help="Path to sovits weights file")
83 | parser.add_argument("--gpt_weights", required=True, help="Path to gpt weights file")
84 | parser.add_argument("--input_folder", type=str, default='input_audio', help="Folder of the input audio files")
85 | parser.add_argument("--output_folder", type=str, default='output_audio', help="Folder to save the output audio files")
86 | parser.add_argument("--parameters_file", type=str, default='inference_parameters.txt', help="File containing parameters for batch processing")
87 |
88 | args = parser.parse_args()
89 |
90 | if not os.path.exists(args.output_folder):
91 | os.makedirs(args.output_folder)
92 |
93 | process_batch(args.sovits_weights,
94 | args.gpt_weights,
95 | args.input_folder,
96 | args.output_folder,
97 | args.parameters_file)
98 |
--------------------------------------------------------------------------------
/src/inference/__init__.py:
--------------------------------------------------------------------------------
1 | from .infer_tool import InferenceModule, TTSInference
--------------------------------------------------------------------------------
/src/inference/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import soundfile as sf
4 | import json
5 |
6 | from infer_tool import TTSInference
7 |
8 |
9 | def generate_audio(sovits_weights,
10 | gpt_weights,
11 | input_folder,
12 | output_folder,
13 | ref_wav_path,
14 | prompt_text,
15 | prompt_language,
16 | text,
17 | text_language,
18 | how_to_cut=None,
19 | save=True
20 | ):
21 | tts = TTSInference(sovits_weights=sovits_weights, gpt_weights=gpt_weights)
22 |
23 | infer_dict = {
24 | 'ref_wav_path': os.path.join(input_folder, ref_wav_path),
25 | 'prompt_text': prompt_text,
26 | 'prompt_language': prompt_language,
27 | 'text': text,
28 | 'text_language': text_language,
29 | 'how_to_cut': how_to_cut
30 | }
31 |
32 | audio_generator = tts.infer(**infer_dict)
33 |
34 | sr, audio = next(audio_generator)
35 |
36 | if save:
37 | ref_wav_name = ''.join(os.path.basename(ref_wav_path).split('.')[:-1])
38 | output_path = os.path.join(output_folder, f"{ref_wav_name}_{text[:6]}.wav")
39 | sf.write(output_path, audio, sr)
40 | print(f"Audio saved to {output_path}")
41 | else:
42 | return sr, audio
43 |
44 |
45 | def process_batch(sovits_weights, gpt_weights, input_folder, output_folder, parameters_file):
46 | _, file_extension = os.path.splitext(parameters_file)
47 |
48 | if file_extension.lower() == '.json':
49 | with open(parameters_file, 'r', encoding='utf-8') as file:
50 | parameters = json.load(file)
51 | for param in parameters:
52 | generate_audio(sovits_weights,
53 | gpt_weights,
54 | input_folder,
55 | output_folder,
56 | param['ref_wav_path'],
57 | param['prompt_text'],
58 | param['prompt_language'],
59 | param['text'],
60 | param['text_language'],
61 | param['how_to_cut']
62 | )
63 | elif file_extension.lower() == '.txt':
64 | with open(parameters_file, 'r', encoding='utf-8') as file:
65 | for line in file:
66 | ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut = line.strip().split('|')
67 | generate_audio(sovits_weights,
68 | gpt_weights,
69 | input_folder,
70 | output_folder,
71 | ref_wav_path,
72 | prompt_text,
73 | prompt_language,
74 | text,
75 | text_language,
76 | how_to_cut
77 | )
78 |
79 |
80 | if __name__ == "__main__":
81 | parser = argparse.ArgumentParser(description="Batch Run TTS Inference")
82 | parser.add_argument("--sovits_weights", required=True, help="Path to sovits weights file")
83 | parser.add_argument("--gpt_weights", required=True, help="Path to gpt weights file")
84 | parser.add_argument("--input_folder", type=str, default='input_audio', help="Folder of the input audio files")
85 | parser.add_argument("--output_folder", type=str, default='output_audio', help="Folder to save the output audio files")
86 | parser.add_argument("--parameters_file", type=str, default='inference_parameters.txt', help="File containing parameters for batch processing")
87 |
88 | args = parser.parse_args()
89 |
90 | if not os.path.exists(args.output_folder):
91 | os.makedirs(args.output_folder)
92 |
93 | process_batch(args.sovits_weights,
94 | args.gpt_weights,
95 | args.input_folder,
96 | args.output_folder,
97 | args.parameters_file)
98 |
--------------------------------------------------------------------------------
/src/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/src/module/__init__.py
--------------------------------------------------------------------------------
/src/module/losses.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | def feature_loss(fmap_r, fmap_g):
8 | loss = 0
9 | for dr, dg in zip(fmap_r, fmap_g):
10 | for rl, gl in zip(dr, dg):
11 | rl = rl.float().detach()
12 | gl = gl.float()
13 | loss += torch.mean(torch.abs(rl - gl))
14 |
15 | return loss * 2
16 |
17 |
18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
19 | loss = 0
20 | r_losses = []
21 | g_losses = []
22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
23 | dr = dr.float()
24 | dg = dg.float()
25 | r_loss = torch.mean((1 - dr) ** 2)
26 | g_loss = torch.mean(dg**2)
27 | loss += r_loss + g_loss
28 | r_losses.append(r_loss.item())
29 | g_losses.append(g_loss.item())
30 |
31 | return loss, r_losses, g_losses
32 |
33 |
34 | def generator_loss(disc_outputs):
35 | loss = 0
36 | gen_losses = []
37 | for dg in disc_outputs:
38 | dg = dg.float()
39 | l = torch.mean((1 - dg) ** 2)
40 | gen_losses.append(l)
41 | loss += l
42 |
43 | return loss, gen_losses
44 |
45 |
46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
47 | """
48 | z_p, logs_q: [b, h, t_t]
49 | m_p, logs_p: [b, h, t_t]
50 | """
51 | z_p = z_p.float()
52 | logs_q = logs_q.float()
53 | m_p = m_p.float()
54 | logs_p = logs_p.float()
55 | z_mask = z_mask.float()
56 |
57 | kl = logs_p - logs_q - 0.5
58 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
59 | kl = torch.sum(kl * z_mask)
60 | l = kl / torch.sum(z_mask)
61 | return l
62 |
63 |
64 | def mle_loss(z, m, logs, logdet, mask):
65 | l = torch.sum(logs) + 0.5 * torch.sum(
66 | torch.exp(-2 * logs) * ((z - m) ** 2)
67 | ) # neg normal likelihood w/o the constant term
68 | l = l - torch.sum(logdet) # log jacobian determinant
69 | l = l / torch.sum(
70 | torch.ones_like(z) * mask
71 | ) # averaging across batch, channel and time axes
72 | l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
73 | return l
74 |
--------------------------------------------------------------------------------
/src/module/mel_processing.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | import torch.utils.data
8 | import numpy as np
9 | import librosa
10 | import librosa.util as librosa_util
11 | from librosa.util import normalize, pad_center, tiny
12 | from scipy.signal import get_window
13 | from scipy.io.wavfile import read
14 | from librosa.filters import mel as librosa_mel_fn
15 |
16 | MAX_WAV_VALUE = 32768.0
17 |
18 |
19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20 | """
21 | PARAMS
22 | ------
23 | C: compression factor
24 | """
25 | return torch.log(torch.clamp(x, min=clip_val) * C)
26 |
27 |
28 | def dynamic_range_decompression_torch(x, C=1):
29 | """
30 | PARAMS
31 | ------
32 | C: compression factor used to compress
33 | """
34 | return torch.exp(x) / C
35 |
36 |
37 | def spectral_normalize_torch(magnitudes):
38 | output = dynamic_range_compression_torch(magnitudes)
39 | return output
40 |
41 |
42 | def spectral_de_normalize_torch(magnitudes):
43 | output = dynamic_range_decompression_torch(magnitudes)
44 | return output
45 |
46 |
47 | mel_basis = {}
48 | hann_window = {}
49 |
50 |
51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52 | if torch.min(y) < -1.0:
53 | print("min value is ", torch.min(y))
54 | if torch.max(y) > 1.0:
55 | print("max value is ", torch.max(y))
56 |
57 | global hann_window
58 | dtype_device = str(y.dtype) + "_" + str(y.device)
59 | wnsize_dtype_device = str(win_size) + "_" + dtype_device
60 | if wnsize_dtype_device not in hann_window:
61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
62 | dtype=y.dtype, device=y.device
63 | )
64 |
65 | y = torch.nn.functional.pad(
66 | y.unsqueeze(1),
67 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
68 | mode="reflect",
69 | )
70 | y = y.squeeze(1)
71 | spec = torch.stft(
72 | y,
73 | n_fft,
74 | hop_length=hop_size,
75 | win_length=win_size,
76 | window=hann_window[wnsize_dtype_device],
77 | center=center,
78 | pad_mode="reflect",
79 | normalized=False,
80 | onesided=True,
81 | return_complex=False,
82 | )
83 |
84 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
85 | return spec
86 |
87 |
88 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
89 | global mel_basis
90 | dtype_device = str(spec.dtype) + "_" + str(spec.device)
91 | fmax_dtype_device = str(fmax) + "_" + dtype_device
92 | if fmax_dtype_device not in mel_basis:
93 | mel = librosa_mel_fn(
94 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
95 | )
96 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
97 | dtype=spec.dtype, device=spec.device
98 | )
99 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
100 | spec = spectral_normalize_torch(spec)
101 | return spec
102 |
103 |
104 | def mel_spectrogram_torch(
105 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
106 | ):
107 | if torch.min(y) < -1.0:
108 | print("min value is ", torch.min(y))
109 | if torch.max(y) > 1.0:
110 | print("max value is ", torch.max(y))
111 |
112 | global mel_basis, hann_window
113 | dtype_device = str(y.dtype) + "_" + str(y.device)
114 | fmax_dtype_device = str(fmax) + "_" + dtype_device
115 | wnsize_dtype_device = str(win_size) + "_" + dtype_device
116 | if fmax_dtype_device not in mel_basis:
117 | mel = librosa_mel_fn(
118 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
119 | )
120 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
121 | dtype=y.dtype, device=y.device
122 | )
123 | if wnsize_dtype_device not in hann_window:
124 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
125 | dtype=y.dtype, device=y.device
126 | )
127 |
128 | y = torch.nn.functional.pad(
129 | y.unsqueeze(1),
130 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
131 | mode="reflect",
132 | )
133 | y = y.squeeze(1)
134 |
135 | spec = torch.stft(
136 | y,
137 | n_fft,
138 | hop_length=hop_size,
139 | win_length=win_size,
140 | window=hann_window[wnsize_dtype_device],
141 | center=center,
142 | pad_mode="reflect",
143 | normalized=False,
144 | onesided=True,
145 | return_complex=False,
146 | )
147 |
148 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
149 |
150 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
151 | spec = spectral_normalize_torch(spec)
152 |
153 | return spec
154 |
--------------------------------------------------------------------------------
/src/preprocess/__init__.py:
--------------------------------------------------------------------------------
1 | from .get_phonemes import get_phonemes
2 | from .get_ssl_features import get_ssl_features
3 | from .get_semantic import get_semantic
--------------------------------------------------------------------------------
/src/preprocess/get_phonemes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4 |
5 | import traceback
6 | import torch
7 | from transformers import AutoModelForMaskedLM, AutoTokenizer
8 |
9 | from text.cleaner import clean_text
10 |
11 |
12 | def get_bert_feature(text, word2ph, tokenizer, bert_model, device):
13 | """
14 | 获取指定文本的BERT特征表示。
15 |
16 | Args:
17 | text (str): 输入文本
18 | word2ph (list): 单词到音素的映射列表
19 | tokenizer (BertTokenizer): BERT分词器对象
20 | bert_model (BertModel): BERT模型对象
21 | device (str): 设备类型('cuda'或'cpu')
22 |
23 | Returns:
24 | torch.Tensor: 音素级别的BERT特征表示,形状为(bert_hidden_size, 音素数量)
25 | """
26 | with torch.no_grad():
27 | inputs = tokenizer(text, return_tensors="pt")
28 | for i in inputs:
29 | inputs[i] = inputs[i].to(device)
30 | res = bert_model(**inputs, output_hidden_states=True)
31 | # 获取倒数第三和倒数第二层的隐藏状态,并在序列长度维度上进行拼接
32 | res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
33 |
34 | # 检查word2ph长度与输入文本长度是否相等
35 | assert len(word2ph) == len(text)
36 |
37 | phone_level_feature = []
38 | for i in range(len(word2ph)):
39 | # 对每个单词的BERT特征进行重复,重复次数等于该单词对应的音素数量
40 | repeat_feature = res[i].repeat(word2ph[i], 1)
41 | phone_level_feature.append(repeat_feature)
42 |
43 | # 将所有音素的BERT特征拼接成一个张量
44 | phone_level_feature = torch.cat(phone_level_feature, dim=0)
45 |
46 | return phone_level_feature.T
47 |
48 | def process(data, save_dir, tokenizer, bert_model, device):
49 | """
50 | 处理输入数据,获取音素序列及BERT特征。
51 |
52 | Args:
53 | data (list): 输入数据列表,每个元素为[wav_name, text, language]
54 | save_dir (str): BERT特征保存目录
55 | tokenizer (BertTokenizer): BERT分词器对象
56 | bert_model (BertModel): BERT模型对象
57 | device (str): 设备类型('cuda'或'cpu')
58 |
59 | Returns:
60 | list: 处理结果列表,每个元素为[wav_name, 音素序列, word2ph, norm_text]
61 | """
62 | res = []
63 | os.makedirs(save_dir, exist_ok=True)
64 |
65 | for name, text, lan in data:
66 | try:
67 | name = os.path.basename(name)
68 | # 清理文本并获取音素序列、单词到音素的映射以及规范化后的文本
69 | phones, word2ph, norm_text = clean_text(
70 | text.replace("%", "-").replace("¥", ","), lan
71 | )
72 | path_bert = f"{save_dir}/{name}.pt"
73 |
74 | # 如果是中文文本且对应的BERT特征文件不存在,则计算并保存BERT特征
75 | if os.path.exists(path_bert) == False and lan == "zh":
76 | bert_feature = get_bert_feature(norm_text, word2ph, tokenizer, bert_model, device)
77 | assert bert_feature.shape[-1] == len(phones)
78 | torch.save(bert_feature, path_bert)
79 | phones = " ".join(phones)
80 | res.append([name, phones, word2ph, norm_text])
81 | except:
82 | print(name, text, traceback.format_exc())
83 |
84 | return res
85 |
86 | def get_phonemes(input_txt_path: str,
87 | save_path: str,
88 | bert_pretrained_dir: str='pretrained_models/chinese-roberta-wwm-ext-large',
89 | is_half: bool=False,
90 | **kwargs) -> None:
91 | """
92 | 从输入文本文件中获取音素序列和BERT特征。
93 |
94 | Args:
95 | input_txt_path (str): 输入文本文件路径
96 | save_path (str): 保存结果的路径
97 | bert_pretrained_dir (str, optional): BERT预训练模型路径. Defaults to 'pretrained_models/chinese-roberta-wwm-ext-large'.
98 | is_half (bool, optional): 是否使用半精度(FP16)模式. Defaults to False.
99 |
100 | Returns:
101 | None
102 | """
103 | os.makedirs(save_path, exist_ok=True)
104 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
105 |
106 | tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
107 | bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
108 |
109 | if is_half:
110 | bert_model = bert_model.half().to(device)
111 | else:
112 | bert_model = bert_model.to(device)
113 |
114 | bert_model.eval()
115 |
116 | todo = []
117 | with open(input_txt_path, "r", encoding="utf8") as f:
118 | lines = f.read().strip("\n").split("\n")
119 | for line in lines:
120 | try:
121 | wav_name, spk_name, language, text = line.split("|")
122 | todo.append([wav_name, text, language.lower()])
123 | except:
124 | print(line, traceback.format_exc())
125 |
126 | res = process(todo, f'{save_path}/bert_features', tokenizer, bert_model, device)
127 |
128 | opt = []
129 | for name, phones, word2ph, norm_text in res:
130 | opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
131 |
132 | with open(f"{save_path}/text2phonemes.txt", "w", encoding="utf8") as f:
133 | f.write("\n".join(opt) + "\n")
134 |
135 | print("文本转音素已完成!")
--------------------------------------------------------------------------------
/src/preprocess/get_semantic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4 |
5 | import traceback
6 | import torch
7 | from module.models import SynthesizerTrn
8 | from utils.utils import get_hparams_from_file
9 |
10 |
11 | def name2go(wav_name: str,
12 | cnhubert_features_path: str,
13 | model: SynthesizerTrn,
14 | device: str,
15 | is_half: bool=False):
16 | """
17 | 根据音频文件名提取语义特征。
18 |
19 | 参数:
20 | wav_name (str): 音频文件的名称。
21 | cnhubert_features_path (str): 存储cnhubert特征文件的路径。
22 | model (SynthesizerTrn): 用于提取语义特征的模型实例。
23 | device (str): 指定运行模型的设备(如'cpu'或'cuda:0')。
24 | is_half (bool, 可选): 如果为True,则使用半精度浮点数处理数据以加快计算速度,默认为False。
25 |
26 | 返回:
27 | str: 音频文件名和其对应的语义特征,使用制表符('\t')分隔。
28 | """
29 | if not os.path.exists(cnhubert_features_path):
30 | return
31 | ssl_content = torch.load(f"{cnhubert_features_path}/{wav_name}.pt", map_location="cpu")
32 | if is_half:
33 | ssl_content = ssl_content.half().to(device)
34 | else:
35 | ssl_content = ssl_content.to(device)
36 | codes = model.extract_latent(ssl_content)
37 | semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
38 | return f"{wav_name}\t{semantic}"
39 |
40 |
41 | def get_semantic(input_txt_path: str,
42 | save_path: str,
43 | G_path: str='pretrained_models/sovits_weights/pretrained/s2G488k.pth',
44 | config_path: str='src/configs/sovits.json',
45 | is_half: bool=False,
46 | **kwargs
47 | ):
48 | """
49 | 从文本文件中读取音频文件名,并提取相应的语义特征。
50 |
51 | 参数:
52 | input_txt_path (str): 包含音频文件名的输入文本文件路径。
53 | save_path (str): 保存结果的路径。
54 | G_path (str, 可选): 指向预训练模型权重的路径,默认为'sovits_weights'下的路径。
55 | config_path (str, 可选): 配置文件的路径,默认为'src/configs'下的sovits.json。
56 | is_half (bool, 可选): 是否使用半精度计算,默认为False。
57 |
58 | 使用kwargs接受任何额外的参数,以便未来的扩展。
59 | """
60 | hubert_features_path = os.path.join(save_path, "cnhubert_features")
61 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
62 |
63 | hps = get_hparams_from_file(config_path)
64 | model = SynthesizerTrn(
65 | hps.data.filter_length // 2 + 1,
66 | hps.train.segment_size // hps.data.hop_length,
67 | n_speakers=hps.data.n_speakers,
68 | **hps.model
69 | )
70 | # 加载模型权重
71 | print(
72 | model.load_state_dict(
73 | torch.load(G_path, map_location=device)["weight"], strict=False
74 | )
75 | )
76 |
77 | if is_half:
78 | model = model.half()
79 | model = model.to(device)
80 | model.eval()
81 |
82 | with open(input_txt_path, "r", encoding="utf8") as f:
83 | lines = f.read().strip("\n").split("\n")
84 |
85 | semantic_results = []
86 | for line in lines:
87 | try:
88 | wav_name, _, _, _ = line.split("|")
89 | wav_name = os.path.basename(wav_name)
90 | semantic = name2go(wav_name, hubert_features_path, model, device, is_half)
91 | semantic_results.append(semantic)
92 | except Exception as e:
93 | # 输出错误信息
94 | print(line, str(e))
95 |
96 | with open(f"{save_path}/name2semantic.tsv", "w", encoding="utf8") as f:
97 | f.write("\n".join(semantic_results))
98 |
99 | print('语义特征提取完成!')
--------------------------------------------------------------------------------
/src/preprocess/get_ssl_features.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4 |
5 | import traceback
6 | import numpy as np
7 | import librosa
8 | from scipy.io import wavfile
9 | import torch
10 |
11 | from utils.utils import load_audio
12 | from feature_extractor.cnhubert import CNHubert
13 |
14 |
15 | MAXX = 0.95
16 | ALPHA = 0.5
17 |
18 | def name2go(wav_name: str,
19 | wav_path: str,
20 | features_path: str,
21 | wav32k_path: str,
22 | model: CNHubert,
23 | device: str,
24 | is_half: bool=False):
25 | """
26 | 对给定的音频文件进行处理和特征提取。
27 |
28 | 参数:
29 | wav_name (str): 音频文件的名称。
30 | wav_path (str): 音频文件的完整路径。
31 | features_path (str): 用于保存提取出的特征文件的路径。
32 | wav32k_path (str): 用于保存处理后的32k采样率音频文件的路径。
33 | model (CNHubert): 预训练的CNHubert模型。
34 | device (str): 指定运行模型的设备,如'cuda:0'或'cpu'。
35 | is_half (bool): 是否将模型和数据转为半精度浮点数以节省内存,默认为False。
36 |
37 | 返回:
38 | None或str: 如果处理中出现问题,返回音频文件的名称;否则不返回任何内容。
39 | """
40 | # 加载音频并调整为32k采样率
41 | tmp_audio = load_audio(wav_path, 32000)
42 | tmp_max = np.abs(tmp_audio).max()
43 | # 如果音频的最大绝对值大于2.2,则打印信息并返回
44 | if tmp_max > 2.2:
45 | print("%s-filtered,%s" % (wav_name, tmp_max))
46 | return
47 | # 对音频数据进行归一化和重采样处理
48 | tmp_audio32 = (tmp_audio / tmp_max * (MAXX * ALPHA * 32768)) + ((1 - ALPHA) * 32768) * tmp_audio
49 | tmp_audio32b = (tmp_audio / tmp_max * (MAXX * ALPHA * 1145.14)) + ((1 - ALPHA) * 1145.14) * tmp_audio
50 | tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 重采样到16k采样率
51 |
52 | tensor_wav16 = torch.from_numpy(tmp_audio)
53 | if is_half:
54 | tensor_wav16 = tensor_wav16.half().to(device)
55 | else:
56 | tensor_wav16 = tensor_wav16.to(device)
57 | # 使用CNHubert模型提取特征
58 | ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu()
59 | # 检查提取的特征是否包含NaN值
60 | if np.isnan(ssl.detach().numpy()).sum() != 0:
61 | print("nan filtered:%s" % wav_name)
62 | return wav_name
63 | # 保存处理后的音频和特征文件
64 | wavfile.write(os.path.join(wav32k_path, wav_name), 32000, tmp_audio32.astype("int16"))
65 | torch.save(ssl, f"{features_path}/{wav_name}.pt")
66 |
67 |
68 | def get_ssl_features(input_txt_path: str,
69 | save_path: str,
70 | input_wav_path: str = None,
71 | cnhubert_path: str = 'pretrained_models/chinese-hubert-base',
72 | is_half: bool = False,
73 | **kwargs):
74 | """
75 | 从文本文件中读取音频文件列表,对每个音频文件进行处理并提取特征。
76 |
77 | 参数:
78 | input_txt_path (str): 包含音频文件信息的文本文件路径。
79 | save_path (str): 保存处理结果和特征文件的根目录。
80 | input_wav_path (str): 音频文件的输入目录。如果不为None,则会从这个目录读取音频文件,默认为None。
81 | cnhubert_path (str): CNHubert模型的路径,默认为预训练模型的路径。
82 | is_half (bool): 是否将模型和数据转为半精度浮点数以节省内存,默认为False。
83 |
84 | 返回:
85 | None
86 | """
87 | # 创建保存32k采样率音频和特征文件的目录
88 | wav32k_path = os.path.join(save_path, "wav32k")
89 | features_path = os.path.join(save_path, "cnhubert_features")
90 | os.makedirs(wav32k_path, exist_ok=True)
91 | os.makedirs(features_path, exist_ok=True)
92 |
93 | # 加载模型并设置运行设备
94 | model = CNHubert(cnhubert_path)
95 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
96 | if is_half:
97 | model = model.half()
98 | model = model.to(device)
99 | model.eval()
100 |
101 | # 读取音频文件列表并进行处理
102 | with open(input_txt_path, "r", encoding="utf8") as f:
103 | lines = f.read().strip("\n").split("\n")
104 |
105 | nan_fails = []
106 | for line in lines:
107 | try:
108 | wav_name, _, _, _ = line.split("|")
109 | if input_wav_path:
110 | wav_name = os.path.basename(wav_name)
111 | wav_path = os.path.join(input_wav_path, wav_name)
112 | else:
113 | wav_path = wav_name
114 | if not os.path.exists(wav_path):
115 | print(f"{wav_path} does not exist")
116 | continue
117 | wav_name = name2go(wav_name, wav_path, features_path, wav32k_path, model, device, is_half)
118 | if wav_name:
119 | nan_fails.append(wav_name)
120 | except:
121 | print(line, traceback.format_exc())
122 |
123 | # 如果有处理失败的文件,尝试不使用半精度重新处理
124 | if len(nan_fails) > 0 and is_half:
125 | is_half = False
126 | model = model.float()
127 | for wav_name in nan_fails:
128 | try:
129 | name2go(wav_name, wav_path, features_path, wav32k_path, model, device, is_half)
130 | except:
131 | print(wav_name, traceback.format_exc())
132 |
133 | print('CnHubert特征提取已完成!')
134 |
--------------------------------------------------------------------------------
/src/preprocess/process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from get_phonemes import get_phonemes
4 | from get_ssl_features import get_ssl_features
5 | from get_semantic import get_semantic
6 |
7 | def main(data_dir="../../data/", log_dir="logs/", name="dolly"):
8 | params = {
9 | "input_txt_path": os.path.join(data_dir, f"{name}/{name}.txt"),
10 | "save_path": f"{log_dir}/{name}",
11 | "input_wav_path": os.path.join(data_dir, f"{name}/vocal/")
12 | }
13 | get_phonemes(**params)
14 | get_ssl_features(**params)
15 | get_semantic(**params)
16 |
17 | if __name__ == "__main__":
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--data_dir", type=str, default="../../data/", help="Directory to save data")
20 | parser.add_argument("--log_dir", type=str, default="logs/", help="Directory to save logs")
21 | parser.add_argument("--name", type=str, default="dolly", help="Name of the logs")
22 | args = parser.parse_args()
23 |
24 | main(args.data_dir, args.log_dir, args.name)
--------------------------------------------------------------------------------
/src/text/__init__.py:
--------------------------------------------------------------------------------
1 | from text.symbols import *
2 |
3 |
4 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
5 |
6 | def cleaned_text_to_sequence(cleaned_text):
7 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
8 | Args:
9 | text: string to convert to a sequence
10 | Returns:
11 | List of integers corresponding to the symbols in the text
12 | '''
13 | phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
14 | return phones
15 |
16 |
--------------------------------------------------------------------------------
/src/text/cleaner.py:
--------------------------------------------------------------------------------
1 | from text import chinese, japanese, cleaned_text_to_sequence, symbols, english
2 |
3 | language_module_map = {"zh": chinese, "ja": japanese, "en": english}
4 | special = [
5 | # ("%", "zh", "SP"),
6 | ("¥", "zh", "SP2"),
7 | ("^", "zh", "SP3"),
8 | # ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧
9 | ]
10 |
11 |
12 | def clean_text(text, language):
13 | if(language not in language_module_map):
14 | language="en"
15 | text=" "
16 | for special_s, special_l, target_symbol in special:
17 | if special_s in text and language == special_l:
18 | return clean_special(text, language, special_s, target_symbol)
19 | language_module = language_module_map[language]
20 | norm_text = language_module.text_normalize(text)
21 | if language == "zh":
22 | phones, word2ph = language_module.g2p(norm_text)
23 | assert len(phones) == sum(word2ph)
24 | assert len(norm_text) == len(word2ph)
25 | else:
26 | phones = language_module.g2p(norm_text)
27 | word2ph = None
28 |
29 | for ph in phones:
30 | assert ph in symbols
31 | return phones, word2ph, norm_text
32 |
33 |
34 | def clean_special(text, language, special_s, target_symbol):
35 | """
36 | 特殊静音段sp符号处理
37 | """
38 | text = text.replace(special_s, ",")
39 | language_module = language_module_map[language]
40 | norm_text = language_module.text_normalize(text)
41 | phones = language_module.g2p(norm_text)
42 | new_ph = []
43 | for ph in phones[0]:
44 | assert ph in symbols
45 | if ph == ",":
46 | new_ph.append(target_symbol)
47 | else:
48 | new_ph.append(ph)
49 | return new_ph, phones[1], norm_text
50 |
51 |
52 | def text_to_sequence(text, language):
53 | phones = clean_text(text)
54 | return cleaned_text_to_sequence(phones)
55 |
56 |
57 | if __name__ == "__main__":
58 | print(clean_text("你好%啊啊啊额、还是到付红四方。", "zh"))
59 |
--------------------------------------------------------------------------------
/src/text/engdict-hot.rep:
--------------------------------------------------------------------------------
1 | CHATGPT CH AE1 T JH IY1 P IY1 T IY1
2 | LLM EH1 L EH1 L EH1 M
--------------------------------------------------------------------------------
/src/text/engdict_cache.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/src/text/engdict_cache.pickle
--------------------------------------------------------------------------------
/src/text/namedict_cache.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZaVang/GPT-SoVits/2fad70072ea91dae560e3da894c719676496684a/src/text/namedict_cache.pickle
--------------------------------------------------------------------------------
/src/text/zh_normalization/README.md:
--------------------------------------------------------------------------------
1 | ## Supported NSW (Non-Standard-Word) Normalization
2 |
3 | |NSW type|raw|normalized|
4 | |:--|:-|:-|
5 | |serial number|电影中梁朝伟扮演的陈永仁的编号27149|电影中梁朝伟扮演的陈永仁的编号二七一四九|
6 | |cardinal|这块黄金重达324.75克
我们班的最高总分为583分|这块黄金重达三百二十四点七五克
我们班的最高总分为五百八十三分|
7 | |numeric range |12\~23
-1.5\~2|十二到二十三
负一点五到二|
8 | |date|她出生于86年8月18日,她弟弟出生于1995年3月1日|她出生于八六年八月十八日, 她弟弟出生于一九九五年三月一日|
9 | |time|等会请在12:05请通知我|等会请在十二点零五分请通知我
10 | |temperature|今天的最低气温达到-10°C|今天的最低气温达到零下十度
11 | |fraction|现场有7/12的观众投出了赞成票|现场有十二分之七的观众投出了赞成票|
12 | |percentage|明天有62%的概率降雨|明天有百分之六十二的概率降雨|
13 | |money|随便来几个价格12块5,34.5元,20.1万|随便来几个价格十二块五,三十四点五元,二十点一万|
14 | |telephone|这是固话0421-33441122
这是手机+86 18544139121|这是固话零四二一三三四四一一二二
这是手机八六一八五四四一三九一二一|
15 | ## References
16 | [Pull requests #658 of DeepSpeech](https://github.com/PaddlePaddle/DeepSpeech/pull/658/files)
17 |
--------------------------------------------------------------------------------
/src/text/zh_normalization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from text.zh_normalization.text_normlization import *
15 |
--------------------------------------------------------------------------------
/src/text/zh_normalization/chronology.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import re
15 |
16 | from .num import DIGITS
17 | from .num import num2str
18 | from .num import verbalize_cardinal
19 | from .num import verbalize_digit
20 |
21 |
22 | def _time_num2str(num_string: str) -> str:
23 | """A special case for verbalizing number in time."""
24 | result = num2str(num_string.lstrip('0'))
25 | if num_string.startswith('0'):
26 | result = DIGITS['0'] + result
27 | return result
28 |
29 |
30 | # 时刻表达式
31 | RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])'
32 | r':([0-5][0-9])'
33 | r'(:([0-5][0-9]))?')
34 |
35 | # 时间范围,如8:30-12:30
36 | RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])'
37 | r':([0-5][0-9])'
38 | r'(:([0-5][0-9]))?'
39 | r'(~|-)'
40 | r'([0-1]?[0-9]|2[0-3])'
41 | r':([0-5][0-9])'
42 | r'(:([0-5][0-9]))?')
43 |
44 |
45 | def replace_time(match) -> str:
46 | """
47 | Args:
48 | match (re.Match)
49 | Returns:
50 | str
51 | """
52 |
53 | is_range = len(match.groups()) > 5
54 |
55 | hour = match.group(1)
56 | minute = match.group(2)
57 | second = match.group(4)
58 |
59 | if is_range:
60 | hour_2 = match.group(6)
61 | minute_2 = match.group(7)
62 | second_2 = match.group(9)
63 |
64 | result = f"{num2str(hour)}点"
65 | if minute.lstrip('0'):
66 | if int(minute) == 30:
67 | result += "半"
68 | else:
69 | result += f"{_time_num2str(minute)}分"
70 | if second and second.lstrip('0'):
71 | result += f"{_time_num2str(second)}秒"
72 |
73 | if is_range:
74 | result += "至"
75 | result += f"{num2str(hour_2)}点"
76 | if minute_2.lstrip('0'):
77 | if int(minute) == 30:
78 | result += "半"
79 | else:
80 | result += f"{_time_num2str(minute_2)}分"
81 | if second_2 and second_2.lstrip('0'):
82 | result += f"{_time_num2str(second_2)}秒"
83 |
84 | return result
85 |
86 |
87 | RE_DATE = re.compile(r'(\d{4}|\d{2})年'
88 | r'((0?[1-9]|1[0-2])月)?'
89 | r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?')
90 |
91 |
92 | def replace_date(match) -> str:
93 | """
94 | Args:
95 | match (re.Match)
96 | Returns:
97 | str
98 | """
99 | year = match.group(1)
100 | month = match.group(3)
101 | day = match.group(5)
102 | result = ""
103 | if year:
104 | result += f"{verbalize_digit(year)}年"
105 | if month:
106 | result += f"{verbalize_cardinal(month)}月"
107 | if day:
108 | result += f"{verbalize_cardinal(day)}{match.group(9)}"
109 | return result
110 |
111 |
112 | # 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
113 | RE_DATE2 = re.compile(
114 | r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])')
115 |
116 |
117 | def replace_date2(match) -> str:
118 | """
119 | Args:
120 | match (re.Match)
121 | Returns:
122 | str
123 | """
124 | year = match.group(1)
125 | month = match.group(3)
126 | day = match.group(4)
127 | result = ""
128 | if year:
129 | result += f"{verbalize_digit(year)}年"
130 | if month:
131 | result += f"{verbalize_cardinal(month)}月"
132 | if day:
133 | result += f"{verbalize_cardinal(day)}日"
134 | return result
135 |
--------------------------------------------------------------------------------
/src/text/zh_normalization/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import re
15 | import string
16 |
17 | from pypinyin.constants import SUPPORT_UCS4
18 |
19 | # 全角半角转换
20 | # 英文字符全角 -> 半角映射表 (num: 52)
21 | F2H_ASCII_LETTERS = {
22 | ord(char) + 65248: ord(char)
23 | for char in string.ascii_letters
24 | }
25 |
26 | # 英文字符半角 -> 全角映射表
27 | H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
28 |
29 | # 数字字符全角 -> 半角映射表 (num: 10)
30 | F2H_DIGITS = {ord(char) + 65248: ord(char) for char in string.digits}
31 | # 数字字符半角 -> 全角映射表
32 | H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()}
33 |
34 | # 标点符号全角 -> 半角映射表 (num: 32)
35 | F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation}
36 | # 标点符号半角 -> 全角映射表
37 | H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
38 |
39 | # 空格 (num: 1)
40 | F2H_SPACE = {'\u3000': ' '}
41 | H2F_SPACE = {' ': '\u3000'}
42 |
43 | # 非"有拼音的汉字"的字符串,可用于NSW提取
44 | if SUPPORT_UCS4:
45 | RE_NSW = re.compile(r'(?:[^'
46 | r'\u3007' # 〇
47 | r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
48 | r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
49 | r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
50 | r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
51 | r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
52 | r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
53 | r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
54 | r'])+')
55 | else:
56 | RE_NSW = re.compile( # pragma: no cover
57 | r'(?:[^'
58 | r'\u3007' # 〇
59 | r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
60 | r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
61 | r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
62 | r'])+')
63 |
--------------------------------------------------------------------------------
/src/text/zh_normalization/phonecode.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import re
15 |
16 | from .num import verbalize_digit
17 |
18 | # 规范化固话/手机号码
19 | # 手机
20 | # http://www.jihaoba.com/news/show/13680
21 | # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
22 | # 联通:130、131、132、156、155、186、185、176
23 | # 电信:133、153、189、180、181、177
24 | RE_MOBILE_PHONE = re.compile(
25 | r"(? str:
34 | if mobile:
35 | sp_parts = phone_string.strip('+').split()
36 | result = ','.join(
37 | [verbalize_digit(part, alt_one=True) for part in sp_parts])
38 | return result
39 | else:
40 | sil_parts = phone_string.split('-')
41 | result = ','.join(
42 | [verbalize_digit(part, alt_one=True) for part in sil_parts])
43 | return result
44 |
45 |
46 | def replace_phone(match) -> str:
47 | """
48 | Args:
49 | match (re.Match)
50 | Returns:
51 | str
52 | """
53 | return phone2str(match.group(0), mobile=False)
54 |
55 |
56 | def replace_mobile(match) -> str:
57 | """
58 | Args:
59 | match (re.Match)
60 | Returns:
61 | str
62 | """
63 | return phone2str(match.group(0))
64 |
--------------------------------------------------------------------------------
/src/text/zh_normalization/quantifier.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import re
15 |
16 | from .num import num2str
17 |
18 | # 温度表达式,温度会影响负号的读法
19 | # -3°C 零下三度
20 | RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)')
21 | measure_dict = {
22 | "cm2": "平方厘米",
23 | "cm²": "平方厘米",
24 | "cm3": "立方厘米",
25 | "cm³": "立方厘米",
26 | "cm": "厘米",
27 | "db": "分贝",
28 | "ds": "毫秒",
29 | "kg": "千克",
30 | "km": "千米",
31 | "m2": "平方米",
32 | "m²": "平方米",
33 | "m³": "立方米",
34 | "m3": "立方米",
35 | "ml": "毫升",
36 | "m": "米",
37 | "mm": "毫米",
38 | "s": "秒"
39 | }
40 |
41 |
42 | def replace_temperature(match) -> str:
43 | """
44 | Args:
45 | match (re.Match)
46 | Returns:
47 | str
48 | """
49 | sign = match.group(1)
50 | temperature = match.group(2)
51 | unit = match.group(3)
52 | sign: str = "零下" if sign else ""
53 | temperature: str = num2str(temperature)
54 | unit: str = "摄氏度" if unit == "摄氏度" else "度"
55 | result = f"{sign}{temperature}{unit}"
56 | return result
57 |
58 |
59 | def replace_measure(sentence) -> str:
60 | for q_notation in measure_dict:
61 | if q_notation in sentence:
62 | sentence = sentence.replace(q_notation, measure_dict[q_notation])
63 | return sentence
64 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import HParams
--------------------------------------------------------------------------------
/src/utils/cut.py:
--------------------------------------------------------------------------------
1 | import re
2 | from enum import Enum
3 |
4 | SPLITS = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
5 |
6 | def split(todo_text):
7 | todo_text = todo_text.replace("……", "。").replace("——", ",")
8 | if todo_text[-1] not in SPLITS:
9 | todo_text += "。"
10 | i_split_head = i_split_tail = 0
11 | len_text = len(todo_text)
12 | todo_texts = []
13 | while 1:
14 | if i_split_head >= len_text:
15 | break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
16 | if todo_text[i_split_head] in SPLITS:
17 | i_split_head += 1
18 | todo_texts.append(todo_text[i_split_tail:i_split_head])
19 | i_split_tail = i_split_head
20 | else:
21 | i_split_head += 1
22 | return todo_texts
23 |
24 |
25 | def get_first(text):
26 | pattern = "[" + "".join(re.escape(sep) for sep in SPLITS) + "]"
27 | text = re.split(pattern, text)[0].strip()
28 | return text
29 |
30 |
31 | def cut1(inp):
32 | inp = inp.strip("\n")
33 | inps = split(inp)
34 | split_idx = list(range(0, len(inps), 4))
35 | split_idx[-1] = None
36 | if len(split_idx) > 1:
37 | opts = []
38 | for idx in range(len(split_idx) - 1):
39 | opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
40 | else:
41 | opts = [inp]
42 | return "\n".join(opts)
43 |
44 |
45 | def cut2(inp):
46 | inp = inp.strip("\n")
47 | inps = split(inp)
48 | if len(inps) < 2:
49 | return inp
50 | opts = []
51 | summ = 0
52 | tmp_str = ""
53 | for i in range(len(inps)):
54 | summ += len(inps[i])
55 | tmp_str += inps[i]
56 | if summ > 50:
57 | summ = 0
58 | opts.append(tmp_str)
59 | tmp_str = ""
60 | if tmp_str != "":
61 | opts.append(tmp_str)
62 | # print(opts)
63 | if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
64 | opts[-2] = opts[-2] + opts[-1]
65 | opts = opts[:-1]
66 | return "\n".join(opts)
67 |
68 |
69 | def cut3(inp):
70 | inp = inp.strip("\n")
71 | return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
72 |
73 |
74 | def cut4(inp):
75 | inp = inp.strip("\n")
76 | return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
77 |
78 |
79 | # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
80 | def cut5(inp):
81 | # if not re.search(r'[^\w\s]', inp[-1]):
82 | # inp += '。'
83 | inp = inp.strip("\n")
84 | punds = r'[,.;?!、,。?!;:…]'
85 | items = re.split(f'({punds})', inp)
86 | mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
87 | # 在句子不存在符号或句尾无符号的时候保证文本完整
88 | if len(items)%2 == 1:
89 | mergeitems.append(items[-1])
90 | opt = "\n".join(mergeitems)
91 | return opt
92 |
93 |
94 | CUT_DICT = {
95 | "凑四句一切": cut1,
96 | "凑50字一切": cut2,
97 | "按中文句号。切": cut3,
98 | "按英文句号.切": cut4,
99 | "按标点符号切": cut5
100 | }
--------------------------------------------------------------------------------
/src/utils/process_ckpt.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | from collections import OrderedDict
3 | from time import time as ttime
4 | import shutil,os
5 | import torch
6 |
7 | def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
8 | dir=os.path.dirname(path)
9 | name=os.path.basename(path)
10 | tmp_path="%s.pth"%(ttime())
11 | torch.save(fea,tmp_path)
12 | shutil.move(tmp_path,"%s/%s"%(dir,name))
13 |
14 | def savee(ckpt, name, epoch, steps, hps):
15 | try:
16 | opt = OrderedDict()
17 | opt["weight"] = {}
18 | for key in ckpt.keys():
19 | if "enc_q" in key:
20 | continue
21 | opt["weight"][key] = ckpt[key].half()
22 | opt["config"] = hps
23 | opt["info"] = "%sepoch_%siteration" % (epoch, steps)
24 | torch.save(opt, f"{hps.model_dir}/{name}.pth")
25 | # my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
26 | return "Success."
27 | except:
28 | return traceback.format_exc()
29 |
--------------------------------------------------------------------------------
/tools/asr/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def check_fw_local_models():
4 | '''
5 | 启动时检查本地是否有 Faster Whisper 模型.
6 | '''
7 | model_size_list = [
8 | "tiny", "tiny.en",
9 | "base", "base.en",
10 | "small", "small.en",
11 | "medium", "medium.en",
12 | "large", "large-v1",
13 | "large-v2", "large-v3"]
14 | for i, size in enumerate(model_size_list):
15 | if os.path.exists(f'tools/asr/models/faster-whisper-{size}'):
16 | model_size_list[i] = size + '-local'
17 | return model_size_list
18 |
19 | asr_dict = {
20 | "达摩 ASR (中文)": {
21 | 'lang': ['zh'],
22 | 'size': ['large'],
23 | 'path': 'funasr_asr.py',
24 | },
25 | "Faster Whisper (多语种)": {
26 | 'lang': ['auto', 'zh', 'en', 'ja'],
27 | 'size': check_fw_local_models(),
28 | 'path': 'fasterwhisper_asr.py'
29 | }
30 | }
31 |
32 |
--------------------------------------------------------------------------------
/tools/asr/fasterwhisper_asr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | os.environ["HF_ENDPOINT"]="https://hf-mirror.com"
4 | import traceback
5 | import requests
6 | from glob import glob
7 |
8 | from faster_whisper import WhisperModel
9 | from tqdm import tqdm
10 |
11 | from tools.asr.config import check_fw_local_models
12 | from tools.asr.funasr_asr import only_asr
13 |
14 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
15 |
16 | language_code_list = [
17 | "af", "am", "ar", "as", "az",
18 | "ba", "be", "bg", "bn", "bo",
19 | "br", "bs", "ca", "cs", "cy",
20 | "da", "de", "el", "en", "es",
21 | "et", "eu", "fa", "fi", "fo",
22 | "fr", "gl", "gu", "ha", "haw",
23 | "he", "hi", "hr", "ht", "hu",
24 | "hy", "id", "is", "it", "ja",
25 | "jw", "ka", "kk", "km", "kn",
26 | "ko", "la", "lb", "ln", "lo",
27 | "lt", "lv", "mg", "mi", "mk",
28 | "ml", "mn", "mr", "ms", "mt",
29 | "my", "ne", "nl", "nn", "no",
30 | "oc", "pa", "pl", "ps", "pt",
31 | "ro", "ru", "sa", "sd", "si",
32 | "sk", "sl", "sn", "so", "sq",
33 | "sr", "su", "sv", "sw", "ta",
34 | "te", "tg", "th", "tk", "tl",
35 | "tr", "tt", "uk", "ur", "uz",
36 | "vi", "yi", "yo", "zh", "yue",
37 | "auto"]
38 |
39 | def execute_asr(input_folder, output_folder, model_size, language,precision):
40 | if '-local' in model_size:
41 | model_size = model_size[:-6]
42 | model_path = f'tools/asr/models/faster-whisper-{model_size}'
43 | else:
44 | model_path = model_size
45 | if language == 'auto':
46 | language = None #不设置语种由模型自动输出概率最高的语种
47 | print("loading faster whisper model:",model_size,model_path)
48 | try:
49 | model = WhisperModel(model_path, device="cuda", compute_type=precision)
50 | except:
51 | return print(traceback.format_exc())
52 | output = []
53 | output_file_name = os.path.basename(input_folder)
54 | output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list')
55 |
56 | if not os.path.exists(output_folder):
57 | os.makedirs(output_folder)
58 |
59 | for file in tqdm(glob(os.path.join(input_folder, '**/*.wav'), recursive=True)):
60 | try:
61 | segments, info = model.transcribe(
62 | audio = file,
63 | beam_size = 5,
64 | vad_filter = True,
65 | vad_parameters = dict(min_silence_duration_ms=700),
66 | language = language)
67 | text = ''
68 |
69 | if info.language == "zh":
70 | print("检测为中文文本,转funasr处理")
71 | text = only_asr(file)
72 |
73 | if text == '':
74 | for segment in segments:
75 | text += segment.text
76 | output.append(f"{file}|{output_file_name}|{info.language.upper()}|{text}")
77 | except:
78 | return print(traceback.format_exc())
79 |
80 | with open(output_file_path, "w", encoding="utf-8") as f:
81 | f.write("\n".join(output))
82 | print(f"ASR 任务完成->标注文件路径: {output_file_path}\n")
83 | return output_file_path
84 |
85 | if __name__ == '__main__':
86 | parser = argparse.ArgumentParser()
87 | parser.add_argument("-i", "--input_folder", type=str, required=True,
88 | help="Path to the folder containing WAV files.")
89 | parser.add_argument("-o", "--output_folder", type=str, required=True,
90 | help="Output folder to store transcriptions.")
91 | parser.add_argument("-s", "--model_size", type=str, default='large-v3',
92 | choices=check_fw_local_models(),
93 | help="Model Size of Faster Whisper")
94 | parser.add_argument("-l", "--language", type=str, default='ja',
95 | choices=language_code_list,
96 | help="Language of the audio files.")
97 | parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32'],
98 | help="fp16 or fp32")
99 |
100 | cmd = parser.parse_args()
101 | output_file_path = execute_asr(
102 | input_folder = cmd.input_folder,
103 | output_folder = cmd.output_folder,
104 | model_size = cmd.model_size,
105 | language = cmd.language,
106 | precision = cmd.precision,
107 | )
108 |
--------------------------------------------------------------------------------
/tools/asr/funasr_asr.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | import argparse
4 | import os
5 | import traceback
6 | from tqdm import tqdm
7 |
8 | from funasr import AutoModel
9 |
10 | path_asr = 'pretrained_models/asr/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
11 | path_vad = 'pretrained_models/asr/speech_fsmn_vad_zh-cn-16k-common-pytorch'
12 | path_punc = 'pretrained_models/asr/punc_ct-transformer_zh-cn-common-vocab272727-pytorch'
13 |
14 | model = AutoModel(
15 | model = path_asr,
16 | model_revision = "v2.0.4",
17 | vad_model = path_vad,
18 | vad_model_revision = "v2.0.4",
19 | punc_model = path_punc,
20 | punc_model_revision = "v2.0.4",
21 | )
22 |
23 | def only_asr(input_file):
24 | try:
25 | text = model.generate(input=input_file)[0]["text"]
26 | except:
27 | text = ''
28 | print(traceback.format_exc())
29 | return text
30 |
31 | def execute_asr(input_folder, output_folder, model_size, language):
32 | input_file_names = os.listdir(input_folder)
33 | input_file_names.sort()
34 |
35 | output = []
36 | output_file_name = os.path.basename(input_folder)
37 |
38 | for name in tqdm(input_file_names):
39 | try:
40 | text = model.generate(input="%s/%s"%(input_folder, name))[0]["text"]
41 | output.append(f"{input_folder}/{name}|{output_file_name}|{language.upper()}|{text}")
42 | except:
43 | print(traceback.format_exc())
44 |
45 | output_folder = output_folder or "output/asr_opt"
46 | os.makedirs(output_folder, exist_ok=True)
47 | output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list')
48 |
49 | with open(output_file_path, "w", encoding="utf-8") as f:
50 | f.write("\n".join(output))
51 | print(f"ASR 任务完成->标注文件路径: {output_file_path}\n")
52 | return output_file_path
53 |
54 | if __name__ == '__main__':
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument("-i", "--input_folder", type=str, required=True,
57 | help="Path to the folder containing WAV files.")
58 | parser.add_argument("-o", "--output_folder", type=str, required=True,
59 | help="Output folder to store transcriptions.")
60 | parser.add_argument("-s", "--model_size", type=str, default='large',
61 | help="Model Size of FunASR is Large")
62 | parser.add_argument("-l", "--language", type=str, default='zh', choices=['zh'],
63 | help="Language of the audio files.")
64 | parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32'],
65 | help="fp16 or fp32")#还没接入
66 |
67 | cmd = parser.parse_args()
68 | execute_asr(
69 | input_folder = cmd.input_folder,
70 | output_folder = cmd.output_folder,
71 | model_size = cmd.model_size,
72 | language = cmd.language,
73 | )
74 |
--------------------------------------------------------------------------------
/tools/asr/models/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------
/tools/cmd-denoise.py:
--------------------------------------------------------------------------------
1 | import os,argparse
2 |
3 | from modelscope.pipelines import pipeline
4 | from modelscope.utils.constant import Tasks
5 | from tqdm import tqdm
6 |
7 | path_denoise = 'tools/denoise-model/speech_frcrn_ans_cirm_16k'
8 | path_denoise = path_denoise if os.path.exists(path_denoise) else "damo/speech_frcrn_ans_cirm_16k"
9 | ans = pipeline(Tasks.acoustic_noise_suppression,model=path_denoise)
10 | def execute_denoise(input_folder,output_folder):
11 | os.makedirs(output_folder,exist_ok=True)
12 | # print(input_folder)
13 | # print(list(os.listdir(input_folder).sort()))
14 | for name in tqdm(os.listdir(input_folder)):
15 | ans("%s/%s"%(input_folder,name),output_path='%s/%s'%(output_folder,name))
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("-i", "--input_folder", type=str, required=True,
20 | help="Path to the folder containing WAV files.")
21 | parser.add_argument("-o", "--output_folder", type=str, required=True,
22 | help="Output folder to store transcriptions.")
23 | parser.add_argument("-p", "--precision", type=str, default='float16', choices=['float16','float32'],
24 | help="fp16 or fp32")#还没接入
25 | cmd = parser.parse_args()
26 | execute_denoise(
27 | input_folder = cmd.input_folder,
28 | output_folder = cmd.output_folder,
29 | )
--------------------------------------------------------------------------------
/tools/denoise-model/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/tools/my_utils.py:
--------------------------------------------------------------------------------
1 | import platform,os,traceback
2 | import ffmpeg
3 | import numpy as np
4 |
5 |
6 | def load_audio(file, sr):
7 | try:
8 | # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
9 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
10 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
11 | file = clean_path(file) # 防止小白拷路径头尾带了空格和"和回车
12 | if os.path.exists(file) == False:
13 | raise RuntimeError(
14 | "You input a wrong audio path that does not exists, please fix it!"
15 | )
16 | out, _ = (
17 | ffmpeg.input(file, threads=0)
18 | .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
19 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
20 | )
21 | except Exception as e:
22 | traceback.print_exc()
23 | raise RuntimeError(f"Failed to load audio: {e}")
24 |
25 | return np.frombuffer(out, np.float32).flatten()
26 |
27 |
28 | def clean_path(path_str):
29 | if platform.system() == 'Windows':
30 | path_str = path_str.replace('/', '\\')
31 | return path_str.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
32 |
--------------------------------------------------------------------------------
/tools/process_data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import os
3 | import argparse
4 | import re
5 |
6 | def convert_xlsx_to_txt(name, language) -> None:
7 | # 读取xlsx文件
8 | xlsx_path = os.path.join('data', name, 'lines.xlsx')
9 | df = pd.read_excel(xlsx_path, header=None) # 不使用表头
10 |
11 | # 创建一个字典来存储音频名称和对应的文本
12 | audio_text_dict = {}
13 | for index, row in df.iterrows():
14 | text = row[0] # 第一列
15 | audio_name = row[1] # 第二列
16 | if pd.isna(text):
17 | continue
18 | # 删除括号里的部分,包括中文括号
19 | text = re.sub(r'\(.*?\)', '', text).strip()
20 | text = re.sub(r'\[.*?\]', '', text).strip()
21 | text = re.sub(r'(.*?)', '', text).strip() # 处理中文括号
22 | if not text:
23 | continue
24 | audio_text_dict[audio_name] = text
25 |
26 | # 遍历vocal文件夹下的所有wav文件
27 | vocal_dir = os.path.join('data', name, 'vocal')
28 | wav_files = [f for f in os.listdir(vocal_dir) if f.endswith('.wav')]
29 |
30 | total_wav_count = len(wav_files)
31 | no_text_count = 0
32 |
33 | # 打开name.txt文件进行写入
34 | txt_path = os.path.join('data', name, f'{name}.txt')
35 | with open(txt_path, 'w', encoding='utf-8') as f:
36 | for wav_file in wav_files:
37 | # 去掉末尾的-00x
38 | base_name = re.sub(r'-\d{3}$', '', wav_file[:-4])
39 |
40 | if base_name in audio_text_dict:
41 | text = audio_text_dict[base_name]
42 | f.write(f"{wav_file}|{name}|{language}|{text}\n")
43 | else:
44 | print(f"Warning: {wav_file} does not have a corresponding text.")
45 | no_text_count += 1
46 |
47 | print(f"Total WAV files: {total_wav_count}")
48 | print(f"WAV files without corresponding text: {no_text_count}")
49 |
50 | def main() -> None:
51 | parser = argparse.ArgumentParser(description='Convert lines.xlsx to name.txt')
52 | parser.add_argument('-n', '--name', type=str, help='Name of the dataset')
53 | parser.add_argument('-l', '--language', type=str, help='Language code (e.g., zh, ja, en)')
54 |
55 | args = parser.parse_args()
56 |
57 | convert_xlsx_to_txt(args.name, args.language)
58 |
59 | if __name__ == '__main__':
60 | main()
61 |
--------------------------------------------------------------------------------
/tools/slice_audio.py:
--------------------------------------------------------------------------------
1 | import os,sys,numpy as np
2 | import traceback
3 | # parent_directory = os.path.dirname(os.path.abspath(__file__))
4 | # sys.path.append(parent_directory)
5 | from my_utils import load_audio
6 | from slicer2 import Slicer
7 | import os.path
8 | from argparse import ArgumentParser
9 | import soundfile
10 | import librosa
11 |
12 | def slice(input_path: str,
13 | output_path: str,
14 | db_threshold: int=-40,
15 | min_length: int=5000,
16 | min_interval: int=300,
17 | hop_size: int=10,
18 | max_sil_kept: int=500,
19 | max_amp: float=1.0,
20 | alpha: float=0):
21 |
22 | os.makedirs(output_path, exist_ok=True)
23 | if os.path.isfile(input_path):
24 | input_files=[input_path]
25 | elif os.path.isdir(input_path):
26 | input_files=[os.path.join(input_path, name) for name in sorted(list(os.listdir(input_path)))]
27 | else:
28 | return "输入路径存在但既不是文件也不是文件夹"
29 |
30 | max_amp=float(max_amp)
31 | alpha=float(alpha)
32 | for file in input_files:
33 | try:
34 | audio, sr = librosa.load(file, sr=None, mono=False)
35 | # print(audio.shape)
36 |
37 | slicer = Slicer(
38 | sr=sr, # 长音频采样率
39 | threshold= int(db_threshold), # 音量小于这个值视作静音的备选切割点
40 | min_length= int(min_length), # 每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值
41 | min_interval= int(min_interval), # 最短切割间隔
42 | hop_size= int(hop_size), # 怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)
43 | max_sil_kept= int(max_sil_kept), # 切完后静音最多留多长
44 | )
45 |
46 | chunks = slicer.slice(audio)
47 | for i, (chunk, start, end) in enumerate(chunks):
48 | if len(chunk.shape) > 1:
49 | chunk = chunk.T
50 | tmp_max = np.abs(chunk).max()
51 | if(tmp_max>1):chunk/=tmp_max
52 | chunk = (chunk / tmp_max * (max_amp * alpha)) + (1 - alpha) * chunk
53 | soundfile.write(
54 | os.path.join(
55 | output_path,
56 | f"%s_%d.wav"
57 | % (os.path.basename(file).rsplit(".", maxsplit=1)[0], i),
58 | ),
59 | chunk,
60 | sr,
61 | )
62 | except:
63 | print(file,"->fail->",traceback.format_exc())
64 | return "执行完毕,请检查输出文件"
65 |
66 |
67 | if __name__ == "__main__":
68 | parser = ArgumentParser()
69 | parser.add_argument(
70 | "--input_path",
71 | type=str,
72 | help="The audios to be sliced, can be a single file or a directory"
73 | )
74 | parser.add_argument(
75 | "--output_path",
76 | type=str,
77 | help="Output directory of the sliced audio clips"
78 | )
79 | parser.add_argument(
80 | "--db_threshold",
81 | type=float,
82 | required=False,
83 | default=-40,
84 | help="The dB threshold for silence detection",
85 | )
86 | parser.add_argument(
87 | "--min_length",
88 | type=int,
89 | required=False,
90 | default=5000,
91 | help="The minimum milliseconds required for each sliced audio clip",
92 | )
93 | parser.add_argument(
94 | "--min_interval",
95 | type=int,
96 | required=False,
97 | default=300,
98 | help="The minimum milliseconds for a silence part to be sliced",
99 | )
100 | parser.add_argument(
101 | "--hop_size",
102 | type=int,
103 | required=False,
104 | default=10,
105 | help="Frame length in milliseconds",
106 | )
107 | parser.add_argument(
108 | "--max_sil_kept",
109 | type=int,
110 | required=False,
111 | default=500,
112 | help="The maximum silence length kept around the sliced clip, presented in milliseconds",
113 | )
114 | parser.add_argument(
115 | "--max_amp",
116 | type=float,
117 | required=False,
118 | default=1.0,
119 | help="The maximum amplitude of the sliced audio clips",
120 | )
121 | parser.add_argument(
122 | "--alpha",
123 | type=float,
124 | required=False,
125 | default=0,
126 | help="The alpha value for amplitude adjustment",
127 | )
128 | args = parser.parse_args()
129 |
130 | slice(**args.__dict__)
131 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import spec_utils
6 |
7 |
8 | class Conv2DBNActiv(nn.Module):
9 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
10 | super(Conv2DBNActiv, self).__init__()
11 | self.conv = nn.Sequential(
12 | nn.Conv2d(
13 | nin,
14 | nout,
15 | kernel_size=ksize,
16 | stride=stride,
17 | padding=pad,
18 | dilation=dilation,
19 | bias=False,
20 | ),
21 | nn.BatchNorm2d(nout),
22 | activ(),
23 | )
24 |
25 | def __call__(self, x):
26 | return self.conv(x)
27 |
28 |
29 | class SeperableConv2DBNActiv(nn.Module):
30 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
31 | super(SeperableConv2DBNActiv, self).__init__()
32 | self.conv = nn.Sequential(
33 | nn.Conv2d(
34 | nin,
35 | nin,
36 | kernel_size=ksize,
37 | stride=stride,
38 | padding=pad,
39 | dilation=dilation,
40 | groups=nin,
41 | bias=False,
42 | ),
43 | nn.Conv2d(nin, nout, kernel_size=1, bias=False),
44 | nn.BatchNorm2d(nout),
45 | activ(),
46 | )
47 |
48 | def __call__(self, x):
49 | return self.conv(x)
50 |
51 |
52 | class Encoder(nn.Module):
53 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
54 | super(Encoder, self).__init__()
55 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
56 | self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
57 |
58 | def __call__(self, x):
59 | skip = self.conv1(x)
60 | h = self.conv2(skip)
61 |
62 | return h, skip
63 |
64 |
65 | class Decoder(nn.Module):
66 | def __init__(
67 | self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
68 | ):
69 | super(Decoder, self).__init__()
70 | self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
71 | self.dropout = nn.Dropout2d(0.1) if dropout else None
72 |
73 | def __call__(self, x, skip=None):
74 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
75 | if skip is not None:
76 | skip = spec_utils.crop_center(skip, x)
77 | x = torch.cat([x, skip], dim=1)
78 | h = self.conv(x)
79 |
80 | if self.dropout is not None:
81 | h = self.dropout(h)
82 |
83 | return h
84 |
85 |
86 | class ASPPModule(nn.Module):
87 | def __init__(self, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU):
88 | super(ASPPModule, self).__init__()
89 | self.conv1 = nn.Sequential(
90 | nn.AdaptiveAvgPool2d((1, None)),
91 | Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
92 | )
93 | self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
94 | self.conv3 = SeperableConv2DBNActiv(
95 | nin, nin, 3, 1, dilations[0], dilations[0], activ=activ
96 | )
97 | self.conv4 = SeperableConv2DBNActiv(
98 | nin, nin, 3, 1, dilations[1], dilations[1], activ=activ
99 | )
100 | self.conv5 = SeperableConv2DBNActiv(
101 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
102 | )
103 | self.bottleneck = nn.Sequential(
104 | Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
105 | )
106 |
107 | def forward(self, x):
108 | _, _, h, w = x.size()
109 | feat1 = F.interpolate(
110 | self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
111 | )
112 | feat2 = self.conv2(x)
113 | feat3 = self.conv3(x)
114 | feat4 = self.conv4(x)
115 | feat5 = self.conv5(x)
116 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
117 | bottle = self.bottleneck(out)
118 | return bottle
119 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/layers_123812KB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import spec_utils
6 |
7 |
8 | class Conv2DBNActiv(nn.Module):
9 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
10 | super(Conv2DBNActiv, self).__init__()
11 | self.conv = nn.Sequential(
12 | nn.Conv2d(
13 | nin,
14 | nout,
15 | kernel_size=ksize,
16 | stride=stride,
17 | padding=pad,
18 | dilation=dilation,
19 | bias=False,
20 | ),
21 | nn.BatchNorm2d(nout),
22 | activ(),
23 | )
24 |
25 | def __call__(self, x):
26 | return self.conv(x)
27 |
28 |
29 | class SeperableConv2DBNActiv(nn.Module):
30 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
31 | super(SeperableConv2DBNActiv, self).__init__()
32 | self.conv = nn.Sequential(
33 | nn.Conv2d(
34 | nin,
35 | nin,
36 | kernel_size=ksize,
37 | stride=stride,
38 | padding=pad,
39 | dilation=dilation,
40 | groups=nin,
41 | bias=False,
42 | ),
43 | nn.Conv2d(nin, nout, kernel_size=1, bias=False),
44 | nn.BatchNorm2d(nout),
45 | activ(),
46 | )
47 |
48 | def __call__(self, x):
49 | return self.conv(x)
50 |
51 |
52 | class Encoder(nn.Module):
53 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
54 | super(Encoder, self).__init__()
55 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
56 | self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
57 |
58 | def __call__(self, x):
59 | skip = self.conv1(x)
60 | h = self.conv2(skip)
61 |
62 | return h, skip
63 |
64 |
65 | class Decoder(nn.Module):
66 | def __init__(
67 | self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
68 | ):
69 | super(Decoder, self).__init__()
70 | self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
71 | self.dropout = nn.Dropout2d(0.1) if dropout else None
72 |
73 | def __call__(self, x, skip=None):
74 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
75 | if skip is not None:
76 | skip = spec_utils.crop_center(skip, x)
77 | x = torch.cat([x, skip], dim=1)
78 | h = self.conv(x)
79 |
80 | if self.dropout is not None:
81 | h = self.dropout(h)
82 |
83 | return h
84 |
85 |
86 | class ASPPModule(nn.Module):
87 | def __init__(self, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU):
88 | super(ASPPModule, self).__init__()
89 | self.conv1 = nn.Sequential(
90 | nn.AdaptiveAvgPool2d((1, None)),
91 | Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
92 | )
93 | self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
94 | self.conv3 = SeperableConv2DBNActiv(
95 | nin, nin, 3, 1, dilations[0], dilations[0], activ=activ
96 | )
97 | self.conv4 = SeperableConv2DBNActiv(
98 | nin, nin, 3, 1, dilations[1], dilations[1], activ=activ
99 | )
100 | self.conv5 = SeperableConv2DBNActiv(
101 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
102 | )
103 | self.bottleneck = nn.Sequential(
104 | Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
105 | )
106 |
107 | def forward(self, x):
108 | _, _, h, w = x.size()
109 | feat1 = F.interpolate(
110 | self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
111 | )
112 | feat2 = self.conv2(x)
113 | feat3 = self.conv3(x)
114 | feat4 = self.conv4(x)
115 | feat5 = self.conv5(x)
116 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
117 | bottle = self.bottleneck(out)
118 | return bottle
119 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/layers_123821KB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import spec_utils
6 |
7 |
8 | class Conv2DBNActiv(nn.Module):
9 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
10 | super(Conv2DBNActiv, self).__init__()
11 | self.conv = nn.Sequential(
12 | nn.Conv2d(
13 | nin,
14 | nout,
15 | kernel_size=ksize,
16 | stride=stride,
17 | padding=pad,
18 | dilation=dilation,
19 | bias=False,
20 | ),
21 | nn.BatchNorm2d(nout),
22 | activ(),
23 | )
24 |
25 | def __call__(self, x):
26 | return self.conv(x)
27 |
28 |
29 | class SeperableConv2DBNActiv(nn.Module):
30 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
31 | super(SeperableConv2DBNActiv, self).__init__()
32 | self.conv = nn.Sequential(
33 | nn.Conv2d(
34 | nin,
35 | nin,
36 | kernel_size=ksize,
37 | stride=stride,
38 | padding=pad,
39 | dilation=dilation,
40 | groups=nin,
41 | bias=False,
42 | ),
43 | nn.Conv2d(nin, nout, kernel_size=1, bias=False),
44 | nn.BatchNorm2d(nout),
45 | activ(),
46 | )
47 |
48 | def __call__(self, x):
49 | return self.conv(x)
50 |
51 |
52 | class Encoder(nn.Module):
53 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
54 | super(Encoder, self).__init__()
55 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
56 | self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
57 |
58 | def __call__(self, x):
59 | skip = self.conv1(x)
60 | h = self.conv2(skip)
61 |
62 | return h, skip
63 |
64 |
65 | class Decoder(nn.Module):
66 | def __init__(
67 | self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
68 | ):
69 | super(Decoder, self).__init__()
70 | self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
71 | self.dropout = nn.Dropout2d(0.1) if dropout else None
72 |
73 | def __call__(self, x, skip=None):
74 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
75 | if skip is not None:
76 | skip = spec_utils.crop_center(skip, x)
77 | x = torch.cat([x, skip], dim=1)
78 | h = self.conv(x)
79 |
80 | if self.dropout is not None:
81 | h = self.dropout(h)
82 |
83 | return h
84 |
85 |
86 | class ASPPModule(nn.Module):
87 | def __init__(self, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU):
88 | super(ASPPModule, self).__init__()
89 | self.conv1 = nn.Sequential(
90 | nn.AdaptiveAvgPool2d((1, None)),
91 | Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
92 | )
93 | self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
94 | self.conv3 = SeperableConv2DBNActiv(
95 | nin, nin, 3, 1, dilations[0], dilations[0], activ=activ
96 | )
97 | self.conv4 = SeperableConv2DBNActiv(
98 | nin, nin, 3, 1, dilations[1], dilations[1], activ=activ
99 | )
100 | self.conv5 = SeperableConv2DBNActiv(
101 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
102 | )
103 | self.bottleneck = nn.Sequential(
104 | Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
105 | )
106 |
107 | def forward(self, x):
108 | _, _, h, w = x.size()
109 | feat1 = F.interpolate(
110 | self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
111 | )
112 | feat2 = self.conv2(x)
113 | feat3 = self.conv3(x)
114 | feat4 = self.conv4(x)
115 | feat5 = self.conv5(x)
116 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
117 | bottle = self.bottleneck(out)
118 | return bottle
119 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/layers_33966KB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import spec_utils
6 |
7 |
8 | class Conv2DBNActiv(nn.Module):
9 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
10 | super(Conv2DBNActiv, self).__init__()
11 | self.conv = nn.Sequential(
12 | nn.Conv2d(
13 | nin,
14 | nout,
15 | kernel_size=ksize,
16 | stride=stride,
17 | padding=pad,
18 | dilation=dilation,
19 | bias=False,
20 | ),
21 | nn.BatchNorm2d(nout),
22 | activ(),
23 | )
24 |
25 | def __call__(self, x):
26 | return self.conv(x)
27 |
28 |
29 | class SeperableConv2DBNActiv(nn.Module):
30 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
31 | super(SeperableConv2DBNActiv, self).__init__()
32 | self.conv = nn.Sequential(
33 | nn.Conv2d(
34 | nin,
35 | nin,
36 | kernel_size=ksize,
37 | stride=stride,
38 | padding=pad,
39 | dilation=dilation,
40 | groups=nin,
41 | bias=False,
42 | ),
43 | nn.Conv2d(nin, nout, kernel_size=1, bias=False),
44 | nn.BatchNorm2d(nout),
45 | activ(),
46 | )
47 |
48 | def __call__(self, x):
49 | return self.conv(x)
50 |
51 |
52 | class Encoder(nn.Module):
53 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
54 | super(Encoder, self).__init__()
55 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
56 | self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
57 |
58 | def __call__(self, x):
59 | skip = self.conv1(x)
60 | h = self.conv2(skip)
61 |
62 | return h, skip
63 |
64 |
65 | class Decoder(nn.Module):
66 | def __init__(
67 | self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
68 | ):
69 | super(Decoder, self).__init__()
70 | self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
71 | self.dropout = nn.Dropout2d(0.1) if dropout else None
72 |
73 | def __call__(self, x, skip=None):
74 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
75 | if skip is not None:
76 | skip = spec_utils.crop_center(skip, x)
77 | x = torch.cat([x, skip], dim=1)
78 | h = self.conv(x)
79 |
80 | if self.dropout is not None:
81 | h = self.dropout(h)
82 |
83 | return h
84 |
85 |
86 | class ASPPModule(nn.Module):
87 | def __init__(self, nin, nout, dilations=(4, 8, 16, 32, 64), activ=nn.ReLU):
88 | super(ASPPModule, self).__init__()
89 | self.conv1 = nn.Sequential(
90 | nn.AdaptiveAvgPool2d((1, None)),
91 | Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
92 | )
93 | self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
94 | self.conv3 = SeperableConv2DBNActiv(
95 | nin, nin, 3, 1, dilations[0], dilations[0], activ=activ
96 | )
97 | self.conv4 = SeperableConv2DBNActiv(
98 | nin, nin, 3, 1, dilations[1], dilations[1], activ=activ
99 | )
100 | self.conv5 = SeperableConv2DBNActiv(
101 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
102 | )
103 | self.conv6 = SeperableConv2DBNActiv(
104 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
105 | )
106 | self.conv7 = SeperableConv2DBNActiv(
107 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
108 | )
109 | self.bottleneck = nn.Sequential(
110 | Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
111 | )
112 |
113 | def forward(self, x):
114 | _, _, h, w = x.size()
115 | feat1 = F.interpolate(
116 | self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
117 | )
118 | feat2 = self.conv2(x)
119 | feat3 = self.conv3(x)
120 | feat4 = self.conv4(x)
121 | feat5 = self.conv5(x)
122 | feat6 = self.conv6(x)
123 | feat7 = self.conv7(x)
124 | out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1)
125 | bottle = self.bottleneck(out)
126 | return bottle
127 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/layers_537227KB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import spec_utils
6 |
7 |
8 | class Conv2DBNActiv(nn.Module):
9 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
10 | super(Conv2DBNActiv, self).__init__()
11 | self.conv = nn.Sequential(
12 | nn.Conv2d(
13 | nin,
14 | nout,
15 | kernel_size=ksize,
16 | stride=stride,
17 | padding=pad,
18 | dilation=dilation,
19 | bias=False,
20 | ),
21 | nn.BatchNorm2d(nout),
22 | activ(),
23 | )
24 |
25 | def __call__(self, x):
26 | return self.conv(x)
27 |
28 |
29 | class SeperableConv2DBNActiv(nn.Module):
30 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
31 | super(SeperableConv2DBNActiv, self).__init__()
32 | self.conv = nn.Sequential(
33 | nn.Conv2d(
34 | nin,
35 | nin,
36 | kernel_size=ksize,
37 | stride=stride,
38 | padding=pad,
39 | dilation=dilation,
40 | groups=nin,
41 | bias=False,
42 | ),
43 | nn.Conv2d(nin, nout, kernel_size=1, bias=False),
44 | nn.BatchNorm2d(nout),
45 | activ(),
46 | )
47 |
48 | def __call__(self, x):
49 | return self.conv(x)
50 |
51 |
52 | class Encoder(nn.Module):
53 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
54 | super(Encoder, self).__init__()
55 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
56 | self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
57 |
58 | def __call__(self, x):
59 | skip = self.conv1(x)
60 | h = self.conv2(skip)
61 |
62 | return h, skip
63 |
64 |
65 | class Decoder(nn.Module):
66 | def __init__(
67 | self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
68 | ):
69 | super(Decoder, self).__init__()
70 | self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
71 | self.dropout = nn.Dropout2d(0.1) if dropout else None
72 |
73 | def __call__(self, x, skip=None):
74 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
75 | if skip is not None:
76 | skip = spec_utils.crop_center(skip, x)
77 | x = torch.cat([x, skip], dim=1)
78 | h = self.conv(x)
79 |
80 | if self.dropout is not None:
81 | h = self.dropout(h)
82 |
83 | return h
84 |
85 |
86 | class ASPPModule(nn.Module):
87 | def __init__(self, nin, nout, dilations=(4, 8, 16, 32, 64), activ=nn.ReLU):
88 | super(ASPPModule, self).__init__()
89 | self.conv1 = nn.Sequential(
90 | nn.AdaptiveAvgPool2d((1, None)),
91 | Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
92 | )
93 | self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
94 | self.conv3 = SeperableConv2DBNActiv(
95 | nin, nin, 3, 1, dilations[0], dilations[0], activ=activ
96 | )
97 | self.conv4 = SeperableConv2DBNActiv(
98 | nin, nin, 3, 1, dilations[1], dilations[1], activ=activ
99 | )
100 | self.conv5 = SeperableConv2DBNActiv(
101 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
102 | )
103 | self.conv6 = SeperableConv2DBNActiv(
104 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
105 | )
106 | self.conv7 = SeperableConv2DBNActiv(
107 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
108 | )
109 | self.bottleneck = nn.Sequential(
110 | Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
111 | )
112 |
113 | def forward(self, x):
114 | _, _, h, w = x.size()
115 | feat1 = F.interpolate(
116 | self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
117 | )
118 | feat2 = self.conv2(x)
119 | feat3 = self.conv3(x)
120 | feat4 = self.conv4(x)
121 | feat5 = self.conv5(x)
122 | feat6 = self.conv6(x)
123 | feat7 = self.conv7(x)
124 | out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1)
125 | bottle = self.bottleneck(out)
126 | return bottle
127 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/layers_537238KB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import spec_utils
6 |
7 |
8 | class Conv2DBNActiv(nn.Module):
9 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
10 | super(Conv2DBNActiv, self).__init__()
11 | self.conv = nn.Sequential(
12 | nn.Conv2d(
13 | nin,
14 | nout,
15 | kernel_size=ksize,
16 | stride=stride,
17 | padding=pad,
18 | dilation=dilation,
19 | bias=False,
20 | ),
21 | nn.BatchNorm2d(nout),
22 | activ(),
23 | )
24 |
25 | def __call__(self, x):
26 | return self.conv(x)
27 |
28 |
29 | class SeperableConv2DBNActiv(nn.Module):
30 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
31 | super(SeperableConv2DBNActiv, self).__init__()
32 | self.conv = nn.Sequential(
33 | nn.Conv2d(
34 | nin,
35 | nin,
36 | kernel_size=ksize,
37 | stride=stride,
38 | padding=pad,
39 | dilation=dilation,
40 | groups=nin,
41 | bias=False,
42 | ),
43 | nn.Conv2d(nin, nout, kernel_size=1, bias=False),
44 | nn.BatchNorm2d(nout),
45 | activ(),
46 | )
47 |
48 | def __call__(self, x):
49 | return self.conv(x)
50 |
51 |
52 | class Encoder(nn.Module):
53 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
54 | super(Encoder, self).__init__()
55 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
56 | self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
57 |
58 | def __call__(self, x):
59 | skip = self.conv1(x)
60 | h = self.conv2(skip)
61 |
62 | return h, skip
63 |
64 |
65 | class Decoder(nn.Module):
66 | def __init__(
67 | self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
68 | ):
69 | super(Decoder, self).__init__()
70 | self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
71 | self.dropout = nn.Dropout2d(0.1) if dropout else None
72 |
73 | def __call__(self, x, skip=None):
74 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
75 | if skip is not None:
76 | skip = spec_utils.crop_center(skip, x)
77 | x = torch.cat([x, skip], dim=1)
78 | h = self.conv(x)
79 |
80 | if self.dropout is not None:
81 | h = self.dropout(h)
82 |
83 | return h
84 |
85 |
86 | class ASPPModule(nn.Module):
87 | def __init__(self, nin, nout, dilations=(4, 8, 16, 32, 64), activ=nn.ReLU):
88 | super(ASPPModule, self).__init__()
89 | self.conv1 = nn.Sequential(
90 | nn.AdaptiveAvgPool2d((1, None)),
91 | Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
92 | )
93 | self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
94 | self.conv3 = SeperableConv2DBNActiv(
95 | nin, nin, 3, 1, dilations[0], dilations[0], activ=activ
96 | )
97 | self.conv4 = SeperableConv2DBNActiv(
98 | nin, nin, 3, 1, dilations[1], dilations[1], activ=activ
99 | )
100 | self.conv5 = SeperableConv2DBNActiv(
101 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
102 | )
103 | self.conv6 = SeperableConv2DBNActiv(
104 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
105 | )
106 | self.conv7 = SeperableConv2DBNActiv(
107 | nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
108 | )
109 | self.bottleneck = nn.Sequential(
110 | Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
111 | )
112 |
113 | def forward(self, x):
114 | _, _, h, w = x.size()
115 | feat1 = F.interpolate(
116 | self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
117 | )
118 | feat2 = self.conv2(x)
119 | feat3 = self.conv3(x)
120 | feat4 = self.conv4(x)
121 | feat5 = self.conv5(x)
122 | feat6 = self.conv6(x)
123 | feat7 = self.conv7(x)
124 | out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1)
125 | bottle = self.bottleneck(out)
126 | return bottle
127 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/layers_new.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import spec_utils
6 |
7 |
8 | class Conv2DBNActiv(nn.Module):
9 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
10 | super(Conv2DBNActiv, self).__init__()
11 | self.conv = nn.Sequential(
12 | nn.Conv2d(
13 | nin,
14 | nout,
15 | kernel_size=ksize,
16 | stride=stride,
17 | padding=pad,
18 | dilation=dilation,
19 | bias=False,
20 | ),
21 | nn.BatchNorm2d(nout),
22 | activ(),
23 | )
24 |
25 | def __call__(self, x):
26 | return self.conv(x)
27 |
28 |
29 | class Encoder(nn.Module):
30 | def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
31 | super(Encoder, self).__init__()
32 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
33 | self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
34 |
35 | def __call__(self, x):
36 | h = self.conv1(x)
37 | h = self.conv2(h)
38 |
39 | return h
40 |
41 |
42 | class Decoder(nn.Module):
43 | def __init__(
44 | self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
45 | ):
46 | super(Decoder, self).__init__()
47 | self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
48 | # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
49 | self.dropout = nn.Dropout2d(0.1) if dropout else None
50 |
51 | def __call__(self, x, skip=None):
52 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
53 |
54 | if skip is not None:
55 | skip = spec_utils.crop_center(skip, x)
56 | x = torch.cat([x, skip], dim=1)
57 |
58 | h = self.conv1(x)
59 | # h = self.conv2(h)
60 |
61 | if self.dropout is not None:
62 | h = self.dropout(h)
63 |
64 | return h
65 |
66 |
67 | class ASPPModule(nn.Module):
68 | def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
69 | super(ASPPModule, self).__init__()
70 | self.conv1 = nn.Sequential(
71 | nn.AdaptiveAvgPool2d((1, None)),
72 | Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ),
73 | )
74 | self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
75 | self.conv3 = Conv2DBNActiv(
76 | nin, nout, 3, 1, dilations[0], dilations[0], activ=activ
77 | )
78 | self.conv4 = Conv2DBNActiv(
79 | nin, nout, 3, 1, dilations[1], dilations[1], activ=activ
80 | )
81 | self.conv5 = Conv2DBNActiv(
82 | nin, nout, 3, 1, dilations[2], dilations[2], activ=activ
83 | )
84 | self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
85 | self.dropout = nn.Dropout2d(0.1) if dropout else None
86 |
87 | def forward(self, x):
88 | _, _, h, w = x.size()
89 | feat1 = F.interpolate(
90 | self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
91 | )
92 | feat2 = self.conv2(x)
93 | feat3 = self.conv3(x)
94 | feat4 = self.conv4(x)
95 | feat5 = self.conv5(x)
96 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
97 | out = self.bottleneck(out)
98 |
99 | if self.dropout is not None:
100 | out = self.dropout(out)
101 |
102 | return out
103 |
104 |
105 | class LSTMModule(nn.Module):
106 | def __init__(self, nin_conv, nin_lstm, nout_lstm):
107 | super(LSTMModule, self).__init__()
108 | self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
109 | self.lstm = nn.LSTM(
110 | input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True
111 | )
112 | self.dense = nn.Sequential(
113 | nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU()
114 | )
115 |
116 | def forward(self, x):
117 | N, _, nbins, nframes = x.size()
118 | h = self.conv(x)[:, 0] # N, nbins, nframes
119 | h = h.permute(2, 0, 1) # nframes, N, nbins
120 | h, _ = self.lstm(h)
121 | h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins
122 | h = h.reshape(nframes, N, 1, nbins)
123 | h = h.permute(1, 2, 3, 0)
124 |
125 | return h
126 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/model_param_init.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pathlib
4 |
5 | default_param = {}
6 | default_param["bins"] = 768
7 | default_param["unstable_bins"] = 9 # training only
8 | default_param["reduction_bins"] = 762 # training only
9 | default_param["sr"] = 44100
10 | default_param["pre_filter_start"] = 757
11 | default_param["pre_filter_stop"] = 768
12 | default_param["band"] = {}
13 |
14 |
15 | default_param["band"][1] = {
16 | "sr": 11025,
17 | "hl": 128,
18 | "n_fft": 960,
19 | "crop_start": 0,
20 | "crop_stop": 245,
21 | "lpf_start": 61, # inference only
22 | "res_type": "polyphase",
23 | }
24 |
25 | default_param["band"][2] = {
26 | "sr": 44100,
27 | "hl": 512,
28 | "n_fft": 1536,
29 | "crop_start": 24,
30 | "crop_stop": 547,
31 | "hpf_start": 81, # inference only
32 | "res_type": "sinc_best",
33 | }
34 |
35 |
36 | def int_keys(d):
37 | r = {}
38 | for k, v in d:
39 | if k.isdigit():
40 | k = int(k)
41 | r[k] = v
42 | return r
43 |
44 |
45 | class ModelParameters(object):
46 | def __init__(self, config_path=""):
47 | if ".pth" == pathlib.Path(config_path).suffix:
48 | import zipfile
49 |
50 | with zipfile.ZipFile(config_path, "r") as zip:
51 | self.param = json.loads(
52 | zip.read("param.json"), object_pairs_hook=int_keys
53 | )
54 | elif ".json" == pathlib.Path(config_path).suffix:
55 | with open(config_path, "r") as f:
56 | self.param = json.loads(f.read(), object_pairs_hook=int_keys)
57 | else:
58 | self.param = default_param
59 |
60 | for k in [
61 | "mid_side",
62 | "mid_side_b",
63 | "mid_side_b2",
64 | "stereo_w",
65 | "stereo_n",
66 | "reverse",
67 | ]:
68 | if not k in self.param:
69 | self.param[k] = False
70 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/1band_sr16000_hl512.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 16000,
8 | "hl": 512,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 1024,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 16000,
17 | "pre_filter_start": 1023,
18 | "pre_filter_stop": 1024
19 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/1band_sr32000_hl512.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 32000,
8 | "hl": 512,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 1024,
12 | "hpf_start": -1,
13 | "res_type": "kaiser_fast"
14 | }
15 | },
16 | "sr": 32000,
17 | "pre_filter_start": 1000,
18 | "pre_filter_stop": 1021
19 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/1band_sr33075_hl384.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 33075,
8 | "hl": 384,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 1024,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 33075,
17 | "pre_filter_start": 1000,
18 | "pre_filter_stop": 1021
19 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl1024.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 44100,
8 | "hl": 1024,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 1024,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 44100,
17 | "pre_filter_start": 1023,
18 | "pre_filter_stop": 1024
19 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl256.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 256,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 44100,
8 | "hl": 256,
9 | "n_fft": 512,
10 | "crop_start": 0,
11 | "crop_stop": 256,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 44100,
17 | "pre_filter_start": 256,
18 | "pre_filter_stop": 256
19 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl512.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 44100,
8 | "hl": 512,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 1024,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 44100,
17 | "pre_filter_start": 1023,
18 | "pre_filter_stop": 1024
19 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/1band_sr44100_hl512_cut.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 1024,
3 | "unstable_bins": 0,
4 | "reduction_bins": 0,
5 | "band": {
6 | "1": {
7 | "sr": 44100,
8 | "hl": 512,
9 | "n_fft": 2048,
10 | "crop_start": 0,
11 | "crop_stop": 700,
12 | "hpf_start": -1,
13 | "res_type": "sinc_best"
14 | }
15 | },
16 | "sr": 44100,
17 | "pre_filter_start": 1023,
18 | "pre_filter_stop": 700
19 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/2band_32000.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 768,
3 | "unstable_bins": 7,
4 | "reduction_bins": 705,
5 | "band": {
6 | "1": {
7 | "sr": 6000,
8 | "hl": 66,
9 | "n_fft": 512,
10 | "crop_start": 0,
11 | "crop_stop": 240,
12 | "lpf_start": 60,
13 | "lpf_stop": 118,
14 | "res_type": "sinc_fastest"
15 | },
16 | "2": {
17 | "sr": 32000,
18 | "hl": 352,
19 | "n_fft": 1024,
20 | "crop_start": 22,
21 | "crop_stop": 505,
22 | "hpf_start": 44,
23 | "hpf_stop": 23,
24 | "res_type": "sinc_medium"
25 | }
26 | },
27 | "sr": 32000,
28 | "pre_filter_start": 710,
29 | "pre_filter_stop": 731
30 | }
31 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/2band_44100_lofi.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 512,
3 | "unstable_bins": 7,
4 | "reduction_bins": 510,
5 | "band": {
6 | "1": {
7 | "sr": 11025,
8 | "hl": 160,
9 | "n_fft": 768,
10 | "crop_start": 0,
11 | "crop_stop": 192,
12 | "lpf_start": 41,
13 | "lpf_stop": 139,
14 | "res_type": "sinc_fastest"
15 | },
16 | "2": {
17 | "sr": 44100,
18 | "hl": 640,
19 | "n_fft": 1024,
20 | "crop_start": 10,
21 | "crop_stop": 320,
22 | "hpf_start": 47,
23 | "hpf_stop": 15,
24 | "res_type": "sinc_medium"
25 | }
26 | },
27 | "sr": 44100,
28 | "pre_filter_start": 510,
29 | "pre_filter_stop": 512
30 | }
31 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/2band_48000.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 768,
3 | "unstable_bins": 7,
4 | "reduction_bins": 705,
5 | "band": {
6 | "1": {
7 | "sr": 6000,
8 | "hl": 66,
9 | "n_fft": 512,
10 | "crop_start": 0,
11 | "crop_stop": 240,
12 | "lpf_start": 60,
13 | "lpf_stop": 240,
14 | "res_type": "sinc_fastest"
15 | },
16 | "2": {
17 | "sr": 48000,
18 | "hl": 528,
19 | "n_fft": 1536,
20 | "crop_start": 22,
21 | "crop_stop": 505,
22 | "hpf_start": 82,
23 | "hpf_stop": 22,
24 | "res_type": "sinc_medium"
25 | }
26 | },
27 | "sr": 48000,
28 | "pre_filter_start": 710,
29 | "pre_filter_stop": 731
30 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/3band_44100.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 768,
3 | "unstable_bins": 5,
4 | "reduction_bins": 733,
5 | "band": {
6 | "1": {
7 | "sr": 11025,
8 | "hl": 128,
9 | "n_fft": 768,
10 | "crop_start": 0,
11 | "crop_stop": 278,
12 | "lpf_start": 28,
13 | "lpf_stop": 140,
14 | "res_type": "polyphase"
15 | },
16 | "2": {
17 | "sr": 22050,
18 | "hl": 256,
19 | "n_fft": 768,
20 | "crop_start": 14,
21 | "crop_stop": 322,
22 | "hpf_start": 70,
23 | "hpf_stop": 14,
24 | "lpf_start": 283,
25 | "lpf_stop": 314,
26 | "res_type": "polyphase"
27 | },
28 | "3": {
29 | "sr": 44100,
30 | "hl": 512,
31 | "n_fft": 768,
32 | "crop_start": 131,
33 | "crop_stop": 313,
34 | "hpf_start": 154,
35 | "hpf_stop": 141,
36 | "res_type": "sinc_medium"
37 | }
38 | },
39 | "sr": 44100,
40 | "pre_filter_start": 757,
41 | "pre_filter_stop": 768
42 | }
43 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/3band_44100_mid.json:
--------------------------------------------------------------------------------
1 | {
2 | "mid_side": true,
3 | "bins": 768,
4 | "unstable_bins": 5,
5 | "reduction_bins": 733,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 768,
11 | "crop_start": 0,
12 | "crop_stop": 278,
13 | "lpf_start": 28,
14 | "lpf_stop": 140,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 22050,
19 | "hl": 256,
20 | "n_fft": 768,
21 | "crop_start": 14,
22 | "crop_stop": 322,
23 | "hpf_start": 70,
24 | "hpf_stop": 14,
25 | "lpf_start": 283,
26 | "lpf_stop": 314,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 44100,
31 | "hl": 512,
32 | "n_fft": 768,
33 | "crop_start": 131,
34 | "crop_stop": 313,
35 | "hpf_start": 154,
36 | "hpf_stop": 141,
37 | "res_type": "sinc_medium"
38 | }
39 | },
40 | "sr": 44100,
41 | "pre_filter_start": 757,
42 | "pre_filter_stop": 768
43 | }
44 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/3band_44100_msb2.json:
--------------------------------------------------------------------------------
1 | {
2 | "mid_side_b2": true,
3 | "bins": 640,
4 | "unstable_bins": 7,
5 | "reduction_bins": 565,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 108,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 187,
13 | "lpf_start": 92,
14 | "lpf_stop": 186,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 22050,
19 | "hl": 216,
20 | "n_fft": 768,
21 | "crop_start": 0,
22 | "crop_stop": 212,
23 | "hpf_start": 68,
24 | "hpf_stop": 34,
25 | "lpf_start": 174,
26 | "lpf_stop": 209,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 44100,
31 | "hl": 432,
32 | "n_fft": 640,
33 | "crop_start": 66,
34 | "crop_stop": 307,
35 | "hpf_start": 86,
36 | "hpf_stop": 72,
37 | "res_type": "kaiser_fast"
38 | }
39 | },
40 | "sr": 44100,
41 | "pre_filter_start": 639,
42 | "pre_filter_stop": 640
43 | }
44 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/4band_44100.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 768,
3 | "unstable_bins": 7,
4 | "reduction_bins": 668,
5 | "band": {
6 | "1": {
7 | "sr": 11025,
8 | "hl": 128,
9 | "n_fft": 1024,
10 | "crop_start": 0,
11 | "crop_stop": 186,
12 | "lpf_start": 37,
13 | "lpf_stop": 73,
14 | "res_type": "polyphase"
15 | },
16 | "2": {
17 | "sr": 11025,
18 | "hl": 128,
19 | "n_fft": 512,
20 | "crop_start": 4,
21 | "crop_stop": 185,
22 | "hpf_start": 36,
23 | "hpf_stop": 18,
24 | "lpf_start": 93,
25 | "lpf_stop": 185,
26 | "res_type": "polyphase"
27 | },
28 | "3": {
29 | "sr": 22050,
30 | "hl": 256,
31 | "n_fft": 512,
32 | "crop_start": 46,
33 | "crop_stop": 186,
34 | "hpf_start": 93,
35 | "hpf_stop": 46,
36 | "lpf_start": 164,
37 | "lpf_stop": 186,
38 | "res_type": "polyphase"
39 | },
40 | "4": {
41 | "sr": 44100,
42 | "hl": 512,
43 | "n_fft": 768,
44 | "crop_start": 121,
45 | "crop_stop": 382,
46 | "hpf_start": 138,
47 | "hpf_stop": 123,
48 | "res_type": "sinc_medium"
49 | }
50 | },
51 | "sr": 44100,
52 | "pre_filter_start": 740,
53 | "pre_filter_stop": 768
54 | }
55 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/4band_44100_mid.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 768,
3 | "unstable_bins": 7,
4 | "mid_side": true,
5 | "reduction_bins": 668,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 186,
13 | "lpf_start": 37,
14 | "lpf_stop": 73,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 11025,
19 | "hl": 128,
20 | "n_fft": 512,
21 | "crop_start": 4,
22 | "crop_stop": 185,
23 | "hpf_start": 36,
24 | "hpf_stop": 18,
25 | "lpf_start": 93,
26 | "lpf_stop": 185,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 22050,
31 | "hl": 256,
32 | "n_fft": 512,
33 | "crop_start": 46,
34 | "crop_stop": 186,
35 | "hpf_start": 93,
36 | "hpf_stop": 46,
37 | "lpf_start": 164,
38 | "lpf_stop": 186,
39 | "res_type": "polyphase"
40 | },
41 | "4": {
42 | "sr": 44100,
43 | "hl": 512,
44 | "n_fft": 768,
45 | "crop_start": 121,
46 | "crop_stop": 382,
47 | "hpf_start": 138,
48 | "hpf_stop": 123,
49 | "res_type": "sinc_medium"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 740,
54 | "pre_filter_stop": 768
55 | }
56 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/4band_44100_msb.json:
--------------------------------------------------------------------------------
1 | {
2 | "mid_side_b": true,
3 | "bins": 768,
4 | "unstable_bins": 7,
5 | "reduction_bins": 668,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 186,
13 | "lpf_start": 37,
14 | "lpf_stop": 73,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 11025,
19 | "hl": 128,
20 | "n_fft": 512,
21 | "crop_start": 4,
22 | "crop_stop": 185,
23 | "hpf_start": 36,
24 | "hpf_stop": 18,
25 | "lpf_start": 93,
26 | "lpf_stop": 185,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 22050,
31 | "hl": 256,
32 | "n_fft": 512,
33 | "crop_start": 46,
34 | "crop_stop": 186,
35 | "hpf_start": 93,
36 | "hpf_stop": 46,
37 | "lpf_start": 164,
38 | "lpf_stop": 186,
39 | "res_type": "polyphase"
40 | },
41 | "4": {
42 | "sr": 44100,
43 | "hl": 512,
44 | "n_fft": 768,
45 | "crop_start": 121,
46 | "crop_stop": 382,
47 | "hpf_start": 138,
48 | "hpf_stop": 123,
49 | "res_type": "sinc_medium"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 740,
54 | "pre_filter_stop": 768
55 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/4band_44100_msb2.json:
--------------------------------------------------------------------------------
1 | {
2 | "mid_side_b": true,
3 | "bins": 768,
4 | "unstable_bins": 7,
5 | "reduction_bins": 668,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 186,
13 | "lpf_start": 37,
14 | "lpf_stop": 73,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 11025,
19 | "hl": 128,
20 | "n_fft": 512,
21 | "crop_start": 4,
22 | "crop_stop": 185,
23 | "hpf_start": 36,
24 | "hpf_stop": 18,
25 | "lpf_start": 93,
26 | "lpf_stop": 185,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 22050,
31 | "hl": 256,
32 | "n_fft": 512,
33 | "crop_start": 46,
34 | "crop_stop": 186,
35 | "hpf_start": 93,
36 | "hpf_stop": 46,
37 | "lpf_start": 164,
38 | "lpf_stop": 186,
39 | "res_type": "polyphase"
40 | },
41 | "4": {
42 | "sr": 44100,
43 | "hl": 512,
44 | "n_fft": 768,
45 | "crop_start": 121,
46 | "crop_stop": 382,
47 | "hpf_start": 138,
48 | "hpf_stop": 123,
49 | "res_type": "sinc_medium"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 740,
54 | "pre_filter_stop": 768
55 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/4band_44100_reverse.json:
--------------------------------------------------------------------------------
1 | {
2 | "reverse": true,
3 | "bins": 768,
4 | "unstable_bins": 7,
5 | "reduction_bins": 668,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 186,
13 | "lpf_start": 37,
14 | "lpf_stop": 73,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 11025,
19 | "hl": 128,
20 | "n_fft": 512,
21 | "crop_start": 4,
22 | "crop_stop": 185,
23 | "hpf_start": 36,
24 | "hpf_stop": 18,
25 | "lpf_start": 93,
26 | "lpf_stop": 185,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 22050,
31 | "hl": 256,
32 | "n_fft": 512,
33 | "crop_start": 46,
34 | "crop_stop": 186,
35 | "hpf_start": 93,
36 | "hpf_stop": 46,
37 | "lpf_start": 164,
38 | "lpf_stop": 186,
39 | "res_type": "polyphase"
40 | },
41 | "4": {
42 | "sr": 44100,
43 | "hl": 512,
44 | "n_fft": 768,
45 | "crop_start": 121,
46 | "crop_stop": 382,
47 | "hpf_start": 138,
48 | "hpf_stop": 123,
49 | "res_type": "sinc_medium"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 740,
54 | "pre_filter_stop": 768
55 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/4band_44100_sw.json:
--------------------------------------------------------------------------------
1 | {
2 | "stereo_w": true,
3 | "bins": 768,
4 | "unstable_bins": 7,
5 | "reduction_bins": 668,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 128,
10 | "n_fft": 1024,
11 | "crop_start": 0,
12 | "crop_stop": 186,
13 | "lpf_start": 37,
14 | "lpf_stop": 73,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 11025,
19 | "hl": 128,
20 | "n_fft": 512,
21 | "crop_start": 4,
22 | "crop_stop": 185,
23 | "hpf_start": 36,
24 | "hpf_stop": 18,
25 | "lpf_start": 93,
26 | "lpf_stop": 185,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 22050,
31 | "hl": 256,
32 | "n_fft": 512,
33 | "crop_start": 46,
34 | "crop_stop": 186,
35 | "hpf_start": 93,
36 | "hpf_stop": 46,
37 | "lpf_start": 164,
38 | "lpf_stop": 186,
39 | "res_type": "polyphase"
40 | },
41 | "4": {
42 | "sr": 44100,
43 | "hl": 512,
44 | "n_fft": 768,
45 | "crop_start": 121,
46 | "crop_stop": 382,
47 | "hpf_start": 138,
48 | "hpf_stop": 123,
49 | "res_type": "sinc_medium"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 740,
54 | "pre_filter_stop": 768
55 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/4band_v2.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 672,
3 | "unstable_bins": 8,
4 | "reduction_bins": 637,
5 | "band": {
6 | "1": {
7 | "sr": 7350,
8 | "hl": 80,
9 | "n_fft": 640,
10 | "crop_start": 0,
11 | "crop_stop": 85,
12 | "lpf_start": 25,
13 | "lpf_stop": 53,
14 | "res_type": "polyphase"
15 | },
16 | "2": {
17 | "sr": 7350,
18 | "hl": 80,
19 | "n_fft": 320,
20 | "crop_start": 4,
21 | "crop_stop": 87,
22 | "hpf_start": 25,
23 | "hpf_stop": 12,
24 | "lpf_start": 31,
25 | "lpf_stop": 62,
26 | "res_type": "polyphase"
27 | },
28 | "3": {
29 | "sr": 14700,
30 | "hl": 160,
31 | "n_fft": 512,
32 | "crop_start": 17,
33 | "crop_stop": 216,
34 | "hpf_start": 48,
35 | "hpf_stop": 24,
36 | "lpf_start": 139,
37 | "lpf_stop": 210,
38 | "res_type": "polyphase"
39 | },
40 | "4": {
41 | "sr": 44100,
42 | "hl": 480,
43 | "n_fft": 960,
44 | "crop_start": 78,
45 | "crop_stop": 383,
46 | "hpf_start": 130,
47 | "hpf_stop": 86,
48 | "res_type": "kaiser_fast"
49 | }
50 | },
51 | "sr": 44100,
52 | "pre_filter_start": 668,
53 | "pre_filter_stop": 672
54 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/4band_v2_sn.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 672,
3 | "unstable_bins": 8,
4 | "reduction_bins": 637,
5 | "band": {
6 | "1": {
7 | "sr": 7350,
8 | "hl": 80,
9 | "n_fft": 640,
10 | "crop_start": 0,
11 | "crop_stop": 85,
12 | "lpf_start": 25,
13 | "lpf_stop": 53,
14 | "res_type": "polyphase"
15 | },
16 | "2": {
17 | "sr": 7350,
18 | "hl": 80,
19 | "n_fft": 320,
20 | "crop_start": 4,
21 | "crop_stop": 87,
22 | "hpf_start": 25,
23 | "hpf_stop": 12,
24 | "lpf_start": 31,
25 | "lpf_stop": 62,
26 | "res_type": "polyphase"
27 | },
28 | "3": {
29 | "sr": 14700,
30 | "hl": 160,
31 | "n_fft": 512,
32 | "crop_start": 17,
33 | "crop_stop": 216,
34 | "hpf_start": 48,
35 | "hpf_stop": 24,
36 | "lpf_start": 139,
37 | "lpf_stop": 210,
38 | "res_type": "polyphase"
39 | },
40 | "4": {
41 | "sr": 44100,
42 | "hl": 480,
43 | "n_fft": 960,
44 | "crop_start": 78,
45 | "crop_stop": 383,
46 | "hpf_start": 130,
47 | "hpf_stop": 86,
48 | "convert_channels": "stereo_n",
49 | "res_type": "kaiser_fast"
50 | }
51 | },
52 | "sr": 44100,
53 | "pre_filter_start": 668,
54 | "pre_filter_stop": 672
55 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/4band_v3.json:
--------------------------------------------------------------------------------
1 | {
2 | "bins": 672,
3 | "unstable_bins": 8,
4 | "reduction_bins": 530,
5 | "band": {
6 | "1": {
7 | "sr": 7350,
8 | "hl": 80,
9 | "n_fft": 640,
10 | "crop_start": 0,
11 | "crop_stop": 85,
12 | "lpf_start": 25,
13 | "lpf_stop": 53,
14 | "res_type": "polyphase"
15 | },
16 | "2": {
17 | "sr": 7350,
18 | "hl": 80,
19 | "n_fft": 320,
20 | "crop_start": 4,
21 | "crop_stop": 87,
22 | "hpf_start": 25,
23 | "hpf_stop": 12,
24 | "lpf_start": 31,
25 | "lpf_stop": 62,
26 | "res_type": "polyphase"
27 | },
28 | "3": {
29 | "sr": 14700,
30 | "hl": 160,
31 | "n_fft": 512,
32 | "crop_start": 17,
33 | "crop_stop": 216,
34 | "hpf_start": 48,
35 | "hpf_stop": 24,
36 | "lpf_start": 139,
37 | "lpf_stop": 210,
38 | "res_type": "polyphase"
39 | },
40 | "4": {
41 | "sr": 44100,
42 | "hl": 480,
43 | "n_fft": 960,
44 | "crop_start": 78,
45 | "crop_stop": 383,
46 | "hpf_start": 130,
47 | "hpf_stop": 86,
48 | "res_type": "kaiser_fast"
49 | }
50 | },
51 | "sr": 44100,
52 | "pre_filter_start": 668,
53 | "pre_filter_stop": 672
54 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/modelparams/ensemble.json:
--------------------------------------------------------------------------------
1 | {
2 | "mid_side_b2": true,
3 | "bins": 1280,
4 | "unstable_bins": 7,
5 | "reduction_bins": 565,
6 | "band": {
7 | "1": {
8 | "sr": 11025,
9 | "hl": 108,
10 | "n_fft": 2048,
11 | "crop_start": 0,
12 | "crop_stop": 374,
13 | "lpf_start": 92,
14 | "lpf_stop": 186,
15 | "res_type": "polyphase"
16 | },
17 | "2": {
18 | "sr": 22050,
19 | "hl": 216,
20 | "n_fft": 1536,
21 | "crop_start": 0,
22 | "crop_stop": 424,
23 | "hpf_start": 68,
24 | "hpf_stop": 34,
25 | "lpf_start": 348,
26 | "lpf_stop": 418,
27 | "res_type": "polyphase"
28 | },
29 | "3": {
30 | "sr": 44100,
31 | "hl": 432,
32 | "n_fft": 1280,
33 | "crop_start": 132,
34 | "crop_stop": 614,
35 | "hpf_start": 172,
36 | "hpf_stop": 144,
37 | "res_type": "polyphase"
38 | }
39 | },
40 | "sr": 44100,
41 | "pre_filter_start": 1280,
42 | "pre_filter_stop": 1280
43 | }
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/nets.py:
--------------------------------------------------------------------------------
1 | import layers
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 |
6 | from . import spec_utils
7 |
8 |
9 | class BaseASPPNet(nn.Module):
10 | def __init__(self, nin, ch, dilations=(4, 8, 16)):
11 | super(BaseASPPNet, self).__init__()
12 | self.enc1 = layers.Encoder(nin, ch, 3, 2, 1)
13 | self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1)
14 | self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1)
15 | self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1)
16 |
17 | self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations)
18 |
19 | self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1)
20 | self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1)
21 | self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1)
22 | self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1)
23 |
24 | def __call__(self, x):
25 | h, e1 = self.enc1(x)
26 | h, e2 = self.enc2(h)
27 | h, e3 = self.enc3(h)
28 | h, e4 = self.enc4(h)
29 |
30 | h = self.aspp(h)
31 |
32 | h = self.dec4(h, e4)
33 | h = self.dec3(h, e3)
34 | h = self.dec2(h, e2)
35 | h = self.dec1(h, e1)
36 |
37 | return h
38 |
39 |
40 | class CascadedASPPNet(nn.Module):
41 | def __init__(self, n_fft):
42 | super(CascadedASPPNet, self).__init__()
43 | self.stg1_low_band_net = BaseASPPNet(2, 16)
44 | self.stg1_high_band_net = BaseASPPNet(2, 16)
45 |
46 | self.stg2_bridge = layers.Conv2DBNActiv(18, 8, 1, 1, 0)
47 | self.stg2_full_band_net = BaseASPPNet(8, 16)
48 |
49 | self.stg3_bridge = layers.Conv2DBNActiv(34, 16, 1, 1, 0)
50 | self.stg3_full_band_net = BaseASPPNet(16, 32)
51 |
52 | self.out = nn.Conv2d(32, 2, 1, bias=False)
53 | self.aux1_out = nn.Conv2d(16, 2, 1, bias=False)
54 | self.aux2_out = nn.Conv2d(16, 2, 1, bias=False)
55 |
56 | self.max_bin = n_fft // 2
57 | self.output_bin = n_fft // 2 + 1
58 |
59 | self.offset = 128
60 |
61 | def forward(self, x, aggressiveness=None):
62 | mix = x.detach()
63 | x = x.clone()
64 |
65 | x = x[:, :, : self.max_bin]
66 |
67 | bandw = x.size()[2] // 2
68 | aux1 = torch.cat(
69 | [
70 | self.stg1_low_band_net(x[:, :, :bandw]),
71 | self.stg1_high_band_net(x[:, :, bandw:]),
72 | ],
73 | dim=2,
74 | )
75 |
76 | h = torch.cat([x, aux1], dim=1)
77 | aux2 = self.stg2_full_band_net(self.stg2_bridge(h))
78 |
79 | h = torch.cat([x, aux1, aux2], dim=1)
80 | h = self.stg3_full_band_net(self.stg3_bridge(h))
81 |
82 | mask = torch.sigmoid(self.out(h))
83 | mask = F.pad(
84 | input=mask,
85 | pad=(0, 0, 0, self.output_bin - mask.size()[2]),
86 | mode="replicate",
87 | )
88 |
89 | if self.training:
90 | aux1 = torch.sigmoid(self.aux1_out(aux1))
91 | aux1 = F.pad(
92 | input=aux1,
93 | pad=(0, 0, 0, self.output_bin - aux1.size()[2]),
94 | mode="replicate",
95 | )
96 | aux2 = torch.sigmoid(self.aux2_out(aux2))
97 | aux2 = F.pad(
98 | input=aux2,
99 | pad=(0, 0, 0, self.output_bin - aux2.size()[2]),
100 | mode="replicate",
101 | )
102 | return mask * mix, aux1 * mix, aux2 * mix
103 | else:
104 | if aggressiveness:
105 | mask[:, :, : aggressiveness["split_bin"]] = torch.pow(
106 | mask[:, :, : aggressiveness["split_bin"]],
107 | 1 + aggressiveness["value"] / 3,
108 | )
109 | mask[:, :, aggressiveness["split_bin"] :] = torch.pow(
110 | mask[:, :, aggressiveness["split_bin"] :],
111 | 1 + aggressiveness["value"],
112 | )
113 |
114 | return mask * mix
115 |
116 | def predict(self, x_mag, aggressiveness=None):
117 | h = self.forward(x_mag, aggressiveness)
118 |
119 | if self.offset > 0:
120 | h = h[:, :, :, self.offset : -self.offset]
121 | assert h.size()[3] > 0
122 |
123 | return h
124 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/nets_123812KB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import layers_123821KB as layers
6 |
7 |
8 | class BaseASPPNet(nn.Module):
9 | def __init__(self, nin, ch, dilations=(4, 8, 16)):
10 | super(BaseASPPNet, self).__init__()
11 | self.enc1 = layers.Encoder(nin, ch, 3, 2, 1)
12 | self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1)
13 | self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1)
14 | self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1)
15 |
16 | self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations)
17 |
18 | self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1)
19 | self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1)
20 | self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1)
21 | self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1)
22 |
23 | def __call__(self, x):
24 | h, e1 = self.enc1(x)
25 | h, e2 = self.enc2(h)
26 | h, e3 = self.enc3(h)
27 | h, e4 = self.enc4(h)
28 |
29 | h = self.aspp(h)
30 |
31 | h = self.dec4(h, e4)
32 | h = self.dec3(h, e3)
33 | h = self.dec2(h, e2)
34 | h = self.dec1(h, e1)
35 |
36 | return h
37 |
38 |
39 | class CascadedASPPNet(nn.Module):
40 | def __init__(self, n_fft):
41 | super(CascadedASPPNet, self).__init__()
42 | self.stg1_low_band_net = BaseASPPNet(2, 32)
43 | self.stg1_high_band_net = BaseASPPNet(2, 32)
44 |
45 | self.stg2_bridge = layers.Conv2DBNActiv(34, 16, 1, 1, 0)
46 | self.stg2_full_band_net = BaseASPPNet(16, 32)
47 |
48 | self.stg3_bridge = layers.Conv2DBNActiv(66, 32, 1, 1, 0)
49 | self.stg3_full_band_net = BaseASPPNet(32, 64)
50 |
51 | self.out = nn.Conv2d(64, 2, 1, bias=False)
52 | self.aux1_out = nn.Conv2d(32, 2, 1, bias=False)
53 | self.aux2_out = nn.Conv2d(32, 2, 1, bias=False)
54 |
55 | self.max_bin = n_fft // 2
56 | self.output_bin = n_fft // 2 + 1
57 |
58 | self.offset = 128
59 |
60 | def forward(self, x, aggressiveness=None):
61 | mix = x.detach()
62 | x = x.clone()
63 |
64 | x = x[:, :, : self.max_bin]
65 |
66 | bandw = x.size()[2] // 2
67 | aux1 = torch.cat(
68 | [
69 | self.stg1_low_band_net(x[:, :, :bandw]),
70 | self.stg1_high_band_net(x[:, :, bandw:]),
71 | ],
72 | dim=2,
73 | )
74 |
75 | h = torch.cat([x, aux1], dim=1)
76 | aux2 = self.stg2_full_band_net(self.stg2_bridge(h))
77 |
78 | h = torch.cat([x, aux1, aux2], dim=1)
79 | h = self.stg3_full_band_net(self.stg3_bridge(h))
80 |
81 | mask = torch.sigmoid(self.out(h))
82 | mask = F.pad(
83 | input=mask,
84 | pad=(0, 0, 0, self.output_bin - mask.size()[2]),
85 | mode="replicate",
86 | )
87 |
88 | if self.training:
89 | aux1 = torch.sigmoid(self.aux1_out(aux1))
90 | aux1 = F.pad(
91 | input=aux1,
92 | pad=(0, 0, 0, self.output_bin - aux1.size()[2]),
93 | mode="replicate",
94 | )
95 | aux2 = torch.sigmoid(self.aux2_out(aux2))
96 | aux2 = F.pad(
97 | input=aux2,
98 | pad=(0, 0, 0, self.output_bin - aux2.size()[2]),
99 | mode="replicate",
100 | )
101 | return mask * mix, aux1 * mix, aux2 * mix
102 | else:
103 | if aggressiveness:
104 | mask[:, :, : aggressiveness["split_bin"]] = torch.pow(
105 | mask[:, :, : aggressiveness["split_bin"]],
106 | 1 + aggressiveness["value"] / 3,
107 | )
108 | mask[:, :, aggressiveness["split_bin"] :] = torch.pow(
109 | mask[:, :, aggressiveness["split_bin"] :],
110 | 1 + aggressiveness["value"],
111 | )
112 |
113 | return mask * mix
114 |
115 | def predict(self, x_mag, aggressiveness=None):
116 | h = self.forward(x_mag, aggressiveness)
117 |
118 | if self.offset > 0:
119 | h = h[:, :, :, self.offset : -self.offset]
120 | assert h.size()[3] > 0
121 |
122 | return h
123 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/nets_123821KB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import layers_123821KB as layers
6 |
7 |
8 | class BaseASPPNet(nn.Module):
9 | def __init__(self, nin, ch, dilations=(4, 8, 16)):
10 | super(BaseASPPNet, self).__init__()
11 | self.enc1 = layers.Encoder(nin, ch, 3, 2, 1)
12 | self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1)
13 | self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1)
14 | self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1)
15 |
16 | self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations)
17 |
18 | self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1)
19 | self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1)
20 | self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1)
21 | self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1)
22 |
23 | def __call__(self, x):
24 | h, e1 = self.enc1(x)
25 | h, e2 = self.enc2(h)
26 | h, e3 = self.enc3(h)
27 | h, e4 = self.enc4(h)
28 |
29 | h = self.aspp(h)
30 |
31 | h = self.dec4(h, e4)
32 | h = self.dec3(h, e3)
33 | h = self.dec2(h, e2)
34 | h = self.dec1(h, e1)
35 |
36 | return h
37 |
38 |
39 | class CascadedASPPNet(nn.Module):
40 | def __init__(self, n_fft):
41 | super(CascadedASPPNet, self).__init__()
42 | self.stg1_low_band_net = BaseASPPNet(2, 32)
43 | self.stg1_high_band_net = BaseASPPNet(2, 32)
44 |
45 | self.stg2_bridge = layers.Conv2DBNActiv(34, 16, 1, 1, 0)
46 | self.stg2_full_band_net = BaseASPPNet(16, 32)
47 |
48 | self.stg3_bridge = layers.Conv2DBNActiv(66, 32, 1, 1, 0)
49 | self.stg3_full_band_net = BaseASPPNet(32, 64)
50 |
51 | self.out = nn.Conv2d(64, 2, 1, bias=False)
52 | self.aux1_out = nn.Conv2d(32, 2, 1, bias=False)
53 | self.aux2_out = nn.Conv2d(32, 2, 1, bias=False)
54 |
55 | self.max_bin = n_fft // 2
56 | self.output_bin = n_fft // 2 + 1
57 |
58 | self.offset = 128
59 |
60 | def forward(self, x, aggressiveness=None):
61 | mix = x.detach()
62 | x = x.clone()
63 |
64 | x = x[:, :, : self.max_bin]
65 |
66 | bandw = x.size()[2] // 2
67 | aux1 = torch.cat(
68 | [
69 | self.stg1_low_band_net(x[:, :, :bandw]),
70 | self.stg1_high_band_net(x[:, :, bandw:]),
71 | ],
72 | dim=2,
73 | )
74 |
75 | h = torch.cat([x, aux1], dim=1)
76 | aux2 = self.stg2_full_band_net(self.stg2_bridge(h))
77 |
78 | h = torch.cat([x, aux1, aux2], dim=1)
79 | h = self.stg3_full_band_net(self.stg3_bridge(h))
80 |
81 | mask = torch.sigmoid(self.out(h))
82 | mask = F.pad(
83 | input=mask,
84 | pad=(0, 0, 0, self.output_bin - mask.size()[2]),
85 | mode="replicate",
86 | )
87 |
88 | if self.training:
89 | aux1 = torch.sigmoid(self.aux1_out(aux1))
90 | aux1 = F.pad(
91 | input=aux1,
92 | pad=(0, 0, 0, self.output_bin - aux1.size()[2]),
93 | mode="replicate",
94 | )
95 | aux2 = torch.sigmoid(self.aux2_out(aux2))
96 | aux2 = F.pad(
97 | input=aux2,
98 | pad=(0, 0, 0, self.output_bin - aux2.size()[2]),
99 | mode="replicate",
100 | )
101 | return mask * mix, aux1 * mix, aux2 * mix
102 | else:
103 | if aggressiveness:
104 | mask[:, :, : aggressiveness["split_bin"]] = torch.pow(
105 | mask[:, :, : aggressiveness["split_bin"]],
106 | 1 + aggressiveness["value"] / 3,
107 | )
108 | mask[:, :, aggressiveness["split_bin"] :] = torch.pow(
109 | mask[:, :, aggressiveness["split_bin"] :],
110 | 1 + aggressiveness["value"],
111 | )
112 |
113 | return mask * mix
114 |
115 | def predict(self, x_mag, aggressiveness=None):
116 | h = self.forward(x_mag, aggressiveness)
117 |
118 | if self.offset > 0:
119 | h = h[:, :, :, self.offset : -self.offset]
120 | assert h.size()[3] > 0
121 |
122 | return h
123 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/nets_33966KB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import layers_33966KB as layers
6 |
7 |
8 | class BaseASPPNet(nn.Module):
9 | def __init__(self, nin, ch, dilations=(4, 8, 16, 32)):
10 | super(BaseASPPNet, self).__init__()
11 | self.enc1 = layers.Encoder(nin, ch, 3, 2, 1)
12 | self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1)
13 | self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1)
14 | self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1)
15 |
16 | self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations)
17 |
18 | self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1)
19 | self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1)
20 | self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1)
21 | self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1)
22 |
23 | def __call__(self, x):
24 | h, e1 = self.enc1(x)
25 | h, e2 = self.enc2(h)
26 | h, e3 = self.enc3(h)
27 | h, e4 = self.enc4(h)
28 |
29 | h = self.aspp(h)
30 |
31 | h = self.dec4(h, e4)
32 | h = self.dec3(h, e3)
33 | h = self.dec2(h, e2)
34 | h = self.dec1(h, e1)
35 |
36 | return h
37 |
38 |
39 | class CascadedASPPNet(nn.Module):
40 | def __init__(self, n_fft):
41 | super(CascadedASPPNet, self).__init__()
42 | self.stg1_low_band_net = BaseASPPNet(2, 16)
43 | self.stg1_high_band_net = BaseASPPNet(2, 16)
44 |
45 | self.stg2_bridge = layers.Conv2DBNActiv(18, 8, 1, 1, 0)
46 | self.stg2_full_band_net = BaseASPPNet(8, 16)
47 |
48 | self.stg3_bridge = layers.Conv2DBNActiv(34, 16, 1, 1, 0)
49 | self.stg3_full_band_net = BaseASPPNet(16, 32)
50 |
51 | self.out = nn.Conv2d(32, 2, 1, bias=False)
52 | self.aux1_out = nn.Conv2d(16, 2, 1, bias=False)
53 | self.aux2_out = nn.Conv2d(16, 2, 1, bias=False)
54 |
55 | self.max_bin = n_fft // 2
56 | self.output_bin = n_fft // 2 + 1
57 |
58 | self.offset = 128
59 |
60 | def forward(self, x, aggressiveness=None):
61 | mix = x.detach()
62 | x = x.clone()
63 |
64 | x = x[:, :, : self.max_bin]
65 |
66 | bandw = x.size()[2] // 2
67 | aux1 = torch.cat(
68 | [
69 | self.stg1_low_band_net(x[:, :, :bandw]),
70 | self.stg1_high_band_net(x[:, :, bandw:]),
71 | ],
72 | dim=2,
73 | )
74 |
75 | h = torch.cat([x, aux1], dim=1)
76 | aux2 = self.stg2_full_band_net(self.stg2_bridge(h))
77 |
78 | h = torch.cat([x, aux1, aux2], dim=1)
79 | h = self.stg3_full_band_net(self.stg3_bridge(h))
80 |
81 | mask = torch.sigmoid(self.out(h))
82 | mask = F.pad(
83 | input=mask,
84 | pad=(0, 0, 0, self.output_bin - mask.size()[2]),
85 | mode="replicate",
86 | )
87 |
88 | if self.training:
89 | aux1 = torch.sigmoid(self.aux1_out(aux1))
90 | aux1 = F.pad(
91 | input=aux1,
92 | pad=(0, 0, 0, self.output_bin - aux1.size()[2]),
93 | mode="replicate",
94 | )
95 | aux2 = torch.sigmoid(self.aux2_out(aux2))
96 | aux2 = F.pad(
97 | input=aux2,
98 | pad=(0, 0, 0, self.output_bin - aux2.size()[2]),
99 | mode="replicate",
100 | )
101 | return mask * mix, aux1 * mix, aux2 * mix
102 | else:
103 | if aggressiveness:
104 | mask[:, :, : aggressiveness["split_bin"]] = torch.pow(
105 | mask[:, :, : aggressiveness["split_bin"]],
106 | 1 + aggressiveness["value"] / 3,
107 | )
108 | mask[:, :, aggressiveness["split_bin"] :] = torch.pow(
109 | mask[:, :, aggressiveness["split_bin"] :],
110 | 1 + aggressiveness["value"],
111 | )
112 |
113 | return mask * mix
114 |
115 | def predict(self, x_mag, aggressiveness=None):
116 | h = self.forward(x_mag, aggressiveness)
117 |
118 | if self.offset > 0:
119 | h = h[:, :, :, self.offset : -self.offset]
120 | assert h.size()[3] > 0
121 |
122 | return h
123 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/nets_537227KB.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 |
6 | from . import layers_537238KB as layers
7 |
8 |
9 | class BaseASPPNet(nn.Module):
10 | def __init__(self, nin, ch, dilations=(4, 8, 16)):
11 | super(BaseASPPNet, self).__init__()
12 | self.enc1 = layers.Encoder(nin, ch, 3, 2, 1)
13 | self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1)
14 | self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1)
15 | self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1)
16 |
17 | self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations)
18 |
19 | self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1)
20 | self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1)
21 | self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1)
22 | self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1)
23 |
24 | def __call__(self, x):
25 | h, e1 = self.enc1(x)
26 | h, e2 = self.enc2(h)
27 | h, e3 = self.enc3(h)
28 | h, e4 = self.enc4(h)
29 |
30 | h = self.aspp(h)
31 |
32 | h = self.dec4(h, e4)
33 | h = self.dec3(h, e3)
34 | h = self.dec2(h, e2)
35 | h = self.dec1(h, e1)
36 |
37 | return h
38 |
39 |
40 | class CascadedASPPNet(nn.Module):
41 | def __init__(self, n_fft):
42 | super(CascadedASPPNet, self).__init__()
43 | self.stg1_low_band_net = BaseASPPNet(2, 64)
44 | self.stg1_high_band_net = BaseASPPNet(2, 64)
45 |
46 | self.stg2_bridge = layers.Conv2DBNActiv(66, 32, 1, 1, 0)
47 | self.stg2_full_band_net = BaseASPPNet(32, 64)
48 |
49 | self.stg3_bridge = layers.Conv2DBNActiv(130, 64, 1, 1, 0)
50 | self.stg3_full_band_net = BaseASPPNet(64, 128)
51 |
52 | self.out = nn.Conv2d(128, 2, 1, bias=False)
53 | self.aux1_out = nn.Conv2d(64, 2, 1, bias=False)
54 | self.aux2_out = nn.Conv2d(64, 2, 1, bias=False)
55 |
56 | self.max_bin = n_fft // 2
57 | self.output_bin = n_fft // 2 + 1
58 |
59 | self.offset = 128
60 |
61 | def forward(self, x, aggressiveness=None):
62 | mix = x.detach()
63 | x = x.clone()
64 |
65 | x = x[:, :, : self.max_bin]
66 |
67 | bandw = x.size()[2] // 2
68 | aux1 = torch.cat(
69 | [
70 | self.stg1_low_band_net(x[:, :, :bandw]),
71 | self.stg1_high_band_net(x[:, :, bandw:]),
72 | ],
73 | dim=2,
74 | )
75 |
76 | h = torch.cat([x, aux1], dim=1)
77 | aux2 = self.stg2_full_band_net(self.stg2_bridge(h))
78 |
79 | h = torch.cat([x, aux1, aux2], dim=1)
80 | h = self.stg3_full_band_net(self.stg3_bridge(h))
81 |
82 | mask = torch.sigmoid(self.out(h))
83 | mask = F.pad(
84 | input=mask,
85 | pad=(0, 0, 0, self.output_bin - mask.size()[2]),
86 | mode="replicate",
87 | )
88 |
89 | if self.training:
90 | aux1 = torch.sigmoid(self.aux1_out(aux1))
91 | aux1 = F.pad(
92 | input=aux1,
93 | pad=(0, 0, 0, self.output_bin - aux1.size()[2]),
94 | mode="replicate",
95 | )
96 | aux2 = torch.sigmoid(self.aux2_out(aux2))
97 | aux2 = F.pad(
98 | input=aux2,
99 | pad=(0, 0, 0, self.output_bin - aux2.size()[2]),
100 | mode="replicate",
101 | )
102 | return mask * mix, aux1 * mix, aux2 * mix
103 | else:
104 | if aggressiveness:
105 | mask[:, :, : aggressiveness["split_bin"]] = torch.pow(
106 | mask[:, :, : aggressiveness["split_bin"]],
107 | 1 + aggressiveness["value"] / 3,
108 | )
109 | mask[:, :, aggressiveness["split_bin"] :] = torch.pow(
110 | mask[:, :, aggressiveness["split_bin"] :],
111 | 1 + aggressiveness["value"],
112 | )
113 |
114 | return mask * mix
115 |
116 | def predict(self, x_mag, aggressiveness=None):
117 | h = self.forward(x_mag, aggressiveness)
118 |
119 | if self.offset > 0:
120 | h = h[:, :, :, self.offset : -self.offset]
121 | assert h.size()[3] > 0
122 |
123 | return h
124 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/nets_537238KB.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 |
6 | from . import layers_537238KB as layers
7 |
8 |
9 | class BaseASPPNet(nn.Module):
10 | def __init__(self, nin, ch, dilations=(4, 8, 16)):
11 | super(BaseASPPNet, self).__init__()
12 | self.enc1 = layers.Encoder(nin, ch, 3, 2, 1)
13 | self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1)
14 | self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1)
15 | self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1)
16 |
17 | self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations)
18 |
19 | self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1)
20 | self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1)
21 | self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1)
22 | self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1)
23 |
24 | def __call__(self, x):
25 | h, e1 = self.enc1(x)
26 | h, e2 = self.enc2(h)
27 | h, e3 = self.enc3(h)
28 | h, e4 = self.enc4(h)
29 |
30 | h = self.aspp(h)
31 |
32 | h = self.dec4(h, e4)
33 | h = self.dec3(h, e3)
34 | h = self.dec2(h, e2)
35 | h = self.dec1(h, e1)
36 |
37 | return h
38 |
39 |
40 | class CascadedASPPNet(nn.Module):
41 | def __init__(self, n_fft):
42 | super(CascadedASPPNet, self).__init__()
43 | self.stg1_low_band_net = BaseASPPNet(2, 64)
44 | self.stg1_high_band_net = BaseASPPNet(2, 64)
45 |
46 | self.stg2_bridge = layers.Conv2DBNActiv(66, 32, 1, 1, 0)
47 | self.stg2_full_band_net = BaseASPPNet(32, 64)
48 |
49 | self.stg3_bridge = layers.Conv2DBNActiv(130, 64, 1, 1, 0)
50 | self.stg3_full_band_net = BaseASPPNet(64, 128)
51 |
52 | self.out = nn.Conv2d(128, 2, 1, bias=False)
53 | self.aux1_out = nn.Conv2d(64, 2, 1, bias=False)
54 | self.aux2_out = nn.Conv2d(64, 2, 1, bias=False)
55 |
56 | self.max_bin = n_fft // 2
57 | self.output_bin = n_fft // 2 + 1
58 |
59 | self.offset = 128
60 |
61 | def forward(self, x, aggressiveness=None):
62 | mix = x.detach()
63 | x = x.clone()
64 |
65 | x = x[:, :, : self.max_bin]
66 |
67 | bandw = x.size()[2] // 2
68 | aux1 = torch.cat(
69 | [
70 | self.stg1_low_band_net(x[:, :, :bandw]),
71 | self.stg1_high_band_net(x[:, :, bandw:]),
72 | ],
73 | dim=2,
74 | )
75 |
76 | h = torch.cat([x, aux1], dim=1)
77 | aux2 = self.stg2_full_band_net(self.stg2_bridge(h))
78 |
79 | h = torch.cat([x, aux1, aux2], dim=1)
80 | h = self.stg3_full_band_net(self.stg3_bridge(h))
81 |
82 | mask = torch.sigmoid(self.out(h))
83 | mask = F.pad(
84 | input=mask,
85 | pad=(0, 0, 0, self.output_bin - mask.size()[2]),
86 | mode="replicate",
87 | )
88 |
89 | if self.training:
90 | aux1 = torch.sigmoid(self.aux1_out(aux1))
91 | aux1 = F.pad(
92 | input=aux1,
93 | pad=(0, 0, 0, self.output_bin - aux1.size()[2]),
94 | mode="replicate",
95 | )
96 | aux2 = torch.sigmoid(self.aux2_out(aux2))
97 | aux2 = F.pad(
98 | input=aux2,
99 | pad=(0, 0, 0, self.output_bin - aux2.size()[2]),
100 | mode="replicate",
101 | )
102 | return mask * mix, aux1 * mix, aux2 * mix
103 | else:
104 | if aggressiveness:
105 | mask[:, :, : aggressiveness["split_bin"]] = torch.pow(
106 | mask[:, :, : aggressiveness["split_bin"]],
107 | 1 + aggressiveness["value"] / 3,
108 | )
109 | mask[:, :, aggressiveness["split_bin"] :] = torch.pow(
110 | mask[:, :, aggressiveness["split_bin"] :],
111 | 1 + aggressiveness["value"],
112 | )
113 |
114 | return mask * mix
115 |
116 | def predict(self, x_mag, aggressiveness=None):
117 | h = self.forward(x_mag, aggressiveness)
118 |
119 | if self.offset > 0:
120 | h = h[:, :, :, self.offset : -self.offset]
121 | assert h.size()[3] > 0
122 |
123 | return h
124 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/nets_61968KB.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import layers_123821KB as layers
6 |
7 |
8 | class BaseASPPNet(nn.Module):
9 | def __init__(self, nin, ch, dilations=(4, 8, 16)):
10 | super(BaseASPPNet, self).__init__()
11 | self.enc1 = layers.Encoder(nin, ch, 3, 2, 1)
12 | self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1)
13 | self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1)
14 | self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1)
15 |
16 | self.aspp = layers.ASPPModule(ch * 8, ch * 16, dilations)
17 |
18 | self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1)
19 | self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1)
20 | self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1)
21 | self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1)
22 |
23 | def __call__(self, x):
24 | h, e1 = self.enc1(x)
25 | h, e2 = self.enc2(h)
26 | h, e3 = self.enc3(h)
27 | h, e4 = self.enc4(h)
28 |
29 | h = self.aspp(h)
30 |
31 | h = self.dec4(h, e4)
32 | h = self.dec3(h, e3)
33 | h = self.dec2(h, e2)
34 | h = self.dec1(h, e1)
35 |
36 | return h
37 |
38 |
39 | class CascadedASPPNet(nn.Module):
40 | def __init__(self, n_fft):
41 | super(CascadedASPPNet, self).__init__()
42 | self.stg1_low_band_net = BaseASPPNet(2, 32)
43 | self.stg1_high_band_net = BaseASPPNet(2, 32)
44 |
45 | self.stg2_bridge = layers.Conv2DBNActiv(34, 16, 1, 1, 0)
46 | self.stg2_full_band_net = BaseASPPNet(16, 32)
47 |
48 | self.stg3_bridge = layers.Conv2DBNActiv(66, 32, 1, 1, 0)
49 | self.stg3_full_band_net = BaseASPPNet(32, 64)
50 |
51 | self.out = nn.Conv2d(64, 2, 1, bias=False)
52 | self.aux1_out = nn.Conv2d(32, 2, 1, bias=False)
53 | self.aux2_out = nn.Conv2d(32, 2, 1, bias=False)
54 |
55 | self.max_bin = n_fft // 2
56 | self.output_bin = n_fft // 2 + 1
57 |
58 | self.offset = 128
59 |
60 | def forward(self, x, aggressiveness=None):
61 | mix = x.detach()
62 | x = x.clone()
63 |
64 | x = x[:, :, : self.max_bin]
65 |
66 | bandw = x.size()[2] // 2
67 | aux1 = torch.cat(
68 | [
69 | self.stg1_low_band_net(x[:, :, :bandw]),
70 | self.stg1_high_band_net(x[:, :, bandw:]),
71 | ],
72 | dim=2,
73 | )
74 |
75 | h = torch.cat([x, aux1], dim=1)
76 | aux2 = self.stg2_full_band_net(self.stg2_bridge(h))
77 |
78 | h = torch.cat([x, aux1, aux2], dim=1)
79 | h = self.stg3_full_band_net(self.stg3_bridge(h))
80 |
81 | mask = torch.sigmoid(self.out(h))
82 | mask = F.pad(
83 | input=mask,
84 | pad=(0, 0, 0, self.output_bin - mask.size()[2]),
85 | mode="replicate",
86 | )
87 |
88 | if self.training:
89 | aux1 = torch.sigmoid(self.aux1_out(aux1))
90 | aux1 = F.pad(
91 | input=aux1,
92 | pad=(0, 0, 0, self.output_bin - aux1.size()[2]),
93 | mode="replicate",
94 | )
95 | aux2 = torch.sigmoid(self.aux2_out(aux2))
96 | aux2 = F.pad(
97 | input=aux2,
98 | pad=(0, 0, 0, self.output_bin - aux2.size()[2]),
99 | mode="replicate",
100 | )
101 | return mask * mix, aux1 * mix, aux2 * mix
102 | else:
103 | if aggressiveness:
104 | mask[:, :, : aggressiveness["split_bin"]] = torch.pow(
105 | mask[:, :, : aggressiveness["split_bin"]],
106 | 1 + aggressiveness["value"] / 3,
107 | )
108 | mask[:, :, aggressiveness["split_bin"] :] = torch.pow(
109 | mask[:, :, aggressiveness["split_bin"] :],
110 | 1 + aggressiveness["value"],
111 | )
112 |
113 | return mask * mix
114 |
115 | def predict(self, x_mag, aggressiveness=None):
116 | h = self.forward(x_mag, aggressiveness)
117 |
118 | if self.offset > 0:
119 | h = h[:, :, :, self.offset : -self.offset]
120 | assert h.size()[3] > 0
121 |
122 | return h
123 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/lib_v5/nets_new.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from . import layers_new
6 |
7 |
8 | class BaseNet(nn.Module):
9 | def __init__(
10 | self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))
11 | ):
12 | super(BaseNet, self).__init__()
13 | self.enc1 = layers_new.Conv2DBNActiv(nin, nout, 3, 1, 1)
14 | self.enc2 = layers_new.Encoder(nout, nout * 2, 3, 2, 1)
15 | self.enc3 = layers_new.Encoder(nout * 2, nout * 4, 3, 2, 1)
16 | self.enc4 = layers_new.Encoder(nout * 4, nout * 6, 3, 2, 1)
17 | self.enc5 = layers_new.Encoder(nout * 6, nout * 8, 3, 2, 1)
18 |
19 | self.aspp = layers_new.ASPPModule(nout * 8, nout * 8, dilations, dropout=True)
20 |
21 | self.dec4 = layers_new.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1)
22 | self.dec3 = layers_new.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1)
23 | self.dec2 = layers_new.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1)
24 | self.lstm_dec2 = layers_new.LSTMModule(nout * 2, nin_lstm, nout_lstm)
25 | self.dec1 = layers_new.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1)
26 |
27 | def __call__(self, x):
28 | e1 = self.enc1(x)
29 | e2 = self.enc2(e1)
30 | e3 = self.enc3(e2)
31 | e4 = self.enc4(e3)
32 | e5 = self.enc5(e4)
33 |
34 | h = self.aspp(e5)
35 |
36 | h = self.dec4(h, e4)
37 | h = self.dec3(h, e3)
38 | h = self.dec2(h, e2)
39 | h = torch.cat([h, self.lstm_dec2(h)], dim=1)
40 | h = self.dec1(h, e1)
41 |
42 | return h
43 |
44 |
45 | class CascadedNet(nn.Module):
46 | def __init__(self, n_fft, nout=32, nout_lstm=128):
47 | super(CascadedNet, self).__init__()
48 |
49 | self.max_bin = n_fft // 2
50 | self.output_bin = n_fft // 2 + 1
51 | self.nin_lstm = self.max_bin // 2
52 | self.offset = 64
53 |
54 | self.stg1_low_band_net = nn.Sequential(
55 | BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm),
56 | layers_new.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0),
57 | )
58 |
59 | self.stg1_high_band_net = BaseNet(
60 | 2, nout // 4, self.nin_lstm // 2, nout_lstm // 2
61 | )
62 |
63 | self.stg2_low_band_net = nn.Sequential(
64 | BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
65 | layers_new.Conv2DBNActiv(nout, nout // 2, 1, 1, 0),
66 | )
67 | self.stg2_high_band_net = BaseNet(
68 | nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2
69 | )
70 |
71 | self.stg3_full_band_net = BaseNet(
72 | 3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm
73 | )
74 |
75 | self.out = nn.Conv2d(nout, 2, 1, bias=False)
76 | self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)
77 |
78 | def forward(self, x):
79 | x = x[:, :, : self.max_bin]
80 |
81 | bandw = x.size()[2] // 2
82 | l1_in = x[:, :, :bandw]
83 | h1_in = x[:, :, bandw:]
84 | l1 = self.stg1_low_band_net(l1_in)
85 | h1 = self.stg1_high_band_net(h1_in)
86 | aux1 = torch.cat([l1, h1], dim=2)
87 |
88 | l2_in = torch.cat([l1_in, l1], dim=1)
89 | h2_in = torch.cat([h1_in, h1], dim=1)
90 | l2 = self.stg2_low_band_net(l2_in)
91 | h2 = self.stg2_high_band_net(h2_in)
92 | aux2 = torch.cat([l2, h2], dim=2)
93 |
94 | f3_in = torch.cat([x, aux1, aux2], dim=1)
95 | f3 = self.stg3_full_band_net(f3_in)
96 |
97 | mask = torch.sigmoid(self.out(f3))
98 | mask = F.pad(
99 | input=mask,
100 | pad=(0, 0, 0, self.output_bin - mask.size()[2]),
101 | mode="replicate",
102 | )
103 |
104 | if self.training:
105 | aux = torch.cat([aux1, aux2], dim=1)
106 | aux = torch.sigmoid(self.aux_out(aux))
107 | aux = F.pad(
108 | input=aux,
109 | pad=(0, 0, 0, self.output_bin - aux.size()[2]),
110 | mode="replicate",
111 | )
112 | return mask, aux
113 | else:
114 | return mask
115 |
116 | def predict_mask(self, x):
117 | mask = self.forward(x)
118 |
119 | if self.offset > 0:
120 | mask = mask[:, :, :, self.offset : -self.offset]
121 | assert mask.size()[3] > 0
122 |
123 | return mask
124 |
125 | def predict(self, x, aggressiveness=None):
126 | mask = self.forward(x)
127 | pred_mag = x * mask
128 |
129 | if self.offset > 0:
130 | pred_mag = pred_mag[:, :, :, self.offset : -self.offset]
131 | assert pred_mag.size()[3] > 0
132 |
133 | return pred_mag
134 |
--------------------------------------------------------------------------------
/tools/uvr5/lib/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import numpy as np
4 | import torch
5 | from tqdm import tqdm
6 |
7 |
8 | def load_data(file_name: str = "./lib/name_params.json") -> dict:
9 | with open(file_name, "r") as f:
10 | data = json.load(f)
11 |
12 | return data
13 |
14 |
15 | def make_padding(width, cropsize, offset):
16 | left = offset
17 | roi_size = cropsize - left * 2
18 | if roi_size == 0:
19 | roi_size = cropsize
20 | right = roi_size - (width % roi_size) + left
21 |
22 | return left, right, roi_size
23 |
24 |
25 | def inference(X_spec, device, model, aggressiveness, data):
26 | """
27 | data : dic configs
28 | """
29 |
30 | def _execute(
31 | X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half=True
32 | ):
33 | model.eval()
34 | with torch.no_grad():
35 | preds = []
36 |
37 | iterations = [n_window]
38 |
39 | total_iterations = sum(iterations)
40 | for i in tqdm(range(n_window)):
41 | start = i * roi_size
42 | X_mag_window = X_mag_pad[
43 | None, :, :, start : start + data["window_size"]
44 | ]
45 | X_mag_window = torch.from_numpy(X_mag_window)
46 | if is_half:
47 | X_mag_window = X_mag_window.half()
48 | X_mag_window = X_mag_window.to(device)
49 |
50 | pred = model.predict(X_mag_window, aggressiveness)
51 |
52 | pred = pred.detach().cpu().numpy()
53 | preds.append(pred[0])
54 |
55 | pred = np.concatenate(preds, axis=2)
56 | return pred
57 |
58 | def preprocess(X_spec):
59 | X_mag = np.abs(X_spec)
60 | X_phase = np.angle(X_spec)
61 |
62 | return X_mag, X_phase
63 |
64 | X_mag, X_phase = preprocess(X_spec)
65 |
66 | coef = X_mag.max()
67 | X_mag_pre = X_mag / coef
68 |
69 | n_frame = X_mag_pre.shape[2]
70 | pad_l, pad_r, roi_size = make_padding(n_frame, data["window_size"], model.offset)
71 | n_window = int(np.ceil(n_frame / roi_size))
72 |
73 | X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
74 |
75 | if list(model.state_dict().values())[0].dtype == torch.float16:
76 | is_half = True
77 | else:
78 | is_half = False
79 | pred = _execute(
80 | X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half
81 | )
82 | pred = pred[:, :, :n_frame]
83 |
84 | if data["tta"]:
85 | pad_l += roi_size // 2
86 | pad_r += roi_size // 2
87 | n_window += 1
88 |
89 | X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
90 |
91 | pred_tta = _execute(
92 | X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half
93 | )
94 | pred_tta = pred_tta[:, :, roi_size // 2 :]
95 | pred_tta = pred_tta[:, :, :n_frame]
96 |
97 | return (pred + pred_tta) * 0.5 * coef, X_mag, np.exp(1.0j * X_phase)
98 | else:
99 | return pred * coef, X_mag, np.exp(1.0j * X_phase)
100 |
101 |
102 | def _get_name_params(model_path, model_hash):
103 | data = load_data()
104 | flag = False
105 | ModelName = model_path
106 | for type in list(data):
107 | for model in list(data[type][0]):
108 | for i in range(len(data[type][0][model])):
109 | if str(data[type][0][model][i]["hash_name"]) == model_hash:
110 | flag = True
111 | elif str(data[type][0][model][i]["hash_name"]) in ModelName:
112 | flag = True
113 |
114 | if flag:
115 | model_params_auto = data[type][0][model][i]["model_params"]
116 | param_name_auto = data[type][0][model][i]["param_name"]
117 | if type == "equivalent":
118 | return param_name_auto, model_params_auto
119 | else:
120 | flag = False
121 | return param_name_auto, model_params_auto
122 |
--------------------------------------------------------------------------------