├── .gitattributes
├── .github
├── ISSUE_TEMPLATE
│ ├── ask_for_help.yaml
│ ├── ask_for_help_en_US.yaml
│ ├── bug_report.yaml
│ ├── bug_report_en_US.yaml
│ ├── config.yml
│ └── default.md
└── workflows
│ ├── reviewdog.yml
│ └── ruff.yml
├── .gitignore
├── .ruff.toml
├── LICENSE
├── README.md
├── README_zh_CN.md
├── cluster
├── __init__.py
├── kmeans.py
└── train_cluster.py
├── compress_model.py
├── configs
├── config.json
└── diffusion.yaml
├── configs_template
├── config_template.json
├── config_tiny_template.json
└── diffusion_template.yaml
├── data_utils.py
├── dataset_raw
└── wav_structure.txt
├── diffusion
├── __init__.py
├── data_loaders.py
├── diffusion.py
├── diffusion_onnx.py
├── dpm_solver_pytorch.py
├── how to export onnx.md
├── infer_gt_mel.py
├── logger
│ ├── __init__.py
│ ├── saver.py
│ └── utils.py
├── onnx_export.py
├── solver.py
├── uni_pc.py
├── unit2mel.py
├── vocoder.py
└── wavenet.py
├── edgetts
├── tts.py
└── tts_voices.py
├── export_index_for_onnx.py
├── filelists
├── test.txt
├── train.txt
└── val.txt
├── flask_api.py
├── flask_api_full_song.py
├── inference
├── __init__.py
├── infer_tool.py
├── infer_tool_grad.py
└── slicer.py
├── inference_main.py
├── logs
└── 44k
│ ├── diffusion
│ └── put_diffusion_pretrained_model_here
│ └── put_pretrained_model_here
├── models.py
├── modules
├── DSConv.py
├── F0Predictor
│ ├── CrepeF0Predictor.py
│ ├── DioF0Predictor.py
│ ├── F0Predictor.py
│ ├── FCPEF0Predictor.py
│ ├── HarvestF0Predictor.py
│ ├── PMF0Predictor.py
│ ├── RMVPEF0Predictor.py
│ ├── __init__.py
│ ├── crepe.py
│ ├── fcpe
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── nvSTFT.py
│ │ └── pcmer.py
│ └── rmvpe
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── deepunet.py
│ │ ├── inference.py
│ │ ├── model.py
│ │ ├── seq.py
│ │ ├── spec.py
│ │ └── utils.py
├── __init__.py
├── attentions.py
├── commons.py
├── enhancer.py
├── losses.py
├── mel_processing.py
└── modules.py
├── onnx_export.py
├── onnx_export_old.py
├── onnxexport
├── model_onnx.py
└── model_onnx_speaker_mix.py
├── preprocess_flist_config.py
├── preprocess_hubert_f0.py
├── pretrain
├── __init__.py
├── meta.py
├── nsf_hifigan
│ └── put_nsf_hifigan_ckpt_here
└── put_hubert_ckpt_here
├── raw
└── put_raw_wav_here
├── requirements.txt
├── requirements_onnx_encoder.txt
├── requirements_win.txt
├── resample.py
├── shadowdiffusion.png
├── sovits4_for_colab.ipynb
├── spkmix.py
├── train.py
├── train_diff.py
├── train_index.py
├── trained
└── put_trained_checkpoints_here
├── utils.py
├── vdecoder
├── __init__.py
├── hifigan
│ ├── env.py
│ ├── models.py
│ ├── nvSTFT.py
│ └── utils.py
├── hifiganwithsnake
│ ├── alias
│ │ ├── __init__.py
│ │ ├── act.py
│ │ ├── filter.py
│ │ └── resample.py
│ ├── env.py
│ ├── models.py
│ ├── nvSTFT.py
│ └── utils.py
└── nsf_hifigan
│ ├── env.py
│ ├── models.py
│ ├── nvSTFT.py
│ └── utils.py
├── vencoder
├── CNHubertLarge.py
├── ContentVec256L12_Onnx.py
├── ContentVec256L9.py
├── ContentVec256L9_Onnx.py
├── ContentVec768L12.py
├── ContentVec768L12_Onnx.py
├── ContentVec768L9_Onnx.py
├── DPHubert.py
├── HubertSoft.py
├── HubertSoft_Onnx.py
├── WavLMBasePlus.py
├── WhisperPPG.py
├── WhisperPPGLarge.py
├── __init__.py
├── dphubert
│ ├── __init__.py
│ ├── components.py
│ ├── hardconcrete.py
│ ├── model.py
│ ├── pruning_utils.py
│ └── utils
│ │ ├── __init__.py
│ │ └── import_huggingface_wavlm.py
├── encoder.py
├── hubert
│ ├── __init__.py
│ ├── hubert_model.py
│ └── hubert_model_onnx.py
├── wavlm
│ ├── WavLM.py
│ └── modules.py
└── whisper
│ ├── __init__.py
│ ├── audio.py
│ ├── decoding.py
│ ├── model.py
│ ├── tokenizer.py
│ └── utils.py
├── wav_upload.py
└── webUI.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | * text=auto eol=lf
2 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/ask_for_help.yaml:
--------------------------------------------------------------------------------
1 | name: 请求帮助
2 | description: 遇到了无法自行解决的错误
3 | title: '[Help]: '
4 | labels: [ "help wanted" ]
5 |
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | #### 提问前请先自己去尝试解决,比如查看[本仓库wiki](https://github.com/svc-develop-team/so-vits-svc/wiki),也可以借助chatgpt或一些搜索引擎(谷歌/必应/New Bing/StackOverflow等等)。如果实在无法自己解决再发issue,在提issue之前,请先了解《[提问的智慧](https://github.com/ryanhanwu/How-To-Ask-Questions-The-Smart-Way/blob/main/README-zh_CN.md)》。
11 | ---
12 | ### 什么样的issue会被直接close
13 | 1. 伸手党
14 | 2. 一键包/环境包相关
15 | 3. 提供的信息不全
16 | 4. 低级的如缺少依赖而导致无法运行的问题
17 | 4. 所用的数据集是无授权数据集(游戏角色/二次元人物暂不归为此类,但是训练时候也要小心谨慎。如果能联系到官方,必须先和官方联系并核实清楚)
18 | ---
19 |
20 | - type: checkboxes
21 | id: Clause
22 | attributes:
23 | label: 请勾选下方的确认框。
24 | options:
25 | - label: "我已仔细阅读[README.md](https://github.com/svc-develop-team/so-vits-svc/blob/4.0/README_zh_CN.md)和[wiki中的Quick solution](https://github.com/svc-develop-team/so-vits-svc/wiki/Quick-solution)。"
26 | required: true
27 | - label: "我已通过各种搜索引擎排查问题,我要提出的问题并不常见。"
28 | required: true
29 | - label: "我未在使用由第三方用户提供的一键包/环境包。"
30 | required: true
31 |
32 | - type: markdown
33 | attributes:
34 | value: |
35 | # 请根据实际使用环境填写以下信息
36 |
37 | - type: input
38 | id: System
39 | attributes:
40 | label: 系统平台版本号
41 | description: Windows执行`winver` | Linux执行`uname -a`
42 | validations:
43 | required: true
44 |
45 | - type: input
46 | id: GPU
47 | attributes:
48 | label: GPU 型号
49 | description: 执行`nvidia-smi`
50 | validations:
51 | required: true
52 |
53 | - type: input
54 | id: PythonVersion
55 | attributes:
56 | label: Python版本
57 | description: 执行`python -V`
58 | validations:
59 | required: true
60 |
61 | - type: input
62 | id: PyTorchVersion
63 | attributes:
64 | label: PyTorch版本
65 | description: 执行`pip show torch`
66 | validations:
67 | required: true
68 |
69 | - type: dropdown
70 | id: Branch
71 | attributes:
72 | label: sovits分支
73 | options:
74 | - 4.0(默认)
75 | - 4.0-v2
76 | - 3.0-32k
77 | - 3.0-48k
78 | validations:
79 | required: true
80 |
81 | - type: input
82 | id: DatasetSource
83 | attributes:
84 | label: 数据集来源(用于判断数据集质量)
85 | description: 如:UVR处理过的vtb直播音频、录音棚录制
86 | validations:
87 | required: true
88 |
89 | - type: input
90 | id: WhereOccurs
91 | attributes:
92 | label: 出现问题的环节或执行的命令
93 | description: 如:预处理、训练、`python preprocess_hubert_f0.py`
94 | validations:
95 | required: true
96 |
97 | - type: textarea
98 | id: Description
99 | attributes:
100 | label: 问题描述
101 | description: 在这里描述自己的问题,越详细越好
102 | validations:
103 | required: true
104 |
105 | - type: textarea
106 | id: Log
107 | attributes:
108 | label: 日志
109 | description: 将从执行命令到执行完毕输出的所有信息(包括你所执行的命令)粘贴到[pastebin.com](https://pastebin.com/)并把剪贴板链接贴到这里,日志量少的话也可以直接贴在下面
110 | render: python
111 | validations:
112 | required: true
113 |
114 | - type: textarea
115 | id: ValidOneClick
116 | attributes:
117 | label: 截图`so-vits-svc`、`logs/44k`文件夹并粘贴到此处
118 | validations:
119 | required: true
120 |
121 | - type: textarea
122 | id: Supplementary
123 | attributes:
124 | label: 补充说明
125 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/ask_for_help_en_US.yaml:
--------------------------------------------------------------------------------
1 | name: Ask for help
2 | description: Encountered an error cannot be resolved by self
3 | title: '[Help]: '
4 | labels: [ "help wanted" ]
5 |
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | #### Please try to solve the problem yourself before asking for help. At first you can read *[repo wiki](https://github.com/svc-develop-team/so-vits-svc/wiki)*. Then you can use chatgpt or some search engines like google, bing, new bing and StackOverflow until you really find that you can't solve it by yourself. And before you raise an issue, please understand *[How To Ask Questions The Smart Way](http://www.catb.org/~esr/faqs/smart-questions.html)* in advance.
11 | ---
12 | ### What kind of issue will be closed immediately
13 | 1. Beggars or Free Riders
14 | 2. One click package / Environment package (Not using `pip install -r requirement.txt`)
15 | 3. Incomplete information
16 | 4. Stupid issues such as miss a dependency package
17 | 4. Using unlicenced dataset (Game characters / anime characters are not included in this category temporarily but you still need to pay attention. If you can contact the official, you must contact the official and verify it at first.)
18 | ---
19 |
20 | - type: checkboxes
21 | id: Clause
22 | attributes:
23 | label: Please check the checkboxes below.
24 | options:
25 | - label: "I have read *[README.md](https://github.com/svc-develop-team/so-vits-svc/blob/4.0/README.md)* and *[Quick solution in wiki](https://github.com/svc-develop-team/so-vits-svc/wiki/Quick-solution)* carefully."
26 | required: true
27 | - label: "I have been troubleshooting issues through various search engines. The questions I want to ask are not common."
28 | required: true
29 | - label: "I am NOT using one click package / environment package."
30 | required: true
31 |
32 | - type: markdown
33 | attributes:
34 | value: |
35 | # Please fill in the following information according to your actual environment
36 |
37 | - type: input
38 | id: System
39 | attributes:
40 | label: OS version
41 | description: Windows run `winver` | Linux run `uname -a`
42 | validations:
43 | required: true
44 |
45 | - type: input
46 | id: GPU
47 | attributes:
48 | label: GPU
49 | description: Run `nvidia-smi`
50 | validations:
51 | required: true
52 |
53 | - type: input
54 | id: PythonVersion
55 | attributes:
56 | label: Python version
57 | description: Run `python -V`
58 | validations:
59 | required: true
60 |
61 | - type: input
62 | id: PyTorchVersion
63 | attributes:
64 | label: PyTorch version
65 | description: Run `pip show torch`
66 | validations:
67 | required: true
68 |
69 | - type: dropdown
70 | id: Branch
71 | attributes:
72 | label: Branch of sovits
73 | options:
74 | - 4.0(Default)
75 | - 4.0-v2
76 | - 3.0-32k
77 | - 3.0-48k
78 | validations:
79 | required: true
80 |
81 | - type: input
82 | id: DatasetSource
83 | attributes:
84 | label: Dataset source (Used to judge the dataset quality)
85 | description: Such as UVR-processed streaming audio / Recorded in recording studio
86 | validations:
87 | required: true
88 |
89 | - type: input
90 | id: WhereOccurs
91 | attributes:
92 | label: Where thr problem occurs or what command you executed
93 | description: Such as Preprocessing / Training / `python preprocess_hubert_f0.py`
94 | validations:
95 | required: true
96 |
97 | - type: textarea
98 | id: Description
99 | attributes:
100 | label: Problem description
101 | description: Describe your problem here, the more detailed the better.
102 | validations:
103 | required: true
104 |
105 | - type: textarea
106 | id: Log
107 | attributes:
108 | label: Log
109 | description: All information output from the command you executed to the end of execution (include the command). It can also be directly posted below if there is only few text.
110 | render: python
111 | validations:
112 | required: true
113 |
114 | - type: textarea
115 | id: ValidOneClick
116 | attributes:
117 | label: Screenshot `so-vits-svc` and `logs/44k` folders and paste here
118 | validations:
119 | required: true
120 |
121 | - type: textarea
122 | id: Supplementary
123 | attributes:
124 | label: Supplementary description
125 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: 问题回报
2 | description: 遇到了BUG?!
3 | title: '[Bug]: '
4 | labels: [ "bug?" ]
5 |
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | # 请根据实际使用环境填写以下信息
11 |
12 | - type: input
13 | id: System
14 | attributes:
15 | label: 系统平台版本号
16 | description: Windows执行`winver` | Linux执行`uname -a`
17 | validations:
18 | required: true
19 |
20 | - type: input
21 | id: GPU
22 | attributes:
23 | label: GPU 型号
24 | description: 执行`nvidia-smi`
25 | validations:
26 | required: true
27 |
28 | - type: input
29 | id: PythonVersion
30 | attributes:
31 | label: Python版本
32 | description: 执行`python -V`
33 | validations:
34 | required: true
35 |
36 | - type: input
37 | id: PyTorchVersion
38 | attributes:
39 | label: PyTorch版本
40 | description: 执行`pip show torch`
41 | validations:
42 | required: true
43 |
44 | - type: dropdown
45 | id: Branch
46 | attributes:
47 | label: sovits分支
48 | options:
49 | - 4.0(默认)
50 | - 4.0-v2
51 | - 3.0-32k
52 | - 3.0-48k
53 | validations:
54 | required: true
55 |
56 | - type: input
57 | id: DatasetSource
58 | attributes:
59 | label: 数据集来源(用于判断数据集质量)
60 | description: 如:UVR处理过的vtb直播音频、录音棚录制
61 | validations:
62 | required: true
63 |
64 | - type: input
65 | id: WhereOccurs
66 | attributes:
67 | label: 出现问题的环节或执行的命令
68 | description: 如:预处理、训练、`python preprocess_hubert_f0.py`
69 | validations:
70 | required: true
71 |
72 | - type: textarea
73 | id: Description
74 | attributes:
75 | label: 情况描述
76 | description: 在这里描述遇到的情况,越详细越好
77 | validations:
78 | required: true
79 |
80 | - type: textarea
81 | id: Log
82 | attributes:
83 | label: 日志
84 | description: 将从执行命令到执行完毕输出的所有信息(包括你所执行的命令)粘贴到[pastebin.com](https://pastebin.com/)并把剪贴板链接贴到这里,日志量少的话也可以直接贴在下面
85 | render: python
86 | validations:
87 | required: true
88 |
89 | - type: textarea
90 | id: Supplementary
91 | attributes:
92 | label: 补充说明
93 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report_en_US.yaml:
--------------------------------------------------------------------------------
1 | name: Bug report
2 | description: Encountered an bug?!
3 | title: '[Bug]: '
4 | labels: [ "bug?" ]
5 |
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | # Please fill in the following information according to your actual environment
11 |
12 | - type: input
13 | id: System
14 | attributes:
15 | label: OS version
16 | description: Windows run `winver` | Linux run `uname -a`
17 | validations:
18 | required: true
19 |
20 | - type: input
21 | id: GPU
22 | attributes:
23 | label: GPU
24 | description: Run `nvidia-smi`
25 | validations:
26 | required: true
27 |
28 | - type: input
29 | id: PythonVersion
30 | attributes:
31 | label: Python version
32 | description: Run `python -V`
33 | validations:
34 | required: true
35 |
36 | - type: input
37 | id: PyTorchVersion
38 | attributes:
39 | label: PyTorch version
40 | description: Run `pip show torch`
41 | validations:
42 | required: true
43 |
44 | - type: dropdown
45 | id: Branch
46 | attributes:
47 | label: Branch of sovits
48 | options:
49 | - 4.0(Default)
50 | - 4.0-v2
51 | - 3.0-32k
52 | - 3.0-48k
53 | validations:
54 | required: true
55 |
56 | - type: input
57 | id: DatasetSource
58 | attributes:
59 | label: Dataset source (Used to judge the dataset quality)
60 | description: Such as UVR-processed streaming audio / Recorded in recording studio
61 | validations:
62 | required: true
63 |
64 | - type: input
65 | id: WhereOccurs
66 | attributes:
67 | label: Where thr problem occurs or what command you executed
68 | description: Such as Preprocessing / Training / `python preprocess_hubert_f0.py`
69 | validations:
70 | required: true
71 |
72 | - type: textarea
73 | id: Description
74 | attributes:
75 | label: Situation description
76 | description: Describe your situation here, the more detailed the better.
77 | validations:
78 | required: true
79 |
80 | - type: textarea
81 | id: Log
82 | attributes:
83 | label: Log
84 | description: All information output from the command you executed to the end of execution (include the command). You can paste them to [pastebin.com](https://pastebin.com/) then paste the short link here. It can also be directly posted below if there is only few text.
85 | render: python
86 | validations:
87 | required: true
88 |
89 | - type: textarea
90 | id: Supplementary
91 | attributes:
92 | label: Supplementary description
93 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: 讨论区 / Discussions
4 | url: https://github.com/svc-develop-team/so-vits-svc/discussions
5 | about: 简单的询问/讨论请转至讨论区或发起一个低优先级的Default issue / For simple inquiries / discussions, please go to the discussions or raise a low priority Default issue
6 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/default.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Default issue
3 | about: 如果模板中没有你想发起的issue类型,可以选择此项,但这个issue也许会获得一个较低的处理优先级 / If there is no issue type you want to raise, you can start with this one. But this issue maybe will get a lower priority to deal with.
4 | title: ''
5 | labels: 'not urgent'
6 | assignees: ''
7 | ---
8 |
--------------------------------------------------------------------------------
/.github/workflows/reviewdog.yml:
--------------------------------------------------------------------------------
1 | name: Ruff Autofix
2 | on: [pull_request]
3 | jobs:
4 | ruff:
5 | permissions:
6 | checks: write
7 | contents: read
8 | pull-requests: write
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v3
12 | - uses: chartboost/ruff-action@v1
13 | with:
14 | args: --fix -e
15 | - uses: reviewdog/action-suggester@v1
16 | with:
17 | tool_name: ruff
18 |
19 |
--------------------------------------------------------------------------------
/.github/workflows/ruff.yml:
--------------------------------------------------------------------------------
1 | name: Ruff
2 | on: [push, pull_request]
3 | jobs:
4 | ruff:
5 | runs-on: ubuntu-latest
6 | steps:
7 | - uses: actions/checkout@v3
8 | - uses: chartboost/ruff-action@v1
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/python
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python
4 |
5 | ### Python ###
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 | checkpoints/
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 | pytestdebug.log
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 | doc/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .venv
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
137 |
138 | # pytype static type analyzer
139 | .pytype/
140 |
141 | # End of https://www.toptal.com/developers/gitignore/api/python
142 |
143 | /shelf/
144 | /workspace.xml
145 |
146 | dataset
147 | dataset_raw
148 | raw
149 | results
150 | inference/chunks_temp.json
151 | logs
152 | hubert/checkpoint_best_legacy_500.pt
153 | configs/config.json
154 | filelists/test.txt
155 | filelists/train.txt
156 | filelists/val.txt
157 | .idea/
158 | .vscode/
159 | .idea/modules.xml
160 | .idea/so-vits-svc.iml
161 | .idea/vcs.xml
162 | .idea/inspectionProfiles/profiles_settings.xml
163 | .idea/inspectionProfiles/Project_Default.xml
164 | pretrain/
165 | .vscode/launch.json
166 |
167 | trained/**/
168 |
--------------------------------------------------------------------------------
/.ruff.toml:
--------------------------------------------------------------------------------
1 | select = ["E", "F", "I"]
2 |
3 | # Never enforce `E501` (line length violations).
4 | ignore = ["E501", "E741"]
5 |
--------------------------------------------------------------------------------
/cluster/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sklearn.cluster import KMeans
3 |
4 |
5 | def get_cluster_model(ckpt_path):
6 | checkpoint = torch.load(ckpt_path)
7 | kmeans_dict = {}
8 | for spk, ckpt in checkpoint.items():
9 | km = KMeans(ckpt["n_features_in_"])
10 | km.__dict__["n_features_in_"] = ckpt["n_features_in_"]
11 | km.__dict__["_n_threads"] = ckpt["_n_threads"]
12 | km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"]
13 | kmeans_dict[spk] = km
14 | return kmeans_dict
15 |
16 | def get_cluster_result(model, x, speaker):
17 | """
18 | x: np.array [t, 256]
19 | return cluster class result
20 | """
21 | return model[speaker].predict(x)
22 |
23 | def get_cluster_center_result(model, x,speaker):
24 | """x: np.array [t, 256]"""
25 | predict = model[speaker].predict(x)
26 | return model[speaker].cluster_centers_[predict]
27 |
28 | def get_center(model, x,speaker):
29 | return model[speaker].cluster_centers_[x]
30 |
--------------------------------------------------------------------------------
/cluster/train_cluster.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import time
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 | import tqdm
10 | from kmeans import KMeansGPU
11 | from sklearn.cluster import KMeans, MiniBatchKMeans
12 |
13 | logging.basicConfig(level=logging.INFO)
14 | logger = logging.getLogger(__name__)
15 |
16 | def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉,虽然库支持但是也不考虑
17 | if str(in_dir).endswith(".ipynb_checkpoints"):
18 | logger.info(f"Ignore {in_dir}")
19 |
20 | logger.info(f"Loading features from {in_dir}")
21 | features = []
22 | nums = 0
23 | for path in tqdm.tqdm(in_dir.glob("*.soft.pt")):
24 | # for name in os.listdir(in_dir):
25 | # path="%s/%s"%(in_dir,name)
26 | features.append(torch.load(path,map_location="cpu").squeeze(0).numpy().T)
27 | # print(features[-1].shape)
28 | features = np.concatenate(features, axis=0)
29 | print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype)
30 | features = features.astype(np.float32)
31 | logger.info(f"Clustering features of shape: {features.shape}")
32 | t = time.time()
33 | if(use_gpu is False):
34 | if use_minibatch:
35 | kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features)
36 | else:
37 | kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features)
38 | else:
39 | kmeans = KMeansGPU(n_clusters=n_clusters, mode='euclidean', verbose=2 if verbose else 0,max_iter=500,tol=1e-2)#
40 | features=torch.from_numpy(features)#.to(device)
41 | kmeans.fit_predict(features)#
42 |
43 | print(time.time()-t, "s")
44 |
45 | x = {
46 | "n_features_in_": kmeans.n_features_in_ if use_gpu is False else features.shape[1],
47 | "_n_threads": kmeans._n_threads if use_gpu is False else 4,
48 | "cluster_centers_": kmeans.cluster_centers_ if use_gpu is False else kmeans.centroids.cpu().numpy(),
49 | }
50 | print("end")
51 |
52 | return x
53 |
54 | if __name__ == "__main__":
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument('--dataset', type=Path, default="./dataset/44k",
57 | help='path of training data directory')
58 | parser.add_argument('--output', type=Path, default="logs/44k",
59 | help='path of model output directory')
60 | parser.add_argument('--gpu',action='store_true', default=False ,
61 | help='to use GPU')
62 |
63 |
64 | args = parser.parse_args()
65 |
66 | checkpoint_dir = args.output
67 | dataset = args.dataset
68 | use_gpu = args.gpu
69 | n_clusters = 10000
70 |
71 | ckpt = {}
72 | for spk in os.listdir(dataset):
73 | if os.path.isdir(dataset/spk):
74 | print(f"train kmeans for {spk}...")
75 | in_dir = dataset/spk
76 | x = train_cluster(in_dir, n_clusters,use_minibatch=False,verbose=False,use_gpu=use_gpu)
77 | ckpt[spk] = x
78 |
79 | checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt"
80 | checkpoint_path.parent.mkdir(exist_ok=True, parents=True)
81 | torch.save(
82 | ckpt,
83 | checkpoint_path,
84 | )
85 |
86 |
--------------------------------------------------------------------------------
/compress_model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 | import utils
6 | from models import SynthesizerTrn
7 |
8 |
9 | def copyStateDict(state_dict):
10 | if list(state_dict.keys())[0].startswith('module'):
11 | start_idx = 1
12 | else:
13 | start_idx = 0
14 | new_state_dict = OrderedDict()
15 | for k, v in state_dict.items():
16 | name = ','.join(k.split('.')[start_idx:])
17 | new_state_dict[name] = v
18 | return new_state_dict
19 |
20 |
21 | def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str):
22 | hps = utils.get_hparams_from_file(config)
23 |
24 | net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
25 | hps.train.segment_size // hps.data.hop_length,
26 | **hps.model)
27 |
28 | optim_g = torch.optim.AdamW(net_g.parameters(),
29 | hps.train.learning_rate,
30 | betas=hps.train.betas,
31 | eps=hps.train.eps)
32 |
33 | state_dict_g = torch.load(input_model, map_location="cpu")
34 | new_dict_g = copyStateDict(state_dict_g)
35 | keys = []
36 | for k, v in new_dict_g['model'].items():
37 | if "enc_q" in k: continue # noqa: E701
38 | keys.append(k)
39 |
40 | new_dict_g = {k: new_dict_g['model'][k].half() for k in keys} if ishalf else {k: new_dict_g['model'][k] for k in keys}
41 |
42 | torch.save(
43 | {
44 | 'model': new_dict_g,
45 | 'iteration': 0,
46 | 'optimizer': optim_g.state_dict(),
47 | 'learning_rate': 0.0001
48 | }, output_model)
49 |
50 |
51 | if __name__ == "__main__":
52 | import argparse
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("-c",
55 | "--config",
56 | type=str,
57 | default='configs/config.json')
58 | parser.add_argument("-i", "--input", type=str)
59 | parser.add_argument("-o", "--output", type=str, default=None)
60 | parser.add_argument('-hf', '--half', action='store_true', default=False, help='Save as FP16')
61 |
62 | args = parser.parse_args()
63 |
64 | output = args.output
65 |
66 | if output is None:
67 | import os.path
68 | filename, ext = os.path.splitext(args.input)
69 | half = "_half" if args.half else ""
70 | output = filename + "_release" + half + ext
71 |
72 | removeOptimizer(args.config, args.input, args.half, output)
--------------------------------------------------------------------------------
/configs/config.json:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/configs/config.json
--------------------------------------------------------------------------------
/configs/diffusion.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/configs/diffusion.yaml
--------------------------------------------------------------------------------
/configs_template/config_template.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 800,
5 | "seed": 1234,
6 | "epochs": 10000,
7 | "learning_rate": 0.0001,
8 | "betas": [
9 | 0.8,
10 | 0.99
11 | ],
12 | "eps": 1e-09,
13 | "batch_size": 6,
14 | "fp16_run": false,
15 | "half_type": "fp16",
16 | "lr_decay": 0.999875,
17 | "segment_size": 10240,
18 | "init_lr_ratio": 1,
19 | "warmup_epochs": 0,
20 | "c_mel": 45,
21 | "c_kl": 1.0,
22 | "use_sr": true,
23 | "max_speclen": 512,
24 | "port": "8001",
25 | "keep_ckpts": 3,
26 | "all_in_mem": false,
27 | "vol_aug":false
28 | },
29 | "data": {
30 | "training_files": "filelists/train.txt",
31 | "validation_files": "filelists/val.txt",
32 | "max_wav_value": 32768.0,
33 | "sampling_rate": 44100,
34 | "filter_length": 2048,
35 | "hop_length": 512,
36 | "win_length": 2048,
37 | "n_mel_channels": 80,
38 | "mel_fmin": 0.0,
39 | "mel_fmax": 22050,
40 | "unit_interpolate_mode":"nearest"
41 | },
42 | "model": {
43 | "inter_channels": 192,
44 | "hidden_channels": 192,
45 | "filter_channels": 768,
46 | "n_heads": 2,
47 | "n_layers": 6,
48 | "kernel_size": 3,
49 | "p_dropout": 0.1,
50 | "resblock": "1",
51 | "resblock_kernel_sizes": [3,7,11],
52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
53 | "upsample_rates": [ 8, 8, 2, 2, 2],
54 | "upsample_initial_channel": 512,
55 | "upsample_kernel_sizes": [16,16, 4, 4, 4],
56 | "n_layers_q": 3,
57 | "n_layers_trans_flow": 3,
58 | "n_flow_layer": 4,
59 | "use_spectral_norm": false,
60 | "gin_channels": 768,
61 | "ssl_dim": 768,
62 | "n_speakers": 200,
63 | "vocoder_name":"nsf-hifigan",
64 | "speech_encoder":"vec768l12",
65 | "speaker_embedding":false,
66 | "vol_embedding":false,
67 | "use_depthwise_conv":false,
68 | "flow_share_parameter": false,
69 | "use_automatic_f0_prediction": true,
70 | "use_transformer_flow": false
71 | },
72 | "spk": {
73 | "nyaru": 0,
74 | "huiyu": 1,
75 | "nen": 2,
76 | "paimon": 3,
77 | "yunhao": 4
78 | }
79 | }
--------------------------------------------------------------------------------
/configs_template/config_tiny_template.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 800,
5 | "seed": 1234,
6 | "epochs": 10000,
7 | "learning_rate": 0.0001,
8 | "betas": [
9 | 0.8,
10 | 0.99
11 | ],
12 | "eps": 1e-09,
13 | "batch_size": 6,
14 | "fp16_run": false,
15 | "half_type": "fp16",
16 | "lr_decay": 0.999875,
17 | "segment_size": 10240,
18 | "init_lr_ratio": 1,
19 | "warmup_epochs": 0,
20 | "c_mel": 45,
21 | "c_kl": 1.0,
22 | "use_sr": true,
23 | "max_speclen": 512,
24 | "port": "8001",
25 | "keep_ckpts": 3,
26 | "all_in_mem": false,
27 | "vol_aug":false
28 | },
29 | "data": {
30 | "training_files": "filelists/train.txt",
31 | "validation_files": "filelists/val.txt",
32 | "max_wav_value": 32768.0,
33 | "sampling_rate": 44100,
34 | "filter_length": 2048,
35 | "hop_length": 512,
36 | "win_length": 2048,
37 | "n_mel_channels": 80,
38 | "mel_fmin": 0.0,
39 | "mel_fmax": 22050,
40 | "unit_interpolate_mode":"nearest"
41 | },
42 | "model": {
43 | "inter_channels": 192,
44 | "hidden_channels": 192,
45 | "filter_channels": 512,
46 | "n_heads": 2,
47 | "n_layers": 6,
48 | "kernel_size": 3,
49 | "p_dropout": 0.1,
50 | "resblock": "1",
51 | "resblock_kernel_sizes": [3,7,11],
52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
53 | "upsample_rates": [ 8, 8, 2, 2, 2],
54 | "upsample_initial_channel": 400,
55 | "upsample_kernel_sizes": [16,16, 4, 4, 4],
56 | "n_layers_q": 3,
57 | "n_layers_trans_flow": 3,
58 | "n_flow_layer": 4,
59 | "use_spectral_norm": false,
60 | "gin_channels": 768,
61 | "ssl_dim": 768,
62 | "n_speakers": 200,
63 | "vocoder_name":"nsf-hifigan",
64 | "speech_encoder":"vec768l12",
65 | "speaker_embedding":false,
66 | "vol_embedding":false,
67 | "use_depthwise_conv":true,
68 | "flow_share_parameter": true,
69 | "use_automatic_f0_prediction": true,
70 | "use_transformer_flow": false
71 | },
72 | "spk": {
73 | "nyaru": 0,
74 | "huiyu": 1,
75 | "nen": 2,
76 | "paimon": 3,
77 | "yunhao": 4
78 | }
79 | }
--------------------------------------------------------------------------------
/configs_template/diffusion_template.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | sampling_rate: 44100
3 | block_size: 512 # Equal to hop_length
4 | duration: 2 # Audio duration during training, must be less than the duration of the shortest audio clip
5 | encoder: 'vec768l12' # 'hubertsoft', 'vec256l9', 'vec768l12'
6 | cnhubertsoft_gate: 10
7 | encoder_sample_rate: 16000
8 | encoder_hop_size: 320
9 | encoder_out_channels: 768 # 256 if using 'hubertsoft'
10 | training_files: "filelists/train.txt"
11 | validation_files: "filelists/val.txt"
12 | extensions: # List of extension included in the data collection
13 | - wav
14 | unit_interpolate_mode: "nearest"
15 | model:
16 | type: 'Diffusion'
17 | n_layers: 20
18 | n_chans: 512
19 | n_hidden: 256
20 | use_pitch_aug: true
21 | timesteps : 1000
22 | k_step_max: 0 # must <= timesteps, If it is 0, train all
23 | n_spk: 1 # max number of different speakers
24 | device: cuda
25 | vocoder:
26 | type: 'nsf-hifigan'
27 | ckpt: 'pretrain/nsf_hifigan/model'
28 | infer:
29 | speedup: 10
30 | method: 'dpm-solver++' # 'pndm' or 'dpm-solver' or 'ddim' or 'unipc' or 'dpm-solver++'
31 | env:
32 | expdir: logs/44k/diffusion
33 | gpu_id: 0
34 | train:
35 | num_workers: 4 # If your cpu and gpu are both very strong, set to 0 may be faster!
36 | amp_dtype: fp32 # fp32, fp16 or bf16 (fp16 or bf16 may be faster if it is supported by your gpu)
37 | batch_size: 48
38 | cache_all_data: true # Save Internal-Memory or Graphics-Memory if it is false, but may be slow
39 | cache_device: 'cpu' # Set to 'cuda' to cache the data into the Graphics-Memory, fastest speed for strong gpu
40 | cache_fp16: true
41 | epochs: 100000
42 | interval_log: 10
43 | interval_val: 2000
44 | interval_force_save: 5000
45 | lr: 0.0001
46 | decay_step: 100000
47 | gamma: 0.5
48 | weight_decay: 0
49 | save_opt: false
50 | spk:
51 | 'nyaru': 0
--------------------------------------------------------------------------------
/dataset_raw/wav_structure.txt:
--------------------------------------------------------------------------------
1 | 数据集准备
2 |
3 | raw
4 | ├───speaker0
5 | │ ├───xxx1-xxx1.wav
6 | │ ├───...
7 | │ └───Lxx-0xx8.wav
8 | └───speaker1
9 | ├───xx2-0xxx2.wav
10 | ├───...
11 | └───xxx7-xxx007.wav
12 |
13 | 此外还需要编辑config.json
14 |
15 | "n_speakers": 10
16 |
17 | "spk":{
18 | "speaker0": 0,
19 | "speaker1": 1,
20 | }
21 |
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/diffusion/__init__.py
--------------------------------------------------------------------------------
/diffusion/how to export onnx.md:
--------------------------------------------------------------------------------
1 | - Open [onnx_export](onnx_export.py)
2 | - project_name = "dddsp" change "project_name" to your project name
3 | - model_path = f'{project_name}/model_500000.pt' change "model_path" to your model path
4 | - Run
--------------------------------------------------------------------------------
/diffusion/infer_gt_mel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from diffusion.unit2mel import load_model_vocoder
5 |
6 |
7 | class DiffGtMel:
8 | def __init__(self, project_path=None, device=None):
9 | self.project_path = project_path
10 | if device is not None:
11 | self.device = device
12 | else:
13 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
14 | self.model = None
15 | self.vocoder = None
16 | self.args = None
17 |
18 | def flush_model(self, project_path, ddsp_config=None):
19 | if (self.model is None) or (project_path != self.project_path):
20 | model, vocoder, args = load_model_vocoder(project_path, device=self.device)
21 | if self.check_args(ddsp_config, args):
22 | self.model = model
23 | self.vocoder = vocoder
24 | self.args = args
25 |
26 | def check_args(self, args1, args2):
27 | if args1.data.block_size != args2.data.block_size:
28 | raise ValueError("DDSP与DIFF模型的block_size不一致")
29 | if args1.data.sampling_rate != args2.data.sampling_rate:
30 | raise ValueError("DDSP与DIFF模型的sampling_rate不一致")
31 | if args1.data.encoder != args2.data.encoder:
32 | raise ValueError("DDSP与DIFF模型的encoder不一致")
33 | return True
34 |
35 | def __call__(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm',
36 | spk_mix_dict=None, start_frame=0):
37 | input_mel = self.vocoder.extract(audio, self.args.data.sampling_rate)
38 | out_mel = self.model(
39 | hubert,
40 | f0,
41 | volume,
42 | spk_id=spk_id,
43 | spk_mix_dict=spk_mix_dict,
44 | gt_spec=input_mel,
45 | infer=True,
46 | infer_speedup=acc,
47 | method=method,
48 | k_step=k_step,
49 | use_tqdm=False)
50 | if start_frame > 0:
51 | out_mel = out_mel[:, start_frame:, :]
52 | f0 = f0[:, start_frame:, :]
53 | output = self.vocoder.infer(out_mel, f0)
54 | if start_frame > 0:
55 | output = F.pad(output, (start_frame * self.vocoder.vocoder_hop_size, 0))
56 | return output
57 |
58 | def infer(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm', silence_front=0,
59 | use_silence=False, spk_mix_dict=None):
60 | start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size)
61 | if use_silence:
62 | audio = audio[:, start_frame * self.vocoder.vocoder_hop_size:]
63 | f0 = f0[:, start_frame:, :]
64 | hubert = hubert[:, start_frame:, :]
65 | volume = volume[:, start_frame:, :]
66 | _start_frame = 0
67 | else:
68 | _start_frame = start_frame
69 | audio = self.__call__(audio, f0, hubert, volume, acc=acc, spk_id=spk_id, k_step=k_step,
70 | method=method, spk_mix_dict=spk_mix_dict, start_frame=_start_frame)
71 | if use_silence:
72 | if start_frame > 0:
73 | audio = F.pad(audio, (start_frame * self.vocoder.vocoder_hop_size, 0))
74 | return audio
75 |
--------------------------------------------------------------------------------
/diffusion/logger/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/diffusion/logger/__init__.py
--------------------------------------------------------------------------------
/diffusion/logger/saver.py:
--------------------------------------------------------------------------------
1 | '''
2 | author: wayn391@mastertones
3 | '''
4 |
5 | import datetime
6 | import os
7 | import time
8 |
9 | import matplotlib.pyplot as plt
10 | import torch
11 | import yaml
12 | from torch.utils.tensorboard import SummaryWriter
13 |
14 |
15 | class Saver(object):
16 | def __init__(
17 | self,
18 | args,
19 | initial_global_step=-1):
20 |
21 | self.expdir = args.env.expdir
22 | self.sample_rate = args.data.sampling_rate
23 |
24 | # cold start
25 | self.global_step = initial_global_step
26 | self.init_time = time.time()
27 | self.last_time = time.time()
28 |
29 | # makedirs
30 | os.makedirs(self.expdir, exist_ok=True)
31 |
32 | # path
33 | self.path_log_info = os.path.join(self.expdir, 'log_info.txt')
34 |
35 | # ckpt
36 | os.makedirs(self.expdir, exist_ok=True)
37 |
38 | # writer
39 | self.writer = SummaryWriter(os.path.join(self.expdir, 'logs'))
40 |
41 | # save config
42 | path_config = os.path.join(self.expdir, 'config.yaml')
43 | with open(path_config, "w") as out_config:
44 | yaml.dump(dict(args), out_config)
45 |
46 |
47 | def log_info(self, msg):
48 | '''log method'''
49 | if isinstance(msg, dict):
50 | msg_list = []
51 | for k, v in msg.items():
52 | tmp_str = ''
53 | if isinstance(v, int):
54 | tmp_str = '{}: {:,}'.format(k, v)
55 | else:
56 | tmp_str = '{}: {}'.format(k, v)
57 |
58 | msg_list.append(tmp_str)
59 | msg_str = '\n'.join(msg_list)
60 | else:
61 | msg_str = msg
62 |
63 | # dsplay
64 | print(msg_str)
65 |
66 | # save
67 | with open(self.path_log_info, 'a') as fp:
68 | fp.write(msg_str+'\n')
69 |
70 | def log_value(self, dict):
71 | for k, v in dict.items():
72 | self.writer.add_scalar(k, v, self.global_step)
73 |
74 | def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5):
75 | spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1)
76 | spec = spec_cat[0]
77 | if isinstance(spec, torch.Tensor):
78 | spec = spec.cpu().numpy()
79 | fig = plt.figure(figsize=(12, 9))
80 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
81 | plt.tight_layout()
82 | self.writer.add_figure(name, fig, self.global_step)
83 |
84 | def log_audio(self, dict):
85 | for k, v in dict.items():
86 | self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate)
87 |
88 | def get_interval_time(self, update=True):
89 | cur_time = time.time()
90 | time_interval = cur_time - self.last_time
91 | if update:
92 | self.last_time = cur_time
93 | return time_interval
94 |
95 | def get_total_time(self, to_str=True):
96 | total_time = time.time() - self.init_time
97 | if to_str:
98 | total_time = str(datetime.timedelta(
99 | seconds=total_time))[:-5]
100 | return total_time
101 |
102 | def save_model(
103 | self,
104 | model,
105 | optimizer,
106 | name='model',
107 | postfix='',
108 | to_json=False):
109 | # path
110 | if postfix:
111 | postfix = '_' + postfix
112 | path_pt = os.path.join(
113 | self.expdir , name+postfix+'.pt')
114 |
115 | # check
116 | print(' [*] model checkpoint saved: {}'.format(path_pt))
117 |
118 | # save
119 | if optimizer is not None:
120 | torch.save({
121 | 'global_step': self.global_step,
122 | 'model': model.state_dict(),
123 | 'optimizer': optimizer.state_dict()}, path_pt)
124 | else:
125 | torch.save({
126 | 'global_step': self.global_step,
127 | 'model': model.state_dict()}, path_pt)
128 |
129 |
130 | def delete_model(self, name='model', postfix=''):
131 | # path
132 | if postfix:
133 | postfix = '_' + postfix
134 | path_pt = os.path.join(
135 | self.expdir , name+postfix+'.pt')
136 |
137 | # delete
138 | if os.path.exists(path_pt):
139 | os.remove(path_pt)
140 | print(' [*] model checkpoint deleted: {}'.format(path_pt))
141 |
142 | def global_step_increment(self):
143 | self.global_step += 1
144 |
145 |
146 |
--------------------------------------------------------------------------------
/diffusion/logger/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import torch
5 | import yaml
6 |
7 |
8 | def traverse_dir(
9 | root_dir,
10 | extensions,
11 | amount=None,
12 | str_include=None,
13 | str_exclude=None,
14 | is_pure=False,
15 | is_sort=False,
16 | is_ext=True):
17 |
18 | file_list = []
19 | cnt = 0
20 | for root, _, files in os.walk(root_dir):
21 | for file in files:
22 | if any([file.endswith(f".{ext}") for ext in extensions]):
23 | # path
24 | mix_path = os.path.join(root, file)
25 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
26 |
27 | # amount
28 | if (amount is not None) and (cnt == amount):
29 | if is_sort:
30 | file_list.sort()
31 | return file_list
32 |
33 | # check string
34 | if (str_include is not None) and (str_include not in pure_path):
35 | continue
36 | if (str_exclude is not None) and (str_exclude in pure_path):
37 | continue
38 |
39 | if not is_ext:
40 | ext = pure_path.split('.')[-1]
41 | pure_path = pure_path[:-(len(ext)+1)]
42 | file_list.append(pure_path)
43 | cnt += 1
44 | if is_sort:
45 | file_list.sort()
46 | return file_list
47 |
48 |
49 |
50 | class DotDict(dict):
51 | def __getattr__(*args):
52 | val = dict.get(*args)
53 | return DotDict(val) if type(val) is dict else val
54 |
55 | __setattr__ = dict.__setitem__
56 | __delattr__ = dict.__delitem__
57 |
58 |
59 | def get_network_paras_amount(model_dict):
60 | info = dict()
61 | for model_name, model in model_dict.items():
62 | # all_params = sum(p.numel() for p in model.parameters())
63 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
64 |
65 | info[model_name] = trainable_params
66 | return info
67 |
68 |
69 | def load_config(path_config):
70 | with open(path_config, "r") as config:
71 | args = yaml.safe_load(config)
72 | args = DotDict(args)
73 | # print(args)
74 | return args
75 |
76 | def save_config(path_config,config):
77 | config = dict(config)
78 | with open(path_config, "w") as f:
79 | yaml.dump(config, f)
80 |
81 | def to_json(path_params, path_json):
82 | params = torch.load(path_params, map_location=torch.device('cpu'))
83 | raw_state_dict = {}
84 | for k, v in params.items():
85 | val = v.flatten().numpy().tolist()
86 | raw_state_dict[k] = val
87 |
88 | with open(path_json, 'w') as outfile:
89 | json.dump(raw_state_dict, outfile,indent= "\t")
90 |
91 |
92 | def convert_tensor_to_numpy(tensor, is_squeeze=True):
93 | if is_squeeze:
94 | tensor = tensor.squeeze()
95 | if tensor.requires_grad:
96 | tensor = tensor.detach()
97 | if tensor.is_cuda:
98 | tensor = tensor.cpu()
99 | return tensor.numpy()
100 |
101 |
102 | def load_model(
103 | expdir,
104 | model,
105 | optimizer,
106 | name='model',
107 | postfix='',
108 | device='cpu'):
109 | if postfix == '':
110 | postfix = '_' + postfix
111 | path = os.path.join(expdir, name+postfix)
112 | path_pt = traverse_dir(expdir, ['pt'], is_ext=False)
113 | global_step = 0
114 | if len(path_pt) > 0:
115 | steps = [s[len(path):] for s in path_pt]
116 | maxstep = max([int(s) if s.isdigit() else 0 for s in steps])
117 | if maxstep >= 0:
118 | path_pt = path+str(maxstep)+'.pt'
119 | else:
120 | path_pt = path+'best.pt'
121 | print(' [*] restoring model from', path_pt)
122 | ckpt = torch.load(path_pt, map_location=torch.device(device))
123 | global_step = ckpt['global_step']
124 | model.load_state_dict(ckpt['model'], strict=False)
125 | if ckpt.get("optimizer") is not None:
126 | optimizer.load_state_dict(ckpt['optimizer'])
127 | return global_step, model, optimizer
128 |
--------------------------------------------------------------------------------
/diffusion/vocoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchaudio.transforms import Resample
3 |
4 | from vdecoder.nsf_hifigan.models import load_config, load_model
5 | from vdecoder.nsf_hifigan.nvSTFT import STFT
6 |
7 |
8 | class Vocoder:
9 | def __init__(self, vocoder_type, vocoder_ckpt, device = None):
10 | if device is None:
11 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
12 | self.device = device
13 |
14 | if vocoder_type == 'nsf-hifigan':
15 | self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device)
16 | elif vocoder_type == 'nsf-hifigan-log10':
17 | self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device)
18 | else:
19 | raise ValueError(f" [x] Unknown vocoder: {vocoder_type}")
20 |
21 | self.resample_kernel = {}
22 | self.vocoder_sample_rate = self.vocoder.sample_rate()
23 | self.vocoder_hop_size = self.vocoder.hop_size()
24 | self.dimension = self.vocoder.dimension()
25 |
26 | def extract(self, audio, sample_rate, keyshift=0):
27 |
28 | # resample
29 | if sample_rate == self.vocoder_sample_rate:
30 | audio_res = audio
31 | else:
32 | key_str = str(sample_rate)
33 | if key_str not in self.resample_kernel:
34 | self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device)
35 | audio_res = self.resample_kernel[key_str](audio)
36 |
37 | # extract
38 | mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins
39 | return mel
40 |
41 | def infer(self, mel, f0):
42 | f0 = f0[:,:mel.size(1),0] # B, n_frames
43 | audio = self.vocoder(mel, f0)
44 | return audio
45 |
46 |
47 | class NsfHifiGAN(torch.nn.Module):
48 | def __init__(self, model_path, device=None):
49 | super().__init__()
50 | if device is None:
51 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
52 | self.device = device
53 | self.model_path = model_path
54 | self.model = None
55 | self.h = load_config(model_path)
56 | self.stft = STFT(
57 | self.h.sampling_rate,
58 | self.h.num_mels,
59 | self.h.n_fft,
60 | self.h.win_size,
61 | self.h.hop_size,
62 | self.h.fmin,
63 | self.h.fmax)
64 |
65 | def sample_rate(self):
66 | return self.h.sampling_rate
67 |
68 | def hop_size(self):
69 | return self.h.hop_size
70 |
71 | def dimension(self):
72 | return self.h.num_mels
73 |
74 | def extract(self, audio, keyshift=0):
75 | mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins
76 | return mel
77 |
78 | def forward(self, mel, f0):
79 | if self.model is None:
80 | print('| Load HifiGAN: ', self.model_path)
81 | self.model, self.h = load_model(self.model_path, device=self.device)
82 | with torch.no_grad():
83 | c = mel.transpose(1, 2)
84 | audio = self.model(c, f0)
85 | return audio
86 |
87 | class NsfHifiGANLog10(NsfHifiGAN):
88 | def forward(self, mel, f0):
89 | if self.model is None:
90 | print('| Load HifiGAN: ', self.model_path)
91 | self.model, self.h = load_model(self.model_path, device=self.device)
92 | with torch.no_grad():
93 | c = 0.434294 * mel.transpose(1, 2)
94 | audio = self.model(c, f0)
95 | return audio
--------------------------------------------------------------------------------
/diffusion/wavenet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from math import sqrt
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.nn import Mish
8 |
9 |
10 | class Conv1d(torch.nn.Conv1d):
11 | def __init__(self, *args, **kwargs):
12 | super().__init__(*args, **kwargs)
13 | nn.init.kaiming_normal_(self.weight)
14 |
15 |
16 | class SinusoidalPosEmb(nn.Module):
17 | def __init__(self, dim):
18 | super().__init__()
19 | self.dim = dim
20 |
21 | def forward(self, x):
22 | device = x.device
23 | half_dim = self.dim // 2
24 | emb = math.log(10000) / (half_dim - 1)
25 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
26 | emb = x[:, None] * emb[None, :]
27 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
28 | return emb
29 |
30 |
31 | class ResidualBlock(nn.Module):
32 | def __init__(self, encoder_hidden, residual_channels, dilation):
33 | super().__init__()
34 | self.residual_channels = residual_channels
35 | self.dilated_conv = nn.Conv1d(
36 | residual_channels,
37 | 2 * residual_channels,
38 | kernel_size=3,
39 | padding=dilation,
40 | dilation=dilation
41 | )
42 | self.diffusion_projection = nn.Linear(residual_channels, residual_channels)
43 | self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1)
44 | self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)
45 |
46 | def forward(self, x, conditioner, diffusion_step):
47 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
48 | conditioner = self.conditioner_projection(conditioner)
49 | y = x + diffusion_step
50 |
51 | y = self.dilated_conv(y) + conditioner
52 |
53 | # Using torch.split instead of torch.chunk to avoid using onnx::Slice
54 | gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
55 | y = torch.sigmoid(gate) * torch.tanh(filter)
56 |
57 | y = self.output_projection(y)
58 |
59 | # Using torch.split instead of torch.chunk to avoid using onnx::Slice
60 | residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
61 | return (x + residual) / math.sqrt(2.0), skip
62 |
63 |
64 | class WaveNet(nn.Module):
65 | def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256):
66 | super().__init__()
67 | self.input_projection = Conv1d(in_dims, n_chans, 1)
68 | self.diffusion_embedding = SinusoidalPosEmb(n_chans)
69 | self.mlp = nn.Sequential(
70 | nn.Linear(n_chans, n_chans * 4),
71 | Mish(),
72 | nn.Linear(n_chans * 4, n_chans)
73 | )
74 | self.residual_layers = nn.ModuleList([
75 | ResidualBlock(
76 | encoder_hidden=n_hidden,
77 | residual_channels=n_chans,
78 | dilation=1
79 | )
80 | for i in range(n_layers)
81 | ])
82 | self.skip_projection = Conv1d(n_chans, n_chans, 1)
83 | self.output_projection = Conv1d(n_chans, in_dims, 1)
84 | nn.init.zeros_(self.output_projection.weight)
85 |
86 | def forward(self, spec, diffusion_step, cond):
87 | """
88 | :param spec: [B, 1, M, T]
89 | :param diffusion_step: [B, 1]
90 | :param cond: [B, M, T]
91 | :return:
92 | """
93 | x = spec.squeeze(1)
94 | x = self.input_projection(x) # [B, residual_channel, T]
95 |
96 | x = F.relu(x)
97 | diffusion_step = self.diffusion_embedding(diffusion_step)
98 | diffusion_step = self.mlp(diffusion_step)
99 | skip = []
100 | for layer in self.residual_layers:
101 | x, skip_connection = layer(x, cond, diffusion_step)
102 | skip.append(skip_connection)
103 |
104 | x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
105 | x = self.skip_projection(x)
106 | x = F.relu(x)
107 | x = self.output_projection(x) # [B, mel_bins, T]
108 | return x[:, None, :, :]
109 |
--------------------------------------------------------------------------------
/edgetts/tts.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import random
3 | import sys
4 |
5 | import edge_tts
6 | from edge_tts import VoicesManager
7 | from langdetect import DetectorFactory, detect
8 |
9 | DetectorFactory.seed = 0
10 |
11 | TEXT = sys.argv[1]
12 | LANG = detect(TEXT) if sys.argv[2] == "Auto" else sys.argv[2]
13 | RATE = sys.argv[3]
14 | VOLUME = sys.argv[4]
15 | GENDER = sys.argv[5] if len(sys.argv) == 6 else None
16 | OUTPUT_FILE = "tts.wav"
17 |
18 | print("Running TTS...")
19 | print(f"Text: {TEXT}, Language: {LANG}, Gender: {GENDER}, Rate: {RATE}, Volume: {VOLUME}")
20 |
21 | async def _main() -> None:
22 | voices = await VoicesManager.create()
23 | if GENDER is not None:
24 | # From "zh-cn" to "zh-CN" etc.
25 | if LANG == "zh-cn" or LANG == "zh-tw":
26 | LOCALE = LANG[:-2] + LANG[-2:].upper()
27 | voice = voices.find(Gender=GENDER, Locale=LOCALE)
28 | else:
29 | voice = voices.find(Gender=GENDER, Language=LANG)
30 | VOICE = random.choice(voice)["Name"]
31 | print(f"Using random {LANG} voice: {VOICE}")
32 | else:
33 | VOICE = LANG
34 |
35 | communicate = edge_tts.Communicate(text = TEXT, voice = VOICE, rate = RATE, volume = VOLUME)
36 | await communicate.save(OUTPUT_FILE)
37 |
38 | if __name__ == "__main__":
39 | if sys.platform.startswith("win"):
40 | asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
41 | asyncio.run(_main())
42 | else:
43 | loop = asyncio.get_event_loop_policy().get_event_loop()
44 | try:
45 | loop.run_until_complete(_main())
46 | finally:
47 | loop.close()
48 |
--------------------------------------------------------------------------------
/export_index_for_onnx.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import faiss
5 |
6 | path = "crs"
7 | indexs_file_path = f"checkpoints/{path}/feature_and_index.pkl"
8 | indexs_out_dir = f"checkpoints/{path}/"
9 |
10 | with open("feature_and_index.pkl",mode="rb") as f:
11 | indexs = pickle.load(f)
12 |
13 | for k in indexs:
14 | print(f"Save {k} index")
15 | faiss.write_index(
16 | indexs[k],
17 | os.path.join(indexs_out_dir,f"Index-{k}.index")
18 | )
19 |
20 | print("Saved all index")
--------------------------------------------------------------------------------
/filelists/test.txt:
--------------------------------------------------------------------------------
1 | ./dataset/44k/taffy/000562.wav
2 | ./dataset/44k/nyaru/000011.wav
3 | ./dataset/44k/nyaru/000008.wav
4 | ./dataset/44k/taffy/000563.wav
5 |
--------------------------------------------------------------------------------
/filelists/train.txt:
--------------------------------------------------------------------------------
1 | ./dataset/44k/taffy/000549.wav
2 | ./dataset/44k/nyaru/000004.wav
3 | ./dataset/44k/nyaru/000006.wav
4 | ./dataset/44k/taffy/000551.wav
5 | ./dataset/44k/nyaru/000009.wav
6 | ./dataset/44k/taffy/000561.wav
7 | ./dataset/44k/nyaru/000001.wav
8 | ./dataset/44k/taffy/000553.wav
9 | ./dataset/44k/nyaru/000002.wav
10 | ./dataset/44k/taffy/000560.wav
11 | ./dataset/44k/taffy/000557.wav
12 | ./dataset/44k/nyaru/000005.wav
13 | ./dataset/44k/taffy/000554.wav
14 | ./dataset/44k/taffy/000550.wav
15 | ./dataset/44k/taffy/000559.wav
16 |
--------------------------------------------------------------------------------
/filelists/val.txt:
--------------------------------------------------------------------------------
1 | ./dataset/44k/nyaru/000003.wav
2 | ./dataset/44k/nyaru/000007.wav
3 | ./dataset/44k/taffy/000558.wav
4 | ./dataset/44k/taffy/000556.wav
5 |
--------------------------------------------------------------------------------
/flask_api.py:
--------------------------------------------------------------------------------
1 | import io
2 | import logging
3 |
4 | import soundfile
5 | import torch
6 | import torchaudio
7 | from flask import Flask, request, send_file
8 | from flask_cors import CORS
9 |
10 | from inference.infer_tool import RealTimeVC, Svc
11 |
12 | app = Flask(__name__)
13 |
14 | CORS(app)
15 |
16 | logging.getLogger('numba').setLevel(logging.WARNING)
17 |
18 |
19 | @app.route("/voiceChangeModel", methods=["POST"])
20 | def voice_change_model():
21 | request_form = request.form
22 | wave_file = request.files.get("sample", None)
23 | # 变调信息
24 | f_pitch_change = float(request_form.get("fPitchChange", 0))
25 | # DAW所需的采样率
26 | daw_sample = int(float(request_form.get("sampleRate", 0)))
27 | speaker_id = int(float(request_form.get("sSpeakId", 0)))
28 | # http获得wav文件并转换
29 | input_wav_path = io.BytesIO(wave_file.read())
30 |
31 | # 模型推理
32 | if raw_infer:
33 | # out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path)
34 | out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path, cluster_infer_ratio=0,
35 | auto_predict_f0=False, noice_scale=0.4, f0_filter=False)
36 | tar_audio = torchaudio.functional.resample(out_audio, svc_model.target_sample, daw_sample)
37 | else:
38 | out_audio = svc.process(svc_model, speaker_id, f_pitch_change, input_wav_path, cluster_infer_ratio=0,
39 | auto_predict_f0=False, noice_scale=0.4, f0_filter=False)
40 | tar_audio = torchaudio.functional.resample(torch.from_numpy(out_audio), svc_model.target_sample, daw_sample)
41 | # 返回音频
42 | out_wav_path = io.BytesIO()
43 | soundfile.write(out_wav_path, tar_audio.cpu().numpy(), daw_sample, format="wav")
44 | out_wav_path.seek(0)
45 | return send_file(out_wav_path, download_name="temp.wav", as_attachment=True)
46 |
47 |
48 | if __name__ == '__main__':
49 | # 启用则为直接切片合成,False为交叉淡化方式
50 | # vst插件调整0.3-0.5s切片时间可以降低延迟,直接切片方法会有连接处爆音、交叉淡化会有轻微重叠声音
51 | # 自行选择能接受的方法,或将vst最大切片时间调整为1s,此处设为Ture,延迟大音质稳定一些
52 | raw_infer = True
53 | # 每个模型和config是唯一对应的
54 | model_name = "logs/32k/G_174000-Copy1.pth"
55 | config_name = "configs/config.json"
56 | cluster_model_path = "logs/44k/kmeans_10000.pt"
57 | svc_model = Svc(model_name, config_name, cluster_model_path=cluster_model_path)
58 | svc = RealTimeVC()
59 | # 此处与vst插件对应,不建议更改
60 | app.run(port=6842, host="0.0.0.0", debug=False, threaded=False)
61 |
--------------------------------------------------------------------------------
/flask_api_full_song.py:
--------------------------------------------------------------------------------
1 | import io
2 |
3 | import numpy as np
4 | import soundfile
5 | from flask import Flask, request, send_file
6 |
7 | from inference import infer_tool, slicer
8 |
9 | app = Flask(__name__)
10 |
11 |
12 | @app.route("/wav2wav", methods=["POST"])
13 | def wav2wav():
14 | request_form = request.form
15 | audio_path = request_form.get("audio_path", None) # wav文件地址
16 | tran = int(float(request_form.get("tran", 0))) # 音调
17 | spk = request_form.get("spk", 0) # 说话人(id或者name都可以,具体看你的config)
18 | wav_format = request_form.get("wav_format", 'wav') # 范围文件格式
19 | infer_tool.format_wav(audio_path)
20 | chunks = slicer.cut(audio_path, db_thresh=-40)
21 | audio_data, audio_sr = slicer.chunks2audio(audio_path, chunks)
22 |
23 | audio = []
24 | for (slice_tag, data) in audio_data:
25 | print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
26 |
27 | length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample))
28 | if slice_tag:
29 | print('jump empty segment')
30 | _audio = np.zeros(length)
31 | else:
32 | # padd
33 | pad_len = int(audio_sr * 0.5)
34 | data = np.concatenate([np.zeros([pad_len]), data, np.zeros([pad_len])])
35 | raw_path = io.BytesIO()
36 | soundfile.write(raw_path, data, audio_sr, format="wav")
37 | raw_path.seek(0)
38 | out_audio, out_sr = svc_model.infer(spk, tran, raw_path)
39 | svc_model.clear_empty()
40 | _audio = out_audio.cpu().numpy()
41 | pad_len = int(svc_model.target_sample * 0.5)
42 | _audio = _audio[pad_len:-pad_len]
43 |
44 | audio.extend(list(infer_tool.pad_array(_audio, length)))
45 | out_wav_path = io.BytesIO()
46 | soundfile.write(out_wav_path, audio, svc_model.target_sample, format=wav_format)
47 | out_wav_path.seek(0)
48 | return send_file(out_wav_path, download_name=f"temp.{wav_format}", as_attachment=True)
49 |
50 |
51 | if __name__ == '__main__':
52 | model_name = "logs/44k/G_60000.pth" # 模型地址
53 | config_name = "configs/config.json" # config地址
54 | svc_model = infer_tool.Svc(model_name, config_name)
55 | app.run(port=1145, host="0.0.0.0", debug=False, threaded=False)
56 |
--------------------------------------------------------------------------------
/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/inference/__init__.py
--------------------------------------------------------------------------------
/logs/44k/diffusion/put_diffusion_pretrained_model_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/logs/44k/diffusion/put_diffusion_pretrained_model_here
--------------------------------------------------------------------------------
/logs/44k/put_pretrained_model_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/logs/44k/put_pretrained_model_here
--------------------------------------------------------------------------------
/modules/DSConv.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn.utils import remove_weight_norm, weight_norm
3 |
4 |
5 | class Depthwise_Separable_Conv1D(nn.Module):
6 | def __init__(
7 | self,
8 | in_channels,
9 | out_channels,
10 | kernel_size,
11 | stride = 1,
12 | padding = 0,
13 | dilation = 1,
14 | bias = True,
15 | padding_mode = 'zeros', # TODO: refine this type
16 | device=None,
17 | dtype=None
18 | ):
19 | super().__init__()
20 | self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
21 | self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
22 |
23 | def forward(self, input):
24 | return self.point_conv(self.depth_conv(input))
25 |
26 | def weight_norm(self):
27 | self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
28 | self.point_conv = weight_norm(self.point_conv, name = 'weight')
29 |
30 | def remove_weight_norm(self):
31 | self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight')
32 | self.point_conv = remove_weight_norm(self.point_conv, name = 'weight')
33 |
34 | class Depthwise_Separable_TransposeConv1D(nn.Module):
35 | def __init__(
36 | self,
37 | in_channels,
38 | out_channels,
39 | kernel_size,
40 | stride = 1,
41 | padding = 0,
42 | output_padding = 0,
43 | bias = True,
44 | dilation = 1,
45 | padding_mode = 'zeros', # TODO: refine this type
46 | device=None,
47 | dtype=None
48 | ):
49 | super().__init__()
50 | self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
51 | self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
52 |
53 | def forward(self, input):
54 | return self.point_conv(self.depth_conv(input))
55 |
56 | def weight_norm(self):
57 | self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
58 | self.point_conv = weight_norm(self.point_conv, name = 'weight')
59 |
60 | def remove_weight_norm(self):
61 | remove_weight_norm(self.depth_conv, name = 'weight')
62 | remove_weight_norm(self.point_conv, name = 'weight')
63 |
64 |
65 | def weight_norm_modules(module, name = 'weight', dim = 0):
66 | if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
67 | module.weight_norm()
68 | return module
69 | else:
70 | return weight_norm(module,name,dim)
71 |
72 | def remove_weight_norm_modules(module, name = 'weight'):
73 | if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
74 | module.remove_weight_norm()
75 | else:
76 | remove_weight_norm(module,name)
--------------------------------------------------------------------------------
/modules/F0Predictor/CrepeF0Predictor.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from modules.F0Predictor.crepe import CrepePitchExtractor
4 | from modules.F0Predictor.F0Predictor import F0Predictor
5 |
6 |
7 | class CrepeF0Predictor(F0Predictor):
8 | def __init__(self,hop_length=512,f0_min=50,f0_max=1100,device=None,sampling_rate=44100,threshold=0.05,model="full"):
9 | self.F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=threshold,model=model)
10 | self.hop_length = hop_length
11 | self.f0_min = f0_min
12 | self.f0_max = f0_max
13 | self.device = device
14 | self.threshold = threshold
15 | self.sampling_rate = sampling_rate
16 | self.name = "crepe"
17 |
18 | def compute_f0(self,wav,p_len=None):
19 | x = torch.FloatTensor(wav).to(self.device)
20 | if p_len is None:
21 | p_len = x.shape[0]//self.hop_length
22 | else:
23 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
24 | f0,uv = self.F0Creper(x[None,:].float(),self.sampling_rate,pad_to=p_len)
25 | return f0
26 |
27 | def compute_f0_uv(self,wav,p_len=None):
28 | x = torch.FloatTensor(wav).to(self.device)
29 | if p_len is None:
30 | p_len = x.shape[0]//self.hop_length
31 | else:
32 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
33 | f0,uv = self.F0Creper(x[None,:].float(),self.sampling_rate,pad_to=p_len)
34 | return f0,uv
--------------------------------------------------------------------------------
/modules/F0Predictor/DioF0Predictor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pyworld
3 |
4 | from modules.F0Predictor.F0Predictor import F0Predictor
5 |
6 |
7 | class DioF0Predictor(F0Predictor):
8 | def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
9 | self.hop_length = hop_length
10 | self.f0_min = f0_min
11 | self.f0_max = f0_max
12 | self.sampling_rate = sampling_rate
13 | self.name = "dio"
14 |
15 | def interpolate_f0(self,f0):
16 | '''
17 | 对F0进行插值处理
18 | '''
19 | vuv_vector = np.zeros_like(f0, dtype=np.float32)
20 | vuv_vector[f0 > 0.0] = 1.0
21 | vuv_vector[f0 <= 0.0] = 0.0
22 |
23 | nzindex = np.nonzero(f0)[0]
24 | data = f0[nzindex]
25 | nzindex = nzindex.astype(np.float32)
26 | time_org = self.hop_length / self.sampling_rate * nzindex
27 | time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
28 |
29 | if data.shape[0] <= 0:
30 | return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
31 |
32 | if data.shape[0] == 1:
33 | return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
34 |
35 | f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
36 |
37 | return f0,vuv_vector
38 |
39 | def resize_f0(self,x, target_len):
40 | source = np.array(x)
41 | source[source<0.001] = np.nan
42 | target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source)
43 | res = np.nan_to_num(target)
44 | return res
45 |
46 | def compute_f0(self,wav,p_len=None):
47 | if p_len is None:
48 | p_len = wav.shape[0]//self.hop_length
49 | f0, t = pyworld.dio(
50 | wav.astype(np.double),
51 | fs=self.sampling_rate,
52 | f0_floor=self.f0_min,
53 | f0_ceil=self.f0_max,
54 | frame_period=1000 * self.hop_length / self.sampling_rate,
55 | )
56 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
57 | for index, pitch in enumerate(f0):
58 | f0[index] = round(pitch, 1)
59 | return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
60 |
61 | def compute_f0_uv(self,wav,p_len=None):
62 | if p_len is None:
63 | p_len = wav.shape[0]//self.hop_length
64 | f0, t = pyworld.dio(
65 | wav.astype(np.double),
66 | fs=self.sampling_rate,
67 | f0_floor=self.f0_min,
68 | f0_ceil=self.f0_max,
69 | frame_period=1000 * self.hop_length / self.sampling_rate,
70 | )
71 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
72 | for index, pitch in enumerate(f0):
73 | f0[index] = round(pitch, 1)
74 | return self.interpolate_f0(self.resize_f0(f0, p_len))
75 |
--------------------------------------------------------------------------------
/modules/F0Predictor/F0Predictor.py:
--------------------------------------------------------------------------------
1 | class F0Predictor(object):
2 | def compute_f0(self,wav,p_len):
3 | '''
4 | input: wav:[signal_length]
5 | p_len:int
6 | output: f0:[signal_length//hop_length]
7 | '''
8 | pass
9 |
10 | def compute_f0_uv(self,wav,p_len):
11 | '''
12 | input: wav:[signal_length]
13 | p_len:int
14 | output: f0:[signal_length//hop_length],uv:[signal_length//hop_length]
15 | '''
16 | pass
--------------------------------------------------------------------------------
/modules/F0Predictor/FCPEF0Predictor.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from modules.F0Predictor.F0Predictor import F0Predictor
8 |
9 | from .fcpe.model import FCPEInfer
10 |
11 |
12 | class FCPEF0Predictor(F0Predictor):
13 | def __init__(self, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sampling_rate=44100,
14 | threshold=0.05):
15 | self.fcpe = FCPEInfer(model_path="pretrain/fcpe.pt", device=device, dtype=dtype)
16 | self.hop_length = hop_length
17 | self.f0_min = f0_min
18 | self.f0_max = f0_max
19 | if device is None:
20 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
21 | else:
22 | self.device = device
23 | self.threshold = threshold
24 | self.sampling_rate = sampling_rate
25 | self.dtype = dtype
26 | self.name = "fcpe"
27 |
28 | def repeat_expand(
29 | self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
30 | ):
31 | ndim = content.ndim
32 |
33 | if content.ndim == 1:
34 | content = content[None, None]
35 | elif content.ndim == 2:
36 | content = content[None]
37 |
38 | assert content.ndim == 3
39 |
40 | is_np = isinstance(content, np.ndarray)
41 | if is_np:
42 | content = torch.from_numpy(content)
43 |
44 | results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
45 |
46 | if is_np:
47 | results = results.numpy()
48 |
49 | if ndim == 1:
50 | return results[0, 0]
51 | elif ndim == 2:
52 | return results[0]
53 |
54 | def post_process(self, x, sampling_rate, f0, pad_to):
55 | if isinstance(f0, np.ndarray):
56 | f0 = torch.from_numpy(f0).float().to(x.device)
57 |
58 | if pad_to is None:
59 | return f0
60 |
61 | f0 = self.repeat_expand(f0, pad_to)
62 |
63 | vuv_vector = torch.zeros_like(f0)
64 | vuv_vector[f0 > 0.0] = 1.0
65 | vuv_vector[f0 <= 0.0] = 0.0
66 |
67 | # 去掉0频率, 并线性插值
68 | nzindex = torch.nonzero(f0).squeeze()
69 | f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
70 | time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
71 | time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
72 |
73 | vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0]
74 |
75 | if f0.shape[0] <= 0:
76 | return torch.zeros(pad_to, dtype=torch.float, device=x.device).cpu().numpy(), vuv_vector.cpu().numpy()
77 | if f0.shape[0] == 1:
78 | return (torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[
79 | 0]).cpu().numpy(), vuv_vector.cpu().numpy()
80 |
81 | # 大概可以用 torch 重写?
82 | f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
83 | # vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
84 |
85 | return f0, vuv_vector.cpu().numpy()
86 |
87 | def compute_f0(self, wav, p_len=None):
88 | x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
89 | if p_len is None:
90 | p_len = x.shape[0] // self.hop_length
91 | else:
92 | assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
93 | f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0]
94 | if torch.all(f0 == 0):
95 | rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
96 | return rtn, rtn
97 | return self.post_process(x, self.sampling_rate, f0, p_len)[0]
98 |
99 | def compute_f0_uv(self, wav, p_len=None):
100 | x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
101 | if p_len is None:
102 | p_len = x.shape[0] // self.hop_length
103 | else:
104 | assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
105 | f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0]
106 | if torch.all(f0 == 0):
107 | rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
108 | return rtn, rtn
109 | return self.post_process(x, self.sampling_rate, f0, p_len)
--------------------------------------------------------------------------------
/modules/F0Predictor/HarvestF0Predictor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pyworld
3 |
4 | from modules.F0Predictor.F0Predictor import F0Predictor
5 |
6 |
7 | class HarvestF0Predictor(F0Predictor):
8 | def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
9 | self.hop_length = hop_length
10 | self.f0_min = f0_min
11 | self.f0_max = f0_max
12 | self.sampling_rate = sampling_rate
13 | self.name = "harvest"
14 |
15 | def interpolate_f0(self,f0):
16 | '''
17 | 对F0进行插值处理
18 | '''
19 | vuv_vector = np.zeros_like(f0, dtype=np.float32)
20 | vuv_vector[f0 > 0.0] = 1.0
21 | vuv_vector[f0 <= 0.0] = 0.0
22 |
23 | nzindex = np.nonzero(f0)[0]
24 | data = f0[nzindex]
25 | nzindex = nzindex.astype(np.float32)
26 | time_org = self.hop_length / self.sampling_rate * nzindex
27 | time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
28 |
29 | if data.shape[0] <= 0:
30 | return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
31 |
32 | if data.shape[0] == 1:
33 | return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
34 |
35 | f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
36 |
37 | return f0,vuv_vector
38 | def resize_f0(self,x, target_len):
39 | source = np.array(x)
40 | source[source<0.001] = np.nan
41 | target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source)
42 | res = np.nan_to_num(target)
43 | return res
44 |
45 | def compute_f0(self,wav,p_len=None):
46 | if p_len is None:
47 | p_len = wav.shape[0]//self.hop_length
48 | f0, t = pyworld.harvest(
49 | wav.astype(np.double),
50 | fs=self.hop_length,
51 | f0_ceil=self.f0_max,
52 | f0_floor=self.f0_min,
53 | frame_period=1000 * self.hop_length / self.sampling_rate,
54 | )
55 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.fs)
56 | return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
57 |
58 | def compute_f0_uv(self,wav,p_len=None):
59 | if p_len is None:
60 | p_len = wav.shape[0]//self.hop_length
61 | f0, t = pyworld.harvest(
62 | wav.astype(np.double),
63 | fs=self.sampling_rate,
64 | f0_floor=self.f0_min,
65 | f0_ceil=self.f0_max,
66 | frame_period=1000 * self.hop_length / self.sampling_rate,
67 | )
68 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
69 | return self.interpolate_f0(self.resize_f0(f0, p_len))
70 |
--------------------------------------------------------------------------------
/modules/F0Predictor/PMF0Predictor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import parselmouth
3 |
4 | from modules.F0Predictor.F0Predictor import F0Predictor
5 |
6 |
7 | class PMF0Predictor(F0Predictor):
8 | def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
9 | self.hop_length = hop_length
10 | self.f0_min = f0_min
11 | self.f0_max = f0_max
12 | self.sampling_rate = sampling_rate
13 | self.name = "pm"
14 |
15 | def interpolate_f0(self,f0):
16 | '''
17 | 对F0进行插值处理
18 | '''
19 | vuv_vector = np.zeros_like(f0, dtype=np.float32)
20 | vuv_vector[f0 > 0.0] = 1.0
21 | vuv_vector[f0 <= 0.0] = 0.0
22 |
23 | nzindex = np.nonzero(f0)[0]
24 | data = f0[nzindex]
25 | nzindex = nzindex.astype(np.float32)
26 | time_org = self.hop_length / self.sampling_rate * nzindex
27 | time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
28 |
29 | if data.shape[0] <= 0:
30 | return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
31 |
32 | if data.shape[0] == 1:
33 | return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
34 |
35 | f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
36 |
37 | return f0,vuv_vector
38 |
39 |
40 | def compute_f0(self,wav,p_len=None):
41 | x = wav
42 | if p_len is None:
43 | p_len = x.shape[0]//self.hop_length
44 | else:
45 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
46 | time_step = self.hop_length / self.sampling_rate * 1000
47 | f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac(
48 | time_step=time_step / 1000, voicing_threshold=0.6,
49 | pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency']
50 |
51 | pad_size=(p_len - len(f0) + 1) // 2
52 | if(pad_size>0 or p_len - len(f0) - pad_size>0):
53 | f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
54 | f0,uv = self.interpolate_f0(f0)
55 | return f0
56 |
57 | def compute_f0_uv(self,wav,p_len=None):
58 | x = wav
59 | if p_len is None:
60 | p_len = x.shape[0]//self.hop_length
61 | else:
62 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
63 | time_step = self.hop_length / self.sampling_rate * 1000
64 | f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac(
65 | time_step=time_step / 1000, voicing_threshold=0.6,
66 | pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency']
67 |
68 | pad_size=(p_len - len(f0) + 1) // 2
69 | if(pad_size>0 or p_len - len(f0) - pad_size>0):
70 | f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
71 | f0,uv = self.interpolate_f0(f0)
72 | return f0,uv
73 |
--------------------------------------------------------------------------------
/modules/F0Predictor/RMVPEF0Predictor.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from modules.F0Predictor.F0Predictor import F0Predictor
8 |
9 | from .rmvpe import RMVPE
10 |
11 |
12 | class RMVPEF0Predictor(F0Predictor):
13 | def __init__(self,hop_length=512,f0_min=50,f0_max=1100, dtype=torch.float32, device=None,sampling_rate=44100,threshold=0.05):
14 | self.rmvpe = RMVPE(model_path="pretrain/rmvpe.pt",dtype=dtype,device=device)
15 | self.hop_length = hop_length
16 | self.f0_min = f0_min
17 | self.f0_max = f0_max
18 | if device is None:
19 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
20 | else:
21 | self.device = device
22 | self.threshold = threshold
23 | self.sampling_rate = sampling_rate
24 | self.dtype = dtype
25 | self.name = "rmvpe"
26 |
27 | def repeat_expand(
28 | self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
29 | ):
30 | ndim = content.ndim
31 |
32 | if content.ndim == 1:
33 | content = content[None, None]
34 | elif content.ndim == 2:
35 | content = content[None]
36 |
37 | assert content.ndim == 3
38 |
39 | is_np = isinstance(content, np.ndarray)
40 | if is_np:
41 | content = torch.from_numpy(content)
42 |
43 | results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
44 |
45 | if is_np:
46 | results = results.numpy()
47 |
48 | if ndim == 1:
49 | return results[0, 0]
50 | elif ndim == 2:
51 | return results[0]
52 |
53 | def post_process(self, x, sampling_rate, f0, pad_to):
54 | if isinstance(f0, np.ndarray):
55 | f0 = torch.from_numpy(f0).float().to(x.device)
56 |
57 | if pad_to is None:
58 | return f0
59 |
60 | f0 = self.repeat_expand(f0, pad_to)
61 |
62 | vuv_vector = torch.zeros_like(f0)
63 | vuv_vector[f0 > 0.0] = 1.0
64 | vuv_vector[f0 <= 0.0] = 0.0
65 |
66 | # 去掉0频率, 并线性插值
67 | nzindex = torch.nonzero(f0).squeeze()
68 | f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
69 | time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
70 | time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
71 |
72 | vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0]
73 |
74 | if f0.shape[0] <= 0:
75 | return torch.zeros(pad_to, dtype=torch.float, device=x.device).cpu().numpy(),vuv_vector.cpu().numpy()
76 | if f0.shape[0] == 1:
77 | return (torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0]).cpu().numpy() ,vuv_vector.cpu().numpy()
78 |
79 | # 大概可以用 torch 重写?
80 | f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
81 | #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
82 |
83 | return f0,vuv_vector.cpu().numpy()
84 |
85 | def compute_f0(self,wav,p_len=None):
86 | x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
87 | if p_len is None:
88 | p_len = x.shape[0]//self.hop_length
89 | else:
90 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
91 | f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
92 | if torch.all(f0 == 0):
93 | rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
94 | return rtn,rtn
95 | return self.post_process(x,self.sampling_rate,f0,p_len)[0]
96 |
97 | def compute_f0_uv(self,wav,p_len=None):
98 | x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
99 | if p_len is None:
100 | p_len = x.shape[0]//self.hop_length
101 | else:
102 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
103 | f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
104 | if torch.all(f0 == 0):
105 | rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
106 | return rtn,rtn
107 | return self.post_process(x,self.sampling_rate,f0,p_len)
--------------------------------------------------------------------------------
/modules/F0Predictor/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/modules/F0Predictor/__init__.py
--------------------------------------------------------------------------------
/modules/F0Predictor/fcpe/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import FCPEInfer # noqa: F401
2 | from .nvSTFT import STFT # noqa: F401
3 | from .pcmer import PCmer # noqa: F401
4 |
--------------------------------------------------------------------------------
/modules/F0Predictor/fcpe/nvSTFT.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | import soundfile as sf
6 | import torch
7 | import torch.nn.functional as F
8 | import torch.utils.data
9 | from librosa.filters import mel as librosa_mel_fn
10 |
11 | os.environ["LRU_CACHE_CAPACITY"] = "3"
12 |
13 | def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
14 | sampling_rate = None
15 | try:
16 | data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
17 | except Exception as ex:
18 | print(f"'{full_path}' failed to load.\nException:")
19 | print(ex)
20 | if return_empty_on_exception:
21 | return [], sampling_rate or target_sr or 48000
22 | else:
23 | raise Exception(ex)
24 |
25 | if len(data.shape) > 1:
26 | data = data[:, 0]
27 | assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
28 |
29 | if np.issubdtype(data.dtype, np.integer): # if audio data is type int
30 | max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
31 | else: # if audio data is type fp32
32 | max_mag = max(np.amax(data), -np.amin(data))
33 | max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
34 |
35 | data = torch.FloatTensor(data.astype(np.float32))/max_mag
36 |
37 | if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
38 | return [], sampling_rate or target_sr or 48000
39 | if target_sr is not None and sampling_rate != target_sr:
40 | data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
41 | sampling_rate = target_sr
42 |
43 | return data, sampling_rate
44 |
45 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
46 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
47 |
48 | def dynamic_range_decompression(x, C=1):
49 | return np.exp(x) / C
50 |
51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
52 | return torch.log(torch.clamp(x, min=clip_val) * C)
53 |
54 | def dynamic_range_decompression_torch(x, C=1):
55 | return torch.exp(x) / C
56 |
57 | class STFT():
58 | def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
59 | self.target_sr = sr
60 |
61 | self.n_mels = n_mels
62 | self.n_fft = n_fft
63 | self.win_size = win_size
64 | self.hop_length = hop_length
65 | self.fmin = fmin
66 | self.fmax = fmax
67 | self.clip_val = clip_val
68 | self.mel_basis = {}
69 | self.hann_window = {}
70 |
71 | def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
72 | sampling_rate = self.target_sr
73 | n_mels = self.n_mels
74 | n_fft = self.n_fft
75 | win_size = self.win_size
76 | hop_length = self.hop_length
77 | fmin = self.fmin
78 | fmax = self.fmax
79 | clip_val = self.clip_val
80 |
81 | factor = 2 ** (keyshift / 12)
82 | n_fft_new = int(np.round(n_fft * factor))
83 | win_size_new = int(np.round(win_size * factor))
84 | hop_length_new = int(np.round(hop_length * speed))
85 | if not train:
86 | mel_basis = self.mel_basis
87 | hann_window = self.hann_window
88 | else:
89 | mel_basis = {}
90 | hann_window = {}
91 |
92 | if torch.min(y) < -1.:
93 | print('min value is ', torch.min(y))
94 | if torch.max(y) > 1.:
95 | print('max value is ', torch.max(y))
96 |
97 | mel_basis_key = str(fmax)+'_'+str(y.device)
98 | if mel_basis_key not in mel_basis:
99 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
100 | mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
101 |
102 | keyshift_key = str(keyshift)+'_'+str(y.device)
103 | if keyshift_key not in hann_window:
104 | hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
105 |
106 | pad_left = (win_size_new - hop_length_new) //2
107 | pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left)
108 | if pad_right < y.size(-1):
109 | mode = 'reflect'
110 | else:
111 | mode = 'constant'
112 | y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode)
113 | y = y.squeeze(1)
114 |
115 | spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key],
116 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
117 | spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
118 | if keyshift != 0:
119 | size = n_fft // 2 + 1
120 | resize = spec.size(1)
121 | if resize < size:
122 | spec = F.pad(spec, (0, 0, 0, size-resize))
123 | spec = spec[:, :size, :] * win_size / win_size_new
124 | spec = torch.matmul(mel_basis[mel_basis_key], spec)
125 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
126 | return spec
127 |
128 | def __call__(self, audiopath):
129 | audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
130 | spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
131 | return spect
132 |
133 | stft = STFT()
134 |
--------------------------------------------------------------------------------
/modules/F0Predictor/rmvpe/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import * # noqa: F403
2 | from .inference import RMVPE # noqa: F401
3 | from .model import E2E, E2E0 # noqa: F401
4 | from .spec import MelSpectrogram # noqa: F401
5 | from .utils import ( # noqa: F401
6 | cycle,
7 | summary,
8 | to_local_average_cents,
9 | to_viterbi_cents,
10 | )
11 |
--------------------------------------------------------------------------------
/modules/F0Predictor/rmvpe/constants.py:
--------------------------------------------------------------------------------
1 | SAMPLE_RATE = 16000
2 |
3 | N_CLASS = 360
4 |
5 | N_MELS = 128
6 | MEL_FMIN = 30
7 | MEL_FMAX = SAMPLE_RATE // 2
8 | WINDOW_LENGTH = 1024
9 | CONST = 1997.3794084376191
10 |
--------------------------------------------------------------------------------
/modules/F0Predictor/rmvpe/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torchaudio.transforms import Resample
4 |
5 | from .constants import * # noqa: F403
6 | from .model import E2E0
7 | from .spec import MelSpectrogram
8 | from .utils import to_local_average_cents, to_viterbi_cents
9 |
10 |
11 | class RMVPE:
12 | def __init__(self, model_path, device=None, dtype = torch.float32, hop_length=160):
13 | self.resample_kernel = {}
14 | if device is None:
15 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
16 | else:
17 | self.device = device
18 | model = E2E0(4, 1, (2, 2))
19 | ckpt = torch.load(model_path, map_location=torch.device(self.device))
20 | model.load_state_dict(ckpt['model'])
21 | model = model.to(dtype).to(self.device)
22 | model.eval()
23 | self.model = model
24 | self.dtype = dtype
25 | self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405
26 | self.resample_kernel = {}
27 |
28 | def mel2hidden(self, mel):
29 | with torch.no_grad():
30 | n_frames = mel.shape[-1]
31 | mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant')
32 | hidden = self.model(mel)
33 | return hidden[:, :n_frames]
34 |
35 | def decode(self, hidden, thred=0.03, use_viterbi=False):
36 | if use_viterbi:
37 | cents_pred = to_viterbi_cents(hidden, thred=thred)
38 | else:
39 | cents_pred = to_local_average_cents(hidden, thred=thred)
40 | f0 = torch.Tensor([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred]).to(self.device)
41 | return f0
42 |
43 | def infer_from_audio(self, audio, sample_rate=16000, thred=0.05, use_viterbi=False):
44 | audio = audio.unsqueeze(0).to(self.dtype).to(self.device)
45 | if sample_rate == 16000:
46 | audio_res = audio
47 | else:
48 | key_str = str(sample_rate)
49 | if key_str not in self.resample_kernel:
50 | self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
51 | self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device)
52 | audio_res = self.resample_kernel[key_str](audio)
53 | mel_extractor = self.mel_extractor.to(self.device)
54 | mel = mel_extractor(audio_res, center=True).to(self.dtype)
55 | hidden = self.mel2hidden(mel)
56 | f0 = self.decode(hidden.squeeze(0), thred=thred, use_viterbi=use_viterbi)
57 | return f0
58 |
--------------------------------------------------------------------------------
/modules/F0Predictor/rmvpe/model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from .constants import * # noqa: F403
4 | from .deepunet import DeepUnet, DeepUnet0
5 | from .seq import BiGRU
6 | from .spec import MelSpectrogram
7 |
8 |
9 | class E2E(nn.Module):
10 | def __init__(self, hop_length, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
11 | en_out_channels=16):
12 | super(E2E, self).__init__()
13 | self.mel = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405
14 | self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
15 | self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
16 | if n_gru:
17 | self.fc = nn.Sequential(
18 | BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405
19 | nn.Linear(512, N_CLASS), # noqa: F405
20 | nn.Dropout(0.25),
21 | nn.Sigmoid()
22 | )
23 | else:
24 | self.fc = nn.Sequential(
25 | nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405
26 | nn.Dropout(0.25),
27 | nn.Sigmoid()
28 | )
29 |
30 | def forward(self, x):
31 | mel = self.mel(x.reshape(-1, x.shape[-1])).transpose(-1, -2).unsqueeze(1)
32 | x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
33 | # x = self.fc(x)
34 | hidden_vec = 0
35 | if len(self.fc) == 4:
36 | for i in range(len(self.fc)):
37 | x = self.fc[i](x)
38 | if i == 0:
39 | hidden_vec = x
40 | return hidden_vec, x
41 |
42 |
43 | class E2E0(nn.Module):
44 | def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
45 | en_out_channels=16):
46 | super(E2E0, self).__init__()
47 | self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
48 | self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
49 | if n_gru:
50 | self.fc = nn.Sequential(
51 | BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405
52 | nn.Linear(512, N_CLASS), # noqa: F405
53 | nn.Dropout(0.25),
54 | nn.Sigmoid()
55 | )
56 | else:
57 | self.fc = nn.Sequential(
58 | nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405
59 | nn.Dropout(0.25),
60 | nn.Sigmoid()
61 | )
62 |
63 | def forward(self, mel):
64 | mel = mel.transpose(-1, -2).unsqueeze(1)
65 | x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
66 | x = self.fc(x)
67 | return x
68 |
--------------------------------------------------------------------------------
/modules/F0Predictor/rmvpe/seq.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BiGRU(nn.Module):
5 | def __init__(self, input_features, hidden_features, num_layers):
6 | super(BiGRU, self).__init__()
7 | self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
8 |
9 | def forward(self, x):
10 | return self.gru(x)[0]
11 |
12 |
13 | class BiLSTM(nn.Module):
14 | def __init__(self, input_features, hidden_features, num_layers):
15 | super(BiLSTM, self).__init__()
16 | self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
17 |
18 | def forward(self, x):
19 | return self.lstm(x)[0]
20 |
21 |
--------------------------------------------------------------------------------
/modules/F0Predictor/rmvpe/spec.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from librosa.filters import mel
5 |
6 |
7 | class MelSpectrogram(torch.nn.Module):
8 | def __init__(
9 | self,
10 | n_mel_channels,
11 | sampling_rate,
12 | win_length,
13 | hop_length,
14 | n_fft=None,
15 | mel_fmin=0,
16 | mel_fmax=None,
17 | clamp = 1e-5
18 | ):
19 | super().__init__()
20 | n_fft = win_length if n_fft is None else n_fft
21 | self.hann_window = {}
22 | mel_basis = mel(
23 | sr=sampling_rate,
24 | n_fft=n_fft,
25 | n_mels=n_mel_channels,
26 | fmin=mel_fmin,
27 | fmax=mel_fmax,
28 | htk=True)
29 | mel_basis = torch.from_numpy(mel_basis).float()
30 | self.register_buffer("mel_basis", mel_basis)
31 | self.n_fft = win_length if n_fft is None else n_fft
32 | self.hop_length = hop_length
33 | self.win_length = win_length
34 | self.sampling_rate = sampling_rate
35 | self.n_mel_channels = n_mel_channels
36 | self.clamp = clamp
37 |
38 | def forward(self, audio, keyshift=0, speed=1, center=True):
39 | factor = 2 ** (keyshift / 12)
40 | n_fft_new = int(np.round(self.n_fft * factor))
41 | win_length_new = int(np.round(self.win_length * factor))
42 | hop_length_new = int(np.round(self.hop_length * speed))
43 |
44 | keyshift_key = str(keyshift)+'_'+str(audio.device)
45 | if keyshift_key not in self.hann_window:
46 | self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
47 |
48 | fft = torch.stft(
49 | audio,
50 | n_fft=n_fft_new,
51 | hop_length=hop_length_new,
52 | win_length=win_length_new,
53 | window=self.hann_window[keyshift_key],
54 | center=center,
55 | return_complex=True)
56 | magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
57 |
58 | if keyshift != 0:
59 | size = self.n_fft // 2 + 1
60 | resize = magnitude.size(1)
61 | if resize < size:
62 | magnitude = F.pad(magnitude, (0, 0, 0, size-resize))
63 | magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
64 |
65 | mel_output = torch.matmul(self.mel_basis, magnitude)
66 | log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
67 | return log_mel_spec
--------------------------------------------------------------------------------
/modules/F0Predictor/rmvpe/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from functools import reduce
3 |
4 | import librosa
5 | import numpy as np
6 | import torch
7 | from torch.nn.modules.module import _addindent
8 |
9 | from .constants import * # noqa: F403
10 |
11 |
12 | def cycle(iterable):
13 | while True:
14 | for item in iterable:
15 | yield item
16 |
17 |
18 | def summary(model, file=sys.stdout):
19 | def repr(model):
20 | # We treat the extra repr like the sub-module, one item per line
21 | extra_lines = []
22 | extra_repr = model.extra_repr()
23 | # empty string will be split into list ['']
24 | if extra_repr:
25 | extra_lines = extra_repr.split('\n')
26 | child_lines = []
27 | total_params = 0
28 | for key, module in model._modules.items():
29 | mod_str, num_params = repr(module)
30 | mod_str = _addindent(mod_str, 2)
31 | child_lines.append('(' + key + '): ' + mod_str)
32 | total_params += num_params
33 | lines = extra_lines + child_lines
34 |
35 | for name, p in model._parameters.items():
36 | if hasattr(p, 'shape'):
37 | total_params += reduce(lambda x, y: x * y, p.shape)
38 |
39 | main_str = model._get_name() + '('
40 | if lines:
41 | # simple one-liner info, which most builtin Modules will use
42 | if len(extra_lines) == 1 and not child_lines:
43 | main_str += extra_lines[0]
44 | else:
45 | main_str += '\n ' + '\n '.join(lines) + '\n'
46 |
47 | main_str += ')'
48 | if file is sys.stdout:
49 | main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
50 | else:
51 | main_str += ', {:,} params'.format(total_params)
52 | return main_str, total_params
53 |
54 | string, count = repr(model)
55 | if file is not None:
56 | if isinstance(file, str):
57 | file = open(file, 'w')
58 | print(string, file=file)
59 | file.flush()
60 |
61 | return count
62 |
63 |
64 | def to_local_average_cents(salience, center=None, thred=0.05):
65 | """
66 | find the weighted average cents near the argmax bin
67 | """
68 |
69 | if not hasattr(to_local_average_cents, 'cents_mapping'):
70 | # the bin number-to-cents mapping
71 | to_local_average_cents.cents_mapping = (
72 | 20 * torch.arange(N_CLASS) + CONST).to(salience.device) # noqa: F405
73 |
74 | if salience.ndim == 1:
75 | if center is None:
76 | center = int(torch.argmax(salience))
77 | start = max(0, center - 4)
78 | end = min(len(salience), center + 5)
79 | salience = salience[start:end]
80 | product_sum = torch.sum(
81 | salience * to_local_average_cents.cents_mapping[start:end])
82 | weight_sum = torch.sum(salience)
83 | return product_sum / weight_sum if torch.max(salience) > thred else 0
84 | if salience.ndim == 2:
85 | return torch.Tensor([to_local_average_cents(salience[i, :], None, thred) for i in
86 | range(salience.shape[0])]).to(salience.device)
87 |
88 | raise Exception("label should be either 1d or 2d ndarray")
89 |
90 | def to_viterbi_cents(salience, thred=0.05):
91 | # Create viterbi transition matrix
92 | if not hasattr(to_viterbi_cents, 'transition'):
93 | xx, yy = torch.meshgrid(range(N_CLASS), range(N_CLASS)) # noqa: F405
94 | transition = torch.maximum(30 - abs(xx - yy), 0)
95 | transition = transition / transition.sum(axis=1, keepdims=True)
96 | to_viterbi_cents.transition = transition
97 |
98 | # Convert to probability
99 | prob = salience.T
100 | prob = prob / prob.sum(axis=0)
101 |
102 | # Perform viterbi decoding
103 | path = librosa.sequence.viterbi(prob.detach().cpu().numpy(), to_viterbi_cents.transition).astype(np.int64)
104 |
105 | return torch.Tensor([to_local_average_cents(salience[i, :], path[i], thred) for i in
106 | range(len(path))]).to(salience.device)
107 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/modules/__init__.py
--------------------------------------------------------------------------------
/modules/enhancer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torchaudio.transforms import Resample
5 |
6 | from vdecoder.nsf_hifigan.models import load_model
7 | from vdecoder.nsf_hifigan.nvSTFT import STFT
8 |
9 |
10 | class Enhancer:
11 | def __init__(self, enhancer_type, enhancer_ckpt, device=None):
12 | if device is None:
13 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
14 | self.device = device
15 |
16 | if enhancer_type == 'nsf-hifigan':
17 | self.enhancer = NsfHifiGAN(enhancer_ckpt, device=self.device)
18 | else:
19 | raise ValueError(f" [x] Unknown enhancer: {enhancer_type}")
20 |
21 | self.resample_kernel = {}
22 | self.enhancer_sample_rate = self.enhancer.sample_rate()
23 | self.enhancer_hop_size = self.enhancer.hop_size()
24 |
25 | def enhance(self,
26 | audio, # 1, T
27 | sample_rate,
28 | f0, # 1, n_frames, 1
29 | hop_size,
30 | adaptive_key = 0,
31 | silence_front = 0
32 | ):
33 | # enhancer start time
34 | start_frame = int(silence_front * sample_rate / hop_size)
35 | real_silence_front = start_frame * hop_size / sample_rate
36 | audio = audio[:, int(np.round(real_silence_front * sample_rate)) : ]
37 | f0 = f0[: , start_frame :, :]
38 |
39 | # adaptive parameters
40 | adaptive_factor = 2 ** ( -adaptive_key / 12)
41 | adaptive_sample_rate = 100 * int(np.round(self.enhancer_sample_rate / adaptive_factor / 100))
42 | real_factor = self.enhancer_sample_rate / adaptive_sample_rate
43 |
44 | # resample the ddsp output
45 | if sample_rate == adaptive_sample_rate:
46 | audio_res = audio
47 | else:
48 | key_str = str(sample_rate) + str(adaptive_sample_rate)
49 | if key_str not in self.resample_kernel:
50 | self.resample_kernel[key_str] = Resample(sample_rate, adaptive_sample_rate, lowpass_filter_width = 128).to(self.device)
51 | audio_res = self.resample_kernel[key_str](audio)
52 |
53 | n_frames = int(audio_res.size(-1) // self.enhancer_hop_size + 1)
54 |
55 | # resample f0
56 | f0_np = f0.squeeze(0).squeeze(-1).cpu().numpy()
57 | f0_np *= real_factor
58 | time_org = (hop_size / sample_rate) * np.arange(len(f0_np)) / real_factor
59 | time_frame = (self.enhancer_hop_size / self.enhancer_sample_rate) * np.arange(n_frames)
60 | f0_res = np.interp(time_frame, time_org, f0_np, left=f0_np[0], right=f0_np[-1])
61 | f0_res = torch.from_numpy(f0_res).unsqueeze(0).float().to(self.device) # 1, n_frames
62 |
63 | # enhance
64 | enhanced_audio, enhancer_sample_rate = self.enhancer(audio_res, f0_res)
65 |
66 | # resample the enhanced output
67 | if adaptive_factor != 0:
68 | key_str = str(adaptive_sample_rate) + str(enhancer_sample_rate)
69 | if key_str not in self.resample_kernel:
70 | self.resample_kernel[key_str] = Resample(adaptive_sample_rate, enhancer_sample_rate, lowpass_filter_width = 128).to(self.device)
71 | enhanced_audio = self.resample_kernel[key_str](enhanced_audio)
72 |
73 | # pad the silence frames
74 | if start_frame > 0:
75 | enhanced_audio = F.pad(enhanced_audio, (int(np.round(enhancer_sample_rate * real_silence_front)), 0))
76 |
77 | return enhanced_audio, enhancer_sample_rate
78 |
79 |
80 | class NsfHifiGAN(torch.nn.Module):
81 | def __init__(self, model_path, device=None):
82 | super().__init__()
83 | if device is None:
84 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
85 | self.device = device
86 | print('| Load HifiGAN: ', model_path)
87 | self.model, self.h = load_model(model_path, device=self.device)
88 |
89 | def sample_rate(self):
90 | return self.h.sampling_rate
91 |
92 | def hop_size(self):
93 | return self.h.hop_size
94 |
95 | def forward(self, audio, f0):
96 | stft = STFT(
97 | self.h.sampling_rate,
98 | self.h.num_mels,
99 | self.h.n_fft,
100 | self.h.win_size,
101 | self.h.hop_size,
102 | self.h.fmin,
103 | self.h.fmax)
104 | with torch.no_grad():
105 | mel = stft.get_mel(audio)
106 | enhanced_audio = self.model(mel, f0[:,:mel.size(-1)]).view(-1)
107 | return enhanced_audio, self.h.sampling_rate
--------------------------------------------------------------------------------
/modules/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def feature_loss(fmap_r, fmap_g):
5 | loss = 0
6 | for dr, dg in zip(fmap_r, fmap_g):
7 | for rl, gl in zip(dr, dg):
8 | rl = rl.float().detach()
9 | gl = gl.float()
10 | loss += torch.mean(torch.abs(rl - gl))
11 |
12 | return loss * 2
13 |
14 |
15 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16 | loss = 0
17 | r_losses = []
18 | g_losses = []
19 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20 | dr = dr.float()
21 | dg = dg.float()
22 | r_loss = torch.mean((1-dr)**2)
23 | g_loss = torch.mean(dg**2)
24 | loss += (r_loss + g_loss)
25 | r_losses.append(r_loss.item())
26 | g_losses.append(g_loss.item())
27 |
28 | return loss, r_losses, g_losses
29 |
30 |
31 | def generator_loss(disc_outputs):
32 | loss = 0
33 | gen_losses = []
34 | for dg in disc_outputs:
35 | dg = dg.float()
36 | l = torch.mean((1-dg)**2)
37 | gen_losses.append(l)
38 | loss += l
39 |
40 | return loss, gen_losses
41 |
42 |
43 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44 | """
45 | z_p, logs_q: [b, h, t_t]
46 | m_p, logs_p: [b, h, t_t]
47 | """
48 | z_p = z_p.float()
49 | logs_q = logs_q.float()
50 | m_p = m_p.float()
51 | logs_p = logs_p.float()
52 | z_mask = z_mask.float()
53 | #print(logs_p)
54 | kl = logs_p - logs_q - 0.5
55 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
56 | kl = torch.sum(kl * z_mask)
57 | l = kl / torch.sum(z_mask)
58 | return l
59 |
--------------------------------------------------------------------------------
/modules/mel_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from librosa.filters import mel as librosa_mel_fn
4 |
5 | MAX_WAV_VALUE = 32768.0
6 |
7 |
8 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
9 | """
10 | PARAMS
11 | ------
12 | C: compression factor
13 | """
14 | return torch.log(torch.clamp(x, min=clip_val) * C)
15 |
16 |
17 | def dynamic_range_decompression_torch(x, C=1):
18 | """
19 | PARAMS
20 | ------
21 | C: compression factor used to compress
22 | """
23 | return torch.exp(x) / C
24 |
25 |
26 | def spectral_normalize_torch(magnitudes):
27 | output = dynamic_range_compression_torch(magnitudes)
28 | return output
29 |
30 |
31 | def spectral_de_normalize_torch(magnitudes):
32 | output = dynamic_range_decompression_torch(magnitudes)
33 | return output
34 |
35 |
36 | mel_basis = {}
37 | hann_window = {}
38 |
39 |
40 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
41 | if torch.min(y) < -1.:
42 | print('min value is ', torch.min(y))
43 | if torch.max(y) > 1.:
44 | print('max value is ', torch.max(y))
45 |
46 | global hann_window
47 | dtype_device = str(y.dtype) + '_' + str(y.device)
48 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
49 | if wnsize_dtype_device not in hann_window:
50 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
51 |
52 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
53 | y = y.squeeze(1)
54 |
55 | y_dtype = y.dtype
56 | if y.dtype == torch.bfloat16:
57 | y = y.to(torch.float32)
58 |
59 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
60 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
61 | spec = torch.view_as_real(spec).to(y_dtype)
62 |
63 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
64 | return spec
65 |
66 |
67 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
68 | global mel_basis
69 | dtype_device = str(spec.dtype) + '_' + str(spec.device)
70 | fmax_dtype_device = str(fmax) + '_' + dtype_device
71 | if fmax_dtype_device not in mel_basis:
72 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
73 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
74 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
75 | spec = spectral_normalize_torch(spec)
76 | return spec
77 |
78 |
79 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
80 | spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center)
81 | spec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax)
82 |
83 | return spec
84 |
--------------------------------------------------------------------------------
/onnx_export.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import torch
5 |
6 | import utils
7 | from onnxexport.model_onnx_speaker_mix import SynthesizerTrn
8 |
9 | parser = argparse.ArgumentParser(description='SoVitsSvc OnnxExport')
10 |
11 | def OnnxExport(path=None):
12 | device = torch.device("cpu")
13 | hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
14 | SVCVITS = SynthesizerTrn(
15 | hps.data.filter_length // 2 + 1,
16 | hps.train.segment_size // hps.data.hop_length,
17 | **hps.model)
18 | _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None)
19 | _ = SVCVITS.eval().to(device)
20 | for i in SVCVITS.parameters():
21 | i.requires_grad = False
22 |
23 | num_frames = 200
24 |
25 | test_hidden_unit = torch.rand(1, num_frames, SVCVITS.gin_channels)
26 | test_pitch = torch.rand(1, num_frames)
27 | test_vol = torch.rand(1, num_frames)
28 | test_mel2ph = torch.LongTensor(torch.arange(0, num_frames)).unsqueeze(0)
29 | test_uv = torch.ones(1, num_frames, dtype=torch.float32)
30 | test_noise = torch.randn(1, 192, num_frames)
31 | test_sid = torch.LongTensor([0])
32 | export_mix = True
33 | if len(hps.spk) < 2:
34 | export_mix = False
35 |
36 | if export_mix:
37 | spk_mix = []
38 | n_spk = len(hps.spk)
39 | for i in range(n_spk):
40 | spk_mix.append(1.0/float(n_spk))
41 | test_sid = torch.tensor(spk_mix)
42 | SVCVITS.export_chara_mix(hps.spk)
43 | test_sid = test_sid.unsqueeze(0)
44 | test_sid = test_sid.repeat(num_frames, 1)
45 |
46 | SVCVITS.eval()
47 |
48 | if export_mix:
49 | daxes = {
50 | "c": [0, 1],
51 | "f0": [1],
52 | "mel2ph": [1],
53 | "uv": [1],
54 | "noise": [2],
55 | "sid":[0]
56 | }
57 | else:
58 | daxes = {
59 | "c": [0, 1],
60 | "f0": [1],
61 | "mel2ph": [1],
62 | "uv": [1],
63 | "noise": [2]
64 | }
65 |
66 | input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"]
67 | output_names = ["audio", ]
68 |
69 | if SVCVITS.vol_embedding:
70 | input_names.append("vol")
71 | vol_dadict = {"vol" : [1]}
72 | daxes.update(vol_dadict)
73 | test_inputs = (
74 | test_hidden_unit.to(device),
75 | test_pitch.to(device),
76 | test_mel2ph.to(device),
77 | test_uv.to(device),
78 | test_noise.to(device),
79 | test_sid.to(device),
80 | test_vol.to(device)
81 | )
82 | else:
83 | test_inputs = (
84 | test_hidden_unit.to(device),
85 | test_pitch.to(device),
86 | test_mel2ph.to(device),
87 | test_uv.to(device),
88 | test_noise.to(device),
89 | test_sid.to(device)
90 | )
91 |
92 | # SVCVITS = torch.jit.script(SVCVITS)
93 | SVCVITS(test_hidden_unit.to(device),
94 | test_pitch.to(device),
95 | test_mel2ph.to(device),
96 | test_uv.to(device),
97 | test_noise.to(device),
98 | test_sid.to(device),
99 | test_vol.to(device))
100 |
101 | SVCVITS.dec.OnnxExport()
102 |
103 | torch.onnx.export(
104 | SVCVITS,
105 | test_inputs,
106 | f"checkpoints/{path}/{path}_SoVits.onnx",
107 | dynamic_axes=daxes,
108 | do_constant_folding=False,
109 | opset_version=16,
110 | verbose=False,
111 | input_names=input_names,
112 | output_names=output_names
113 | )
114 |
115 | vec_lay = "layer-12" if SVCVITS.gin_channels == 768 else "layer-9"
116 | spklist = []
117 | for key in hps.spk.keys():
118 | spklist.append(key)
119 |
120 | MoeVSConf = {
121 | "Folder" : f"{path}",
122 | "Name" : f"{path}",
123 | "Type" : "SoVits",
124 | "Rate" : hps.data.sampling_rate,
125 | "Hop" : hps.data.hop_length,
126 | "Hubert": f"vec-{SVCVITS.gin_channels}-{vec_lay}",
127 | "SoVits4": True,
128 | "SoVits3": False,
129 | "CharaMix": export_mix,
130 | "Volume": SVCVITS.vol_embedding,
131 | "HiddenSize": SVCVITS.gin_channels,
132 | "Characters": spklist,
133 | "Cluster": ""
134 | }
135 |
136 | with open(f"checkpoints/{path}.json", 'w') as MoeVsConfFile:
137 | json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
138 |
139 |
140 | if __name__ == '__main__':
141 | parser.add_argument('-n', '--model_name', type=str, default="TransformerFlow", help='模型文件夹名(根目录下新建ckeckpoints文件夹,在此文件夹下建立一个新的文件夹,放置模型,该文件夹名即为此项)')
142 | args = parser.parse_args()
143 | path = args.model_name
144 | OnnxExport(path)
145 |
--------------------------------------------------------------------------------
/onnx_export_old.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import utils
4 | from onnxexport.model_onnx import SynthesizerTrn
5 |
6 |
7 | def main(NetExport):
8 | path = "SoVits4.0"
9 | if NetExport:
10 | device = torch.device("cpu")
11 | hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
12 | SVCVITS = SynthesizerTrn(
13 | hps.data.filter_length // 2 + 1,
14 | hps.train.segment_size // hps.data.hop_length,
15 | **hps.model)
16 | _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None)
17 | _ = SVCVITS.eval().to(device)
18 | for i in SVCVITS.parameters():
19 | i.requires_grad = False
20 |
21 | n_frame = 10
22 | test_hidden_unit = torch.rand(1, n_frame, 256)
23 | test_pitch = torch.rand(1, n_frame)
24 | test_mel2ph = torch.arange(0, n_frame, dtype=torch.int64)[None] # torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0)
25 | test_uv = torch.ones(1, n_frame, dtype=torch.float32)
26 | test_noise = torch.randn(1, 192, n_frame)
27 | test_sid = torch.LongTensor([0])
28 | input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"]
29 | output_names = ["audio", ]
30 |
31 | torch.onnx.export(SVCVITS,
32 | (
33 | test_hidden_unit.to(device),
34 | test_pitch.to(device),
35 | test_mel2ph.to(device),
36 | test_uv.to(device),
37 | test_noise.to(device),
38 | test_sid.to(device)
39 | ),
40 | f"checkpoints/{path}/model.onnx",
41 | dynamic_axes={
42 | "c": [0, 1],
43 | "f0": [1],
44 | "mel2ph": [1],
45 | "uv": [1],
46 | "noise": [2],
47 | },
48 | do_constant_folding=False,
49 | opset_version=16,
50 | verbose=False,
51 | input_names=input_names,
52 | output_names=output_names)
53 |
54 |
55 | if __name__ == '__main__':
56 | main(True)
57 |
--------------------------------------------------------------------------------
/preprocess_flist_config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import wave
6 | from random import shuffle
7 |
8 | from loguru import logger
9 | from tqdm import tqdm
10 |
11 | import diffusion.logger.utils as du
12 |
13 | pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
14 |
15 | def get_wav_duration(file_path):
16 | try:
17 | with wave.open(file_path, 'rb') as wav_file:
18 | # 获取音频帧数
19 | n_frames = wav_file.getnframes()
20 | # 获取采样率
21 | framerate = wav_file.getframerate()
22 | # 计算时长(秒)
23 | return n_frames / float(framerate)
24 | except Exception as e:
25 | logger.error(f"Reading {file_path}")
26 | raise e
27 |
28 | if __name__ == "__main__":
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list")
31 | parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list")
32 | parser.add_argument("--source_dir", type=str, default="./dataset/44k", help="path to source dir")
33 | parser.add_argument("--speech_encoder", type=str, default="vec768l12", help="choice a speech encoder|'vec768l12','vec256l9','hubertsoft','whisper-ppg','cnhubertlarge','dphubert','whisper-ppg-large','wavlmbase+'")
34 | parser.add_argument("--vol_aug", action="store_true", help="Whether to use volume embedding and volume augmentation")
35 | parser.add_argument("--tiny", action="store_true", help="Whether to train sovits tiny")
36 | args = parser.parse_args()
37 |
38 | config_template = json.load(open("configs_template/config_tiny_template.json")) if args.tiny else json.load(open("configs_template/config_template.json"))
39 | train = []
40 | val = []
41 | idx = 0
42 | spk_dict = {}
43 | spk_id = 0
44 |
45 | for speaker in tqdm(os.listdir(args.source_dir)):
46 | spk_dict[speaker] = spk_id
47 | spk_id += 1
48 | wavs = []
49 |
50 | for file_name in os.listdir(os.path.join(args.source_dir, speaker)):
51 | if not file_name.endswith("wav"):
52 | continue
53 | if file_name.startswith("."):
54 | continue
55 |
56 | file_path = "/".join([args.source_dir, speaker, file_name])
57 |
58 | if not pattern.match(file_name):
59 | logger.warning("Detected non-ASCII file name: " + file_path)
60 |
61 | if get_wav_duration(file_path) < 0.3:
62 | logger.info("Skip too short audio: " + file_path)
63 | continue
64 |
65 | wavs.append(file_path)
66 |
67 | shuffle(wavs)
68 | train += wavs[2:]
69 | val += wavs[:2]
70 |
71 | shuffle(train)
72 | shuffle(val)
73 |
74 | logger.info("Writing " + args.train_list)
75 | with open(args.train_list, "w") as f:
76 | for fname in tqdm(train):
77 | wavpath = fname
78 | f.write(wavpath + "\n")
79 |
80 | logger.info("Writing " + args.val_list)
81 | with open(args.val_list, "w") as f:
82 | for fname in tqdm(val):
83 | wavpath = fname
84 | f.write(wavpath + "\n")
85 |
86 |
87 | d_config_template = du.load_config("configs_template/diffusion_template.yaml")
88 | d_config_template["model"]["n_spk"] = spk_id
89 | d_config_template["data"]["encoder"] = args.speech_encoder
90 | d_config_template["spk"] = spk_dict
91 |
92 | config_template["spk"] = spk_dict
93 | config_template["model"]["n_speakers"] = spk_id
94 | config_template["model"]["speech_encoder"] = args.speech_encoder
95 |
96 | if args.speech_encoder == "vec768l12" or args.speech_encoder == "dphubert" or args.speech_encoder == "wavlmbase+":
97 | config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 768
98 | d_config_template["data"]["encoder_out_channels"] = 768
99 | elif args.speech_encoder == "vec256l9" or args.speech_encoder == 'hubertsoft':
100 | config_template["model"]["ssl_dim"] = config_template["model"]["gin_channels"] = 256
101 | d_config_template["data"]["encoder_out_channels"] = 256
102 | elif args.speech_encoder == "whisper-ppg" or args.speech_encoder == 'cnhubertlarge':
103 | config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1024
104 | d_config_template["data"]["encoder_out_channels"] = 1024
105 | elif args.speech_encoder == "whisper-ppg-large":
106 | config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1280
107 | d_config_template["data"]["encoder_out_channels"] = 1280
108 |
109 | if args.vol_aug:
110 | config_template["train"]["vol_aug"] = config_template["model"]["vol_embedding"] = True
111 |
112 | if args.tiny:
113 | config_template["model"]["filter_channels"] = 512
114 |
115 | logger.info("Writing to configs/config.json")
116 | with open("configs/config.json", "w") as f:
117 | json.dump(config_template, f, indent=2)
118 | logger.info("Writing to configs/diffusion.yaml")
119 | du.save_config("configs/diffusion.yaml",d_config_template)
120 |
--------------------------------------------------------------------------------
/pretrain/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/pretrain/__init__.py
--------------------------------------------------------------------------------
/pretrain/meta.py:
--------------------------------------------------------------------------------
1 | def download_dict():
2 | return {
3 | "vec768l12": {
4 | "url": "https://ibm.ent.box.com/shared/static/z1wgl1stco8ffooyatzdwsqn2psd9lrr",
5 | "output": "./pretrain/checkpoint_best_legacy_500.pt"
6 | },
7 | "vec256l9": {
8 | "url": "https://ibm.ent.box.com/shared/static/z1wgl1stco8ffooyatzdwsqn2psd9lrr",
9 | "output": "./pretrain/checkpoint_best_legacy_500.pt"
10 | },
11 | "hubertsoft": {
12 | "url": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt",
13 | "output": "./pretrain/hubert-soft-0d54a1f4.pt"
14 | },
15 | "whisper-ppg-small": {
16 | "url": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
17 | "output": "./pretrain/small.pt"
18 | },
19 | "whisper-ppg": {
20 | "url": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
21 | "output": "./pretrain/medium.pt"
22 | },
23 | "whisper-ppg-large": {
24 | "url": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
25 | "output": "./pretrain/large-v2.pt"
26 | }
27 | }
28 |
29 |
30 | def get_speech_encoder(config_path="configs/config.json"):
31 | import json
32 |
33 | with open(config_path, "r") as f:
34 | data = f.read()
35 | config = json.loads(data)
36 | speech_encoder = config["model"]["speech_encoder"]
37 | dict = download_dict()
38 |
39 | return dict[speech_encoder]["url"], dict[speech_encoder]["output"]
40 |
--------------------------------------------------------------------------------
/pretrain/nsf_hifigan/put_nsf_hifigan_ckpt_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/pretrain/nsf_hifigan/put_nsf_hifigan_ckpt_here
--------------------------------------------------------------------------------
/pretrain/put_hubert_ckpt_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/pretrain/put_hubert_ckpt_here
--------------------------------------------------------------------------------
/raw/put_raw_wav_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/raw/put_raw_wav_here
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ffmpeg-python
2 | Flask
3 | Flask_Cors
4 | gradio>=3.7.0
5 | numpy==1.23.5
6 | pyworld
7 | scipy==1.10.0
8 | SoundFile==0.12.1
9 | torch
10 | torchaudio
11 | torchcrepe
12 | tqdm
13 | rich
14 | loguru
15 | scikit-maad
16 | praat-parselmouth
17 | onnx
18 | onnxsim
19 | onnxoptimizer
20 | fairseq==0.12.2
21 | librosa==0.9.1
22 | tensorboard
23 | tensorboardX
24 | transformers
25 | edge_tts
26 | langdetect
27 | pyyaml
28 | pynvml
29 | faiss-cpu
30 | einops
31 | local_attention
--------------------------------------------------------------------------------
/requirements_onnx_encoder.txt:
--------------------------------------------------------------------------------
1 | Flask
2 | Flask_Cors
3 | gradio>=3.7.0
4 | numpy==1.23.0
5 | pyworld==0.2.5
6 | scipy==1.10.0
7 | SoundFile==0.12.1
8 | torch==1.13.1
9 | torchaudio==0.13.1
10 | torchcrepe
11 | tqdm
12 | rich.progress
13 | loguru
14 | scikit-maad
15 | praat-parselmouth
16 | onnx
17 | onnxsim
18 | onnxoptimizer
19 | onnxruntime-gpu
20 | librosa==0.9.1
21 | tensorboard
22 | tensorboardX
23 | edge_tts
24 | langdetect
25 | pyyaml
26 | pynvml
27 | transformers
28 | ffmpeg-python
29 | faiss-cpu
--------------------------------------------------------------------------------
/requirements_win.txt:
--------------------------------------------------------------------------------
1 | librosa==0.9.1
2 | fairseq==0.12.2
3 | ffmpeg-python
4 | Flask==2.1.2
5 | Flask_Cors==3.0.10
6 | gradio>=3.7.0
7 | numpy
8 | playsound==1.3.0
9 | PyAudio==0.2.12
10 | pydub==0.25.1
11 | pyworld==0.3.0
12 | requests==2.28.1
13 | scipy==1.7.3
14 | sounddevice==0.4.5
15 | SoundFile==0.10.3.post1
16 | starlette==0.19.1
17 | tqdm==4.63.0
18 | rich
19 | loguru
20 | torchcrepe
21 | scikit-maad
22 | praat-parselmouth
23 | onnx
24 | onnxsim
25 | onnxoptimizer
26 | tensorboard
27 | tensorboardX
28 | transformers
29 | edge_tts
30 | langdetect
31 | pyyaml
32 | pynvml
33 | faiss-cpu
34 |
--------------------------------------------------------------------------------
/resample.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import concurrent.futures
3 | import os
4 | from concurrent.futures import ProcessPoolExecutor
5 | from multiprocessing import cpu_count
6 |
7 | import librosa
8 | import numpy as np
9 | from rich.progress import track
10 | from scipy.io import wavfile
11 |
12 |
13 | def load_wav(wav_path):
14 | return librosa.load(wav_path, sr=None)
15 |
16 |
17 | def trim_wav(wav, top_db=40):
18 | return librosa.effects.trim(wav, top_db=top_db)
19 |
20 |
21 | def normalize_peak(wav, threshold=1.0):
22 | peak = np.abs(wav).max()
23 | if peak > threshold:
24 | wav = 0.98 * wav / peak
25 | return wav
26 |
27 |
28 | def resample_wav(wav, sr, target_sr):
29 | return librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
30 |
31 |
32 | def save_wav_to_path(wav, save_path, sr):
33 | wavfile.write(
34 | save_path,
35 | sr,
36 | (wav * np.iinfo(np.int16).max).astype(np.int16)
37 | )
38 |
39 |
40 | def process(item):
41 | spkdir, wav_name, args = item
42 | speaker = spkdir.replace("\\", "/").split("/")[-1]
43 |
44 | wav_path = os.path.join(args.in_dir, speaker, wav_name)
45 | if os.path.exists(wav_path) and '.wav' in wav_path:
46 | os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True)
47 |
48 | wav, sr = load_wav(wav_path)
49 | wav, _ = trim_wav(wav)
50 | wav = normalize_peak(wav)
51 | resampled_wav = resample_wav(wav, sr, args.sr2)
52 |
53 | if not args.skip_loudnorm:
54 | resampled_wav /= np.max(np.abs(resampled_wav))
55 |
56 | save_path2 = os.path.join(args.out_dir2, speaker, wav_name)
57 | save_wav_to_path(resampled_wav, save_path2, args.sr2)
58 |
59 |
60 | """
61 | def process_all_speakers():
62 | process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1)
63 |
64 | with ThreadPoolExecutor(max_workers=process_count) as executor:
65 | for speaker in speakers:
66 | spk_dir = os.path.join(args.in_dir, speaker)
67 | if os.path.isdir(spk_dir):
68 | print(spk_dir)
69 | futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")]
70 | for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
71 | pass
72 | """
73 | # multi process
74 |
75 |
76 | def process_all_speakers():
77 | process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1)
78 | with ProcessPoolExecutor(max_workers=process_count) as executor:
79 | for speaker in speakers:
80 | spk_dir = os.path.join(args.in_dir, speaker)
81 | if os.path.isdir(spk_dir):
82 | print(spk_dir)
83 | futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")]
84 | for _ in track(concurrent.futures.as_completed(futures), total=len(futures), description="resampling:"):
85 | pass
86 |
87 |
88 | if __name__ == "__main__":
89 | parser = argparse.ArgumentParser()
90 | parser.add_argument("--sr2", type=int, default=44100, help="sampling rate")
91 | parser.add_argument("--in_dir", type=str, default="./dataset_raw", help="path to source dir")
92 | parser.add_argument("--out_dir2", type=str, default="./dataset/44k", help="path to target dir")
93 | parser.add_argument("--skip_loudnorm", action="store_true", help="Skip loudness matching if you have done it")
94 | args = parser.parse_args()
95 |
96 | print(f"CPU count: {cpu_count()}")
97 | speakers = os.listdir(args.in_dir)
98 | process_all_speakers()
99 |
--------------------------------------------------------------------------------
/shadowdiffusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/shadowdiffusion.png
--------------------------------------------------------------------------------
/spkmix.py:
--------------------------------------------------------------------------------
1 | # 角色混合轨道 编写规则:
2 | # 角色ID : [[起始时间1, 终止时间1, 起始数值1, 起始数值1], [起始时间2, 终止时间2, 起始数值2, 起始数值2]]
3 | # 起始时间和前一个的终止时间必须相同,第一个起始时间必须为0,最后一个终止时间必须为1 (时间的范围为0-1)
4 | # 全部角色必须填写,不使用的角色填[[0., 1., 0., 0.]]即可
5 | # 融合数值可以随便填,在指定的时间段内从起始数值线性变化为终止数值,内部会自动确保线性组合为1,可以放心使用
6 |
7 | spk_mix_map = {
8 | 0 : [[0., 0.5, 1, 0.5], [0.5, 1, 0.5, 1]],
9 | 1 : [[0., 0.35, 1, 0.5], [0.35, 0.75, 0.75, 1], [0.75, 1, 0.45, 1]],
10 | 2 : [[0., 0.35, 1, 0.5], [0.35, 0.75, 0.75, 1], [0.75, 1, 0.45, 1]]
11 | }
--------------------------------------------------------------------------------
/train_diff.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from loguru import logger
5 | from torch.optim import lr_scheduler
6 |
7 | from diffusion.data_loaders import get_data_loaders
8 | from diffusion.logger import utils
9 | from diffusion.solver import train
10 | from diffusion.unit2mel import Unit2Mel
11 | from diffusion.vocoder import Vocoder
12 |
13 |
14 | def parse_args(args=None, namespace=None):
15 | """Parse command-line arguments."""
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument(
18 | "-c",
19 | "--config",
20 | type=str,
21 | required=True,
22 | help="path to the config file")
23 | return parser.parse_args(args=args, namespace=namespace)
24 |
25 |
26 | if __name__ == '__main__':
27 | # parse commands
28 | cmd = parse_args()
29 |
30 | # load config
31 | args = utils.load_config(cmd.config)
32 | logger.info(' > config:'+ cmd.config)
33 | logger.info(' > exp:'+ args.env.expdir)
34 |
35 | # load vocoder
36 | vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device)
37 |
38 | # load model
39 | model = Unit2Mel(
40 | args.data.encoder_out_channels,
41 | args.model.n_spk,
42 | args.model.use_pitch_aug,
43 | vocoder.dimension,
44 | args.model.n_layers,
45 | args.model.n_chans,
46 | args.model.n_hidden,
47 | args.model.timesteps,
48 | args.model.k_step_max
49 | )
50 |
51 | logger.info(f' > Now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}')
52 |
53 | # load parameters
54 | optimizer = torch.optim.AdamW(model.parameters())
55 | initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device)
56 | for param_group in optimizer.param_groups:
57 | param_group['initial_lr'] = args.train.lr
58 | param_group['lr'] = args.train.lr * (args.train.gamma ** max(((initial_global_step-2)//args.train.decay_step),0) )
59 | param_group['weight_decay'] = args.train.weight_decay
60 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma,last_epoch=initial_global_step-2)
61 |
62 | # device
63 | if args.device == 'cuda':
64 | torch.cuda.set_device(args.env.gpu_id)
65 | model.to(args.device)
66 |
67 | for state in optimizer.state.values():
68 | for k, v in state.items():
69 | if torch.is_tensor(v):
70 | state[k] = v.to(args.device)
71 |
72 | # datas
73 | loader_train, loader_valid = get_data_loaders(args, whole_audio=False)
74 |
75 | # run
76 | train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid)
77 |
78 |
--------------------------------------------------------------------------------
/train_index.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 |
5 | import utils
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument(
10 | "--root_dir", type=str, default="dataset/44k", help="path to root dir"
11 | )
12 | parser.add_argument('-c', '--config', type=str, default="./configs/config.json",
13 | help='JSON file for configuration')
14 | parser.add_argument(
15 | "--output_dir", type=str, default="logs/44k", help="path to output dir"
16 | )
17 |
18 | args = parser.parse_args()
19 |
20 | hps = utils.get_hparams_from_file(args.config)
21 | spk_dic = hps.spk
22 | result = {}
23 |
24 | for k,v in spk_dic.items():
25 | print(f"now, index {k} feature...")
26 | index = utils.train_index(k,args.root_dir)
27 | result[v] = index
28 |
29 | with open(os.path.join(args.output_dir,"feature_and_index.pkl"),"wb") as f:
30 | pickle.dump(result,f)
--------------------------------------------------------------------------------
/trained/put_trained_checkpoints_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/trained/put_trained_checkpoints_here
--------------------------------------------------------------------------------
/vdecoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vdecoder/__init__.py
--------------------------------------------------------------------------------
/vdecoder/hifigan/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
10 |
11 | def build_env(config, config_name, path):
12 | t_path = os.path.join(path, config_name)
13 | if config != t_path:
14 | os.makedirs(path, exist_ok=True)
15 | shutil.copyfile(config, os.path.join(path, config_name))
16 |
--------------------------------------------------------------------------------
/vdecoder/hifigan/nvSTFT.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | import soundfile as sf
6 | import torch
7 | import torch.utils.data
8 | from librosa.filters import mel as librosa_mel_fn
9 |
10 | os.environ["LRU_CACHE_CAPACITY"] = "3"
11 |
12 | def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
13 | sampling_rate = None
14 | try:
15 | data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
16 | except Exception as ex:
17 | print(f"'{full_path}' failed to load.\nException:")
18 | print(ex)
19 | if return_empty_on_exception:
20 | return [], sampling_rate or target_sr or 32000
21 | else:
22 | raise Exception(ex)
23 |
24 | if len(data.shape) > 1:
25 | data = data[:, 0]
26 | assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
27 |
28 | if np.issubdtype(data.dtype, np.integer): # if audio data is type int
29 | max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
30 | else: # if audio data is type fp32
31 | max_mag = max(np.amax(data), -np.amin(data))
32 | max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
33 |
34 | data = torch.FloatTensor(data.astype(np.float32))/max_mag
35 |
36 | if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
37 | return [], sampling_rate or target_sr or 32000
38 | if target_sr is not None and sampling_rate != target_sr:
39 | data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
40 | sampling_rate = target_sr
41 |
42 | return data, sampling_rate
43 |
44 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
45 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
46 |
47 | def dynamic_range_decompression(x, C=1):
48 | return np.exp(x) / C
49 |
50 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
51 | return torch.log(torch.clamp(x, min=clip_val) * C)
52 |
53 | def dynamic_range_decompression_torch(x, C=1):
54 | return torch.exp(x) / C
55 |
56 | class STFT():
57 | def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
58 | self.target_sr = sr
59 |
60 | self.n_mels = n_mels
61 | self.n_fft = n_fft
62 | self.win_size = win_size
63 | self.hop_length = hop_length
64 | self.fmin = fmin
65 | self.fmax = fmax
66 | self.clip_val = clip_val
67 | self.mel_basis = {}
68 | self.hann_window = {}
69 |
70 | def get_mel(self, y, center=False):
71 | sampling_rate = self.target_sr
72 | n_mels = self.n_mels
73 | n_fft = self.n_fft
74 | win_size = self.win_size
75 | hop_length = self.hop_length
76 | fmin = self.fmin
77 | fmax = self.fmax
78 | clip_val = self.clip_val
79 |
80 | if torch.min(y) < -1.:
81 | print('min value is ', torch.min(y))
82 | if torch.max(y) > 1.:
83 | print('max value is ', torch.max(y))
84 |
85 | if fmax not in self.mel_basis:
86 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
87 | self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
88 | self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device)
89 |
90 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect')
91 | y = y.squeeze(1)
92 |
93 | spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)],
94 | center=center, pad_mode='reflect', normalized=False, onesided=True)
95 | # print(111,spec)
96 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
97 | # print(222,spec)
98 | spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec)
99 | # print(333,spec)
100 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
101 | # print(444,spec)
102 | return spec
103 |
104 | def __call__(self, audiopath):
105 | audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
106 | spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
107 | return spect
108 |
109 | stft = STFT()
110 |
--------------------------------------------------------------------------------
/vdecoder/hifigan/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | # matplotlib.use("Agg")
5 | import matplotlib.pylab as plt
6 | import torch
7 | from torch.nn.utils import weight_norm
8 |
9 |
10 | def plot_spectrogram(spectrogram):
11 | fig, ax = plt.subplots(figsize=(10, 2))
12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13 | interpolation='none')
14 | plt.colorbar(im, ax=ax)
15 |
16 | fig.canvas.draw()
17 | plt.close()
18 |
19 | return fig
20 |
21 |
22 | def init_weights(m, mean=0.0, std=0.01):
23 | classname = m.__class__.__name__
24 | if classname.find("Conv") != -1:
25 | m.weight.data.normal_(mean, std)
26 |
27 |
28 | def apply_weight_norm(m):
29 | classname = m.__class__.__name__
30 | if classname.find("Conv") != -1:
31 | weight_norm(m)
32 |
33 |
34 | def get_padding(kernel_size, dilation=1):
35 | return int((kernel_size*dilation - dilation)/2)
36 |
37 |
38 | def load_checkpoint(filepath, device):
39 | assert os.path.isfile(filepath)
40 | print("Loading '{}'".format(filepath))
41 | checkpoint_dict = torch.load(filepath, map_location=device)
42 | print("Complete.")
43 | return checkpoint_dict
44 |
45 |
46 | def save_checkpoint(filepath, obj):
47 | print("Saving checkpoint to {}".format(filepath))
48 | torch.save(obj, filepath)
49 | print("Complete.")
50 |
51 |
52 | def del_old_checkpoints(cp_dir, prefix, n_models=2):
53 | pattern = os.path.join(cp_dir, prefix + '????????')
54 | cp_list = glob.glob(pattern) # get checkpoint paths
55 | cp_list = sorted(cp_list)# sort by iter
56 | if len(cp_list) > n_models: # if more than n_models models are found
57 | for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
58 | open(cp, 'w').close()# empty file contents
59 | os.unlink(cp)# delete file (move to trash when using Colab)
60 |
61 |
62 | def scan_checkpoint(cp_dir, prefix):
63 | pattern = os.path.join(cp_dir, prefix + '????????')
64 | cp_list = glob.glob(pattern)
65 | if len(cp_list) == 0:
66 | return None
67 | return sorted(cp_list)[-1]
68 |
69 |
--------------------------------------------------------------------------------
/vdecoder/hifiganwithsnake/alias/__init__.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | from .act import * # noqa: F403
5 | from .filter import * # noqa: F403
6 | from .resample import * # noqa: F403
7 |
--------------------------------------------------------------------------------
/vdecoder/hifiganwithsnake/alias/act.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch import pow, sin
8 | from torch.nn import Parameter
9 |
10 | from .resample import DownSample1d, UpSample1d
11 |
12 |
13 | class Activation1d(nn.Module):
14 | def __init__(self,
15 | activation,
16 | up_ratio: int = 2,
17 | down_ratio: int = 2,
18 | up_kernel_size: int = 12,
19 | down_kernel_size: int = 12):
20 | super().__init__()
21 | self.up_ratio = up_ratio
22 | self.down_ratio = down_ratio
23 | self.act = activation
24 | self.upsample = UpSample1d(up_ratio, up_kernel_size)
25 | self.downsample = DownSample1d(down_ratio, down_kernel_size)
26 |
27 | # x: [B,C,T]
28 | def forward(self, x):
29 | x = self.upsample(x)
30 | x = self.act(x)
31 | x = self.downsample(x)
32 |
33 | return x
34 |
35 |
36 | class SnakeBeta(nn.Module):
37 | '''
38 | A modified Snake function which uses separate parameters for the magnitude of the periodic components
39 | Shape:
40 | - Input: (B, C, T)
41 | - Output: (B, C, T), same shape as the input
42 | Parameters:
43 | - alpha - trainable parameter that controls frequency
44 | - beta - trainable parameter that controls magnitude
45 | References:
46 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
47 | https://arxiv.org/abs/2006.08195
48 | Examples:
49 | >>> a1 = snakebeta(256)
50 | >>> x = torch.randn(256)
51 | >>> x = a1(x)
52 | '''
53 |
54 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
55 | '''
56 | Initialization.
57 | INPUT:
58 | - in_features: shape of the input
59 | - alpha - trainable parameter that controls frequency
60 | - beta - trainable parameter that controls magnitude
61 | alpha is initialized to 1 by default, higher values = higher-frequency.
62 | beta is initialized to 1 by default, higher values = higher-magnitude.
63 | alpha will be trained along with the rest of your model.
64 | '''
65 | super(SnakeBeta, self).__init__()
66 | self.in_features = in_features
67 | # initialize alpha
68 | self.alpha_logscale = alpha_logscale
69 | if self.alpha_logscale: # log scale alphas initialized to zeros
70 | self.alpha = Parameter(torch.zeros(in_features) * alpha)
71 | self.beta = Parameter(torch.zeros(in_features) * alpha)
72 | else: # linear scale alphas initialized to ones
73 | self.alpha = Parameter(torch.ones(in_features) * alpha)
74 | self.beta = Parameter(torch.ones(in_features) * alpha)
75 | self.alpha.requires_grad = alpha_trainable
76 | self.beta.requires_grad = alpha_trainable
77 | self.no_div_by_zero = 0.000000001
78 |
79 | def forward(self, x):
80 | '''
81 | Forward pass of the function.
82 | Applies the function to the input elementwise.
83 | SnakeBeta = x + 1/b * sin^2 (xa)
84 | '''
85 | alpha = self.alpha.unsqueeze(
86 | 0).unsqueeze(-1) # line up with x to [B, C, T]
87 | beta = self.beta.unsqueeze(0).unsqueeze(-1)
88 | if self.alpha_logscale:
89 | alpha = torch.exp(alpha)
90 | beta = torch.exp(beta)
91 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
92 | return x
93 |
94 |
95 | class Mish(nn.Module):
96 | """
97 | Mish activation function is proposed in "Mish: A Self
98 | Regularized Non-Monotonic Neural Activation Function"
99 | paper, https://arxiv.org/abs/1908.08681.
100 | """
101 |
102 | def __init__(self):
103 | super().__init__()
104 |
105 | def forward(self, x):
106 | return x * torch.tanh(F.softplus(x))
107 |
108 |
109 | class SnakeAlias(nn.Module):
110 | def __init__(self,
111 | channels,
112 | up_ratio: int = 2,
113 | down_ratio: int = 2,
114 | up_kernel_size: int = 12,
115 | down_kernel_size: int = 12,
116 | C = None):
117 | super().__init__()
118 | self.up_ratio = up_ratio
119 | self.down_ratio = down_ratio
120 | self.act = SnakeBeta(channels, alpha_logscale=True)
121 | self.upsample = UpSample1d(up_ratio, up_kernel_size, C)
122 | self.downsample = DownSample1d(down_ratio, down_kernel_size, C)
123 |
124 | # x: [B,C,T]
125 | def forward(self, x, C=None):
126 | x = self.upsample(x, C)
127 | x = self.act(x)
128 | x = self.downsample(x)
129 |
130 | return x
--------------------------------------------------------------------------------
/vdecoder/hifiganwithsnake/alias/filter.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | if 'sinc' in dir(torch):
11 | sinc = torch.sinc
12 | else:
13 | # This code is adopted from adefossez's julius.core.sinc under the MIT License
14 | # https://adefossez.github.io/julius/julius/core.html
15 | # LICENSE is in incl_licenses directory.
16 | def sinc(x: torch.Tensor):
17 | """
18 | Implementation of sinc, i.e. sin(pi * x) / (pi * x)
19 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
20 | """
21 | return torch.where(x == 0,
22 | torch.tensor(1., device=x.device, dtype=x.dtype),
23 | torch.sin(math.pi * x) / math.pi / x)
24 |
25 |
26 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
27 | # https://adefossez.github.io/julius/julius/lowpass.html
28 | # LICENSE is in incl_licenses directory.
29 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
30 | even = (kernel_size % 2 == 0)
31 | half_size = kernel_size // 2
32 |
33 | #For kaiser window
34 | delta_f = 4 * half_width
35 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
36 | if A > 50.:
37 | beta = 0.1102 * (A - 8.7)
38 | elif A >= 21.:
39 | beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
40 | else:
41 | beta = 0.
42 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
43 |
44 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
45 | if even:
46 | time = (torch.arange(-half_size, half_size) + 0.5)
47 | else:
48 | time = torch.arange(kernel_size) - half_size
49 | if cutoff == 0:
50 | filter_ = torch.zeros_like(time)
51 | else:
52 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
53 | # Normalize filter to have sum = 1, otherwise we will have a small leakage
54 | # of the constant component in the input signal.
55 | filter_ /= filter_.sum()
56 | filter = filter_.view(1, 1, kernel_size)
57 |
58 | return filter
59 |
60 |
61 | class LowPassFilter1d(nn.Module):
62 | def __init__(self,
63 | cutoff=0.5,
64 | half_width=0.6,
65 | stride: int = 1,
66 | padding: bool = True,
67 | padding_mode: str = 'replicate',
68 | kernel_size: int = 12,
69 | C=None):
70 | # kernel_size should be even number for stylegan3 setup,
71 | # in this implementation, odd number is also possible.
72 | super().__init__()
73 | if cutoff < -0.:
74 | raise ValueError("Minimum cutoff must be larger than zero.")
75 | if cutoff > 0.5:
76 | raise ValueError("A cutoff above 0.5 does not make sense.")
77 | self.kernel_size = kernel_size
78 | self.even = (kernel_size % 2 == 0)
79 | self.pad_left = kernel_size // 2 - int(self.even)
80 | self.pad_right = kernel_size // 2
81 | self.stride = stride
82 | self.padding = padding
83 | self.padding_mode = padding_mode
84 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
85 | self.register_buffer("filter", filter)
86 | self.conv1d_block = None
87 | if C is not None:
88 | self.conv1d_block = [nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False),]
89 | self.conv1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1))
90 | self.conv1d_block[0].requires_grad_(False)
91 |
92 | #input [B, C, T]
93 | def forward(self, x):
94 | if self.conv1d_block[0].weight.device != x.device:
95 | self.conv1d_block[0] = self.conv1d_block[0].to(x.device)
96 | if self.conv1d_block is None:
97 | _, C, _ = x.shape
98 |
99 | if self.padding:
100 | x = F.pad(x, (self.pad_left, self.pad_right),
101 | mode=self.padding_mode)
102 | out = F.conv1d(x, self.filter.expand(C, -1, -1),
103 | stride=self.stride, groups=C)
104 | else:
105 | if self.padding:
106 | x = F.pad(x, (self.pad_left, self.pad_right),
107 | mode=self.padding_mode)
108 | out = self.conv1d_block[0](x)
109 |
110 | return out
--------------------------------------------------------------------------------
/vdecoder/hifiganwithsnake/alias/resample.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | from .filter import LowPassFilter1d, kaiser_sinc_filter1d
8 |
9 |
10 | class UpSample1d(nn.Module):
11 | def __init__(self, ratio=2, kernel_size=None, C=None):
12 | super().__init__()
13 | self.ratio = ratio
14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15 | self.stride = ratio
16 | self.pad = self.kernel_size // ratio - 1
17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20 | half_width=0.6 / ratio,
21 | kernel_size=self.kernel_size)
22 | self.register_buffer("filter", filter)
23 | self.conv_transpose1d_block = None
24 | if C is not None:
25 | self.conv_transpose1d_block = [nn.ConvTranspose1d(C,
26 | C,
27 | kernel_size=self.kernel_size,
28 | stride=self.stride,
29 | groups=C,
30 | bias=False
31 | ),]
32 | self.conv_transpose1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1).clone())
33 | self.conv_transpose1d_block[0].requires_grad_(False)
34 |
35 |
36 |
37 | # x: [B, C, T]
38 | def forward(self, x, C=None):
39 | if self.conv_transpose1d_block[0].weight.device != x.device:
40 | self.conv_transpose1d_block[0] = self.conv_transpose1d_block[0].to(x.device)
41 | if self.conv_transpose1d_block is None:
42 | if C is None:
43 | _, C, _ = x.shape
44 | # print("snake.conv_t.in:",x.shape)
45 | x = F.pad(x, (self.pad, self.pad), mode='replicate')
46 | x = self.ratio * F.conv_transpose1d(
47 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
48 | # print("snake.conv_t.out:",x.shape)
49 | x = x[..., self.pad_left:-self.pad_right]
50 | else:
51 | x = F.pad(x, (self.pad, self.pad), mode='replicate')
52 | x = self.ratio * self.conv_transpose1d_block[0](x)
53 | x = x[..., self.pad_left:-self.pad_right]
54 | return x
55 |
56 |
57 | class DownSample1d(nn.Module):
58 | def __init__(self, ratio=2, kernel_size=None, C=None):
59 | super().__init__()
60 | self.ratio = ratio
61 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
62 | self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
63 | half_width=0.6 / ratio,
64 | stride=ratio,
65 | kernel_size=self.kernel_size,
66 | C=C)
67 |
68 |
69 | def forward(self, x):
70 | xx = self.lowpass(x)
71 |
72 | return xx
--------------------------------------------------------------------------------
/vdecoder/hifiganwithsnake/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
10 |
11 | def build_env(config, config_name, path):
12 | t_path = os.path.join(path, config_name)
13 | if config != t_path:
14 | os.makedirs(path, exist_ok=True)
15 | shutil.copyfile(config, os.path.join(path, config_name))
16 |
--------------------------------------------------------------------------------
/vdecoder/hifiganwithsnake/nvSTFT.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | import soundfile as sf
6 | import torch
7 | import torch.utils.data
8 | from librosa.filters import mel as librosa_mel_fn
9 |
10 | os.environ["LRU_CACHE_CAPACITY"] = "3"
11 |
12 | def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
13 | sampling_rate = None
14 | try:
15 | data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
16 | except Exception as ex:
17 | print(f"'{full_path}' failed to load.\nException:")
18 | print(ex)
19 | if return_empty_on_exception:
20 | return [], sampling_rate or target_sr or 32000
21 | else:
22 | raise Exception(ex)
23 |
24 | if len(data.shape) > 1:
25 | data = data[:, 0]
26 | assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
27 |
28 | if np.issubdtype(data.dtype, np.integer): # if audio data is type int
29 | max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
30 | else: # if audio data is type fp32
31 | max_mag = max(np.amax(data), -np.amin(data))
32 | max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
33 |
34 | data = torch.FloatTensor(data.astype(np.float32))/max_mag
35 |
36 | if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
37 | return [], sampling_rate or target_sr or 32000
38 | if target_sr is not None and sampling_rate != target_sr:
39 | data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
40 | sampling_rate = target_sr
41 |
42 | return data, sampling_rate
43 |
44 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
45 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
46 |
47 | def dynamic_range_decompression(x, C=1):
48 | return np.exp(x) / C
49 |
50 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
51 | return torch.log(torch.clamp(x, min=clip_val) * C)
52 |
53 | def dynamic_range_decompression_torch(x, C=1):
54 | return torch.exp(x) / C
55 |
56 | class STFT():
57 | def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
58 | self.target_sr = sr
59 |
60 | self.n_mels = n_mels
61 | self.n_fft = n_fft
62 | self.win_size = win_size
63 | self.hop_length = hop_length
64 | self.fmin = fmin
65 | self.fmax = fmax
66 | self.clip_val = clip_val
67 | self.mel_basis = {}
68 | self.hann_window = {}
69 |
70 | def get_mel(self, y, center=False):
71 | sampling_rate = self.target_sr
72 | n_mels = self.n_mels
73 | n_fft = self.n_fft
74 | win_size = self.win_size
75 | hop_length = self.hop_length
76 | fmin = self.fmin
77 | fmax = self.fmax
78 | clip_val = self.clip_val
79 |
80 | if torch.min(y) < -1.:
81 | print('min value is ', torch.min(y))
82 | if torch.max(y) > 1.:
83 | print('max value is ', torch.max(y))
84 |
85 | if fmax not in self.mel_basis:
86 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
87 | self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
88 | self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device)
89 |
90 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect')
91 | y = y.squeeze(1)
92 |
93 | spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)],
94 | center=center, pad_mode='reflect', normalized=False, onesided=True)
95 | # print(111,spec)
96 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
97 | # print(222,spec)
98 | spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec)
99 | # print(333,spec)
100 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
101 | # print(444,spec)
102 | return spec
103 |
104 | def __call__(self, audiopath):
105 | audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
106 | spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
107 | return spect
108 |
109 | stft = STFT()
110 |
--------------------------------------------------------------------------------
/vdecoder/hifiganwithsnake/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | # matplotlib.use("Agg")
5 | import matplotlib.pylab as plt
6 | import torch
7 | from torch.nn.utils import weight_norm
8 |
9 |
10 | def plot_spectrogram(spectrogram):
11 | fig, ax = plt.subplots(figsize=(10, 2))
12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13 | interpolation='none')
14 | plt.colorbar(im, ax=ax)
15 |
16 | fig.canvas.draw()
17 | plt.close()
18 |
19 | return fig
20 |
21 |
22 | def init_weights(m, mean=0.0, std=0.01):
23 | classname = m.__class__.__name__
24 | if classname.find("Conv") != -1:
25 | m.weight.data.normal_(mean, std)
26 |
27 |
28 | def apply_weight_norm(m):
29 | classname = m.__class__.__name__
30 | if classname.find("Conv") != -1:
31 | weight_norm(m)
32 |
33 |
34 | def get_padding(kernel_size, dilation=1):
35 | return int((kernel_size*dilation - dilation)/2)
36 |
37 |
38 | def load_checkpoint(filepath, device):
39 | assert os.path.isfile(filepath)
40 | print("Loading '{}'".format(filepath))
41 | checkpoint_dict = torch.load(filepath, map_location=device)
42 | print("Complete.")
43 | return checkpoint_dict
44 |
45 |
46 | def save_checkpoint(filepath, obj):
47 | print("Saving checkpoint to {}".format(filepath))
48 | torch.save(obj, filepath)
49 | print("Complete.")
50 |
51 |
52 | def del_old_checkpoints(cp_dir, prefix, n_models=2):
53 | pattern = os.path.join(cp_dir, prefix + '????????')
54 | cp_list = glob.glob(pattern) # get checkpoint paths
55 | cp_list = sorted(cp_list)# sort by iter
56 | if len(cp_list) > n_models: # if more than n_models models are found
57 | for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
58 | open(cp, 'w').close()# empty file contents
59 | os.unlink(cp)# delete file (move to trash when using Colab)
60 |
61 |
62 | def scan_checkpoint(cp_dir, prefix):
63 | pattern = os.path.join(cp_dir, prefix + '????????')
64 | cp_list = glob.glob(pattern)
65 | if len(cp_list) == 0:
66 | return None
67 | return sorted(cp_list)[-1]
68 |
69 |
--------------------------------------------------------------------------------
/vdecoder/nsf_hifigan/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
10 |
11 | def build_env(config, config_name, path):
12 | t_path = os.path.join(path, config_name)
13 | if config != t_path:
14 | os.makedirs(path, exist_ok=True)
15 | shutil.copyfile(config, os.path.join(path, config_name))
16 |
--------------------------------------------------------------------------------
/vdecoder/nsf_hifigan/nvSTFT.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | import soundfile as sf
6 | import torch
7 | import torch.nn.functional as F
8 | import torch.utils.data
9 | from librosa.filters import mel as librosa_mel_fn
10 |
11 | os.environ["LRU_CACHE_CAPACITY"] = "3"
12 |
13 | def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
14 | sampling_rate = None
15 | try:
16 | data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
17 | except Exception as ex:
18 | print(f"'{full_path}' failed to load.\nException:")
19 | print(ex)
20 | if return_empty_on_exception:
21 | return [], sampling_rate or target_sr or 48000
22 | else:
23 | raise Exception(ex)
24 |
25 | if len(data.shape) > 1:
26 | data = data[:, 0]
27 | assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
28 |
29 | if np.issubdtype(data.dtype, np.integer): # if audio data is type int
30 | max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
31 | else: # if audio data is type fp32
32 | max_mag = max(np.amax(data), -np.amin(data))
33 | max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
34 |
35 | data = torch.FloatTensor(data.astype(np.float32))/max_mag
36 |
37 | if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
38 | return [], sampling_rate or target_sr or 48000
39 | if target_sr is not None and sampling_rate != target_sr:
40 | data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
41 | sampling_rate = target_sr
42 |
43 | return data, sampling_rate
44 |
45 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
46 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
47 |
48 | def dynamic_range_decompression(x, C=1):
49 | return np.exp(x) / C
50 |
51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
52 | return torch.log(torch.clamp(x, min=clip_val) * C)
53 |
54 | def dynamic_range_decompression_torch(x, C=1):
55 | return torch.exp(x) / C
56 |
57 | class STFT():
58 | def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
59 | self.target_sr = sr
60 |
61 | self.n_mels = n_mels
62 | self.n_fft = n_fft
63 | self.win_size = win_size
64 | self.hop_length = hop_length
65 | self.fmin = fmin
66 | self.fmax = fmax
67 | self.clip_val = clip_val
68 | self.mel_basis = {}
69 | self.hann_window = {}
70 |
71 | def get_mel(self, y, keyshift=0, speed=1, center=False):
72 | sampling_rate = self.target_sr
73 | n_mels = self.n_mels
74 | n_fft = self.n_fft
75 | win_size = self.win_size
76 | hop_length = self.hop_length
77 | fmin = self.fmin
78 | fmax = self.fmax
79 | clip_val = self.clip_val
80 |
81 | factor = 2 ** (keyshift / 12)
82 | n_fft_new = int(np.round(n_fft * factor))
83 | win_size_new = int(np.round(win_size * factor))
84 | hop_length_new = int(np.round(hop_length * speed))
85 |
86 | if torch.min(y) < -1.:
87 | print('min value is ', torch.min(y))
88 | if torch.max(y) > 1.:
89 | print('max value is ', torch.max(y))
90 |
91 | mel_basis_key = str(fmax)+'_'+str(y.device)
92 | if mel_basis_key not in self.mel_basis:
93 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
94 | self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
95 |
96 | keyshift_key = str(keyshift)+'_'+str(y.device)
97 | if keyshift_key not in self.hann_window:
98 | self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
99 |
100 | pad_left = (win_size_new - hop_length_new) //2
101 | pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left)
102 | if pad_right < y.size(-1):
103 | mode = 'reflect'
104 | else:
105 | mode = 'constant'
106 | y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode)
107 | y = y.squeeze(1)
108 |
109 | spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key],
110 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
111 | # print(111,spec)
112 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
113 | if keyshift != 0:
114 | size = n_fft // 2 + 1
115 | resize = spec.size(1)
116 | if resize < size:
117 | spec = F.pad(spec, (0, 0, 0, size-resize))
118 | spec = spec[:, :size, :] * win_size / win_size_new
119 |
120 | # print(222,spec)
121 | spec = torch.matmul(self.mel_basis[mel_basis_key], spec)
122 | # print(333,spec)
123 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
124 | # print(444,spec)
125 | return spec
126 |
127 | def __call__(self, audiopath):
128 | audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
129 | spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
130 | return spect
131 |
132 | stft = STFT()
133 |
--------------------------------------------------------------------------------
/vdecoder/nsf_hifigan/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | import matplotlib
5 | import matplotlib.pylab as plt
6 | import torch
7 | from torch.nn.utils import weight_norm
8 |
9 | matplotlib.use("Agg")
10 |
11 |
12 | def plot_spectrogram(spectrogram):
13 | fig, ax = plt.subplots(figsize=(10, 2))
14 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
15 | interpolation='none')
16 | plt.colorbar(im, ax=ax)
17 |
18 | fig.canvas.draw()
19 | plt.close()
20 |
21 | return fig
22 |
23 |
24 | def init_weights(m, mean=0.0, std=0.01):
25 | classname = m.__class__.__name__
26 | if classname.find("Conv") != -1:
27 | m.weight.data.normal_(mean, std)
28 |
29 |
30 | def apply_weight_norm(m):
31 | classname = m.__class__.__name__
32 | if classname.find("Conv") != -1:
33 | weight_norm(m)
34 |
35 |
36 | def get_padding(kernel_size, dilation=1):
37 | return int((kernel_size*dilation - dilation)/2)
38 |
39 |
40 | def load_checkpoint(filepath, device):
41 | assert os.path.isfile(filepath)
42 | print("Loading '{}'".format(filepath))
43 | checkpoint_dict = torch.load(filepath, map_location=device)
44 | print("Complete.")
45 | return checkpoint_dict
46 |
47 |
48 | def save_checkpoint(filepath, obj):
49 | print("Saving checkpoint to {}".format(filepath))
50 | torch.save(obj, filepath)
51 | print("Complete.")
52 |
53 |
54 | def del_old_checkpoints(cp_dir, prefix, n_models=2):
55 | pattern = os.path.join(cp_dir, prefix + '????????')
56 | cp_list = glob.glob(pattern) # get checkpoint paths
57 | cp_list = sorted(cp_list)# sort by iter
58 | if len(cp_list) > n_models: # if more than n_models models are found
59 | for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
60 | open(cp, 'w').close()# empty file contents
61 | os.unlink(cp)# delete file (move to trash when using Colab)
62 |
63 |
64 | def scan_checkpoint(cp_dir, prefix):
65 | pattern = os.path.join(cp_dir, prefix + '????????')
66 | cp_list = glob.glob(pattern)
67 | if len(cp_list) == 0:
68 | return None
69 | return sorted(cp_list)[-1]
70 |
71 |
--------------------------------------------------------------------------------
/vencoder/CNHubertLarge.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from fairseq import checkpoint_utils
3 |
4 | from vencoder.encoder import SpeechEncoder
5 |
6 |
7 | class CNHubertLarge(SpeechEncoder):
8 | def __init__(self, vec_path="pretrain/chinese-hubert-large-fairseq-ckpt.pt", device=None):
9 | super().__init__()
10 | print("load model(s) from {}".format(vec_path))
11 | self.hidden_dim = 1024
12 | models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
13 | [vec_path],
14 | suffix="",
15 | )
16 | if device is None:
17 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 | else:
19 | self.dev = torch.device(device)
20 | self.model = models[0].to(self.dev)
21 | self.model.eval()
22 |
23 | def encoder(self, wav):
24 | feats = wav
25 | if feats.dim() == 2: # double channels
26 | feats = feats.mean(-1)
27 | assert feats.dim() == 1, feats.dim()
28 | feats = feats.view(1, -1)
29 | padding_mask = torch.BoolTensor(feats.shape).fill_(False)
30 | inputs = {
31 | "source": feats.to(wav.device),
32 | "padding_mask": padding_mask.to(wav.device)
33 | }
34 | with torch.no_grad():
35 | logits = self.model.extract_features(**inputs)
36 | return logits[0].transpose(1, 2)
--------------------------------------------------------------------------------
/vencoder/ContentVec256L12_Onnx.py:
--------------------------------------------------------------------------------
1 | import onnxruntime
2 | import torch
3 |
4 | from vencoder.encoder import SpeechEncoder
5 |
6 |
7 | class ContentVec256L12_Onnx(SpeechEncoder):
8 | def __init__(self, vec_path="pretrain/vec-256-layer-12.onnx", device=None):
9 | super().__init__()
10 | print("load model(s) from {}".format(vec_path))
11 | self.hidden_dim = 256
12 | if device is None:
13 | self.dev = torch.device("cpu")
14 | else:
15 | self.dev = torch.device(device)
16 |
17 | if device == 'cuda' or device == torch.device("cuda"):
18 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
19 | else:
20 | providers = ['CPUExecutionProvider']
21 |
22 | self.model = onnxruntime.InferenceSession(vec_path, providers=providers)
23 |
24 | def encoder(self, wav):
25 | feats = wav
26 | if feats.dim() == 2: # double channels
27 | feats = feats.mean(-1)
28 | assert feats.dim() == 1, feats.dim()
29 | feats = feats.view(1, -1)
30 | feats = feats.unsqueeze(0).cpu().detach().numpy()
31 | onnx_input = {self.model.get_inputs()[0].name: feats}
32 | logits = self.model.run(None, onnx_input)
33 | return torch.tensor(logits[0]).transpose(1, 2).to(self.dev)
34 |
--------------------------------------------------------------------------------
/vencoder/ContentVec256L9.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from fairseq import checkpoint_utils
3 |
4 | from vencoder.encoder import SpeechEncoder
5 |
6 |
7 | class ContentVec256L9(SpeechEncoder):
8 | def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None):
9 | super().__init__()
10 | print("load model(s) from {}".format(vec_path))
11 | models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
12 | [vec_path],
13 | suffix="",
14 | )
15 | self.hidden_dim = 256
16 | if device is None:
17 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 | else:
19 | self.dev = torch.device(device)
20 | self.model = models[0].to(self.dev)
21 | self.model.eval()
22 |
23 | def encoder(self, wav):
24 | feats = wav
25 | if feats.dim() == 2: # double channels
26 | feats = feats.mean(-1)
27 | assert feats.dim() == 1, feats.dim()
28 | feats = feats.view(1, -1)
29 | padding_mask = torch.BoolTensor(feats.shape).fill_(False)
30 | inputs = {
31 | "source": feats.to(wav.device),
32 | "padding_mask": padding_mask.to(wav.device),
33 | "output_layer": 9, # layer 9
34 | }
35 | with torch.no_grad():
36 | logits = self.model.extract_features(**inputs)
37 | feats = self.model.final_proj(logits[0])
38 | return feats.transpose(1, 2)
39 |
--------------------------------------------------------------------------------
/vencoder/ContentVec256L9_Onnx.py:
--------------------------------------------------------------------------------
1 | import onnxruntime
2 | import torch
3 |
4 | from vencoder.encoder import SpeechEncoder
5 |
6 |
7 | class ContentVec256L9_Onnx(SpeechEncoder):
8 | def __init__(self, vec_path="pretrain/vec-256-layer-9.onnx", device=None):
9 | super().__init__()
10 | print("load model(s) from {}".format(vec_path))
11 | self.hidden_dim = 256
12 | if device is None:
13 | self.dev = torch.device("cpu")
14 | else:
15 | self.dev = torch.device(device)
16 | if device == 'cpu' or device == torch.device("cpu") or device is None:
17 | providers = ['CPUExecutionProvider']
18 | elif device == 'cuda' or device == torch.device("cuda"):
19 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
20 | self.model = onnxruntime.InferenceSession(vec_path, providers=providers)
21 |
22 | def encoder(self, wav):
23 | feats = wav
24 | if feats.dim() == 2: # double channels
25 | feats = feats.mean(-1)
26 | assert feats.dim() == 1, feats.dim()
27 | feats = feats.view(1, -1)
28 | feats = feats.unsqueeze(0).cpu().detach().numpy()
29 | onnx_input = {self.model.get_inputs()[0].name: feats}
30 | logits = self.model.run(None, onnx_input)
31 | return torch.tensor(logits[0]).transpose(1, 2).to(self.dev)
32 |
--------------------------------------------------------------------------------
/vencoder/ContentVec768L12.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from fairseq import checkpoint_utils
3 |
4 | from vencoder.encoder import SpeechEncoder
5 |
6 |
7 | class ContentVec768L12(SpeechEncoder):
8 | def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None):
9 | super().__init__()
10 | print("load model(s) from {}".format(vec_path))
11 | self.hidden_dim = 768
12 | models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
13 | [vec_path],
14 | suffix="",
15 | )
16 | if device is None:
17 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 | else:
19 | self.dev = torch.device(device)
20 | self.model = models[0].to(self.dev)
21 | self.model.eval()
22 |
23 | def encoder(self, wav):
24 | feats = wav
25 | if feats.dim() == 2: # double channels
26 | feats = feats.mean(-1)
27 | assert feats.dim() == 1, feats.dim()
28 | feats = feats.view(1, -1)
29 | padding_mask = torch.BoolTensor(feats.shape).fill_(False)
30 | inputs = {
31 | "source": feats.to(wav.device),
32 | "padding_mask": padding_mask.to(wav.device),
33 | "output_layer": 12, # layer 12
34 | }
35 | with torch.no_grad():
36 | logits = self.model.extract_features(**inputs)
37 | return logits[0].transpose(1, 2)
38 |
--------------------------------------------------------------------------------
/vencoder/ContentVec768L12_Onnx.py:
--------------------------------------------------------------------------------
1 | import onnxruntime
2 | import torch
3 |
4 | from vencoder.encoder import SpeechEncoder
5 |
6 |
7 | class ContentVec768L12_Onnx(SpeechEncoder):
8 | def __init__(self, vec_path="pretrain/vec-768-layer-12.onnx", device=None):
9 | super().__init__()
10 | print("load model(s) from {}".format(vec_path))
11 | self.hidden_dim = 768
12 | if device is None:
13 | self.dev = torch.device("cpu")
14 | else:
15 | self.dev = torch.device(device)
16 |
17 | if device == 'cuda' or device == torch.device("cuda"):
18 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
19 | else:
20 | providers = ['CPUExecutionProvider']
21 |
22 | self.model = onnxruntime.InferenceSession(vec_path, providers=providers)
23 |
24 | def encoder(self, wav):
25 | feats = wav
26 | if feats.dim() == 2: # double channels
27 | feats = feats.mean(-1)
28 | assert feats.dim() == 1, feats.dim()
29 | feats = feats.view(1, -1)
30 | feats = feats.unsqueeze(0).cpu().detach().numpy()
31 | onnx_input = {self.model.get_inputs()[0].name: feats}
32 | logits = self.model.run(None, onnx_input)
33 | return torch.tensor(logits[0]).transpose(1, 2).to(self.dev)
34 |
--------------------------------------------------------------------------------
/vencoder/ContentVec768L9_Onnx.py:
--------------------------------------------------------------------------------
1 | import onnxruntime
2 | import torch
3 |
4 | from vencoder.encoder import SpeechEncoder
5 |
6 |
7 | class ContentVec768L9_Onnx(SpeechEncoder):
8 | def __init__(self,vec_path = "pretrain/vec-768-layer-9.onnx",device=None):
9 | super().__init__()
10 | print("load model(s) from {}".format(vec_path))
11 | self.hidden_dim = 768
12 | if device is None:
13 | self.dev = torch.device("cpu")
14 | else:
15 | self.dev = torch.device(device)
16 |
17 | if device == 'cuda' or device == torch.device("cuda"):
18 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
19 | else:
20 | providers = ['CPUExecutionProvider']
21 |
22 | self.model = onnxruntime.InferenceSession(vec_path, providers=providers)
23 |
24 | def encoder(self, wav):
25 | feats = wav
26 | if feats.dim() == 2: # double channels
27 | feats = feats.mean(-1)
28 | assert feats.dim() == 1, feats.dim()
29 | feats = feats.view(1, -1)
30 | feats = feats.unsqueeze(0).cpu().detach().numpy()
31 | onnx_input = {self.model.get_inputs()[0].name: feats}
32 | logits = self.model.run(None, onnx_input)
33 | return torch.tensor(logits[0]).transpose(1, 2).to(self.dev)
34 |
--------------------------------------------------------------------------------
/vencoder/DPHubert.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from vencoder.dphubert.model import wav2vec2_model
4 | from vencoder.encoder import SpeechEncoder
5 |
6 |
7 | class DPHubert(SpeechEncoder):
8 | def __init__(self, vec_path="pretrain/DPHuBERT-sp0.75.pth", device=None):
9 | super().__init__()
10 | print("load model(s) from {}".format(vec_path))
11 | if device is None:
12 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 | else:
14 | self.dev = torch.device(device)
15 | ckpt = torch.load(vec_path)
16 | self.hidden_dim = 768
17 | self.model = wav2vec2_model(**ckpt["config"]).to(self.dev)
18 | self.model.load_state_dict(ckpt["state_dict"], strict=False)
19 |
20 | def encoder(self, wav):
21 | feats = wav
22 | if feats.dim() == 2: # double channels
23 | feats = feats.mean(-1)
24 | assert feats.dim() == 1, feats.dim()
25 | feats = feats[None, :]
26 | with torch.no_grad():
27 | with torch.inference_mode():
28 | units = self.model(feats)[0]
29 | return units.transpose(1,2)
30 |
--------------------------------------------------------------------------------
/vencoder/HubertSoft.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from vencoder.encoder import SpeechEncoder
4 | from vencoder.hubert import hubert_model
5 |
6 |
7 | class HubertSoft(SpeechEncoder):
8 | def __init__(self, vec_path="pretrain/hubert-soft-0d54a1f4.pt", device=None):
9 | super().__init__()
10 | print("load model(s) from {}".format(vec_path))
11 | hubert_soft = hubert_model.hubert_soft(vec_path)
12 | if device is None:
13 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 | else:
15 | self.dev = torch.device(device)
16 | self.hidden_dim = 256
17 | self.model = hubert_soft.to(self.dev)
18 |
19 | def encoder(self, wav):
20 | feats = wav
21 | if feats.dim() == 2: # double channels
22 | feats = feats.mean(-1)
23 | assert feats.dim() == 1, feats.dim()
24 | feats = feats[None,None,:]
25 | with torch.no_grad():
26 | with torch.inference_mode():
27 | units = self.model.units(feats)
28 | return units.transpose(1,2)
29 |
--------------------------------------------------------------------------------
/vencoder/HubertSoft_Onnx.py:
--------------------------------------------------------------------------------
1 | import onnxruntime
2 | import torch
3 |
4 | from vencoder.encoder import SpeechEncoder
5 |
6 |
7 | class HubertSoft_Onnx(SpeechEncoder):
8 | def __init__(self, vec_path="pretrain/hubert-soft.onnx", device=None):
9 | super().__init__()
10 | print("load model(s) from {}".format(vec_path))
11 | self.hidden_dim = 256
12 | if device is None:
13 | self.dev = torch.device("cpu")
14 | else:
15 | self.dev = torch.device(device)
16 |
17 | if device == 'cuda' or device == torch.device("cuda"):
18 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
19 | else:
20 | providers = ['CPUExecutionProvider']
21 |
22 | self.model = onnxruntime.InferenceSession(vec_path, providers=providers)
23 |
24 | def encoder(self, wav):
25 | feats = wav
26 | if feats.dim() == 2: # double channels
27 | feats = feats.mean(-1)
28 | assert feats.dim() == 1, feats.dim()
29 | feats = feats.view(1, -1)
30 | feats = feats.unsqueeze(0).cpu().detach().numpy()
31 | onnx_input = {self.model.get_inputs()[0].name: feats}
32 | logits = self.model.run(None, onnx_input)
33 | return torch.tensor(logits[0]).transpose(1, 2).to(self.dev)
34 |
--------------------------------------------------------------------------------
/vencoder/WavLMBasePlus.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from vencoder.encoder import SpeechEncoder
4 | from vencoder.wavlm.WavLM import WavLM, WavLMConfig
5 |
6 |
7 | class WavLMBasePlus(SpeechEncoder):
8 | def __init__(self, vec_path="pretrain/WavLM-Base+.pt", device=None):
9 | super().__init__()
10 | print("load model(s) from {}".format(vec_path))
11 | checkpoint = torch.load(vec_path)
12 | self.cfg = WavLMConfig(checkpoint['cfg'])
13 | if device is None:
14 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 | else:
16 | self.dev = torch.device(device)
17 | self.hidden_dim = self.cfg.encoder_embed_dim
18 | self.model = WavLM(self.cfg)
19 | self.model.load_state_dict(checkpoint['model'])
20 | self.model.to(self.dev).eval()
21 |
22 | def encoder(self, wav):
23 | feats = wav
24 | if feats.dim() == 2: # double channels
25 | feats = feats.mean(-1)
26 | assert feats.dim() == 1, feats.dim()
27 | if self.cfg.normalize:
28 | feats = torch.nn.functional.layer_norm(feats, feats.shape)
29 | with torch.no_grad():
30 | with torch.inference_mode():
31 | units = self.model.extract_features(feats[None, :])[0]
32 | return units.transpose(1, 2)
33 |
--------------------------------------------------------------------------------
/vencoder/WhisperPPG.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from vencoder.encoder import SpeechEncoder
4 | from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim
5 | from vencoder.whisper.model import ModelDimensions, Whisper
6 |
7 |
8 | class WhisperPPG(SpeechEncoder):
9 | def __init__(self, vec_path="pretrain/medium.pt", device=None):
10 | super().__init__()
11 | if device is None:
12 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 | else:
14 | self.dev = torch.device(device)
15 | checkpoint = torch.load(vec_path, map_location=device)
16 | dims = ModelDimensions(**checkpoint["dims"])
17 | model = Whisper(dims)
18 | model.load_state_dict(checkpoint["model_state_dict"])
19 | self.hidden_dim = dims
20 | self.model = model.to(self.dev)
21 |
22 | def encoder(self, wav):
23 | audio = wav
24 | audln = audio.shape[0]
25 | ppgln = audln // 320
26 | audio = pad_or_trim(audio)
27 | mel = log_mel_spectrogram(audio).to(self.dev)
28 | with torch.no_grad():
29 | ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy()
30 | ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev)
31 | return ppg[None, :, :].transpose(1, 2)
32 |
--------------------------------------------------------------------------------
/vencoder/WhisperPPGLarge.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from vencoder.encoder import SpeechEncoder
4 | from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim
5 | from vencoder.whisper.model import ModelDimensions, Whisper
6 |
7 |
8 | class WhisperPPGLarge(SpeechEncoder):
9 | def __init__(self, vec_path="pretrain/large-v2.pt", device=None):
10 | super().__init__()
11 | if device is None:
12 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 | else:
14 | self.dev = torch.device(device)
15 | checkpoint = torch.load(vec_path, map_location=device)
16 | dims = ModelDimensions(**checkpoint["dims"])
17 | model = Whisper(dims)
18 | model.load_state_dict(checkpoint["model_state_dict"])
19 | self.hidden_dim = dims
20 | self.model = model.to(self.dev)
21 |
22 | def encoder(self, wav):
23 | audio = wav
24 | audln = audio.shape[0]
25 | ppgln = audln // 320
26 | audio = pad_or_trim(audio)
27 | mel = log_mel_spectrogram(audio).to(self.dev)
28 | with torch.no_grad():
29 | ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy()
30 | ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev)
31 | return ppg[None, :, :].transpose(1, 2)
32 |
--------------------------------------------------------------------------------
/vencoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vencoder/__init__.py
--------------------------------------------------------------------------------
/vencoder/dphubert/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vencoder/dphubert/__init__.py
--------------------------------------------------------------------------------
/vencoder/dphubert/hardconcrete.py:
--------------------------------------------------------------------------------
1 | """Implementation of the hard Concrete distribution.
2 |
3 | Originally from:
4 | https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py
5 |
6 | """
7 |
8 | import math
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 |
14 | class HardConcrete(nn.Module):
15 | """A HarcConcrete module.
16 | Use this module to create a mask of size N, which you can
17 | then use to perform L0 regularization.
18 |
19 | To obtain a mask, simply run a forward pass through the module
20 | with no input data. The mask is sampled in training mode, and
21 | fixed during evaluation mode, e.g.:
22 |
23 | >>> module = HardConcrete(n_in=100)
24 | >>> mask = module()
25 | >>> norm = module.l0_norm()
26 | """
27 |
28 | def __init__(
29 | self,
30 | n_in: int,
31 | init_mean: float = 0.5,
32 | init_std: float = 0.01,
33 | temperature: float = 2/3, # from CoFi
34 | stretch: float = 0.1,
35 | eps: float = 1e-6
36 | ) -> None:
37 | """Initialize the HardConcrete module.
38 | Parameters
39 | ----------
40 | n_in : int
41 | The number of hard concrete variables in this mask.
42 | init_mean : float, optional
43 | Initial drop rate for hard concrete parameter,
44 | by default 0.5.,
45 | init_std: float, optional
46 | Used to initialize the hard concrete parameters,
47 | by default 0.01.
48 | temperature : float, optional
49 | Temperature used to control the sharpness of the
50 | distribution, by default 1.0
51 | stretch : float, optional
52 | Stretch the sampled value from [0, 1] to the interval
53 | [-stretch, 1 + stretch], by default 0.1.
54 | """
55 | super().__init__()
56 |
57 | self.n_in = n_in
58 | self.limit_l = -stretch
59 | self.limit_r = 1.0 + stretch
60 | self.log_alpha = nn.Parameter(torch.zeros(n_in))
61 | self.beta = temperature
62 | self.init_mean = init_mean
63 | self.init_std = init_std
64 | self.bias = -self.beta * math.log(-self.limit_l / self.limit_r)
65 |
66 | self.eps = eps
67 | self.compiled_mask = None
68 | self.reset_parameters()
69 |
70 | def reset_parameters(self):
71 | """Reset the parameters of this module."""
72 | self.compiled_mask = None
73 | mean = math.log(1 - self.init_mean) - math.log(self.init_mean)
74 | self.log_alpha.data.normal_(mean, self.init_std)
75 |
76 | def l0_norm(self) -> torch.Tensor:
77 | """Compute the expected L0 norm of this mask.
78 | Returns
79 | -------
80 | torch.Tensor
81 | The expected L0 norm.
82 | """
83 | return (self.log_alpha + self.bias).sigmoid().sum()
84 |
85 | def forward(self) -> torch.Tensor:
86 | """Sample a hard concrete mask.
87 | Returns
88 | -------
89 | torch.Tensor
90 | The sampled binary mask
91 | """
92 | if self.training:
93 | # Reset the compiled mask
94 | self.compiled_mask = None
95 | # Sample mask dynamically
96 | u = self.log_alpha.new(self.n_in).uniform_(self.eps, 1 - self.eps)
97 | s = torch.sigmoid((torch.log(u / (1 - u)) + self.log_alpha) / self.beta)
98 | s = s * (self.limit_r - self.limit_l) + self.limit_l
99 | mask = s.clamp(min=0., max=1.)
100 |
101 | else:
102 | # Compile new mask if not cached
103 | if self.compiled_mask is None:
104 | # Get expected sparsity
105 | expected_num_zeros = self.n_in - self.l0_norm().item()
106 | num_zeros = round(expected_num_zeros)
107 | # Approximate expected value of each mask variable z;
108 | # We use an empirically validated magic number 0.8
109 | soft_mask = torch.sigmoid(self.log_alpha / self.beta * 0.8)
110 | # Prune small values to set to 0
111 | _, indices = torch.topk(soft_mask, k=num_zeros, largest=False)
112 | soft_mask[indices] = 0.
113 | self.compiled_mask = soft_mask
114 | mask = self.compiled_mask
115 |
116 | return mask
117 |
118 | def extra_repr(self) -> str:
119 | return str(self.n_in)
120 |
121 | def __repr__(self) -> str:
122 | return "{}({})".format(self.__class__.__name__, self.extra_repr())
123 |
--------------------------------------------------------------------------------
/vencoder/dphubert/pruning_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for pruning."""
2 |
3 | from typing import Union
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: str):
10 | "Prune linear layer in place."
11 | # NOTE: weight: (out_features, in_features), bias: (out_features,)
12 | if dim == "input":
13 | dim = 1
14 | layer.in_features = len(index)
15 | elif dim == "output":
16 | dim = 0
17 | layer.out_features = len(index)
18 | else:
19 | raise ValueError
20 |
21 | layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach())
22 | if layer.bias is not None and dim == 0:
23 | layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach())
24 |
25 |
26 | def prune_conv1d_layer(layer: nn.Conv1d, index: torch.LongTensor, dim: str):
27 | """Prune conv1d in place."""
28 | # NOTE: weight: (out_channels, in_channels, kernel_size), bias: (out_channels,)
29 | if dim == "input":
30 | dim = 1
31 | layer.in_channels = len(index)
32 | elif dim == "output":
33 | dim = 0
34 | layer.out_channels = len(index)
35 | else:
36 | raise ValueError
37 |
38 | layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach())
39 | if layer.bias is not None and dim == 0:
40 | layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach())
41 |
42 |
43 | def prune_layer_norm(layernorm: Union[nn.LayerNorm, nn.GroupNorm], index: torch.LongTensor):
44 | """Prune layer norm or group norm in place."""
45 | layernorm.weight = nn.Parameter(layernorm.weight.index_select(0, index).clone().detach())
46 | layernorm.bias = nn.Parameter(layernorm.bias.index_select(0, index).clone().detach())
47 | if isinstance(layernorm, nn.LayerNorm):
48 | layernorm.normalized_shape = (len(index),)
49 | elif isinstance(layernorm, nn.GroupNorm):
50 | layernorm.num_groups = len(index)
51 | layernorm.num_channels = len(index)
52 |
--------------------------------------------------------------------------------
/vencoder/dphubert/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vencoder/dphubert/utils/__init__.py
--------------------------------------------------------------------------------
/vencoder/dphubert/utils/import_huggingface_wavlm.py:
--------------------------------------------------------------------------------
1 | """Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format.
2 |
3 | Originally from:
4 | https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/utils/import_huggingface.py
5 |
6 | """
7 |
8 | import logging
9 | from typing import Any, Dict
10 |
11 | from torch.nn import Module
12 |
13 | from ..model import Wav2Vec2Model, wav2vec2_model, wavlm_model
14 |
15 | _LG = logging.getLogger(__name__)
16 |
17 |
18 | def _get_config(cfg):
19 | config = {
20 | "extractor_mode": f"{cfg.feat_extract_norm}_norm",
21 | "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)),
22 | "extractor_conv_bias": cfg.conv_bias,
23 | "encoder_embed_dim": cfg.hidden_size,
24 | "encoder_projection_dropout": cfg.feat_proj_dropout,
25 | "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings,
26 | "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups,
27 | "encoder_num_layers": cfg.num_hidden_layers,
28 | "encoder_num_heads": cfg.num_attention_heads,
29 | "encoder_attention_dropout": cfg.attention_dropout,
30 | "encoder_ff_interm_features": cfg.intermediate_size,
31 | "encoder_ff_interm_dropout": cfg.activation_dropout,
32 | "encoder_dropout": cfg.hidden_dropout,
33 | "encoder_layer_norm_first": cfg.do_stable_layer_norm,
34 | "encoder_layer_drop": cfg.layerdrop,
35 | }
36 | return config
37 |
38 |
39 | def _get_config_wavlm(cfg):
40 | config = {
41 | "extractor_mode": f"{cfg.feat_extract_norm}_norm",
42 | "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)),
43 | "extractor_conv_bias": cfg.conv_bias,
44 | "encoder_embed_dim": cfg.hidden_size,
45 | "encoder_projection_dropout": cfg.feat_proj_dropout,
46 | "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings,
47 | "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups,
48 | "encoder_num_layers": cfg.num_hidden_layers,
49 | "encoder_use_attention": [True] * cfg.num_hidden_layers,
50 | "encoder_use_feed_forward": [True] * cfg.num_hidden_layers,
51 | "encoder_total_num_heads": [cfg.num_attention_heads for _ in range(cfg.num_hidden_layers)],
52 | "encoder_remaining_heads": [list(range(cfg.num_attention_heads)) for _ in range(cfg.num_hidden_layers)],
53 | "encoder_num_buckets": cfg.num_buckets,
54 | "encoder_max_distance": cfg.max_bucket_distance,
55 | "encoder_attention_dropout": cfg.attention_dropout,
56 | "encoder_ff_interm_features": [cfg.intermediate_size for _ in range(cfg.num_hidden_layers)],
57 | "encoder_ff_interm_dropout": cfg.activation_dropout,
58 | "encoder_dropout": cfg.hidden_dropout,
59 | "encoder_layer_norm_first": cfg.do_stable_layer_norm,
60 | "encoder_layer_drop": cfg.layerdrop,
61 | "normalize_waveform": cfg.feat_extract_norm == "layer",
62 | }
63 | return config
64 |
65 |
66 | def _build(config, original):
67 | is_for_ctc = original.__class__.__name__ in ["Wav2Vec2ForCTC", "WavLMForCTC"]
68 | if is_for_ctc:
69 | aux_num_out = original.config.vocab_size
70 | wav2vec2 = original.wav2vec2
71 | else:
72 | _LG.warning(
73 | "The model is not an instance of Wav2Vec2ForCTC or WavLMForCTC. " '"lm_head" module is not imported.'
74 | )
75 | aux_num_out = None
76 | wav2vec2 = original
77 | is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"]
78 | if is_wavlm:
79 | imported = wavlm_model(**config, aux_num_out=aux_num_out)
80 | else:
81 | imported = wav2vec2_model(**config, aux_num_out=aux_num_out)
82 | print(imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict(), strict=False))
83 | print(imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict(), strict=False))
84 | encoder_state_dict = wav2vec2.encoder.state_dict()
85 | if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model
86 | transform_wavlm_encoder_state(encoder_state_dict, config["encoder_num_layers"])
87 | print(imported.encoder.transformer.load_state_dict(encoder_state_dict, strict=False))
88 | if is_for_ctc:
89 | imported.aux.load_state_dict(original.lm_head.state_dict())
90 | return imported
91 |
92 |
93 | def transform_wavlm_encoder_state(state: Dict[str, Any], encoder_num_layers: int):
94 | """Converts WavLM encoder state from HuggingFace format. In particular, concatenates linear projection weights and
95 | biases to align with the structure of ``torch.nn.MultiheadAttention``.
96 | """
97 | pass
98 |
99 |
100 | def import_huggingface_model(original: Module) -> Wav2Vec2Model:
101 | """Builds :class:`Wav2Vec2Model` from the corresponding model object of
102 | `Transformers `_.
103 |
104 | Args:
105 | original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``.
106 |
107 | Returns:
108 | Wav2Vec2Model: Imported model.
109 |
110 | Example
111 | >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model
112 | >>>
113 | >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
114 | >>> model = import_huggingface_model(original)
115 | >>>
116 | >>> waveforms, _ = torchaudio.load("audio.wav")
117 | >>> logits, _ = model(waveforms)
118 | """
119 | _LG.info("Importing model.")
120 | _LG.info("Loading model configuration.")
121 | is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"]
122 | if is_wavlm:
123 | config = _get_config_wavlm(original.config)
124 | else:
125 | config = _get_config(original.config)
126 | _LG.debug(" - config: %s", config)
127 | _LG.info("Building model.")
128 | imported = _build(config, original)
129 | return imported
130 |
--------------------------------------------------------------------------------
/vencoder/encoder.py:
--------------------------------------------------------------------------------
1 | class SpeechEncoder(object):
2 | def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None):
3 | self.model = None # This is Model
4 | self.hidden_dim = 768
5 | pass
6 |
7 |
8 | def encoder(self, wav):
9 | """
10 | input: wav:[signal_length]
11 | output: embedding:[batchsize,hidden_dim,wav_frame]
12 | """
13 | pass
14 |
--------------------------------------------------------------------------------
/vencoder/hubert/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vencoder/hubert/__init__.py
--------------------------------------------------------------------------------
/vencoder/whisper/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vencoder/whisper/__init__.py
--------------------------------------------------------------------------------
/vencoder/whisper/audio.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 | from typing import Union
3 |
4 | import ffmpeg
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from librosa.filters import mel as librosa_mel_fn
9 |
10 | from .utils import exact_div
11 |
12 | # hard-coded audio hyperparameters
13 | SAMPLE_RATE = 16000
14 | N_FFT = 400
15 | N_MELS = 80
16 | HOP_LENGTH = 160
17 | CHUNK_LENGTH = 30
18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
20 |
21 |
22 | def load_audio(file: str, sr: int = SAMPLE_RATE):
23 | """
24 | Open an audio file and read as mono waveform, resampling as necessary
25 |
26 | Parameters
27 | ----------
28 | file: str
29 | The audio file to open
30 |
31 | sr: int
32 | The sample rate to resample the audio if necessary
33 |
34 | Returns
35 | -------
36 | A NumPy array containing the audio waveform, in float32 dtype.
37 | """
38 | try:
39 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
40 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
41 | out, _ = (
42 | ffmpeg.input(file, threads=0)
43 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
44 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
45 | )
46 | except ffmpeg.Error as e:
47 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
48 |
49 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
50 |
51 |
52 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
53 | """
54 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
55 | """
56 | if torch.is_tensor(array):
57 | if array.shape[axis] > length:
58 | array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
59 |
60 | if array.shape[axis] < length:
61 | pad_widths = [(0, 0)] * array.ndim
62 | pad_widths[axis] = (0, length - array.shape[axis])
63 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
64 | else:
65 | if array.shape[axis] > length:
66 | array = array.take(indices=range(length), axis=axis)
67 |
68 | if array.shape[axis] < length:
69 | pad_widths = [(0, 0)] * array.ndim
70 | pad_widths[axis] = (0, length - array.shape[axis])
71 | array = np.pad(array, pad_widths)
72 |
73 | return array
74 |
75 |
76 | @lru_cache(maxsize=None)
77 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
78 | """
79 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
80 | Allows decoupling librosa dependency; saved using:
81 |
82 | np.savez_compressed(
83 | "mel_filters.npz",
84 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
85 | )
86 | """
87 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
88 | return torch.from_numpy(librosa_mel_fn(sr=SAMPLE_RATE,n_fft=N_FFT,n_mels=n_mels)).to(device)
89 |
90 |
91 | def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
92 | """
93 | Compute the log-Mel spectrogram of
94 |
95 | Parameters
96 | ----------
97 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
98 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
99 |
100 | n_mels: int
101 | The number of Mel-frequency filters, only 80 is supported
102 |
103 | Returns
104 | -------
105 | torch.Tensor, shape = (80, n_frames)
106 | A Tensor that contains the Mel spectrogram
107 | """
108 | if not torch.is_tensor(audio):
109 | if isinstance(audio, str):
110 | audio = load_audio(audio)
111 | audio = torch.from_numpy(audio)
112 |
113 | window = torch.hann_window(N_FFT).to(audio.device)
114 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
115 | magnitudes = stft[..., :-1].abs() ** 2
116 |
117 | filters = mel_filters(audio.device, n_mels)
118 | mel_spec = filters @ magnitudes
119 |
120 | log_spec = torch.clamp(mel_spec, min=1e-10).log10()
121 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
122 | log_spec = (log_spec + 4.0) / 4.0
123 | return log_spec
124 |
--------------------------------------------------------------------------------
/vencoder/whisper/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 | import zlib
5 | from typing import Callable, TextIO
6 |
7 | system_encoding = sys.getdefaultencoding()
8 |
9 | if system_encoding != "utf-8":
10 | def make_safe(string):
11 | # replaces any character not representable using the system default encoding with an '?',
12 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
13 | return string.encode(system_encoding, errors="replace").decode(system_encoding)
14 | else:
15 | def make_safe(string):
16 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
17 | return string
18 |
19 |
20 | def exact_div(x, y):
21 | assert x % y == 0
22 | return x // y
23 |
24 |
25 | def str2bool(string):
26 | str2val = {"True": True, "False": False}
27 | if string in str2val:
28 | return str2val[string]
29 | else:
30 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
31 |
32 |
33 | def optional_int(string):
34 | return None if string == "None" else int(string)
35 |
36 |
37 | def optional_float(string):
38 | return None if string == "None" else float(string)
39 |
40 |
41 | def compression_ratio(text) -> float:
42 | text_bytes = text.encode("utf-8")
43 | return len(text_bytes) / len(zlib.compress(text_bytes))
44 |
45 |
46 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
47 | assert seconds >= 0, "non-negative timestamp expected"
48 | milliseconds = round(seconds * 1000.0)
49 |
50 | hours = milliseconds // 3_600_000
51 | milliseconds -= hours * 3_600_000
52 |
53 | minutes = milliseconds // 60_000
54 | milliseconds -= minutes * 60_000
55 |
56 | seconds = milliseconds // 1_000
57 | milliseconds -= seconds * 1_000
58 |
59 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
60 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
61 |
62 |
63 | class ResultWriter:
64 | extension: str
65 |
66 | def __init__(self, output_dir: str):
67 | self.output_dir = output_dir
68 |
69 | def __call__(self, result: dict, audio_path: str):
70 | audio_basename = os.path.basename(audio_path)
71 | output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
72 |
73 | with open(output_path, "w", encoding="utf-8") as f:
74 | self.write_result(result, file=f)
75 |
76 | def write_result(self, result: dict, file: TextIO):
77 | raise NotImplementedError
78 |
79 |
80 | class WriteTXT(ResultWriter):
81 | extension: str = "txt"
82 |
83 | def write_result(self, result: dict, file: TextIO):
84 | for segment in result["segments"]:
85 | print(segment['text'].strip(), file=file, flush=True)
86 |
87 |
88 | class WriteVTT(ResultWriter):
89 | extension: str = "vtt"
90 |
91 | def write_result(self, result: dict, file: TextIO):
92 | print("WEBVTT\n", file=file)
93 | for segment in result["segments"]:
94 | print(
95 | f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
96 | f"{segment['text'].strip().replace('-->', '->')}\n",
97 | file=file,
98 | flush=True,
99 | )
100 |
101 |
102 | class WriteSRT(ResultWriter):
103 | extension: str = "srt"
104 |
105 | def write_result(self, result: dict, file: TextIO):
106 | for i, segment in enumerate(result["segments"], start=1):
107 | # write srt lines
108 | print(
109 | f"{i}\n"
110 | f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
111 | f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
112 | f"{segment['text'].strip().replace('-->', '->')}\n",
113 | file=file,
114 | flush=True,
115 | )
116 |
117 |
118 | class WriteTSV(ResultWriter):
119 | """
120 | Write a transcript to a file in TSV (tab-separated values) format containing lines like:
121 | \t\t
122 |
123 | Using integer milliseconds as start and end times means there's no chance of interference from
124 | an environment setting a language encoding that causes the decimal in a floating point number
125 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
126 | """
127 | extension: str = "tsv"
128 |
129 | def write_result(self, result: dict, file: TextIO):
130 | print("start", "end", "text", sep="\t", file=file)
131 | for segment in result["segments"]:
132 | print(round(1000 * segment['start']), file=file, end="\t")
133 | print(round(1000 * segment['end']), file=file, end="\t")
134 | print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
135 |
136 |
137 | class WriteJSON(ResultWriter):
138 | extension: str = "json"
139 |
140 | def write_result(self, result: dict, file: TextIO):
141 | json.dump(result, file)
142 |
143 |
144 | def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
145 | writers = {
146 | "txt": WriteTXT,
147 | "vtt": WriteVTT,
148 | "srt": WriteSRT,
149 | "tsv": WriteTSV,
150 | "json": WriteJSON,
151 | }
152 |
153 | if output_format == "all":
154 | all_writers = [writer(output_dir) for writer in writers.values()]
155 |
156 | def write_all(result: dict, file: TextIO):
157 | for writer in all_writers:
158 | writer(result, file)
159 |
160 | return write_all
161 |
162 | return writers[output_format](output_dir)
163 |
164 |
--------------------------------------------------------------------------------
/wav_upload.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | from google.colab import files
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--type", type=str, required=True, help="type of file to upload")
10 | args = parser.parse_args()
11 | file_type = args.type
12 |
13 | basepath = os.getcwd()
14 | uploaded = files.upload() # 上传文件
15 | assert(file_type in ['zip', 'audio'])
16 | if file_type == "zip":
17 | upload_path = "./upload/"
18 | for filename in uploaded.keys():
19 | #将上传的文件移动到指定的位置上
20 | shutil.move(os.path.join(basepath, filename), os.path.join(upload_path, "userzip.zip"))
21 | elif file_type == "audio":
22 | upload_path = "./raw/"
23 | for filename in uploaded.keys():
24 | #将上传的文件移动到指定的位置上
25 | shutil.move(os.path.join(basepath, filename), os.path.join(upload_path, filename))
--------------------------------------------------------------------------------