├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── ask_for_help.yaml │ ├── ask_for_help_en_US.yaml │ ├── bug_report.yaml │ ├── bug_report_en_US.yaml │ ├── config.yml │ └── default.md └── workflows │ ├── reviewdog.yml │ └── ruff.yml ├── .gitignore ├── .ruff.toml ├── LICENSE ├── README.md ├── README_zh_CN.md ├── cluster ├── __init__.py ├── kmeans.py └── train_cluster.py ├── compress_model.py ├── configs ├── config.json └── diffusion.yaml ├── configs_template ├── config_template.json ├── config_tiny_template.json └── diffusion_template.yaml ├── data_utils.py ├── dataset_raw └── wav_structure.txt ├── diffusion ├── __init__.py ├── data_loaders.py ├── diffusion.py ├── diffusion_onnx.py ├── dpm_solver_pytorch.py ├── how to export onnx.md ├── infer_gt_mel.py ├── logger │ ├── __init__.py │ ├── saver.py │ └── utils.py ├── onnx_export.py ├── solver.py ├── uni_pc.py ├── unit2mel.py ├── vocoder.py └── wavenet.py ├── edgetts ├── tts.py └── tts_voices.py ├── export_index_for_onnx.py ├── filelists ├── test.txt ├── train.txt └── val.txt ├── flask_api.py ├── flask_api_full_song.py ├── inference ├── __init__.py ├── infer_tool.py ├── infer_tool_grad.py └── slicer.py ├── inference_main.py ├── logs └── 44k │ ├── diffusion │ └── put_diffusion_pretrained_model_here │ └── put_pretrained_model_here ├── models.py ├── modules ├── DSConv.py ├── F0Predictor │ ├── CrepeF0Predictor.py │ ├── DioF0Predictor.py │ ├── F0Predictor.py │ ├── FCPEF0Predictor.py │ ├── HarvestF0Predictor.py │ ├── PMF0Predictor.py │ ├── RMVPEF0Predictor.py │ ├── __init__.py │ ├── crepe.py │ ├── fcpe │ │ ├── __init__.py │ │ ├── model.py │ │ ├── nvSTFT.py │ │ └── pcmer.py │ └── rmvpe │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── deepunet.py │ │ ├── inference.py │ │ ├── model.py │ │ ├── seq.py │ │ ├── spec.py │ │ └── utils.py ├── __init__.py ├── attentions.py ├── commons.py ├── enhancer.py ├── losses.py ├── mel_processing.py └── modules.py ├── onnx_export.py ├── onnx_export_old.py ├── onnxexport ├── model_onnx.py └── model_onnx_speaker_mix.py ├── preprocess_flist_config.py ├── preprocess_hubert_f0.py ├── pretrain ├── __init__.py ├── meta.py ├── nsf_hifigan │ └── put_nsf_hifigan_ckpt_here └── put_hubert_ckpt_here ├── raw └── put_raw_wav_here ├── requirements.txt ├── requirements_onnx_encoder.txt ├── requirements_win.txt ├── resample.py ├── shadowdiffusion.png ├── sovits4_for_colab.ipynb ├── spkmix.py ├── train.py ├── train_diff.py ├── train_index.py ├── trained └── put_trained_checkpoints_here ├── utils.py ├── vdecoder ├── __init__.py ├── hifigan │ ├── env.py │ ├── models.py │ ├── nvSTFT.py │ └── utils.py ├── hifiganwithsnake │ ├── alias │ │ ├── __init__.py │ │ ├── act.py │ │ ├── filter.py │ │ └── resample.py │ ├── env.py │ ├── models.py │ ├── nvSTFT.py │ └── utils.py └── nsf_hifigan │ ├── env.py │ ├── models.py │ ├── nvSTFT.py │ └── utils.py ├── vencoder ├── CNHubertLarge.py ├── ContentVec256L12_Onnx.py ├── ContentVec256L9.py ├── ContentVec256L9_Onnx.py ├── ContentVec768L12.py ├── ContentVec768L12_Onnx.py ├── ContentVec768L9_Onnx.py ├── DPHubert.py ├── HubertSoft.py ├── HubertSoft_Onnx.py ├── WavLMBasePlus.py ├── WhisperPPG.py ├── WhisperPPGLarge.py ├── __init__.py ├── dphubert │ ├── __init__.py │ ├── components.py │ ├── hardconcrete.py │ ├── model.py │ ├── pruning_utils.py │ └── utils │ │ ├── __init__.py │ │ └── import_huggingface_wavlm.py ├── encoder.py ├── hubert │ ├── __init__.py │ ├── hubert_model.py │ └── hubert_model_onnx.py ├── wavlm │ ├── WavLM.py │ └── modules.py └── whisper │ ├── __init__.py │ ├── audio.py │ ├── decoding.py │ ├── model.py │ ├── tokenizer.py │ └── utils.py ├── wav_upload.py └── webUI.py /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/ask_for_help.yaml: -------------------------------------------------------------------------------- 1 | name: 请求帮助 2 | description: 遇到了无法自行解决的错误 3 | title: '[Help]: ' 4 | labels: [ "help wanted" ] 5 | 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | #### 提问前请先自己去尝试解决,比如查看[本仓库wiki](https://github.com/svc-develop-team/so-vits-svc/wiki),也可以借助chatgpt或一些搜索引擎(谷歌/必应/New Bing/StackOverflow等等)。如果实在无法自己解决再发issue,在提issue之前,请先了解《[提问的智慧](https://github.com/ryanhanwu/How-To-Ask-Questions-The-Smart-Way/blob/main/README-zh_CN.md)》。 11 | --- 12 | ### 什么样的issue会被直接close 13 | 1. 伸手党 14 | 2. 一键包/环境包相关 15 | 3. 提供的信息不全 16 | 4. 低级的如缺少依赖而导致无法运行的问题 17 | 4. 所用的数据集是无授权数据集(游戏角色/二次元人物暂不归为此类,但是训练时候也要小心谨慎。如果能联系到官方,必须先和官方联系并核实清楚) 18 | --- 19 | 20 | - type: checkboxes 21 | id: Clause 22 | attributes: 23 | label: 请勾选下方的确认框。 24 | options: 25 | - label: "我已仔细阅读[README.md](https://github.com/svc-develop-team/so-vits-svc/blob/4.0/README_zh_CN.md)和[wiki中的Quick solution](https://github.com/svc-develop-team/so-vits-svc/wiki/Quick-solution)。" 26 | required: true 27 | - label: "我已通过各种搜索引擎排查问题,我要提出的问题并不常见。" 28 | required: true 29 | - label: "我未在使用由第三方用户提供的一键包/环境包。" 30 | required: true 31 | 32 | - type: markdown 33 | attributes: 34 | value: | 35 | # 请根据实际使用环境填写以下信息 36 | 37 | - type: input 38 | id: System 39 | attributes: 40 | label: 系统平台版本号 41 | description: Windows执行`winver` | Linux执行`uname -a` 42 | validations: 43 | required: true 44 | 45 | - type: input 46 | id: GPU 47 | attributes: 48 | label: GPU 型号 49 | description: 执行`nvidia-smi` 50 | validations: 51 | required: true 52 | 53 | - type: input 54 | id: PythonVersion 55 | attributes: 56 | label: Python版本 57 | description: 执行`python -V` 58 | validations: 59 | required: true 60 | 61 | - type: input 62 | id: PyTorchVersion 63 | attributes: 64 | label: PyTorch版本 65 | description: 执行`pip show torch` 66 | validations: 67 | required: true 68 | 69 | - type: dropdown 70 | id: Branch 71 | attributes: 72 | label: sovits分支 73 | options: 74 | - 4.0(默认) 75 | - 4.0-v2 76 | - 3.0-32k 77 | - 3.0-48k 78 | validations: 79 | required: true 80 | 81 | - type: input 82 | id: DatasetSource 83 | attributes: 84 | label: 数据集来源(用于判断数据集质量) 85 | description: 如:UVR处理过的vtb直播音频、录音棚录制 86 | validations: 87 | required: true 88 | 89 | - type: input 90 | id: WhereOccurs 91 | attributes: 92 | label: 出现问题的环节或执行的命令 93 | description: 如:预处理、训练、`python preprocess_hubert_f0.py` 94 | validations: 95 | required: true 96 | 97 | - type: textarea 98 | id: Description 99 | attributes: 100 | label: 问题描述 101 | description: 在这里描述自己的问题,越详细越好 102 | validations: 103 | required: true 104 | 105 | - type: textarea 106 | id: Log 107 | attributes: 108 | label: 日志 109 | description: 将从执行命令到执行完毕输出的所有信息(包括你所执行的命令)粘贴到[pastebin.com](https://pastebin.com/)并把剪贴板链接贴到这里,日志量少的话也可以直接贴在下面 110 | render: python 111 | validations: 112 | required: true 113 | 114 | - type: textarea 115 | id: ValidOneClick 116 | attributes: 117 | label: 截图`so-vits-svc`、`logs/44k`文件夹并粘贴到此处 118 | validations: 119 | required: true 120 | 121 | - type: textarea 122 | id: Supplementary 123 | attributes: 124 | label: 补充说明 125 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/ask_for_help_en_US.yaml: -------------------------------------------------------------------------------- 1 | name: Ask for help 2 | description: Encountered an error cannot be resolved by self 3 | title: '[Help]: ' 4 | labels: [ "help wanted" ] 5 | 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | #### Please try to solve the problem yourself before asking for help. At first you can read *[repo wiki](https://github.com/svc-develop-team/so-vits-svc/wiki)*. Then you can use chatgpt or some search engines like google, bing, new bing and StackOverflow until you really find that you can't solve it by yourself. And before you raise an issue, please understand *[How To Ask Questions The Smart Way](http://www.catb.org/~esr/faqs/smart-questions.html)* in advance. 11 | --- 12 | ### What kind of issue will be closed immediately 13 | 1. Beggars or Free Riders 14 | 2. One click package / Environment package (Not using `pip install -r requirement.txt`) 15 | 3. Incomplete information 16 | 4. Stupid issues such as miss a dependency package 17 | 4. Using unlicenced dataset (Game characters / anime characters are not included in this category temporarily but you still need to pay attention. If you can contact the official, you must contact the official and verify it at first.) 18 | --- 19 | 20 | - type: checkboxes 21 | id: Clause 22 | attributes: 23 | label: Please check the checkboxes below. 24 | options: 25 | - label: "I have read *[README.md](https://github.com/svc-develop-team/so-vits-svc/blob/4.0/README.md)* and *[Quick solution in wiki](https://github.com/svc-develop-team/so-vits-svc/wiki/Quick-solution)* carefully." 26 | required: true 27 | - label: "I have been troubleshooting issues through various search engines. The questions I want to ask are not common." 28 | required: true 29 | - label: "I am NOT using one click package / environment package." 30 | required: true 31 | 32 | - type: markdown 33 | attributes: 34 | value: | 35 | # Please fill in the following information according to your actual environment 36 | 37 | - type: input 38 | id: System 39 | attributes: 40 | label: OS version 41 | description: Windows run `winver` | Linux run `uname -a` 42 | validations: 43 | required: true 44 | 45 | - type: input 46 | id: GPU 47 | attributes: 48 | label: GPU 49 | description: Run `nvidia-smi` 50 | validations: 51 | required: true 52 | 53 | - type: input 54 | id: PythonVersion 55 | attributes: 56 | label: Python version 57 | description: Run `python -V` 58 | validations: 59 | required: true 60 | 61 | - type: input 62 | id: PyTorchVersion 63 | attributes: 64 | label: PyTorch version 65 | description: Run `pip show torch` 66 | validations: 67 | required: true 68 | 69 | - type: dropdown 70 | id: Branch 71 | attributes: 72 | label: Branch of sovits 73 | options: 74 | - 4.0(Default) 75 | - 4.0-v2 76 | - 3.0-32k 77 | - 3.0-48k 78 | validations: 79 | required: true 80 | 81 | - type: input 82 | id: DatasetSource 83 | attributes: 84 | label: Dataset source (Used to judge the dataset quality) 85 | description: Such as UVR-processed streaming audio / Recorded in recording studio 86 | validations: 87 | required: true 88 | 89 | - type: input 90 | id: WhereOccurs 91 | attributes: 92 | label: Where thr problem occurs or what command you executed 93 | description: Such as Preprocessing / Training / `python preprocess_hubert_f0.py` 94 | validations: 95 | required: true 96 | 97 | - type: textarea 98 | id: Description 99 | attributes: 100 | label: Problem description 101 | description: Describe your problem here, the more detailed the better. 102 | validations: 103 | required: true 104 | 105 | - type: textarea 106 | id: Log 107 | attributes: 108 | label: Log 109 | description: All information output from the command you executed to the end of execution (include the command). It can also be directly posted below if there is only few text. 110 | render: python 111 | validations: 112 | required: true 113 | 114 | - type: textarea 115 | id: ValidOneClick 116 | attributes: 117 | label: Screenshot `so-vits-svc` and `logs/44k` folders and paste here 118 | validations: 119 | required: true 120 | 121 | - type: textarea 122 | id: Supplementary 123 | attributes: 124 | label: Supplementary description 125 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: 问题回报 2 | description: 遇到了BUG?! 3 | title: '[Bug]: ' 4 | labels: [ "bug?" ] 5 | 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | # 请根据实际使用环境填写以下信息 11 | 12 | - type: input 13 | id: System 14 | attributes: 15 | label: 系统平台版本号 16 | description: Windows执行`winver` | Linux执行`uname -a` 17 | validations: 18 | required: true 19 | 20 | - type: input 21 | id: GPU 22 | attributes: 23 | label: GPU 型号 24 | description: 执行`nvidia-smi` 25 | validations: 26 | required: true 27 | 28 | - type: input 29 | id: PythonVersion 30 | attributes: 31 | label: Python版本 32 | description: 执行`python -V` 33 | validations: 34 | required: true 35 | 36 | - type: input 37 | id: PyTorchVersion 38 | attributes: 39 | label: PyTorch版本 40 | description: 执行`pip show torch` 41 | validations: 42 | required: true 43 | 44 | - type: dropdown 45 | id: Branch 46 | attributes: 47 | label: sovits分支 48 | options: 49 | - 4.0(默认) 50 | - 4.0-v2 51 | - 3.0-32k 52 | - 3.0-48k 53 | validations: 54 | required: true 55 | 56 | - type: input 57 | id: DatasetSource 58 | attributes: 59 | label: 数据集来源(用于判断数据集质量) 60 | description: 如:UVR处理过的vtb直播音频、录音棚录制 61 | validations: 62 | required: true 63 | 64 | - type: input 65 | id: WhereOccurs 66 | attributes: 67 | label: 出现问题的环节或执行的命令 68 | description: 如:预处理、训练、`python preprocess_hubert_f0.py` 69 | validations: 70 | required: true 71 | 72 | - type: textarea 73 | id: Description 74 | attributes: 75 | label: 情况描述 76 | description: 在这里描述遇到的情况,越详细越好 77 | validations: 78 | required: true 79 | 80 | - type: textarea 81 | id: Log 82 | attributes: 83 | label: 日志 84 | description: 将从执行命令到执行完毕输出的所有信息(包括你所执行的命令)粘贴到[pastebin.com](https://pastebin.com/)并把剪贴板链接贴到这里,日志量少的话也可以直接贴在下面 85 | render: python 86 | validations: 87 | required: true 88 | 89 | - type: textarea 90 | id: Supplementary 91 | attributes: 92 | label: 补充说明 93 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report_en_US.yaml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Encountered an bug?! 3 | title: '[Bug]: ' 4 | labels: [ "bug?" ] 5 | 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | # Please fill in the following information according to your actual environment 11 | 12 | - type: input 13 | id: System 14 | attributes: 15 | label: OS version 16 | description: Windows run `winver` | Linux run `uname -a` 17 | validations: 18 | required: true 19 | 20 | - type: input 21 | id: GPU 22 | attributes: 23 | label: GPU 24 | description: Run `nvidia-smi` 25 | validations: 26 | required: true 27 | 28 | - type: input 29 | id: PythonVersion 30 | attributes: 31 | label: Python version 32 | description: Run `python -V` 33 | validations: 34 | required: true 35 | 36 | - type: input 37 | id: PyTorchVersion 38 | attributes: 39 | label: PyTorch version 40 | description: Run `pip show torch` 41 | validations: 42 | required: true 43 | 44 | - type: dropdown 45 | id: Branch 46 | attributes: 47 | label: Branch of sovits 48 | options: 49 | - 4.0(Default) 50 | - 4.0-v2 51 | - 3.0-32k 52 | - 3.0-48k 53 | validations: 54 | required: true 55 | 56 | - type: input 57 | id: DatasetSource 58 | attributes: 59 | label: Dataset source (Used to judge the dataset quality) 60 | description: Such as UVR-processed streaming audio / Recorded in recording studio 61 | validations: 62 | required: true 63 | 64 | - type: input 65 | id: WhereOccurs 66 | attributes: 67 | label: Where thr problem occurs or what command you executed 68 | description: Such as Preprocessing / Training / `python preprocess_hubert_f0.py` 69 | validations: 70 | required: true 71 | 72 | - type: textarea 73 | id: Description 74 | attributes: 75 | label: Situation description 76 | description: Describe your situation here, the more detailed the better. 77 | validations: 78 | required: true 79 | 80 | - type: textarea 81 | id: Log 82 | attributes: 83 | label: Log 84 | description: All information output from the command you executed to the end of execution (include the command). You can paste them to [pastebin.com](https://pastebin.com/) then paste the short link here. It can also be directly posted below if there is only few text. 85 | render: python 86 | validations: 87 | required: true 88 | 89 | - type: textarea 90 | id: Supplementary 91 | attributes: 92 | label: Supplementary description 93 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: 讨论区 / Discussions 4 | url: https://github.com/svc-develop-team/so-vits-svc/discussions 5 | about: 简单的询问/讨论请转至讨论区或发起一个低优先级的Default issue / For simple inquiries / discussions, please go to the discussions or raise a low priority Default issue 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/default.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Default issue 3 | about: 如果模板中没有你想发起的issue类型,可以选择此项,但这个issue也许会获得一个较低的处理优先级 / If there is no issue type you want to raise, you can start with this one. But this issue maybe will get a lower priority to deal with. 4 | title: '' 5 | labels: 'not urgent' 6 | assignees: '' 7 | --- 8 | -------------------------------------------------------------------------------- /.github/workflows/reviewdog.yml: -------------------------------------------------------------------------------- 1 | name: Ruff Autofix 2 | on: [pull_request] 3 | jobs: 4 | ruff: 5 | permissions: 6 | checks: write 7 | contents: read 8 | pull-requests: write 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - uses: chartboost/ruff-action@v1 13 | with: 14 | args: --fix -e 15 | - uses: reviewdog/action-suggester@v1 16 | with: 17 | tool_name: ruff 18 | 19 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | on: [push, pull_request] 3 | jobs: 4 | ruff: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v3 8 | - uses: chartboost/ruff-action@v1 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | checkpoints/ 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | pytestdebug.log 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | doc/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # pytype static type analyzer 139 | .pytype/ 140 | 141 | # End of https://www.toptal.com/developers/gitignore/api/python 142 | 143 | /shelf/ 144 | /workspace.xml 145 | 146 | dataset 147 | dataset_raw 148 | raw 149 | results 150 | inference/chunks_temp.json 151 | logs 152 | hubert/checkpoint_best_legacy_500.pt 153 | configs/config.json 154 | filelists/test.txt 155 | filelists/train.txt 156 | filelists/val.txt 157 | .idea/ 158 | .vscode/ 159 | .idea/modules.xml 160 | .idea/so-vits-svc.iml 161 | .idea/vcs.xml 162 | .idea/inspectionProfiles/profiles_settings.xml 163 | .idea/inspectionProfiles/Project_Default.xml 164 | pretrain/ 165 | .vscode/launch.json 166 | 167 | trained/**/ 168 | -------------------------------------------------------------------------------- /.ruff.toml: -------------------------------------------------------------------------------- 1 | select = ["E", "F", "I"] 2 | 3 | # Never enforce `E501` (line length violations). 4 | ignore = ["E501", "E741"] 5 | -------------------------------------------------------------------------------- /cluster/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.cluster import KMeans 3 | 4 | 5 | def get_cluster_model(ckpt_path): 6 | checkpoint = torch.load(ckpt_path) 7 | kmeans_dict = {} 8 | for spk, ckpt in checkpoint.items(): 9 | km = KMeans(ckpt["n_features_in_"]) 10 | km.__dict__["n_features_in_"] = ckpt["n_features_in_"] 11 | km.__dict__["_n_threads"] = ckpt["_n_threads"] 12 | km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"] 13 | kmeans_dict[spk] = km 14 | return kmeans_dict 15 | 16 | def get_cluster_result(model, x, speaker): 17 | """ 18 | x: np.array [t, 256] 19 | return cluster class result 20 | """ 21 | return model[speaker].predict(x) 22 | 23 | def get_cluster_center_result(model, x,speaker): 24 | """x: np.array [t, 256]""" 25 | predict = model[speaker].predict(x) 26 | return model[speaker].cluster_centers_[predict] 27 | 28 | def get_center(model, x,speaker): 29 | return model[speaker].cluster_centers_[x] 30 | -------------------------------------------------------------------------------- /cluster/train_cluster.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import time 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | import tqdm 10 | from kmeans import KMeansGPU 11 | from sklearn.cluster import KMeans, MiniBatchKMeans 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | def train_cluster(in_dir, n_clusters, use_minibatch=True, verbose=False,use_gpu=False):#gpu_minibatch真拉,虽然库支持但是也不考虑 17 | if str(in_dir).endswith(".ipynb_checkpoints"): 18 | logger.info(f"Ignore {in_dir}") 19 | 20 | logger.info(f"Loading features from {in_dir}") 21 | features = [] 22 | nums = 0 23 | for path in tqdm.tqdm(in_dir.glob("*.soft.pt")): 24 | # for name in os.listdir(in_dir): 25 | # path="%s/%s"%(in_dir,name) 26 | features.append(torch.load(path,map_location="cpu").squeeze(0).numpy().T) 27 | # print(features[-1].shape) 28 | features = np.concatenate(features, axis=0) 29 | print(nums, features.nbytes/ 1024**2, "MB , shape:",features.shape, features.dtype) 30 | features = features.astype(np.float32) 31 | logger.info(f"Clustering features of shape: {features.shape}") 32 | t = time.time() 33 | if(use_gpu is False): 34 | if use_minibatch: 35 | kmeans = MiniBatchKMeans(n_clusters=n_clusters,verbose=verbose, batch_size=4096, max_iter=80).fit(features) 36 | else: 37 | kmeans = KMeans(n_clusters=n_clusters,verbose=verbose).fit(features) 38 | else: 39 | kmeans = KMeansGPU(n_clusters=n_clusters, mode='euclidean', verbose=2 if verbose else 0,max_iter=500,tol=1e-2)# 40 | features=torch.from_numpy(features)#.to(device) 41 | kmeans.fit_predict(features)# 42 | 43 | print(time.time()-t, "s") 44 | 45 | x = { 46 | "n_features_in_": kmeans.n_features_in_ if use_gpu is False else features.shape[1], 47 | "_n_threads": kmeans._n_threads if use_gpu is False else 4, 48 | "cluster_centers_": kmeans.cluster_centers_ if use_gpu is False else kmeans.centroids.cpu().numpy(), 49 | } 50 | print("end") 51 | 52 | return x 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--dataset', type=Path, default="./dataset/44k", 57 | help='path of training data directory') 58 | parser.add_argument('--output', type=Path, default="logs/44k", 59 | help='path of model output directory') 60 | parser.add_argument('--gpu',action='store_true', default=False , 61 | help='to use GPU') 62 | 63 | 64 | args = parser.parse_args() 65 | 66 | checkpoint_dir = args.output 67 | dataset = args.dataset 68 | use_gpu = args.gpu 69 | n_clusters = 10000 70 | 71 | ckpt = {} 72 | for spk in os.listdir(dataset): 73 | if os.path.isdir(dataset/spk): 74 | print(f"train kmeans for {spk}...") 75 | in_dir = dataset/spk 76 | x = train_cluster(in_dir, n_clusters,use_minibatch=False,verbose=False,use_gpu=use_gpu) 77 | ckpt[spk] = x 78 | 79 | checkpoint_path = checkpoint_dir / f"kmeans_{n_clusters}.pt" 80 | checkpoint_path.parent.mkdir(exist_ok=True, parents=True) 81 | torch.save( 82 | ckpt, 83 | checkpoint_path, 84 | ) 85 | 86 | -------------------------------------------------------------------------------- /compress_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | import utils 6 | from models import SynthesizerTrn 7 | 8 | 9 | def copyStateDict(state_dict): 10 | if list(state_dict.keys())[0].startswith('module'): 11 | start_idx = 1 12 | else: 13 | start_idx = 0 14 | new_state_dict = OrderedDict() 15 | for k, v in state_dict.items(): 16 | name = ','.join(k.split('.')[start_idx:]) 17 | new_state_dict[name] = v 18 | return new_state_dict 19 | 20 | 21 | def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str): 22 | hps = utils.get_hparams_from_file(config) 23 | 24 | net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1, 25 | hps.train.segment_size // hps.data.hop_length, 26 | **hps.model) 27 | 28 | optim_g = torch.optim.AdamW(net_g.parameters(), 29 | hps.train.learning_rate, 30 | betas=hps.train.betas, 31 | eps=hps.train.eps) 32 | 33 | state_dict_g = torch.load(input_model, map_location="cpu") 34 | new_dict_g = copyStateDict(state_dict_g) 35 | keys = [] 36 | for k, v in new_dict_g['model'].items(): 37 | if "enc_q" in k: continue # noqa: E701 38 | keys.append(k) 39 | 40 | new_dict_g = {k: new_dict_g['model'][k].half() for k in keys} if ishalf else {k: new_dict_g['model'][k] for k in keys} 41 | 42 | torch.save( 43 | { 44 | 'model': new_dict_g, 45 | 'iteration': 0, 46 | 'optimizer': optim_g.state_dict(), 47 | 'learning_rate': 0.0001 48 | }, output_model) 49 | 50 | 51 | if __name__ == "__main__": 52 | import argparse 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("-c", 55 | "--config", 56 | type=str, 57 | default='configs/config.json') 58 | parser.add_argument("-i", "--input", type=str) 59 | parser.add_argument("-o", "--output", type=str, default=None) 60 | parser.add_argument('-hf', '--half', action='store_true', default=False, help='Save as FP16') 61 | 62 | args = parser.parse_args() 63 | 64 | output = args.output 65 | 66 | if output is None: 67 | import os.path 68 | filename, ext = os.path.splitext(args.input) 69 | half = "_half" if args.half else "" 70 | output = filename + "_release" + half + ext 71 | 72 | removeOptimizer(args.config, args.input, args.half, output) -------------------------------------------------------------------------------- /configs/config.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/configs/config.json -------------------------------------------------------------------------------- /configs/diffusion.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/configs/diffusion.yaml -------------------------------------------------------------------------------- /configs_template/config_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 800, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 0.0001, 8 | "betas": [ 9 | 0.8, 10 | 0.99 11 | ], 12 | "eps": 1e-09, 13 | "batch_size": 6, 14 | "fp16_run": false, 15 | "half_type": "fp16", 16 | "lr_decay": 0.999875, 17 | "segment_size": 10240, 18 | "init_lr_ratio": 1, 19 | "warmup_epochs": 0, 20 | "c_mel": 45, 21 | "c_kl": 1.0, 22 | "use_sr": true, 23 | "max_speclen": 512, 24 | "port": "8001", 25 | "keep_ckpts": 3, 26 | "all_in_mem": false, 27 | "vol_aug":false 28 | }, 29 | "data": { 30 | "training_files": "filelists/train.txt", 31 | "validation_files": "filelists/val.txt", 32 | "max_wav_value": 32768.0, 33 | "sampling_rate": 44100, 34 | "filter_length": 2048, 35 | "hop_length": 512, 36 | "win_length": 2048, 37 | "n_mel_channels": 80, 38 | "mel_fmin": 0.0, 39 | "mel_fmax": 22050, 40 | "unit_interpolate_mode":"nearest" 41 | }, 42 | "model": { 43 | "inter_channels": 192, 44 | "hidden_channels": 192, 45 | "filter_channels": 768, 46 | "n_heads": 2, 47 | "n_layers": 6, 48 | "kernel_size": 3, 49 | "p_dropout": 0.1, 50 | "resblock": "1", 51 | "resblock_kernel_sizes": [3,7,11], 52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 53 | "upsample_rates": [ 8, 8, 2, 2, 2], 54 | "upsample_initial_channel": 512, 55 | "upsample_kernel_sizes": [16,16, 4, 4, 4], 56 | "n_layers_q": 3, 57 | "n_layers_trans_flow": 3, 58 | "n_flow_layer": 4, 59 | "use_spectral_norm": false, 60 | "gin_channels": 768, 61 | "ssl_dim": 768, 62 | "n_speakers": 200, 63 | "vocoder_name":"nsf-hifigan", 64 | "speech_encoder":"vec768l12", 65 | "speaker_embedding":false, 66 | "vol_embedding":false, 67 | "use_depthwise_conv":false, 68 | "flow_share_parameter": false, 69 | "use_automatic_f0_prediction": true, 70 | "use_transformer_flow": false 71 | }, 72 | "spk": { 73 | "nyaru": 0, 74 | "huiyu": 1, 75 | "nen": 2, 76 | "paimon": 3, 77 | "yunhao": 4 78 | } 79 | } -------------------------------------------------------------------------------- /configs_template/config_tiny_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 800, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 0.0001, 8 | "betas": [ 9 | 0.8, 10 | 0.99 11 | ], 12 | "eps": 1e-09, 13 | "batch_size": 6, 14 | "fp16_run": false, 15 | "half_type": "fp16", 16 | "lr_decay": 0.999875, 17 | "segment_size": 10240, 18 | "init_lr_ratio": 1, 19 | "warmup_epochs": 0, 20 | "c_mel": 45, 21 | "c_kl": 1.0, 22 | "use_sr": true, 23 | "max_speclen": 512, 24 | "port": "8001", 25 | "keep_ckpts": 3, 26 | "all_in_mem": false, 27 | "vol_aug":false 28 | }, 29 | "data": { 30 | "training_files": "filelists/train.txt", 31 | "validation_files": "filelists/val.txt", 32 | "max_wav_value": 32768.0, 33 | "sampling_rate": 44100, 34 | "filter_length": 2048, 35 | "hop_length": 512, 36 | "win_length": 2048, 37 | "n_mel_channels": 80, 38 | "mel_fmin": 0.0, 39 | "mel_fmax": 22050, 40 | "unit_interpolate_mode":"nearest" 41 | }, 42 | "model": { 43 | "inter_channels": 192, 44 | "hidden_channels": 192, 45 | "filter_channels": 512, 46 | "n_heads": 2, 47 | "n_layers": 6, 48 | "kernel_size": 3, 49 | "p_dropout": 0.1, 50 | "resblock": "1", 51 | "resblock_kernel_sizes": [3,7,11], 52 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 53 | "upsample_rates": [ 8, 8, 2, 2, 2], 54 | "upsample_initial_channel": 400, 55 | "upsample_kernel_sizes": [16,16, 4, 4, 4], 56 | "n_layers_q": 3, 57 | "n_layers_trans_flow": 3, 58 | "n_flow_layer": 4, 59 | "use_spectral_norm": false, 60 | "gin_channels": 768, 61 | "ssl_dim": 768, 62 | "n_speakers": 200, 63 | "vocoder_name":"nsf-hifigan", 64 | "speech_encoder":"vec768l12", 65 | "speaker_embedding":false, 66 | "vol_embedding":false, 67 | "use_depthwise_conv":true, 68 | "flow_share_parameter": true, 69 | "use_automatic_f0_prediction": true, 70 | "use_transformer_flow": false 71 | }, 72 | "spk": { 73 | "nyaru": 0, 74 | "huiyu": 1, 75 | "nen": 2, 76 | "paimon": 3, 77 | "yunhao": 4 78 | } 79 | } -------------------------------------------------------------------------------- /configs_template/diffusion_template.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | sampling_rate: 44100 3 | block_size: 512 # Equal to hop_length 4 | duration: 2 # Audio duration during training, must be less than the duration of the shortest audio clip 5 | encoder: 'vec768l12' # 'hubertsoft', 'vec256l9', 'vec768l12' 6 | cnhubertsoft_gate: 10 7 | encoder_sample_rate: 16000 8 | encoder_hop_size: 320 9 | encoder_out_channels: 768 # 256 if using 'hubertsoft' 10 | training_files: "filelists/train.txt" 11 | validation_files: "filelists/val.txt" 12 | extensions: # List of extension included in the data collection 13 | - wav 14 | unit_interpolate_mode: "nearest" 15 | model: 16 | type: 'Diffusion' 17 | n_layers: 20 18 | n_chans: 512 19 | n_hidden: 256 20 | use_pitch_aug: true 21 | timesteps : 1000 22 | k_step_max: 0 # must <= timesteps, If it is 0, train all 23 | n_spk: 1 # max number of different speakers 24 | device: cuda 25 | vocoder: 26 | type: 'nsf-hifigan' 27 | ckpt: 'pretrain/nsf_hifigan/model' 28 | infer: 29 | speedup: 10 30 | method: 'dpm-solver++' # 'pndm' or 'dpm-solver' or 'ddim' or 'unipc' or 'dpm-solver++' 31 | env: 32 | expdir: logs/44k/diffusion 33 | gpu_id: 0 34 | train: 35 | num_workers: 4 # If your cpu and gpu are both very strong, set to 0 may be faster! 36 | amp_dtype: fp32 # fp32, fp16 or bf16 (fp16 or bf16 may be faster if it is supported by your gpu) 37 | batch_size: 48 38 | cache_all_data: true # Save Internal-Memory or Graphics-Memory if it is false, but may be slow 39 | cache_device: 'cpu' # Set to 'cuda' to cache the data into the Graphics-Memory, fastest speed for strong gpu 40 | cache_fp16: true 41 | epochs: 100000 42 | interval_log: 10 43 | interval_val: 2000 44 | interval_force_save: 5000 45 | lr: 0.0001 46 | decay_step: 100000 47 | gamma: 0.5 48 | weight_decay: 0 49 | save_opt: false 50 | spk: 51 | 'nyaru': 0 -------------------------------------------------------------------------------- /dataset_raw/wav_structure.txt: -------------------------------------------------------------------------------- 1 | 数据集准备 2 | 3 | raw 4 | ├───speaker0 5 | │ ├───xxx1-xxx1.wav 6 | │ ├───... 7 | │ └───Lxx-0xx8.wav 8 | └───speaker1 9 | ├───xx2-0xxx2.wav 10 | ├───... 11 | └───xxx7-xxx007.wav 12 | 13 | 此外还需要编辑config.json 14 | 15 | "n_speakers": 10 16 | 17 | "spk":{ 18 | "speaker0": 0, 19 | "speaker1": 1, 20 | } 21 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/diffusion/__init__.py -------------------------------------------------------------------------------- /diffusion/how to export onnx.md: -------------------------------------------------------------------------------- 1 | - Open [onnx_export](onnx_export.py) 2 | - project_name = "dddsp" change "project_name" to your project name 3 | - model_path = f'{project_name}/model_500000.pt' change "model_path" to your model path 4 | - Run -------------------------------------------------------------------------------- /diffusion/infer_gt_mel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from diffusion.unit2mel import load_model_vocoder 5 | 6 | 7 | class DiffGtMel: 8 | def __init__(self, project_path=None, device=None): 9 | self.project_path = project_path 10 | if device is not None: 11 | self.device = device 12 | else: 13 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | self.model = None 15 | self.vocoder = None 16 | self.args = None 17 | 18 | def flush_model(self, project_path, ddsp_config=None): 19 | if (self.model is None) or (project_path != self.project_path): 20 | model, vocoder, args = load_model_vocoder(project_path, device=self.device) 21 | if self.check_args(ddsp_config, args): 22 | self.model = model 23 | self.vocoder = vocoder 24 | self.args = args 25 | 26 | def check_args(self, args1, args2): 27 | if args1.data.block_size != args2.data.block_size: 28 | raise ValueError("DDSP与DIFF模型的block_size不一致") 29 | if args1.data.sampling_rate != args2.data.sampling_rate: 30 | raise ValueError("DDSP与DIFF模型的sampling_rate不一致") 31 | if args1.data.encoder != args2.data.encoder: 32 | raise ValueError("DDSP与DIFF模型的encoder不一致") 33 | return True 34 | 35 | def __call__(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm', 36 | spk_mix_dict=None, start_frame=0): 37 | input_mel = self.vocoder.extract(audio, self.args.data.sampling_rate) 38 | out_mel = self.model( 39 | hubert, 40 | f0, 41 | volume, 42 | spk_id=spk_id, 43 | spk_mix_dict=spk_mix_dict, 44 | gt_spec=input_mel, 45 | infer=True, 46 | infer_speedup=acc, 47 | method=method, 48 | k_step=k_step, 49 | use_tqdm=False) 50 | if start_frame > 0: 51 | out_mel = out_mel[:, start_frame:, :] 52 | f0 = f0[:, start_frame:, :] 53 | output = self.vocoder.infer(out_mel, f0) 54 | if start_frame > 0: 55 | output = F.pad(output, (start_frame * self.vocoder.vocoder_hop_size, 0)) 56 | return output 57 | 58 | def infer(self, audio, f0, hubert, volume, acc=1, spk_id=1, k_step=0, method='pndm', silence_front=0, 59 | use_silence=False, spk_mix_dict=None): 60 | start_frame = int(silence_front * self.vocoder.vocoder_sample_rate / self.vocoder.vocoder_hop_size) 61 | if use_silence: 62 | audio = audio[:, start_frame * self.vocoder.vocoder_hop_size:] 63 | f0 = f0[:, start_frame:, :] 64 | hubert = hubert[:, start_frame:, :] 65 | volume = volume[:, start_frame:, :] 66 | _start_frame = 0 67 | else: 68 | _start_frame = start_frame 69 | audio = self.__call__(audio, f0, hubert, volume, acc=acc, spk_id=spk_id, k_step=k_step, 70 | method=method, spk_mix_dict=spk_mix_dict, start_frame=_start_frame) 71 | if use_silence: 72 | if start_frame > 0: 73 | audio = F.pad(audio, (start_frame * self.vocoder.vocoder_hop_size, 0)) 74 | return audio 75 | -------------------------------------------------------------------------------- /diffusion/logger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/diffusion/logger/__init__.py -------------------------------------------------------------------------------- /diffusion/logger/saver.py: -------------------------------------------------------------------------------- 1 | ''' 2 | author: wayn391@mastertones 3 | ''' 4 | 5 | import datetime 6 | import os 7 | import time 8 | 9 | import matplotlib.pyplot as plt 10 | import torch 11 | import yaml 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | 15 | class Saver(object): 16 | def __init__( 17 | self, 18 | args, 19 | initial_global_step=-1): 20 | 21 | self.expdir = args.env.expdir 22 | self.sample_rate = args.data.sampling_rate 23 | 24 | # cold start 25 | self.global_step = initial_global_step 26 | self.init_time = time.time() 27 | self.last_time = time.time() 28 | 29 | # makedirs 30 | os.makedirs(self.expdir, exist_ok=True) 31 | 32 | # path 33 | self.path_log_info = os.path.join(self.expdir, 'log_info.txt') 34 | 35 | # ckpt 36 | os.makedirs(self.expdir, exist_ok=True) 37 | 38 | # writer 39 | self.writer = SummaryWriter(os.path.join(self.expdir, 'logs')) 40 | 41 | # save config 42 | path_config = os.path.join(self.expdir, 'config.yaml') 43 | with open(path_config, "w") as out_config: 44 | yaml.dump(dict(args), out_config) 45 | 46 | 47 | def log_info(self, msg): 48 | '''log method''' 49 | if isinstance(msg, dict): 50 | msg_list = [] 51 | for k, v in msg.items(): 52 | tmp_str = '' 53 | if isinstance(v, int): 54 | tmp_str = '{}: {:,}'.format(k, v) 55 | else: 56 | tmp_str = '{}: {}'.format(k, v) 57 | 58 | msg_list.append(tmp_str) 59 | msg_str = '\n'.join(msg_list) 60 | else: 61 | msg_str = msg 62 | 63 | # dsplay 64 | print(msg_str) 65 | 66 | # save 67 | with open(self.path_log_info, 'a') as fp: 68 | fp.write(msg_str+'\n') 69 | 70 | def log_value(self, dict): 71 | for k, v in dict.items(): 72 | self.writer.add_scalar(k, v, self.global_step) 73 | 74 | def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5): 75 | spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1) 76 | spec = spec_cat[0] 77 | if isinstance(spec, torch.Tensor): 78 | spec = spec.cpu().numpy() 79 | fig = plt.figure(figsize=(12, 9)) 80 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax) 81 | plt.tight_layout() 82 | self.writer.add_figure(name, fig, self.global_step) 83 | 84 | def log_audio(self, dict): 85 | for k, v in dict.items(): 86 | self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate) 87 | 88 | def get_interval_time(self, update=True): 89 | cur_time = time.time() 90 | time_interval = cur_time - self.last_time 91 | if update: 92 | self.last_time = cur_time 93 | return time_interval 94 | 95 | def get_total_time(self, to_str=True): 96 | total_time = time.time() - self.init_time 97 | if to_str: 98 | total_time = str(datetime.timedelta( 99 | seconds=total_time))[:-5] 100 | return total_time 101 | 102 | def save_model( 103 | self, 104 | model, 105 | optimizer, 106 | name='model', 107 | postfix='', 108 | to_json=False): 109 | # path 110 | if postfix: 111 | postfix = '_' + postfix 112 | path_pt = os.path.join( 113 | self.expdir , name+postfix+'.pt') 114 | 115 | # check 116 | print(' [*] model checkpoint saved: {}'.format(path_pt)) 117 | 118 | # save 119 | if optimizer is not None: 120 | torch.save({ 121 | 'global_step': self.global_step, 122 | 'model': model.state_dict(), 123 | 'optimizer': optimizer.state_dict()}, path_pt) 124 | else: 125 | torch.save({ 126 | 'global_step': self.global_step, 127 | 'model': model.state_dict()}, path_pt) 128 | 129 | 130 | def delete_model(self, name='model', postfix=''): 131 | # path 132 | if postfix: 133 | postfix = '_' + postfix 134 | path_pt = os.path.join( 135 | self.expdir , name+postfix+'.pt') 136 | 137 | # delete 138 | if os.path.exists(path_pt): 139 | os.remove(path_pt) 140 | print(' [*] model checkpoint deleted: {}'.format(path_pt)) 141 | 142 | def global_step_increment(self): 143 | self.global_step += 1 144 | 145 | 146 | -------------------------------------------------------------------------------- /diffusion/logger/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | import yaml 6 | 7 | 8 | def traverse_dir( 9 | root_dir, 10 | extensions, 11 | amount=None, 12 | str_include=None, 13 | str_exclude=None, 14 | is_pure=False, 15 | is_sort=False, 16 | is_ext=True): 17 | 18 | file_list = [] 19 | cnt = 0 20 | for root, _, files in os.walk(root_dir): 21 | for file in files: 22 | if any([file.endswith(f".{ext}") for ext in extensions]): 23 | # path 24 | mix_path = os.path.join(root, file) 25 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 26 | 27 | # amount 28 | if (amount is not None) and (cnt == amount): 29 | if is_sort: 30 | file_list.sort() 31 | return file_list 32 | 33 | # check string 34 | if (str_include is not None) and (str_include not in pure_path): 35 | continue 36 | if (str_exclude is not None) and (str_exclude in pure_path): 37 | continue 38 | 39 | if not is_ext: 40 | ext = pure_path.split('.')[-1] 41 | pure_path = pure_path[:-(len(ext)+1)] 42 | file_list.append(pure_path) 43 | cnt += 1 44 | if is_sort: 45 | file_list.sort() 46 | return file_list 47 | 48 | 49 | 50 | class DotDict(dict): 51 | def __getattr__(*args): 52 | val = dict.get(*args) 53 | return DotDict(val) if type(val) is dict else val 54 | 55 | __setattr__ = dict.__setitem__ 56 | __delattr__ = dict.__delitem__ 57 | 58 | 59 | def get_network_paras_amount(model_dict): 60 | info = dict() 61 | for model_name, model in model_dict.items(): 62 | # all_params = sum(p.numel() for p in model.parameters()) 63 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 64 | 65 | info[model_name] = trainable_params 66 | return info 67 | 68 | 69 | def load_config(path_config): 70 | with open(path_config, "r") as config: 71 | args = yaml.safe_load(config) 72 | args = DotDict(args) 73 | # print(args) 74 | return args 75 | 76 | def save_config(path_config,config): 77 | config = dict(config) 78 | with open(path_config, "w") as f: 79 | yaml.dump(config, f) 80 | 81 | def to_json(path_params, path_json): 82 | params = torch.load(path_params, map_location=torch.device('cpu')) 83 | raw_state_dict = {} 84 | for k, v in params.items(): 85 | val = v.flatten().numpy().tolist() 86 | raw_state_dict[k] = val 87 | 88 | with open(path_json, 'w') as outfile: 89 | json.dump(raw_state_dict, outfile,indent= "\t") 90 | 91 | 92 | def convert_tensor_to_numpy(tensor, is_squeeze=True): 93 | if is_squeeze: 94 | tensor = tensor.squeeze() 95 | if tensor.requires_grad: 96 | tensor = tensor.detach() 97 | if tensor.is_cuda: 98 | tensor = tensor.cpu() 99 | return tensor.numpy() 100 | 101 | 102 | def load_model( 103 | expdir, 104 | model, 105 | optimizer, 106 | name='model', 107 | postfix='', 108 | device='cpu'): 109 | if postfix == '': 110 | postfix = '_' + postfix 111 | path = os.path.join(expdir, name+postfix) 112 | path_pt = traverse_dir(expdir, ['pt'], is_ext=False) 113 | global_step = 0 114 | if len(path_pt) > 0: 115 | steps = [s[len(path):] for s in path_pt] 116 | maxstep = max([int(s) if s.isdigit() else 0 for s in steps]) 117 | if maxstep >= 0: 118 | path_pt = path+str(maxstep)+'.pt' 119 | else: 120 | path_pt = path+'best.pt' 121 | print(' [*] restoring model from', path_pt) 122 | ckpt = torch.load(path_pt, map_location=torch.device(device)) 123 | global_step = ckpt['global_step'] 124 | model.load_state_dict(ckpt['model'], strict=False) 125 | if ckpt.get("optimizer") is not None: 126 | optimizer.load_state_dict(ckpt['optimizer']) 127 | return global_step, model, optimizer 128 | -------------------------------------------------------------------------------- /diffusion/vocoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchaudio.transforms import Resample 3 | 4 | from vdecoder.nsf_hifigan.models import load_config, load_model 5 | from vdecoder.nsf_hifigan.nvSTFT import STFT 6 | 7 | 8 | class Vocoder: 9 | def __init__(self, vocoder_type, vocoder_ckpt, device = None): 10 | if device is None: 11 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 12 | self.device = device 13 | 14 | if vocoder_type == 'nsf-hifigan': 15 | self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device) 16 | elif vocoder_type == 'nsf-hifigan-log10': 17 | self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device) 18 | else: 19 | raise ValueError(f" [x] Unknown vocoder: {vocoder_type}") 20 | 21 | self.resample_kernel = {} 22 | self.vocoder_sample_rate = self.vocoder.sample_rate() 23 | self.vocoder_hop_size = self.vocoder.hop_size() 24 | self.dimension = self.vocoder.dimension() 25 | 26 | def extract(self, audio, sample_rate, keyshift=0): 27 | 28 | # resample 29 | if sample_rate == self.vocoder_sample_rate: 30 | audio_res = audio 31 | else: 32 | key_str = str(sample_rate) 33 | if key_str not in self.resample_kernel: 34 | self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device) 35 | audio_res = self.resample_kernel[key_str](audio) 36 | 37 | # extract 38 | mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins 39 | return mel 40 | 41 | def infer(self, mel, f0): 42 | f0 = f0[:,:mel.size(1),0] # B, n_frames 43 | audio = self.vocoder(mel, f0) 44 | return audio 45 | 46 | 47 | class NsfHifiGAN(torch.nn.Module): 48 | def __init__(self, model_path, device=None): 49 | super().__init__() 50 | if device is None: 51 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 52 | self.device = device 53 | self.model_path = model_path 54 | self.model = None 55 | self.h = load_config(model_path) 56 | self.stft = STFT( 57 | self.h.sampling_rate, 58 | self.h.num_mels, 59 | self.h.n_fft, 60 | self.h.win_size, 61 | self.h.hop_size, 62 | self.h.fmin, 63 | self.h.fmax) 64 | 65 | def sample_rate(self): 66 | return self.h.sampling_rate 67 | 68 | def hop_size(self): 69 | return self.h.hop_size 70 | 71 | def dimension(self): 72 | return self.h.num_mels 73 | 74 | def extract(self, audio, keyshift=0): 75 | mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins 76 | return mel 77 | 78 | def forward(self, mel, f0): 79 | if self.model is None: 80 | print('| Load HifiGAN: ', self.model_path) 81 | self.model, self.h = load_model(self.model_path, device=self.device) 82 | with torch.no_grad(): 83 | c = mel.transpose(1, 2) 84 | audio = self.model(c, f0) 85 | return audio 86 | 87 | class NsfHifiGANLog10(NsfHifiGAN): 88 | def forward(self, mel, f0): 89 | if self.model is None: 90 | print('| Load HifiGAN: ', self.model_path) 91 | self.model, self.h = load_model(self.model_path, device=self.device) 92 | with torch.no_grad(): 93 | c = 0.434294 * mel.transpose(1, 2) 94 | audio = self.model(c, f0) 95 | return audio -------------------------------------------------------------------------------- /diffusion/wavenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from math import sqrt 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn import Mish 8 | 9 | 10 | class Conv1d(torch.nn.Conv1d): 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | nn.init.kaiming_normal_(self.weight) 14 | 15 | 16 | class SinusoidalPosEmb(nn.Module): 17 | def __init__(self, dim): 18 | super().__init__() 19 | self.dim = dim 20 | 21 | def forward(self, x): 22 | device = x.device 23 | half_dim = self.dim // 2 24 | emb = math.log(10000) / (half_dim - 1) 25 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 26 | emb = x[:, None] * emb[None, :] 27 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 28 | return emb 29 | 30 | 31 | class ResidualBlock(nn.Module): 32 | def __init__(self, encoder_hidden, residual_channels, dilation): 33 | super().__init__() 34 | self.residual_channels = residual_channels 35 | self.dilated_conv = nn.Conv1d( 36 | residual_channels, 37 | 2 * residual_channels, 38 | kernel_size=3, 39 | padding=dilation, 40 | dilation=dilation 41 | ) 42 | self.diffusion_projection = nn.Linear(residual_channels, residual_channels) 43 | self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1) 44 | self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1) 45 | 46 | def forward(self, x, conditioner, diffusion_step): 47 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) 48 | conditioner = self.conditioner_projection(conditioner) 49 | y = x + diffusion_step 50 | 51 | y = self.dilated_conv(y) + conditioner 52 | 53 | # Using torch.split instead of torch.chunk to avoid using onnx::Slice 54 | gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) 55 | y = torch.sigmoid(gate) * torch.tanh(filter) 56 | 57 | y = self.output_projection(y) 58 | 59 | # Using torch.split instead of torch.chunk to avoid using onnx::Slice 60 | residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) 61 | return (x + residual) / math.sqrt(2.0), skip 62 | 63 | 64 | class WaveNet(nn.Module): 65 | def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256): 66 | super().__init__() 67 | self.input_projection = Conv1d(in_dims, n_chans, 1) 68 | self.diffusion_embedding = SinusoidalPosEmb(n_chans) 69 | self.mlp = nn.Sequential( 70 | nn.Linear(n_chans, n_chans * 4), 71 | Mish(), 72 | nn.Linear(n_chans * 4, n_chans) 73 | ) 74 | self.residual_layers = nn.ModuleList([ 75 | ResidualBlock( 76 | encoder_hidden=n_hidden, 77 | residual_channels=n_chans, 78 | dilation=1 79 | ) 80 | for i in range(n_layers) 81 | ]) 82 | self.skip_projection = Conv1d(n_chans, n_chans, 1) 83 | self.output_projection = Conv1d(n_chans, in_dims, 1) 84 | nn.init.zeros_(self.output_projection.weight) 85 | 86 | def forward(self, spec, diffusion_step, cond): 87 | """ 88 | :param spec: [B, 1, M, T] 89 | :param diffusion_step: [B, 1] 90 | :param cond: [B, M, T] 91 | :return: 92 | """ 93 | x = spec.squeeze(1) 94 | x = self.input_projection(x) # [B, residual_channel, T] 95 | 96 | x = F.relu(x) 97 | diffusion_step = self.diffusion_embedding(diffusion_step) 98 | diffusion_step = self.mlp(diffusion_step) 99 | skip = [] 100 | for layer in self.residual_layers: 101 | x, skip_connection = layer(x, cond, diffusion_step) 102 | skip.append(skip_connection) 103 | 104 | x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) 105 | x = self.skip_projection(x) 106 | x = F.relu(x) 107 | x = self.output_projection(x) # [B, mel_bins, T] 108 | return x[:, None, :, :] 109 | -------------------------------------------------------------------------------- /edgetts/tts.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import random 3 | import sys 4 | 5 | import edge_tts 6 | from edge_tts import VoicesManager 7 | from langdetect import DetectorFactory, detect 8 | 9 | DetectorFactory.seed = 0 10 | 11 | TEXT = sys.argv[1] 12 | LANG = detect(TEXT) if sys.argv[2] == "Auto" else sys.argv[2] 13 | RATE = sys.argv[3] 14 | VOLUME = sys.argv[4] 15 | GENDER = sys.argv[5] if len(sys.argv) == 6 else None 16 | OUTPUT_FILE = "tts.wav" 17 | 18 | print("Running TTS...") 19 | print(f"Text: {TEXT}, Language: {LANG}, Gender: {GENDER}, Rate: {RATE}, Volume: {VOLUME}") 20 | 21 | async def _main() -> None: 22 | voices = await VoicesManager.create() 23 | if GENDER is not None: 24 | # From "zh-cn" to "zh-CN" etc. 25 | if LANG == "zh-cn" or LANG == "zh-tw": 26 | LOCALE = LANG[:-2] + LANG[-2:].upper() 27 | voice = voices.find(Gender=GENDER, Locale=LOCALE) 28 | else: 29 | voice = voices.find(Gender=GENDER, Language=LANG) 30 | VOICE = random.choice(voice)["Name"] 31 | print(f"Using random {LANG} voice: {VOICE}") 32 | else: 33 | VOICE = LANG 34 | 35 | communicate = edge_tts.Communicate(text = TEXT, voice = VOICE, rate = RATE, volume = VOLUME) 36 | await communicate.save(OUTPUT_FILE) 37 | 38 | if __name__ == "__main__": 39 | if sys.platform.startswith("win"): 40 | asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) 41 | asyncio.run(_main()) 42 | else: 43 | loop = asyncio.get_event_loop_policy().get_event_loop() 44 | try: 45 | loop.run_until_complete(_main()) 46 | finally: 47 | loop.close() 48 | -------------------------------------------------------------------------------- /export_index_for_onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import faiss 5 | 6 | path = "crs" 7 | indexs_file_path = f"checkpoints/{path}/feature_and_index.pkl" 8 | indexs_out_dir = f"checkpoints/{path}/" 9 | 10 | with open("feature_and_index.pkl",mode="rb") as f: 11 | indexs = pickle.load(f) 12 | 13 | for k in indexs: 14 | print(f"Save {k} index") 15 | faiss.write_index( 16 | indexs[k], 17 | os.path.join(indexs_out_dir,f"Index-{k}.index") 18 | ) 19 | 20 | print("Saved all index") -------------------------------------------------------------------------------- /filelists/test.txt: -------------------------------------------------------------------------------- 1 | ./dataset/44k/taffy/000562.wav 2 | ./dataset/44k/nyaru/000011.wav 3 | ./dataset/44k/nyaru/000008.wav 4 | ./dataset/44k/taffy/000563.wav 5 | -------------------------------------------------------------------------------- /filelists/train.txt: -------------------------------------------------------------------------------- 1 | ./dataset/44k/taffy/000549.wav 2 | ./dataset/44k/nyaru/000004.wav 3 | ./dataset/44k/nyaru/000006.wav 4 | ./dataset/44k/taffy/000551.wav 5 | ./dataset/44k/nyaru/000009.wav 6 | ./dataset/44k/taffy/000561.wav 7 | ./dataset/44k/nyaru/000001.wav 8 | ./dataset/44k/taffy/000553.wav 9 | ./dataset/44k/nyaru/000002.wav 10 | ./dataset/44k/taffy/000560.wav 11 | ./dataset/44k/taffy/000557.wav 12 | ./dataset/44k/nyaru/000005.wav 13 | ./dataset/44k/taffy/000554.wav 14 | ./dataset/44k/taffy/000550.wav 15 | ./dataset/44k/taffy/000559.wav 16 | -------------------------------------------------------------------------------- /filelists/val.txt: -------------------------------------------------------------------------------- 1 | ./dataset/44k/nyaru/000003.wav 2 | ./dataset/44k/nyaru/000007.wav 3 | ./dataset/44k/taffy/000558.wav 4 | ./dataset/44k/taffy/000556.wav 5 | -------------------------------------------------------------------------------- /flask_api.py: -------------------------------------------------------------------------------- 1 | import io 2 | import logging 3 | 4 | import soundfile 5 | import torch 6 | import torchaudio 7 | from flask import Flask, request, send_file 8 | from flask_cors import CORS 9 | 10 | from inference.infer_tool import RealTimeVC, Svc 11 | 12 | app = Flask(__name__) 13 | 14 | CORS(app) 15 | 16 | logging.getLogger('numba').setLevel(logging.WARNING) 17 | 18 | 19 | @app.route("/voiceChangeModel", methods=["POST"]) 20 | def voice_change_model(): 21 | request_form = request.form 22 | wave_file = request.files.get("sample", None) 23 | # 变调信息 24 | f_pitch_change = float(request_form.get("fPitchChange", 0)) 25 | # DAW所需的采样率 26 | daw_sample = int(float(request_form.get("sampleRate", 0))) 27 | speaker_id = int(float(request_form.get("sSpeakId", 0))) 28 | # http获得wav文件并转换 29 | input_wav_path = io.BytesIO(wave_file.read()) 30 | 31 | # 模型推理 32 | if raw_infer: 33 | # out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path) 34 | out_audio, out_sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path, cluster_infer_ratio=0, 35 | auto_predict_f0=False, noice_scale=0.4, f0_filter=False) 36 | tar_audio = torchaudio.functional.resample(out_audio, svc_model.target_sample, daw_sample) 37 | else: 38 | out_audio = svc.process(svc_model, speaker_id, f_pitch_change, input_wav_path, cluster_infer_ratio=0, 39 | auto_predict_f0=False, noice_scale=0.4, f0_filter=False) 40 | tar_audio = torchaudio.functional.resample(torch.from_numpy(out_audio), svc_model.target_sample, daw_sample) 41 | # 返回音频 42 | out_wav_path = io.BytesIO() 43 | soundfile.write(out_wav_path, tar_audio.cpu().numpy(), daw_sample, format="wav") 44 | out_wav_path.seek(0) 45 | return send_file(out_wav_path, download_name="temp.wav", as_attachment=True) 46 | 47 | 48 | if __name__ == '__main__': 49 | # 启用则为直接切片合成,False为交叉淡化方式 50 | # vst插件调整0.3-0.5s切片时间可以降低延迟,直接切片方法会有连接处爆音、交叉淡化会有轻微重叠声音 51 | # 自行选择能接受的方法,或将vst最大切片时间调整为1s,此处设为Ture,延迟大音质稳定一些 52 | raw_infer = True 53 | # 每个模型和config是唯一对应的 54 | model_name = "logs/32k/G_174000-Copy1.pth" 55 | config_name = "configs/config.json" 56 | cluster_model_path = "logs/44k/kmeans_10000.pt" 57 | svc_model = Svc(model_name, config_name, cluster_model_path=cluster_model_path) 58 | svc = RealTimeVC() 59 | # 此处与vst插件对应,不建议更改 60 | app.run(port=6842, host="0.0.0.0", debug=False, threaded=False) 61 | -------------------------------------------------------------------------------- /flask_api_full_song.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import numpy as np 4 | import soundfile 5 | from flask import Flask, request, send_file 6 | 7 | from inference import infer_tool, slicer 8 | 9 | app = Flask(__name__) 10 | 11 | 12 | @app.route("/wav2wav", methods=["POST"]) 13 | def wav2wav(): 14 | request_form = request.form 15 | audio_path = request_form.get("audio_path", None) # wav文件地址 16 | tran = int(float(request_form.get("tran", 0))) # 音调 17 | spk = request_form.get("spk", 0) # 说话人(id或者name都可以,具体看你的config) 18 | wav_format = request_form.get("wav_format", 'wav') # 范围文件格式 19 | infer_tool.format_wav(audio_path) 20 | chunks = slicer.cut(audio_path, db_thresh=-40) 21 | audio_data, audio_sr = slicer.chunks2audio(audio_path, chunks) 22 | 23 | audio = [] 24 | for (slice_tag, data) in audio_data: 25 | print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======') 26 | 27 | length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample)) 28 | if slice_tag: 29 | print('jump empty segment') 30 | _audio = np.zeros(length) 31 | else: 32 | # padd 33 | pad_len = int(audio_sr * 0.5) 34 | data = np.concatenate([np.zeros([pad_len]), data, np.zeros([pad_len])]) 35 | raw_path = io.BytesIO() 36 | soundfile.write(raw_path, data, audio_sr, format="wav") 37 | raw_path.seek(0) 38 | out_audio, out_sr = svc_model.infer(spk, tran, raw_path) 39 | svc_model.clear_empty() 40 | _audio = out_audio.cpu().numpy() 41 | pad_len = int(svc_model.target_sample * 0.5) 42 | _audio = _audio[pad_len:-pad_len] 43 | 44 | audio.extend(list(infer_tool.pad_array(_audio, length))) 45 | out_wav_path = io.BytesIO() 46 | soundfile.write(out_wav_path, audio, svc_model.target_sample, format=wav_format) 47 | out_wav_path.seek(0) 48 | return send_file(out_wav_path, download_name=f"temp.{wav_format}", as_attachment=True) 49 | 50 | 51 | if __name__ == '__main__': 52 | model_name = "logs/44k/G_60000.pth" # 模型地址 53 | config_name = "configs/config.json" # config地址 54 | svc_model = infer_tool.Svc(model_name, config_name) 55 | app.run(port=1145, host="0.0.0.0", debug=False, threaded=False) 56 | -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/inference/__init__.py -------------------------------------------------------------------------------- /logs/44k/diffusion/put_diffusion_pretrained_model_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/logs/44k/diffusion/put_diffusion_pretrained_model_here -------------------------------------------------------------------------------- /logs/44k/put_pretrained_model_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/logs/44k/put_pretrained_model_here -------------------------------------------------------------------------------- /modules/DSConv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils import remove_weight_norm, weight_norm 3 | 4 | 5 | class Depthwise_Separable_Conv1D(nn.Module): 6 | def __init__( 7 | self, 8 | in_channels, 9 | out_channels, 10 | kernel_size, 11 | stride = 1, 12 | padding = 0, 13 | dilation = 1, 14 | bias = True, 15 | padding_mode = 'zeros', # TODO: refine this type 16 | device=None, 17 | dtype=None 18 | ): 19 | super().__init__() 20 | self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) 21 | self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) 22 | 23 | def forward(self, input): 24 | return self.point_conv(self.depth_conv(input)) 25 | 26 | def weight_norm(self): 27 | self.depth_conv = weight_norm(self.depth_conv, name = 'weight') 28 | self.point_conv = weight_norm(self.point_conv, name = 'weight') 29 | 30 | def remove_weight_norm(self): 31 | self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight') 32 | self.point_conv = remove_weight_norm(self.point_conv, name = 'weight') 33 | 34 | class Depthwise_Separable_TransposeConv1D(nn.Module): 35 | def __init__( 36 | self, 37 | in_channels, 38 | out_channels, 39 | kernel_size, 40 | stride = 1, 41 | padding = 0, 42 | output_padding = 0, 43 | bias = True, 44 | dilation = 1, 45 | padding_mode = 'zeros', # TODO: refine this type 46 | device=None, 47 | dtype=None 48 | ): 49 | super().__init__() 50 | self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype) 51 | self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype) 52 | 53 | def forward(self, input): 54 | return self.point_conv(self.depth_conv(input)) 55 | 56 | def weight_norm(self): 57 | self.depth_conv = weight_norm(self.depth_conv, name = 'weight') 58 | self.point_conv = weight_norm(self.point_conv, name = 'weight') 59 | 60 | def remove_weight_norm(self): 61 | remove_weight_norm(self.depth_conv, name = 'weight') 62 | remove_weight_norm(self.point_conv, name = 'weight') 63 | 64 | 65 | def weight_norm_modules(module, name = 'weight', dim = 0): 66 | if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): 67 | module.weight_norm() 68 | return module 69 | else: 70 | return weight_norm(module,name,dim) 71 | 72 | def remove_weight_norm_modules(module, name = 'weight'): 73 | if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D): 74 | module.remove_weight_norm() 75 | else: 76 | remove_weight_norm(module,name) -------------------------------------------------------------------------------- /modules/F0Predictor/CrepeF0Predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from modules.F0Predictor.crepe import CrepePitchExtractor 4 | from modules.F0Predictor.F0Predictor import F0Predictor 5 | 6 | 7 | class CrepeF0Predictor(F0Predictor): 8 | def __init__(self,hop_length=512,f0_min=50,f0_max=1100,device=None,sampling_rate=44100,threshold=0.05,model="full"): 9 | self.F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=threshold,model=model) 10 | self.hop_length = hop_length 11 | self.f0_min = f0_min 12 | self.f0_max = f0_max 13 | self.device = device 14 | self.threshold = threshold 15 | self.sampling_rate = sampling_rate 16 | self.name = "crepe" 17 | 18 | def compute_f0(self,wav,p_len=None): 19 | x = torch.FloatTensor(wav).to(self.device) 20 | if p_len is None: 21 | p_len = x.shape[0]//self.hop_length 22 | else: 23 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" 24 | f0,uv = self.F0Creper(x[None,:].float(),self.sampling_rate,pad_to=p_len) 25 | return f0 26 | 27 | def compute_f0_uv(self,wav,p_len=None): 28 | x = torch.FloatTensor(wav).to(self.device) 29 | if p_len is None: 30 | p_len = x.shape[0]//self.hop_length 31 | else: 32 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" 33 | f0,uv = self.F0Creper(x[None,:].float(),self.sampling_rate,pad_to=p_len) 34 | return f0,uv -------------------------------------------------------------------------------- /modules/F0Predictor/DioF0Predictor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyworld 3 | 4 | from modules.F0Predictor.F0Predictor import F0Predictor 5 | 6 | 7 | class DioF0Predictor(F0Predictor): 8 | def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): 9 | self.hop_length = hop_length 10 | self.f0_min = f0_min 11 | self.f0_max = f0_max 12 | self.sampling_rate = sampling_rate 13 | self.name = "dio" 14 | 15 | def interpolate_f0(self,f0): 16 | ''' 17 | 对F0进行插值处理 18 | ''' 19 | vuv_vector = np.zeros_like(f0, dtype=np.float32) 20 | vuv_vector[f0 > 0.0] = 1.0 21 | vuv_vector[f0 <= 0.0] = 0.0 22 | 23 | nzindex = np.nonzero(f0)[0] 24 | data = f0[nzindex] 25 | nzindex = nzindex.astype(np.float32) 26 | time_org = self.hop_length / self.sampling_rate * nzindex 27 | time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate 28 | 29 | if data.shape[0] <= 0: 30 | return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector 31 | 32 | if data.shape[0] == 1: 33 | return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector 34 | 35 | f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) 36 | 37 | return f0,vuv_vector 38 | 39 | def resize_f0(self,x, target_len): 40 | source = np.array(x) 41 | source[source<0.001] = np.nan 42 | target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source) 43 | res = np.nan_to_num(target) 44 | return res 45 | 46 | def compute_f0(self,wav,p_len=None): 47 | if p_len is None: 48 | p_len = wav.shape[0]//self.hop_length 49 | f0, t = pyworld.dio( 50 | wav.astype(np.double), 51 | fs=self.sampling_rate, 52 | f0_floor=self.f0_min, 53 | f0_ceil=self.f0_max, 54 | frame_period=1000 * self.hop_length / self.sampling_rate, 55 | ) 56 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate) 57 | for index, pitch in enumerate(f0): 58 | f0[index] = round(pitch, 1) 59 | return self.interpolate_f0(self.resize_f0(f0, p_len))[0] 60 | 61 | def compute_f0_uv(self,wav,p_len=None): 62 | if p_len is None: 63 | p_len = wav.shape[0]//self.hop_length 64 | f0, t = pyworld.dio( 65 | wav.astype(np.double), 66 | fs=self.sampling_rate, 67 | f0_floor=self.f0_min, 68 | f0_ceil=self.f0_max, 69 | frame_period=1000 * self.hop_length / self.sampling_rate, 70 | ) 71 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate) 72 | for index, pitch in enumerate(f0): 73 | f0[index] = round(pitch, 1) 74 | return self.interpolate_f0(self.resize_f0(f0, p_len)) 75 | -------------------------------------------------------------------------------- /modules/F0Predictor/F0Predictor.py: -------------------------------------------------------------------------------- 1 | class F0Predictor(object): 2 | def compute_f0(self,wav,p_len): 3 | ''' 4 | input: wav:[signal_length] 5 | p_len:int 6 | output: f0:[signal_length//hop_length] 7 | ''' 8 | pass 9 | 10 | def compute_f0_uv(self,wav,p_len): 11 | ''' 12 | input: wav:[signal_length] 13 | p_len:int 14 | output: f0:[signal_length//hop_length],uv:[signal_length//hop_length] 15 | ''' 16 | pass -------------------------------------------------------------------------------- /modules/F0Predictor/FCPEF0Predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from modules.F0Predictor.F0Predictor import F0Predictor 8 | 9 | from .fcpe.model import FCPEInfer 10 | 11 | 12 | class FCPEF0Predictor(F0Predictor): 13 | def __init__(self, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sampling_rate=44100, 14 | threshold=0.05): 15 | self.fcpe = FCPEInfer(model_path="pretrain/fcpe.pt", device=device, dtype=dtype) 16 | self.hop_length = hop_length 17 | self.f0_min = f0_min 18 | self.f0_max = f0_max 19 | if device is None: 20 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | else: 22 | self.device = device 23 | self.threshold = threshold 24 | self.sampling_rate = sampling_rate 25 | self.dtype = dtype 26 | self.name = "fcpe" 27 | 28 | def repeat_expand( 29 | self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" 30 | ): 31 | ndim = content.ndim 32 | 33 | if content.ndim == 1: 34 | content = content[None, None] 35 | elif content.ndim == 2: 36 | content = content[None] 37 | 38 | assert content.ndim == 3 39 | 40 | is_np = isinstance(content, np.ndarray) 41 | if is_np: 42 | content = torch.from_numpy(content) 43 | 44 | results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) 45 | 46 | if is_np: 47 | results = results.numpy() 48 | 49 | if ndim == 1: 50 | return results[0, 0] 51 | elif ndim == 2: 52 | return results[0] 53 | 54 | def post_process(self, x, sampling_rate, f0, pad_to): 55 | if isinstance(f0, np.ndarray): 56 | f0 = torch.from_numpy(f0).float().to(x.device) 57 | 58 | if pad_to is None: 59 | return f0 60 | 61 | f0 = self.repeat_expand(f0, pad_to) 62 | 63 | vuv_vector = torch.zeros_like(f0) 64 | vuv_vector[f0 > 0.0] = 1.0 65 | vuv_vector[f0 <= 0.0] = 0.0 66 | 67 | # 去掉0频率, 并线性插值 68 | nzindex = torch.nonzero(f0).squeeze() 69 | f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() 70 | time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() 71 | time_frame = np.arange(pad_to) * self.hop_length / sampling_rate 72 | 73 | vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0] 74 | 75 | if f0.shape[0] <= 0: 76 | return torch.zeros(pad_to, dtype=torch.float, device=x.device).cpu().numpy(), vuv_vector.cpu().numpy() 77 | if f0.shape[0] == 1: 78 | return (torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[ 79 | 0]).cpu().numpy(), vuv_vector.cpu().numpy() 80 | 81 | # 大概可以用 torch 重写? 82 | f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) 83 | # vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) 84 | 85 | return f0, vuv_vector.cpu().numpy() 86 | 87 | def compute_f0(self, wav, p_len=None): 88 | x = torch.FloatTensor(wav).to(self.dtype).to(self.device) 89 | if p_len is None: 90 | p_len = x.shape[0] // self.hop_length 91 | else: 92 | assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" 93 | f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0] 94 | if torch.all(f0 == 0): 95 | rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) 96 | return rtn, rtn 97 | return self.post_process(x, self.sampling_rate, f0, p_len)[0] 98 | 99 | def compute_f0_uv(self, wav, p_len=None): 100 | x = torch.FloatTensor(wav).to(self.dtype).to(self.device) 101 | if p_len is None: 102 | p_len = x.shape[0] // self.hop_length 103 | else: 104 | assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error" 105 | f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0] 106 | if torch.all(f0 == 0): 107 | rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) 108 | return rtn, rtn 109 | return self.post_process(x, self.sampling_rate, f0, p_len) -------------------------------------------------------------------------------- /modules/F0Predictor/HarvestF0Predictor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyworld 3 | 4 | from modules.F0Predictor.F0Predictor import F0Predictor 5 | 6 | 7 | class HarvestF0Predictor(F0Predictor): 8 | def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): 9 | self.hop_length = hop_length 10 | self.f0_min = f0_min 11 | self.f0_max = f0_max 12 | self.sampling_rate = sampling_rate 13 | self.name = "harvest" 14 | 15 | def interpolate_f0(self,f0): 16 | ''' 17 | 对F0进行插值处理 18 | ''' 19 | vuv_vector = np.zeros_like(f0, dtype=np.float32) 20 | vuv_vector[f0 > 0.0] = 1.0 21 | vuv_vector[f0 <= 0.0] = 0.0 22 | 23 | nzindex = np.nonzero(f0)[0] 24 | data = f0[nzindex] 25 | nzindex = nzindex.astype(np.float32) 26 | time_org = self.hop_length / self.sampling_rate * nzindex 27 | time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate 28 | 29 | if data.shape[0] <= 0: 30 | return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector 31 | 32 | if data.shape[0] == 1: 33 | return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector 34 | 35 | f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) 36 | 37 | return f0,vuv_vector 38 | def resize_f0(self,x, target_len): 39 | source = np.array(x) 40 | source[source<0.001] = np.nan 41 | target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source) 42 | res = np.nan_to_num(target) 43 | return res 44 | 45 | def compute_f0(self,wav,p_len=None): 46 | if p_len is None: 47 | p_len = wav.shape[0]//self.hop_length 48 | f0, t = pyworld.harvest( 49 | wav.astype(np.double), 50 | fs=self.hop_length, 51 | f0_ceil=self.f0_max, 52 | f0_floor=self.f0_min, 53 | frame_period=1000 * self.hop_length / self.sampling_rate, 54 | ) 55 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.fs) 56 | return self.interpolate_f0(self.resize_f0(f0, p_len))[0] 57 | 58 | def compute_f0_uv(self,wav,p_len=None): 59 | if p_len is None: 60 | p_len = wav.shape[0]//self.hop_length 61 | f0, t = pyworld.harvest( 62 | wav.astype(np.double), 63 | fs=self.sampling_rate, 64 | f0_floor=self.f0_min, 65 | f0_ceil=self.f0_max, 66 | frame_period=1000 * self.hop_length / self.sampling_rate, 67 | ) 68 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate) 69 | return self.interpolate_f0(self.resize_f0(f0, p_len)) 70 | -------------------------------------------------------------------------------- /modules/F0Predictor/PMF0Predictor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import parselmouth 3 | 4 | from modules.F0Predictor.F0Predictor import F0Predictor 5 | 6 | 7 | class PMF0Predictor(F0Predictor): 8 | def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100): 9 | self.hop_length = hop_length 10 | self.f0_min = f0_min 11 | self.f0_max = f0_max 12 | self.sampling_rate = sampling_rate 13 | self.name = "pm" 14 | 15 | def interpolate_f0(self,f0): 16 | ''' 17 | 对F0进行插值处理 18 | ''' 19 | vuv_vector = np.zeros_like(f0, dtype=np.float32) 20 | vuv_vector[f0 > 0.0] = 1.0 21 | vuv_vector[f0 <= 0.0] = 0.0 22 | 23 | nzindex = np.nonzero(f0)[0] 24 | data = f0[nzindex] 25 | nzindex = nzindex.astype(np.float32) 26 | time_org = self.hop_length / self.sampling_rate * nzindex 27 | time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate 28 | 29 | if data.shape[0] <= 0: 30 | return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector 31 | 32 | if data.shape[0] == 1: 33 | return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector 34 | 35 | f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1]) 36 | 37 | return f0,vuv_vector 38 | 39 | 40 | def compute_f0(self,wav,p_len=None): 41 | x = wav 42 | if p_len is None: 43 | p_len = x.shape[0]//self.hop_length 44 | else: 45 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" 46 | time_step = self.hop_length / self.sampling_rate * 1000 47 | f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac( 48 | time_step=time_step / 1000, voicing_threshold=0.6, 49 | pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency'] 50 | 51 | pad_size=(p_len - len(f0) + 1) // 2 52 | if(pad_size>0 or p_len - len(f0) - pad_size>0): 53 | f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') 54 | f0,uv = self.interpolate_f0(f0) 55 | return f0 56 | 57 | def compute_f0_uv(self,wav,p_len=None): 58 | x = wav 59 | if p_len is None: 60 | p_len = x.shape[0]//self.hop_length 61 | else: 62 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" 63 | time_step = self.hop_length / self.sampling_rate * 1000 64 | f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac( 65 | time_step=time_step / 1000, voicing_threshold=0.6, 66 | pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency'] 67 | 68 | pad_size=(p_len - len(f0) + 1) // 2 69 | if(pad_size>0 or p_len - len(f0) - pad_size>0): 70 | f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') 71 | f0,uv = self.interpolate_f0(f0) 72 | return f0,uv 73 | -------------------------------------------------------------------------------- /modules/F0Predictor/RMVPEF0Predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from modules.F0Predictor.F0Predictor import F0Predictor 8 | 9 | from .rmvpe import RMVPE 10 | 11 | 12 | class RMVPEF0Predictor(F0Predictor): 13 | def __init__(self,hop_length=512,f0_min=50,f0_max=1100, dtype=torch.float32, device=None,sampling_rate=44100,threshold=0.05): 14 | self.rmvpe = RMVPE(model_path="pretrain/rmvpe.pt",dtype=dtype,device=device) 15 | self.hop_length = hop_length 16 | self.f0_min = f0_min 17 | self.f0_max = f0_max 18 | if device is None: 19 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 20 | else: 21 | self.device = device 22 | self.threshold = threshold 23 | self.sampling_rate = sampling_rate 24 | self.dtype = dtype 25 | self.name = "rmvpe" 26 | 27 | def repeat_expand( 28 | self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" 29 | ): 30 | ndim = content.ndim 31 | 32 | if content.ndim == 1: 33 | content = content[None, None] 34 | elif content.ndim == 2: 35 | content = content[None] 36 | 37 | assert content.ndim == 3 38 | 39 | is_np = isinstance(content, np.ndarray) 40 | if is_np: 41 | content = torch.from_numpy(content) 42 | 43 | results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) 44 | 45 | if is_np: 46 | results = results.numpy() 47 | 48 | if ndim == 1: 49 | return results[0, 0] 50 | elif ndim == 2: 51 | return results[0] 52 | 53 | def post_process(self, x, sampling_rate, f0, pad_to): 54 | if isinstance(f0, np.ndarray): 55 | f0 = torch.from_numpy(f0).float().to(x.device) 56 | 57 | if pad_to is None: 58 | return f0 59 | 60 | f0 = self.repeat_expand(f0, pad_to) 61 | 62 | vuv_vector = torch.zeros_like(f0) 63 | vuv_vector[f0 > 0.0] = 1.0 64 | vuv_vector[f0 <= 0.0] = 0.0 65 | 66 | # 去掉0频率, 并线性插值 67 | nzindex = torch.nonzero(f0).squeeze() 68 | f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() 69 | time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() 70 | time_frame = np.arange(pad_to) * self.hop_length / sampling_rate 71 | 72 | vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0] 73 | 74 | if f0.shape[0] <= 0: 75 | return torch.zeros(pad_to, dtype=torch.float, device=x.device).cpu().numpy(),vuv_vector.cpu().numpy() 76 | if f0.shape[0] == 1: 77 | return (torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0]).cpu().numpy() ,vuv_vector.cpu().numpy() 78 | 79 | # 大概可以用 torch 重写? 80 | f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) 81 | #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) 82 | 83 | return f0,vuv_vector.cpu().numpy() 84 | 85 | def compute_f0(self,wav,p_len=None): 86 | x = torch.FloatTensor(wav).to(self.dtype).to(self.device) 87 | if p_len is None: 88 | p_len = x.shape[0]//self.hop_length 89 | else: 90 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" 91 | f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold) 92 | if torch.all(f0 == 0): 93 | rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) 94 | return rtn,rtn 95 | return self.post_process(x,self.sampling_rate,f0,p_len)[0] 96 | 97 | def compute_f0_uv(self,wav,p_len=None): 98 | x = torch.FloatTensor(wav).to(self.dtype).to(self.device) 99 | if p_len is None: 100 | p_len = x.shape[0]//self.hop_length 101 | else: 102 | assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error" 103 | f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold) 104 | if torch.all(f0 == 0): 105 | rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len) 106 | return rtn,rtn 107 | return self.post_process(x,self.sampling_rate,f0,p_len) -------------------------------------------------------------------------------- /modules/F0Predictor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/modules/F0Predictor/__init__.py -------------------------------------------------------------------------------- /modules/F0Predictor/fcpe/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import FCPEInfer # noqa: F401 2 | from .nvSTFT import STFT # noqa: F401 3 | from .pcmer import PCmer # noqa: F401 4 | -------------------------------------------------------------------------------- /modules/F0Predictor/fcpe/nvSTFT.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | import soundfile as sf 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.utils.data 9 | from librosa.filters import mel as librosa_mel_fn 10 | 11 | os.environ["LRU_CACHE_CAPACITY"] = "3" 12 | 13 | def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): 14 | sampling_rate = None 15 | try: 16 | data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. 17 | except Exception as ex: 18 | print(f"'{full_path}' failed to load.\nException:") 19 | print(ex) 20 | if return_empty_on_exception: 21 | return [], sampling_rate or target_sr or 48000 22 | else: 23 | raise Exception(ex) 24 | 25 | if len(data.shape) > 1: 26 | data = data[:, 0] 27 | assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) 28 | 29 | if np.issubdtype(data.dtype, np.integer): # if audio data is type int 30 | max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX 31 | else: # if audio data is type fp32 32 | max_mag = max(np.amax(data), -np.amin(data)) 33 | max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 34 | 35 | data = torch.FloatTensor(data.astype(np.float32))/max_mag 36 | 37 | if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except 38 | return [], sampling_rate or target_sr or 48000 39 | if target_sr is not None and sampling_rate != target_sr: 40 | data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) 41 | sampling_rate = target_sr 42 | 43 | return data, sampling_rate 44 | 45 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 46 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 47 | 48 | def dynamic_range_decompression(x, C=1): 49 | return np.exp(x) / C 50 | 51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 52 | return torch.log(torch.clamp(x, min=clip_val) * C) 53 | 54 | def dynamic_range_decompression_torch(x, C=1): 55 | return torch.exp(x) / C 56 | 57 | class STFT(): 58 | def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): 59 | self.target_sr = sr 60 | 61 | self.n_mels = n_mels 62 | self.n_fft = n_fft 63 | self.win_size = win_size 64 | self.hop_length = hop_length 65 | self.fmin = fmin 66 | self.fmax = fmax 67 | self.clip_val = clip_val 68 | self.mel_basis = {} 69 | self.hann_window = {} 70 | 71 | def get_mel(self, y, keyshift=0, speed=1, center=False, train=False): 72 | sampling_rate = self.target_sr 73 | n_mels = self.n_mels 74 | n_fft = self.n_fft 75 | win_size = self.win_size 76 | hop_length = self.hop_length 77 | fmin = self.fmin 78 | fmax = self.fmax 79 | clip_val = self.clip_val 80 | 81 | factor = 2 ** (keyshift / 12) 82 | n_fft_new = int(np.round(n_fft * factor)) 83 | win_size_new = int(np.round(win_size * factor)) 84 | hop_length_new = int(np.round(hop_length * speed)) 85 | if not train: 86 | mel_basis = self.mel_basis 87 | hann_window = self.hann_window 88 | else: 89 | mel_basis = {} 90 | hann_window = {} 91 | 92 | if torch.min(y) < -1.: 93 | print('min value is ', torch.min(y)) 94 | if torch.max(y) > 1.: 95 | print('max value is ', torch.max(y)) 96 | 97 | mel_basis_key = str(fmax)+'_'+str(y.device) 98 | if mel_basis_key not in mel_basis: 99 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) 100 | mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) 101 | 102 | keyshift_key = str(keyshift)+'_'+str(y.device) 103 | if keyshift_key not in hann_window: 104 | hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device) 105 | 106 | pad_left = (win_size_new - hop_length_new) //2 107 | pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left) 108 | if pad_right < y.size(-1): 109 | mode = 'reflect' 110 | else: 111 | mode = 'constant' 112 | y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode) 113 | y = y.squeeze(1) 114 | 115 | spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], 116 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 117 | spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9)) 118 | if keyshift != 0: 119 | size = n_fft // 2 + 1 120 | resize = spec.size(1) 121 | if resize < size: 122 | spec = F.pad(spec, (0, 0, 0, size-resize)) 123 | spec = spec[:, :size, :] * win_size / win_size_new 124 | spec = torch.matmul(mel_basis[mel_basis_key], spec) 125 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val) 126 | return spec 127 | 128 | def __call__(self, audiopath): 129 | audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) 130 | spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) 131 | return spect 132 | 133 | stft = STFT() 134 | -------------------------------------------------------------------------------- /modules/F0Predictor/rmvpe/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import * # noqa: F403 2 | from .inference import RMVPE # noqa: F401 3 | from .model import E2E, E2E0 # noqa: F401 4 | from .spec import MelSpectrogram # noqa: F401 5 | from .utils import ( # noqa: F401 6 | cycle, 7 | summary, 8 | to_local_average_cents, 9 | to_viterbi_cents, 10 | ) 11 | -------------------------------------------------------------------------------- /modules/F0Predictor/rmvpe/constants.py: -------------------------------------------------------------------------------- 1 | SAMPLE_RATE = 16000 2 | 3 | N_CLASS = 360 4 | 5 | N_MELS = 128 6 | MEL_FMIN = 30 7 | MEL_FMAX = SAMPLE_RATE // 2 8 | WINDOW_LENGTH = 1024 9 | CONST = 1997.3794084376191 10 | -------------------------------------------------------------------------------- /modules/F0Predictor/rmvpe/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchaudio.transforms import Resample 4 | 5 | from .constants import * # noqa: F403 6 | from .model import E2E0 7 | from .spec import MelSpectrogram 8 | from .utils import to_local_average_cents, to_viterbi_cents 9 | 10 | 11 | class RMVPE: 12 | def __init__(self, model_path, device=None, dtype = torch.float32, hop_length=160): 13 | self.resample_kernel = {} 14 | if device is None: 15 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 16 | else: 17 | self.device = device 18 | model = E2E0(4, 1, (2, 2)) 19 | ckpt = torch.load(model_path, map_location=torch.device(self.device)) 20 | model.load_state_dict(ckpt['model']) 21 | model = model.to(dtype).to(self.device) 22 | model.eval() 23 | self.model = model 24 | self.dtype = dtype 25 | self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405 26 | self.resample_kernel = {} 27 | 28 | def mel2hidden(self, mel): 29 | with torch.no_grad(): 30 | n_frames = mel.shape[-1] 31 | mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant') 32 | hidden = self.model(mel) 33 | return hidden[:, :n_frames] 34 | 35 | def decode(self, hidden, thred=0.03, use_viterbi=False): 36 | if use_viterbi: 37 | cents_pred = to_viterbi_cents(hidden, thred=thred) 38 | else: 39 | cents_pred = to_local_average_cents(hidden, thred=thred) 40 | f0 = torch.Tensor([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred]).to(self.device) 41 | return f0 42 | 43 | def infer_from_audio(self, audio, sample_rate=16000, thred=0.05, use_viterbi=False): 44 | audio = audio.unsqueeze(0).to(self.dtype).to(self.device) 45 | if sample_rate == 16000: 46 | audio_res = audio 47 | else: 48 | key_str = str(sample_rate) 49 | if key_str not in self.resample_kernel: 50 | self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128) 51 | self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device) 52 | audio_res = self.resample_kernel[key_str](audio) 53 | mel_extractor = self.mel_extractor.to(self.device) 54 | mel = mel_extractor(audio_res, center=True).to(self.dtype) 55 | hidden = self.mel2hidden(mel) 56 | f0 = self.decode(hidden.squeeze(0), thred=thred, use_viterbi=use_viterbi) 57 | return f0 58 | -------------------------------------------------------------------------------- /modules/F0Predictor/rmvpe/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .constants import * # noqa: F403 4 | from .deepunet import DeepUnet, DeepUnet0 5 | from .seq import BiGRU 6 | from .spec import MelSpectrogram 7 | 8 | 9 | class E2E(nn.Module): 10 | def __init__(self, hop_length, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, 11 | en_out_channels=16): 12 | super(E2E, self).__init__() 13 | self.mel = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405 14 | self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) 15 | self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) 16 | if n_gru: 17 | self.fc = nn.Sequential( 18 | BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405 19 | nn.Linear(512, N_CLASS), # noqa: F405 20 | nn.Dropout(0.25), 21 | nn.Sigmoid() 22 | ) 23 | else: 24 | self.fc = nn.Sequential( 25 | nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405 26 | nn.Dropout(0.25), 27 | nn.Sigmoid() 28 | ) 29 | 30 | def forward(self, x): 31 | mel = self.mel(x.reshape(-1, x.shape[-1])).transpose(-1, -2).unsqueeze(1) 32 | x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) 33 | # x = self.fc(x) 34 | hidden_vec = 0 35 | if len(self.fc) == 4: 36 | for i in range(len(self.fc)): 37 | x = self.fc[i](x) 38 | if i == 0: 39 | hidden_vec = x 40 | return hidden_vec, x 41 | 42 | 43 | class E2E0(nn.Module): 44 | def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, 45 | en_out_channels=16): 46 | super(E2E0, self).__init__() 47 | self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) 48 | self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) 49 | if n_gru: 50 | self.fc = nn.Sequential( 51 | BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405 52 | nn.Linear(512, N_CLASS), # noqa: F405 53 | nn.Dropout(0.25), 54 | nn.Sigmoid() 55 | ) 56 | else: 57 | self.fc = nn.Sequential( 58 | nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405 59 | nn.Dropout(0.25), 60 | nn.Sigmoid() 61 | ) 62 | 63 | def forward(self, mel): 64 | mel = mel.transpose(-1, -2).unsqueeze(1) 65 | x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) 66 | x = self.fc(x) 67 | return x 68 | -------------------------------------------------------------------------------- /modules/F0Predictor/rmvpe/seq.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BiGRU(nn.Module): 5 | def __init__(self, input_features, hidden_features, num_layers): 6 | super(BiGRU, self).__init__() 7 | self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) 8 | 9 | def forward(self, x): 10 | return self.gru(x)[0] 11 | 12 | 13 | class BiLSTM(nn.Module): 14 | def __init__(self, input_features, hidden_features, num_layers): 15 | super(BiLSTM, self).__init__() 16 | self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) 17 | 18 | def forward(self, x): 19 | return self.lstm(x)[0] 20 | 21 | -------------------------------------------------------------------------------- /modules/F0Predictor/rmvpe/spec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from librosa.filters import mel 5 | 6 | 7 | class MelSpectrogram(torch.nn.Module): 8 | def __init__( 9 | self, 10 | n_mel_channels, 11 | sampling_rate, 12 | win_length, 13 | hop_length, 14 | n_fft=None, 15 | mel_fmin=0, 16 | mel_fmax=None, 17 | clamp = 1e-5 18 | ): 19 | super().__init__() 20 | n_fft = win_length if n_fft is None else n_fft 21 | self.hann_window = {} 22 | mel_basis = mel( 23 | sr=sampling_rate, 24 | n_fft=n_fft, 25 | n_mels=n_mel_channels, 26 | fmin=mel_fmin, 27 | fmax=mel_fmax, 28 | htk=True) 29 | mel_basis = torch.from_numpy(mel_basis).float() 30 | self.register_buffer("mel_basis", mel_basis) 31 | self.n_fft = win_length if n_fft is None else n_fft 32 | self.hop_length = hop_length 33 | self.win_length = win_length 34 | self.sampling_rate = sampling_rate 35 | self.n_mel_channels = n_mel_channels 36 | self.clamp = clamp 37 | 38 | def forward(self, audio, keyshift=0, speed=1, center=True): 39 | factor = 2 ** (keyshift / 12) 40 | n_fft_new = int(np.round(self.n_fft * factor)) 41 | win_length_new = int(np.round(self.win_length * factor)) 42 | hop_length_new = int(np.round(self.hop_length * speed)) 43 | 44 | keyshift_key = str(keyshift)+'_'+str(audio.device) 45 | if keyshift_key not in self.hann_window: 46 | self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) 47 | 48 | fft = torch.stft( 49 | audio, 50 | n_fft=n_fft_new, 51 | hop_length=hop_length_new, 52 | win_length=win_length_new, 53 | window=self.hann_window[keyshift_key], 54 | center=center, 55 | return_complex=True) 56 | magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) 57 | 58 | if keyshift != 0: 59 | size = self.n_fft // 2 + 1 60 | resize = magnitude.size(1) 61 | if resize < size: 62 | magnitude = F.pad(magnitude, (0, 0, 0, size-resize)) 63 | magnitude = magnitude[:, :size, :] * self.win_length / win_length_new 64 | 65 | mel_output = torch.matmul(self.mel_basis, magnitude) 66 | log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) 67 | return log_mel_spec -------------------------------------------------------------------------------- /modules/F0Predictor/rmvpe/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from functools import reduce 3 | 4 | import librosa 5 | import numpy as np 6 | import torch 7 | from torch.nn.modules.module import _addindent 8 | 9 | from .constants import * # noqa: F403 10 | 11 | 12 | def cycle(iterable): 13 | while True: 14 | for item in iterable: 15 | yield item 16 | 17 | 18 | def summary(model, file=sys.stdout): 19 | def repr(model): 20 | # We treat the extra repr like the sub-module, one item per line 21 | extra_lines = [] 22 | extra_repr = model.extra_repr() 23 | # empty string will be split into list [''] 24 | if extra_repr: 25 | extra_lines = extra_repr.split('\n') 26 | child_lines = [] 27 | total_params = 0 28 | for key, module in model._modules.items(): 29 | mod_str, num_params = repr(module) 30 | mod_str = _addindent(mod_str, 2) 31 | child_lines.append('(' + key + '): ' + mod_str) 32 | total_params += num_params 33 | lines = extra_lines + child_lines 34 | 35 | for name, p in model._parameters.items(): 36 | if hasattr(p, 'shape'): 37 | total_params += reduce(lambda x, y: x * y, p.shape) 38 | 39 | main_str = model._get_name() + '(' 40 | if lines: 41 | # simple one-liner info, which most builtin Modules will use 42 | if len(extra_lines) == 1 and not child_lines: 43 | main_str += extra_lines[0] 44 | else: 45 | main_str += '\n ' + '\n '.join(lines) + '\n' 46 | 47 | main_str += ')' 48 | if file is sys.stdout: 49 | main_str += ', \033[92m{:,}\033[0m params'.format(total_params) 50 | else: 51 | main_str += ', {:,} params'.format(total_params) 52 | return main_str, total_params 53 | 54 | string, count = repr(model) 55 | if file is not None: 56 | if isinstance(file, str): 57 | file = open(file, 'w') 58 | print(string, file=file) 59 | file.flush() 60 | 61 | return count 62 | 63 | 64 | def to_local_average_cents(salience, center=None, thred=0.05): 65 | """ 66 | find the weighted average cents near the argmax bin 67 | """ 68 | 69 | if not hasattr(to_local_average_cents, 'cents_mapping'): 70 | # the bin number-to-cents mapping 71 | to_local_average_cents.cents_mapping = ( 72 | 20 * torch.arange(N_CLASS) + CONST).to(salience.device) # noqa: F405 73 | 74 | if salience.ndim == 1: 75 | if center is None: 76 | center = int(torch.argmax(salience)) 77 | start = max(0, center - 4) 78 | end = min(len(salience), center + 5) 79 | salience = salience[start:end] 80 | product_sum = torch.sum( 81 | salience * to_local_average_cents.cents_mapping[start:end]) 82 | weight_sum = torch.sum(salience) 83 | return product_sum / weight_sum if torch.max(salience) > thred else 0 84 | if salience.ndim == 2: 85 | return torch.Tensor([to_local_average_cents(salience[i, :], None, thred) for i in 86 | range(salience.shape[0])]).to(salience.device) 87 | 88 | raise Exception("label should be either 1d or 2d ndarray") 89 | 90 | def to_viterbi_cents(salience, thred=0.05): 91 | # Create viterbi transition matrix 92 | if not hasattr(to_viterbi_cents, 'transition'): 93 | xx, yy = torch.meshgrid(range(N_CLASS), range(N_CLASS)) # noqa: F405 94 | transition = torch.maximum(30 - abs(xx - yy), 0) 95 | transition = transition / transition.sum(axis=1, keepdims=True) 96 | to_viterbi_cents.transition = transition 97 | 98 | # Convert to probability 99 | prob = salience.T 100 | prob = prob / prob.sum(axis=0) 101 | 102 | # Perform viterbi decoding 103 | path = librosa.sequence.viterbi(prob.detach().cpu().numpy(), to_viterbi_cents.transition).astype(np.int64) 104 | 105 | return torch.Tensor([to_local_average_cents(salience[i, :], path[i], thred) for i in 106 | range(len(path))]).to(salience.device) 107 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/modules/__init__.py -------------------------------------------------------------------------------- /modules/enhancer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torchaudio.transforms import Resample 5 | 6 | from vdecoder.nsf_hifigan.models import load_model 7 | from vdecoder.nsf_hifigan.nvSTFT import STFT 8 | 9 | 10 | class Enhancer: 11 | def __init__(self, enhancer_type, enhancer_ckpt, device=None): 12 | if device is None: 13 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | self.device = device 15 | 16 | if enhancer_type == 'nsf-hifigan': 17 | self.enhancer = NsfHifiGAN(enhancer_ckpt, device=self.device) 18 | else: 19 | raise ValueError(f" [x] Unknown enhancer: {enhancer_type}") 20 | 21 | self.resample_kernel = {} 22 | self.enhancer_sample_rate = self.enhancer.sample_rate() 23 | self.enhancer_hop_size = self.enhancer.hop_size() 24 | 25 | def enhance(self, 26 | audio, # 1, T 27 | sample_rate, 28 | f0, # 1, n_frames, 1 29 | hop_size, 30 | adaptive_key = 0, 31 | silence_front = 0 32 | ): 33 | # enhancer start time 34 | start_frame = int(silence_front * sample_rate / hop_size) 35 | real_silence_front = start_frame * hop_size / sample_rate 36 | audio = audio[:, int(np.round(real_silence_front * sample_rate)) : ] 37 | f0 = f0[: , start_frame :, :] 38 | 39 | # adaptive parameters 40 | adaptive_factor = 2 ** ( -adaptive_key / 12) 41 | adaptive_sample_rate = 100 * int(np.round(self.enhancer_sample_rate / adaptive_factor / 100)) 42 | real_factor = self.enhancer_sample_rate / adaptive_sample_rate 43 | 44 | # resample the ddsp output 45 | if sample_rate == adaptive_sample_rate: 46 | audio_res = audio 47 | else: 48 | key_str = str(sample_rate) + str(adaptive_sample_rate) 49 | if key_str not in self.resample_kernel: 50 | self.resample_kernel[key_str] = Resample(sample_rate, adaptive_sample_rate, lowpass_filter_width = 128).to(self.device) 51 | audio_res = self.resample_kernel[key_str](audio) 52 | 53 | n_frames = int(audio_res.size(-1) // self.enhancer_hop_size + 1) 54 | 55 | # resample f0 56 | f0_np = f0.squeeze(0).squeeze(-1).cpu().numpy() 57 | f0_np *= real_factor 58 | time_org = (hop_size / sample_rate) * np.arange(len(f0_np)) / real_factor 59 | time_frame = (self.enhancer_hop_size / self.enhancer_sample_rate) * np.arange(n_frames) 60 | f0_res = np.interp(time_frame, time_org, f0_np, left=f0_np[0], right=f0_np[-1]) 61 | f0_res = torch.from_numpy(f0_res).unsqueeze(0).float().to(self.device) # 1, n_frames 62 | 63 | # enhance 64 | enhanced_audio, enhancer_sample_rate = self.enhancer(audio_res, f0_res) 65 | 66 | # resample the enhanced output 67 | if adaptive_factor != 0: 68 | key_str = str(adaptive_sample_rate) + str(enhancer_sample_rate) 69 | if key_str not in self.resample_kernel: 70 | self.resample_kernel[key_str] = Resample(adaptive_sample_rate, enhancer_sample_rate, lowpass_filter_width = 128).to(self.device) 71 | enhanced_audio = self.resample_kernel[key_str](enhanced_audio) 72 | 73 | # pad the silence frames 74 | if start_frame > 0: 75 | enhanced_audio = F.pad(enhanced_audio, (int(np.round(enhancer_sample_rate * real_silence_front)), 0)) 76 | 77 | return enhanced_audio, enhancer_sample_rate 78 | 79 | 80 | class NsfHifiGAN(torch.nn.Module): 81 | def __init__(self, model_path, device=None): 82 | super().__init__() 83 | if device is None: 84 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 85 | self.device = device 86 | print('| Load HifiGAN: ', model_path) 87 | self.model, self.h = load_model(model_path, device=self.device) 88 | 89 | def sample_rate(self): 90 | return self.h.sampling_rate 91 | 92 | def hop_size(self): 93 | return self.h.hop_size 94 | 95 | def forward(self, audio, f0): 96 | stft = STFT( 97 | self.h.sampling_rate, 98 | self.h.num_mels, 99 | self.h.n_fft, 100 | self.h.win_size, 101 | self.h.hop_size, 102 | self.h.fmin, 103 | self.h.fmax) 104 | with torch.no_grad(): 105 | mel = stft.get_mel(audio) 106 | enhanced_audio = self.model(mel, f0[:,:mel.size(-1)]).view(-1) 107 | return enhanced_audio, self.h.sampling_rate -------------------------------------------------------------------------------- /modules/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def feature_loss(fmap_r, fmap_g): 5 | loss = 0 6 | for dr, dg in zip(fmap_r, fmap_g): 7 | for rl, gl in zip(dr, dg): 8 | rl = rl.float().detach() 9 | gl = gl.float() 10 | loss += torch.mean(torch.abs(rl - gl)) 11 | 12 | return loss * 2 13 | 14 | 15 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 16 | loss = 0 17 | r_losses = [] 18 | g_losses = [] 19 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 20 | dr = dr.float() 21 | dg = dg.float() 22 | r_loss = torch.mean((1-dr)**2) 23 | g_loss = torch.mean(dg**2) 24 | loss += (r_loss + g_loss) 25 | r_losses.append(r_loss.item()) 26 | g_losses.append(g_loss.item()) 27 | 28 | return loss, r_losses, g_losses 29 | 30 | 31 | def generator_loss(disc_outputs): 32 | loss = 0 33 | gen_losses = [] 34 | for dg in disc_outputs: 35 | dg = dg.float() 36 | l = torch.mean((1-dg)**2) 37 | gen_losses.append(l) 38 | loss += l 39 | 40 | return loss, gen_losses 41 | 42 | 43 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 44 | """ 45 | z_p, logs_q: [b, h, t_t] 46 | m_p, logs_p: [b, h, t_t] 47 | """ 48 | z_p = z_p.float() 49 | logs_q = logs_q.float() 50 | m_p = m_p.float() 51 | logs_p = logs_p.float() 52 | z_mask = z_mask.float() 53 | #print(logs_p) 54 | kl = logs_p - logs_q - 0.5 55 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) 56 | kl = torch.sum(kl * z_mask) 57 | l = kl / torch.sum(z_mask) 58 | return l 59 | -------------------------------------------------------------------------------- /modules/mel_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from librosa.filters import mel as librosa_mel_fn 4 | 5 | MAX_WAV_VALUE = 32768.0 6 | 7 | 8 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 9 | """ 10 | PARAMS 11 | ------ 12 | C: compression factor 13 | """ 14 | return torch.log(torch.clamp(x, min=clip_val) * C) 15 | 16 | 17 | def dynamic_range_decompression_torch(x, C=1): 18 | """ 19 | PARAMS 20 | ------ 21 | C: compression factor used to compress 22 | """ 23 | return torch.exp(x) / C 24 | 25 | 26 | def spectral_normalize_torch(magnitudes): 27 | output = dynamic_range_compression_torch(magnitudes) 28 | return output 29 | 30 | 31 | def spectral_de_normalize_torch(magnitudes): 32 | output = dynamic_range_decompression_torch(magnitudes) 33 | return output 34 | 35 | 36 | mel_basis = {} 37 | hann_window = {} 38 | 39 | 40 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 41 | if torch.min(y) < -1.: 42 | print('min value is ', torch.min(y)) 43 | if torch.max(y) > 1.: 44 | print('max value is ', torch.max(y)) 45 | 46 | global hann_window 47 | dtype_device = str(y.dtype) + '_' + str(y.device) 48 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 49 | if wnsize_dtype_device not in hann_window: 50 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 51 | 52 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 53 | y = y.squeeze(1) 54 | 55 | y_dtype = y.dtype 56 | if y.dtype == torch.bfloat16: 57 | y = y.to(torch.float32) 58 | 59 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 60 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 61 | spec = torch.view_as_real(spec).to(y_dtype) 62 | 63 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 64 | return spec 65 | 66 | 67 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 68 | global mel_basis 69 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 70 | fmax_dtype_device = str(fmax) + '_' + dtype_device 71 | if fmax_dtype_device not in mel_basis: 72 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 73 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 74 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 75 | spec = spectral_normalize_torch(spec) 76 | return spec 77 | 78 | 79 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 80 | spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center) 81 | spec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax) 82 | 83 | return spec 84 | -------------------------------------------------------------------------------- /onnx_export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import torch 5 | 6 | import utils 7 | from onnxexport.model_onnx_speaker_mix import SynthesizerTrn 8 | 9 | parser = argparse.ArgumentParser(description='SoVitsSvc OnnxExport') 10 | 11 | def OnnxExport(path=None): 12 | device = torch.device("cpu") 13 | hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") 14 | SVCVITS = SynthesizerTrn( 15 | hps.data.filter_length // 2 + 1, 16 | hps.train.segment_size // hps.data.hop_length, 17 | **hps.model) 18 | _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None) 19 | _ = SVCVITS.eval().to(device) 20 | for i in SVCVITS.parameters(): 21 | i.requires_grad = False 22 | 23 | num_frames = 200 24 | 25 | test_hidden_unit = torch.rand(1, num_frames, SVCVITS.gin_channels) 26 | test_pitch = torch.rand(1, num_frames) 27 | test_vol = torch.rand(1, num_frames) 28 | test_mel2ph = torch.LongTensor(torch.arange(0, num_frames)).unsqueeze(0) 29 | test_uv = torch.ones(1, num_frames, dtype=torch.float32) 30 | test_noise = torch.randn(1, 192, num_frames) 31 | test_sid = torch.LongTensor([0]) 32 | export_mix = True 33 | if len(hps.spk) < 2: 34 | export_mix = False 35 | 36 | if export_mix: 37 | spk_mix = [] 38 | n_spk = len(hps.spk) 39 | for i in range(n_spk): 40 | spk_mix.append(1.0/float(n_spk)) 41 | test_sid = torch.tensor(spk_mix) 42 | SVCVITS.export_chara_mix(hps.spk) 43 | test_sid = test_sid.unsqueeze(0) 44 | test_sid = test_sid.repeat(num_frames, 1) 45 | 46 | SVCVITS.eval() 47 | 48 | if export_mix: 49 | daxes = { 50 | "c": [0, 1], 51 | "f0": [1], 52 | "mel2ph": [1], 53 | "uv": [1], 54 | "noise": [2], 55 | "sid":[0] 56 | } 57 | else: 58 | daxes = { 59 | "c": [0, 1], 60 | "f0": [1], 61 | "mel2ph": [1], 62 | "uv": [1], 63 | "noise": [2] 64 | } 65 | 66 | input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"] 67 | output_names = ["audio", ] 68 | 69 | if SVCVITS.vol_embedding: 70 | input_names.append("vol") 71 | vol_dadict = {"vol" : [1]} 72 | daxes.update(vol_dadict) 73 | test_inputs = ( 74 | test_hidden_unit.to(device), 75 | test_pitch.to(device), 76 | test_mel2ph.to(device), 77 | test_uv.to(device), 78 | test_noise.to(device), 79 | test_sid.to(device), 80 | test_vol.to(device) 81 | ) 82 | else: 83 | test_inputs = ( 84 | test_hidden_unit.to(device), 85 | test_pitch.to(device), 86 | test_mel2ph.to(device), 87 | test_uv.to(device), 88 | test_noise.to(device), 89 | test_sid.to(device) 90 | ) 91 | 92 | # SVCVITS = torch.jit.script(SVCVITS) 93 | SVCVITS(test_hidden_unit.to(device), 94 | test_pitch.to(device), 95 | test_mel2ph.to(device), 96 | test_uv.to(device), 97 | test_noise.to(device), 98 | test_sid.to(device), 99 | test_vol.to(device)) 100 | 101 | SVCVITS.dec.OnnxExport() 102 | 103 | torch.onnx.export( 104 | SVCVITS, 105 | test_inputs, 106 | f"checkpoints/{path}/{path}_SoVits.onnx", 107 | dynamic_axes=daxes, 108 | do_constant_folding=False, 109 | opset_version=16, 110 | verbose=False, 111 | input_names=input_names, 112 | output_names=output_names 113 | ) 114 | 115 | vec_lay = "layer-12" if SVCVITS.gin_channels == 768 else "layer-9" 116 | spklist = [] 117 | for key in hps.spk.keys(): 118 | spklist.append(key) 119 | 120 | MoeVSConf = { 121 | "Folder" : f"{path}", 122 | "Name" : f"{path}", 123 | "Type" : "SoVits", 124 | "Rate" : hps.data.sampling_rate, 125 | "Hop" : hps.data.hop_length, 126 | "Hubert": f"vec-{SVCVITS.gin_channels}-{vec_lay}", 127 | "SoVits4": True, 128 | "SoVits3": False, 129 | "CharaMix": export_mix, 130 | "Volume": SVCVITS.vol_embedding, 131 | "HiddenSize": SVCVITS.gin_channels, 132 | "Characters": spklist, 133 | "Cluster": "" 134 | } 135 | 136 | with open(f"checkpoints/{path}.json", 'w') as MoeVsConfFile: 137 | json.dump(MoeVSConf, MoeVsConfFile, indent = 4) 138 | 139 | 140 | if __name__ == '__main__': 141 | parser.add_argument('-n', '--model_name', type=str, default="TransformerFlow", help='模型文件夹名(根目录下新建ckeckpoints文件夹,在此文件夹下建立一个新的文件夹,放置模型,该文件夹名即为此项)') 142 | args = parser.parse_args() 143 | path = args.model_name 144 | OnnxExport(path) 145 | -------------------------------------------------------------------------------- /onnx_export_old.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utils 4 | from onnxexport.model_onnx import SynthesizerTrn 5 | 6 | 7 | def main(NetExport): 8 | path = "SoVits4.0" 9 | if NetExport: 10 | device = torch.device("cpu") 11 | hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") 12 | SVCVITS = SynthesizerTrn( 13 | hps.data.filter_length // 2 + 1, 14 | hps.train.segment_size // hps.data.hop_length, 15 | **hps.model) 16 | _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None) 17 | _ = SVCVITS.eval().to(device) 18 | for i in SVCVITS.parameters(): 19 | i.requires_grad = False 20 | 21 | n_frame = 10 22 | test_hidden_unit = torch.rand(1, n_frame, 256) 23 | test_pitch = torch.rand(1, n_frame) 24 | test_mel2ph = torch.arange(0, n_frame, dtype=torch.int64)[None] # torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).unsqueeze(0) 25 | test_uv = torch.ones(1, n_frame, dtype=torch.float32) 26 | test_noise = torch.randn(1, 192, n_frame) 27 | test_sid = torch.LongTensor([0]) 28 | input_names = ["c", "f0", "mel2ph", "uv", "noise", "sid"] 29 | output_names = ["audio", ] 30 | 31 | torch.onnx.export(SVCVITS, 32 | ( 33 | test_hidden_unit.to(device), 34 | test_pitch.to(device), 35 | test_mel2ph.to(device), 36 | test_uv.to(device), 37 | test_noise.to(device), 38 | test_sid.to(device) 39 | ), 40 | f"checkpoints/{path}/model.onnx", 41 | dynamic_axes={ 42 | "c": [0, 1], 43 | "f0": [1], 44 | "mel2ph": [1], 45 | "uv": [1], 46 | "noise": [2], 47 | }, 48 | do_constant_folding=False, 49 | opset_version=16, 50 | verbose=False, 51 | input_names=input_names, 52 | output_names=output_names) 53 | 54 | 55 | if __name__ == '__main__': 56 | main(True) 57 | -------------------------------------------------------------------------------- /preprocess_flist_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import wave 6 | from random import shuffle 7 | 8 | from loguru import logger 9 | from tqdm import tqdm 10 | 11 | import diffusion.logger.utils as du 12 | 13 | pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$') 14 | 15 | def get_wav_duration(file_path): 16 | try: 17 | with wave.open(file_path, 'rb') as wav_file: 18 | # 获取音频帧数 19 | n_frames = wav_file.getnframes() 20 | # 获取采样率 21 | framerate = wav_file.getframerate() 22 | # 计算时长(秒) 23 | return n_frames / float(framerate) 24 | except Exception as e: 25 | logger.error(f"Reading {file_path}") 26 | raise e 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list") 31 | parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list") 32 | parser.add_argument("--source_dir", type=str, default="./dataset/44k", help="path to source dir") 33 | parser.add_argument("--speech_encoder", type=str, default="vec768l12", help="choice a speech encoder|'vec768l12','vec256l9','hubertsoft','whisper-ppg','cnhubertlarge','dphubert','whisper-ppg-large','wavlmbase+'") 34 | parser.add_argument("--vol_aug", action="store_true", help="Whether to use volume embedding and volume augmentation") 35 | parser.add_argument("--tiny", action="store_true", help="Whether to train sovits tiny") 36 | args = parser.parse_args() 37 | 38 | config_template = json.load(open("configs_template/config_tiny_template.json")) if args.tiny else json.load(open("configs_template/config_template.json")) 39 | train = [] 40 | val = [] 41 | idx = 0 42 | spk_dict = {} 43 | spk_id = 0 44 | 45 | for speaker in tqdm(os.listdir(args.source_dir)): 46 | spk_dict[speaker] = spk_id 47 | spk_id += 1 48 | wavs = [] 49 | 50 | for file_name in os.listdir(os.path.join(args.source_dir, speaker)): 51 | if not file_name.endswith("wav"): 52 | continue 53 | if file_name.startswith("."): 54 | continue 55 | 56 | file_path = "/".join([args.source_dir, speaker, file_name]) 57 | 58 | if not pattern.match(file_name): 59 | logger.warning("Detected non-ASCII file name: " + file_path) 60 | 61 | if get_wav_duration(file_path) < 0.3: 62 | logger.info("Skip too short audio: " + file_path) 63 | continue 64 | 65 | wavs.append(file_path) 66 | 67 | shuffle(wavs) 68 | train += wavs[2:] 69 | val += wavs[:2] 70 | 71 | shuffle(train) 72 | shuffle(val) 73 | 74 | logger.info("Writing " + args.train_list) 75 | with open(args.train_list, "w") as f: 76 | for fname in tqdm(train): 77 | wavpath = fname 78 | f.write(wavpath + "\n") 79 | 80 | logger.info("Writing " + args.val_list) 81 | with open(args.val_list, "w") as f: 82 | for fname in tqdm(val): 83 | wavpath = fname 84 | f.write(wavpath + "\n") 85 | 86 | 87 | d_config_template = du.load_config("configs_template/diffusion_template.yaml") 88 | d_config_template["model"]["n_spk"] = spk_id 89 | d_config_template["data"]["encoder"] = args.speech_encoder 90 | d_config_template["spk"] = spk_dict 91 | 92 | config_template["spk"] = spk_dict 93 | config_template["model"]["n_speakers"] = spk_id 94 | config_template["model"]["speech_encoder"] = args.speech_encoder 95 | 96 | if args.speech_encoder == "vec768l12" or args.speech_encoder == "dphubert" or args.speech_encoder == "wavlmbase+": 97 | config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 768 98 | d_config_template["data"]["encoder_out_channels"] = 768 99 | elif args.speech_encoder == "vec256l9" or args.speech_encoder == 'hubertsoft': 100 | config_template["model"]["ssl_dim"] = config_template["model"]["gin_channels"] = 256 101 | d_config_template["data"]["encoder_out_channels"] = 256 102 | elif args.speech_encoder == "whisper-ppg" or args.speech_encoder == 'cnhubertlarge': 103 | config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1024 104 | d_config_template["data"]["encoder_out_channels"] = 1024 105 | elif args.speech_encoder == "whisper-ppg-large": 106 | config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1280 107 | d_config_template["data"]["encoder_out_channels"] = 1280 108 | 109 | if args.vol_aug: 110 | config_template["train"]["vol_aug"] = config_template["model"]["vol_embedding"] = True 111 | 112 | if args.tiny: 113 | config_template["model"]["filter_channels"] = 512 114 | 115 | logger.info("Writing to configs/config.json") 116 | with open("configs/config.json", "w") as f: 117 | json.dump(config_template, f, indent=2) 118 | logger.info("Writing to configs/diffusion.yaml") 119 | du.save_config("configs/diffusion.yaml",d_config_template) 120 | -------------------------------------------------------------------------------- /pretrain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/pretrain/__init__.py -------------------------------------------------------------------------------- /pretrain/meta.py: -------------------------------------------------------------------------------- 1 | def download_dict(): 2 | return { 3 | "vec768l12": { 4 | "url": "https://ibm.ent.box.com/shared/static/z1wgl1stco8ffooyatzdwsqn2psd9lrr", 5 | "output": "./pretrain/checkpoint_best_legacy_500.pt" 6 | }, 7 | "vec256l9": { 8 | "url": "https://ibm.ent.box.com/shared/static/z1wgl1stco8ffooyatzdwsqn2psd9lrr", 9 | "output": "./pretrain/checkpoint_best_legacy_500.pt" 10 | }, 11 | "hubertsoft": { 12 | "url": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt", 13 | "output": "./pretrain/hubert-soft-0d54a1f4.pt" 14 | }, 15 | "whisper-ppg-small": { 16 | "url": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 17 | "output": "./pretrain/small.pt" 18 | }, 19 | "whisper-ppg": { 20 | "url": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 21 | "output": "./pretrain/medium.pt" 22 | }, 23 | "whisper-ppg-large": { 24 | "url": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 25 | "output": "./pretrain/large-v2.pt" 26 | } 27 | } 28 | 29 | 30 | def get_speech_encoder(config_path="configs/config.json"): 31 | import json 32 | 33 | with open(config_path, "r") as f: 34 | data = f.read() 35 | config = json.loads(data) 36 | speech_encoder = config["model"]["speech_encoder"] 37 | dict = download_dict() 38 | 39 | return dict[speech_encoder]["url"], dict[speech_encoder]["output"] 40 | -------------------------------------------------------------------------------- /pretrain/nsf_hifigan/put_nsf_hifigan_ckpt_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/pretrain/nsf_hifigan/put_nsf_hifigan_ckpt_here -------------------------------------------------------------------------------- /pretrain/put_hubert_ckpt_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/pretrain/put_hubert_ckpt_here -------------------------------------------------------------------------------- /raw/put_raw_wav_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/raw/put_raw_wav_here -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ffmpeg-python 2 | Flask 3 | Flask_Cors 4 | gradio>=3.7.0 5 | numpy==1.23.5 6 | pyworld 7 | scipy==1.10.0 8 | SoundFile==0.12.1 9 | torch 10 | torchaudio 11 | torchcrepe 12 | tqdm 13 | rich 14 | loguru 15 | scikit-maad 16 | praat-parselmouth 17 | onnx 18 | onnxsim 19 | onnxoptimizer 20 | fairseq==0.12.2 21 | librosa==0.9.1 22 | tensorboard 23 | tensorboardX 24 | transformers 25 | edge_tts 26 | langdetect 27 | pyyaml 28 | pynvml 29 | faiss-cpu 30 | einops 31 | local_attention -------------------------------------------------------------------------------- /requirements_onnx_encoder.txt: -------------------------------------------------------------------------------- 1 | Flask 2 | Flask_Cors 3 | gradio>=3.7.0 4 | numpy==1.23.0 5 | pyworld==0.2.5 6 | scipy==1.10.0 7 | SoundFile==0.12.1 8 | torch==1.13.1 9 | torchaudio==0.13.1 10 | torchcrepe 11 | tqdm 12 | rich.progress 13 | loguru 14 | scikit-maad 15 | praat-parselmouth 16 | onnx 17 | onnxsim 18 | onnxoptimizer 19 | onnxruntime-gpu 20 | librosa==0.9.1 21 | tensorboard 22 | tensorboardX 23 | edge_tts 24 | langdetect 25 | pyyaml 26 | pynvml 27 | transformers 28 | ffmpeg-python 29 | faiss-cpu -------------------------------------------------------------------------------- /requirements_win.txt: -------------------------------------------------------------------------------- 1 | librosa==0.9.1 2 | fairseq==0.12.2 3 | ffmpeg-python 4 | Flask==2.1.2 5 | Flask_Cors==3.0.10 6 | gradio>=3.7.0 7 | numpy 8 | playsound==1.3.0 9 | PyAudio==0.2.12 10 | pydub==0.25.1 11 | pyworld==0.3.0 12 | requests==2.28.1 13 | scipy==1.7.3 14 | sounddevice==0.4.5 15 | SoundFile==0.10.3.post1 16 | starlette==0.19.1 17 | tqdm==4.63.0 18 | rich 19 | loguru 20 | torchcrepe 21 | scikit-maad 22 | praat-parselmouth 23 | onnx 24 | onnxsim 25 | onnxoptimizer 26 | tensorboard 27 | tensorboardX 28 | transformers 29 | edge_tts 30 | langdetect 31 | pyyaml 32 | pynvml 33 | faiss-cpu 34 | -------------------------------------------------------------------------------- /resample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import concurrent.futures 3 | import os 4 | from concurrent.futures import ProcessPoolExecutor 5 | from multiprocessing import cpu_count 6 | 7 | import librosa 8 | import numpy as np 9 | from rich.progress import track 10 | from scipy.io import wavfile 11 | 12 | 13 | def load_wav(wav_path): 14 | return librosa.load(wav_path, sr=None) 15 | 16 | 17 | def trim_wav(wav, top_db=40): 18 | return librosa.effects.trim(wav, top_db=top_db) 19 | 20 | 21 | def normalize_peak(wav, threshold=1.0): 22 | peak = np.abs(wav).max() 23 | if peak > threshold: 24 | wav = 0.98 * wav / peak 25 | return wav 26 | 27 | 28 | def resample_wav(wav, sr, target_sr): 29 | return librosa.resample(wav, orig_sr=sr, target_sr=target_sr) 30 | 31 | 32 | def save_wav_to_path(wav, save_path, sr): 33 | wavfile.write( 34 | save_path, 35 | sr, 36 | (wav * np.iinfo(np.int16).max).astype(np.int16) 37 | ) 38 | 39 | 40 | def process(item): 41 | spkdir, wav_name, args = item 42 | speaker = spkdir.replace("\\", "/").split("/")[-1] 43 | 44 | wav_path = os.path.join(args.in_dir, speaker, wav_name) 45 | if os.path.exists(wav_path) and '.wav' in wav_path: 46 | os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True) 47 | 48 | wav, sr = load_wav(wav_path) 49 | wav, _ = trim_wav(wav) 50 | wav = normalize_peak(wav) 51 | resampled_wav = resample_wav(wav, sr, args.sr2) 52 | 53 | if not args.skip_loudnorm: 54 | resampled_wav /= np.max(np.abs(resampled_wav)) 55 | 56 | save_path2 = os.path.join(args.out_dir2, speaker, wav_name) 57 | save_wav_to_path(resampled_wav, save_path2, args.sr2) 58 | 59 | 60 | """ 61 | def process_all_speakers(): 62 | process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1) 63 | 64 | with ThreadPoolExecutor(max_workers=process_count) as executor: 65 | for speaker in speakers: 66 | spk_dir = os.path.join(args.in_dir, speaker) 67 | if os.path.isdir(spk_dir): 68 | print(spk_dir) 69 | futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")] 70 | for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 71 | pass 72 | """ 73 | # multi process 74 | 75 | 76 | def process_all_speakers(): 77 | process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1) 78 | with ProcessPoolExecutor(max_workers=process_count) as executor: 79 | for speaker in speakers: 80 | spk_dir = os.path.join(args.in_dir, speaker) 81 | if os.path.isdir(spk_dir): 82 | print(spk_dir) 83 | futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")] 84 | for _ in track(concurrent.futures.as_completed(futures), total=len(futures), description="resampling:"): 85 | pass 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--sr2", type=int, default=44100, help="sampling rate") 91 | parser.add_argument("--in_dir", type=str, default="./dataset_raw", help="path to source dir") 92 | parser.add_argument("--out_dir2", type=str, default="./dataset/44k", help="path to target dir") 93 | parser.add_argument("--skip_loudnorm", action="store_true", help="Skip loudness matching if you have done it") 94 | args = parser.parse_args() 95 | 96 | print(f"CPU count: {cpu_count()}") 97 | speakers = os.listdir(args.in_dir) 98 | process_all_speakers() 99 | -------------------------------------------------------------------------------- /shadowdiffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/shadowdiffusion.png -------------------------------------------------------------------------------- /spkmix.py: -------------------------------------------------------------------------------- 1 | # 角色混合轨道 编写规则: 2 | # 角色ID : [[起始时间1, 终止时间1, 起始数值1, 起始数值1], [起始时间2, 终止时间2, 起始数值2, 起始数值2]] 3 | # 起始时间和前一个的终止时间必须相同,第一个起始时间必须为0,最后一个终止时间必须为1 (时间的范围为0-1) 4 | # 全部角色必须填写,不使用的角色填[[0., 1., 0., 0.]]即可 5 | # 融合数值可以随便填,在指定的时间段内从起始数值线性变化为终止数值,内部会自动确保线性组合为1,可以放心使用 6 | 7 | spk_mix_map = { 8 | 0 : [[0., 0.5, 1, 0.5], [0.5, 1, 0.5, 1]], 9 | 1 : [[0., 0.35, 1, 0.5], [0.35, 0.75, 0.75, 1], [0.75, 1, 0.45, 1]], 10 | 2 : [[0., 0.35, 1, 0.5], [0.35, 0.75, 0.75, 1], [0.75, 1, 0.45, 1]] 11 | } -------------------------------------------------------------------------------- /train_diff.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from loguru import logger 5 | from torch.optim import lr_scheduler 6 | 7 | from diffusion.data_loaders import get_data_loaders 8 | from diffusion.logger import utils 9 | from diffusion.solver import train 10 | from diffusion.unit2mel import Unit2Mel 11 | from diffusion.vocoder import Vocoder 12 | 13 | 14 | def parse_args(args=None, namespace=None): 15 | """Parse command-line arguments.""" 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "-c", 19 | "--config", 20 | type=str, 21 | required=True, 22 | help="path to the config file") 23 | return parser.parse_args(args=args, namespace=namespace) 24 | 25 | 26 | if __name__ == '__main__': 27 | # parse commands 28 | cmd = parse_args() 29 | 30 | # load config 31 | args = utils.load_config(cmd.config) 32 | logger.info(' > config:'+ cmd.config) 33 | logger.info(' > exp:'+ args.env.expdir) 34 | 35 | # load vocoder 36 | vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device) 37 | 38 | # load model 39 | model = Unit2Mel( 40 | args.data.encoder_out_channels, 41 | args.model.n_spk, 42 | args.model.use_pitch_aug, 43 | vocoder.dimension, 44 | args.model.n_layers, 45 | args.model.n_chans, 46 | args.model.n_hidden, 47 | args.model.timesteps, 48 | args.model.k_step_max 49 | ) 50 | 51 | logger.info(f' > Now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}') 52 | 53 | # load parameters 54 | optimizer = torch.optim.AdamW(model.parameters()) 55 | initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device) 56 | for param_group in optimizer.param_groups: 57 | param_group['initial_lr'] = args.train.lr 58 | param_group['lr'] = args.train.lr * (args.train.gamma ** max(((initial_global_step-2)//args.train.decay_step),0) ) 59 | param_group['weight_decay'] = args.train.weight_decay 60 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma,last_epoch=initial_global_step-2) 61 | 62 | # device 63 | if args.device == 'cuda': 64 | torch.cuda.set_device(args.env.gpu_id) 65 | model.to(args.device) 66 | 67 | for state in optimizer.state.values(): 68 | for k, v in state.items(): 69 | if torch.is_tensor(v): 70 | state[k] = v.to(args.device) 71 | 72 | # datas 73 | loader_train, loader_valid = get_data_loaders(args, whole_audio=False) 74 | 75 | # run 76 | train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid) 77 | 78 | -------------------------------------------------------------------------------- /train_index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | 5 | import utils 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--root_dir", type=str, default="dataset/44k", help="path to root dir" 11 | ) 12 | parser.add_argument('-c', '--config', type=str, default="./configs/config.json", 13 | help='JSON file for configuration') 14 | parser.add_argument( 15 | "--output_dir", type=str, default="logs/44k", help="path to output dir" 16 | ) 17 | 18 | args = parser.parse_args() 19 | 20 | hps = utils.get_hparams_from_file(args.config) 21 | spk_dic = hps.spk 22 | result = {} 23 | 24 | for k,v in spk_dic.items(): 25 | print(f"now, index {k} feature...") 26 | index = utils.train_index(k,args.root_dir) 27 | result[v] = index 28 | 29 | with open(os.path.join(args.output_dir,"feature_and_index.pkl"),"wb") as f: 30 | pickle.dump(result,f) -------------------------------------------------------------------------------- /trained/put_trained_checkpoints_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/trained/put_trained_checkpoints_here -------------------------------------------------------------------------------- /vdecoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vdecoder/__init__.py -------------------------------------------------------------------------------- /vdecoder/hifigan/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /vdecoder/hifigan/nvSTFT.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | import soundfile as sf 6 | import torch 7 | import torch.utils.data 8 | from librosa.filters import mel as librosa_mel_fn 9 | 10 | os.environ["LRU_CACHE_CAPACITY"] = "3" 11 | 12 | def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): 13 | sampling_rate = None 14 | try: 15 | data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. 16 | except Exception as ex: 17 | print(f"'{full_path}' failed to load.\nException:") 18 | print(ex) 19 | if return_empty_on_exception: 20 | return [], sampling_rate or target_sr or 32000 21 | else: 22 | raise Exception(ex) 23 | 24 | if len(data.shape) > 1: 25 | data = data[:, 0] 26 | assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) 27 | 28 | if np.issubdtype(data.dtype, np.integer): # if audio data is type int 29 | max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX 30 | else: # if audio data is type fp32 31 | max_mag = max(np.amax(data), -np.amin(data)) 32 | max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 33 | 34 | data = torch.FloatTensor(data.astype(np.float32))/max_mag 35 | 36 | if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except 37 | return [], sampling_rate or target_sr or 32000 38 | if target_sr is not None and sampling_rate != target_sr: 39 | data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) 40 | sampling_rate = target_sr 41 | 42 | return data, sampling_rate 43 | 44 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 45 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 46 | 47 | def dynamic_range_decompression(x, C=1): 48 | return np.exp(x) / C 49 | 50 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 51 | return torch.log(torch.clamp(x, min=clip_val) * C) 52 | 53 | def dynamic_range_decompression_torch(x, C=1): 54 | return torch.exp(x) / C 55 | 56 | class STFT(): 57 | def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): 58 | self.target_sr = sr 59 | 60 | self.n_mels = n_mels 61 | self.n_fft = n_fft 62 | self.win_size = win_size 63 | self.hop_length = hop_length 64 | self.fmin = fmin 65 | self.fmax = fmax 66 | self.clip_val = clip_val 67 | self.mel_basis = {} 68 | self.hann_window = {} 69 | 70 | def get_mel(self, y, center=False): 71 | sampling_rate = self.target_sr 72 | n_mels = self.n_mels 73 | n_fft = self.n_fft 74 | win_size = self.win_size 75 | hop_length = self.hop_length 76 | fmin = self.fmin 77 | fmax = self.fmax 78 | clip_val = self.clip_val 79 | 80 | if torch.min(y) < -1.: 81 | print('min value is ', torch.min(y)) 82 | if torch.max(y) > 1.: 83 | print('max value is ', torch.max(y)) 84 | 85 | if fmax not in self.mel_basis: 86 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) 87 | self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 88 | self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device) 89 | 90 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect') 91 | y = y.squeeze(1) 92 | 93 | spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)], 94 | center=center, pad_mode='reflect', normalized=False, onesided=True) 95 | # print(111,spec) 96 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 97 | # print(222,spec) 98 | spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec) 99 | # print(333,spec) 100 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val) 101 | # print(444,spec) 102 | return spec 103 | 104 | def __call__(self, audiopath): 105 | audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) 106 | spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) 107 | return spect 108 | 109 | stft = STFT() 110 | -------------------------------------------------------------------------------- /vdecoder/hifigan/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | # matplotlib.use("Agg") 5 | import matplotlib.pylab as plt 6 | import torch 7 | from torch.nn.utils import weight_norm 8 | 9 | 10 | def plot_spectrogram(spectrogram): 11 | fig, ax = plt.subplots(figsize=(10, 2)) 12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 13 | interpolation='none') 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size*dilation - dilation)/2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def del_old_checkpoints(cp_dir, prefix, n_models=2): 53 | pattern = os.path.join(cp_dir, prefix + '????????') 54 | cp_list = glob.glob(pattern) # get checkpoint paths 55 | cp_list = sorted(cp_list)# sort by iter 56 | if len(cp_list) > n_models: # if more than n_models models are found 57 | for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models 58 | open(cp, 'w').close()# empty file contents 59 | os.unlink(cp)# delete file (move to trash when using Colab) 60 | 61 | 62 | def scan_checkpoint(cp_dir, prefix): 63 | pattern = os.path.join(cp_dir, prefix + '????????') 64 | cp_list = glob.glob(pattern) 65 | if len(cp_list) == 0: 66 | return None 67 | return sorted(cp_list)[-1] 68 | 69 | -------------------------------------------------------------------------------- /vdecoder/hifiganwithsnake/alias/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .act import * # noqa: F403 5 | from .filter import * # noqa: F403 6 | from .resample import * # noqa: F403 7 | -------------------------------------------------------------------------------- /vdecoder/hifiganwithsnake/alias/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import pow, sin 8 | from torch.nn import Parameter 9 | 10 | from .resample import DownSample1d, UpSample1d 11 | 12 | 13 | class Activation1d(nn.Module): 14 | def __init__(self, 15 | activation, 16 | up_ratio: int = 2, 17 | down_ratio: int = 2, 18 | up_kernel_size: int = 12, 19 | down_kernel_size: int = 12): 20 | super().__init__() 21 | self.up_ratio = up_ratio 22 | self.down_ratio = down_ratio 23 | self.act = activation 24 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 25 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 26 | 27 | # x: [B,C,T] 28 | def forward(self, x): 29 | x = self.upsample(x) 30 | x = self.act(x) 31 | x = self.downsample(x) 32 | 33 | return x 34 | 35 | 36 | class SnakeBeta(nn.Module): 37 | ''' 38 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 39 | Shape: 40 | - Input: (B, C, T) 41 | - Output: (B, C, T), same shape as the input 42 | Parameters: 43 | - alpha - trainable parameter that controls frequency 44 | - beta - trainable parameter that controls magnitude 45 | References: 46 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 47 | https://arxiv.org/abs/2006.08195 48 | Examples: 49 | >>> a1 = snakebeta(256) 50 | >>> x = torch.randn(256) 51 | >>> x = a1(x) 52 | ''' 53 | 54 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 55 | ''' 56 | Initialization. 57 | INPUT: 58 | - in_features: shape of the input 59 | - alpha - trainable parameter that controls frequency 60 | - beta - trainable parameter that controls magnitude 61 | alpha is initialized to 1 by default, higher values = higher-frequency. 62 | beta is initialized to 1 by default, higher values = higher-magnitude. 63 | alpha will be trained along with the rest of your model. 64 | ''' 65 | super(SnakeBeta, self).__init__() 66 | self.in_features = in_features 67 | # initialize alpha 68 | self.alpha_logscale = alpha_logscale 69 | if self.alpha_logscale: # log scale alphas initialized to zeros 70 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 71 | self.beta = Parameter(torch.zeros(in_features) * alpha) 72 | else: # linear scale alphas initialized to ones 73 | self.alpha = Parameter(torch.ones(in_features) * alpha) 74 | self.beta = Parameter(torch.ones(in_features) * alpha) 75 | self.alpha.requires_grad = alpha_trainable 76 | self.beta.requires_grad = alpha_trainable 77 | self.no_div_by_zero = 0.000000001 78 | 79 | def forward(self, x): 80 | ''' 81 | Forward pass of the function. 82 | Applies the function to the input elementwise. 83 | SnakeBeta = x + 1/b * sin^2 (xa) 84 | ''' 85 | alpha = self.alpha.unsqueeze( 86 | 0).unsqueeze(-1) # line up with x to [B, C, T] 87 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 88 | if self.alpha_logscale: 89 | alpha = torch.exp(alpha) 90 | beta = torch.exp(beta) 91 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 92 | return x 93 | 94 | 95 | class Mish(nn.Module): 96 | """ 97 | Mish activation function is proposed in "Mish: A Self 98 | Regularized Non-Monotonic Neural Activation Function" 99 | paper, https://arxiv.org/abs/1908.08681. 100 | """ 101 | 102 | def __init__(self): 103 | super().__init__() 104 | 105 | def forward(self, x): 106 | return x * torch.tanh(F.softplus(x)) 107 | 108 | 109 | class SnakeAlias(nn.Module): 110 | def __init__(self, 111 | channels, 112 | up_ratio: int = 2, 113 | down_ratio: int = 2, 114 | up_kernel_size: int = 12, 115 | down_kernel_size: int = 12, 116 | C = None): 117 | super().__init__() 118 | self.up_ratio = up_ratio 119 | self.down_ratio = down_ratio 120 | self.act = SnakeBeta(channels, alpha_logscale=True) 121 | self.upsample = UpSample1d(up_ratio, up_kernel_size, C) 122 | self.downsample = DownSample1d(down_ratio, down_kernel_size, C) 123 | 124 | # x: [B,C,T] 125 | def forward(self, x, C=None): 126 | x = self.upsample(x, C) 127 | x = self.act(x) 128 | x = self.downsample(x) 129 | 130 | return x -------------------------------------------------------------------------------- /vdecoder/hifiganwithsnake/alias/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | if 'sinc' in dir(torch): 11 | sinc = torch.sinc 12 | else: 13 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 14 | # https://adefossez.github.io/julius/julius/core.html 15 | # LICENSE is in incl_licenses directory. 16 | def sinc(x: torch.Tensor): 17 | """ 18 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 19 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 20 | """ 21 | return torch.where(x == 0, 22 | torch.tensor(1., device=x.device, dtype=x.dtype), 23 | torch.sin(math.pi * x) / math.pi / x) 24 | 25 | 26 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 27 | # https://adefossez.github.io/julius/julius/lowpass.html 28 | # LICENSE is in incl_licenses directory. 29 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 30 | even = (kernel_size % 2 == 0) 31 | half_size = kernel_size // 2 32 | 33 | #For kaiser window 34 | delta_f = 4 * half_width 35 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 36 | if A > 50.: 37 | beta = 0.1102 * (A - 8.7) 38 | elif A >= 21.: 39 | beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) 40 | else: 41 | beta = 0. 42 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 43 | 44 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 45 | if even: 46 | time = (torch.arange(-half_size, half_size) + 0.5) 47 | else: 48 | time = torch.arange(kernel_size) - half_size 49 | if cutoff == 0: 50 | filter_ = torch.zeros_like(time) 51 | else: 52 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 53 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 54 | # of the constant component in the input signal. 55 | filter_ /= filter_.sum() 56 | filter = filter_.view(1, 1, kernel_size) 57 | 58 | return filter 59 | 60 | 61 | class LowPassFilter1d(nn.Module): 62 | def __init__(self, 63 | cutoff=0.5, 64 | half_width=0.6, 65 | stride: int = 1, 66 | padding: bool = True, 67 | padding_mode: str = 'replicate', 68 | kernel_size: int = 12, 69 | C=None): 70 | # kernel_size should be even number for stylegan3 setup, 71 | # in this implementation, odd number is also possible. 72 | super().__init__() 73 | if cutoff < -0.: 74 | raise ValueError("Minimum cutoff must be larger than zero.") 75 | if cutoff > 0.5: 76 | raise ValueError("A cutoff above 0.5 does not make sense.") 77 | self.kernel_size = kernel_size 78 | self.even = (kernel_size % 2 == 0) 79 | self.pad_left = kernel_size // 2 - int(self.even) 80 | self.pad_right = kernel_size // 2 81 | self.stride = stride 82 | self.padding = padding 83 | self.padding_mode = padding_mode 84 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 85 | self.register_buffer("filter", filter) 86 | self.conv1d_block = None 87 | if C is not None: 88 | self.conv1d_block = [nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False),] 89 | self.conv1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1)) 90 | self.conv1d_block[0].requires_grad_(False) 91 | 92 | #input [B, C, T] 93 | def forward(self, x): 94 | if self.conv1d_block[0].weight.device != x.device: 95 | self.conv1d_block[0] = self.conv1d_block[0].to(x.device) 96 | if self.conv1d_block is None: 97 | _, C, _ = x.shape 98 | 99 | if self.padding: 100 | x = F.pad(x, (self.pad_left, self.pad_right), 101 | mode=self.padding_mode) 102 | out = F.conv1d(x, self.filter.expand(C, -1, -1), 103 | stride=self.stride, groups=C) 104 | else: 105 | if self.padding: 106 | x = F.pad(x, (self.pad_left, self.pad_right), 107 | mode=self.padding_mode) 108 | out = self.conv1d_block[0](x) 109 | 110 | return out -------------------------------------------------------------------------------- /vdecoder/hifiganwithsnake/alias/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | from .filter import LowPassFilter1d, kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None, C=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, 20 | half_width=0.6 / ratio, 21 | kernel_size=self.kernel_size) 22 | self.register_buffer("filter", filter) 23 | self.conv_transpose1d_block = None 24 | if C is not None: 25 | self.conv_transpose1d_block = [nn.ConvTranspose1d(C, 26 | C, 27 | kernel_size=self.kernel_size, 28 | stride=self.stride, 29 | groups=C, 30 | bias=False 31 | ),] 32 | self.conv_transpose1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1).clone()) 33 | self.conv_transpose1d_block[0].requires_grad_(False) 34 | 35 | 36 | 37 | # x: [B, C, T] 38 | def forward(self, x, C=None): 39 | if self.conv_transpose1d_block[0].weight.device != x.device: 40 | self.conv_transpose1d_block[0] = self.conv_transpose1d_block[0].to(x.device) 41 | if self.conv_transpose1d_block is None: 42 | if C is None: 43 | _, C, _ = x.shape 44 | # print("snake.conv_t.in:",x.shape) 45 | x = F.pad(x, (self.pad, self.pad), mode='replicate') 46 | x = self.ratio * F.conv_transpose1d( 47 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 48 | # print("snake.conv_t.out:",x.shape) 49 | x = x[..., self.pad_left:-self.pad_right] 50 | else: 51 | x = F.pad(x, (self.pad, self.pad), mode='replicate') 52 | x = self.ratio * self.conv_transpose1d_block[0](x) 53 | x = x[..., self.pad_left:-self.pad_right] 54 | return x 55 | 56 | 57 | class DownSample1d(nn.Module): 58 | def __init__(self, ratio=2, kernel_size=None, C=None): 59 | super().__init__() 60 | self.ratio = ratio 61 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 62 | self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, 63 | half_width=0.6 / ratio, 64 | stride=ratio, 65 | kernel_size=self.kernel_size, 66 | C=C) 67 | 68 | 69 | def forward(self, x): 70 | xx = self.lowpass(x) 71 | 72 | return xx -------------------------------------------------------------------------------- /vdecoder/hifiganwithsnake/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /vdecoder/hifiganwithsnake/nvSTFT.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | import soundfile as sf 6 | import torch 7 | import torch.utils.data 8 | from librosa.filters import mel as librosa_mel_fn 9 | 10 | os.environ["LRU_CACHE_CAPACITY"] = "3" 11 | 12 | def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): 13 | sampling_rate = None 14 | try: 15 | data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. 16 | except Exception as ex: 17 | print(f"'{full_path}' failed to load.\nException:") 18 | print(ex) 19 | if return_empty_on_exception: 20 | return [], sampling_rate or target_sr or 32000 21 | else: 22 | raise Exception(ex) 23 | 24 | if len(data.shape) > 1: 25 | data = data[:, 0] 26 | assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) 27 | 28 | if np.issubdtype(data.dtype, np.integer): # if audio data is type int 29 | max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX 30 | else: # if audio data is type fp32 31 | max_mag = max(np.amax(data), -np.amin(data)) 32 | max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 33 | 34 | data = torch.FloatTensor(data.astype(np.float32))/max_mag 35 | 36 | if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except 37 | return [], sampling_rate or target_sr or 32000 38 | if target_sr is not None and sampling_rate != target_sr: 39 | data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) 40 | sampling_rate = target_sr 41 | 42 | return data, sampling_rate 43 | 44 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 45 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 46 | 47 | def dynamic_range_decompression(x, C=1): 48 | return np.exp(x) / C 49 | 50 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 51 | return torch.log(torch.clamp(x, min=clip_val) * C) 52 | 53 | def dynamic_range_decompression_torch(x, C=1): 54 | return torch.exp(x) / C 55 | 56 | class STFT(): 57 | def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): 58 | self.target_sr = sr 59 | 60 | self.n_mels = n_mels 61 | self.n_fft = n_fft 62 | self.win_size = win_size 63 | self.hop_length = hop_length 64 | self.fmin = fmin 65 | self.fmax = fmax 66 | self.clip_val = clip_val 67 | self.mel_basis = {} 68 | self.hann_window = {} 69 | 70 | def get_mel(self, y, center=False): 71 | sampling_rate = self.target_sr 72 | n_mels = self.n_mels 73 | n_fft = self.n_fft 74 | win_size = self.win_size 75 | hop_length = self.hop_length 76 | fmin = self.fmin 77 | fmax = self.fmax 78 | clip_val = self.clip_val 79 | 80 | if torch.min(y) < -1.: 81 | print('min value is ', torch.min(y)) 82 | if torch.max(y) > 1.: 83 | print('max value is ', torch.max(y)) 84 | 85 | if fmax not in self.mel_basis: 86 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) 87 | self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 88 | self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device) 89 | 90 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect') 91 | y = y.squeeze(1) 92 | 93 | spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)], 94 | center=center, pad_mode='reflect', normalized=False, onesided=True) 95 | # print(111,spec) 96 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 97 | # print(222,spec) 98 | spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec) 99 | # print(333,spec) 100 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val) 101 | # print(444,spec) 102 | return spec 103 | 104 | def __call__(self, audiopath): 105 | audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) 106 | spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) 107 | return spect 108 | 109 | stft = STFT() 110 | -------------------------------------------------------------------------------- /vdecoder/hifiganwithsnake/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | # matplotlib.use("Agg") 5 | import matplotlib.pylab as plt 6 | import torch 7 | from torch.nn.utils import weight_norm 8 | 9 | 10 | def plot_spectrogram(spectrogram): 11 | fig, ax = plt.subplots(figsize=(10, 2)) 12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 13 | interpolation='none') 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size*dilation - dilation)/2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def del_old_checkpoints(cp_dir, prefix, n_models=2): 53 | pattern = os.path.join(cp_dir, prefix + '????????') 54 | cp_list = glob.glob(pattern) # get checkpoint paths 55 | cp_list = sorted(cp_list)# sort by iter 56 | if len(cp_list) > n_models: # if more than n_models models are found 57 | for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models 58 | open(cp, 'w').close()# empty file contents 59 | os.unlink(cp)# delete file (move to trash when using Colab) 60 | 61 | 62 | def scan_checkpoint(cp_dir, prefix): 63 | pattern = os.path.join(cp_dir, prefix + '????????') 64 | cp_list = glob.glob(pattern) 65 | if len(cp_list) == 0: 66 | return None 67 | return sorted(cp_list)[-1] 68 | 69 | -------------------------------------------------------------------------------- /vdecoder/nsf_hifigan/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /vdecoder/nsf_hifigan/nvSTFT.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | import soundfile as sf 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.utils.data 9 | from librosa.filters import mel as librosa_mel_fn 10 | 11 | os.environ["LRU_CACHE_CAPACITY"] = "3" 12 | 13 | def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): 14 | sampling_rate = None 15 | try: 16 | data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. 17 | except Exception as ex: 18 | print(f"'{full_path}' failed to load.\nException:") 19 | print(ex) 20 | if return_empty_on_exception: 21 | return [], sampling_rate or target_sr or 48000 22 | else: 23 | raise Exception(ex) 24 | 25 | if len(data.shape) > 1: 26 | data = data[:, 0] 27 | assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) 28 | 29 | if np.issubdtype(data.dtype, np.integer): # if audio data is type int 30 | max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX 31 | else: # if audio data is type fp32 32 | max_mag = max(np.amax(data), -np.amin(data)) 33 | max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 34 | 35 | data = torch.FloatTensor(data.astype(np.float32))/max_mag 36 | 37 | if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except 38 | return [], sampling_rate or target_sr or 48000 39 | if target_sr is not None and sampling_rate != target_sr: 40 | data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) 41 | sampling_rate = target_sr 42 | 43 | return data, sampling_rate 44 | 45 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 46 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 47 | 48 | def dynamic_range_decompression(x, C=1): 49 | return np.exp(x) / C 50 | 51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 52 | return torch.log(torch.clamp(x, min=clip_val) * C) 53 | 54 | def dynamic_range_decompression_torch(x, C=1): 55 | return torch.exp(x) / C 56 | 57 | class STFT(): 58 | def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): 59 | self.target_sr = sr 60 | 61 | self.n_mels = n_mels 62 | self.n_fft = n_fft 63 | self.win_size = win_size 64 | self.hop_length = hop_length 65 | self.fmin = fmin 66 | self.fmax = fmax 67 | self.clip_val = clip_val 68 | self.mel_basis = {} 69 | self.hann_window = {} 70 | 71 | def get_mel(self, y, keyshift=0, speed=1, center=False): 72 | sampling_rate = self.target_sr 73 | n_mels = self.n_mels 74 | n_fft = self.n_fft 75 | win_size = self.win_size 76 | hop_length = self.hop_length 77 | fmin = self.fmin 78 | fmax = self.fmax 79 | clip_val = self.clip_val 80 | 81 | factor = 2 ** (keyshift / 12) 82 | n_fft_new = int(np.round(n_fft * factor)) 83 | win_size_new = int(np.round(win_size * factor)) 84 | hop_length_new = int(np.round(hop_length * speed)) 85 | 86 | if torch.min(y) < -1.: 87 | print('min value is ', torch.min(y)) 88 | if torch.max(y) > 1.: 89 | print('max value is ', torch.max(y)) 90 | 91 | mel_basis_key = str(fmax)+'_'+str(y.device) 92 | if mel_basis_key not in self.mel_basis: 93 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) 94 | self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) 95 | 96 | keyshift_key = str(keyshift)+'_'+str(y.device) 97 | if keyshift_key not in self.hann_window: 98 | self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device) 99 | 100 | pad_left = (win_size_new - hop_length_new) //2 101 | pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left) 102 | if pad_right < y.size(-1): 103 | mode = 'reflect' 104 | else: 105 | mode = 'constant' 106 | y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode) 107 | y = y.squeeze(1) 108 | 109 | spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key], 110 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 111 | # print(111,spec) 112 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 113 | if keyshift != 0: 114 | size = n_fft // 2 + 1 115 | resize = spec.size(1) 116 | if resize < size: 117 | spec = F.pad(spec, (0, 0, 0, size-resize)) 118 | spec = spec[:, :size, :] * win_size / win_size_new 119 | 120 | # print(222,spec) 121 | spec = torch.matmul(self.mel_basis[mel_basis_key], spec) 122 | # print(333,spec) 123 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val) 124 | # print(444,spec) 125 | return spec 126 | 127 | def __call__(self, audiopath): 128 | audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) 129 | spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) 130 | return spect 131 | 132 | stft = STFT() 133 | -------------------------------------------------------------------------------- /vdecoder/nsf_hifigan/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import matplotlib 5 | import matplotlib.pylab as plt 6 | import torch 7 | from torch.nn.utils import weight_norm 8 | 9 | matplotlib.use("Agg") 10 | 11 | 12 | def plot_spectrogram(spectrogram): 13 | fig, ax = plt.subplots(figsize=(10, 2)) 14 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 15 | interpolation='none') 16 | plt.colorbar(im, ax=ax) 17 | 18 | fig.canvas.draw() 19 | plt.close() 20 | 21 | return fig 22 | 23 | 24 | def init_weights(m, mean=0.0, std=0.01): 25 | classname = m.__class__.__name__ 26 | if classname.find("Conv") != -1: 27 | m.weight.data.normal_(mean, std) 28 | 29 | 30 | def apply_weight_norm(m): 31 | classname = m.__class__.__name__ 32 | if classname.find("Conv") != -1: 33 | weight_norm(m) 34 | 35 | 36 | def get_padding(kernel_size, dilation=1): 37 | return int((kernel_size*dilation - dilation)/2) 38 | 39 | 40 | def load_checkpoint(filepath, device): 41 | assert os.path.isfile(filepath) 42 | print("Loading '{}'".format(filepath)) 43 | checkpoint_dict = torch.load(filepath, map_location=device) 44 | print("Complete.") 45 | return checkpoint_dict 46 | 47 | 48 | def save_checkpoint(filepath, obj): 49 | print("Saving checkpoint to {}".format(filepath)) 50 | torch.save(obj, filepath) 51 | print("Complete.") 52 | 53 | 54 | def del_old_checkpoints(cp_dir, prefix, n_models=2): 55 | pattern = os.path.join(cp_dir, prefix + '????????') 56 | cp_list = glob.glob(pattern) # get checkpoint paths 57 | cp_list = sorted(cp_list)# sort by iter 58 | if len(cp_list) > n_models: # if more than n_models models are found 59 | for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models 60 | open(cp, 'w').close()# empty file contents 61 | os.unlink(cp)# delete file (move to trash when using Colab) 62 | 63 | 64 | def scan_checkpoint(cp_dir, prefix): 65 | pattern = os.path.join(cp_dir, prefix + '????????') 66 | cp_list = glob.glob(pattern) 67 | if len(cp_list) == 0: 68 | return None 69 | return sorted(cp_list)[-1] 70 | 71 | -------------------------------------------------------------------------------- /vencoder/CNHubertLarge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fairseq import checkpoint_utils 3 | 4 | from vencoder.encoder import SpeechEncoder 5 | 6 | 7 | class CNHubertLarge(SpeechEncoder): 8 | def __init__(self, vec_path="pretrain/chinese-hubert-large-fairseq-ckpt.pt", device=None): 9 | super().__init__() 10 | print("load model(s) from {}".format(vec_path)) 11 | self.hidden_dim = 1024 12 | models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( 13 | [vec_path], 14 | suffix="", 15 | ) 16 | if device is None: 17 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | else: 19 | self.dev = torch.device(device) 20 | self.model = models[0].to(self.dev) 21 | self.model.eval() 22 | 23 | def encoder(self, wav): 24 | feats = wav 25 | if feats.dim() == 2: # double channels 26 | feats = feats.mean(-1) 27 | assert feats.dim() == 1, feats.dim() 28 | feats = feats.view(1, -1) 29 | padding_mask = torch.BoolTensor(feats.shape).fill_(False) 30 | inputs = { 31 | "source": feats.to(wav.device), 32 | "padding_mask": padding_mask.to(wav.device) 33 | } 34 | with torch.no_grad(): 35 | logits = self.model.extract_features(**inputs) 36 | return logits[0].transpose(1, 2) -------------------------------------------------------------------------------- /vencoder/ContentVec256L12_Onnx.py: -------------------------------------------------------------------------------- 1 | import onnxruntime 2 | import torch 3 | 4 | from vencoder.encoder import SpeechEncoder 5 | 6 | 7 | class ContentVec256L12_Onnx(SpeechEncoder): 8 | def __init__(self, vec_path="pretrain/vec-256-layer-12.onnx", device=None): 9 | super().__init__() 10 | print("load model(s) from {}".format(vec_path)) 11 | self.hidden_dim = 256 12 | if device is None: 13 | self.dev = torch.device("cpu") 14 | else: 15 | self.dev = torch.device(device) 16 | 17 | if device == 'cuda' or device == torch.device("cuda"): 18 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 19 | else: 20 | providers = ['CPUExecutionProvider'] 21 | 22 | self.model = onnxruntime.InferenceSession(vec_path, providers=providers) 23 | 24 | def encoder(self, wav): 25 | feats = wav 26 | if feats.dim() == 2: # double channels 27 | feats = feats.mean(-1) 28 | assert feats.dim() == 1, feats.dim() 29 | feats = feats.view(1, -1) 30 | feats = feats.unsqueeze(0).cpu().detach().numpy() 31 | onnx_input = {self.model.get_inputs()[0].name: feats} 32 | logits = self.model.run(None, onnx_input) 33 | return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) 34 | -------------------------------------------------------------------------------- /vencoder/ContentVec256L9.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fairseq import checkpoint_utils 3 | 4 | from vencoder.encoder import SpeechEncoder 5 | 6 | 7 | class ContentVec256L9(SpeechEncoder): 8 | def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): 9 | super().__init__() 10 | print("load model(s) from {}".format(vec_path)) 11 | models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( 12 | [vec_path], 13 | suffix="", 14 | ) 15 | self.hidden_dim = 256 16 | if device is None: 17 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | else: 19 | self.dev = torch.device(device) 20 | self.model = models[0].to(self.dev) 21 | self.model.eval() 22 | 23 | def encoder(self, wav): 24 | feats = wav 25 | if feats.dim() == 2: # double channels 26 | feats = feats.mean(-1) 27 | assert feats.dim() == 1, feats.dim() 28 | feats = feats.view(1, -1) 29 | padding_mask = torch.BoolTensor(feats.shape).fill_(False) 30 | inputs = { 31 | "source": feats.to(wav.device), 32 | "padding_mask": padding_mask.to(wav.device), 33 | "output_layer": 9, # layer 9 34 | } 35 | with torch.no_grad(): 36 | logits = self.model.extract_features(**inputs) 37 | feats = self.model.final_proj(logits[0]) 38 | return feats.transpose(1, 2) 39 | -------------------------------------------------------------------------------- /vencoder/ContentVec256L9_Onnx.py: -------------------------------------------------------------------------------- 1 | import onnxruntime 2 | import torch 3 | 4 | from vencoder.encoder import SpeechEncoder 5 | 6 | 7 | class ContentVec256L9_Onnx(SpeechEncoder): 8 | def __init__(self, vec_path="pretrain/vec-256-layer-9.onnx", device=None): 9 | super().__init__() 10 | print("load model(s) from {}".format(vec_path)) 11 | self.hidden_dim = 256 12 | if device is None: 13 | self.dev = torch.device("cpu") 14 | else: 15 | self.dev = torch.device(device) 16 | if device == 'cpu' or device == torch.device("cpu") or device is None: 17 | providers = ['CPUExecutionProvider'] 18 | elif device == 'cuda' or device == torch.device("cuda"): 19 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 20 | self.model = onnxruntime.InferenceSession(vec_path, providers=providers) 21 | 22 | def encoder(self, wav): 23 | feats = wav 24 | if feats.dim() == 2: # double channels 25 | feats = feats.mean(-1) 26 | assert feats.dim() == 1, feats.dim() 27 | feats = feats.view(1, -1) 28 | feats = feats.unsqueeze(0).cpu().detach().numpy() 29 | onnx_input = {self.model.get_inputs()[0].name: feats} 30 | logits = self.model.run(None, onnx_input) 31 | return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) 32 | -------------------------------------------------------------------------------- /vencoder/ContentVec768L12.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fairseq import checkpoint_utils 3 | 4 | from vencoder.encoder import SpeechEncoder 5 | 6 | 7 | class ContentVec768L12(SpeechEncoder): 8 | def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): 9 | super().__init__() 10 | print("load model(s) from {}".format(vec_path)) 11 | self.hidden_dim = 768 12 | models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( 13 | [vec_path], 14 | suffix="", 15 | ) 16 | if device is None: 17 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | else: 19 | self.dev = torch.device(device) 20 | self.model = models[0].to(self.dev) 21 | self.model.eval() 22 | 23 | def encoder(self, wav): 24 | feats = wav 25 | if feats.dim() == 2: # double channels 26 | feats = feats.mean(-1) 27 | assert feats.dim() == 1, feats.dim() 28 | feats = feats.view(1, -1) 29 | padding_mask = torch.BoolTensor(feats.shape).fill_(False) 30 | inputs = { 31 | "source": feats.to(wav.device), 32 | "padding_mask": padding_mask.to(wav.device), 33 | "output_layer": 12, # layer 12 34 | } 35 | with torch.no_grad(): 36 | logits = self.model.extract_features(**inputs) 37 | return logits[0].transpose(1, 2) 38 | -------------------------------------------------------------------------------- /vencoder/ContentVec768L12_Onnx.py: -------------------------------------------------------------------------------- 1 | import onnxruntime 2 | import torch 3 | 4 | from vencoder.encoder import SpeechEncoder 5 | 6 | 7 | class ContentVec768L12_Onnx(SpeechEncoder): 8 | def __init__(self, vec_path="pretrain/vec-768-layer-12.onnx", device=None): 9 | super().__init__() 10 | print("load model(s) from {}".format(vec_path)) 11 | self.hidden_dim = 768 12 | if device is None: 13 | self.dev = torch.device("cpu") 14 | else: 15 | self.dev = torch.device(device) 16 | 17 | if device == 'cuda' or device == torch.device("cuda"): 18 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 19 | else: 20 | providers = ['CPUExecutionProvider'] 21 | 22 | self.model = onnxruntime.InferenceSession(vec_path, providers=providers) 23 | 24 | def encoder(self, wav): 25 | feats = wav 26 | if feats.dim() == 2: # double channels 27 | feats = feats.mean(-1) 28 | assert feats.dim() == 1, feats.dim() 29 | feats = feats.view(1, -1) 30 | feats = feats.unsqueeze(0).cpu().detach().numpy() 31 | onnx_input = {self.model.get_inputs()[0].name: feats} 32 | logits = self.model.run(None, onnx_input) 33 | return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) 34 | -------------------------------------------------------------------------------- /vencoder/ContentVec768L9_Onnx.py: -------------------------------------------------------------------------------- 1 | import onnxruntime 2 | import torch 3 | 4 | from vencoder.encoder import SpeechEncoder 5 | 6 | 7 | class ContentVec768L9_Onnx(SpeechEncoder): 8 | def __init__(self,vec_path = "pretrain/vec-768-layer-9.onnx",device=None): 9 | super().__init__() 10 | print("load model(s) from {}".format(vec_path)) 11 | self.hidden_dim = 768 12 | if device is None: 13 | self.dev = torch.device("cpu") 14 | else: 15 | self.dev = torch.device(device) 16 | 17 | if device == 'cuda' or device == torch.device("cuda"): 18 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 19 | else: 20 | providers = ['CPUExecutionProvider'] 21 | 22 | self.model = onnxruntime.InferenceSession(vec_path, providers=providers) 23 | 24 | def encoder(self, wav): 25 | feats = wav 26 | if feats.dim() == 2: # double channels 27 | feats = feats.mean(-1) 28 | assert feats.dim() == 1, feats.dim() 29 | feats = feats.view(1, -1) 30 | feats = feats.unsqueeze(0).cpu().detach().numpy() 31 | onnx_input = {self.model.get_inputs()[0].name: feats} 32 | logits = self.model.run(None, onnx_input) 33 | return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) 34 | -------------------------------------------------------------------------------- /vencoder/DPHubert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from vencoder.dphubert.model import wav2vec2_model 4 | from vencoder.encoder import SpeechEncoder 5 | 6 | 7 | class DPHubert(SpeechEncoder): 8 | def __init__(self, vec_path="pretrain/DPHuBERT-sp0.75.pth", device=None): 9 | super().__init__() 10 | print("load model(s) from {}".format(vec_path)) 11 | if device is None: 12 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | else: 14 | self.dev = torch.device(device) 15 | ckpt = torch.load(vec_path) 16 | self.hidden_dim = 768 17 | self.model = wav2vec2_model(**ckpt["config"]).to(self.dev) 18 | self.model.load_state_dict(ckpt["state_dict"], strict=False) 19 | 20 | def encoder(self, wav): 21 | feats = wav 22 | if feats.dim() == 2: # double channels 23 | feats = feats.mean(-1) 24 | assert feats.dim() == 1, feats.dim() 25 | feats = feats[None, :] 26 | with torch.no_grad(): 27 | with torch.inference_mode(): 28 | units = self.model(feats)[0] 29 | return units.transpose(1,2) 30 | -------------------------------------------------------------------------------- /vencoder/HubertSoft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from vencoder.encoder import SpeechEncoder 4 | from vencoder.hubert import hubert_model 5 | 6 | 7 | class HubertSoft(SpeechEncoder): 8 | def __init__(self, vec_path="pretrain/hubert-soft-0d54a1f4.pt", device=None): 9 | super().__init__() 10 | print("load model(s) from {}".format(vec_path)) 11 | hubert_soft = hubert_model.hubert_soft(vec_path) 12 | if device is None: 13 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | else: 15 | self.dev = torch.device(device) 16 | self.hidden_dim = 256 17 | self.model = hubert_soft.to(self.dev) 18 | 19 | def encoder(self, wav): 20 | feats = wav 21 | if feats.dim() == 2: # double channels 22 | feats = feats.mean(-1) 23 | assert feats.dim() == 1, feats.dim() 24 | feats = feats[None,None,:] 25 | with torch.no_grad(): 26 | with torch.inference_mode(): 27 | units = self.model.units(feats) 28 | return units.transpose(1,2) 29 | -------------------------------------------------------------------------------- /vencoder/HubertSoft_Onnx.py: -------------------------------------------------------------------------------- 1 | import onnxruntime 2 | import torch 3 | 4 | from vencoder.encoder import SpeechEncoder 5 | 6 | 7 | class HubertSoft_Onnx(SpeechEncoder): 8 | def __init__(self, vec_path="pretrain/hubert-soft.onnx", device=None): 9 | super().__init__() 10 | print("load model(s) from {}".format(vec_path)) 11 | self.hidden_dim = 256 12 | if device is None: 13 | self.dev = torch.device("cpu") 14 | else: 15 | self.dev = torch.device(device) 16 | 17 | if device == 'cuda' or device == torch.device("cuda"): 18 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 19 | else: 20 | providers = ['CPUExecutionProvider'] 21 | 22 | self.model = onnxruntime.InferenceSession(vec_path, providers=providers) 23 | 24 | def encoder(self, wav): 25 | feats = wav 26 | if feats.dim() == 2: # double channels 27 | feats = feats.mean(-1) 28 | assert feats.dim() == 1, feats.dim() 29 | feats = feats.view(1, -1) 30 | feats = feats.unsqueeze(0).cpu().detach().numpy() 31 | onnx_input = {self.model.get_inputs()[0].name: feats} 32 | logits = self.model.run(None, onnx_input) 33 | return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) 34 | -------------------------------------------------------------------------------- /vencoder/WavLMBasePlus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from vencoder.encoder import SpeechEncoder 4 | from vencoder.wavlm.WavLM import WavLM, WavLMConfig 5 | 6 | 7 | class WavLMBasePlus(SpeechEncoder): 8 | def __init__(self, vec_path="pretrain/WavLM-Base+.pt", device=None): 9 | super().__init__() 10 | print("load model(s) from {}".format(vec_path)) 11 | checkpoint = torch.load(vec_path) 12 | self.cfg = WavLMConfig(checkpoint['cfg']) 13 | if device is None: 14 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | else: 16 | self.dev = torch.device(device) 17 | self.hidden_dim = self.cfg.encoder_embed_dim 18 | self.model = WavLM(self.cfg) 19 | self.model.load_state_dict(checkpoint['model']) 20 | self.model.to(self.dev).eval() 21 | 22 | def encoder(self, wav): 23 | feats = wav 24 | if feats.dim() == 2: # double channels 25 | feats = feats.mean(-1) 26 | assert feats.dim() == 1, feats.dim() 27 | if self.cfg.normalize: 28 | feats = torch.nn.functional.layer_norm(feats, feats.shape) 29 | with torch.no_grad(): 30 | with torch.inference_mode(): 31 | units = self.model.extract_features(feats[None, :])[0] 32 | return units.transpose(1, 2) 33 | -------------------------------------------------------------------------------- /vencoder/WhisperPPG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from vencoder.encoder import SpeechEncoder 4 | from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim 5 | from vencoder.whisper.model import ModelDimensions, Whisper 6 | 7 | 8 | class WhisperPPG(SpeechEncoder): 9 | def __init__(self, vec_path="pretrain/medium.pt", device=None): 10 | super().__init__() 11 | if device is None: 12 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | else: 14 | self.dev = torch.device(device) 15 | checkpoint = torch.load(vec_path, map_location=device) 16 | dims = ModelDimensions(**checkpoint["dims"]) 17 | model = Whisper(dims) 18 | model.load_state_dict(checkpoint["model_state_dict"]) 19 | self.hidden_dim = dims 20 | self.model = model.to(self.dev) 21 | 22 | def encoder(self, wav): 23 | audio = wav 24 | audln = audio.shape[0] 25 | ppgln = audln // 320 26 | audio = pad_or_trim(audio) 27 | mel = log_mel_spectrogram(audio).to(self.dev) 28 | with torch.no_grad(): 29 | ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() 30 | ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev) 31 | return ppg[None, :, :].transpose(1, 2) 32 | -------------------------------------------------------------------------------- /vencoder/WhisperPPGLarge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from vencoder.encoder import SpeechEncoder 4 | from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim 5 | from vencoder.whisper.model import ModelDimensions, Whisper 6 | 7 | 8 | class WhisperPPGLarge(SpeechEncoder): 9 | def __init__(self, vec_path="pretrain/large-v2.pt", device=None): 10 | super().__init__() 11 | if device is None: 12 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | else: 14 | self.dev = torch.device(device) 15 | checkpoint = torch.load(vec_path, map_location=device) 16 | dims = ModelDimensions(**checkpoint["dims"]) 17 | model = Whisper(dims) 18 | model.load_state_dict(checkpoint["model_state_dict"]) 19 | self.hidden_dim = dims 20 | self.model = model.to(self.dev) 21 | 22 | def encoder(self, wav): 23 | audio = wav 24 | audln = audio.shape[0] 25 | ppgln = audln // 320 26 | audio = pad_or_trim(audio) 27 | mel = log_mel_spectrogram(audio).to(self.dev) 28 | with torch.no_grad(): 29 | ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() 30 | ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev) 31 | return ppg[None, :, :].transpose(1, 2) 32 | -------------------------------------------------------------------------------- /vencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vencoder/__init__.py -------------------------------------------------------------------------------- /vencoder/dphubert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vencoder/dphubert/__init__.py -------------------------------------------------------------------------------- /vencoder/dphubert/hardconcrete.py: -------------------------------------------------------------------------------- 1 | """Implementation of the hard Concrete distribution. 2 | 3 | Originally from: 4 | https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py 5 | 6 | """ 7 | 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class HardConcrete(nn.Module): 15 | """A HarcConcrete module. 16 | Use this module to create a mask of size N, which you can 17 | then use to perform L0 regularization. 18 | 19 | To obtain a mask, simply run a forward pass through the module 20 | with no input data. The mask is sampled in training mode, and 21 | fixed during evaluation mode, e.g.: 22 | 23 | >>> module = HardConcrete(n_in=100) 24 | >>> mask = module() 25 | >>> norm = module.l0_norm() 26 | """ 27 | 28 | def __init__( 29 | self, 30 | n_in: int, 31 | init_mean: float = 0.5, 32 | init_std: float = 0.01, 33 | temperature: float = 2/3, # from CoFi 34 | stretch: float = 0.1, 35 | eps: float = 1e-6 36 | ) -> None: 37 | """Initialize the HardConcrete module. 38 | Parameters 39 | ---------- 40 | n_in : int 41 | The number of hard concrete variables in this mask. 42 | init_mean : float, optional 43 | Initial drop rate for hard concrete parameter, 44 | by default 0.5., 45 | init_std: float, optional 46 | Used to initialize the hard concrete parameters, 47 | by default 0.01. 48 | temperature : float, optional 49 | Temperature used to control the sharpness of the 50 | distribution, by default 1.0 51 | stretch : float, optional 52 | Stretch the sampled value from [0, 1] to the interval 53 | [-stretch, 1 + stretch], by default 0.1. 54 | """ 55 | super().__init__() 56 | 57 | self.n_in = n_in 58 | self.limit_l = -stretch 59 | self.limit_r = 1.0 + stretch 60 | self.log_alpha = nn.Parameter(torch.zeros(n_in)) 61 | self.beta = temperature 62 | self.init_mean = init_mean 63 | self.init_std = init_std 64 | self.bias = -self.beta * math.log(-self.limit_l / self.limit_r) 65 | 66 | self.eps = eps 67 | self.compiled_mask = None 68 | self.reset_parameters() 69 | 70 | def reset_parameters(self): 71 | """Reset the parameters of this module.""" 72 | self.compiled_mask = None 73 | mean = math.log(1 - self.init_mean) - math.log(self.init_mean) 74 | self.log_alpha.data.normal_(mean, self.init_std) 75 | 76 | def l0_norm(self) -> torch.Tensor: 77 | """Compute the expected L0 norm of this mask. 78 | Returns 79 | ------- 80 | torch.Tensor 81 | The expected L0 norm. 82 | """ 83 | return (self.log_alpha + self.bias).sigmoid().sum() 84 | 85 | def forward(self) -> torch.Tensor: 86 | """Sample a hard concrete mask. 87 | Returns 88 | ------- 89 | torch.Tensor 90 | The sampled binary mask 91 | """ 92 | if self.training: 93 | # Reset the compiled mask 94 | self.compiled_mask = None 95 | # Sample mask dynamically 96 | u = self.log_alpha.new(self.n_in).uniform_(self.eps, 1 - self.eps) 97 | s = torch.sigmoid((torch.log(u / (1 - u)) + self.log_alpha) / self.beta) 98 | s = s * (self.limit_r - self.limit_l) + self.limit_l 99 | mask = s.clamp(min=0., max=1.) 100 | 101 | else: 102 | # Compile new mask if not cached 103 | if self.compiled_mask is None: 104 | # Get expected sparsity 105 | expected_num_zeros = self.n_in - self.l0_norm().item() 106 | num_zeros = round(expected_num_zeros) 107 | # Approximate expected value of each mask variable z; 108 | # We use an empirically validated magic number 0.8 109 | soft_mask = torch.sigmoid(self.log_alpha / self.beta * 0.8) 110 | # Prune small values to set to 0 111 | _, indices = torch.topk(soft_mask, k=num_zeros, largest=False) 112 | soft_mask[indices] = 0. 113 | self.compiled_mask = soft_mask 114 | mask = self.compiled_mask 115 | 116 | return mask 117 | 118 | def extra_repr(self) -> str: 119 | return str(self.n_in) 120 | 121 | def __repr__(self) -> str: 122 | return "{}({})".format(self.__class__.__name__, self.extra_repr()) 123 | -------------------------------------------------------------------------------- /vencoder/dphubert/pruning_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for pruning.""" 2 | 3 | from typing import Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: str): 10 | "Prune linear layer in place." 11 | # NOTE: weight: (out_features, in_features), bias: (out_features,) 12 | if dim == "input": 13 | dim = 1 14 | layer.in_features = len(index) 15 | elif dim == "output": 16 | dim = 0 17 | layer.out_features = len(index) 18 | else: 19 | raise ValueError 20 | 21 | layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) 22 | if layer.bias is not None and dim == 0: 23 | layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) 24 | 25 | 26 | def prune_conv1d_layer(layer: nn.Conv1d, index: torch.LongTensor, dim: str): 27 | """Prune conv1d in place.""" 28 | # NOTE: weight: (out_channels, in_channels, kernel_size), bias: (out_channels,) 29 | if dim == "input": 30 | dim = 1 31 | layer.in_channels = len(index) 32 | elif dim == "output": 33 | dim = 0 34 | layer.out_channels = len(index) 35 | else: 36 | raise ValueError 37 | 38 | layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) 39 | if layer.bias is not None and dim == 0: 40 | layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) 41 | 42 | 43 | def prune_layer_norm(layernorm: Union[nn.LayerNorm, nn.GroupNorm], index: torch.LongTensor): 44 | """Prune layer norm or group norm in place.""" 45 | layernorm.weight = nn.Parameter(layernorm.weight.index_select(0, index).clone().detach()) 46 | layernorm.bias = nn.Parameter(layernorm.bias.index_select(0, index).clone().detach()) 47 | if isinstance(layernorm, nn.LayerNorm): 48 | layernorm.normalized_shape = (len(index),) 49 | elif isinstance(layernorm, nn.GroupNorm): 50 | layernorm.num_groups = len(index) 51 | layernorm.num_channels = len(index) 52 | -------------------------------------------------------------------------------- /vencoder/dphubert/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vencoder/dphubert/utils/__init__.py -------------------------------------------------------------------------------- /vencoder/dphubert/utils/import_huggingface_wavlm.py: -------------------------------------------------------------------------------- 1 | """Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format. 2 | 3 | Originally from: 4 | https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/utils/import_huggingface.py 5 | 6 | """ 7 | 8 | import logging 9 | from typing import Any, Dict 10 | 11 | from torch.nn import Module 12 | 13 | from ..model import Wav2Vec2Model, wav2vec2_model, wavlm_model 14 | 15 | _LG = logging.getLogger(__name__) 16 | 17 | 18 | def _get_config(cfg): 19 | config = { 20 | "extractor_mode": f"{cfg.feat_extract_norm}_norm", 21 | "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), 22 | "extractor_conv_bias": cfg.conv_bias, 23 | "encoder_embed_dim": cfg.hidden_size, 24 | "encoder_projection_dropout": cfg.feat_proj_dropout, 25 | "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, 26 | "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, 27 | "encoder_num_layers": cfg.num_hidden_layers, 28 | "encoder_num_heads": cfg.num_attention_heads, 29 | "encoder_attention_dropout": cfg.attention_dropout, 30 | "encoder_ff_interm_features": cfg.intermediate_size, 31 | "encoder_ff_interm_dropout": cfg.activation_dropout, 32 | "encoder_dropout": cfg.hidden_dropout, 33 | "encoder_layer_norm_first": cfg.do_stable_layer_norm, 34 | "encoder_layer_drop": cfg.layerdrop, 35 | } 36 | return config 37 | 38 | 39 | def _get_config_wavlm(cfg): 40 | config = { 41 | "extractor_mode": f"{cfg.feat_extract_norm}_norm", 42 | "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), 43 | "extractor_conv_bias": cfg.conv_bias, 44 | "encoder_embed_dim": cfg.hidden_size, 45 | "encoder_projection_dropout": cfg.feat_proj_dropout, 46 | "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, 47 | "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, 48 | "encoder_num_layers": cfg.num_hidden_layers, 49 | "encoder_use_attention": [True] * cfg.num_hidden_layers, 50 | "encoder_use_feed_forward": [True] * cfg.num_hidden_layers, 51 | "encoder_total_num_heads": [cfg.num_attention_heads for _ in range(cfg.num_hidden_layers)], 52 | "encoder_remaining_heads": [list(range(cfg.num_attention_heads)) for _ in range(cfg.num_hidden_layers)], 53 | "encoder_num_buckets": cfg.num_buckets, 54 | "encoder_max_distance": cfg.max_bucket_distance, 55 | "encoder_attention_dropout": cfg.attention_dropout, 56 | "encoder_ff_interm_features": [cfg.intermediate_size for _ in range(cfg.num_hidden_layers)], 57 | "encoder_ff_interm_dropout": cfg.activation_dropout, 58 | "encoder_dropout": cfg.hidden_dropout, 59 | "encoder_layer_norm_first": cfg.do_stable_layer_norm, 60 | "encoder_layer_drop": cfg.layerdrop, 61 | "normalize_waveform": cfg.feat_extract_norm == "layer", 62 | } 63 | return config 64 | 65 | 66 | def _build(config, original): 67 | is_for_ctc = original.__class__.__name__ in ["Wav2Vec2ForCTC", "WavLMForCTC"] 68 | if is_for_ctc: 69 | aux_num_out = original.config.vocab_size 70 | wav2vec2 = original.wav2vec2 71 | else: 72 | _LG.warning( 73 | "The model is not an instance of Wav2Vec2ForCTC or WavLMForCTC. " '"lm_head" module is not imported.' 74 | ) 75 | aux_num_out = None 76 | wav2vec2 = original 77 | is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] 78 | if is_wavlm: 79 | imported = wavlm_model(**config, aux_num_out=aux_num_out) 80 | else: 81 | imported = wav2vec2_model(**config, aux_num_out=aux_num_out) 82 | print(imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict(), strict=False)) 83 | print(imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict(), strict=False)) 84 | encoder_state_dict = wav2vec2.encoder.state_dict() 85 | if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model 86 | transform_wavlm_encoder_state(encoder_state_dict, config["encoder_num_layers"]) 87 | print(imported.encoder.transformer.load_state_dict(encoder_state_dict, strict=False)) 88 | if is_for_ctc: 89 | imported.aux.load_state_dict(original.lm_head.state_dict()) 90 | return imported 91 | 92 | 93 | def transform_wavlm_encoder_state(state: Dict[str, Any], encoder_num_layers: int): 94 | """Converts WavLM encoder state from HuggingFace format. In particular, concatenates linear projection weights and 95 | biases to align with the structure of ``torch.nn.MultiheadAttention``. 96 | """ 97 | pass 98 | 99 | 100 | def import_huggingface_model(original: Module) -> Wav2Vec2Model: 101 | """Builds :class:`Wav2Vec2Model` from the corresponding model object of 102 | `Transformers `_. 103 | 104 | Args: 105 | original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``. 106 | 107 | Returns: 108 | Wav2Vec2Model: Imported model. 109 | 110 | Example 111 | >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model 112 | >>> 113 | >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") 114 | >>> model = import_huggingface_model(original) 115 | >>> 116 | >>> waveforms, _ = torchaudio.load("audio.wav") 117 | >>> logits, _ = model(waveforms) 118 | """ 119 | _LG.info("Importing model.") 120 | _LG.info("Loading model configuration.") 121 | is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] 122 | if is_wavlm: 123 | config = _get_config_wavlm(original.config) 124 | else: 125 | config = _get_config(original.config) 126 | _LG.debug(" - config: %s", config) 127 | _LG.info("Building model.") 128 | imported = _build(config, original) 129 | return imported 130 | -------------------------------------------------------------------------------- /vencoder/encoder.py: -------------------------------------------------------------------------------- 1 | class SpeechEncoder(object): 2 | def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): 3 | self.model = None # This is Model 4 | self.hidden_dim = 768 5 | pass 6 | 7 | 8 | def encoder(self, wav): 9 | """ 10 | input: wav:[signal_length] 11 | output: embedding:[batchsize,hidden_dim,wav_frame] 12 | """ 13 | pass 14 | -------------------------------------------------------------------------------- /vencoder/hubert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vencoder/hubert/__init__.py -------------------------------------------------------------------------------- /vencoder/whisper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svc-develop-team/so-vits-svc/730930d337d171479eadf305f96cbed4bb393e77/vencoder/whisper/__init__.py -------------------------------------------------------------------------------- /vencoder/whisper/audio.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from typing import Union 3 | 4 | import ffmpeg 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from librosa.filters import mel as librosa_mel_fn 9 | 10 | from .utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | N_MELS = 80 16 | HOP_LENGTH = 160 17 | CHUNK_LENGTH = 30 18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk 19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input 20 | 21 | 22 | def load_audio(file: str, sr: int = SAMPLE_RATE): 23 | """ 24 | Open an audio file and read as mono waveform, resampling as necessary 25 | 26 | Parameters 27 | ---------- 28 | file: str 29 | The audio file to open 30 | 31 | sr: int 32 | The sample rate to resample the audio if necessary 33 | 34 | Returns 35 | ------- 36 | A NumPy array containing the audio waveform, in float32 dtype. 37 | """ 38 | try: 39 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary. 40 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. 41 | out, _ = ( 42 | ffmpeg.input(file, threads=0) 43 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) 44 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) 45 | ) 46 | except ffmpeg.Error as e: 47 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 48 | 49 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 50 | 51 | 52 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 53 | """ 54 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 55 | """ 56 | if torch.is_tensor(array): 57 | if array.shape[axis] > length: 58 | array = array.index_select(dim=axis, index=torch.arange(length, device=array.device)) 59 | 60 | if array.shape[axis] < length: 61 | pad_widths = [(0, 0)] * array.ndim 62 | pad_widths[axis] = (0, length - array.shape[axis]) 63 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 64 | else: 65 | if array.shape[axis] > length: 66 | array = array.take(indices=range(length), axis=axis) 67 | 68 | if array.shape[axis] < length: 69 | pad_widths = [(0, 0)] * array.ndim 70 | pad_widths[axis] = (0, length - array.shape[axis]) 71 | array = np.pad(array, pad_widths) 72 | 73 | return array 74 | 75 | 76 | @lru_cache(maxsize=None) 77 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: 78 | """ 79 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 80 | Allows decoupling librosa dependency; saved using: 81 | 82 | np.savez_compressed( 83 | "mel_filters.npz", 84 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 85 | ) 86 | """ 87 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 88 | return torch.from_numpy(librosa_mel_fn(sr=SAMPLE_RATE,n_fft=N_FFT,n_mels=n_mels)).to(device) 89 | 90 | 91 | def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): 92 | """ 93 | Compute the log-Mel spectrogram of 94 | 95 | Parameters 96 | ---------- 97 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 98 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 99 | 100 | n_mels: int 101 | The number of Mel-frequency filters, only 80 is supported 102 | 103 | Returns 104 | ------- 105 | torch.Tensor, shape = (80, n_frames) 106 | A Tensor that contains the Mel spectrogram 107 | """ 108 | if not torch.is_tensor(audio): 109 | if isinstance(audio, str): 110 | audio = load_audio(audio) 111 | audio = torch.from_numpy(audio) 112 | 113 | window = torch.hann_window(N_FFT).to(audio.device) 114 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 115 | magnitudes = stft[..., :-1].abs() ** 2 116 | 117 | filters = mel_filters(audio.device, n_mels) 118 | mel_spec = filters @ magnitudes 119 | 120 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 121 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 122 | log_spec = (log_spec + 4.0) / 4.0 123 | return log_spec 124 | -------------------------------------------------------------------------------- /vencoder/whisper/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import zlib 5 | from typing import Callable, TextIO 6 | 7 | system_encoding = sys.getdefaultencoding() 8 | 9 | if system_encoding != "utf-8": 10 | def make_safe(string): 11 | # replaces any character not representable using the system default encoding with an '?', 12 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). 13 | return string.encode(system_encoding, errors="replace").decode(system_encoding) 14 | else: 15 | def make_safe(string): 16 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding 17 | return string 18 | 19 | 20 | def exact_div(x, y): 21 | assert x % y == 0 22 | return x // y 23 | 24 | 25 | def str2bool(string): 26 | str2val = {"True": True, "False": False} 27 | if string in str2val: 28 | return str2val[string] 29 | else: 30 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 31 | 32 | 33 | def optional_int(string): 34 | return None if string == "None" else int(string) 35 | 36 | 37 | def optional_float(string): 38 | return None if string == "None" else float(string) 39 | 40 | 41 | def compression_ratio(text) -> float: 42 | text_bytes = text.encode("utf-8") 43 | return len(text_bytes) / len(zlib.compress(text_bytes)) 44 | 45 | 46 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): 47 | assert seconds >= 0, "non-negative timestamp expected" 48 | milliseconds = round(seconds * 1000.0) 49 | 50 | hours = milliseconds // 3_600_000 51 | milliseconds -= hours * 3_600_000 52 | 53 | minutes = milliseconds // 60_000 54 | milliseconds -= minutes * 60_000 55 | 56 | seconds = milliseconds // 1_000 57 | milliseconds -= seconds * 1_000 58 | 59 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 60 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 61 | 62 | 63 | class ResultWriter: 64 | extension: str 65 | 66 | def __init__(self, output_dir: str): 67 | self.output_dir = output_dir 68 | 69 | def __call__(self, result: dict, audio_path: str): 70 | audio_basename = os.path.basename(audio_path) 71 | output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension) 72 | 73 | with open(output_path, "w", encoding="utf-8") as f: 74 | self.write_result(result, file=f) 75 | 76 | def write_result(self, result: dict, file: TextIO): 77 | raise NotImplementedError 78 | 79 | 80 | class WriteTXT(ResultWriter): 81 | extension: str = "txt" 82 | 83 | def write_result(self, result: dict, file: TextIO): 84 | for segment in result["segments"]: 85 | print(segment['text'].strip(), file=file, flush=True) 86 | 87 | 88 | class WriteVTT(ResultWriter): 89 | extension: str = "vtt" 90 | 91 | def write_result(self, result: dict, file: TextIO): 92 | print("WEBVTT\n", file=file) 93 | for segment in result["segments"]: 94 | print( 95 | f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" 96 | f"{segment['text'].strip().replace('-->', '->')}\n", 97 | file=file, 98 | flush=True, 99 | ) 100 | 101 | 102 | class WriteSRT(ResultWriter): 103 | extension: str = "srt" 104 | 105 | def write_result(self, result: dict, file: TextIO): 106 | for i, segment in enumerate(result["segments"], start=1): 107 | # write srt lines 108 | print( 109 | f"{i}\n" 110 | f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " 111 | f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" 112 | f"{segment['text'].strip().replace('-->', '->')}\n", 113 | file=file, 114 | flush=True, 115 | ) 116 | 117 | 118 | class WriteTSV(ResultWriter): 119 | """ 120 | Write a transcript to a file in TSV (tab-separated values) format containing lines like: 121 | \t\t 122 | 123 | Using integer milliseconds as start and end times means there's no chance of interference from 124 | an environment setting a language encoding that causes the decimal in a floating point number 125 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. 126 | """ 127 | extension: str = "tsv" 128 | 129 | def write_result(self, result: dict, file: TextIO): 130 | print("start", "end", "text", sep="\t", file=file) 131 | for segment in result["segments"]: 132 | print(round(1000 * segment['start']), file=file, end="\t") 133 | print(round(1000 * segment['end']), file=file, end="\t") 134 | print(segment['text'].strip().replace("\t", " "), file=file, flush=True) 135 | 136 | 137 | class WriteJSON(ResultWriter): 138 | extension: str = "json" 139 | 140 | def write_result(self, result: dict, file: TextIO): 141 | json.dump(result, file) 142 | 143 | 144 | def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: 145 | writers = { 146 | "txt": WriteTXT, 147 | "vtt": WriteVTT, 148 | "srt": WriteSRT, 149 | "tsv": WriteTSV, 150 | "json": WriteJSON, 151 | } 152 | 153 | if output_format == "all": 154 | all_writers = [writer(output_dir) for writer in writers.values()] 155 | 156 | def write_all(result: dict, file: TextIO): 157 | for writer in all_writers: 158 | writer(result, file) 159 | 160 | return write_all 161 | 162 | return writers[output_format](output_dir) 163 | 164 | -------------------------------------------------------------------------------- /wav_upload.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | from google.colab import files 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--type", type=str, required=True, help="type of file to upload") 10 | args = parser.parse_args() 11 | file_type = args.type 12 | 13 | basepath = os.getcwd() 14 | uploaded = files.upload() # 上传文件 15 | assert(file_type in ['zip', 'audio']) 16 | if file_type == "zip": 17 | upload_path = "./upload/" 18 | for filename in uploaded.keys(): 19 | #将上传的文件移动到指定的位置上 20 | shutil.move(os.path.join(basepath, filename), os.path.join(upload_path, "userzip.zip")) 21 | elif file_type == "audio": 22 | upload_path = "./raw/" 23 | for filename in uploaded.keys(): 24 | #将上传的文件移动到指定的位置上 25 | shutil.move(os.path.join(basepath, filename), os.path.join(upload_path, filename)) --------------------------------------------------------------------------------