├── .gitignore ├── I18n.py ├── LICENSE ├── README.md ├── configs ├── reflow-vae-lynxnet.yaml ├── reflow-vae-musa.yaml └── reflow-vae-wavenet.yaml ├── data ├── train │ ├── .gitignore │ └── audio │ │ └── .gitignore └── val │ ├── .gitignore │ └── audio │ └── .gitignore ├── draw.py ├── encoder ├── hubert │ └── model.py └── rmvpe │ ├── __init__.py │ ├── constants.py │ ├── deepunet.py │ ├── inference.py │ ├── model.py │ ├── seq.py │ ├── spec.py │ └── utils.py ├── exp └── .gitignore ├── logger ├── __init__.py ├── saver.py └── utils.py ├── main.py ├── nsf_hifigan ├── env.py ├── models.py ├── nvSTFT.py └── utils.py ├── preprocess.py ├── pretrain ├── contentvec │ └── .gitignore ├── nsf_hifigan │ └── .gitignore └── rmvpe │ └── .gitignore ├── realtime.py ├── reflow ├── data_loaders.py ├── extractors.py ├── model_conformer_naive.py ├── naive_v2_diff.py ├── reflow.py ├── solver.py ├── vocoder.py └── wavenet.py ├── requirements.txt ├── slicer.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | 166 | .DS_Store 167 | checkpoint_best_legacy_500.pt 168 | pretrain/nsf_hifigan/config.json 169 | pretrain/nsf_hifigan/model 170 | data/* 171 | __pycache__ 172 | exp/* 173 | dataset_raw/* -------------------------------------------------------------------------------- /I18n.py: -------------------------------------------------------------------------------- 1 | import locale 2 | ''' 3 | 本地化方式如下所示 4 | ''' 5 | 6 | LANGUAGE_LIST = ['zh_CN', 'en_US', 'ja_JP'] 7 | LANGUAGE_ALL = { 8 | 'zh_CN': { 9 | 'SUPER': 'END', 10 | 'LANGUAGE': 'zh_CN', 11 | '选择模型文件': '选择模型文件', 12 | '选择配置文件所在目录': '选择配置文件所在目录', 13 | '模型文件:.pt格式(自动识别同目录下config.yaml)':'模型文件:.pt格式(自动识别同目录下config.yaml)', 14 | '打开文件夹': '打开文件夹', 15 | '读取配置文件': '读取配置文件', 16 | '保存配置文件': '保存配置文件', 17 | '快速配置文件': '快速配置文件', 18 | '输入设备': '输入设备', 19 | '输出设备': '输出设备', 20 | '音频设备': '音频设备', 21 | '说话人id': '说话人id', 22 | '响应阈值': '响应阈值', 23 | '变调': '变调', 24 | '采样率': '采样率', 25 | '启用捏音色功能': '启用捏音色功能', 26 | '设置混合音色': '设置混合音色', 27 | '普通设置': '普通设置', 28 | '音频切分大小': '音频切分大小', 29 | '交叉淡化时长': '交叉淡化时长', 30 | '额外推理时长': '额外推理时长', 31 | 'f0预测模式': 'f0预测模式', 32 | '启用相位声码器': '启用相位声码器', 33 | '性能设置': '性能设置', 34 | '开始音频转换': '开始音频转换', 35 | '停止音频转换': '停止音频转换', 36 | '推理所用时间(ms):': '推理所用时间(ms):', 37 | '不转换安全区(加速但损失效果)': '不转换安全区(加速但损失效果)', 38 | '采样步数': '采样步数', 39 | '时间起点': '时间起点', 40 | '采样算法': '采样算法', 41 | 'Reflow设置': 'Reflow设置', 42 | '共振峰偏移': '共振峰偏移', 43 | }, 44 | 'en_US': { 45 | 'SUPER': 'zh_CN', 46 | 'LANGUAGE': 'en_US', 47 | "选择模型文件": "Select model file", 48 | "选择配置文件所在目录": "Select configuration file directory", 49 | "模型文件:.pt格式(自动识别同目录下config.yaml)": "Model file: .pt format (automatically detects config.yaml in the same directory)", 50 | "打开文件夹": "Open folder", 51 | "读取配置文件": "Load configuration file", 52 | "保存配置文件": "Save configuration file", 53 | "快速配置文件": "Quick configuration file", 54 | "输入设备": "Input device", 55 | "输出设备": "Output device", 56 | "音频设备": "Audio device", 57 | "说话人id": "Speaker ID", 58 | "响应阈值": "Response threshold", 59 | "变调": "Pitch shift", 60 | "采样率": "Sampling rate", 61 | "启用捏音色功能": "Enable Mix Speaker", 62 | "设置混合音色": "Mix speaker", 63 | "普通设置": "General settings", 64 | "音频切分大小": "Segmentation size", 65 | "交叉淡化时长": "Crossfade duration", 66 | "额外推理时长": "Extra inference time", 67 | "f0预测模式": "f0Extractor", 68 | "启用相位声码器": "Enable phase vocoder", 69 | "性能设置": "Performance settings", 70 | "开始音频转换": "Start conversion", 71 | "停止音频转换": "Stop conversion", 72 | "推理所用时间(ms):": "Inference time(ms):", 73 | "不转换安全区(加速但损失效果)": "Ignore safe zone (faster but less effective)", 74 | "采样步数": "Sampling steps", 75 | "时间起点": "t_start", 76 | "采样算法": "Sampling method", 77 | "Reflow设置": "Reflow settings", 78 | "共振峰偏移": "Formant shift" 79 | }, 80 | 'ja_JP': { 81 | 'SUPER': 'zh_CN', 82 | 'LANGUAGE': 'ja_JP', 83 | '选择模型文件': 'モデルを選択', 84 | '选择配置文件所在目录': '設定ファイルを選択', 85 | "模型文件:.pt格式(自动识别同目录下config.yaml)": "モデルファイル: .pt形式(同じディレクトリ内のconfig.yamlを自動認識)", 86 | '打开文件夹': 'フォルダを開く', 87 | '读取配置文件': '設定ファイルを読み込む', 88 | '保存配置文件': '設定ファイルを保存', 89 | '快速配置文件': '設定プロファイル', 90 | '输入设备': '入力デバイス', 91 | '输出设备': '出力デバイス', 92 | '音频设备': '音声デバイス', 93 | '说话人id': '話者ID', 94 | '响应阈值': '応答時の閾値', 95 | '变调': '音程', 96 | '采样率': 'サンプリングレート', 97 | '启用捏音色功能': 'ミキシングを有効化', 98 | '设置混合音色': 'ミキシング', 99 | '普通设置': '通常設定', 100 | '音频切分大小': 'セグメンテーションのサイズ', 101 | '交叉淡化时长': 'クロスフェードの間隔', 102 | '额外推理时长': '追加推論時間', 103 | 'f0预测模式': 'f0予測モデル', 104 | '启用相位声码器': 'Phase Vocoderを有効化', 105 | '性能设置': 'パフォーマンスの設定', 106 | '开始音频转换': '変換開始', 107 | '停止音频转换': '変換停止', 108 | '推理所用时间(ms):': '推論時間(ms):', 109 | "不转换安全区(加速但损失效果)": "安全エリアを変換しない(高速化されますが、効果が損なわれます)", 110 | "采样步数": "Sampling steps", 111 | "时间起点": "t_start", 112 | "采样算法": "Sampling methods", 113 | "Reflow设置": "Reflow設定", 114 | "共振峰偏移": "フォルマントシフト" 115 | } 116 | } 117 | 118 | 119 | class I18nAuto: 120 | def __init__(self, language=None): 121 | self.language_list = LANGUAGE_LIST 122 | self.language_all = LANGUAGE_ALL 123 | self.language_map = {} 124 | if language is None: 125 | language = 'auto' 126 | if language == 'auto': 127 | language = locale.getdefaultlocale()[0] 128 | if language not in self.language_list: 129 | language = 'zh_CN' 130 | self.language = language 131 | super_language_list = [] 132 | while self.language_all[language]['SUPER'] != 'END': 133 | super_language_list.append(language) 134 | language = self.language_all[language]['SUPER'] 135 | super_language_list.append('zh_CN') 136 | super_language_list.reverse() 137 | for _lang in super_language_list: 138 | self.read_language(self.language_all[_lang]) 139 | 140 | def read_language(self, lang_dict: dict): 141 | for _key in lang_dict.keys(): 142 | self.language_map[_key] = lang_dict[_key] 143 | 144 | def __call__(self, key): 145 | return self.language_map[key] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 yxlllc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReFlow-VAE-SVC 2 | 3 | 安装依赖,数据准备,配置编码器(hubert 或者 contentvec) ,声码器 (nsf-hifigan) 与音高提取器 (RMVPE) 的环节与 DDSP-SVC 项目相同。 4 | 5 | 6 | (1)预处理: 7 | 8 | ```bash 9 | python preprocess.py -c configs/reflow-vae-wavenet.yaml 10 | ``` 11 | 12 | (2)训练(无底模): 13 | 14 | ```bash 15 | python train.py -c configs/reflow-vae-wavenet.yaml 16 | ``` 17 | wavenet的Beta版底模可以在这里下载:https://huggingface.co/OOPPEENN/pretrained_model 18 | lynxnet的Bate版底模可以在这里下载:https://huggingface.co/tepetst3033/Reflow_VAE_SVC_retrained_model_with_lynxnet 19 | 20 | (3)非实时推理: 21 | 22 | ```bash 23 | # 普通模式, 需要语义编码器, 比如 contentvec 24 | python main.py -i -m -o -k -tid -step -method 25 | # VAE 模式, 无需语义编码器, 特化 sid 到 tid 的变声(或者音高编辑,如果sid == tid) 26 | python main.py -i -m -o -k -sid -tid -step -method 27 | ``` 28 | -------------------------------------------------------------------------------- /configs/reflow-vae-lynxnet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | f0_extractor: 'rmvpe' # 'parselmouth', 'dio', 'harvest', 'crepe' or 'rmvpe' 3 | f0_min: 65 # about C2 4 | f0_max: 800 # about G5 5 | sampling_rate: 44100 6 | block_size: 512 # Equal to hop_length 7 | duration: 2 # Audio duration during training, must be less than the duration of the shortest audio clip 8 | encoder: 'contentvec768l12' # 'hubertsoft', 'hubertbase', 'hubertbase768', 'contentvec', 'contentvec768' or 'contentvec768l12' or 'cnhubertsoftfish' 9 | cnhubertsoft_gate: 10 10 | encoder_sample_rate: 16000 11 | encoder_hop_size: 320 12 | encoder_out_channels: 768 # 256 if using 'hubertsoft' 13 | encoder_ckpt: pretrain/contentvec/checkpoint_best_legacy_500.pt 14 | train_path: data/train # Create a folder named "audio" under this path and put the audio clip in it 15 | valid_path: data/val # Create a folder named "audio" under this path and put the audio clip in it 16 | extensions: # List of extension included in the data collection 17 | - wav 18 | model: 19 | type: 'RectifiedFlow_VAE' 20 | back_bone: 'lynxnet' 21 | n_layers: 6 22 | n_chans: 512 23 | n_hidden: 256 24 | use_pitch_aug: true 25 | use_attention: true 26 | n_spk: 1 # max number of different speakers 27 | device: cuda # training device 28 | vocoder: 29 | type: 'nsf-hifigan' 30 | ckpt: 'pretrain/nsf_hifigan/model' 31 | infer: 32 | infer_step: 50 33 | method: 'euler' # 'euler', 'rk4' 34 | env: 35 | expdir: exp/reflowvae-test 36 | gpu_id: 0 37 | train: 38 | num_workers: 2 # If your cpu and gpu are both very strong, set to 0 may be faster! 39 | amp_dtype: fp32 # fp32, fp16 or bf16 (fp16 or bf16 may be faster if it is supported by your gpu) 40 | batch_size: 48 41 | cache_all_data: true # Save Internal-Memory or Graphics-Memory if it is false, but may be slow 42 | cache_device: 'cpu' # Set to 'cuda' to cache the data into the Graphics-Memory, fastest speed for strong gpu 43 | cache_fp16: true 44 | epochs: 100000 45 | interval_log: 1 46 | interval_val: 2000 47 | interval_force_save: 10000 48 | lr: 0.0002 49 | decay_step: 100000 50 | gamma: 0.5 51 | weight_decay: 0 52 | save_opt: false 53 | -------------------------------------------------------------------------------- /configs/reflow-vae-musa.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | f0_extractor: 'fcpe' # 'parselmouth', 'dio', 'harvest', 'crepe' or 'rmvpe' 3 | f0_min: 65 # about C2 4 | f0_max: 800 # about G5 5 | sampling_rate: 44100 6 | block_size: 512 # Equal to hop_length 7 | duration: 2 # Audio duration during training, must be less than the duration of the shortest audio clip 8 | encoder: 'hubertsoft' # 'hubertsoft', 'hubertbase', 'hubertbase768', 'contentvec', 'contentvec768' or 'contentvec768l12' or 'cnhubertsoftfish' 9 | cnhubertsoft_gate: 10 10 | encoder_sample_rate: 16000 11 | encoder_hop_size: 320 12 | encoder_out_channels: 256 # 256 if using 'hubertsoft' 13 | encoder_ckpt: pretrain/hubert/hubert-soft-0d54a1f4.pt 14 | train_path: data/train # Create a folder named "audio" under this path and put the audio clip in it 15 | valid_path: data/val # Create a folder named "audio" under this path and put the audio clip in it 16 | extensions: # List of extension included in the data collection 17 | - wav 18 | model: 19 | type: 'RectifiedFlow_VAE' 20 | back_bone: 'wavenet' 21 | n_layers: 20 22 | n_chans: 512 23 | n_hidden: 256 24 | use_pitch_aug: true 25 | use_attention: false 26 | n_spk: 1 # max number of different speakers 27 | device: musa # training device 28 | vocoder: 29 | type: 'nsf-hifigan' 30 | ckpt: 'pretrain/nsf_hifigan/model' 31 | infer: 32 | infer_step: 50 33 | method: 'euler' # 'euler', 'rk4' 34 | env: 35 | expdir: exp/reflowvae-test-musa 36 | gpu_id: 0 37 | train: 38 | num_workers: 2 # If your cpu and gpu are both very strong, set to 0 may be faster! 39 | amp_dtype: fp32 # fp32, fp16 or bf16 (fp16 or bf16 may be faster if it is supported by your gpu) 40 | batch_size: 48 41 | cache_all_data: true # Save Internal-Memory or Graphics-Memory if it is false, but may be slow 42 | cache_device: 'cpu' # Set to 'cuda' to cache the data into the Graphics-Memory, fastest speed for strong gpu 43 | cache_fp16: true 44 | epochs: 100000 45 | interval_log: 1 46 | interval_val: 2000 47 | interval_force_save: 10000 48 | lr: 0.0002 49 | decay_step: 100000 50 | gamma: 0.5 51 | weight_decay: 0 52 | save_opt: false 53 | -------------------------------------------------------------------------------- /configs/reflow-vae-wavenet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | f0_extractor: 'rmvpe' # 'parselmouth', 'dio', 'harvest', 'crepe' or 'rmvpe' 3 | f0_min: 65 # about C2 4 | f0_max: 800 # about G5 5 | sampling_rate: 44100 6 | block_size: 512 # Equal to hop_length 7 | duration: 2 # Audio duration during training, must be less than the duration of the shortest audio clip 8 | encoder: 'contentvec768l12' # 'hubertsoft', 'hubertbase', 'hubertbase768', 'contentvec', 'contentvec768' or 'contentvec768l12' or 'cnhubertsoftfish' 9 | cnhubertsoft_gate: 10 10 | encoder_sample_rate: 16000 11 | encoder_hop_size: 320 12 | encoder_out_channels: 768 # 256 if using 'hubertsoft' 13 | encoder_ckpt: pretrain/contentvec/checkpoint_best_legacy_500.pt 14 | train_path: data/train # Create a folder named "audio" under this path and put the audio clip in it 15 | valid_path: data/val # Create a folder named "audio" under this path and put the audio clip in it 16 | extensions: # List of extension included in the data collection 17 | - wav 18 | model: 19 | type: 'RectifiedFlow_VAE' 20 | back_bone: 'wavenet' 21 | n_layers: 20 22 | n_chans: 512 23 | n_hidden: 256 24 | use_pitch_aug: true 25 | use_attention: true 26 | n_spk: 1 # max number of different speakers 27 | device: cuda # training device 28 | vocoder: 29 | type: 'nsf-hifigan' 30 | ckpt: 'pretrain/nsf_hifigan/model' 31 | infer: 32 | infer_step: 50 33 | method: 'euler' # 'euler', 'rk4' 34 | env: 35 | expdir: exp/reflowvae-test 36 | gpu_id: 0 37 | train: 38 | num_workers: 2 # If your cpu and gpu are both very strong, set to 0 may be faster! 39 | amp_dtype: fp32 # fp32, fp16 or bf16 (fp16 or bf16 may be faster if it is supported by your gpu) 40 | batch_size: 48 41 | cache_all_data: true # Save Internal-Memory or Graphics-Memory if it is false, but may be slow 42 | cache_device: 'cpu' # Set to 'cuda' to cache the data into the Graphics-Memory, fastest speed for strong gpu 43 | cache_fp16: true 44 | epochs: 100000 45 | interval_log: 1 46 | interval_val: 2000 47 | interval_force_save: 10000 48 | lr: 0.0002 49 | decay_step: 100000 50 | gamma: 0.5 51 | weight_decay: 0 52 | save_opt: false 53 | -------------------------------------------------------------------------------- /data/train/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !audio -------------------------------------------------------------------------------- /data/train/audio/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /data/val/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !audio -------------------------------------------------------------------------------- /data/val/audio/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /draw.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tqdm 4 | import os 5 | import shutil 6 | 7 | import soundfile as sf 8 | 9 | WAV_MIN_LENGTH = 2 # wav文件的最短时长 / The minimum duration of wav files 10 | SAMPLE_MIN = 2 # 抽取的文件数量下限 / The lower limit of the number of files to be extracted 11 | SAMPLE_MAX = 10 # 抽取的文件数量上限 / The upper limit of the number of files to be extracted 12 | 13 | 14 | def parse_args(args=None, namespace=None): 15 | """Parse command-line arguments.""" 16 | root_dir = os.path.abspath('.') 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "-t", 20 | "--train", 21 | type=str, 22 | default=root_dir + "/data/train/audio", # 固定源目录为根目录下/data/train/audio目录 23 | help="directory where contains train dataset" 24 | ) 25 | parser.add_argument( 26 | "-v", 27 | "--val", 28 | type=str, 29 | default=root_dir + "/data/val/audio", 30 | help="directory where contains validate dataset" 31 | ) 32 | parser.add_argument( 33 | "-r", 34 | "--sample_rate", 35 | type=float, 36 | default=1, 37 | help="The percentage of files to be extracted" # 抽取文件数量的百分比 38 | ) 39 | parser.add_argument( 40 | "-e", 41 | "--extensions", 42 | type=str, 43 | required=False, 44 | nargs="*", 45 | default=["wav", "flac"], 46 | help="list of using file extensions, e.g.) -f wav flac ..." 47 | ) 48 | return parser.parse_args(args=args, namespace=namespace) 49 | 50 | 51 | # 定义一个函数,用于检查wav文件的时长是否大于最短时长 52 | def check_duration(wav_file): 53 | # 打开wav文件 54 | f = sf.SoundFile(wav_file) 55 | # 获取帧数和帧率 56 | frames = f.frames 57 | rate = f.samplerate 58 | # 计算时长(秒) 59 | duration = frames / float(rate) 60 | # 关闭文件 61 | f.close() 62 | # 返回时长是否大于最短时长的布尔值 63 | return duration > WAV_MIN_LENGTH 64 | 65 | # 定义一个函数,用于从给定的目录中随机抽取一定比例的wav文件,并剪切到另一个目录中,保留数据结构 66 | def split_data(src_dir, dst_dir, ratio, extensions): 67 | # 创建目标目录(如果不存在) 68 | if not os.path.exists(dst_dir): 69 | os.makedirs(dst_dir) 70 | 71 | # 获取源目录下所有的子目录和文件名 72 | subdirs, files, subfiles = [], [], [] 73 | for item in os.listdir(src_dir): 74 | item_path = os.path.join(src_dir, item) 75 | if os.path.isdir(item_path): 76 | subdirs.append(item) 77 | for subitem in os.listdir(item_path): 78 | subitem_path = os.path.join(item_path, subitem) 79 | if os.path.isfile(subitem_path) and any([subitem.endswith(f".{ext}") for ext in extensions]): 80 | subfiles.append(subitem) 81 | elif os.path.isfile(item_path) and any([item.endswith(f".{ext}") for ext in extensions]): 82 | files.append(item) 83 | 84 | # 如果源目录下没有任何wav文件,则报错并退出函数 85 | if len(files) == 0: 86 | if len(subfiles) == 0: 87 | print(f"Error: No wav files found in {src_dir}") 88 | return 89 | 90 | # 计算需要抽取的wav文件数量 91 | num_files = int(len(files) * ratio) 92 | num_files = max(SAMPLE_MIN, min(SAMPLE_MAX, num_files)) 93 | 94 | # 随机打乱文件名列表,并取出前num_files个作为抽取结果 95 | np.random.shuffle(files) 96 | selected_files = files[:num_files] 97 | 98 | # 创建一个进度条对象,用于显示程序的运行进度 99 | pbar = tqdm.tqdm(total=num_files) 100 | 101 | # 遍历抽取结果中的每个文件名,检查是否大于2秒 102 | for file in selected_files: 103 | src_file = os.path.join(src_dir, file) 104 | # 检查源文件的时长是否大于2秒,如果不是,则打印源文件的文件名,并跳过该文件 105 | if not check_duration(src_file): 106 | print(f"Skipped {src_file} because its duration is less than 2 seconds.") 107 | continue 108 | # 拼接源文件和目标文件的完整路径,移动文件,并更新进度条 109 | dst_file = os.path.join(dst_dir, file) 110 | shutil.move(src_file, dst_file) 111 | pbar.update(1) 112 | 113 | pbar.close() 114 | 115 | # 遍历源目录下所有的子目录(如果有) 116 | for subdir in subdirs: 117 | # 拼接子目录在源目录和目标目录中的完整路径 118 | src_subdir = os.path.join(src_dir, subdir) 119 | dst_subdir = os.path.join(dst_dir, subdir) 120 | # 递归地调用本函数,对子目录中的wav文件进行同样的操作,保留数据结构 121 | split_data(src_subdir, dst_subdir, ratio, extensions) 122 | 123 | # 定义主函数,用于获取用户输入并调用上述函数 124 | 125 | def main(cmd): 126 | dst_dir = cmd.val 127 | # 抽取比例,默认为1 128 | ratio = cmd.sample_rate / 100 129 | 130 | src_dir = cmd.train 131 | 132 | extensions = cmd.extensions 133 | 134 | # 调用split_data函数,对源目录中的wav文件进行抽取,并剪切到目标目录中,保留数据结构 135 | split_data(src_dir, dst_dir, ratio, extensions) 136 | 137 | # 如果本模块是主模块,则执行主函数 138 | if __name__ == "__main__": 139 | # parse commands 140 | cmd = parse_args() 141 | 142 | main(cmd) 143 | -------------------------------------------------------------------------------- /encoder/hubert/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, Tuple 3 | import random 4 | 5 | from sklearn.cluster import KMeans 6 | 7 | import torch 8 | try: 9 | import torch_musa 10 | except ImportError: 11 | pass 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 15 | 16 | URLS = { 17 | "hubert-discrete": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-discrete-e9416457.pt", 18 | "hubert-soft": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt", 19 | "kmeans100": "https://github.com/bshall/hubert/releases/download/v0.1/kmeans100-50f36a95.pt", 20 | } 21 | 22 | 23 | class Hubert(nn.Module): 24 | def __init__(self, num_label_embeddings: int = 100, mask: bool = True): 25 | super().__init__() 26 | self._mask = mask 27 | self.feature_extractor = FeatureExtractor() 28 | self.feature_projection = FeatureProjection() 29 | self.positional_embedding = PositionalConvEmbedding() 30 | self.norm = nn.LayerNorm(768) 31 | self.dropout = nn.Dropout(0.1) 32 | self.encoder = TransformerEncoder( 33 | nn.TransformerEncoderLayer( 34 | 768, 12, 3072, activation="gelu", batch_first=True 35 | ), 36 | 12, 37 | ) 38 | self.proj = nn.Linear(768, 256) 39 | 40 | self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) 41 | self.label_embedding = nn.Embedding(num_label_embeddings, 256) 42 | 43 | def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 44 | mask = None 45 | if self.training and self._mask: 46 | mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) 47 | x[mask] = self.masked_spec_embed.to(x.dtype) 48 | return x, mask 49 | 50 | def encode( 51 | self, x: torch.Tensor, layer: Optional[int] = None 52 | ) -> Tuple[torch.Tensor, torch.Tensor]: 53 | x = self.feature_extractor(x) 54 | x = self.feature_projection(x.transpose(1, 2)) 55 | x, mask = self.mask(x) 56 | x = x + self.positional_embedding(x) 57 | x = self.dropout(self.norm(x)) 58 | x = self.encoder(x, output_layer=layer) 59 | return x, mask 60 | 61 | def logits(self, x: torch.Tensor) -> torch.Tensor: 62 | logits = torch.cosine_similarity( 63 | x.unsqueeze(2), 64 | self.label_embedding.weight.unsqueeze(0).unsqueeze(0), 65 | dim=-1, 66 | ) 67 | return logits / 0.1 68 | 69 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 70 | x, mask = self.encode(x) 71 | x = self.proj(x) 72 | logits = self.logits(x) 73 | return logits, mask 74 | 75 | 76 | class HubertSoft(Hubert): 77 | def __init__(self): 78 | super().__init__() 79 | 80 | @torch.inference_mode() 81 | def units(self, wav: torch.Tensor) -> torch.Tensor: 82 | wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) 83 | x, _ = self.encode(wav) 84 | return self.proj(x) 85 | 86 | 87 | class HubertDiscrete(Hubert): 88 | def __init__(self, kmeans): 89 | super().__init__(504) 90 | self.kmeans = kmeans 91 | 92 | @torch.inference_mode() 93 | def units(self, wav: torch.Tensor) -> torch.LongTensor: 94 | wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) 95 | x, _ = self.encode(wav, layer=7) 96 | x = self.kmeans.predict(x.squeeze().cpu().numpy()) 97 | return torch.tensor(x, dtype=torch.long, device=wav.device) 98 | 99 | 100 | class FeatureExtractor(nn.Module): 101 | def __init__(self): 102 | super().__init__() 103 | self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) 104 | self.norm0 = nn.GroupNorm(512, 512) 105 | self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) 106 | self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) 107 | self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) 108 | self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) 109 | self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) 110 | self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) 111 | 112 | def forward(self, x: torch.Tensor) -> torch.Tensor: 113 | x = F.gelu(self.norm0(self.conv0(x))) 114 | x = F.gelu(self.conv1(x)) 115 | x = F.gelu(self.conv2(x)) 116 | x = F.gelu(self.conv3(x)) 117 | x = F.gelu(self.conv4(x)) 118 | x = F.gelu(self.conv5(x)) 119 | x = F.gelu(self.conv6(x)) 120 | return x 121 | 122 | 123 | class FeatureProjection(nn.Module): 124 | def __init__(self): 125 | super().__init__() 126 | self.norm = nn.LayerNorm(512) 127 | self.projection = nn.Linear(512, 768) 128 | self.dropout = nn.Dropout(0.1) 129 | 130 | def forward(self, x: torch.Tensor) -> torch.Tensor: 131 | x = self.norm(x) 132 | x = self.projection(x) 133 | x = self.dropout(x) 134 | return x 135 | 136 | 137 | class PositionalConvEmbedding(nn.Module): 138 | def __init__(self): 139 | super().__init__() 140 | self.conv = nn.Conv1d( 141 | 768, 142 | 768, 143 | kernel_size=128, 144 | padding=128 // 2, 145 | groups=16, 146 | ) 147 | self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) 148 | 149 | def forward(self, x: torch.Tensor) -> torch.Tensor: 150 | x = self.conv(x.transpose(1, 2)) 151 | x = F.gelu(x[:, :, :-1]) 152 | return x.transpose(1, 2) 153 | 154 | 155 | class TransformerEncoder(nn.Module): 156 | def __init__( 157 | self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int 158 | ) -> None: 159 | super(TransformerEncoder, self).__init__() 160 | self.layers = nn.ModuleList( 161 | [copy.deepcopy(encoder_layer) for _ in range(num_layers)] 162 | ) 163 | self.num_layers = num_layers 164 | 165 | def forward( 166 | self, 167 | src: torch.Tensor, 168 | mask: torch.Tensor = None, 169 | src_key_padding_mask: torch.Tensor = None, 170 | output_layer: Optional[int] = None, 171 | ) -> torch.Tensor: 172 | output = src 173 | for layer in self.layers[:output_layer]: 174 | output = layer( 175 | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask 176 | ) 177 | return output 178 | 179 | 180 | def _compute_mask( 181 | shape: Tuple[int, int], 182 | mask_prob: float, 183 | mask_length: int, 184 | device: torch.device, 185 | min_masks: int = 0, 186 | ) -> torch.Tensor: 187 | batch_size, sequence_length = shape 188 | 189 | if mask_length < 1: 190 | raise ValueError("`mask_length` has to be bigger than 0.") 191 | 192 | if mask_length > sequence_length: 193 | raise ValueError( 194 | f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" 195 | ) 196 | 197 | # compute number of masked spans in batch 198 | num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) 199 | num_masked_spans = max(num_masked_spans, min_masks) 200 | 201 | # make sure num masked indices <= sequence_length 202 | if num_masked_spans * mask_length > sequence_length: 203 | num_masked_spans = sequence_length // mask_length 204 | 205 | # SpecAugment mask to fill 206 | mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) 207 | 208 | # uniform distribution to sample from, make sure that offset samples are < sequence_length 209 | uniform_dist = torch.ones( 210 | (batch_size, sequence_length - (mask_length - 1)), device=device 211 | ) 212 | 213 | # get random indices to mask 214 | mask_indices = torch.multinomial(uniform_dist, num_masked_spans) 215 | 216 | # expand masked indices to masked spans 217 | mask_indices = ( 218 | mask_indices.unsqueeze(dim=-1) 219 | .expand((batch_size, num_masked_spans, mask_length)) 220 | .reshape(batch_size, num_masked_spans * mask_length) 221 | ) 222 | offsets = ( 223 | torch.arange(mask_length, device=device)[None, None, :] 224 | .expand((batch_size, num_masked_spans, mask_length)) 225 | .reshape(batch_size, num_masked_spans * mask_length) 226 | ) 227 | mask_idxs = mask_indices + offsets 228 | 229 | # scatter indices to mask 230 | mask = mask.scatter(1, mask_idxs, True) 231 | 232 | return mask 233 | 234 | 235 | def hubert_discrete( 236 | pretrained: bool = True, 237 | progress: bool = True, 238 | ) -> HubertDiscrete: 239 | r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. 240 | Args: 241 | pretrained (bool): load pretrained weights into the model 242 | progress (bool): show progress bar when downloading model 243 | """ 244 | kmeans = kmeans100(pretrained=pretrained, progress=progress) 245 | hubert = HubertDiscrete(kmeans) 246 | if pretrained: 247 | checkpoint = torch.hub.load_state_dict_from_url( 248 | URLS["hubert-discrete"], progress=progress 249 | ) 250 | consume_prefix_in_state_dict_if_present(checkpoint, "module.") 251 | hubert.load_state_dict(checkpoint) 252 | hubert.eval() 253 | return hubert 254 | 255 | 256 | def hubert_soft( 257 | pretrained: bool = True, 258 | progress: bool = True, 259 | ) -> HubertSoft: 260 | r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. 261 | Args: 262 | pretrained (bool): load pretrained weights into the model 263 | progress (bool): show progress bar when downloading model 264 | """ 265 | hubert = HubertSoft() 266 | if pretrained: 267 | checkpoint = torch.hub.load_state_dict_from_url( 268 | URLS["hubert-soft"], progress=progress 269 | ) 270 | consume_prefix_in_state_dict_if_present(checkpoint, "module.") 271 | hubert.load_state_dict(checkpoint) 272 | hubert.eval() 273 | return hubert 274 | 275 | 276 | def _kmeans( 277 | num_clusters: int, pretrained: bool = True, progress: bool = True 278 | ) -> KMeans: 279 | kmeans = KMeans(num_clusters) 280 | if pretrained: 281 | checkpoint = torch.hub.load_state_dict_from_url( 282 | URLS[f"kmeans{num_clusters}"], progress=progress 283 | ) 284 | kmeans.__dict__["n_features_in_"] = checkpoint["n_features_in_"] 285 | kmeans.__dict__["_n_threads"] = checkpoint["_n_threads"] 286 | kmeans.__dict__["cluster_centers_"] = checkpoint["cluster_centers_"].numpy() 287 | return kmeans 288 | 289 | 290 | def kmeans100(pretrained: bool = True, progress: bool = True) -> KMeans: 291 | r""" 292 | k-means checkpoint for HuBERT-Discrete with 100 clusters. 293 | Args: 294 | pretrained (bool): load pretrained weights into the model 295 | progress (bool): show progress bar when downloading model 296 | """ 297 | return _kmeans(100, pretrained, progress) -------------------------------------------------------------------------------- /encoder/rmvpe/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import * 2 | from .model import E2E, E2E0 3 | from .utils import to_local_average_f0, to_viterbi_f0 4 | from .inference import RMVPE 5 | from .spec import MelSpectrogram -------------------------------------------------------------------------------- /encoder/rmvpe/constants.py: -------------------------------------------------------------------------------- 1 | SAMPLE_RATE = 16000 2 | 3 | N_CLASS = 360 4 | 5 | N_MELS = 128 6 | MEL_FMIN = 30 7 | MEL_FMAX = 8000 8 | WINDOW_LENGTH = 1024 9 | CONST = 1997.3794084376191 10 | -------------------------------------------------------------------------------- /encoder/rmvpe/deepunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .constants import N_MELS 4 | 5 | 6 | class ConvBlockRes(nn.Module): 7 | def __init__(self, in_channels, out_channels, momentum=0.01): 8 | super(ConvBlockRes, self).__init__() 9 | self.conv = nn.Sequential( 10 | nn.Conv2d(in_channels=in_channels, 11 | out_channels=out_channels, 12 | kernel_size=(3, 3), 13 | stride=(1, 1), 14 | padding=(1, 1), 15 | bias=False), 16 | nn.BatchNorm2d(out_channels, momentum=momentum), 17 | nn.ReLU(), 18 | 19 | nn.Conv2d(in_channels=out_channels, 20 | out_channels=out_channels, 21 | kernel_size=(3, 3), 22 | stride=(1, 1), 23 | padding=(1, 1), 24 | bias=False), 25 | nn.BatchNorm2d(out_channels, momentum=momentum), 26 | nn.ReLU(), 27 | ) 28 | if in_channels != out_channels: 29 | self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) 30 | self.is_shortcut = True 31 | else: 32 | self.is_shortcut = False 33 | 34 | def forward(self, x): 35 | if self.is_shortcut: 36 | return self.conv(x) + self.shortcut(x) 37 | else: 38 | return self.conv(x) + x 39 | 40 | 41 | class ResEncoderBlock(nn.Module): 42 | def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): 43 | super(ResEncoderBlock, self).__init__() 44 | self.n_blocks = n_blocks 45 | self.conv = nn.ModuleList() 46 | self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) 47 | for i in range(n_blocks - 1): 48 | self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) 49 | self.kernel_size = kernel_size 50 | if self.kernel_size is not None: 51 | self.pool = nn.AvgPool2d(kernel_size=kernel_size) 52 | 53 | def forward(self, x): 54 | for i in range(self.n_blocks): 55 | x = self.conv[i](x) 56 | if self.kernel_size is not None: 57 | return x, self.pool(x) 58 | else: 59 | return x 60 | 61 | 62 | class ResDecoderBlock(nn.Module): 63 | def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): 64 | super(ResDecoderBlock, self).__init__() 65 | out_padding = (0, 1) if stride == (1, 2) else (1, 1) 66 | self.n_blocks = n_blocks 67 | self.conv1 = nn.Sequential( 68 | nn.ConvTranspose2d(in_channels=in_channels, 69 | out_channels=out_channels, 70 | kernel_size=(3, 3), 71 | stride=stride, 72 | padding=(1, 1), 73 | output_padding=out_padding, 74 | bias=False), 75 | nn.BatchNorm2d(out_channels, momentum=momentum), 76 | nn.ReLU(), 77 | ) 78 | self.conv2 = nn.ModuleList() 79 | self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) 80 | for i in range(n_blocks-1): 81 | self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) 82 | 83 | def forward(self, x, concat_tensor): 84 | x = self.conv1(x) 85 | x = torch.cat((x, concat_tensor), dim=1) 86 | for i in range(self.n_blocks): 87 | x = self.conv2[i](x) 88 | return x 89 | 90 | 91 | class Encoder(nn.Module): 92 | def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): 93 | super(Encoder, self).__init__() 94 | self.n_encoders = n_encoders 95 | self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) 96 | self.layers = nn.ModuleList() 97 | self.latent_channels = [] 98 | for i in range(self.n_encoders): 99 | self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)) 100 | self.latent_channels.append([out_channels, in_size]) 101 | in_channels = out_channels 102 | out_channels *= 2 103 | in_size //= 2 104 | self.out_size = in_size 105 | self.out_channel = out_channels 106 | 107 | def forward(self, x): 108 | concat_tensors = [] 109 | x = self.bn(x) 110 | for i in range(self.n_encoders): 111 | _, x = self.layers[i](x) 112 | concat_tensors.append(_) 113 | return x, concat_tensors 114 | 115 | 116 | class Intermediate(nn.Module): 117 | def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): 118 | super(Intermediate, self).__init__() 119 | self.n_inters = n_inters 120 | self.layers = nn.ModuleList() 121 | self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)) 122 | for i in range(self.n_inters-1): 123 | self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) 124 | 125 | def forward(self, x): 126 | for i in range(self.n_inters): 127 | x = self.layers[i](x) 128 | return x 129 | 130 | 131 | class Decoder(nn.Module): 132 | def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): 133 | super(Decoder, self).__init__() 134 | self.layers = nn.ModuleList() 135 | self.n_decoders = n_decoders 136 | for i in range(self.n_decoders): 137 | out_channels = in_channels // 2 138 | self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) 139 | in_channels = out_channels 140 | 141 | def forward(self, x, concat_tensors): 142 | for i in range(self.n_decoders): 143 | x = self.layers[i](x, concat_tensors[-1-i]) 144 | return x 145 | 146 | 147 | class TimbreFilter(nn.Module): 148 | def __init__(self, latent_rep_channels): 149 | super(TimbreFilter, self).__init__() 150 | self.layers = nn.ModuleList() 151 | for latent_rep in latent_rep_channels: 152 | self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0])) 153 | 154 | def forward(self, x_tensors): 155 | out_tensors = [] 156 | for i, layer in enumerate(self.layers): 157 | out_tensors.append(layer(x_tensors[i])) 158 | return out_tensors 159 | 160 | 161 | class DeepUnet(nn.Module): 162 | def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): 163 | super(DeepUnet, self).__init__() 164 | self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) 165 | self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) 166 | self.tf = TimbreFilter(self.encoder.latent_channels) 167 | self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) 168 | 169 | def forward(self, x): 170 | x, concat_tensors = self.encoder(x) 171 | x = self.intermediate(x) 172 | concat_tensors = self.tf(concat_tensors) 173 | x = self.decoder(x, concat_tensors) 174 | return x 175 | 176 | 177 | class DeepUnet0(nn.Module): 178 | def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): 179 | super(DeepUnet0, self).__init__() 180 | self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) 181 | self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) 182 | self.tf = TimbreFilter(self.encoder.latent_channels) 183 | self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) 184 | 185 | def forward(self, x): 186 | x, concat_tensors = self.encoder(x) 187 | x = self.intermediate(x) 188 | x = self.decoder(x, concat_tensors) 189 | return x 190 | -------------------------------------------------------------------------------- /encoder/rmvpe/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torchaudio.transforms import Resample 5 | from .constants import * 6 | from .model import E2E0, E2E 7 | from .spec import MelSpectrogram 8 | from .utils import to_local_average_f0, to_viterbi_f0 9 | 10 | class RMVPE: 11 | def __init__(self, model_path, hop_length=160): 12 | self.resample_kernel = {} 13 | model = E2E0(4, 1, (2, 2)) 14 | ckpt = torch.load(model_path) 15 | model.load_state_dict(ckpt['model'], strict=False) 16 | model.eval() 17 | self.model = model 18 | self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) 19 | self.resample_kernel = {} 20 | 21 | def mel2hidden(self, mel): 22 | with torch.no_grad(): 23 | n_frames = mel.shape[-1] 24 | mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant') 25 | hidden = self.model(mel) 26 | return hidden[:, :n_frames] 27 | 28 | def decode(self, hidden, thred=0.03, use_viterbi=False): 29 | if use_viterbi: 30 | f0 = to_viterbi_f0(hidden, thred=thred) 31 | else: 32 | f0 = to_local_average_f0(hidden, thred=thred) 33 | return f0 34 | 35 | def infer_from_audio(self, audio, sample_rate=16000, device=None, thred=0.03, use_viterbi=False): 36 | if device is None: 37 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 38 | audio = torch.from_numpy(audio).float().unsqueeze(0).to(device) 39 | if sample_rate == 16000: 40 | audio_res = audio 41 | else: 42 | key_str = str(sample_rate) 43 | if key_str not in self.resample_kernel: 44 | self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128) 45 | self.resample_kernel[key_str] = self.resample_kernel[key_str].to(device) 46 | audio_res = self.resample_kernel[key_str](audio) 47 | mel_extractor = self.mel_extractor.to(device) 48 | self.model = self.model.to(device) 49 | mel = mel_extractor(audio_res, center=True) 50 | hidden = self.mel2hidden(mel) 51 | f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi) 52 | return f0 -------------------------------------------------------------------------------- /encoder/rmvpe/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .deepunet import DeepUnet, DeepUnet0 4 | from .constants import * 5 | from .spec import MelSpectrogram 6 | from .seq import BiGRU 7 | 8 | 9 | class E2E(nn.Module): 10 | def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, 11 | en_out_channels=16): 12 | super(E2E, self).__init__() 13 | self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) 14 | self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) 15 | if n_gru: 16 | self.fc = nn.Sequential( 17 | BiGRU(3 * N_MELS, 256, n_gru), 18 | nn.Linear(512, N_CLASS), 19 | nn.Dropout(0.25), 20 | nn.Sigmoid() 21 | ) 22 | else: 23 | self.fc = nn.Sequential( 24 | nn.Linear(3 * N_MELS, N_CLASS), 25 | nn.Dropout(0.25), 26 | nn.Sigmoid() 27 | ) 28 | 29 | def forward(self, mel): 30 | mel = mel.transpose(-1, -2).unsqueeze(1) 31 | x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) 32 | x = self.fc(x) 33 | return x 34 | 35 | 36 | class E2E0(nn.Module): 37 | def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, 38 | en_out_channels=16): 39 | super(E2E0, self).__init__() 40 | self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) 41 | self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) 42 | if n_gru: 43 | self.fc = nn.Sequential( 44 | BiGRU(3 * N_MELS, 256, n_gru), 45 | nn.Linear(512, N_CLASS), 46 | nn.Dropout(0.25), 47 | nn.Sigmoid() 48 | ) 49 | else: 50 | self.fc = nn.Sequential( 51 | nn.Linear(3 * N_MELS, N_CLASS), 52 | nn.Dropout(0.25), 53 | nn.Sigmoid() 54 | ) 55 | 56 | def forward(self, mel): 57 | mel = mel.transpose(-1, -2).unsqueeze(1) 58 | x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) 59 | x = self.fc(x) 60 | return x 61 | -------------------------------------------------------------------------------- /encoder/rmvpe/seq.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BiGRU(nn.Module): 5 | def __init__(self, input_features, hidden_features, num_layers): 6 | super(BiGRU, self).__init__() 7 | self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) 8 | 9 | def forward(self, x): 10 | return self.gru(x)[0] 11 | 12 | 13 | class BiLSTM(nn.Module): 14 | def __init__(self, input_features, hidden_features, num_layers): 15 | super(BiLSTM, self).__init__() 16 | self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) 17 | 18 | def forward(self, x): 19 | return self.lstm(x)[0] 20 | 21 | -------------------------------------------------------------------------------- /encoder/rmvpe/spec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from librosa.filters import mel 5 | 6 | class MelSpectrogram(torch.nn.Module): 7 | def __init__( 8 | self, 9 | n_mel_channels, 10 | sampling_rate, 11 | win_length, 12 | hop_length, 13 | n_fft=None, 14 | mel_fmin=0, 15 | mel_fmax=None, 16 | clamp = 1e-5 17 | ): 18 | super().__init__() 19 | n_fft = win_length if n_fft is None else n_fft 20 | self.hann_window = {} 21 | mel_basis = mel( 22 | sr=sampling_rate, 23 | n_fft=n_fft, 24 | n_mels=n_mel_channels, 25 | fmin=mel_fmin, 26 | fmax=mel_fmax, 27 | htk=True) 28 | mel_basis = torch.from_numpy(mel_basis).float() 29 | self.register_buffer("mel_basis", mel_basis) 30 | self.n_fft = win_length if n_fft is None else n_fft 31 | self.hop_length = hop_length 32 | self.win_length = win_length 33 | self.sampling_rate = sampling_rate 34 | self.n_mel_channels = n_mel_channels 35 | self.clamp = clamp 36 | 37 | def forward(self, audio, keyshift=0, speed=1, center=True): 38 | factor = 2 ** (keyshift / 12) 39 | n_fft_new = int(np.round(self.n_fft * factor)) 40 | win_length_new = int(np.round(self.win_length * factor)) 41 | hop_length_new = int(np.round(self.hop_length * speed)) 42 | 43 | keyshift_key = str(keyshift)+'_'+str(audio.device) 44 | if keyshift_key not in self.hann_window: 45 | self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) 46 | 47 | fft = torch.stft( 48 | audio, 49 | n_fft=n_fft_new, 50 | hop_length=hop_length_new, 51 | win_length=win_length_new, 52 | window=self.hann_window[keyshift_key], 53 | center=center, 54 | return_complex=True) 55 | magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) 56 | 57 | if keyshift != 0: 58 | size = self.n_fft // 2 + 1 59 | resize = magnitude.size(1) 60 | if resize < size: 61 | magnitude = F.pad(magnitude, (0, 0, 0, size-resize)) 62 | magnitude = magnitude[:, :size, :] * self.win_length / win_length_new 63 | 64 | mel_output = torch.matmul(self.mel_basis, magnitude) 65 | log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) 66 | return log_mel_spec -------------------------------------------------------------------------------- /encoder/rmvpe/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import librosa 4 | import torch 5 | from functools import reduce 6 | from .constants import * 7 | from torch.nn.modules.module import _addindent 8 | 9 | 10 | def cycle(iterable): 11 | while True: 12 | for item in iterable: 13 | yield item 14 | 15 | 16 | def summary(model, file=sys.stdout): 17 | def repr(model): 18 | # We treat the extra repr like the sub-module, one item per line 19 | extra_lines = [] 20 | extra_repr = model.extra_repr() 21 | # empty string will be split into list [''] 22 | if extra_repr: 23 | extra_lines = extra_repr.split('\n') 24 | child_lines = [] 25 | total_params = 0 26 | for key, module in model._modules.items(): 27 | mod_str, num_params = repr(module) 28 | mod_str = _addindent(mod_str, 2) 29 | child_lines.append('(' + key + '): ' + mod_str) 30 | total_params += num_params 31 | lines = extra_lines + child_lines 32 | 33 | for name, p in model._parameters.items(): 34 | if hasattr(p, 'shape'): 35 | total_params += reduce(lambda x, y: x * y, p.shape) 36 | 37 | main_str = model._get_name() + '(' 38 | if lines: 39 | # simple one-liner info, which most builtin Modules will use 40 | if len(extra_lines) == 1 and not child_lines: 41 | main_str += extra_lines[0] 42 | else: 43 | main_str += '\n ' + '\n '.join(lines) + '\n' 44 | 45 | main_str += ')' 46 | if file is sys.stdout: 47 | main_str += ', \033[92m{:,}\033[0m params'.format(total_params) 48 | else: 49 | main_str += ', {:,} params'.format(total_params) 50 | return main_str, total_params 51 | 52 | string, count = repr(model) 53 | if file is not None: 54 | if isinstance(file, str): 55 | file = open(file, 'w') 56 | print(string, file=file) 57 | file.flush() 58 | 59 | return count 60 | 61 | 62 | def to_local_average_cents(salience, center=None, thred=0.03): 63 | """ 64 | find the weighted average cents near the argmax bin 65 | """ 66 | 67 | if not hasattr(to_local_average_cents, 'cents_mapping'): 68 | # the bin number-to-cents mapping 69 | to_local_average_cents.cents_mapping = ( 70 | 20 * np.arange(N_CLASS) + CONST) 71 | 72 | if salience.ndim == 1: 73 | if center is None: 74 | center = int(np.argmax(salience)) 75 | start = max(0, center - 4) 76 | end = min(len(salience), center + 5) 77 | salience = salience[start:end] 78 | product_sum = np.sum( 79 | salience * to_local_average_cents.cents_mapping[start:end]) 80 | weight_sum = np.sum(salience) 81 | return product_sum / weight_sum if np.max(salience) > thred else 0 82 | if salience.ndim == 2: 83 | return np.array([to_local_average_cents(salience[i, :], None, thred) for i in 84 | range(salience.shape[0])]) 85 | 86 | raise Exception("label should be either 1d or 2d ndarray") 87 | 88 | def to_viterbi_cents(salience, thred=0.03): 89 | # Create viterbi transition matrix 90 | if not hasattr(to_viterbi_cents, 'transition'): 91 | xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS)) 92 | transition = np.maximum(30 - abs(xx - yy), 0) 93 | transition = transition / transition.sum(axis=1, keepdims=True) 94 | to_viterbi_cents.transition = transition 95 | 96 | # Convert to probability 97 | prob = salience.T 98 | prob = prob / prob.sum(axis=0) 99 | 100 | # Perform viterbi decoding 101 | path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64) 102 | 103 | return np.array([to_local_average_cents(salience[i, :], path[i], thred) for i in 104 | range(len(path))]) 105 | 106 | def to_local_average_f0(hidden, center=None, thred=0.03): 107 | idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N] 108 | idx_cents = idx * 20 + CONST # [B=1, N] 109 | if center is None: 110 | center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1] 111 | start = torch.clip(center - 4, min=0) # [B, T, 1] 112 | end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1] 113 | idx_mask = (idx >= start) & (idx < end) # [B, T, N] 114 | weights = hidden * idx_mask # [B, T, N] 115 | product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T] 116 | weight_sum = torch.sum(weights, dim=2) # [B, T] 117 | cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T] 118 | f0 = 10 * 2 ** (cents / 1200) 119 | uv = hidden.max(dim=2)[0] < thred # [B, T] 120 | f0 = f0 * ~uv 121 | return f0.squeeze(0).cpu().numpy() 122 | 123 | def to_viterbi_f0(hidden, thred=0.03): 124 | # Create viterbi transition matrix 125 | if not hasattr(to_viterbi_cents, 'transition'): 126 | xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS)) 127 | transition = np.maximum(30 - abs(xx - yy), 0) 128 | transition = transition / transition.sum(axis=1, keepdims=True) 129 | to_viterbi_cents.transition = transition 130 | 131 | # Convert to probability 132 | prob = hidden.squeeze(0).cpu().numpy() 133 | prob = prob.T 134 | prob = prob / prob.sum(axis=0) 135 | 136 | # Perform viterbi decoding 137 | path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64) 138 | center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device) 139 | 140 | return to_local_average_f0(hidden, center=center, thred=thred) 141 | 142 | -------------------------------------------------------------------------------- /exp/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxlllc/ReFlow-VAE-SVC/0550d0efd84027e4547f698b49c58f19c52ab984/logger/__init__.py -------------------------------------------------------------------------------- /logger/saver.py: -------------------------------------------------------------------------------- 1 | ''' 2 | author: wayn391@mastertones 3 | ''' 4 | 5 | import os 6 | import json 7 | import time 8 | import yaml 9 | import datetime 10 | import torch 11 | try: 12 | import torch_musa 13 | except ImportError: 14 | pass 15 | import matplotlib 16 | matplotlib.use('Agg') 17 | import matplotlib.pyplot as plt 18 | from . import utils 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | class Saver(object): 22 | def __init__( 23 | self, 24 | args, 25 | initial_global_step=-1): 26 | 27 | self.expdir = args.env.expdir 28 | self.sample_rate = args.data.sampling_rate 29 | 30 | # cold start 31 | self.global_step = initial_global_step 32 | self.init_time = time.time() 33 | self.last_time = time.time() 34 | 35 | # makedirs 36 | os.makedirs(self.expdir, exist_ok=True) 37 | 38 | # path 39 | self.path_log_info = os.path.join(self.expdir, 'log_info.txt') 40 | 41 | # ckpt 42 | os.makedirs(self.expdir, exist_ok=True) 43 | 44 | # writer 45 | self.writer = SummaryWriter(os.path.join(self.expdir, 'logs')) 46 | 47 | # save config 48 | path_config = os.path.join(self.expdir, 'config.yaml') 49 | with open(path_config, "w") as out_config: 50 | yaml.dump(dict(args), out_config) 51 | 52 | 53 | def log_info(self, msg): 54 | '''log method''' 55 | if isinstance(msg, dict): 56 | msg_list = [] 57 | for k, v in msg.items(): 58 | tmp_str = '' 59 | if isinstance(v, int): 60 | tmp_str = '{}: {:,}'.format(k, v) 61 | else: 62 | tmp_str = '{}: {}'.format(k, v) 63 | 64 | msg_list.append(tmp_str) 65 | msg_str = '\n'.join(msg_list) 66 | else: 67 | msg_str = msg 68 | 69 | # dsplay 70 | print(msg_str) 71 | 72 | # save 73 | with open(self.path_log_info, 'a') as fp: 74 | fp.write(msg_str+'\n') 75 | 76 | def log_value(self, dict): 77 | for k, v in dict.items(): 78 | self.writer.add_scalar(k, v, self.global_step) 79 | 80 | def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5): 81 | spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1) 82 | spec = spec_cat[0] 83 | if isinstance(spec, torch.Tensor): 84 | spec = spec.cpu().numpy() 85 | fig = plt.figure(figsize=(12, 9)) 86 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax) 87 | plt.tight_layout() 88 | self.writer.add_figure(name, fig, self.global_step) 89 | 90 | def log_audio(self, dict): 91 | for k, v in dict.items(): 92 | self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate) 93 | 94 | def get_interval_time(self, update=True): 95 | cur_time = time.time() 96 | time_interval = cur_time - self.last_time 97 | if update: 98 | self.last_time = cur_time 99 | return time_interval 100 | 101 | def get_total_time(self, to_str=True): 102 | total_time = time.time() - self.init_time 103 | if to_str: 104 | total_time = str(datetime.timedelta( 105 | seconds=total_time))[:-5] 106 | return total_time 107 | 108 | def save_model( 109 | self, 110 | model, 111 | optimizer, 112 | name='model', 113 | postfix='', 114 | to_json=False): 115 | # path 116 | if postfix: 117 | postfix = '_' + postfix 118 | path_pt = os.path.join( 119 | self.expdir , name+postfix+'.pt') 120 | 121 | # check 122 | print(' [*] model checkpoint saved: {}'.format(path_pt)) 123 | 124 | # save 125 | if optimizer is not None: 126 | torch.save({ 127 | 'global_step': self.global_step, 128 | 'model': model.state_dict(), 129 | 'optimizer': optimizer.state_dict()}, path_pt) 130 | else: 131 | torch.save({ 132 | 'global_step': self.global_step, 133 | 'model': model.state_dict()}, path_pt) 134 | 135 | # to json 136 | if to_json: 137 | path_json = os.path.join( 138 | self.expdir , name+'.json') 139 | utils.to_json(path_params, path_json) 140 | 141 | def delete_model(self, name='model', postfix=''): 142 | # path 143 | if postfix: 144 | postfix = '_' + postfix 145 | path_pt = os.path.join( 146 | self.expdir , name+postfix+'.pt') 147 | 148 | # delete 149 | if os.path.exists(path_pt): 150 | os.remove(path_pt) 151 | print(' [*] model checkpoint deleted: {}'.format(path_pt)) 152 | 153 | def global_step_increment(self): 154 | self.global_step += 1 155 | 156 | 157 | -------------------------------------------------------------------------------- /logger/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import json 4 | import pickle 5 | import torch 6 | try: 7 | import torch_musa 8 | except ImportError: 9 | pass 10 | 11 | def traverse_dir( 12 | root_dir, 13 | extensions, 14 | amount=None, 15 | str_include=None, 16 | str_exclude=None, 17 | is_pure=False, 18 | is_sort=False, 19 | is_ext=True): 20 | 21 | file_list = [] 22 | cnt = 0 23 | for root, _, files in os.walk(root_dir): 24 | for file in files: 25 | if any([file.endswith(f".{ext}") for ext in extensions]): 26 | # path 27 | mix_path = os.path.join(root, file) 28 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 29 | 30 | # amount 31 | if (amount is not None) and (cnt == amount): 32 | if is_sort: 33 | file_list.sort() 34 | return file_list 35 | 36 | # check string 37 | if (str_include is not None) and (str_include not in pure_path): 38 | continue 39 | if (str_exclude is not None) and (str_exclude in pure_path): 40 | continue 41 | 42 | if not is_ext: 43 | ext = pure_path.split('.')[-1] 44 | pure_path = pure_path[:-(len(ext)+1)] 45 | file_list.append(pure_path) 46 | cnt += 1 47 | if is_sort: 48 | file_list.sort() 49 | return file_list 50 | 51 | 52 | 53 | class DotDict(dict): 54 | def __getattr__(*args): 55 | val = dict.get(*args) 56 | return DotDict(val) if type(val) is dict else val 57 | 58 | __setattr__ = dict.__setitem__ 59 | __delattr__ = dict.__delitem__ 60 | 61 | 62 | def get_network_paras_amount(model_dict): 63 | info = dict() 64 | for model_name, model in model_dict.items(): 65 | # all_params = sum(p.numel() for p in model.parameters()) 66 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 67 | 68 | info[model_name] = trainable_params 69 | return info 70 | 71 | 72 | def load_config(path_config): 73 | with open(path_config, "r") as config: 74 | args = yaml.safe_load(config) 75 | args = DotDict(args) 76 | # print(args) 77 | return args 78 | 79 | 80 | def to_json(path_params, path_json): 81 | params = torch.load(path_params, map_location=torch.device('cpu')) 82 | raw_state_dict = {} 83 | for k, v in params.items(): 84 | val = v.flatten().numpy().tolist() 85 | raw_state_dict[k] = val 86 | 87 | with open(path_json, 'w') as outfile: 88 | json.dump(raw_state_dict, outfile,indent= "\t") 89 | 90 | 91 | def convert_tensor_to_numpy(tensor, is_squeeze=True): 92 | if is_squeeze: 93 | tensor = tensor.squeeze() 94 | if tensor.requires_grad: 95 | tensor = tensor.detach() 96 | if tensor.is_cuda: 97 | tensor = tensor.cpu() 98 | return tensor.numpy() 99 | 100 | 101 | def load_model( 102 | expdir, 103 | model, 104 | optimizer, 105 | name='model', 106 | postfix='', 107 | device='cpu'): 108 | if postfix == '': 109 | postfix = '_' + postfix 110 | path = os.path.join(expdir, name+postfix) 111 | path_pt = traverse_dir(expdir, ['pt'], is_ext=False) 112 | global_step = 0 113 | if len(path_pt) > 0: 114 | steps = [s[len(path):] for s in path_pt] 115 | maxstep = max([int(s) if s.isdigit() else 0 for s in steps]) 116 | if maxstep >= 0: 117 | path_pt = path+str(maxstep)+'.pt' 118 | else: 119 | path_pt = path+'best.pt' 120 | print(' [*] restoring model from', path_pt) 121 | ckpt = torch.load(path_pt, map_location=torch.device(device)) 122 | global_step = ckpt['global_step'] 123 | model.load_state_dict(ckpt['model'], strict=False) 124 | if ckpt.get('optimizer') != None: 125 | optimizer.load_state_dict(ckpt['optimizer']) 126 | return global_step, model, optimizer 127 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | try: 4 | import torch_musa 5 | use_torch_musa = True 6 | except ImportError: 7 | use_torch_musa = False 8 | import librosa 9 | import argparse 10 | import numpy as np 11 | import soundfile as sf 12 | import pyworld as pw 13 | import parselmouth 14 | import hashlib 15 | import torch.nn.functional as F 16 | from ast import literal_eval 17 | from slicer import Slicer 18 | from reflow.extractors import F0_Extractor, Volume_Extractor, Units_Encoder 19 | from reflow.vocoder import load_model_vocoder 20 | from tqdm import tqdm 21 | 22 | 23 | def parse_args(args=None, namespace=None): 24 | """Parse command-line arguments.""" 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | "-m", 28 | "--model_ckpt", 29 | type=str, 30 | required=True, 31 | help="path to the model checkpoint", 32 | ) 33 | parser.add_argument( 34 | "-d", 35 | "--device", 36 | type=str, 37 | default=None, 38 | required=False, 39 | help="cpu/cuda/musa, auto if not set") 40 | parser.add_argument( 41 | "-i", 42 | "--input", 43 | type=str, 44 | required=True, 45 | help="path to the input audio file", 46 | ) 47 | parser.add_argument( 48 | "-o", 49 | "--output", 50 | type=str, 51 | required=True, 52 | help="path to the output audio file", 53 | ) 54 | parser.add_argument( 55 | "-sid", 56 | "--source_spk_id", 57 | type=str, 58 | required=False, 59 | default='none', 60 | help="source speaker id (for multi-speaker model) | default: none", 61 | ) 62 | parser.add_argument( 63 | "-tid", 64 | "--target_spk_id", 65 | type=str, 66 | required=False, 67 | default=1, 68 | help="target speaker id (for multi-speaker model) | default: 1", 69 | ) 70 | parser.add_argument( 71 | "-mix", 72 | "--spk_mix_dict", 73 | type=str, 74 | required=False, 75 | default="None", 76 | help="mix-speaker dictionary (for multi-speaker model) | default: None", 77 | ) 78 | parser.add_argument( 79 | "-k", 80 | "--key", 81 | type=str, 82 | required=False, 83 | default=0, 84 | help="key changed (number of semitones) | default: 0", 85 | ) 86 | parser.add_argument( 87 | "-f", 88 | "--formant_shift_key", 89 | type=str, 90 | required=False, 91 | default=0, 92 | help="formant changed (number of semitones) , only for pitch-augmented model| default: 0", 93 | ) 94 | parser.add_argument( 95 | "-pe", 96 | "--pitch_extractor", 97 | type=str, 98 | required=False, 99 | default='rmvpe', 100 | help="pitch extrator type: parselmouth, dio, harvest, crepe, fcpe, rmvpe (default)", 101 | ) 102 | parser.add_argument( 103 | "-fmin", 104 | "--f0_min", 105 | type=str, 106 | required=False, 107 | default=50, 108 | help="min f0 (Hz) | default: 50", 109 | ) 110 | parser.add_argument( 111 | "-fmax", 112 | "--f0_max", 113 | type=str, 114 | required=False, 115 | default=1100, 116 | help="max f0 (Hz) | default: 1100", 117 | ) 118 | parser.add_argument( 119 | "-th", 120 | "--threhold", 121 | type=str, 122 | required=False, 123 | default=-60, 124 | help="response threhold (dB) | default: -60", 125 | ) 126 | parser.add_argument( 127 | "-step", 128 | "--infer_step", 129 | type=str, 130 | required=False, 131 | default='auto', 132 | help="sample steps | default: auto", 133 | ) 134 | parser.add_argument( 135 | "-method", 136 | "--method", 137 | type=str, 138 | required=False, 139 | default='auto', 140 | help="euler or rk4 | default: auto", 141 | ) 142 | return parser.parse_args(args=args, namespace=namespace) 143 | 144 | 145 | def upsample(signal, factor): 146 | signal = signal.permute(0, 2, 1) 147 | signal = F.interpolate(torch.cat((signal,signal[:,:,-1:]),2), size=signal.shape[-1] * factor + 1, mode='linear', align_corners=True) 148 | signal = signal[:,:,:-1] 149 | return signal.permute(0, 2, 1) 150 | 151 | 152 | def split(audio, sample_rate, hop_size, db_thresh = -40, min_len = 5000): 153 | slicer = Slicer( 154 | sr=sample_rate, 155 | threshold=db_thresh, 156 | min_length=min_len) 157 | chunks = dict(slicer.slice(audio)) 158 | result = [] 159 | for k, v in chunks.items(): 160 | tag = v["split_time"].split(",") 161 | if tag[0] != tag[1]: 162 | start_frame = int(int(tag[0]) // hop_size) 163 | end_frame = int(int(tag[1]) // hop_size) 164 | if end_frame > start_frame: 165 | result.append(( 166 | start_frame, 167 | audio[int(start_frame * hop_size) : int(end_frame * hop_size)])) 168 | return result 169 | 170 | 171 | def cross_fade(a: np.ndarray, b: np.ndarray, idx: int): 172 | result = np.zeros(idx + b.shape[0]) 173 | fade_len = a.shape[0] - idx 174 | np.copyto(dst=result[:idx], src=a[:idx]) 175 | k = np.linspace(0, 1.0, num=fade_len, endpoint=True) 176 | result[idx: a.shape[0]] = (1 - k) * a[idx:] + k * b[: fade_len] 177 | np.copyto(dst=result[a.shape[0]:], src=b[fade_len:]) 178 | return result 179 | 180 | 181 | if __name__ == '__main__': 182 | # parse commands 183 | cmd = parse_args() 184 | 185 | #device = 'cpu' 186 | device = cmd.device 187 | if device is None: 188 | if torch.cuda.is_available(): 189 | device = 'cuda' 190 | elif use_torch_musa: 191 | if torch.musa.is_available(): 192 | device = 'musa' 193 | else: 194 | device = 'cpu' 195 | else: 196 | device = 'cpu' 197 | 198 | # load reflow model 199 | model, vocoder, args = load_model_vocoder(cmd.model_ckpt, device=device) 200 | 201 | # load input 202 | audio, sample_rate = librosa.load(cmd.input, sr=None) 203 | if len(audio.shape) > 1: 204 | audio = librosa.to_mono(audio) 205 | hop_size = args.data.block_size * sample_rate / args.data.sampling_rate 206 | 207 | # get MD5 hash from wav file 208 | md5_hash = "" 209 | with open(cmd.input, 'rb') as f: 210 | data = f.read() 211 | md5_hash = hashlib.md5(data).hexdigest() 212 | print("MD5: " + md5_hash) 213 | 214 | cache_dir_path = os.path.join(os.path.dirname(__file__), "cache") 215 | cache_file_path = os.path.join(cache_dir_path, f"{cmd.pitch_extractor}_{hop_size}_{cmd.f0_min}_{cmd.f0_max}_{md5_hash}.npy") 216 | 217 | is_cache_available = os.path.exists(cache_file_path) 218 | if is_cache_available: 219 | # f0 cache load 220 | print('Loading pitch curves for input audio from cache directory...') 221 | f0 = np.load(cache_file_path, allow_pickle=False) 222 | else: 223 | # extract f0 224 | print('Pitch extractor type: ' + cmd.pitch_extractor) 225 | pitch_extractor = F0_Extractor( 226 | cmd.pitch_extractor, 227 | sample_rate, 228 | hop_size, 229 | float(cmd.f0_min), 230 | float(cmd.f0_max)) 231 | print('Extracting the pitch curve of the input audio...') 232 | f0 = pitch_extractor.extract(audio, uv_interp = True, device = device) 233 | 234 | # f0 cache save 235 | os.makedirs(cache_dir_path, exist_ok=True) 236 | np.save(cache_file_path, f0, allow_pickle=False) 237 | 238 | # key change 239 | input_f0 = torch.from_numpy(f0).float().to(device).unsqueeze(-1).unsqueeze(0) 240 | output_f0 = input_f0 * 2 ** (float(cmd.key) / 12) 241 | 242 | # formant change 243 | formant_shift_key = torch.from_numpy(np.array([[float(cmd.formant_shift_key)]])).float().to(device) 244 | 245 | # source speaker id 246 | if cmd.source_spk_id == 'none': 247 | # load units encoder 248 | if args.data.encoder == 'cnhubertsoftfish': 249 | cnhubertsoft_gate = args.data.cnhubertsoft_gate 250 | else: 251 | cnhubertsoft_gate = 10 252 | units_encoder = Units_Encoder( 253 | args.data.encoder, 254 | args.data.encoder_ckpt, 255 | args.data.encoder_sample_rate, 256 | args.data.encoder_hop_size, 257 | cnhubertsoft_gate=cnhubertsoft_gate, 258 | device = device) 259 | # extract volume 260 | print('Extracting the volume envelope of the input audio...') 261 | volume_extractor = Volume_Extractor(hop_size) 262 | volume = volume_extractor.extract(audio) 263 | mask = (volume > 10 ** (float(cmd.threhold) / 20)).astype('float') 264 | mask = np.pad(mask, (4, 4), constant_values=(mask[0], mask[-1])) 265 | mask = np.array([np.max(mask[n : n + 9]) for n in range(len(mask) - 8)]) 266 | mask = torch.from_numpy(mask).float().to(device).unsqueeze(-1).unsqueeze(0) 267 | mask = upsample(mask, args.data.block_size).squeeze(-1) 268 | volume = torch.from_numpy(volume).float().to(device).unsqueeze(-1).unsqueeze(0) 269 | 270 | else: 271 | source_spk_id = torch.LongTensor(np.array([[int(cmd.source_spk_id)]])).to(device) 272 | print('Using VAE mode...') 273 | print('Source Speaker ID: '+ str(int(cmd.source_spk_id))) 274 | 275 | # targer speaker id or mix-speaker dictionary 276 | spk_mix_dict = literal_eval(cmd.spk_mix_dict) 277 | target_spk_id = torch.LongTensor(np.array([[int(cmd.target_spk_id)]])).to(device) 278 | if spk_mix_dict is not None: 279 | print('Mix-speaker mode') 280 | else: 281 | print('Target Speaker ID: '+ str(int(cmd.target_spk_id))) 282 | 283 | # sampling method 284 | if cmd.method == 'auto': 285 | method = args.infer.method 286 | else: 287 | method = cmd.method 288 | 289 | # infer step 290 | if cmd.infer_step == 'auto': 291 | infer_step = args.infer.infer_step 292 | else: 293 | infer_step = int(cmd.infer_step) 294 | 295 | if infer_step < 0: 296 | print('infer step cannot be negative!') 297 | exit(0) 298 | 299 | # forward and save the output 300 | result = np.zeros(0) 301 | current_length = 0 302 | segments = split(audio, sample_rate, hop_size) 303 | print('Cut the input audio into ' + str(len(segments)) + ' slices') 304 | with torch.no_grad(): 305 | for segment in tqdm(segments): 306 | start_frame = segment[0] 307 | seg_input = torch.from_numpy(segment[1]).float().unsqueeze(0).to(device) 308 | if cmd.source_spk_id == 'none': 309 | seg_units = units_encoder.encode(seg_input, sample_rate, hop_size) 310 | seg_f0 = output_f0[:, start_frame : start_frame + seg_units.size(1), :] 311 | seg_volume = volume[:, start_frame : start_frame + seg_units.size(1), :] 312 | 313 | seg_output = model( 314 | seg_units, 315 | seg_f0, 316 | seg_volume, 317 | spk_id = target_spk_id, 318 | spk_mix_dict = spk_mix_dict, 319 | aug_shift = formant_shift_key, 320 | vocoder=vocoder, 321 | infer=True, 322 | return_wav=True, 323 | infer_step=infer_step, 324 | method=method) 325 | seg_output *= mask[:, start_frame * args.data.block_size : (start_frame + seg_units.size(1)) * args.data.block_size] 326 | else: 327 | seg_input_mel = vocoder.extract(seg_input, sample_rate) 328 | seg_input_mel = torch.cat((seg_input_mel, seg_input_mel[:,-1:,:]), 1) 329 | seg_input_f0 = input_f0[:, start_frame : start_frame + seg_input_mel.size(1), :] 330 | seg_output_f0 = output_f0[:, start_frame : start_frame + seg_input_mel.size(1), :] 331 | 332 | seg_output_mel = model.vae_infer( 333 | seg_input_mel, 334 | seg_input_f0, 335 | source_spk_id, 336 | seg_output_f0, 337 | target_spk_id, 338 | spk_mix_dict, 339 | formant_shift_key, 340 | infer_step, 341 | method) 342 | seg_output = vocoder.infer(seg_output_mel, seg_output_f0) 343 | 344 | seg_output = seg_output.squeeze().cpu().numpy() 345 | 346 | silent_length = round(start_frame * args.data.block_size) - current_length 347 | if silent_length >= 0: 348 | result = np.append(result, np.zeros(silent_length)) 349 | result = np.append(result, seg_output) 350 | else: 351 | result = cross_fade(result, seg_output, current_length + silent_length) 352 | current_length = current_length + silent_length + len(seg_output) 353 | sf.write(cmd.output, result, args.data.sampling_rate) 354 | -------------------------------------------------------------------------------- /nsf_hifigan/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /nsf_hifigan/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import torch 5 | try: 6 | import torch_musa 7 | use_torch_musa = True 8 | except ImportError: 9 | use_torch_musa = False 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d 13 | # from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm 14 | 15 | LRELU_SLOPE = 0.1 16 | _OLD_WEIGHT_NORM = False 17 | try: 18 | from torch.nn.utils.parametrizations import weight_norm 19 | except ImportError: 20 | from torch.nn.utils import weight_norm 21 | from torch.nn.utils import remove_weight_norm 22 | _OLD_WEIGHT_NORM = True 23 | 24 | try: 25 | from torch.nn.utils.parametrizations import spectral_norm 26 | except ImportError: 27 | from torch.nn.utils import spectral_norm 28 | 29 | 30 | class AttrDict(dict): 31 | def __init__(self, *args, **kwargs): 32 | super(AttrDict, self).__init__(*args, **kwargs) 33 | self.__dict__ = self 34 | 35 | def load_model(model_path, device='cuda'): 36 | h = load_config(model_path) 37 | 38 | generator = Generator(h).to(device) 39 | 40 | cp_dict = torch.load(model_path, map_location=device) 41 | generator.load_state_dict(cp_dict['generator']) 42 | generator.eval() 43 | generator.remove_weight_norm() 44 | del cp_dict 45 | return generator, h 46 | 47 | def load_config(model_path): 48 | config_file = os.path.join(os.path.split(model_path)[0], 'config.json') 49 | with open(config_file) as f: 50 | data = f.read() 51 | 52 | json_config = json.loads(data) 53 | h = AttrDict(json_config) 54 | return h 55 | 56 | def init_weights(m, mean=0.0, std=0.01): 57 | classname = m.__class__.__name__ 58 | if classname.find("Conv") != -1: 59 | m.weight.data.normal_(mean, std) 60 | 61 | 62 | def get_padding(kernel_size, dilation=1): 63 | return int((kernel_size * dilation - dilation) / 2) 64 | 65 | class ResBlock1(torch.nn.Module): 66 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 67 | super(ResBlock1, self).__init__() 68 | self.h = h 69 | self.convs1 = nn.ModuleList([ 70 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 71 | padding=get_padding(kernel_size, dilation[0]))), 72 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 73 | padding=get_padding(kernel_size, dilation[1]))), 74 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 75 | padding=get_padding(kernel_size, dilation[2]))) 76 | ]) 77 | self.convs1.apply(init_weights) 78 | 79 | self.convs2 = nn.ModuleList([ 80 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 81 | padding=get_padding(kernel_size, 1))), 82 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 83 | padding=get_padding(kernel_size, 1))), 84 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 85 | padding=get_padding(kernel_size, 1))) 86 | ]) 87 | self.convs2.apply(init_weights) 88 | 89 | def forward(self, x): 90 | for c1, c2 in zip(self.convs1, self.convs2): 91 | xt = F.leaky_relu(x, LRELU_SLOPE) 92 | xt = c1(xt) 93 | xt = F.leaky_relu(xt, LRELU_SLOPE) 94 | xt = c2(xt) 95 | x = xt + x 96 | return x 97 | 98 | def remove_weight_norm(self): 99 | global _OLD_WEIGHT_NORM 100 | if _OLD_WEIGHT_NORM: 101 | for l in self.convs1: 102 | remove_weight_norm(l) 103 | for l in self.convs2: 104 | remove_weight_norm(l) 105 | else: 106 | for l in self.convs1: 107 | torch.nn.utils.parametrize.remove_parametrizations(l) 108 | for l in self.convs2: 109 | torch.nn.utils.parametrize.remove_parametrizations(l) 110 | 111 | 112 | 113 | class ResBlock2(torch.nn.Module): 114 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 115 | super(ResBlock2, self).__init__() 116 | self.h = h 117 | self.convs = nn.ModuleList([ 118 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 119 | padding=get_padding(kernel_size, dilation[0]))), 120 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 121 | padding=get_padding(kernel_size, dilation[1]))) 122 | ]) 123 | self.convs.apply(init_weights) 124 | 125 | def forward(self, x): 126 | for c in self.convs: 127 | xt = F.leaky_relu(x, LRELU_SLOPE) 128 | xt = c(xt) 129 | x = xt + x 130 | return x 131 | 132 | def remove_weight_norm(self): 133 | 134 | global _OLD_WEIGHT_NORM 135 | if _OLD_WEIGHT_NORM: 136 | for l in self.convs: 137 | remove_weight_norm(l) 138 | 139 | else: 140 | for l in self.convs: 141 | torch.nn.utils.parametrize.remove_parametrizations(l) 142 | 143 | 144 | 145 | class SineGen(torch.nn.Module): 146 | """ Definition of sine generator 147 | SineGen(samp_rate, harmonic_num = 0, 148 | sine_amp = 0.1, noise_std = 0.003, 149 | voiced_threshold = 0, 150 | flag_for_pulse=False) 151 | samp_rate: sampling rate in Hz 152 | harmonic_num: number of harmonic overtones (default 0) 153 | sine_amp: amplitude of sine-waveform (default 0.1) 154 | noise_std: std of Gaussian noise (default 0.003) 155 | voiced_threshold: F0 threshold for U/V classification (default 0) 156 | flag_for_pulse: this SinGen is used inside PulseGen (default False) 157 | Note: when flag_for_pulse is True, the first time step of a voiced 158 | segment is always sin(np.pi) or cos(0) 159 | """ 160 | 161 | def __init__(self, samp_rate, harmonic_num=0, 162 | sine_amp=0.1, noise_std=0.003, 163 | voiced_threshold=0): 164 | super(SineGen, self).__init__() 165 | self.sine_amp = sine_amp 166 | self.noise_std = noise_std 167 | self.harmonic_num = harmonic_num 168 | self.dim = self.harmonic_num + 1 169 | self.sampling_rate = samp_rate 170 | self.voiced_threshold = voiced_threshold 171 | 172 | def _f02uv(self, f0): 173 | # generate uv signal 174 | uv = torch.ones_like(f0) 175 | uv = uv * (f0 > self.voiced_threshold) 176 | return uv 177 | 178 | def _f02sine(self, f0, upp): 179 | """ f0: (batchsize, length, dim) 180 | where dim indicates fundamental tone and overtones 181 | """ 182 | rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, device=f0.device) 183 | rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5 184 | rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0) 185 | rad += F.pad(rad_acc, (0, 0, 1, -1)) 186 | rad = rad.reshape(f0.shape[0], -1, 1) 187 | rad = torch.multiply(rad, torch.arange(1, self.dim + 1, device=f0.device).reshape(1, 1, -1)) 188 | rand_ini = torch.rand(1, 1, self.dim, device=f0.device) 189 | rand_ini[..., 0] = 0 190 | rad += rand_ini 191 | sines = torch.sin(2 * np.pi * rad) 192 | return sines 193 | 194 | @torch.no_grad() 195 | def forward(self, f0, upp): 196 | """ sine_tensor, uv = forward(f0) 197 | input F0: tensor(batchsize=1, length, dim=1) 198 | f0 for unvoiced steps should be 0 199 | output sine_tensor: tensor(batchsize=1, length, dim) 200 | output uv: tensor(batchsize=1, length, 1) 201 | """ 202 | f0 = f0.unsqueeze(-1) 203 | sine_waves = self._f02sine(f0, upp) * self.sine_amp 204 | uv = (f0 > self.voiced_threshold).float() 205 | uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) 206 | noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 207 | noise = noise_amp * torch.randn_like(sine_waves) 208 | sine_waves = sine_waves * uv + noise 209 | return sine_waves 210 | 211 | 212 | class SourceModuleHnNSF(torch.nn.Module): 213 | """ SourceModule for hn-nsf 214 | SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, 215 | add_noise_std=0.003, voiced_threshod=0) 216 | sampling_rate: sampling_rate in Hz 217 | harmonic_num: number of harmonic above F0 (default: 0) 218 | sine_amp: amplitude of sine source signal (default: 0.1) 219 | add_noise_std: std of additive Gaussian noise (default: 0.003) 220 | note that amplitude of noise in unvoiced is decided 221 | by sine_amp 222 | voiced_threshold: threhold to set U/V given F0 (default: 0) 223 | Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) 224 | F0_sampled (batchsize, length, 1) 225 | Sine_source (batchsize, length, 1) 226 | noise_source (batchsize, length 1) 227 | uv (batchsize, length, 1) 228 | """ 229 | 230 | def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, 231 | add_noise_std=0.003, voiced_threshold=0): 232 | super(SourceModuleHnNSF, self).__init__() 233 | 234 | self.sine_amp = sine_amp 235 | self.noise_std = add_noise_std 236 | 237 | # to produce sine waveforms 238 | self.l_sin_gen = SineGen(sampling_rate, harmonic_num, 239 | sine_amp, add_noise_std, voiced_threshold) 240 | 241 | # to merge source harmonics into a single excitation 242 | self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) 243 | self.l_tanh = torch.nn.Tanh() 244 | 245 | def forward(self, x, upp): 246 | sine_wavs = self.l_sin_gen(x, upp) 247 | sine_merge = self.l_tanh(self.l_linear(sine_wavs)) 248 | return sine_merge 249 | 250 | 251 | class Generator(torch.nn.Module): 252 | def __init__(self, h): 253 | super(Generator, self).__init__() 254 | self.h = h 255 | self.num_kernels = len(h.resblock_kernel_sizes) 256 | self.num_upsamples = len(h.upsample_rates) 257 | self.m_source = SourceModuleHnNSF( 258 | sampling_rate=h.sampling_rate, 259 | harmonic_num=8 260 | ) 261 | self.noise_convs = nn.ModuleList() 262 | self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) 263 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 264 | 265 | self.ups = nn.ModuleList() 266 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 267 | c_cur = h.upsample_initial_channel // (2 ** (i + 1)) 268 | self.ups.append(weight_norm( 269 | ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)), 270 | k, u, padding=(k - u) // 2))) 271 | if i + 1 < len(h.upsample_rates): # 272 | stride_f0 = int(np.prod(h.upsample_rates[i + 1:])) 273 | self.noise_convs.append(Conv1d( 274 | 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) 275 | else: 276 | self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) 277 | self.resblocks = nn.ModuleList() 278 | ch = h.upsample_initial_channel 279 | for i in range(len(self.ups)): 280 | ch //= 2 281 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 282 | self.resblocks.append(resblock(h, ch, k, d)) 283 | 284 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 285 | self.ups.apply(init_weights) 286 | self.conv_post.apply(init_weights) 287 | self.upp = int(np.prod(h.upsample_rates)) 288 | 289 | def forward(self, x, f0): 290 | har_source = self.m_source(f0, self.upp).transpose(1, 2) 291 | x = self.conv_pre(x) 292 | for i in range(self.num_upsamples): 293 | x = F.leaky_relu(x, LRELU_SLOPE) 294 | x = self.ups[i](x) 295 | x_source = self.noise_convs[i](har_source) 296 | x = x + x_source 297 | xs = None 298 | for j in range(self.num_kernels): 299 | if xs is None: 300 | xs = self.resblocks[i * self.num_kernels + j](x) 301 | else: 302 | xs += self.resblocks[i * self.num_kernels + j](x) 303 | x = xs / self.num_kernels 304 | x = F.leaky_relu(x) 305 | x = self.conv_post(x) 306 | x = torch.tanh(x) 307 | 308 | return x 309 | 310 | def remove_weight_norm(self): 311 | # rank_zero_info('Removing weight norm...') 312 | print('Removing weight norm...') 313 | global _OLD_WEIGHT_NORM 314 | if _OLD_WEIGHT_NORM: 315 | for l in self.ups: 316 | remove_weight_norm(l) 317 | for l in self.resblocks: 318 | l.remove_weight_norm() 319 | 320 | remove_weight_norm(self.conv_pre) 321 | remove_weight_norm(self.conv_post) 322 | #else: 323 | # for l in self.ups: 324 | # torch.nn.utils.parametrize.remove_parametrizations(l) 325 | # for l in self.resblocks: 326 | # l.remove_weight_norm() 327 | # 328 | # torch.nn.utils.parametrize.remove_parametrizations(self.conv_pre) 329 | # torch.nn.utils.parametrize.remove_parametrizations(self.conv_post) 330 | 331 | 332 | 333 | class DiscriminatorP(torch.nn.Module): 334 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 335 | super(DiscriminatorP, self).__init__() 336 | self.period = period 337 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 338 | self.convs = nn.ModuleList( 339 | [ 340 | norm_f( 341 | Conv2d( 342 | 1, 343 | 32, 344 | (kernel_size, 1), 345 | (stride, 1), 346 | padding=(get_padding(5, 1), 0), 347 | ) 348 | ), 349 | norm_f( 350 | Conv2d( 351 | 32, 352 | 128, 353 | (kernel_size, 1), 354 | (stride, 1), 355 | padding=(get_padding(5, 1), 0), 356 | ) 357 | ), 358 | norm_f( 359 | Conv2d( 360 | 128, 361 | 512, 362 | (kernel_size, 1), 363 | (stride, 1), 364 | padding=(get_padding(5, 1), 0), 365 | ) 366 | ), 367 | norm_f( 368 | Conv2d( 369 | 512, 370 | 1024, 371 | (kernel_size, 1), 372 | (stride, 1), 373 | padding=(get_padding(5, 1), 0), 374 | ) 375 | ), 376 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 377 | ] 378 | ) 379 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 380 | 381 | def forward(self, x): 382 | fmap = [] 383 | 384 | # 1d to 2d 385 | b, c, t = x.shape 386 | if t % self.period != 0: # pad first 387 | n_pad = self.period - (t % self.period) 388 | x = F.pad(x, (0, n_pad), "reflect") 389 | t = t + n_pad 390 | x = x.view(b, c, t // self.period, self.period) 391 | 392 | for l in self.convs: 393 | x = l(x) 394 | x = F.leaky_relu(x, LRELU_SLOPE, inplace=True) 395 | x = torch.nan_to_num(x) 396 | 397 | fmap.append(x) 398 | 399 | x = self.conv_post(x) 400 | x = torch.nan_to_num(x) 401 | fmap.append(x) 402 | x = torch.flatten(x, 1, -1) 403 | 404 | return x, fmap 405 | 406 | 407 | class MultiPeriodDiscriminator(torch.nn.Module): 408 | def __init__(self, periods=None): 409 | super(MultiPeriodDiscriminator, self).__init__() 410 | self.periods = periods if periods is not None else [2, 3, 5, 7, 11] 411 | self.discriminators = nn.ModuleList() 412 | for period in self.periods: 413 | self.discriminators.append(DiscriminatorP(period)) 414 | 415 | def forward(self, y): 416 | y_d_rs = [] 417 | 418 | fmap_rs = [] 419 | 420 | 421 | for i, d in enumerate(self.discriminators): 422 | y_d_r, fmap_r = d(y) 423 | 424 | y_d_rs.append(y_d_r) 425 | fmap_rs.append(fmap_r) 426 | 427 | 428 | return y_d_rs, fmap_rs, 429 | 430 | 431 | class DiscriminatorS(torch.nn.Module): 432 | def __init__(self, use_spectral_norm=False): 433 | super(DiscriminatorS, self).__init__() 434 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 435 | self.convs = nn.ModuleList( 436 | [ 437 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 438 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 439 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 440 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 441 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 442 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 443 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 444 | ] 445 | ) 446 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 447 | 448 | def forward(self, x): 449 | fmap = [] 450 | for l in self.convs: 451 | x = l(x) 452 | x = F.leaky_relu(x, LRELU_SLOPE, inplace=True) 453 | x = torch.nan_to_num(x) 454 | fmap.append(x) 455 | 456 | x = self.conv_post(x) 457 | x = torch.nan_to_num(x) 458 | fmap.append(x) 459 | x = torch.flatten(x, 1, -1) 460 | 461 | return x, fmap 462 | 463 | 464 | class MultiScaleDiscriminator(torch.nn.Module): 465 | def __init__(self): 466 | super(MultiScaleDiscriminator, self).__init__() 467 | self.discriminators = nn.ModuleList( 468 | [ 469 | DiscriminatorS(use_spectral_norm=True), 470 | DiscriminatorS(), 471 | DiscriminatorS(), 472 | ] 473 | ) 474 | self.meanpools = nn.ModuleList( 475 | [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] 476 | ) 477 | 478 | def forward(self, y): 479 | y_d_rs = [] 480 | 481 | fmap_rs = [] 482 | 483 | for i, d in enumerate(self.discriminators): 484 | if i != 0: 485 | y = self.meanpools[i - 1](y) 486 | 487 | y_d_r, fmap_r = d(y) 488 | 489 | y_d_rs.append(y_d_r) 490 | fmap_rs.append(fmap_r) 491 | 492 | 493 | return y_d_rs, fmap_rs, 494 | 495 | 496 | def feature_loss(fmap_r, fmap_g): 497 | loss = 0 498 | for dr, dg in zip(fmap_r, fmap_g): 499 | for rl, gl in zip(dr, dg): 500 | loss += torch.mean(torch.abs(rl - gl)) 501 | 502 | return loss * 2 503 | 504 | 505 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 506 | loss = 0 507 | r_losses = [] 508 | g_losses = [] 509 | 510 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 511 | r_loss = torch.mean((1 - dr) ** 2) 512 | g_loss = torch.mean(dg**2) 513 | loss += r_loss + g_loss 514 | r_losses.append(r_loss.item()) 515 | g_losses.append(g_loss.item()) 516 | 517 | return loss, r_losses, g_losses 518 | 519 | 520 | def generator_loss(disc_outputs): 521 | loss = 0 522 | gen_losses = [] 523 | 524 | for dg in disc_outputs: 525 | l = torch.mean((1 - dg) ** 2) 526 | gen_losses.append(l) 527 | loss += l 528 | 529 | return loss, gen_losses 530 | -------------------------------------------------------------------------------- /nsf_hifigan/nvSTFT.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | os.environ["LRU_CACHE_CAPACITY"] = "3" 4 | import random 5 | import torch 6 | try: 7 | import torch_musa 8 | except ImportError: 9 | pass 10 | import torch.utils.data 11 | import numpy as np 12 | import librosa 13 | from librosa.util import normalize 14 | from librosa.filters import mel as librosa_mel_fn 15 | from scipy.io.wavfile import read 16 | import soundfile as sf 17 | import torch.nn.functional as F 18 | 19 | def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): 20 | sampling_rate = None 21 | try: 22 | data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. 23 | except Exception as ex: 24 | print(f"'{full_path}' failed to load.\nException:") 25 | print(ex) 26 | if return_empty_on_exception: 27 | return [], sampling_rate or target_sr or 48000 28 | else: 29 | raise Exception(ex) 30 | 31 | if len(data.shape) > 1: 32 | data = data[:, 0] 33 | assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) 34 | 35 | if np.issubdtype(data.dtype, np.integer): # if audio data is type int 36 | max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX 37 | else: # if audio data is type fp32 38 | max_mag = max(np.amax(data), -np.amin(data)) 39 | max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 40 | 41 | data = torch.FloatTensor(data.astype(np.float32))/max_mag 42 | 43 | if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except 44 | return [], sampling_rate or target_sr or 48000 45 | if target_sr is not None and sampling_rate != target_sr: 46 | data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) 47 | sampling_rate = target_sr 48 | 49 | return data, sampling_rate 50 | 51 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 52 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 53 | 54 | def dynamic_range_decompression(x, C=1): 55 | return np.exp(x) / C 56 | 57 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 58 | return torch.log(torch.clamp(x, min=clip_val) * C) 59 | 60 | def dynamic_range_decompression_torch(x, C=1): 61 | return torch.exp(x) / C 62 | 63 | class STFT(): 64 | def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): 65 | self.target_sr = sr 66 | 67 | self.n_mels = n_mels 68 | self.n_fft = n_fft 69 | self.win_size = win_size 70 | self.hop_length = hop_length 71 | self.fmin = fmin 72 | self.fmax = fmax 73 | self.clip_val = clip_val 74 | self.mel_basis = {} 75 | self.hann_window = {} 76 | 77 | def get_mel(self, y, keyshift=0, speed=1, center=False): 78 | sampling_rate = self.target_sr 79 | n_mels = self.n_mels 80 | n_fft = self.n_fft 81 | win_size = self.win_size 82 | hop_length = self.hop_length 83 | fmin = self.fmin 84 | fmax = self.fmax 85 | clip_val = self.clip_val 86 | 87 | factor = 2 ** (keyshift / 12) 88 | n_fft_new = int(np.round(n_fft * factor)) 89 | win_size_new = int(np.round(win_size * factor)) 90 | hop_length_new = int(np.round(hop_length * speed)) 91 | 92 | mel_basis_key = str(fmax)+'_'+str(y.device) 93 | if mel_basis_key not in self.mel_basis: 94 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) 95 | self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) 96 | 97 | keyshift_key = str(keyshift)+'_'+str(y.device) 98 | if keyshift_key not in self.hann_window: 99 | self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device) 100 | 101 | pad_left = (win_size_new - hop_length_new) //2 102 | pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left) 103 | if pad_right < y.size(-1): 104 | mode = 'reflect' 105 | else: 106 | mode = 'constant' 107 | y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode) 108 | y = y.squeeze(1) 109 | 110 | spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key], 111 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 112 | spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9)) 113 | if keyshift != 0: 114 | size = n_fft // 2 + 1 115 | resize = spec.size(1) 116 | if resize < size: 117 | spec = F.pad(spec, (0, 0, 0, size-resize)) 118 | spec = spec[:, :size, :] * win_size / win_size_new 119 | spec = torch.matmul(self.mel_basis[mel_basis_key], spec) 120 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val) 121 | return spec 122 | 123 | def __call__(self, audiopath): 124 | audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) 125 | spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) 126 | return spect 127 | 128 | stft = STFT() 129 | -------------------------------------------------------------------------------- /nsf_hifigan/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | try: 6 | import torch_musa 7 | except ImportError: 8 | pass 9 | from torch.nn.utils import weight_norm 10 | matplotlib.use("Agg") 11 | import matplotlib.pylab as plt 12 | 13 | 14 | def plot_spectrogram(spectrogram): 15 | fig, ax = plt.subplots(figsize=(10, 2)) 16 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 17 | interpolation='none') 18 | plt.colorbar(im, ax=ax) 19 | 20 | fig.canvas.draw() 21 | plt.close() 22 | 23 | return fig 24 | 25 | 26 | def init_weights(m, mean=0.0, std=0.01): 27 | classname = m.__class__.__name__ 28 | if classname.find("Conv") != -1: 29 | m.weight.data.normal_(mean, std) 30 | 31 | 32 | def apply_weight_norm(m): 33 | classname = m.__class__.__name__ 34 | if classname.find("Conv") != -1: 35 | weight_norm(m) 36 | 37 | 38 | def get_padding(kernel_size, dilation=1): 39 | return int((kernel_size*dilation - dilation)/2) 40 | 41 | 42 | def load_checkpoint(filepath, device): 43 | assert os.path.isfile(filepath) 44 | print("Loading '{}'".format(filepath)) 45 | checkpoint_dict = torch.load(filepath, map_location=device) 46 | print("Complete.") 47 | return checkpoint_dict 48 | 49 | 50 | def save_checkpoint(filepath, obj): 51 | print("Saving checkpoint to {}".format(filepath)) 52 | torch.save(obj, filepath) 53 | print("Complete.") 54 | 55 | 56 | def del_old_checkpoints(cp_dir, prefix, n_models=2): 57 | pattern = os.path.join(cp_dir, prefix + '????????') 58 | cp_list = glob.glob(pattern) # get checkpoint paths 59 | cp_list = sorted(cp_list)# sort by iter 60 | if len(cp_list) > n_models: # if more than n_models models are found 61 | for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models 62 | open(cp, 'w').close()# empty file contents 63 | os.unlink(cp)# delete file (move to trash when using Colab) 64 | 65 | 66 | def scan_checkpoint(cp_dir, prefix): 67 | pattern = os.path.join(cp_dir, prefix + '????????') 68 | cp_list = glob.glob(pattern) 69 | if len(cp_list) == 0: 70 | return None 71 | return sorted(cp_list)[-1] 72 | 73 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import librosa 5 | import torch 6 | try: 7 | import torch_musa 8 | use_torch_musa = True 9 | except ImportError: 10 | use_torch_musa = False 11 | import pyworld as pw 12 | import parselmouth 13 | import argparse 14 | import shutil 15 | from logger import utils 16 | from tqdm import tqdm 17 | from reflow.extractors import F0_Extractor, Volume_Extractor, Units_Encoder 18 | from reflow.vocoder import Vocoder 19 | from logger.utils import traverse_dir 20 | import concurrent.futures 21 | 22 | def parse_args(args=None, namespace=None): 23 | """Parse command-line arguments.""" 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "-c", 27 | "--config", 28 | type=str, 29 | required=True, 30 | help="path to the config file") 31 | parser.add_argument( 32 | "-d", 33 | "--device", 34 | type=str, 35 | default=None, 36 | required=False, 37 | help="cpu/cuda/musa, auto if not set") 38 | return parser.parse_args(args=args, namespace=namespace) 39 | 40 | def preprocess(path, f0_extractor, volume_extractor, mel_extractor, units_encoder, sample_rate, hop_size, device = 'cuda', use_pitch_aug = False, extensions = ['wav']): 41 | 42 | path_srcdir = os.path.join(path, 'audio') 43 | path_unitsdir = os.path.join(path, 'units') 44 | path_f0dir = os.path.join(path, 'f0') 45 | path_volumedir = os.path.join(path, 'volume') 46 | path_augvoldir = os.path.join(path, 'aug_vol') 47 | path_meldir = os.path.join(path, 'mel') 48 | path_augmeldir = os.path.join(path, 'aug_mel') 49 | path_skipdir = os.path.join(path, 'skip') 50 | 51 | # list files 52 | filelist = traverse_dir( 53 | path_srcdir, 54 | extensions=extensions, 55 | is_pure=True, 56 | is_sort=True, 57 | is_ext=True) 58 | 59 | # pitch augmentation dictionary 60 | pitch_aug_dict = {} 61 | 62 | # run 63 | def process(file): 64 | binfile = file+'.npy' 65 | path_srcfile = os.path.join(path_srcdir, file) 66 | path_unitsfile = os.path.join(path_unitsdir, binfile) 67 | path_f0file = os.path.join(path_f0dir, binfile) 68 | path_volumefile = os.path.join(path_volumedir, binfile) 69 | path_augvolfile = os.path.join(path_augvoldir, binfile) 70 | path_melfile = os.path.join(path_meldir, binfile) 71 | path_augmelfile = os.path.join(path_augmeldir, binfile) 72 | path_skipfile = os.path.join(path_skipdir, file) 73 | 74 | # load audio 75 | audio, _ = librosa.load(path_srcfile, sr=sample_rate) 76 | if len(audio.shape) > 1: 77 | audio = librosa.to_mono(audio) 78 | audio_t = torch.from_numpy(audio).float().to(device) 79 | audio_t = audio_t.unsqueeze(0) 80 | 81 | # extract volume 82 | volume = volume_extractor.extract(audio) 83 | 84 | # extract mel and volume augmentaion 85 | if mel_extractor is not None: 86 | mel_t = mel_extractor.extract(audio_t, sample_rate) 87 | mel = mel_t.squeeze().to('cpu').numpy() 88 | 89 | max_amp = float(torch.max(torch.abs(audio_t))) + 1e-5 90 | max_shift = min(1, np.log10(1/max_amp)) 91 | log10_vol_shift = random.uniform(-1, max_shift) 92 | if use_pitch_aug: 93 | keyshift = random.uniform(-5, 5) 94 | else: 95 | keyshift = 0 96 | 97 | aug_mel_t = mel_extractor.extract(audio_t * (10 ** log10_vol_shift), sample_rate, keyshift = keyshift) 98 | aug_mel = aug_mel_t.squeeze().to('cpu').numpy() 99 | aug_vol = volume_extractor.extract(audio * (10 ** log10_vol_shift)) 100 | 101 | # units encode 102 | units_t = units_encoder.encode(audio_t, sample_rate, hop_size) 103 | units = units_t.squeeze().to('cpu').numpy() 104 | 105 | # extract f0 106 | f0 = f0_extractor.extract(audio, uv_interp = False) 107 | 108 | uv = f0 == 0 109 | if len(f0[~uv]) > 0: 110 | # interpolate the unvoiced f0 111 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) 112 | 113 | # save npy 114 | os.makedirs(os.path.dirname(path_unitsfile), exist_ok=True) 115 | np.save(path_unitsfile, units) 116 | os.makedirs(os.path.dirname(path_f0file), exist_ok=True) 117 | np.save(path_f0file, f0) 118 | os.makedirs(os.path.dirname(path_volumefile), exist_ok=True) 119 | np.save(path_volumefile, volume) 120 | if mel_extractor is not None: 121 | pitch_aug_dict[file] = keyshift 122 | os.makedirs(os.path.dirname(path_melfile), exist_ok=True) 123 | np.save(path_melfile, mel) 124 | os.makedirs(os.path.dirname(path_augmelfile), exist_ok=True) 125 | np.save(path_augmelfile, aug_mel) 126 | os.makedirs(os.path.dirname(path_augvolfile), exist_ok=True) 127 | np.save(path_augvolfile, aug_vol) 128 | else: 129 | print('\n[Error] F0 extraction failed: ' + path_srcfile) 130 | os.makedirs(os.path.dirname(path_skipfile), exist_ok=True) 131 | shutil.move(path_srcfile, os.path.dirname(path_skipfile)) 132 | print('This file has been moved to ' + path_skipfile) 133 | print('Preprocess the audio clips in :', path_srcdir) 134 | 135 | # single process 136 | for file in tqdm(filelist, total=len(filelist)): 137 | process(file) 138 | 139 | if mel_extractor is not None: 140 | path_pitchaugdict = os.path.join(path, 'pitch_aug_dict.npy') 141 | np.save(path_pitchaugdict, pitch_aug_dict) 142 | # multi-process (have bugs) 143 | ''' 144 | with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor: 145 | list(tqdm(executor.map(process, filelist), total=len(filelist))) 146 | ''' 147 | 148 | if __name__ == '__main__': 149 | # parse commands 150 | cmd = parse_args() 151 | 152 | device = cmd.device 153 | if device is None: 154 | if torch.cuda.is_available(): 155 | device = 'cuda' 156 | elif use_torch_musa: 157 | if torch.musa.is_available(): 158 | device = 'musa' 159 | else: 160 | device = 'cpu' 161 | else: 162 | device = 'cpu' 163 | 164 | # load config 165 | args = utils.load_config(cmd.config) 166 | sample_rate = args.data.sampling_rate 167 | hop_size = args.data.block_size 168 | 169 | extensions = args.data.extensions 170 | 171 | # initialize f0 extractor 172 | f0_extractor = F0_Extractor( 173 | args.data.f0_extractor, 174 | args.data.sampling_rate, 175 | args.data.block_size, 176 | args.data.f0_min, 177 | args.data.f0_max) 178 | 179 | # initialize volume extractor 180 | volume_extractor = Volume_Extractor(args.data.block_size) 181 | 182 | # initialize mel extractor 183 | mel_extractor = None 184 | use_pitch_aug = False 185 | if args.model.type in ['RectifiedFlow_VAE']: 186 | mel_extractor = Vocoder(args.vocoder.type, args.vocoder.ckpt, device = device) 187 | if mel_extractor.vocoder_sample_rate != sample_rate or mel_extractor.vocoder_hop_size != hop_size: 188 | mel_extractor = None 189 | print('Unmatch vocoder parameters, mel extraction is ignored!') 190 | elif args.model.use_pitch_aug: 191 | use_pitch_aug = True 192 | 193 | # initialize units encoder 194 | if args.data.encoder == 'cnhubertsoftfish': 195 | cnhubertsoft_gate = args.data.cnhubertsoft_gate 196 | else: 197 | cnhubertsoft_gate = 10 198 | units_encoder = Units_Encoder( 199 | args.data.encoder, 200 | args.data.encoder_ckpt, 201 | args.data.encoder_sample_rate, 202 | args.data.encoder_hop_size, 203 | cnhubertsoft_gate=cnhubertsoft_gate, 204 | device = device) 205 | 206 | # preprocess training set 207 | preprocess(args.data.train_path, f0_extractor, volume_extractor, mel_extractor, units_encoder, sample_rate, hop_size, device = device, use_pitch_aug = use_pitch_aug, extensions = extensions) 208 | 209 | # preprocess validation set 210 | preprocess(args.data.valid_path, f0_extractor, volume_extractor, mel_extractor, units_encoder, sample_rate, hop_size, device = device, use_pitch_aug = False, extensions = extensions) 211 | 212 | -------------------------------------------------------------------------------- /pretrain/contentvec/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /pretrain/nsf_hifigan/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /pretrain/rmvpe/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /reflow/data_loaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | import numpy as np 5 | import librosa 6 | import torch 7 | try: 8 | import torch_musa 9 | except ImportError: 10 | pass 11 | import random 12 | from tqdm import tqdm 13 | from torch.utils.data import Dataset 14 | 15 | def traverse_dir( 16 | root_dir, 17 | extensions, 18 | amount=None, 19 | str_include=None, 20 | str_exclude=None, 21 | is_pure=False, 22 | is_sort=False, 23 | is_ext=True): 24 | 25 | file_list = [] 26 | cnt = 0 27 | for root, _, files in os.walk(root_dir): 28 | for file in files: 29 | if any([file.endswith(f".{ext}") for ext in extensions]): 30 | # path 31 | mix_path = os.path.join(root, file) 32 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 33 | 34 | # amount 35 | if (amount is not None) and (cnt == amount): 36 | if is_sort: 37 | file_list.sort() 38 | return file_list 39 | 40 | # check string 41 | if (str_include is not None) and (str_include not in pure_path): 42 | continue 43 | if (str_exclude is not None) and (str_exclude in pure_path): 44 | continue 45 | 46 | if not is_ext: 47 | ext = pure_path.split('.')[-1] 48 | pure_path = pure_path[:-(len(ext)+1)] 49 | file_list.append(pure_path) 50 | cnt += 1 51 | if is_sort: 52 | file_list.sort() 53 | return file_list 54 | 55 | 56 | def get_data_loaders(args, whole_audio=False): 57 | data_train = AudioDataset( 58 | args.data.train_path, 59 | waveform_sec=args.data.duration, 60 | hop_size=args.data.block_size, 61 | sample_rate=args.data.sampling_rate, 62 | load_all_data=args.train.cache_all_data, 63 | whole_audio=whole_audio, 64 | extensions=args.data.extensions, 65 | n_spk=args.model.n_spk, 66 | device=args.train.cache_device, 67 | fp16=args.train.cache_fp16, 68 | use_aug=True) 69 | loader_train = torch.utils.data.DataLoader( 70 | data_train , 71 | batch_size=args.train.batch_size if not whole_audio else 1, 72 | shuffle=True, 73 | num_workers=args.train.num_workers if args.train.cache_device=='cpu' else 0, 74 | persistent_workers=(args.train.num_workers > 0) if args.train.cache_device=='cpu' else False, 75 | pin_memory=True if args.train.cache_device=='cpu' else False 76 | ) 77 | data_valid = AudioDataset( 78 | args.data.valid_path, 79 | waveform_sec=args.data.duration, 80 | hop_size=args.data.block_size, 81 | sample_rate=args.data.sampling_rate, 82 | load_all_data=args.train.cache_all_data, 83 | whole_audio=True, 84 | extensions=args.data.extensions, 85 | n_spk=args.model.n_spk) 86 | loader_valid = torch.utils.data.DataLoader( 87 | data_valid, 88 | batch_size=1, 89 | shuffle=False, 90 | num_workers=0, 91 | pin_memory=True 92 | ) 93 | return loader_train, loader_valid 94 | 95 | 96 | class AudioDataset(Dataset): 97 | def __init__( 98 | self, 99 | path_root, 100 | waveform_sec, 101 | hop_size, 102 | sample_rate, 103 | load_all_data=True, 104 | whole_audio=False, 105 | extensions=['wav'], 106 | n_spk=1, 107 | device='cpu', 108 | fp16=False, 109 | use_aug=False, 110 | ): 111 | super().__init__() 112 | 113 | self.waveform_sec = waveform_sec 114 | self.sample_rate = sample_rate 115 | self.hop_size = hop_size 116 | self.path_root = path_root 117 | self.paths = traverse_dir( 118 | os.path.join(path_root, 'audio'), 119 | extensions=extensions, 120 | is_pure=True, 121 | is_sort=True, 122 | is_ext=True 123 | ) 124 | self.whole_audio = whole_audio 125 | self.use_aug = use_aug 126 | self.data_buffer={} 127 | self.pitch_aug_dict = np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item() 128 | if load_all_data: 129 | print('Load all the data from :', path_root) 130 | else: 131 | print('Load the f0, volume data from :', path_root) 132 | for name_ext in tqdm(self.paths, total=len(self.paths)): 133 | name = os.path.splitext(name_ext)[0] 134 | path_audio = os.path.join(self.path_root, 'audio', name_ext) 135 | duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate) 136 | 137 | path_f0 = os.path.join(self.path_root, 'f0', name_ext) + '.npy' 138 | f0 = np.load(path_f0) 139 | f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(device) 140 | 141 | path_volume = os.path.join(self.path_root, 'volume', name_ext) + '.npy' 142 | volume = np.load(path_volume) 143 | volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device) 144 | 145 | path_augvol = os.path.join(self.path_root, 'aug_vol', name_ext) + '.npy' 146 | aug_vol = np.load(path_augvol) 147 | aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device) 148 | 149 | if n_spk is not None and n_spk > 1: 150 | dirname_split = re.split(r"_|\-", os.path.dirname(name_ext), 2)[0] 151 | spk_id = int(dirname_split) if str.isdigit(dirname_split) else 0 152 | if spk_id < 1 or spk_id > n_spk: 153 | raise ValueError(' [x] Muiti-speaker traing error : spk_id must be a positive integer from 1 to n_spk ') 154 | else: 155 | spk_id = 1 156 | spk_id = torch.LongTensor(np.array([spk_id])).to(device) 157 | 158 | if load_all_data: 159 | ''' 160 | audio, sr = librosa.load(path_audio, sr=self.sample_rate) 161 | if len(audio.shape) > 1: 162 | audio = librosa.to_mono(audio) 163 | audio = torch.from_numpy(audio).to(device) 164 | ''' 165 | path_mel = os.path.join(self.path_root, 'mel', name_ext) + '.npy' 166 | mel = np.load(path_mel) 167 | mel = torch.from_numpy(mel).to(device) 168 | 169 | path_augmel = os.path.join(self.path_root, 'aug_mel', name_ext) + '.npy' 170 | aug_mel = np.load(path_augmel) 171 | aug_mel = torch.from_numpy(aug_mel).to(device) 172 | 173 | path_units = os.path.join(self.path_root, 'units', name_ext) + '.npy' 174 | units = np.load(path_units) 175 | units = torch.from_numpy(units).to(device) 176 | 177 | if fp16: 178 | mel = mel.half() 179 | aug_mel = aug_mel.half() 180 | units = units.half() 181 | 182 | self.data_buffer[name_ext] = { 183 | 'duration': duration, 184 | 'mel': mel, 185 | 'aug_mel': aug_mel, 186 | 'units': units, 187 | 'f0': f0, 188 | 'volume': volume, 189 | 'aug_vol': aug_vol, 190 | 'spk_id': spk_id 191 | } 192 | else: 193 | self.data_buffer[name_ext] = { 194 | 'duration': duration, 195 | 'f0': f0, 196 | 'volume': volume, 197 | 'aug_vol': aug_vol, 198 | 'spk_id': spk_id 199 | } 200 | 201 | 202 | def __getitem__(self, file_idx): 203 | name_ext = self.paths[file_idx] 204 | data_buffer = self.data_buffer[name_ext] 205 | # check duration. if too short, then skip 206 | if data_buffer['duration'] < (self.waveform_sec + 0.1): 207 | return self.__getitem__( (file_idx + 1) % len(self.paths)) 208 | 209 | # get item 210 | return self.get_data(name_ext, data_buffer) 211 | 212 | def get_data(self, name_ext, data_buffer): 213 | name = os.path.splitext(name_ext)[0] 214 | frame_resolution = self.hop_size / self.sample_rate 215 | duration = data_buffer['duration'] 216 | waveform_sec = duration if self.whole_audio else self.waveform_sec 217 | 218 | # load audio 219 | idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1) 220 | start_frame = int(idx_from / frame_resolution) 221 | units_frame_len = int(waveform_sec / frame_resolution) 222 | aug_flag = random.choice([True, False]) and self.use_aug 223 | ''' 224 | audio = data_buffer.get('audio') 225 | if audio is None: 226 | path_audio = os.path.join(self.path_root, 'audio', name) + '.wav' 227 | audio, sr = librosa.load( 228 | path_audio, 229 | sr = self.sample_rate, 230 | offset = start_frame * frame_resolution, 231 | duration = waveform_sec) 232 | if len(audio.shape) > 1: 233 | audio = librosa.to_mono(audio) 234 | # clip audio into N seconds 235 | audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size] 236 | audio = torch.from_numpy(audio).float() 237 | else: 238 | audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size] 239 | ''' 240 | # load mel 241 | mel_key = 'aug_mel' if aug_flag else 'mel' 242 | mel = data_buffer.get(mel_key) 243 | if mel is None: 244 | mel = os.path.join(self.path_root, mel_key, name_ext) + '.npy' 245 | mel = np.load(mel) 246 | mel = mel[start_frame : start_frame + units_frame_len] 247 | mel = torch.from_numpy(mel).float() 248 | else: 249 | mel = mel[start_frame : start_frame + units_frame_len] 250 | 251 | # load units 252 | units = data_buffer.get('units') 253 | if units is None: 254 | units = os.path.join(self.path_root, 'units', name_ext) + '.npy' 255 | units = np.load(units) 256 | units = units[start_frame : start_frame + units_frame_len] 257 | units = torch.from_numpy(units).float() 258 | else: 259 | units = units[start_frame : start_frame + units_frame_len] 260 | 261 | # load f0 262 | f0 = data_buffer.get('f0') 263 | aug_shift = 0 264 | if aug_flag: 265 | aug_shift = self.pitch_aug_dict[name_ext] 266 | f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len] 267 | 268 | # load volume 269 | vol_key = 'aug_vol' if aug_flag else 'volume' 270 | volume = data_buffer.get(vol_key) 271 | volume_frames = volume[start_frame : start_frame + units_frame_len] 272 | 273 | # load spk_id 274 | spk_id = data_buffer.get('spk_id') 275 | 276 | # load shift 277 | aug_shift = torch.from_numpy(np.array([[aug_shift]])).float() 278 | 279 | return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name, name_ext=name_ext) 280 | 281 | def __len__(self): 282 | return len(self.paths) -------------------------------------------------------------------------------- /reflow/extractors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import yaml 4 | import torch 5 | try: 6 | import torch_musa 7 | use_torch_musa = True 8 | except ImportError: 9 | use_torch_musa = False 10 | import torch.nn.functional as F 11 | import pyworld as pw 12 | import parselmouth 13 | import torchcrepe 14 | import resampy 15 | from transformers import HubertModel, Wav2Vec2FeatureExtractor 16 | from fairseq import checkpoint_utils 17 | from encoder.hubert.model import HubertSoft 18 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 19 | from torchaudio.transforms import Resample 20 | import time 21 | 22 | CREPE_RESAMPLE_KERNEL = {} 23 | F0_KERNEL = {} 24 | 25 | 26 | def MaskedAvgPool1d(x, kernel_size): 27 | x = x.unsqueeze(1) 28 | x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect") 29 | mask = ~torch.isnan(x) 30 | masked_x = torch.where(mask, x, torch.zeros_like(x)) 31 | ones_kernel = torch.ones(x.size(1), 1, kernel_size, device=x.device) 32 | 33 | # Perform sum pooling 34 | sum_pooled = F.conv1d( 35 | masked_x, 36 | ones_kernel, 37 | stride=1, 38 | padding=0, 39 | groups=x.size(1), 40 | ) 41 | 42 | # Count the non-masked (valid) elements in each pooling window 43 | valid_count = F.conv1d( 44 | mask.float(), 45 | ones_kernel, 46 | stride=1, 47 | padding=0, 48 | groups=x.size(1), 49 | ) 50 | valid_count = valid_count.clamp(min=1) # Avoid division by zero 51 | 52 | # Perform masked average pooling 53 | avg_pooled = sum_pooled / valid_count 54 | 55 | return avg_pooled.squeeze(1) 56 | 57 | 58 | def MedianPool1d(x, kernel_size): 59 | x = x.unsqueeze(1) 60 | x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect") 61 | x = x.squeeze(1) 62 | x = x.unfold(1, kernel_size, 1) 63 | x, _ = torch.sort(x, dim=-1) 64 | return x[:, :, (kernel_size - 1) // 2] 65 | 66 | 67 | class F0_Extractor: 68 | def __init__(self, f0_extractor, sample_rate = 44100, hop_size = 512, f0_min = 65, f0_max = 800): 69 | self.f0_extractor = f0_extractor 70 | self.sample_rate = sample_rate 71 | self.hop_size = hop_size 72 | self.f0_min = f0_min 73 | self.f0_max = f0_max 74 | if f0_extractor == 'crepe': 75 | key_str = str(sample_rate) 76 | if key_str not in CREPE_RESAMPLE_KERNEL: 77 | CREPE_RESAMPLE_KERNEL[key_str] = Resample(sample_rate, 16000, lowpass_filter_width = 128) 78 | self.resample_kernel = CREPE_RESAMPLE_KERNEL[key_str] 79 | if f0_extractor == 'rmvpe': 80 | if 'rmvpe' not in F0_KERNEL : 81 | from encoder.rmvpe import RMVPE 82 | F0_KERNEL['rmvpe'] = RMVPE('pretrain/rmvpe/model.pt', hop_length=160) 83 | self.rmvpe = F0_KERNEL['rmvpe'] 84 | if f0_extractor == 'fcpe': 85 | if torch.cuda.is_available(): 86 | self.device_fcpe = 'cuda' 87 | elif use_torch_musa: 88 | if torch.musa.is_available(): 89 | self.device_fcpe = 'musa' 90 | else: 91 | self.device_fcpe = 'cpu' 92 | else: 93 | self.device_fcpe = 'cpu' 94 | if 'fcpe' not in F0_KERNEL : 95 | from torchfcpe import spawn_bundled_infer_model 96 | F0_KERNEL['fcpe'] = spawn_bundled_infer_model(device=self.device_fcpe) 97 | self.fcpe = F0_KERNEL['fcpe'] 98 | 99 | def extract(self, audio, uv_interp = False, device = None, silence_front = 0): # audio: 1d numpy array 100 | # extractor start time 101 | n_frames = int(len(audio) // self.hop_size) + 1 102 | 103 | start_frame = int(silence_front * self.sample_rate / self.hop_size) 104 | real_silence_front = start_frame * self.hop_size / self.sample_rate 105 | audio = audio[int(np.round(real_silence_front * self.sample_rate)) : ] 106 | 107 | # extract f0 using parselmouth 108 | if self.f0_extractor == 'parselmouth': 109 | l_pad = int(np.ceil(1.5 / self.f0_min * self.sample_rate)) 110 | r_pad = int(self.hop_size * ((len(audio) - 1) // self.hop_size + 1) - len(audio) + l_pad + 1) 111 | s = parselmouth.Sound(np.pad(audio, (l_pad, r_pad)), self.sample_rate).to_pitch_ac( 112 | time_step = self.hop_size / self.sample_rate, 113 | voicing_threshold = 0.6, 114 | pitch_floor = self.f0_min, 115 | pitch_ceiling = self.f0_max) 116 | assert np.abs(s.t1 - 1.5 / self.f0_min) < 0.001 117 | f0 = np.pad(s.selected_array['frequency'], (start_frame, 0)) 118 | if len(f0) < n_frames: 119 | f0 = np.pad(f0, (0, n_frames - len(f0))) 120 | f0 = f0[: n_frames] 121 | 122 | # extract f0 using dio 123 | elif self.f0_extractor == 'dio': 124 | _f0, t = pw.dio( 125 | audio.astype('double'), 126 | self.sample_rate, 127 | f0_floor = self.f0_min, 128 | f0_ceil = self.f0_max, 129 | channels_in_octave=2, 130 | frame_period = (1000 * self.hop_size / self.sample_rate)) 131 | f0 = pw.stonemask(audio.astype('double'), _f0, t, self.sample_rate) 132 | f0 = np.pad(f0.astype('float'), (start_frame, n_frames - len(f0) - start_frame)) 133 | 134 | # extract f0 using harvest 135 | elif self.f0_extractor == 'harvest': 136 | f0, _ = pw.harvest( 137 | audio.astype('double'), 138 | self.sample_rate, 139 | f0_floor = self.f0_min, 140 | f0_ceil = self.f0_max, 141 | frame_period = (1000 * self.hop_size / self.sample_rate)) 142 | f0 = np.pad(f0.astype('float'), (start_frame, n_frames - len(f0) - start_frame)) 143 | 144 | # extract f0 using crepe 145 | elif self.f0_extractor == 'crepe': 146 | if device is None: 147 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 148 | resample_kernel = self.resample_kernel.to(device) 149 | wav16k_torch = resample_kernel(torch.FloatTensor(audio).unsqueeze(0).to(device)) 150 | 151 | f0, pd = torchcrepe.predict(wav16k_torch, 16000, 80, self.f0_min, self.f0_max, pad=True, model='full', batch_size=512, device=device, return_periodicity=True) 152 | pd = MedianPool1d(pd, 4) 153 | f0 = torchcrepe.threshold.At(0.05)(f0, pd) 154 | f0 = MaskedAvgPool1d(f0, 4) 155 | 156 | f0 = f0.squeeze(0).cpu().numpy() 157 | f0 = np.array([f0[int(min(int(np.round(n * self.hop_size / self.sample_rate / 0.005)), len(f0) - 1))] for n in range(n_frames - start_frame)]) 158 | f0 = np.pad(f0, (start_frame, 0)) 159 | 160 | # extract f0 using rmvpe 161 | elif self.f0_extractor == "rmvpe": 162 | f0 = self.rmvpe.infer_from_audio(audio, self.sample_rate, device=device, thred=0.03, use_viterbi=False) 163 | uv = f0 == 0 164 | if len(f0[~uv]) > 0: 165 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) 166 | origin_time = 0.01 * np.arange(len(f0)) 167 | target_time = self.hop_size / self.sample_rate * np.arange(n_frames - start_frame) 168 | f0 = np.interp(target_time, origin_time, f0) 169 | uv = np.interp(target_time, origin_time, uv.astype(float)) > 0.5 170 | f0[uv] = 0 171 | f0 = np.pad(f0, (start_frame, 0)) 172 | 173 | # extract f0 using fcpe 174 | elif self.f0_extractor == "fcpe": 175 | _audio = torch.from_numpy(audio).to(self.device_fcpe).unsqueeze(0) 176 | f0 = self.fcpe(_audio, sr=self.sample_rate, decoder_mode="local_argmax", threshold=0.006) 177 | f0 = f0.squeeze().cpu().numpy() 178 | uv = f0 == 0 179 | if len(f0[~uv]) > 0: 180 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) 181 | origin_time = 0.01 * np.arange(len(f0)) 182 | target_time = self.hop_size / self.sample_rate * np.arange(n_frames - start_frame) 183 | f0 = np.interp(target_time, origin_time, f0) 184 | uv = np.interp(target_time, origin_time, uv.astype(float)) > 0.5 185 | f0[uv] = 0 186 | f0 = np.pad(f0, (start_frame, 0)) 187 | 188 | else: 189 | raise ValueError(f" [x] Unknown f0 extractor: {self.f0_extractor}") 190 | 191 | # interpolate the unvoiced f0 192 | if uv_interp: 193 | uv = f0 == 0 194 | if len(f0[~uv]) > 0: 195 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) 196 | f0[f0 < self.f0_min] = self.f0_min 197 | return f0 198 | 199 | 200 | class Volume_Extractor: 201 | def __init__(self, hop_size = 512): 202 | self.hop_size = hop_size 203 | 204 | def extract(self, audio): # audio: 1d numpy array 205 | n_frames = int(len(audio) // self.hop_size) + 1 206 | audio2 = audio ** 2 207 | audio2 = np.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect') 208 | volume = np.array([np.mean(audio2[int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)]) 209 | volume = np.sqrt(volume) 210 | return volume 211 | 212 | 213 | class Units_Encoder: 214 | def __init__(self, encoder, encoder_ckpt, encoder_sample_rate = 16000, encoder_hop_size = 320, device = None, 215 | cnhubertsoft_gate=10): 216 | if device is None: 217 | if torch.cuda.is_available(): 218 | device = 'cuda' 219 | elif use_torch_musa: 220 | if torch.musa.is_available(): 221 | device = 'musa' 222 | else: 223 | device = 'cpu' 224 | else: 225 | device = 'cpu' 226 | self.device = device 227 | 228 | is_loaded_encoder = False 229 | if encoder == 'hubertsoft': 230 | self.model = Audio2HubertSoft(encoder_ckpt).to(device) 231 | is_loaded_encoder = True 232 | if encoder == 'hubertbase': 233 | self.model = Audio2HubertBase(encoder_ckpt, device=device) 234 | is_loaded_encoder = True 235 | if encoder == 'hubertbase768': 236 | self.model = Audio2HubertBase768(encoder_ckpt, device=device) 237 | is_loaded_encoder = True 238 | if encoder == 'hubertbase768l12': 239 | self.model = Audio2HubertBase768L12(encoder_ckpt, device=device) 240 | is_loaded_encoder = True 241 | if encoder == 'hubertlarge1024l24': 242 | self.model = Audio2HubertLarge1024L24(encoder_ckpt, device=device) 243 | is_loaded_encoder = True 244 | if encoder == 'contentvec': 245 | self.model = Audio2ContentVec(encoder_ckpt, device=device) 246 | is_loaded_encoder = True 247 | if encoder == 'contentvec768': 248 | self.model = Audio2ContentVec768(encoder_ckpt, device=device) 249 | is_loaded_encoder = True 250 | if encoder == 'contentvec768l12': 251 | self.model = Audio2ContentVec768L12(encoder_ckpt, device=device) 252 | is_loaded_encoder = True 253 | if encoder == 'cnhubertsoftfish': 254 | self.model = CNHubertSoftFish(encoder_ckpt, device=device, gate_size=cnhubertsoft_gate) 255 | is_loaded_encoder = True 256 | if not is_loaded_encoder: 257 | raise ValueError(f" [x] Unknown units encoder: {encoder}") 258 | 259 | self.resample_kernel = {} 260 | self.encoder_sample_rate = encoder_sample_rate 261 | self.encoder_hop_size = encoder_hop_size 262 | 263 | def encode(self, 264 | audio, # B, T 265 | sample_rate, 266 | hop_size): 267 | 268 | # resample 269 | if sample_rate == self.encoder_sample_rate: 270 | audio_res = audio 271 | else: 272 | key_str = str(sample_rate) 273 | if key_str not in self.resample_kernel: 274 | self.resample_kernel[key_str] = Resample(sample_rate, self.encoder_sample_rate, lowpass_filter_width = 128).to(self.device) 275 | audio_res = self.resample_kernel[key_str](audio) 276 | 277 | # encode 278 | if audio_res.size(-1) < 400: 279 | audio_res = torch.nn.functional.pad(audio, (0, 400 - audio_res.size(-1))) 280 | units = self.model(audio_res) 281 | 282 | # alignment 283 | n_frames = audio.size(-1) // hop_size + 1 284 | ratio = (hop_size / sample_rate) / (self.encoder_hop_size / self.encoder_sample_rate) 285 | index = torch.clamp(torch.round(ratio * torch.arange(n_frames).to(self.device)).long(), max = units.size(1) - 1) 286 | units_aligned = torch.gather(units, 1, index.unsqueeze(0).unsqueeze(-1).repeat([1, 1, units.size(-1)])) 287 | return units_aligned 288 | 289 | class Audio2HubertSoft(torch.nn.Module): 290 | def __init__(self, path, h_sample_rate = 16000, h_hop_size = 320): 291 | super().__init__() 292 | print(' [Encoder Model] HuBERT Soft') 293 | self.hubert = HubertSoft() 294 | print(' [Loading] ' + path) 295 | checkpoint = torch.load(path) 296 | consume_prefix_in_state_dict_if_present(checkpoint, "module.") 297 | self.hubert.load_state_dict(checkpoint) 298 | self.hubert.eval() 299 | 300 | def forward(self, 301 | audio): # B, T 302 | with torch.inference_mode(): 303 | units = self.hubert.units(audio.unsqueeze(1)) 304 | return units 305 | 306 | 307 | class Audio2ContentVec(): 308 | def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): 309 | self.device = device 310 | print(' [Encoder Model] Content Vec') 311 | print(' [Loading] ' + path) 312 | self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) 313 | self.hubert = self.models[0] 314 | self.hubert = self.hubert.to(self.device) 315 | self.hubert.eval() 316 | 317 | def __call__(self, 318 | audio): # B, T 319 | # wav_tensor = torch.from_numpy(audio).to(self.device) 320 | wav_tensor = audio 321 | feats = wav_tensor.view(1, -1) 322 | padding_mask = torch.BoolTensor(feats.shape).fill_(False) 323 | inputs = { 324 | "source": feats.to(wav_tensor.device), 325 | "padding_mask": padding_mask.to(wav_tensor.device), 326 | "output_layer": 9, # layer 9 327 | } 328 | with torch.no_grad(): 329 | logits = self.hubert.extract_features(**inputs) 330 | feats = self.hubert.final_proj(logits[0]) 331 | units = feats # .transpose(2, 1) 332 | return units 333 | 334 | 335 | class Audio2ContentVec768(): 336 | def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): 337 | self.device = device 338 | print(' [Encoder Model] Content Vec') 339 | print(' [Loading] ' + path) 340 | self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) 341 | self.hubert = self.models[0] 342 | self.hubert = self.hubert.to(self.device) 343 | self.hubert.eval() 344 | 345 | def __call__(self, 346 | audio): # B, T 347 | # wav_tensor = torch.from_numpy(audio).to(self.device) 348 | wav_tensor = audio 349 | feats = wav_tensor.view(1, -1) 350 | padding_mask = torch.BoolTensor(feats.shape).fill_(False) 351 | inputs = { 352 | "source": feats.to(wav_tensor.device), 353 | "padding_mask": padding_mask.to(wav_tensor.device), 354 | "output_layer": 9, # layer 9 355 | } 356 | with torch.no_grad(): 357 | logits = self.hubert.extract_features(**inputs) 358 | feats = logits[0] 359 | units = feats # .transpose(2, 1) 360 | return units 361 | 362 | 363 | class Audio2ContentVec768L12(): 364 | def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): 365 | self.device = device 366 | print(' [Encoder Model] Content Vec') 367 | print(' [Loading] ' + path) 368 | self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) 369 | self.hubert = self.models[0] 370 | self.hubert = self.hubert.to(self.device) 371 | self.hubert.eval() 372 | 373 | def __call__(self, 374 | audio): # B, T 375 | # wav_tensor = torch.from_numpy(audio).to(self.device) 376 | wav_tensor = audio 377 | feats = wav_tensor.view(1, -1) 378 | padding_mask = torch.BoolTensor(feats.shape).fill_(False) 379 | inputs = { 380 | "source": feats.to(wav_tensor.device), 381 | "padding_mask": padding_mask.to(wav_tensor.device), 382 | "output_layer": 12, # layer 12 383 | } 384 | with torch.no_grad(): 385 | logits = self.hubert.extract_features(**inputs) 386 | feats = logits[0] 387 | units = feats # .transpose(2, 1) 388 | return units 389 | 390 | 391 | class CNHubertSoftFish(torch.nn.Module): 392 | def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu', gate_size=10): 393 | super().__init__() 394 | self.device = device 395 | self.gate_size = gate_size 396 | 397 | self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( 398 | "./pretrain/TencentGameMate/chinese-hubert-base") 399 | self.model = HubertModel.from_pretrained("./pretrain/TencentGameMate/chinese-hubert-base") 400 | self.proj = torch.nn.Sequential(torch.nn.Dropout(0.1), torch.nn.Linear(768, 256)) 401 | # self.label_embedding = nn.Embedding(128, 256) 402 | 403 | state_dict = torch.load(path, map_location=device) 404 | self.load_state_dict(state_dict) 405 | 406 | @torch.no_grad() 407 | def forward(self, audio): 408 | input_values = self.feature_extractor( 409 | audio, sampling_rate=16000, return_tensors="pt" 410 | ).input_values 411 | input_values = input_values.to(self.model.device) 412 | 413 | return self._forward(input_values[0]) 414 | 415 | @torch.no_grad() 416 | def _forward(self, input_values): 417 | features = self.model(input_values) 418 | features = self.proj(features.last_hidden_state) 419 | 420 | # Top-k gating 421 | topk, indices = torch.topk(features, self.gate_size, dim=2) 422 | features = torch.zeros_like(features).scatter(2, indices, topk) 423 | features = features / features.sum(2, keepdim=True) 424 | 425 | return features.to(self.device) # .transpose(1, 2) 426 | 427 | 428 | class Audio2HubertBase(): 429 | def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): 430 | self.device = device 431 | print(' [Encoder Model] HuBERT Base') 432 | print(' [Loading] ' + path) 433 | self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) 434 | self.hubert = self.models[0] 435 | self.hubert = self.hubert.to(self.device) 436 | self.hubert = self.hubert.float() 437 | self.hubert.eval() 438 | 439 | def __call__(self, 440 | audio): # B, T 441 | with torch.no_grad(): 442 | padding_mask = torch.BoolTensor(audio.shape).fill_(False) 443 | inputs = { 444 | "source": audio.to(self.device), 445 | "padding_mask": padding_mask.to(self.device), 446 | "output_layer": 9, # layer 9 447 | } 448 | logits = self.hubert.extract_features(**inputs) 449 | units = self.hubert.final_proj(logits[0]) 450 | return units 451 | 452 | 453 | class Audio2HubertBase768(): 454 | def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): 455 | self.device = device 456 | print(' [Encoder Model] HuBERT Base') 457 | print(' [Loading] ' + path) 458 | self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) 459 | self.hubert = self.models[0] 460 | self.hubert = self.hubert.to(self.device) 461 | self.hubert = self.hubert.float() 462 | self.hubert.eval() 463 | 464 | def __call__(self, 465 | audio): # B, T 466 | with torch.no_grad(): 467 | padding_mask = torch.BoolTensor(audio.shape).fill_(False) 468 | inputs = { 469 | "source": audio.to(self.device), 470 | "padding_mask": padding_mask.to(self.device), 471 | "output_layer": 9, # layer 9 472 | } 473 | logits = self.hubert.extract_features(**inputs) 474 | units = logits[0] 475 | return units 476 | 477 | 478 | class Audio2HubertBase768L12(): 479 | def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): 480 | self.device = device 481 | print(' [Encoder Model] HuBERT Base') 482 | print(' [Loading] ' + path) 483 | self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) 484 | self.hubert = self.models[0] 485 | self.hubert = self.hubert.to(self.device) 486 | self.hubert = self.hubert.float() 487 | self.hubert.eval() 488 | 489 | def __call__(self, 490 | audio): # B, T 491 | with torch.no_grad(): 492 | padding_mask = torch.BoolTensor(audio.shape).fill_(False) 493 | inputs = { 494 | "source": audio.to(self.device), 495 | "padding_mask": padding_mask.to(self.device), 496 | "output_layer": 12, # layer 12 497 | } 498 | logits = self.hubert.extract_features(**inputs) 499 | units = logits[0] 500 | return units 501 | 502 | 503 | class Audio2HubertLarge1024L24(): 504 | def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'): 505 | self.device = device 506 | print(' [Encoder Model] HuBERT Base') 507 | print(' [Loading] ' + path) 508 | self.models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", ) 509 | self.hubert = self.models[0] 510 | self.hubert = self.hubert.to(self.device) 511 | self.hubert = self.hubert.float() 512 | self.hubert.eval() 513 | 514 | def __call__(self, 515 | audio): # B, T 516 | with torch.no_grad(): 517 | padding_mask = torch.BoolTensor(audio.shape).fill_(False) 518 | inputs = { 519 | "source": audio.to(self.device), 520 | "padding_mask": padding_mask.to(self.device), 521 | "output_layer": 24, # layer 24 522 | } 523 | logits = self.hubert.extract_features(**inputs) 524 | units = logits[0] 525 | return units 526 | -------------------------------------------------------------------------------- /reflow/model_conformer_naive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | try: 3 | import torch_musa 4 | except ImportError: 5 | pass 6 | from torch import nn 7 | 8 | # From https://github.com/CNChTu/Diffusion-SVC/ by CNChTu 9 | # License: MIT 10 | 11 | 12 | class ConformerNaiveEncoder(nn.Module): 13 | """ 14 | Conformer Naive Encoder 15 | 16 | Args: 17 | dim_model (int): Dimension of model 18 | num_layers (int): Number of layers 19 | num_heads (int): Number of heads 20 | use_norm (bool): Whether to use norm for FastAttention, only True can use bf16/fp16, default False 21 | conv_only (bool): Whether to use only conv module without attention, default False 22 | conv_dropout (float): Dropout rate of conv module, default 0. 23 | atten_dropout (float): Dropout rate of attention module, default 0. 24 | """ 25 | 26 | def __init__(self, 27 | num_layers: int, 28 | num_heads: int, 29 | dim_model: int, 30 | use_norm: bool = False, 31 | conv_only: bool = False, 32 | conv_dropout: float = 0., 33 | atten_dropout: float = 0. 34 | ): 35 | super().__init__() 36 | self.num_layers = num_layers 37 | self.num_heads = num_heads 38 | self.dim_model = dim_model 39 | self.use_norm = use_norm 40 | self.residual_dropout = 0.1 # 废弃代码,仅做兼容性保留 41 | self.attention_dropout = 0.1 # 废弃代码,仅做兼容性保留 42 | 43 | self.encoder_layers = nn.ModuleList( 44 | [ 45 | CFNEncoderLayer(dim_model, num_heads, use_norm, conv_only, conv_dropout, atten_dropout) 46 | for _ in range(num_layers) 47 | ] 48 | ) 49 | 50 | def forward(self, x, mask=None) -> torch.Tensor: 51 | """ 52 | Args: 53 | x (torch.Tensor): Input tensor (#batch, length, dim_model) 54 | mask (torch.Tensor): Mask tensor, default None 55 | return: 56 | torch.Tensor: Output tensor (#batch, length, dim_model) 57 | """ 58 | 59 | for (i, layer) in enumerate(self.encoder_layers): 60 | x = layer(x, mask) 61 | return x # (#batch, length, dim_model) 62 | 63 | 64 | class CFNEncoderLayer(nn.Module): 65 | """ 66 | Conformer Naive Encoder Layer 67 | 68 | Args: 69 | dim_model (int): Dimension of model 70 | num_heads (int): Number of heads 71 | use_norm (bool): Whether to use norm for FastAttention, only True can use bf16/fp16, default False 72 | conv_only (bool): Whether to use only conv module without attention, default False 73 | conv_dropout (float): Dropout rate of conv module, default 0.1 74 | atten_dropout (float): Dropout rate of attention module, default 0.1 75 | """ 76 | 77 | def __init__(self, 78 | dim_model: int, 79 | num_heads: int = 8, 80 | use_norm: bool = False, 81 | conv_only: bool = False, 82 | conv_dropout: float = 0., 83 | atten_dropout: float = 0.1 84 | ): 85 | super().__init__() 86 | 87 | self.conformer = ConformerConvModule(dim_model, use_norm=use_norm, dropout=conv_dropout) 88 | 89 | self.norm = nn.LayerNorm(dim_model) 90 | 91 | self.dropout = nn.Dropout(0.1) # 废弃代码,仅做兼容性保留 92 | 93 | # selfatt -> fastatt: performer! 94 | if not conv_only: 95 | self.attn = nn.TransformerEncoderLayer( 96 | d_model=dim_model, 97 | nhead=num_heads, 98 | dim_feedforward=dim_model * 4, 99 | dropout=atten_dropout, 100 | activation='gelu' 101 | ) 102 | else: 103 | self.attn = None 104 | 105 | def forward(self, x, mask=None) -> torch.Tensor: 106 | """ 107 | Args: 108 | x (torch.Tensor): Input tensor (#batch, length, dim_model) 109 | mask (torch.Tensor): Mask tensor, default None 110 | return: 111 | torch.Tensor: Output tensor (#batch, length, dim_model) 112 | """ 113 | if self.attn is not None: 114 | x = x + (self.attn(self.norm(x), mask=mask)) 115 | 116 | x = x + (self.conformer(x)) 117 | 118 | return x # (#batch, length, dim_model) 119 | 120 | 121 | class ConformerConvModule(nn.Module): 122 | def __init__( 123 | self, 124 | dim, 125 | expansion_factor=2, 126 | kernel_size=31, 127 | dropout=0., 128 | use_norm=False, 129 | conv_model_type='mode1' 130 | ): 131 | super().__init__() 132 | 133 | inner_dim = dim * expansion_factor 134 | padding = calc_same_padding(kernel_size) 135 | 136 | if conv_model_type == 'mode1': 137 | self.net = nn.Sequential( 138 | nn.LayerNorm(dim) if use_norm else nn.Identity(), 139 | Transpose((1, 2)), 140 | nn.Conv1d(dim, inner_dim * 2, 1), 141 | nn.GLU(dim=1), 142 | nn.Conv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=padding[0], groups=inner_dim), 143 | nn.SiLU(), 144 | nn.Conv1d(inner_dim, dim, 1), 145 | Transpose((1, 2)), 146 | nn.Dropout(dropout) 147 | ) 148 | elif conv_model_type == 'mode2': 149 | raise NotImplementedError('mode2 not implemented yet') 150 | else: 151 | raise ValueError(f'{conv_model_type} is not a valid conv_model_type') 152 | 153 | def forward(self, x): 154 | return self.net(x) 155 | 156 | 157 | def calc_same_padding(kernel_size): 158 | pad = kernel_size // 2 159 | return (pad, pad - (kernel_size + 1) % 2) 160 | 161 | 162 | class Transpose(nn.Module): 163 | def __init__(self, dims): 164 | super().__init__() 165 | assert len(dims) == 2, 'dims must be a tuple of two dimensions' 166 | self.dims = dims 167 | 168 | def forward(self, x): 169 | return x.transpose(*self.dims) 170 | -------------------------------------------------------------------------------- /reflow/naive_v2_diff.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | try: 6 | import torch_musa 7 | except ImportError: 8 | pass 9 | import torch.nn.functional as F 10 | from torch import nn 11 | from .model_conformer_naive import ConformerConvModule 12 | import random 13 | 14 | 15 | # from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py 16 | # 参考了这个 17 | 18 | 19 | class DiffusionEmbedding(nn.Module): 20 | """Diffusion Step Embedding""" 21 | 22 | def __init__(self, d_denoiser): 23 | super(DiffusionEmbedding, self).__init__() 24 | self.dim = d_denoiser 25 | 26 | def forward(self, x): 27 | device = x.device 28 | half_dim = self.dim // 2 29 | emb = math.log(10000) / (half_dim - 1) 30 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 31 | emb = x[:, None] * emb[None, :] 32 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 33 | return emb 34 | 35 | 36 | class NaiveV2DiffLayer(nn.Module): 37 | 38 | def __init__(self, 39 | dim_model: int, 40 | dim_cond: int, 41 | num_heads: int = 4, 42 | use_norm: bool = False, 43 | conv_only: bool = True, 44 | conv_dropout: float = 0., 45 | atten_dropout: float = 0.1, 46 | use_mlp=True, 47 | expansion_factor=2, 48 | kernel_size=31, 49 | wavenet_like=False, 50 | conv_model_type='mode1', 51 | ): 52 | super().__init__() 53 | 54 | self.conformer = ConformerConvModule( 55 | dim_model, 56 | expansion_factor=expansion_factor, 57 | kernel_size=kernel_size, 58 | dropout=conv_dropout, 59 | use_norm=use_norm, 60 | conv_model_type=conv_model_type, 61 | ) 62 | #self.norm = nn.LayerNorm(dim_model) 63 | 64 | self.dropout = nn.Dropout(0.1) # 废弃代码,仅做兼容性保留 65 | if wavenet_like: 66 | self.wavenet_like_proj = nn.Conv1d(dim_model, 2 * dim_model, 1) 67 | else: 68 | self.wavenet_like_proj = None 69 | 70 | self.diffusion_step_projection = nn.Conv1d(dim_model, dim_model, 1) 71 | if dim_cond > 0: 72 | self.condition_projection = nn.Conv1d(dim_cond, dim_model, 1) 73 | else: 74 | self.condition_projection = nn.Identity() 75 | 76 | # selfatt -> fastatt: performer! 77 | if not conv_only: 78 | self.attn = nn.TransformerEncoderLayer( 79 | d_model=dim_model, 80 | nhead=num_heads, 81 | dim_feedforward=dim_model * 4, 82 | dropout=atten_dropout, 83 | activation='gelu' 84 | ) 85 | self.norm = nn.LayerNorm(dim_model) 86 | else: 87 | self.attn = None 88 | 89 | def forward(self, x, condition=None, diffusion_step=None) -> torch.Tensor: 90 | res_x = x.transpose(1, 2) 91 | x = x + self.diffusion_step_projection(diffusion_step) + self.condition_projection(condition) 92 | x = x.transpose(1, 2) 93 | 94 | if self.attn is not None: 95 | x = (self.attn(self.norm(x))) 96 | 97 | x = self.conformer(x) # (#batch, dim_model, length) 98 | 99 | if self.wavenet_like_proj is not None: 100 | x = self.wavenet_like_proj(x.transpose(1, 2)).transpose(1, 2) 101 | x = F.glu(x, dim=-1) 102 | return ((x + res_x)/math.sqrt(2.0)).transpose(1, 2), res_x.transpose(1, 2) 103 | else: 104 | x = x + res_x 105 | x = x.transpose(1, 2) 106 | return x # (#batch, length, dim_model) 107 | 108 | 109 | class NaiveV2Diff(nn.Module): 110 | def __init__( 111 | self, 112 | mel_channels=128, 113 | dim=512, 114 | use_mlp=True, 115 | mlp_factor=4, 116 | condition_dim=256, 117 | num_layers=20, 118 | expansion_factor=2, 119 | kernel_size=31, 120 | conv_only=True, 121 | wavenet_like=False, 122 | use_norm=False, 123 | conv_model_type='mode1', 124 | conv_dropout=0.0, 125 | atten_dropout=0.1, 126 | ): 127 | super(NaiveV2Diff, self).__init__() 128 | self.wavenet_like = wavenet_like 129 | self.mask_cond_ratio = None 130 | 131 | self.input_projection = nn.Conv1d(mel_channels, dim, 1) 132 | self.diffusion_embedding = nn.Sequential( 133 | DiffusionEmbedding(dim), 134 | nn.Linear(dim, dim * mlp_factor), 135 | nn.GELU(), 136 | nn.Linear(dim * mlp_factor, dim), 137 | ) 138 | 139 | if use_mlp and condition_dim > 0: 140 | self.conditioner_projection = nn.Sequential( 141 | nn.Conv1d(condition_dim, dim * mlp_factor, 1), 142 | nn.GELU(), 143 | nn.Conv1d(dim * mlp_factor, dim, 1), 144 | ) 145 | else: 146 | self.conditioner_projection = nn.Identity() 147 | 148 | self.residual_layers = nn.ModuleList( 149 | [ 150 | NaiveV2DiffLayer( 151 | dim_model=dim, 152 | dim_cond=dim if use_mlp else condition_dim, 153 | num_heads=8, 154 | use_norm=use_norm, 155 | conv_only=conv_only, 156 | conv_dropout=conv_dropout, 157 | atten_dropout=atten_dropout, 158 | use_mlp=use_mlp, 159 | expansion_factor=expansion_factor, 160 | kernel_size=kernel_size, 161 | wavenet_like=wavenet_like, 162 | conv_model_type=conv_model_type, 163 | ) 164 | for i in range(num_layers) 165 | ] 166 | ) 167 | 168 | if use_mlp: 169 | _ = nn.Conv1d(dim * mlp_factor, mel_channels, kernel_size=1) 170 | nn.init.zeros_(_.weight) 171 | self.output_projection = nn.Sequential( 172 | nn.Conv1d(dim, dim * mlp_factor, kernel_size=1), 173 | nn.GELU(), 174 | _, 175 | ) 176 | else: 177 | self.output_projection = nn.Conv1d(dim, mel_channels, kernel_size=1) 178 | nn.init.zeros_(self.output_projection.weight) 179 | 180 | def forward(self, spec, diffusion_step, cond=None): 181 | x = spec 182 | conditioner = cond 183 | """ 184 | 185 | :param x: [B, M, T] 186 | :param diffusion_step: [B,] 187 | :param conditioner: [B, M, T] 188 | :return: 189 | """ 190 | 191 | # To keep compatibility with DiffSVC, [B, 1, M, T] 192 | use_4_dim = False 193 | if x.dim() == 4: 194 | x = x[:, 0] 195 | use_4_dim = True 196 | 197 | assert x.dim() == 3, f"mel must be 3 dim tensor, but got {x.dim()}" 198 | 199 | x = self.input_projection(x) # x [B, residual_channel, T] 200 | x = F.gelu(x) 201 | 202 | diffusion_step = self.diffusion_embedding(diffusion_step).unsqueeze(-1) 203 | condition = self.conditioner_projection(conditioner) 204 | 205 | if self.wavenet_like: 206 | _sk = [] 207 | for layer in self.residual_layers: 208 | # conditional mask 209 | if self.mask_cond_ratio is not None: 210 | _mask_cond_ratio = random.choice([True, True, False]) 211 | if _mask_cond_ratio: 212 | # 随机从0到mask_cond_ratio中选择一个数 213 | _mask_cond_ratio = random.uniform(0, self.mask_cond_ratio) 214 | _conditioner = F.dropout(conditioner, _mask_cond_ratio) 215 | else: 216 | _conditioner = conditioner 217 | # forward 218 | x, sk = layer(x, _conditioner, diffusion_step) 219 | _sk.append(sk) 220 | x = torch.sum(torch.stack(_sk), dim=0) / math.sqrt(len(self.residual_layers)) 221 | 222 | else: 223 | for layer in self.residual_layers: 224 | # conditional mask 225 | if self.mask_cond_ratio is not None: 226 | _mask_cond_ratio = random.choice([True, True, False]) 227 | if _mask_cond_ratio: 228 | # 随机从0到mask_cond_ratio中选择一个数 229 | _mask_cond_ratio = random.uniform(0, self.mask_cond_ratio) 230 | _conditioner = F.dropout(conditioner, _mask_cond_ratio) 231 | else: 232 | _conditioner = conditioner 233 | # forward 234 | x = layer(x, condition, diffusion_step) 235 | 236 | # MLP and GLU 237 | x = self.output_projection(x) # [B, 128, T] 238 | 239 | return x[:, None] if use_4_dim else x 240 | -------------------------------------------------------------------------------- /reflow/reflow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | try: 4 | import torch_musa 5 | except ImportError: 6 | pass 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from tqdm import tqdm 10 | 11 | 12 | class Bi_RectifiedFlow(nn.Module): 13 | def __init__(self, 14 | velocity_fn, 15 | spec_min=-12, 16 | spec_max=2): 17 | super().__init__() 18 | self.velocity_fn = velocity_fn 19 | self.spec_min = spec_min 20 | self.spec_max = spec_max 21 | 22 | def reflow_loss(self, x_1, x_0, t, cond=None, loss_type='l2_lognorm'): 23 | x_t = x_0 + t[:, None, None, None] * (x_1 - x_0) 24 | v_pred = self.velocity_fn(x_t, 1000 * t, cond=cond) 25 | 26 | if loss_type == 'l1': 27 | loss = (x_1 - x_0 - v_pred).abs().mean() 28 | elif loss_type == 'l2': 29 | loss = F.mse_loss(x_1 - x_0, v_pred) 30 | elif loss_type == 'l2_lognorm': 31 | weights = 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / ( 1 - t)) ** 2) 32 | loss = torch.mean(weights[:, None, None, None] * F.mse_loss(x_1 - x_0, v_pred, reduction='none')) 33 | else: 34 | raise NotImplementedError() 35 | 36 | return loss 37 | 38 | def sample_euler(self, x, t, dt, cond=None): 39 | x += self.velocity_fn(x, 1000 * t, cond=cond) * dt 40 | t += dt 41 | return x, t 42 | 43 | def sample_rk4(self, x, t, dt, cond=None): 44 | k_1 = self.velocity_fn(x, 1000 * t, cond=cond) 45 | k_2 = self.velocity_fn(x + 0.5 * k_1 * dt, 1000 * (t + 0.5 * dt), cond=cond) 46 | k_3 = self.velocity_fn(x + 0.5 * k_2 * dt, 1000 * (t + 0.5 * dt), cond=cond) 47 | k_4 = self.velocity_fn(x + k_3 * dt, 1000 * (t + dt), cond=cond) 48 | x += (k_1 + 2 * k_2 + 2 * k_3 + k_4) * dt / 6 49 | t += dt 50 | return x, t 51 | 52 | def sample_heun(self, x, t, dt, cond=None): 53 | # Predict 54 | k_1 = self.velocity_fn(x, 1000 * t, cond=cond) 55 | x_pred = x + k_1 * dt 56 | t_pred = t + dt 57 | # Correct 58 | k_2 = self.velocity_fn(x_pred, 1000 * t_pred, cond=cond) 59 | x += (k_1 + k_2) / 2 * dt 60 | t += dt 61 | return x, t 62 | 63 | def sample_PECECE(self, x, t, dt, cond=None): 64 | # Predict1 65 | k_1 = self.velocity_fn(x, 1000 * t, cond=cond) 66 | x_pred1 = x + k_1 * dt 67 | t_pred1 = t + dt 68 | # Correct1 69 | k_2 = self.velocity_fn(x_pred1, 1000 * t_pred1, cond=cond) 70 | x_corr1 = x + (k_1 + k_2) / 2 * dt 71 | # Predict2 72 | k_3 = self.velocity_fn(x_corr1, 1000 * (t + dt), cond=cond) 73 | x_pred2 = x_corr1 + k_3 * dt 74 | # Correct2 75 | k_4 = self.velocity_fn(x_pred2, 1000 * (t + 2*dt), cond=cond) 76 | x += (k_3 + k_4) / 2 * dt 77 | t += dt 78 | return x, t 79 | 80 | def forward(self, 81 | infer=True, 82 | x_start=None, 83 | x_end=None, 84 | cond=None, 85 | t_start=0.0, 86 | t_end=1.0, 87 | infer_step=10, 88 | method='euler', 89 | use_tqdm=True): 90 | if cond is not None: 91 | cond = cond.transpose(1, 2) # [B, H, T] 92 | if not infer: 93 | x_0 = x_start.transpose(1, 2).unsqueeze(1) # [B, 1, M, T] 94 | x_1 = self.norm_spec(x_end).transpose(1, 2).unsqueeze(1) # [B, 1, M, T] 95 | t = torch.rand(x_0.shape[0], device=x_0.device) 96 | t = torch.clip(t, 1e-7, 1-1e-7) 97 | return self.reflow_loss(x_1, x_0, t, cond=cond) 98 | else: 99 | # initial condition and step size of the ODE 100 | if t_start < 0.0: 101 | t_start = 0.0 102 | elif t_start > 1.0: 103 | t_start = 1.0 104 | if t_end < 0.0: 105 | t_end = 0.0 106 | elif t_end > 1.0: 107 | t_end = 1.0 108 | assert t_start < t_end 109 | 110 | if x_start is not None and x_end is None: 111 | x = x_start.transpose(1, 2).unsqueeze(1) # [B, 1, M, T] 112 | t = torch.full((x_start.shape[0],), t_start, device=x_start.device) 113 | dt = (t_end - t_start) / infer_step 114 | elif x_start is None and x_end is not None: 115 | x = self.norm_spec(x_end).transpose(1, 2).unsqueeze(1) # [B, 1, M, T] 116 | t = torch.full((x_end.shape[0],), t_end, device=x_end.device) 117 | dt = -(t_end - t_start) / infer_step 118 | 119 | # sampling 120 | if method == 'euler': 121 | if use_tqdm: 122 | for i in tqdm(range(infer_step), desc='sample time step', total=infer_step): 123 | x, t = self.sample_euler(x, t, dt, cond=cond) 124 | else: 125 | for i in range(infer_step): 126 | x, t = self.sample_euler(x, t, dt, cond=cond) 127 | 128 | elif method == 'rk4': 129 | if use_tqdm: 130 | for i in tqdm(range(infer_step), desc='sample time step', total=infer_step): 131 | x, t = self.sample_rk4(x, t, dt, cond=cond) 132 | else: 133 | for i in range(infer_step): 134 | x, t = self.sample_rk4(x, t, dt, cond=cond) 135 | 136 | elif method == 'heun': 137 | if use_tqdm: 138 | for i in tqdm(range(infer_step), desc='sample time step', total=infer_step): 139 | x, t = self.sample_heun(x, t, dt, cond=cond) 140 | else: 141 | for i in range(infer_step): 142 | x, t = self.sample_heun(x, t, dt, cond=cond) 143 | 144 | elif method == 'PECECE': 145 | if use_tqdm: 146 | for i in tqdm(range(infer_step), desc='sample time step', total=infer_step): 147 | x, t = self.sample_PECECE(x, t, dt, cond=cond) 148 | else: 149 | for i in range(infer_step): 150 | x, t = self.sample_PECECE(x, t, dt, cond=cond) 151 | 152 | else: 153 | raise NotImplementedError(method) 154 | 155 | x = x.squeeze(1).transpose(1, 2) # [B, T, M] 156 | 157 | if dt > 0: 158 | return self.denorm_spec(x) 159 | else: 160 | return x 161 | 162 | def norm_spec(self, x): 163 | return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 164 | 165 | def denorm_spec(self, x): 166 | return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min -------------------------------------------------------------------------------- /reflow/solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | try: 6 | import torch_musa 7 | use_torch_musa = True 8 | except ImportError: 9 | use_torch_musa = False 10 | import librosa 11 | from logger.saver import Saver 12 | from logger import utils 13 | from torch import autocast 14 | from nsf_hifigan.nvSTFT import STFT 15 | # from torch.cuda.amp import GradScaler 16 | 17 | def calculate_mel_snr(gt_mel, pred_mel): 18 | # 计算误差图像 19 | error_image = gt_mel - pred_mel 20 | # 计算参考图像的平方均值 21 | mean_square_reference = torch.mean(gt_mel ** 2) 22 | # 计算误差图像的方差 23 | variance_error = torch.var(error_image) 24 | # 计算并返回SNR 25 | snr = 10 * torch.log10(mean_square_reference / variance_error) 26 | return snr 27 | 28 | 29 | def calculate_mel_si_snr(gt_mel, pred_mel): 30 | # 将测试图像按比例调整以最小化误差 31 | scale = torch.sum(gt_mel * pred_mel) / torch.sum(gt_mel ** 2) 32 | test_image_scaled = scale * pred_mel 33 | # 计算误差图像 34 | error_image = gt_mel - test_image_scaled 35 | # 计算参考图像的平方均值 36 | mean_square_reference = torch.mean(gt_mel ** 2) 37 | # 计算误差图像的方差 38 | variance_error = torch.var(error_image) 39 | # 计算并返回SI-SNR 40 | si_snr = 10 * torch.log10(mean_square_reference / variance_error) 41 | return si_snr 42 | 43 | 44 | def calculate_mel_psnr(gt_mel, pred_mel): 45 | # 计算误差图像 46 | error_image = gt_mel - pred_mel 47 | # 计算误差图像的均方误差 48 | mse = torch.mean(error_image ** 2) 49 | # 计算参考图像的最大可能功率 50 | max_power = torch.max(gt_mel) ** 2 51 | # 计算并返回PSNR 52 | psnr = 10 * torch.log10(max_power / mse) 53 | return psnr 54 | 55 | def clip_grad_value_(parameters, clip_value): 56 | if isinstance(parameters, torch.Tensor): 57 | parameters = [parameters] 58 | 59 | parameters_with_grad = [p for p in parameters if p.grad is not None] 60 | 61 | torch.nn.utils.clip_grad_value_(parameters_with_grad, clip_value) 62 | 63 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters_with_grad]), 2) 64 | return total_norm 65 | 66 | def test(args, model, vocoder, loader_test, saver): 67 | print(' [*] testing...') 68 | model.eval() 69 | 70 | # losses 71 | test_loss = 0. 72 | 73 | # mel mse val 74 | mel_val_mse_all = 0 75 | mel_val_mse_all_num = 0 76 | mel_val_snr_all = 0 77 | mel_val_psnr_all = 0 78 | mel_val_sisnr_all = 0 79 | 80 | # intialization 81 | num_batches = len(loader_test) 82 | rtf_all = [] 83 | spec_min = -2 84 | spec_max = 10 85 | spec_range = 12 86 | 87 | # run 88 | with torch.no_grad(): 89 | for bidx, data in enumerate(loader_test): 90 | fn = data['name'][0] 91 | print('--------') 92 | print('{}/{} - {}'.format(bidx, num_batches, fn)) 93 | 94 | # unpack data 95 | for k in data.keys(): 96 | if not k.startswith('name'): 97 | data[k] = data[k].to(args.device) 98 | print('>>', data['name'][0]) 99 | 100 | # forward 101 | st_time = time.time() 102 | mel = model( 103 | data['units'], 104 | data['f0'], 105 | data['volume'], 106 | data['spk_id'], 107 | vocoder=vocoder, 108 | infer=True, 109 | return_wav=False, 110 | infer_step=args.infer.infer_step, 111 | method=args.infer.method) 112 | signal = vocoder.infer(mel, data['f0']) 113 | ed_time = time.time() 114 | 115 | # RTF 116 | run_time = ed_time - st_time 117 | song_time = signal.shape[-1] / args.data.sampling_rate 118 | rtf = run_time / song_time 119 | print('RTF: {} | {} / {}'.format(rtf, run_time, song_time)) 120 | rtf_all.append(rtf) 121 | 122 | # loss 123 | loss = model( 124 | data['units'], 125 | data['f0'], 126 | data['volume'], 127 | data['spk_id'], 128 | vocoder=vocoder, 129 | gt_spec=data['mel'], 130 | infer=False) 131 | test_loss += loss.item() 132 | 133 | # log mel 134 | saver.log_spec(data['name'][0], data['mel'], mel) 135 | 136 | # log audio 137 | path_audio = os.path.join(args.data.valid_path, 'audio', data['name_ext'][0]) 138 | audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate) 139 | if len(audio.shape) > 1: 140 | audio = librosa.to_mono(audio) 141 | audio = torch.from_numpy(audio).unsqueeze(0).to(signal) 142 | saver.log_audio({fn+'/gt.wav': audio, fn+'/pred.wav': signal}) 143 | 144 | WAV2MEL = STFT( 145 | sr=args.data.sampling_rate, 146 | n_mels=128, 147 | n_fft=2048, 148 | win_size=2048, 149 | hop_length=512, 150 | fmin=40, 151 | fmax=22050, 152 | clip_val=1e-5) 153 | audio = audio.unsqueeze(0) 154 | pre_mel = WAV2MEL.get_mel(signal[0, ...]) 155 | pre_mel = pre_mel.transpose(-1, -2) 156 | gt_mel = WAV2MEL.get_mel(audio[0, ...]) 157 | gt_mel = gt_mel.transpose(-1, -2) 158 | # 如果形状不同,裁剪使得形状相同 159 | if pre_mel.shape[1] != gt_mel.shape[1]: 160 | gt_mel = gt_mel[:, :pre_mel.shape[1], :] 161 | saver.log_spec(data['name'][0], gt_mel, pre_mel) 162 | 163 | # 计算指标 164 | mel_val_mse_all += torch.nn.functional.mse_loss(mel, data['mel']).detach().cpu().numpy() 165 | gt_mel_norm = torch.clip(data['mel'], spec_min, spec_max) 166 | gt_mel_norm = gt_mel_norm / spec_range + spec_min 167 | pre_mel_norm = torch.clip(mel, spec_min, spec_max) 168 | pre_mel_norm = pre_mel_norm / spec_range + spec_min 169 | mel_val_snr_all += calculate_mel_snr(gt_mel_norm, pre_mel_norm).detach().cpu().numpy() 170 | mel_val_psnr_all += calculate_mel_psnr(gt_mel_norm, pre_mel_norm).detach().cpu().numpy() 171 | mel_val_sisnr_all += calculate_mel_si_snr(gt_mel_norm, pre_mel_norm).detach().cpu().numpy() 172 | mel_val_mse_all_num += 1 173 | 174 | # report 175 | test_loss /= num_batches 176 | mel_val_mse_all /= mel_val_mse_all_num 177 | mel_val_snr_all /= mel_val_mse_all_num 178 | mel_val_psnr_all /= mel_val_mse_all_num 179 | mel_val_sisnr_all /= mel_val_mse_all_num 180 | 181 | # check 182 | print(' [test_loss] test_loss:', test_loss) 183 | print(' Real Time Factor', np.mean(rtf_all)) 184 | saver.log_value({ 185 | 'validation/mel_val_mse': mel_val_mse_all 186 | }) 187 | print(' Mel Val SNR', mel_val_snr_all) 188 | saver.log_value({ 189 | 'validation/mel_val_snr': mel_val_snr_all 190 | }) 191 | print(' Mel Val PSNR', mel_val_psnr_all) 192 | saver.log_value({ 193 | 'validation/mel_val_psnr': mel_val_psnr_all 194 | }) 195 | print(' Mel Val SI-SNR', mel_val_sisnr_all) 196 | saver.log_value({ 197 | 'validation/mel_val_sisnr': mel_val_sisnr_all 198 | }) 199 | return test_loss 200 | 201 | 202 | def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test): 203 | # saver 204 | saver = Saver(args, initial_global_step=initial_global_step) 205 | 206 | # model size 207 | params_count = utils.get_network_paras_amount({'model': model}) 208 | saver.log_info('--- model size ---') 209 | saver.log_info(params_count) 210 | 211 | # run 212 | num_batches = len(loader_train) 213 | start_epoch = initial_global_step // num_batches 214 | model.train() 215 | saver.log_info('======= start training =======') 216 | if use_torch_musa: 217 | scaler = torch.musa.amp.GradScaler() 218 | else: 219 | scaler = torch.cuda.amp.GradScaler() 220 | 221 | if args.train.amp_dtype == 'fp32': 222 | dtype = torch.float32 223 | elif args.train.amp_dtype == 'fp16': 224 | dtype = torch.float16 225 | elif args.train.amp_dtype == 'bf16': 226 | dtype = torch.bfloat16 227 | else: 228 | raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype) 229 | for epoch in range(start_epoch, args.train.epochs): 230 | for batch_idx, data in enumerate(loader_train): 231 | saver.global_step_increment() 232 | optimizer.zero_grad() 233 | 234 | # unpack data 235 | for k in data.keys(): 236 | if not k.startswith('name'): 237 | data[k] = data[k].to(args.device) 238 | 239 | # forward 240 | if dtype == torch.float32: 241 | loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], 242 | aug_shift=data['aug_shift'], vocoder=vocoder, gt_spec=data['mel'].float(), infer=False) 243 | else: 244 | if use_torch_musa: 245 | with torch.musa.amp.autocast(dtype=dtype): 246 | loss = model(data['units'], data['f0'], data['volume'], data['spk_id'], 247 | aug_shift=data['aug_shift'], vocoder=vocoder, gt_spec=data['mel'].float(), infer=False) 248 | else: 249 | with autocast(device_type=args.device, dtype=dtype): 250 | loss = model(data['units'], data['f0'], data['volume'], data['spk_id'], 251 | aug_shift=data['aug_shift'], vocoder=vocoder, gt_spec=data['mel'].float(), infer=False) 252 | 253 | # handle nan loss 254 | if torch.isnan(loss): 255 | raise ValueError(' [x] nan loss ') 256 | else: 257 | # backpropagate 258 | if dtype == torch.float32: 259 | loss.backward() 260 | grad_norm = clip_grad_value_(model.parameters(), 1) 261 | optimizer.step() 262 | else: 263 | scaler.scale(loss).backward() 264 | scaler.unscale_(optimizer) 265 | grad_norm = clip_grad_value_(model.parameters(), 1) 266 | scaler.step(optimizer) 267 | scaler.update() 268 | scheduler.step() 269 | 270 | # log loss 271 | if saver.global_step % args.train.interval_log == 0: 272 | current_lr = optimizer.param_groups[0]['lr'] 273 | saver.log_info( 274 | 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {} | grad: {:.2f}'.format( 275 | epoch, 276 | batch_idx, 277 | num_batches, 278 | args.env.expdir, 279 | args.train.interval_log/saver.get_interval_time(), 280 | current_lr, 281 | loss.item(), 282 | saver.get_total_time(), 283 | saver.global_step, 284 | grad_norm 285 | ) 286 | ) 287 | 288 | saver.log_value({ 289 | 'train/loss': loss.item(), 290 | 'train/lr': current_lr, 291 | 'train/grad_norm': grad_norm 292 | }) 293 | 294 | # validation 295 | if saver.global_step % args.train.interval_val == 0: 296 | optimizer_save = optimizer if args.train.save_opt else None 297 | 298 | # save latest 299 | saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}') 300 | last_val_step = saver.global_step - args.train.interval_val 301 | if last_val_step % args.train.interval_force_save != 0: 302 | saver.delete_model(postfix=f'{last_val_step}') 303 | 304 | # run testing set 305 | test_loss = test(args, model, vocoder, loader_test, saver) 306 | 307 | # log loss 308 | saver.log_info( 309 | ' --- --- \nloss: {:.3f}. '.format( 310 | test_loss, 311 | ) 312 | ) 313 | 314 | saver.log_value({ 315 | 'validation/loss': test_loss, 316 | }) 317 | 318 | model.train() 319 | 320 | 321 | -------------------------------------------------------------------------------- /reflow/vocoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | try: 5 | import torch_musa 6 | use_torch_musa = True 7 | except ImportError: 8 | use_torch_musa = False 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | from nsf_hifigan.nvSTFT import STFT 13 | from nsf_hifigan.models import load_model,load_config 14 | from torchaudio.transforms import Resample 15 | from .reflow import Bi_RectifiedFlow 16 | from .naive_v2_diff import NaiveV2Diff 17 | from .wavenet import WaveNet 18 | 19 | class DotDict(dict): 20 | def __getattr__(*args): 21 | val = dict.get(*args) 22 | return DotDict(val) if type(val) is dict else val 23 | 24 | __setattr__ = dict.__setitem__ 25 | __delattr__ = dict.__delitem__ 26 | 27 | 28 | def load_model_vocoder( 29 | model_path, 30 | device='cpu'): 31 | config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') 32 | with open(config_file, "r") as config: 33 | args = yaml.safe_load(config) 34 | args = DotDict(args) 35 | 36 | # load vocoder 37 | vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device) 38 | 39 | # load model 40 | if args.model.type == 'RectifiedFlow_VAE': 41 | model = Unit2Wav_VAE( 42 | args.data.sampling_rate, 43 | args.data.block_size, 44 | args.model.win_length, 45 | args.data.encoder_out_channels, 46 | args.model.n_spk, 47 | args.model.use_pitch_aug, 48 | vocoder.dimension, 49 | args.model.n_layers, 50 | args.model.n_chans, 51 | args.model.n_hidden, 52 | args.model.back_bone, 53 | args.model.use_attention) 54 | 55 | else: 56 | raise ValueError(f" [x] Unknown Model: {args.model.type}") 57 | 58 | print(' [Loading] ' + model_path) 59 | ckpt = torch.load(model_path, map_location=torch.device(device)) 60 | model.to(device) 61 | model.load_state_dict(ckpt['model']) 62 | model.eval() 63 | return model, vocoder, args 64 | 65 | 66 | class Vocoder: 67 | def __init__(self, vocoder_type, vocoder_ckpt, device = None): 68 | if device is None: 69 | if torch.cuda.is_available(): 70 | device = 'cuda' 71 | elif use_torch_musa: 72 | if torch.musa.is_available(): 73 | device = 'musa' 74 | else: 75 | device = 'cpu' 76 | else: 77 | device = 'cpu' 78 | self.device = device 79 | 80 | if vocoder_type == 'nsf-hifigan': 81 | self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device) 82 | elif vocoder_type == 'nsf-hifigan-log10': 83 | self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device) 84 | else: 85 | raise ValueError(f" [x] Unknown vocoder: {vocoder_type}") 86 | 87 | self.resample_kernel = {} 88 | self.vocoder_sample_rate = self.vocoder.sample_rate() 89 | self.vocoder_hop_size = self.vocoder.hop_size() 90 | self.dimension = self.vocoder.dimension() 91 | 92 | def extract(self, audio, sample_rate=0, keyshift=0): 93 | 94 | # resample 95 | if sample_rate == self.vocoder_sample_rate or sample_rate == 0: 96 | audio_res = audio 97 | else: 98 | key_str = str(sample_rate) 99 | if key_str not in self.resample_kernel: 100 | self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device) 101 | audio_res = self.resample_kernel[key_str](audio) 102 | 103 | # extract 104 | mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins 105 | return mel 106 | 107 | def infer(self, mel, f0): 108 | f0 = f0[:,:mel.size(1),0] # B, n_frames 109 | audio = self.vocoder(mel, f0) 110 | return audio 111 | 112 | 113 | class NsfHifiGAN(torch.nn.Module): 114 | def __init__(self, model_path, device=None): 115 | super().__init__() 116 | if device is None: 117 | if torch.cuda.is_available(): 118 | device = 'cuda' 119 | elif use_torch_musa: 120 | if torch.musa.is_available(): 121 | device = 'musa' 122 | else: 123 | device = 'cpu' 124 | else: 125 | device = 'cpu' 126 | self.device = device 127 | self.model_path = model_path 128 | self.model = None 129 | self.h = load_config(model_path) 130 | self.stft = STFT( 131 | self.h.sampling_rate, 132 | self.h.num_mels, 133 | self.h.n_fft, 134 | self.h.win_size, 135 | self.h.hop_size, 136 | self.h.fmin, 137 | self.h.fmax) 138 | 139 | def sample_rate(self): 140 | return self.h.sampling_rate 141 | 142 | def hop_size(self): 143 | return self.h.hop_size 144 | 145 | def dimension(self): 146 | return self.h.num_mels 147 | 148 | def extract(self, audio, keyshift=0): 149 | mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) # B, n_frames, bins 150 | return mel 151 | 152 | def forward(self, mel, f0): 153 | if self.model is None: 154 | print('| Load HifiGAN: ', self.model_path) 155 | self.model, self.h = load_model(self.model_path, device=self.device) 156 | with torch.no_grad(): 157 | c = mel.transpose(1, 2) 158 | audio = self.model(c, f0) 159 | return audio 160 | 161 | 162 | class NsfHifiGANLog10(NsfHifiGAN): 163 | def forward(self, mel, f0): 164 | if self.model is None: 165 | print('| Load HifiGAN: ', self.model_path) 166 | self.model, self.h = load_model(self.model_path, device=self.device) 167 | with torch.no_grad(): 168 | c = 0.434294 * mel.transpose(1, 2) 169 | audio = self.model(c, f0) 170 | return audio 171 | 172 | 173 | class Unit2Wav_VAE(nn.Module): 174 | def __init__( 175 | self, 176 | sampling_rate, 177 | block_size, 178 | win_length, 179 | n_unit, 180 | n_spk, 181 | use_pitch_aug=False, 182 | out_dims=128, 183 | n_layers=6, 184 | n_chans=512, 185 | n_hidden=256, 186 | back_bone='lynxnet', 187 | use_attention=False): 188 | super().__init__() 189 | self.f0_embed = nn.Linear(1, n_hidden) 190 | self.use_attention = use_attention 191 | if use_attention: 192 | self.unit_embed = nn.Linear(n_unit, n_hidden) 193 | self.volume_embed = nn.Linear(1, n_hidden) 194 | self.attention = nn.Sequential( 195 | nn.TransformerEncoderLayer( 196 | d_model=n_hidden, 197 | nhead=8, 198 | dim_feedforward=n_hidden * 4, 199 | dropout=0.1, 200 | activation='gelu', 201 | ), 202 | nn.Linear(n_hidden, out_dims), 203 | ) 204 | else: 205 | self.unit_embed = nn.Linear(n_unit, out_dims) 206 | self.volume_embed = nn.Linear(1, out_dims) 207 | if use_pitch_aug: 208 | self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False) 209 | else: 210 | self.aug_shift_embed = None 211 | self.n_spk = n_spk 212 | if n_spk is not None and n_spk > 1: 213 | self.spk_embed = nn.Embedding(n_spk, n_hidden) 214 | if back_bone is None or back_bone == 'lynxnet': 215 | self.reflow_model = Bi_RectifiedFlow(NaiveV2Diff(mel_channels=out_dims, dim=n_chans, num_layers=n_layers, condition_dim=n_hidden, use_mlp=False)) 216 | elif back_bone == 'wavenet': 217 | self.reflow_model = Bi_RectifiedFlow(WaveNet(in_dims=out_dims, n_layers=n_layers, n_chans=n_chans, n_hidden=n_hidden)) 218 | else: 219 | raise ValueError(f" [x] Unknown Backbone: {back_bone}") 220 | 221 | def forward(self, units, f0, volume, spk_id=None, spk_mix_dict=None, aug_shift=None, vocoder=None, 222 | gt_spec=None, infer=True, return_wav=False, infer_step=10, method='euler', t_start=0.0, use_tqdm=True): 223 | 224 | ''' 225 | input: 226 | B x n_frames x n_unit 227 | return: 228 | dict of B x n_frames x feat 229 | ''' 230 | # condition 231 | cond = self.f0_embed((1+ f0 / 700).log()) 232 | if self.n_spk is not None and self.n_spk > 1: 233 | if spk_mix_dict is not None: 234 | for k, v in spk_mix_dict.items(): 235 | spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) 236 | cond = cond + v * self.spk_embed(spk_id_torch - 1) 237 | else: 238 | cond = cond + self.spk_embed(spk_id - 1) 239 | if self.aug_shift_embed is not None and aug_shift is not None: 240 | cond = cond + self.aug_shift_embed(aug_shift / 5) 241 | 242 | # vae mean 243 | x = self.unit_embed(units) + self.volume_embed(volume) 244 | if self.use_attention: 245 | x = self.attention(x) 246 | 247 | # vae noise 248 | x += torch.randn_like(x) 249 | 250 | x = self.reflow_model(infer=infer, x_start=x, x_end=gt_spec, cond=cond, infer_step=infer_step, method='euler', use_tqdm=True) 251 | 252 | if return_wav and infer: 253 | return vocoder.infer(x, f0) 254 | else: 255 | return x 256 | 257 | def vae_infer(self, input_mel, input_f0, input_spk_id, output_f0, output_spk_id=None, spk_mix_dict=None, aug_shift=None, 258 | infer_step=10, method='euler'): 259 | 260 | # source condition 261 | source_cond = self.f0_embed((1+ input_f0 / 700).log()) + self.spk_embed(input_spk_id - 1) 262 | 263 | # target condition 264 | target_cond = self.f0_embed((1+ output_f0 / 700).log()) 265 | if self.n_spk is not None and self.n_spk > 1: 266 | if spk_mix_dict is not None: 267 | for k, v in spk_mix_dict.items(): 268 | spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device) 269 | target_cond = target_cond + v * self.spk_embed(spk_id_torch - 1) 270 | else: 271 | target_cond = target_cond + self.spk_embed(output_spk_id - 1) 272 | if self.aug_shift_embed is not None and aug_shift is not None: 273 | target_cond = target_cond + self.aug_shift_embed(aug_shift / 5) 274 | 275 | print("\nExtracting features...") 276 | latent = self.reflow_model(infer=True, x_end=input_mel, cond=source_cond, infer_step=infer_step, method='euler', use_tqdm=True) 277 | print("\nSynthesizing...") 278 | output_mel = self.reflow_model(infer=True, x_start=latent, cond=target_cond, infer_step=infer_step, method='euler', use_tqdm=True) 279 | return output_mel 280 | -------------------------------------------------------------------------------- /reflow/wavenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from math import sqrt 3 | 4 | import torch 5 | try: 6 | import torch_musa 7 | use_torch_musa = True 8 | except ImportError: 9 | from torch.nn import Mish 10 | use_torch_musa = False 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from transformers.models.roformer.modeling_roformer import RoFormerEncoder, RoFormerConfig 14 | 15 | 16 | if use_torch_musa: 17 | class Mish(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | 21 | def forward(self, x): 22 | return x * (torch.tanh(F.softplus(x))) 23 | 24 | 25 | class Conv1d(torch.nn.Conv1d): 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | nn.init.kaiming_normal_(self.weight) 29 | 30 | 31 | class SinusoidalPosEmb(nn.Module): 32 | def __init__(self, dim): 33 | super().__init__() 34 | self.dim = dim 35 | 36 | def forward(self, x): 37 | device = x.device 38 | half_dim = self.dim // 2 39 | emb = math.log(10000) / (half_dim - 1) 40 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 41 | emb = x[:, None] * emb[None, :] 42 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 43 | return emb 44 | 45 | 46 | class ResidualBlock(nn.Module): 47 | def __init__(self, encoder_hidden, residual_channels, dilation, kernel_size=3): 48 | super().__init__() 49 | self.residual_channels = residual_channels 50 | self.dilated_conv = nn.Conv1d( 51 | residual_channels, 52 | 2 * residual_channels, 53 | kernel_size=kernel_size, 54 | padding=dilation if (kernel_size == 3) else int((kernel_size-1) * dilation / 2), 55 | dilation=dilation 56 | ) 57 | self.diffusion_projection = nn.Linear(residual_channels, residual_channels) 58 | self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1) 59 | self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1) 60 | 61 | def forward(self, x, conditioner, diffusion_step): 62 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) 63 | conditioner = self.conditioner_projection(conditioner) 64 | y = x + diffusion_step 65 | 66 | y = self.dilated_conv(y) + conditioner 67 | 68 | # Using torch.split instead of torch.chunk to avoid using onnx::Slice 69 | gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) 70 | y = torch.sigmoid(gate) * torch.tanh(filter) 71 | 72 | y = self.output_projection(y) 73 | 74 | # Using torch.split instead of torch.chunk to avoid using onnx::Slice 75 | residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) 76 | return (x + residual) / math.sqrt(2.0), skip 77 | 78 | 79 | class WaveNet(nn.Module): 80 | def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256, dilation=1, kernel_size=3, 81 | transformer_use=False, transformer_roformer_use=False, transformer_n_layers=2, transformer_n_head=4): 82 | super().__init__() 83 | self.input_projection = Conv1d(in_dims, n_chans, 1) 84 | self.diffusion_embedding = SinusoidalPosEmb(n_chans) 85 | self.mlp = nn.Sequential( 86 | nn.Linear(n_chans, n_chans * 4), 87 | Mish(), 88 | nn.Linear(n_chans * 4, n_chans) 89 | ) 90 | self.residual_layers = nn.ModuleList([ 91 | ResidualBlock( 92 | encoder_hidden=n_hidden, 93 | residual_channels=n_chans, 94 | dilation=(2 ** (i % dilation)) if (dilation != 1) else 1, 95 | kernel_size=kernel_size 96 | ) 97 | for i in range(n_layers) 98 | ]) 99 | self.transformer_roformer_use = transformer_roformer_use if (transformer_roformer_use is not None) else False 100 | if transformer_use: 101 | if transformer_roformer_use: 102 | self.transformer = RoFormerEncoder( 103 | RoFormerConfig( 104 | hidden_size=n_chans, 105 | max_position_embeddings=4096, 106 | num_attention_heads=transformer_n_head, 107 | num_hidden_layers=transformer_n_layers, 108 | add_cross_attention=False 109 | ) 110 | ) 111 | else: 112 | transformer_layer = nn.TransformerEncoderLayer( 113 | d_model=n_chans, 114 | nhead=transformer_n_head, 115 | dim_feedforward=n_chans * 4, 116 | dropout=0.1, 117 | activation='gelu' 118 | ) 119 | self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=transformer_n_layers) 120 | else: 121 | self.transformer = None 122 | 123 | self.skip_projection = Conv1d(n_chans, n_chans, 1) 124 | self.output_projection = Conv1d(n_chans, in_dims, 1) 125 | nn.init.zeros_(self.output_projection.weight) 126 | 127 | def forward(self, spec, diffusion_step, cond): 128 | """ 129 | :param spec: [B, 1, M, T] 130 | :param diffusion_step: [B, 1] 131 | :param cond: [B, M, T] 132 | :return: 133 | """ 134 | x = spec.squeeze(1) 135 | x = self.input_projection(x) # [B, residual_channel, T] 136 | 137 | x = F.relu(x) 138 | diffusion_step = self.diffusion_embedding(diffusion_step) 139 | diffusion_step = self.mlp(diffusion_step) 140 | skip = [] 141 | for layer in self.residual_layers: 142 | x, skip_connection = layer(x, cond, diffusion_step) 143 | skip.append(skip_connection) 144 | 145 | x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) 146 | x = self.skip_projection(x) 147 | x = F.relu(x) 148 | if self.transformer is not None: 149 | if self.transformer_roformer_use: 150 | x = self.transformer(x.transpose(1, 2))[0].transpose(1, 2) 151 | else: 152 | x = self.transformer(x.transpose(1, 2)).transpose(1, 2) 153 | x = self.output_projection(x) # [B, mel_bins, T] 154 | return x[:, None, :, :] 155 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fairseq 2 | librosa 3 | matplotlib 4 | numpy==1.26.4 5 | praat-parselmouth 6 | pyworld 7 | PyYAML 8 | resampy 9 | scikit_learn 10 | scipy 11 | SoundFile 12 | torchcrepe 13 | torchfcpe 14 | tqdm 15 | transformers 16 | tensorboard 17 | -------------------------------------------------------------------------------- /slicer.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torch 3 | try: 4 | import torch_musa 5 | except ImportError: 6 | pass 7 | import torchaudio 8 | 9 | 10 | class Slicer: 11 | def __init__(self, 12 | sr: int, 13 | threshold: float = -40., 14 | min_length: int = 5000, 15 | min_interval: int = 300, 16 | hop_size: int = 20, 17 | max_sil_kept: int = 5000): 18 | if not min_length >= min_interval >= hop_size: 19 | raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size') 20 | if not max_sil_kept >= hop_size: 21 | raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size') 22 | min_interval = sr * min_interval / 1000 23 | self.threshold = 10 ** (threshold / 20.) 24 | self.hop_size = round(sr * hop_size / 1000) 25 | self.win_size = min(round(min_interval), 4 * self.hop_size) 26 | self.min_length = round(sr * min_length / 1000 / self.hop_size) 27 | self.min_interval = round(min_interval / self.hop_size) 28 | self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) 29 | 30 | def _apply_slice(self, waveform, begin, end): 31 | if len(waveform.shape) > 1: 32 | return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)] 33 | else: 34 | return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)] 35 | 36 | # @timeit 37 | def slice(self, waveform): 38 | if len(waveform.shape) > 1: 39 | samples = librosa.to_mono(waveform) 40 | else: 41 | samples = waveform 42 | if samples.shape[0] <= self.min_length: 43 | return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} 44 | rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) 45 | sil_tags = [] 46 | silence_start = None 47 | clip_start = 0 48 | for i, rms in enumerate(rms_list): 49 | # Keep looping while frame is silent. 50 | if rms < self.threshold: 51 | # Record start of silent frames. 52 | if silence_start is None: 53 | silence_start = i 54 | continue 55 | # Keep looping while frame is not silent and silence start has not been recorded. 56 | if silence_start is None: 57 | continue 58 | # Clear recorded silence start if interval is not enough or clip is too short 59 | is_leading_silence = silence_start == 0 and i > self.max_sil_kept 60 | need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length 61 | if not is_leading_silence and not need_slice_middle: 62 | silence_start = None 63 | continue 64 | # Need slicing. Record the range of silent frames to be removed. 65 | if i - silence_start <= self.max_sil_kept: 66 | pos = rms_list[silence_start: i + 1].argmin() + silence_start 67 | if silence_start == 0: 68 | sil_tags.append((0, pos)) 69 | else: 70 | sil_tags.append((pos, pos)) 71 | clip_start = pos 72 | elif i - silence_start <= self.max_sil_kept * 2: 73 | pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin() 74 | pos += i - self.max_sil_kept 75 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start 76 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept 77 | if silence_start == 0: 78 | sil_tags.append((0, pos_r)) 79 | clip_start = pos_r 80 | else: 81 | sil_tags.append((min(pos_l, pos), max(pos_r, pos))) 82 | clip_start = max(pos_r, pos) 83 | else: 84 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start 85 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept 86 | if silence_start == 0: 87 | sil_tags.append((0, pos_r)) 88 | else: 89 | sil_tags.append((pos_l, pos_r)) 90 | clip_start = pos_r 91 | silence_start = None 92 | # Deal with trailing silence. 93 | total_frames = rms_list.shape[0] 94 | if silence_start is not None and total_frames - silence_start >= self.min_interval: 95 | silence_end = min(total_frames, silence_start + self.max_sil_kept) 96 | pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start 97 | sil_tags.append((pos, total_frames + 1)) 98 | # Apply and return slices. 99 | if len(sil_tags) == 0: 100 | return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} 101 | else: 102 | chunks = [] 103 | # 第一段静音并非从头开始,补上有声片段 104 | if sil_tags[0][0]: 105 | chunks.append( 106 | {"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"}) 107 | for i in range(0, len(sil_tags)): 108 | # 标识有声片段(跳过第一段) 109 | if i: 110 | chunks.append({"slice": False, 111 | "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"}) 112 | # 标识所有静音片段 113 | chunks.append({"slice": True, 114 | "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"}) 115 | # 最后一段静音并非结尾,补上结尾片段 116 | if sil_tags[-1][1] * self.hop_size < len(waveform): 117 | chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"}) 118 | chunk_dict = {} 119 | for i in range(len(chunks)): 120 | chunk_dict[str(i)] = chunks[i] 121 | return chunk_dict 122 | 123 | 124 | def cut(audio_path, db_thresh=-30, min_len=5000, flask_mode=False, flask_sr=None): 125 | if not flask_mode: 126 | audio, sr = librosa.load(audio_path, sr=None) 127 | else: 128 | audio = audio_path 129 | sr = flask_sr 130 | slicer = Slicer( 131 | sr=sr, 132 | threshold=db_thresh, 133 | min_length=min_len 134 | ) 135 | chunks = slicer.slice(audio) 136 | return chunks 137 | 138 | 139 | def chunks2audio(audio_path, chunks): 140 | chunks = dict(chunks) 141 | audio, sr = torchaudio.load(audio_path) 142 | if len(audio.shape) == 2 and audio.shape[1] >= 2: 143 | audio = torch.mean(audio, dim=0).unsqueeze(0) 144 | audio = audio.cpu().numpy()[0] 145 | result = [] 146 | for k, v in chunks.items(): 147 | tag = v["split_time"].split(",") 148 | if tag[0] != tag[1]: 149 | result.append((v["slice"], audio[int(tag[0]):int(tag[1])])) 150 | return result, sr -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | try: 5 | import torch_musa 6 | except ImportError: 7 | pass 8 | from torch.optim import lr_scheduler 9 | from logger import utils 10 | from reflow.data_loaders import get_data_loaders 11 | from reflow.vocoder import Vocoder, Unit2Wav_VAE 12 | from reflow.solver import train 13 | 14 | def parse_args(args=None, namespace=None): 15 | """Parse command-line arguments.""" 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "-c", 19 | "--config", 20 | type=str, 21 | required=True, 22 | help="path to the config file") 23 | return parser.parse_args(args=args, namespace=namespace) 24 | 25 | 26 | if __name__ == '__main__': 27 | # parse commands 28 | cmd = parse_args() 29 | 30 | # load config 31 | args = utils.load_config(cmd.config) 32 | print(' > config:', cmd.config) 33 | print(' > exp:', args.env.expdir) 34 | 35 | # load vocoder 36 | vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device) 37 | 38 | # load model 39 | if args.model.type == 'RectifiedFlow_VAE': 40 | model = Unit2Wav_VAE( 41 | args.data.sampling_rate, 42 | args.data.block_size, 43 | args.model.win_length, 44 | args.data.encoder_out_channels, 45 | args.model.n_spk, 46 | args.model.use_pitch_aug, 47 | vocoder.dimension, 48 | args.model.n_layers, 49 | args.model.n_chans, 50 | args.model.n_hidden, 51 | args.model.back_bone, 52 | args.model.use_attention) 53 | else: 54 | raise ValueError(f" [x] Unknown Model: {args.model.type}") 55 | 56 | # load parameters 57 | optimizer = torch.optim.AdamW(model.parameters()) 58 | initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device) 59 | for param_group in optimizer.param_groups: 60 | param_group['initial_lr'] = args.train.lr 61 | param_group['lr'] = args.train.lr * args.train.gamma ** max((initial_global_step - 2) // args.train.decay_step, 0) 62 | param_group['weight_decay'] = args.train.weight_decay 63 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma, last_epoch=initial_global_step-2) 64 | 65 | # device 66 | if args.device == 'cuda': 67 | torch.cuda.set_device(args.env.gpu_id) 68 | model.to(args.device) 69 | 70 | for state in optimizer.state.values(): 71 | for k, v in state.items(): 72 | if torch.is_tensor(v): 73 | state[k] = v.to(args.device) 74 | 75 | # datas 76 | loader_train, loader_valid = get_data_loaders(args, whole_audio=False) 77 | 78 | # run 79 | train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid) 80 | 81 | --------------------------------------------------------------------------------